diff --git a/config_jingbo.py b/config_jingbo.py index bf6e9b6..48f5a77 100644 --- a/config_jingbo.py +++ b/config_jingbo.py @@ -237,7 +237,7 @@ table_name = 'v_tbl_crude_oil_warning' ### 开关 is_train = False # 是否训练 is_debug = False # 是否调试 -is_eta = False # 是否使用eta接口 +is_eta = True # 是否使用eta接口 is_timefurture = True # 是否使用时间特征 is_fivemodels = False # 是否使用之前保存的最佳的5个模型 is_edbcode = False # 特征使用edbcoding列表中的 diff --git a/lib/tools.py b/lib/tools.py index 5584f84..2fb3a32 100644 --- a/lib/tools.py +++ b/lib/tools.py @@ -352,6 +352,29 @@ def dateConvert(df, datecol='ds'): df[datecol] = pd.to_datetime(df[datecol],format=r'%Y/%m/%d') return df +def save_to_database(sqlitedb,df,dbname,end_time): + ''' + create_dt ,ds 判断数据是否存在,不存在则插入,存在则更新 + ''' + # 判断格式是否为日期时间类型 + if pd.api.types.is_datetime64_any_dtype(df['ds']): + df['ds'] = df['ds'].dt.strftime('%Y-%m-%d') + + if not sqlitedb.check_table_exists(dbname): + df.to_sql(dbname,sqlitedb.connection,index=False) + else: + for col in df.columns: + sqlitedb.add_column_if_not_exists(dbname,col,'TEXT') + for row in df.itertuples(index=False): + row_dict = row._asdict() + columns=row_dict.keys() + check_query = sqlitedb.select_data(dbname,where_condition = f"ds = '{row.ds}' and created_dt = '{end_time}'") + if len(check_query) > 0: + set_clause = ", ".join([f"{key} = '{value}'" for key, value in row_dict.items()]) + sqlitedb.update_data(dbname,set_clause,where_condition = f"ds = '{row.ds} and created_dt = {end_time}'") + continue + else: + sqlitedb.insert_data(dbname,tuple(row_dict.values()),columns=columns) class SQLiteHandler: def __init__(self, db_name): @@ -444,6 +467,9 @@ class SQLiteHandler: else: print(f"Column '{column_name}' already exists in table '{table_name}'.") + + + import logging class MySQLDB: def __init__(self, host, user, password, database): diff --git a/main_yuanyou.py b/main_yuanyou.py index 5892b2c..d15475c 100644 --- a/main_yuanyou.py +++ b/main_yuanyou.py @@ -93,6 +93,12 @@ def predict_main(): # 保存最新日期的y值到数据库 # 取第一行数据存储到数据库中 first_row = df[['ds', 'y']].tail(1) + # 判断ds是否与ent_time 一致且 y 不为空 + if len(first_row) > 0 and first_row['y'].values[0] is not None: + pass + else: + logger.info('{end_time}预测目标数据为空,跳过') + return # 将最新真实值保存到数据库 if not sqlitedb.check_table_exists('trueandpredict'): first_row.to_sql('trueandpredict', sqlitedb.connection, index=False) @@ -257,10 +263,11 @@ if __name__ == '__main__': global end_time is_on = True # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 - for i_time in pd.date_range('2024-10-29', '2024-12-16', freq='B'): + for i_time in pd.date_range('2024-12-25', '2024-12-26', freq='B'): end_time = i_time.strftime('%Y-%m-%d') predict_main() if is_on: is_train = False is_on = False - is_fivemodels = True \ No newline at end of file + is_fivemodels = True + is_eta = False \ No newline at end of file diff --git a/models/nerulforcastmodels.py b/models/nerulforcastmodels.py index d4160b9..8608490 100644 --- a/models/nerulforcastmodels.py +++ b/models/nerulforcastmodels.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import matplotlib.dates as mdates import datetime from lib.tools import Graphs,mse,rmse,mae,exception_logger +from lib.tools import save_to_database from lib.dataread import * from neuralforecast import NeuralForecast from neuralforecast.models import NHITS,Informer, NBEATSx,LSTM,PatchTST, iTransformer, TSMixer @@ -204,25 +205,7 @@ def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patien df_predict.to_csv(os.path.join(dataset,"predict.csv"),index=False) # 将预测结果保存到数据库 - def save_to_database(df): - # ds列转为日期字符串 - df['ds'] = df['ds'].dt.strftime('%Y-%m-%d') - if not sqlitedb.check_table_exists('predict'): - df.to_sql('predict',sqlitedb.connection,index=False) - else: - for col in df.columns: - sqlitedb.add_column_if_not_exists('predict',col,'TEXT') - for row in df.itertuples(index=False): - row_dict = row._asdict() - columns=row_dict.keys() - check_query = sqlitedb.select_data('predict',where_condition = f"ds = '{row.ds}' and created_dt = '{end_time}'") - if len(check_query) > 0: - set_clause = ", ".join([f"{key} = '{value}'" for key, value in row_dict.items()]) - sqlitedb.update_data('predict',set_clause,where_condition = f"ds = '{row.ds} and created_dt = {end_time}'") - continue - else: - sqlitedb.insert_data('predict',tuple(row_dict.values()),columns=columns) - save_to_database(df_predict) + save_to_database(sqlitedb,df_predict,'predict',end_time) # 把预测值上传到eta if is_update_eta: @@ -252,21 +235,23 @@ def model_losss(sqlitedb,end_time): most_model_name = most_model[0] # 预测数据处理 predict - df_combined = loadcsv(os.path.join(dataset,"cross_validation.csv")) - df_combined = dateConvert(df_combined) - # 删除空列 - df_combined.dropna(axis=1,inplace=True) - # 删除缺失值,预测过程不能有缺失值 + # df_combined = loadcsv(os.path.join(dataset,"cross_validation.csv")) + # df_combined = dateConvert(df_combined) + df_combined = sqlitedb.select_data('accuracy') + # 删除缺失值大于80%的列 + df_combined = df_combined.loc[:, df_combined.isnull().mean() < 0.8] + # 删除缺失值 df_combined.dropna(inplace=True) # 其他列转为数值类型 - df_combined = df_combined.astype({col: 'float32' for col in df_combined.columns if col not in ['cutoff','ds'] }) + df_combined = df_combined.astype({col: 'float32' for col in df_combined.columns if col not in ['CREAT_DATE','ds','created_dt'] }) # 使用 groupby 和 transform 结合 lambda 函数来获取每个分组中 cutoff 的最小值,并创建一个新的列来存储这个最大值 - df_combined['max_cutoff'] = df_combined.groupby('ds')['cutoff'].transform('max') + df_combined['max_cutoff'] = df_combined.groupby('ds')['CREAT_DATE'].transform('min') # 然后筛选出那些 cutoff 等于 max_cutoff 的行,这样就得到了每个分组中 cutoff 最大的行,并保留了其他列 - df_combined = df_combined[df_combined['cutoff'] == df_combined['max_cutoff']] + df_combined = df_combined[df_combined['CREAT_DATE'] == df_combined['max_cutoff']] + df_combined4 = df_combined.copy() # 备份df_combined,后面画图需要 # 删除模型生成的cutoff列 - df_combined.drop(columns=['cutoff', 'max_cutoff'], inplace=True) + df_combined.drop(columns=['CREAT_DATE', 'max_cutoff','created_dt','min_within_quantile','max_within_quantile','id','min_price','max_price'], inplace=True) # 获取模型名称 modelnames = df_combined.columns.to_list()[1:] if 'y' in modelnames: @@ -432,17 +417,9 @@ def model_losss(sqlitedb,end_time): df_predict2['id'] = range(max_id + 1, max_id + 1 + len(df_predict2)) else: df_predict2['id'] = range(1, 1 + len(df_predict2)) - # df_predict2['CREAT_DATE'] = now if end_time == '' else end_time df_predict2['CREAT_DATE'] = end_time - def get_common_columns(df1, df2): - # 获取两个DataFrame的公共列名 - return list(set(df1.columns).intersection(df2.columns)) - common_columns = get_common_columns(df_predict2, existing_data) - try: - df_predict2[common_columns].to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False) - except: - df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False) + save_to_database(sqlitedb,df_predict2,"accuracy",end_time) # 上周准确率计算 predict_y = sqlitedb.select_data(table_name = "accuracy") @@ -487,8 +464,8 @@ def model_losss(sqlitedb,end_time): return 0 # 定义一个函数来计算准确率 - # 比较真实最高最低,和预测最高最低 计算准确率 def calculate_accuracy(row): + # 比较真实最高最低,和预测最高最低 计算准确率 # 全子集情况: if (row['max_price'] >= row['HIGH_PRICE'] and row['min_price'] <= row['LOW_PRICE']) or \ (row['max_price'] <= row['HIGH_PRICE'] and row['min_price'] >= row['LOW_PRICE']): @@ -607,6 +584,39 @@ def model_losss(sqlitedb,end_time): plt.savefig(os.path.join(dataset,'历史价格-预测值.png'), bbox_inches='tight') plt.close() + + def _plt_modeltopten_predict_ture(df): + lens = df.shape[0] if df.shape[0] < 180 else 90 + df = df[-lens:] # 取180个数据点画图 + df['mean_price'] = df[allmodelnames[:10]].mean(axis=1) + # 历史价格 + plt.figure(figsize=(20, 10)) + plt.plot(df['ds'], df['y'], label='真实值') + plt.plot(df['ds'], df['mean_price'], label='模型前十均值', linestyle='--', color='orange') + # 颜色填充 + plt.fill_between(df['ds'], df['max_price'], df['min_price'], alpha=0.2) + # markers = ['o', 's', '^', 'D', 'v', '*', 'p', 'h', 'H', '+', 'x', 'd'] + # random_marker = random.choice(markers) + # for model in allmodelnames: + # for model in ['BiTCN','RNN']: + # plt.plot(df['ds'], df[model], label=model,marker=random_marker) + # plt.plot(df_combined3['ds'], df_combined3['min_abs_error_rate_prediction'], label='最小绝对误差', linestyle='--', color='orange') + # 网格 + plt.grid(True) + # 显示历史值 + for i, j in zip(df['ds'], df['y']): + plt.text(i, j, str(j), ha='center', va='bottom') + + # 当前日期画竖虚线 + plt.axvline(x=df['ds'].iloc[-horizon], color='r', linestyle='--') + plt.legend() + plt.xlabel('日期') + plt.ylabel('价格') + + plt.savefig(os.path.join(dataset,'历史价格-预测值1.png'), bbox_inches='tight') + plt.close() + + def _plt_predict_table(df): # 预测值表格 fig, ax = plt.subplots(figsize=(20, 6)) @@ -641,6 +651,7 @@ def model_losss(sqlitedb,end_time): plt.close() _plt_predict_ture(df_combined3) + _plt_modeltopten_predict_ture(df_combined4) _plt_predict_table(df_combined3) _plt_model_results3() diff --git a/测试环境获取市场信息平台数据项.ipynb b/测试环境获取市场信息平台数据项.ipynb index ce30594..5b9306b 100644 --- a/测试环境获取市场信息平台数据项.ipynb +++ b/测试环境获取市场信息平台数据项.ipynb @@ -163,14 +163,13 @@ ], "source": [ "query_data_list_item_nos_data = {\n", - " \"funcModule\":'数据项编码集合',\n", - " \"funcOperation\":'数据项编码集合',\n", - " \"data\":{\n", - " \"dataItemNoList\":['EXCHANGE|RATE|MIDDLE_PRICE'],\n", - " \"dateEnd\":'20240101',\n", - " \"dateStart\":'20241024'\n", - " \n", - " }\n", + " \"funcModule\": \"数据项\",\n", + " \"funcOperation\": \"查询\",\n", + " \"data\": {\n", + " \"dateStart\":\"20240101\",\n", + " \"dateEnd\":\"20241231\",\n", + " \"dataItemNoList\":[\"EXCHANGE|RATE|MIDDLE_PRICE\",\"/|250290802|PRICE\"]\n", + " }\n", "}\n", "\n", "headers = {\"Authorization\": token}\n",