result_validate.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import os
  2. import datetime
  3. import pandas as pd
  4. from data_loader import mongo_con_parse, validate_one_line, fill_hourly_crawl_date
  5. def validate_process(node, interval_hours, pred_time_str):
  6. date = pred_time_str[4:8]
  7. output_dir = f"./validate/{node}_{date}"
  8. os.makedirs(output_dir, exist_ok=True)
  9. object_dir = "./predictions"
  10. if interval_hours == 4:
  11. object_dir = "./predictions_4"
  12. elif interval_hours == 2:
  13. object_dir = "./predictions_2"
  14. csv_file = f'future_predictions_{pred_time_str}.csv'
  15. csv_path = os.path.join(object_dir, csv_file)
  16. try:
  17. df_predict = pd.read_csv(csv_path)
  18. except Exception as e:
  19. print(f"read {csv_path} error: {str(e)}")
  20. df_predict = pd.DataFrame()
  21. if df_predict.empty:
  22. print(f"预测数据为空")
  23. return
  24. # fly_day = df_predict['flight_day'].unique()[0]
  25. client, db = mongo_con_parse()
  26. count = 0
  27. for idx, row in df_predict.iterrows():
  28. city_pair = row['city_pair']
  29. flight_day = row['flight_day']
  30. flight_number_1 = row['flight_number_1']
  31. flight_number_2 = row['flight_number_2']
  32. baggage = row['baggage']
  33. valid_begin_hour = row['valid_begin_hour']
  34. df_val= validate_one_line(db, city_pair, flight_day, flight_number_1, flight_number_2, baggage, valid_begin_hour)
  35. # 有可能在当前验证时刻,数据库里没有在valid_begin_hour之后的数据
  36. if not df_val.empty:
  37. df_val_f = fill_hourly_crawl_date(df_val, rear_fill=2)
  38. df_val_f = df_val_f[df_val_f['is_filled']==0] # 只要原始数据,不要补齐的
  39. if df_val_f.empty:
  40. drop_flag = 0
  41. first_drop_amount = pd.NA
  42. first_drop_hours = pd.NA
  43. last_hours_util = pd.NA
  44. last_update_hour = pd.NA
  45. list_change_price = []
  46. list_change_hours = []
  47. else:
  48. # 有效数据的最后一行
  49. last_row = df_val_f.iloc[-1]
  50. last_hours_util = last_row['hours_until_departure']
  51. last_update_hour = last_row['update_hour']
  52. # 价格变化过滤
  53. df_price_changes = df_val_f.loc[
  54. df_val_f["adult_total_price"].shift() != df_val_f["adult_total_price"]
  55. ].copy()
  56. # 价格变化幅度
  57. df_price_changes['change_amount'] = df_price_changes['adult_total_price'].diff().fillna(0)
  58. # 找到第一个 change_amount 小于 -10 的行
  59. first_negative_change = df_price_changes[df_price_changes['change_amount'] < -10].head(1)
  60. # 提取所需的值
  61. if not first_negative_change.empty:
  62. drop_flag = 1
  63. first_drop_amount = first_negative_change['change_amount'].iloc[0].round(2)
  64. first_drop_hours = first_negative_change['hours_until_departure'].iloc[0]
  65. else:
  66. drop_flag = 0
  67. first_drop_amount = pd.NA
  68. first_drop_hours = pd.NA
  69. list_change_price = df_price_changes['adult_total_price'].tolist()
  70. list_change_hours = df_price_changes['hours_until_departure'].tolist()
  71. else:
  72. drop_flag = 0
  73. first_drop_amount = pd.NA
  74. first_drop_hours = pd.NA
  75. last_hours_util = pd.NA
  76. last_update_hour = pd.NA
  77. list_change_price = []
  78. list_change_hours = []
  79. safe_sep = "; "
  80. df_predict.at[idx, 'change_prices'] = safe_sep.join(map(str, list_change_price))
  81. df_predict.at[idx, 'change_hours'] = safe_sep.join(map(str, list_change_hours))
  82. df_predict.at[idx, 'last_hours_util'] = last_hours_util
  83. df_predict.at[idx, 'last_update_hour'] = last_update_hour
  84. df_predict.at[idx, 'first_drop_amount'] = first_drop_amount * -1 # 负数转正数
  85. df_predict.at[idx, 'first_drop_hours'] = first_drop_hours
  86. df_predict.at[idx, 'drop_flag'] = drop_flag
  87. count += 1
  88. if count % 5 == 0:
  89. print(f"cal count: {count}")
  90. print(f"计算结束")
  91. client.close()
  92. timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
  93. save_scv = f"result_validate_{node}_{interval_hours}_{pred_time_str}_{timestamp_str}.csv"
  94. output_path = os.path.join(output_dir, save_scv)
  95. df_predict.to_csv(output_path, index=False, encoding="utf-8-sig")
  96. print(f"保存完成: {output_path}")
  97. if __name__ == "__main__":
  98. node, interval_hours, pred_time_str = "node0112", 8, "202601141600"
  99. validate_process(node, interval_hours, pred_time_str)