添加获取市场信息平台请求数据
This commit is contained in:
parent
bfb981d486
commit
c074f1eeae
@ -237,7 +237,7 @@ table_name = 'v_tbl_crude_oil_warning'
|
|||||||
### 开关
|
### 开关
|
||||||
is_train = False # 是否训练
|
is_train = False # 是否训练
|
||||||
is_debug = False # 是否调试
|
is_debug = False # 是否调试
|
||||||
is_eta = False # 是否使用eta接口
|
is_eta = True # 是否使用eta接口
|
||||||
is_timefurture = True # 是否使用时间特征
|
is_timefurture = True # 是否使用时间特征
|
||||||
is_fivemodels = False # 是否使用之前保存的最佳的5个模型
|
is_fivemodels = False # 是否使用之前保存的最佳的5个模型
|
||||||
is_edbcode = False # 特征使用edbcoding列表中的
|
is_edbcode = False # 特征使用edbcoding列表中的
|
||||||
|
26
lib/tools.py
26
lib/tools.py
@ -352,6 +352,29 @@ def dateConvert(df, datecol='ds'):
|
|||||||
df[datecol] = pd.to_datetime(df[datecol],format=r'%Y/%m/%d')
|
df[datecol] = pd.to_datetime(df[datecol],format=r'%Y/%m/%d')
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
def save_to_database(sqlitedb,df,dbname,end_time):
|
||||||
|
'''
|
||||||
|
create_dt ,ds 判断数据是否存在,不存在则插入,存在则更新
|
||||||
|
'''
|
||||||
|
# 判断格式是否为日期时间类型
|
||||||
|
if pd.api.types.is_datetime64_any_dtype(df['ds']):
|
||||||
|
df['ds'] = df['ds'].dt.strftime('%Y-%m-%d')
|
||||||
|
|
||||||
|
if not sqlitedb.check_table_exists(dbname):
|
||||||
|
df.to_sql(dbname,sqlitedb.connection,index=False)
|
||||||
|
else:
|
||||||
|
for col in df.columns:
|
||||||
|
sqlitedb.add_column_if_not_exists(dbname,col,'TEXT')
|
||||||
|
for row in df.itertuples(index=False):
|
||||||
|
row_dict = row._asdict()
|
||||||
|
columns=row_dict.keys()
|
||||||
|
check_query = sqlitedb.select_data(dbname,where_condition = f"ds = '{row.ds}' and created_dt = '{end_time}'")
|
||||||
|
if len(check_query) > 0:
|
||||||
|
set_clause = ", ".join([f"{key} = '{value}'" for key, value in row_dict.items()])
|
||||||
|
sqlitedb.update_data(dbname,set_clause,where_condition = f"ds = '{row.ds} and created_dt = {end_time}'")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
sqlitedb.insert_data(dbname,tuple(row_dict.values()),columns=columns)
|
||||||
|
|
||||||
class SQLiteHandler:
|
class SQLiteHandler:
|
||||||
def __init__(self, db_name):
|
def __init__(self, db_name):
|
||||||
@ -444,6 +467,9 @@ class SQLiteHandler:
|
|||||||
else:
|
else:
|
||||||
print(f"Column '{column_name}' already exists in table '{table_name}'.")
|
print(f"Column '{column_name}' already exists in table '{table_name}'.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
class MySQLDB:
|
class MySQLDB:
|
||||||
def __init__(self, host, user, password, database):
|
def __init__(self, host, user, password, database):
|
||||||
|
@ -93,6 +93,12 @@ def predict_main():
|
|||||||
# 保存最新日期的y值到数据库
|
# 保存最新日期的y值到数据库
|
||||||
# 取第一行数据存储到数据库中
|
# 取第一行数据存储到数据库中
|
||||||
first_row = df[['ds', 'y']].tail(1)
|
first_row = df[['ds', 'y']].tail(1)
|
||||||
|
# 判断ds是否与ent_time 一致且 y 不为空
|
||||||
|
if len(first_row) > 0 and first_row['y'].values[0] is not None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.info('{end_time}预测目标数据为空,跳过')
|
||||||
|
return
|
||||||
# 将最新真实值保存到数据库
|
# 将最新真实值保存到数据库
|
||||||
if not sqlitedb.check_table_exists('trueandpredict'):
|
if not sqlitedb.check_table_exists('trueandpredict'):
|
||||||
first_row.to_sql('trueandpredict', sqlitedb.connection, index=False)
|
first_row.to_sql('trueandpredict', sqlitedb.connection, index=False)
|
||||||
@ -257,10 +263,11 @@ if __name__ == '__main__':
|
|||||||
global end_time
|
global end_time
|
||||||
is_on = True
|
is_on = True
|
||||||
# 遍历2024-11-25 到 2024-12-3 之间的工作日日期
|
# 遍历2024-11-25 到 2024-12-3 之间的工作日日期
|
||||||
for i_time in pd.date_range('2024-10-29', '2024-12-16', freq='B'):
|
for i_time in pd.date_range('2024-12-25', '2024-12-26', freq='B'):
|
||||||
end_time = i_time.strftime('%Y-%m-%d')
|
end_time = i_time.strftime('%Y-%m-%d')
|
||||||
predict_main()
|
predict_main()
|
||||||
if is_on:
|
if is_on:
|
||||||
is_train = False
|
is_train = False
|
||||||
is_on = False
|
is_on = False
|
||||||
is_fivemodels = True
|
is_fivemodels = True
|
||||||
|
is_eta = False
|
@ -7,6 +7,7 @@ import matplotlib.pyplot as plt
|
|||||||
import matplotlib.dates as mdates
|
import matplotlib.dates as mdates
|
||||||
import datetime
|
import datetime
|
||||||
from lib.tools import Graphs,mse,rmse,mae,exception_logger
|
from lib.tools import Graphs,mse,rmse,mae,exception_logger
|
||||||
|
from lib.tools import save_to_database
|
||||||
from lib.dataread import *
|
from lib.dataread import *
|
||||||
from neuralforecast import NeuralForecast
|
from neuralforecast import NeuralForecast
|
||||||
from neuralforecast.models import NHITS,Informer, NBEATSx,LSTM,PatchTST, iTransformer, TSMixer
|
from neuralforecast.models import NHITS,Informer, NBEATSx,LSTM,PatchTST, iTransformer, TSMixer
|
||||||
@ -204,25 +205,7 @@ def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patien
|
|||||||
df_predict.to_csv(os.path.join(dataset,"predict.csv"),index=False)
|
df_predict.to_csv(os.path.join(dataset,"predict.csv"),index=False)
|
||||||
|
|
||||||
# 将预测结果保存到数据库
|
# 将预测结果保存到数据库
|
||||||
def save_to_database(df):
|
save_to_database(sqlitedb,df_predict,'predict',end_time)
|
||||||
# ds列转为日期字符串
|
|
||||||
df['ds'] = df['ds'].dt.strftime('%Y-%m-%d')
|
|
||||||
if not sqlitedb.check_table_exists('predict'):
|
|
||||||
df.to_sql('predict',sqlitedb.connection,index=False)
|
|
||||||
else:
|
|
||||||
for col in df.columns:
|
|
||||||
sqlitedb.add_column_if_not_exists('predict',col,'TEXT')
|
|
||||||
for row in df.itertuples(index=False):
|
|
||||||
row_dict = row._asdict()
|
|
||||||
columns=row_dict.keys()
|
|
||||||
check_query = sqlitedb.select_data('predict',where_condition = f"ds = '{row.ds}' and created_dt = '{end_time}'")
|
|
||||||
if len(check_query) > 0:
|
|
||||||
set_clause = ", ".join([f"{key} = '{value}'" for key, value in row_dict.items()])
|
|
||||||
sqlitedb.update_data('predict',set_clause,where_condition = f"ds = '{row.ds} and created_dt = {end_time}'")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
sqlitedb.insert_data('predict',tuple(row_dict.values()),columns=columns)
|
|
||||||
save_to_database(df_predict)
|
|
||||||
|
|
||||||
# 把预测值上传到eta
|
# 把预测值上传到eta
|
||||||
if is_update_eta:
|
if is_update_eta:
|
||||||
@ -252,21 +235,23 @@ def model_losss(sqlitedb,end_time):
|
|||||||
most_model_name = most_model[0]
|
most_model_name = most_model[0]
|
||||||
|
|
||||||
# 预测数据处理 predict
|
# 预测数据处理 predict
|
||||||
df_combined = loadcsv(os.path.join(dataset,"cross_validation.csv"))
|
# df_combined = loadcsv(os.path.join(dataset,"cross_validation.csv"))
|
||||||
df_combined = dateConvert(df_combined)
|
# df_combined = dateConvert(df_combined)
|
||||||
# 删除空列
|
df_combined = sqlitedb.select_data('accuracy')
|
||||||
df_combined.dropna(axis=1,inplace=True)
|
# 删除缺失值大于80%的列
|
||||||
# 删除缺失值,预测过程不能有缺失值
|
df_combined = df_combined.loc[:, df_combined.isnull().mean() < 0.8]
|
||||||
|
# 删除缺失值
|
||||||
df_combined.dropna(inplace=True)
|
df_combined.dropna(inplace=True)
|
||||||
# 其他列转为数值类型
|
# 其他列转为数值类型
|
||||||
df_combined = df_combined.astype({col: 'float32' for col in df_combined.columns if col not in ['cutoff','ds'] })
|
df_combined = df_combined.astype({col: 'float32' for col in df_combined.columns if col not in ['CREAT_DATE','ds','created_dt'] })
|
||||||
# 使用 groupby 和 transform 结合 lambda 函数来获取每个分组中 cutoff 的最小值,并创建一个新的列来存储这个最大值
|
# 使用 groupby 和 transform 结合 lambda 函数来获取每个分组中 cutoff 的最小值,并创建一个新的列来存储这个最大值
|
||||||
df_combined['max_cutoff'] = df_combined.groupby('ds')['cutoff'].transform('max')
|
df_combined['max_cutoff'] = df_combined.groupby('ds')['CREAT_DATE'].transform('min')
|
||||||
|
|
||||||
# 然后筛选出那些 cutoff 等于 max_cutoff 的行,这样就得到了每个分组中 cutoff 最大的行,并保留了其他列
|
# 然后筛选出那些 cutoff 等于 max_cutoff 的行,这样就得到了每个分组中 cutoff 最大的行,并保留了其他列
|
||||||
df_combined = df_combined[df_combined['cutoff'] == df_combined['max_cutoff']]
|
df_combined = df_combined[df_combined['CREAT_DATE'] == df_combined['max_cutoff']]
|
||||||
|
df_combined4 = df_combined.copy() # 备份df_combined,后面画图需要
|
||||||
# 删除模型生成的cutoff列
|
# 删除模型生成的cutoff列
|
||||||
df_combined.drop(columns=['cutoff', 'max_cutoff'], inplace=True)
|
df_combined.drop(columns=['CREAT_DATE', 'max_cutoff','created_dt','min_within_quantile','max_within_quantile','id','min_price','max_price'], inplace=True)
|
||||||
# 获取模型名称
|
# 获取模型名称
|
||||||
modelnames = df_combined.columns.to_list()[1:]
|
modelnames = df_combined.columns.to_list()[1:]
|
||||||
if 'y' in modelnames:
|
if 'y' in modelnames:
|
||||||
@ -432,17 +417,9 @@ def model_losss(sqlitedb,end_time):
|
|||||||
df_predict2['id'] = range(max_id + 1, max_id + 1 + len(df_predict2))
|
df_predict2['id'] = range(max_id + 1, max_id + 1 + len(df_predict2))
|
||||||
else:
|
else:
|
||||||
df_predict2['id'] = range(1, 1 + len(df_predict2))
|
df_predict2['id'] = range(1, 1 + len(df_predict2))
|
||||||
# df_predict2['CREAT_DATE'] = now if end_time == '' else end_time
|
|
||||||
df_predict2['CREAT_DATE'] = end_time
|
df_predict2['CREAT_DATE'] = end_time
|
||||||
def get_common_columns(df1, df2):
|
|
||||||
# 获取两个DataFrame的公共列名
|
|
||||||
return list(set(df1.columns).intersection(df2.columns))
|
|
||||||
|
|
||||||
common_columns = get_common_columns(df_predict2, existing_data)
|
save_to_database(sqlitedb,df_predict2,"accuracy",end_time)
|
||||||
try:
|
|
||||||
df_predict2[common_columns].to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False)
|
|
||||||
except:
|
|
||||||
df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False)
|
|
||||||
|
|
||||||
# 上周准确率计算
|
# 上周准确率计算
|
||||||
predict_y = sqlitedb.select_data(table_name = "accuracy")
|
predict_y = sqlitedb.select_data(table_name = "accuracy")
|
||||||
@ -487,8 +464,8 @@ def model_losss(sqlitedb,end_time):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
# 定义一个函数来计算准确率
|
# 定义一个函数来计算准确率
|
||||||
# 比较真实最高最低,和预测最高最低 计算准确率
|
|
||||||
def calculate_accuracy(row):
|
def calculate_accuracy(row):
|
||||||
|
# 比较真实最高最低,和预测最高最低 计算准确率
|
||||||
# 全子集情况:
|
# 全子集情况:
|
||||||
if (row['max_price'] >= row['HIGH_PRICE'] and row['min_price'] <= row['LOW_PRICE']) or \
|
if (row['max_price'] >= row['HIGH_PRICE'] and row['min_price'] <= row['LOW_PRICE']) or \
|
||||||
(row['max_price'] <= row['HIGH_PRICE'] and row['min_price'] >= row['LOW_PRICE']):
|
(row['max_price'] <= row['HIGH_PRICE'] and row['min_price'] >= row['LOW_PRICE']):
|
||||||
@ -607,6 +584,39 @@ def model_losss(sqlitedb,end_time):
|
|||||||
plt.savefig(os.path.join(dataset,'历史价格-预测值.png'), bbox_inches='tight')
|
plt.savefig(os.path.join(dataset,'历史价格-预测值.png'), bbox_inches='tight')
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _plt_modeltopten_predict_ture(df):
|
||||||
|
lens = df.shape[0] if df.shape[0] < 180 else 90
|
||||||
|
df = df[-lens:] # 取180个数据点画图
|
||||||
|
df['mean_price'] = df[allmodelnames[:10]].mean(axis=1)
|
||||||
|
# 历史价格
|
||||||
|
plt.figure(figsize=(20, 10))
|
||||||
|
plt.plot(df['ds'], df['y'], label='真实值')
|
||||||
|
plt.plot(df['ds'], df['mean_price'], label='模型前十均值', linestyle='--', color='orange')
|
||||||
|
# 颜色填充
|
||||||
|
plt.fill_between(df['ds'], df['max_price'], df['min_price'], alpha=0.2)
|
||||||
|
# markers = ['o', 's', '^', 'D', 'v', '*', 'p', 'h', 'H', '+', 'x', 'd']
|
||||||
|
# random_marker = random.choice(markers)
|
||||||
|
# for model in allmodelnames:
|
||||||
|
# for model in ['BiTCN','RNN']:
|
||||||
|
# plt.plot(df['ds'], df[model], label=model,marker=random_marker)
|
||||||
|
# plt.plot(df_combined3['ds'], df_combined3['min_abs_error_rate_prediction'], label='最小绝对误差', linestyle='--', color='orange')
|
||||||
|
# 网格
|
||||||
|
plt.grid(True)
|
||||||
|
# 显示历史值
|
||||||
|
for i, j in zip(df['ds'], df['y']):
|
||||||
|
plt.text(i, j, str(j), ha='center', va='bottom')
|
||||||
|
|
||||||
|
# 当前日期画竖虚线
|
||||||
|
plt.axvline(x=df['ds'].iloc[-horizon], color='r', linestyle='--')
|
||||||
|
plt.legend()
|
||||||
|
plt.xlabel('日期')
|
||||||
|
plt.ylabel('价格')
|
||||||
|
|
||||||
|
plt.savefig(os.path.join(dataset,'历史价格-预测值1.png'), bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
def _plt_predict_table(df):
|
def _plt_predict_table(df):
|
||||||
# 预测值表格
|
# 预测值表格
|
||||||
fig, ax = plt.subplots(figsize=(20, 6))
|
fig, ax = plt.subplots(figsize=(20, 6))
|
||||||
@ -641,6 +651,7 @@ def model_losss(sqlitedb,end_time):
|
|||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
_plt_predict_ture(df_combined3)
|
_plt_predict_ture(df_combined3)
|
||||||
|
_plt_modeltopten_predict_ture(df_combined4)
|
||||||
_plt_predict_table(df_combined3)
|
_plt_predict_table(df_combined3)
|
||||||
_plt_model_results3()
|
_plt_model_results3()
|
||||||
|
|
||||||
|
@ -163,13 +163,12 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"query_data_list_item_nos_data = {\n",
|
"query_data_list_item_nos_data = {\n",
|
||||||
" \"funcModule\":'数据项编码集合',\n",
|
" \"funcModule\": \"数据项\",\n",
|
||||||
" \"funcOperation\":'数据项编码集合',\n",
|
" \"funcOperation\": \"查询\",\n",
|
||||||
" \"data\": {\n",
|
" \"data\": {\n",
|
||||||
" \"dataItemNoList\":['EXCHANGE|RATE|MIDDLE_PRICE'],\n",
|
" \"dateStart\":\"20240101\",\n",
|
||||||
" \"dateEnd\":'20240101',\n",
|
" \"dateEnd\":\"20241231\",\n",
|
||||||
" \"dateStart\":'20241024'\n",
|
" \"dataItemNoList\":[\"EXCHANGE|RATE|MIDDLE_PRICE\",\"/|250290802|PRICE\"]\n",
|
||||||
" \n",
|
|
||||||
" }\n",
|
" }\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
Loading…
Reference in New Issue
Block a user