diff --git a/lib/pydantic_models.py b/lib/pydantic_models.py index 4b1c9ea..ef3cf3e 100644 --- a/lib/pydantic_models.py +++ b/lib/pydantic_models.py @@ -9,8 +9,8 @@ from decimal import Decimal class PredictionResult(BaseModel): feature_factor_frequency: str strategy_id: int - oil_code: Optional[str] = None - oil_name: Optional[str] = None + oil_code: Optional[str] = 'CRUDE' + oil_name: Optional[str] = '原油' data_date: Optional[datetime] = None market_price: Optional[Decimal] = None day_price: Optional[Decimal] = None diff --git a/lib/tools.py b/lib/tools.py index c2f7b9b..d11671d 100644 --- a/lib/tools.py +++ b/lib/tools.py @@ -686,7 +686,6 @@ def get_modelsname(df, global_config): model_name_list = [row['model_name'] for row in modelsname] model_name_list = set(columns) & set(model_name_list) model_name_list = list(model_name_list) - global_config['db_mysql'].close() return model_name_list, model_id_name_dict diff --git a/main_yuanyou.py b/main_yuanyou.py index 1e56d43..583306d 100644 --- a/main_yuanyou.py +++ b/main_yuanyou.py @@ -381,15 +381,15 @@ def predict_main(): logger.info(f'要更新y的信息:{update_y}') # try: for row in update_y.itertuples(index=False): - # try: - row_dict = row._asdict() - yy = df[df['ds'] == row_dict['ds']]['y'].values[0] - LOW = df[df['ds'] == row_dict['ds']]['Brentzdj'].values[0] - HIGH = df[df['ds'] == row_dict['ds']]['Brentzgj'].values[0] - sqlitedb.update_data( - 'accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'") - # except: - # logger.info(f'更新accuracy表的y值失败:{row_dict}') + try: + row_dict = row._asdict() + yy = df[df['ds'] == row_dict['ds']]['y'].values[0] + LOW = df[df['ds'] == row_dict['ds']]['Brentzdj'].values[0] + HIGH = df[df['ds'] == row_dict['ds']]['Brentzgj'].values[0] + sqlitedb.update_data( + 'accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'") + except: + logger.info(f'更新accuracy表的y值失败:{row_dict}') # except Exception as e: # logger.info(f'更新accuracy表的y值失败:{e}') @@ -581,10 +581,11 @@ def predict_main(): if __name__ == '__main__': # global end_time # # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 - # for i_time in pd.date_range('2024-12-1', '2025-2-26', freq='W'): - # end_time = i_time.strftime('%Y-%m-%d') - # predict_main() + for i_time in pd.date_range('2025-6-11', '2025-6-28', freq='B'): + global_config['end_time'] = i_time.strftime('%Y-%m-%d') + global_config['db_mysql'].connect() + predict_main() # predict_main() # push_market_value() - sql_inset_predict(global_config=global_config) + # sql_inset_predict(global_config=global_config) diff --git a/models/nerulforcastmodels.py b/models/nerulforcastmodels.py index 23bbc48..36b17bc 100644 --- a/models/nerulforcastmodels.py +++ b/models/nerulforcastmodels.py @@ -241,37 +241,6 @@ def ex_Model(df, horizon, input_size, train_steps, val_check_steps, early_stop_p if config.is_update_eta: df_predict['ds'] = pd.to_datetime(df_predict['ds']) - # 按行遍历df_predict - IndexName = data['IndexName'] - for index, row in df_predict.iterrows(): - try: - yy = config.bdwdname[index] - except IndexError as e: - break - for m in modelsindex[index].keys(): - if m not in df_predict.columns: - continue - if m == 'FEDformer': - continue - list = [{'Date': config.end_time, 'Value': round(row[m], 2)}] - data['DataList'] = list - data['IndexCode'] = modelsindex[index][m] - # data['IndexName'] = f'价格预测{m}模型' - data['IndexName'] = data['IndexName'].replace('xx', m) - data['IndexName'] = data['IndexName'].replace('yy', yy) - data['Remark'] = m - print('预测数据上传到eta:') - etadata.push_data(data) - # print(data) - data['IndexName'] = IndexName - - # 把预测值上传到市场信息平台 - if config.is_update_market: - ''' - 预测结果整理,写入到数据表 v_tbl_predict_prediction_results - ''' - df_predict['ds'] = pd.to_datetime(df_predict['ds']) - # 按行遍历df_predict IndexName = data['IndexName'] for index, row in df_predict.iterrows():