一、决策树回归模型的机器学习
决策树回归主要用于处理连续变量。可以用在股票价格滤波预测上,以下是股票指数运用该原理生成的走势图。
二、决策树回归模型的数学原理
三、决策树模型python源代码
复制粘贴,修改后缀.txt为.py皆可使用,股票价格滤波效果一级棒
import pandas as pd
import numpy as np
import akshare as ak
import matplotlib.pyplot as plt
import json
import requests
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
def get_stock_hist_data_em(stock='0.399300',start_date='20210101',end_date='20500101',data_type='D'):
data_dict = {'1': '1', '5': '5', '15': '15', '30': '30', '60': '60', 'D': '101', 'W': '102', 'M': '103'}
klt = data_dict[data_type]
fq='1'
url = 'http://push2his.eastmoney.com/api/qt/stock/kline/'
params = {
'fields1': 'f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13',
'fields2': 'f51,f52,f53,f54,f55,f56,f57,f58,f59,f60,f61',
'beg': start_date,
'end': end_date,
'ut': 'fa5fd1943c7b386f172d6893dbfba10b',
'rtntype': end_date,
'secid': stock,
'klt': klt,
'fqt': fq,
'cb': 'jsonp1668432946680'
}
res = requests.get(url=url, params=params)
text = res.text[19:len(res.text) - 2]
json_text = json.loads(text)
try:
df = pd.DataFrame(json_text['data']['klines'])
df.columns = ['数据']
data_list = []
for i in df['数据']:
data_list.append(i.split(','))
data = pd.DataFrame(data_list)
columns = ['date', 'open', 'close', 'high', 'low', 'volume', '成交额', '振幅', '涨跌幅', '涨跌额', '换手率']
data.columns = columns
for m in columns[1:]:
data[m] = pd.to_numeric(data[m])
data.sort_index(ascending=True,ignore_index=True,inplace=True)
return data
except:
pass
df = get_stock_hist_data_em(stock='0.399300', start_date='20100101', end_date='20500101', data_type='30')
from sklearn.tree import DecisionTreeRegressor
np.random.seed(0)
X = np.sort(5 * np.random.rand(len(df['close']), 1), axis=0)
y =df['close'].values.reshape(-1, 1)
tree_model = DecisionTreeRegressor(max_depth=4)
tree_model.fit(X, y)
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_pred = tree_model.predict(X_test)
plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
plt.plot(X_test, y_pred, color="cornflowerblue", label="prediction")
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()