添加获取市场信息平台请求数据

This commit is contained in:
workpc 2024-12-26 11:37:18 +08:00
parent bfb981d486
commit c074f1eeae
5 changed files with 92 additions and 49 deletions

View File

@ -237,7 +237,7 @@ table_name = 'v_tbl_crude_oil_warning'
### 开关
is_train = False # 是否训练
is_debug = False # 是否调试
is_eta = False # 是否使用eta接口
is_eta = True # 是否使用eta接口
is_timefurture = True # 是否使用时间特征
is_fivemodels = False # 是否使用之前保存的最佳的5个模型
is_edbcode = False # 特征使用edbcoding列表中的

View File

@ -352,6 +352,29 @@ def dateConvert(df, datecol='ds'):
df[datecol] = pd.to_datetime(df[datecol],format=r'%Y/%m/%d')
return df
def save_to_database(sqlitedb,df,dbname,end_time):
'''
create_dt ,ds 判断数据是否存在不存在则插入存在则更新
'''
# 判断格式是否为日期时间类型
if pd.api.types.is_datetime64_any_dtype(df['ds']):
df['ds'] = df['ds'].dt.strftime('%Y-%m-%d')
if not sqlitedb.check_table_exists(dbname):
df.to_sql(dbname,sqlitedb.connection,index=False)
else:
for col in df.columns:
sqlitedb.add_column_if_not_exists(dbname,col,'TEXT')
for row in df.itertuples(index=False):
row_dict = row._asdict()
columns=row_dict.keys()
check_query = sqlitedb.select_data(dbname,where_condition = f"ds = '{row.ds}' and created_dt = '{end_time}'")
if len(check_query) > 0:
set_clause = ", ".join([f"{key} = '{value}'" for key, value in row_dict.items()])
sqlitedb.update_data(dbname,set_clause,where_condition = f"ds = '{row.ds} and created_dt = {end_time}'")
continue
else:
sqlitedb.insert_data(dbname,tuple(row_dict.values()),columns=columns)
class SQLiteHandler:
def __init__(self, db_name):
@ -444,6 +467,9 @@ class SQLiteHandler:
else:
print(f"Column '{column_name}' already exists in table '{table_name}'.")
import logging
class MySQLDB:
def __init__(self, host, user, password, database):

View File

@ -93,6 +93,12 @@ def predict_main():
# 保存最新日期的y值到数据库
# 取第一行数据存储到数据库中
first_row = df[['ds', 'y']].tail(1)
# 判断ds是否与ent_time 一致且 y 不为空
if len(first_row) > 0 and first_row['y'].values[0] is not None:
pass
else:
logger.info('{end_time}预测目标数据为空,跳过')
return
# 将最新真实值保存到数据库
if not sqlitedb.check_table_exists('trueandpredict'):
first_row.to_sql('trueandpredict', sqlitedb.connection, index=False)
@ -257,10 +263,11 @@ if __name__ == '__main__':
global end_time
is_on = True
# 遍历2024-11-25 到 2024-12-3 之间的工作日日期
for i_time in pd.date_range('2024-10-29', '2024-12-16', freq='B'):
for i_time in pd.date_range('2024-12-25', '2024-12-26', freq='B'):
end_time = i_time.strftime('%Y-%m-%d')
predict_main()
if is_on:
is_train = False
is_on = False
is_fivemodels = True
is_eta = False

View File

@ -7,6 +7,7 @@ import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import datetime
from lib.tools import Graphs,mse,rmse,mae,exception_logger
from lib.tools import save_to_database
from lib.dataread import *
from neuralforecast import NeuralForecast
from neuralforecast.models import NHITS,Informer, NBEATSx,LSTM,PatchTST, iTransformer, TSMixer
@ -204,25 +205,7 @@ def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patien
df_predict.to_csv(os.path.join(dataset,"predict.csv"),index=False)
# 将预测结果保存到数据库
def save_to_database(df):
# ds列转为日期字符串
df['ds'] = df['ds'].dt.strftime('%Y-%m-%d')
if not sqlitedb.check_table_exists('predict'):
df.to_sql('predict',sqlitedb.connection,index=False)
else:
for col in df.columns:
sqlitedb.add_column_if_not_exists('predict',col,'TEXT')
for row in df.itertuples(index=False):
row_dict = row._asdict()
columns=row_dict.keys()
check_query = sqlitedb.select_data('predict',where_condition = f"ds = '{row.ds}' and created_dt = '{end_time}'")
if len(check_query) > 0:
set_clause = ", ".join([f"{key} = '{value}'" for key, value in row_dict.items()])
sqlitedb.update_data('predict',set_clause,where_condition = f"ds = '{row.ds} and created_dt = {end_time}'")
continue
else:
sqlitedb.insert_data('predict',tuple(row_dict.values()),columns=columns)
save_to_database(df_predict)
save_to_database(sqlitedb,df_predict,'predict',end_time)
# 把预测值上传到eta
if is_update_eta:
@ -252,21 +235,23 @@ def model_losss(sqlitedb,end_time):
most_model_name = most_model[0]
# 预测数据处理 predict
df_combined = loadcsv(os.path.join(dataset,"cross_validation.csv"))
df_combined = dateConvert(df_combined)
# 删除空列
df_combined.dropna(axis=1,inplace=True)
# 删除缺失值,预测过程不能有缺失值
# df_combined = loadcsv(os.path.join(dataset,"cross_validation.csv"))
# df_combined = dateConvert(df_combined)
df_combined = sqlitedb.select_data('accuracy')
# 删除缺失值大于80%的列
df_combined = df_combined.loc[:, df_combined.isnull().mean() < 0.8]
# 删除缺失值
df_combined.dropna(inplace=True)
# 其他列转为数值类型
df_combined = df_combined.astype({col: 'float32' for col in df_combined.columns if col not in ['cutoff','ds'] })
df_combined = df_combined.astype({col: 'float32' for col in df_combined.columns if col not in ['CREAT_DATE','ds','created_dt'] })
# 使用 groupby 和 transform 结合 lambda 函数来获取每个分组中 cutoff 的最小值,并创建一个新的列来存储这个最大值
df_combined['max_cutoff'] = df_combined.groupby('ds')['cutoff'].transform('max')
df_combined['max_cutoff'] = df_combined.groupby('ds')['CREAT_DATE'].transform('min')
# 然后筛选出那些 cutoff 等于 max_cutoff 的行,这样就得到了每个分组中 cutoff 最大的行,并保留了其他列
df_combined = df_combined[df_combined['cutoff'] == df_combined['max_cutoff']]
df_combined = df_combined[df_combined['CREAT_DATE'] == df_combined['max_cutoff']]
df_combined4 = df_combined.copy() # 备份df_combined,后面画图需要
# 删除模型生成的cutoff列
df_combined.drop(columns=['cutoff', 'max_cutoff'], inplace=True)
df_combined.drop(columns=['CREAT_DATE', 'max_cutoff','created_dt','min_within_quantile','max_within_quantile','id','min_price','max_price'], inplace=True)
# 获取模型名称
modelnames = df_combined.columns.to_list()[1:]
if 'y' in modelnames:
@ -432,17 +417,9 @@ def model_losss(sqlitedb,end_time):
df_predict2['id'] = range(max_id + 1, max_id + 1 + len(df_predict2))
else:
df_predict2['id'] = range(1, 1 + len(df_predict2))
# df_predict2['CREAT_DATE'] = now if end_time == '' else end_time
df_predict2['CREAT_DATE'] = end_time
def get_common_columns(df1, df2):
# 获取两个DataFrame的公共列名
return list(set(df1.columns).intersection(df2.columns))
common_columns = get_common_columns(df_predict2, existing_data)
try:
df_predict2[common_columns].to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False)
except:
df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False)
save_to_database(sqlitedb,df_predict2,"accuracy",end_time)
# 上周准确率计算
predict_y = sqlitedb.select_data(table_name = "accuracy")
@ -487,8 +464,8 @@ def model_losss(sqlitedb,end_time):
return 0
# 定义一个函数来计算准确率
# 比较真实最高最低,和预测最高最低 计算准确率
def calculate_accuracy(row):
# 比较真实最高最低,和预测最高最低 计算准确率
# 全子集情况:
if (row['max_price'] >= row['HIGH_PRICE'] and row['min_price'] <= row['LOW_PRICE']) or \
(row['max_price'] <= row['HIGH_PRICE'] and row['min_price'] >= row['LOW_PRICE']):
@ -607,6 +584,39 @@ def model_losss(sqlitedb,end_time):
plt.savefig(os.path.join(dataset,'历史价格-预测值.png'), bbox_inches='tight')
plt.close()
def _plt_modeltopten_predict_ture(df):
lens = df.shape[0] if df.shape[0] < 180 else 90
df = df[-lens:] # 取180个数据点画图
df['mean_price'] = df[allmodelnames[:10]].mean(axis=1)
# 历史价格
plt.figure(figsize=(20, 10))
plt.plot(df['ds'], df['y'], label='真实值')
plt.plot(df['ds'], df['mean_price'], label='模型前十均值', linestyle='--', color='orange')
# 颜色填充
plt.fill_between(df['ds'], df['max_price'], df['min_price'], alpha=0.2)
# markers = ['o', 's', '^', 'D', 'v', '*', 'p', 'h', 'H', '+', 'x', 'd']
# random_marker = random.choice(markers)
# for model in allmodelnames:
# for model in ['BiTCN','RNN']:
# plt.plot(df['ds'], df[model], label=model,marker=random_marker)
# plt.plot(df_combined3['ds'], df_combined3['min_abs_error_rate_prediction'], label='最小绝对误差', linestyle='--', color='orange')
# 网格
plt.grid(True)
# 显示历史值
for i, j in zip(df['ds'], df['y']):
plt.text(i, j, str(j), ha='center', va='bottom')
# 当前日期画竖虚线
plt.axvline(x=df['ds'].iloc[-horizon], color='r', linestyle='--')
plt.legend()
plt.xlabel('日期')
plt.ylabel('价格')
plt.savefig(os.path.join(dataset,'历史价格-预测值1.png'), bbox_inches='tight')
plt.close()
def _plt_predict_table(df):
# 预测值表格
fig, ax = plt.subplots(figsize=(20, 6))
@ -641,6 +651,7 @@ def model_losss(sqlitedb,end_time):
plt.close()
_plt_predict_ture(df_combined3)
_plt_modeltopten_predict_ture(df_combined4)
_plt_predict_table(df_combined3)
_plt_model_results3()

View File

@ -163,13 +163,12 @@
],
"source": [
"query_data_list_item_nos_data = {\n",
" \"funcModule\":'数据项编码集合',\n",
" \"funcOperation\":'数据项编码集合',\n",
" \"funcModule\": \"数据项\",\n",
" \"funcOperation\": \"查询\",\n",
" \"data\": {\n",
" \"dataItemNoList\":['EXCHANGE|RATE|MIDDLE_PRICE'],\n",
" \"dateEnd\":'20240101',\n",
" \"dateStart\":'20241024'\n",
" \n",
" \"dateStart\":\"20240101\",\n",
" \"dateEnd\":\"20241231\",\n",
" \"dataItemNoList\":[\"EXCHANGE|RATE|MIDDLE_PRICE\",\"/|250290802|PRICE\"]\n",
" }\n",
"}\n",
"\n",