result_validate.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. import os
  2. import datetime
  3. import pandas as pd
  4. import argparse
  5. from data_loader import mongo_con_parse, validate_one_line, fill_hourly_crawl_date
  6. def validate_process(node, interval_hours, pred_time_str):
  7. date = pred_time_str[4:8]
  8. output_dir = f"./validate/{node}_{date}"
  9. os.makedirs(output_dir, exist_ok=True)
  10. object_dir = "./predictions"
  11. if interval_hours == 4:
  12. object_dir = "./predictions_4"
  13. elif interval_hours == 2:
  14. object_dir = "./predictions_2"
  15. csv_file = f'future_predictions_{pred_time_str}.csv'
  16. csv_path = os.path.join(object_dir, csv_file)
  17. try:
  18. df_predict = pd.read_csv(csv_path)
  19. except Exception as e:
  20. print(f"read {csv_path} error: {str(e)}")
  21. df_predict = pd.DataFrame()
  22. if df_predict.empty:
  23. print(f"预测数据为空")
  24. return
  25. # fly_day = df_predict['flight_day'].unique()[0]
  26. client, db = mongo_con_parse()
  27. count = 0
  28. for idx, row in df_predict.iterrows():
  29. city_pair = row['city_pair']
  30. flight_day = row['flight_day']
  31. flight_number_1 = row['flight_number_1']
  32. flight_number_2 = row['flight_number_2']
  33. baggage = row['baggage']
  34. valid_begin_hour = row['valid_begin_hour']
  35. df_val= validate_one_line(db, city_pair, flight_day, flight_number_1, flight_number_2, baggage, valid_begin_hour)
  36. # 有可能在当前验证时刻,数据库里没有在valid_begin_hour之后的数据
  37. if not df_val.empty:
  38. df_val_f = fill_hourly_crawl_date(df_val, rear_fill=2)
  39. df_val_f = df_val_f[df_val_f['is_filled']==0] # 只要原始数据,不要补齐的
  40. if df_val_f.empty:
  41. drop_flag = 0
  42. first_drop_amount = pd.NA
  43. first_drop_hours = pd.NA
  44. last_hours_util = pd.NA
  45. last_update_hour = pd.NA
  46. list_change_price = []
  47. list_change_hours = []
  48. else:
  49. # 有效数据的最后一行
  50. last_row = df_val_f.iloc[-1]
  51. last_hours_util = last_row['hours_until_departure']
  52. last_update_hour = last_row['update_hour']
  53. # 价格变化过滤
  54. df_price_changes = df_val_f.loc[
  55. df_val_f["adult_total_price"].shift() != df_val_f["adult_total_price"]
  56. ].copy()
  57. # 价格变化幅度
  58. df_price_changes['change_amount'] = df_price_changes['adult_total_price'].diff().fillna(0)
  59. # 找到第一个 change_amount 小于 -10 的行
  60. first_negative_change = df_price_changes[df_price_changes['change_amount'] < -10].head(1)
  61. # 提取所需的值
  62. if not first_negative_change.empty:
  63. drop_flag = 1
  64. first_drop_amount = first_negative_change['change_amount'].iloc[0].round(2)
  65. first_drop_hours = first_negative_change['hours_until_departure'].iloc[0]
  66. else:
  67. drop_flag = 0
  68. first_drop_amount = pd.NA
  69. first_drop_hours = pd.NA
  70. list_change_price = df_price_changes['adult_total_price'].tolist()
  71. list_change_hours = df_price_changes['hours_until_departure'].tolist()
  72. else:
  73. drop_flag = 0
  74. first_drop_amount = pd.NA
  75. first_drop_hours = pd.NA
  76. last_hours_util = pd.NA
  77. last_update_hour = pd.NA
  78. list_change_price = []
  79. list_change_hours = []
  80. safe_sep = "; "
  81. df_predict.at[idx, 'change_prices'] = safe_sep.join(map(str, list_change_price))
  82. df_predict.at[idx, 'change_hours'] = safe_sep.join(map(str, list_change_hours))
  83. df_predict.at[idx, 'last_hours_util'] = last_hours_util
  84. df_predict.at[idx, 'last_update_hour'] = last_update_hour
  85. df_predict.at[idx, 'first_drop_amount'] = first_drop_amount * -1 # 负数转正数
  86. df_predict.at[idx, 'first_drop_hours'] = first_drop_hours
  87. df_predict.at[idx, 'drop_flag'] = drop_flag
  88. count += 1
  89. if count % 5 == 0:
  90. print(f"cal count: {count}")
  91. print(f"计算结束")
  92. client.close()
  93. timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
  94. save_scv = f"result_validate_{node}_{interval_hours}_{pred_time_str}_{timestamp_str}.csv"
  95. output_path = os.path.join(output_dir, save_scv)
  96. df_predict.to_csv(output_path, index=False, encoding="utf-8-sig")
  97. print(f"保存完成: {output_path}")
  98. def validate_process_auto(node, interval_hours):
  99. '''自动验证脚本'''
  100. # 当前时间,取整时
  101. current_time = datetime.datetime.now()
  102. current_time_str = current_time.strftime("%Y%m%d%H%M")
  103. hourly_time = current_time.replace(minute=0, second=0, microsecond=0)
  104. vali_time_str = hourly_time.strftime("%Y%m%d%H%M")
  105. print(f"验证时间:{current_time_str}, (取整): {vali_time_str}")
  106. output_dir = f"./validate/{node}"
  107. os.makedirs(output_dir, exist_ok=True)
  108. object_dir = "./predictions"
  109. if interval_hours == 4:
  110. object_dir = "./predictions_4"
  111. elif interval_hours == 2:
  112. object_dir = "./predictions_2"
  113. # 检查目录是否存在
  114. if not os.path.exists(object_dir):
  115. print(f"目录不存在: {object_dir}")
  116. return
  117. # 获取所有以 future_predictions_ 开头的 CSV 文件
  118. csv_files = []
  119. for file in os.listdir(object_dir):
  120. if file.startswith("future_predictions_") and file.endswith(".csv"):
  121. csv_files.append(file)
  122. if not csv_files:
  123. print(f"在 {object_dir} 中没有找到 future_predictions_ 开头的 CSV 文件")
  124. return
  125. # 提取时间戳并转换为 datetime 对象
  126. file_times = []
  127. for file in csv_files:
  128. # 提取时间戳部分:future_predictions_202601151600.csv -> 202601151600
  129. timestamp_str = file.replace("future_predictions_", "").replace(".csv", "")
  130. try:
  131. # 将时间戳转换为 datetime 对象
  132. file_time = datetime.datetime.strptime(timestamp_str, "%Y%m%d%H%M")
  133. file_times.append((file, file_time))
  134. except ValueError as e:
  135. print(f"文件 {file} 的时间戳格式错误: {e}")
  136. continue
  137. if not file_times:
  138. print("没有找到有效的时间戳文件")
  139. return
  140. # 计算昨天的对应时间
  141. yesterday_time = hourly_time - datetime.timedelta(hours=24)
  142. print(f"昨天对应时间: {yesterday_time.strftime('%Y%m%d%H%M')}")
  143. # 过滤出小于昨天对应时间的文件,并按时间排序
  144. valid_files = [(f, t) for f, t in file_times if t < yesterday_time]
  145. valid_files.sort(key=lambda x: x[1]) # 按时间升序排序
  146. if not valid_files:
  147. print(f"没有找到小于昨天对应时间 {yesterday_time.strftime('%Y%m%d%H%M')} 的文件")
  148. return
  149. # 获取最后一个小于昨天对应时间的文件
  150. last_valid_file, last_valid_time = valid_files[-1]
  151. last_valid_time_str = last_valid_time.strftime("%Y%m%d%H%M")
  152. print(f"找到符合条件的文件: {last_valid_file} (时间: {last_valid_time_str})")
  153. csv_path = os.path.join(object_dir, last_valid_file)
  154. # 开始验证
  155. try:
  156. df_predict = pd.read_csv(csv_path)
  157. except Exception as e:
  158. print(f"read {csv_path} error: {str(e)}")
  159. df_predict = pd.DataFrame()
  160. if df_predict.empty:
  161. print(f"预测数据为空")
  162. return
  163. client, db = mongo_con_parse()
  164. count = 0
  165. for idx, row in df_predict.iterrows():
  166. city_pair = row['city_pair']
  167. flight_day = row['flight_day']
  168. flight_number_1 = row['flight_number_1']
  169. flight_number_2 = row['flight_number_2']
  170. baggage = row['baggage']
  171. valid_begin_hour = row['valid_begin_hour']
  172. df_val= validate_one_line(db, city_pair, flight_day, flight_number_1, flight_number_2, baggage, valid_begin_hour)
  173. # 有可能在当前验证时刻,数据库里没有在valid_begin_hour之后的数据
  174. if not df_val.empty:
  175. df_val_f = fill_hourly_crawl_date(df_val, rear_fill=2)
  176. df_val_f = df_val_f[df_val_f['is_filled']==0] # 只要原始数据,不要补齐的
  177. if df_val_f.empty:
  178. drop_flag = 0
  179. first_drop_amount = pd.NA
  180. first_drop_hours = pd.NA
  181. last_hours_util = pd.NA
  182. last_update_hour = pd.NA
  183. list_change_price = []
  184. list_change_hours = []
  185. else:
  186. # 有效数据的最后一行
  187. last_row = df_val_f.iloc[-1]
  188. last_hours_util = last_row['hours_until_departure']
  189. last_update_hour = last_row['update_hour']
  190. # 价格变化过滤
  191. df_price_changes = df_val_f.loc[
  192. df_val_f["adult_total_price"].shift() != df_val_f["adult_total_price"]
  193. ].copy()
  194. # 价格变化幅度
  195. df_price_changes['change_amount'] = df_price_changes['adult_total_price'].diff().fillna(0)
  196. # 找到第一个 change_amount 小于 -10 的行
  197. first_negative_change = df_price_changes[df_price_changes['change_amount'] < -10].head(1)
  198. # 提取所需的值
  199. if not first_negative_change.empty:
  200. drop_flag = 1
  201. first_drop_amount = first_negative_change['change_amount'].iloc[0].round(2)
  202. first_drop_hours = first_negative_change['hours_until_departure'].iloc[0]
  203. else:
  204. drop_flag = 0
  205. first_drop_amount = pd.NA
  206. first_drop_hours = pd.NA
  207. list_change_price = df_price_changes['adult_total_price'].tolist()
  208. list_change_hours = df_price_changes['hours_until_departure'].tolist()
  209. else:
  210. drop_flag = 0
  211. first_drop_amount = pd.NA
  212. first_drop_hours = pd.NA
  213. last_hours_util = pd.NA
  214. last_update_hour = pd.NA
  215. list_change_price = []
  216. list_change_hours = []
  217. safe_sep = "; "
  218. df_predict.at[idx, 'change_prices'] = safe_sep.join(map(str, list_change_price))
  219. df_predict.at[idx, 'change_hours'] = safe_sep.join(map(str, list_change_hours))
  220. df_predict.at[idx, 'last_hours_util'] = last_hours_util
  221. df_predict.at[idx, 'last_update_hour'] = last_update_hour
  222. df_predict.at[idx, 'first_drop_amount'] = first_drop_amount * -1 # 负数转正数
  223. df_predict.at[idx, 'first_drop_hours'] = first_drop_hours
  224. df_predict.at[idx, 'drop_flag'] = drop_flag
  225. count += 1
  226. if count % 5 == 0:
  227. print(f"cal count: {count}")
  228. print(f"计算结束")
  229. client.close()
  230. timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
  231. save_scv = f"result_validate_{node}_{interval_hours}_{last_valid_time_str}_{timestamp_str}.csv"
  232. output_path = os.path.join(output_dir, save_scv)
  233. df_predict.to_csv(output_path, index=False, encoding="utf-8-sig")
  234. print(f"保存完成: {output_path}")
  235. print(f"验证完成: {node} {interval_hours} {last_valid_time_str}")
  236. print()
  237. if __name__ == "__main__":
  238. parser = argparse.ArgumentParser(description='验证脚本')
  239. parser.add_argument('--interval', type=int, choices=[2, 4, 8],
  240. default=0, help='间隔小时数(2, 4, 8)')
  241. args = parser.parse_args()
  242. interval_hours = args.interval
  243. # 0 手动验证
  244. if interval_hours == 0:
  245. node, interval_hours, pred_time_str = "node0112", 8, "202601151600"
  246. validate_process(node, interval_hours, pred_time_str)
  247. # 自动验证
  248. else:
  249. # 这个node可以手动去改
  250. node = "node0112"
  251. if interval_hours == 4:
  252. node = "node0114"
  253. if interval_hours == 2:
  254. node = "node0115"
  255. validate_process_auto(node, interval_hours)