Files
very_short_lightning/4_predict_v2.py
2025-07-28 11:08:04 +08:00

170 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import pickle
import os
import glob
import numpy as np
from scipy.stats import skew, kurtosis
from loguru import logger
# --- 配置参数 (与原脚本保持一致) ---
WINDOW_SIZES = [2, 5, 10, 30, 60, 300, 600, 1200]
# 为了计算1200秒的窗口我们需要至少20分钟的数据。为安全起见我们加载约25分钟的数据。
# 假设每个CSV文件包含5分钟数据加载最新的5个文件。
NUM_FILES_TO_LOAD = 5
def get_current_features(station_id: str) -> tuple[pd.DataFrame, pd.Timestamp] | tuple[None, None]:
"""
加载指定站点的最新实时数据,并计算最后一个时间点的滚动窗口特征。
Args:
station_id (str): 站点ID用于定位数据目录。
Returns:
tuple[pd.DataFrame, pd.Timestamp] | tuple[None, None]:
包含单个时间点及其所有特征的DataFrame和最新时间点如果数据不足则返回(None, None)。
"""
logger.info(f"\n--- 开始为站点 {station_id} 计算当前特征 ---")
# 1. 加载近期数据
file_list: list[str] = glob.glob(f"./realtime_data/{station_id}/*.csv")
if not file_list:
logger.info(f"错误: 在 './realtime_data/{station_id}/' 目录下未找到任何CSV文件。")
return None, None
file_list = sorted(file_list)
files_to_process: list[str] = file_list[-NUM_FILES_TO_LOAD:]
logger.info(f"准备加载最新的 {len(files_to_process)} 个文件: {files_to_process}")
df_list: list[pd.DataFrame] = [pd.read_csv(f) for f in files_to_process]
if not df_list:
logger.info("错误: 未能成功读取任何文件。")
return None, None
df: pd.DataFrame = pd.concat(df_list, ignore_index=True)
logger.info(df.head())
# 2. 数据预处理
df["time"] = pd.to_datetime(df["time"], format="%Y-%m-%d %H:%M:%S")
df = df.sort_values('time').drop_duplicates(subset='time', keep='first')
logger.info("line49 df, %s", df)
if df.empty:
logger.info("错误: 处理后的数据为空。")
return None, None
# 创建连续时间索引并插值
full_time_index: pd.DatetimeIndex = pd.date_range(start=df["time"].min(), end=df["time"].max(), freq="s")
df = df.set_index("time").reindex(full_time_index).rename_axis("time")
df['ef'] = df['ef'].interpolate(method='linear', limit_direction='both')
df = df.reset_index()
# 3. 为最后一个时间点计算特征
latest_time: pd.Timestamp = df['time'].iloc[-1]
logger.info(f"数据已预处理完毕,将为最新时间点 {latest_time} 计算特征。")
ef_values: np.ndarray = df['ef'].values
diff_values: np.ndarray = np.diff(ef_values, prepend=ef_values[0])
current_features: dict = {'time': latest_time}
for win in WINDOW_SIZES:
# 检查是否有足够的数据来形成窗口
if len(ef_values) < win:
logger.info(f"数据长度 ({len(ef_values)}s) 不足以计算 {win}s 窗口特征将填充为NaN。")
current_features[f'mean_{win}s'] = np.nan
current_features[f'std_{win}s'] = np.nan
current_features[f'var_{win}s'] = np.nan
current_features[f'min_{win}s'] = np.nan
current_features[f'max_{win}s'] = np.nan
current_features[f'ptp_{win}s'] = np.nan
current_features[f'skew_{win}s'] = np.nan
current_features[f'kurt_{win}s'] = np.nan
current_features[f'diff_mean_{win}s'] = np.nan
current_features[f'diff_std_{win}s'] = np.nan
continue
# 只取信号末尾的 `win` 个点进行计算
last_window_ef: np.ndarray = ef_values[-win:]
last_window_diff: np.ndarray = diff_values[-win:]
# 计算统计特征
current_features[f'mean_{win}s'] = last_window_ef.mean()
current_features[f'std_{win}s'] = last_window_ef.std(ddof=1)
current_features[f'var_{win}s'] = last_window_ef.var(ddof=1)
current_features[f'min_{win}s'] = last_window_ef.min()
current_features[f'max_{win}s'] = last_window_ef.max()
current_features[f'ptp_{win}s'] = np.ptp(last_window_ef)
current_features[f'skew_{win}s'] = skew(last_window_ef, bias=False)
current_features[f'kurt_{win}s'] = kurtosis(last_window_ef, bias=False)
# 计算差分特征
current_features[f'diff_mean_{win}s'] = last_window_diff.mean()
current_features[f'diff_std_{win}s'] = last_window_diff.std(ddof=1)
# 4. 返回单行DataFrame
return pd.DataFrame([current_features]), latest_time
if __name__ == "__main__":
# 获取realtime_data文件夹下所有子文件夹的名称作为station_id_list
# 1. 使用glob模块获取realtime_data目录下的所有内容
all_paths: list = glob.glob("realtime_data/*")
# 2. 只保留目录(即文件夹),并提取其名称
station_id_list: list = [os.path.basename(path) for path in all_paths if os.path.isdir(path)]
print(f"station_id_list:{station_id_list}")
# 循环处理每个需要计算的站点
for sid in station_id_list:
# 调用核心函数,获取特征和最新时间
latest_features_df, latest_time = get_current_features(sid)
if latest_features_df is None or latest_time is None:
logger.info("未能获取到有效的特征数据,跳过该站点。")
continue
# 只保留特征,不包含'time'列
features_for_pred: pd.DataFrame = latest_features_df.drop('time', axis=1)
# 1. 先加载分类模型判断是否为0
classifier_model_path: str = "model/classifier_model.pkl"
if not os.path.exists(classifier_model_path):
logger.error(f"分类模型文件不存在: {classifier_model_path}")
continue
with open(classifier_model_path, "rb") as f:
classifier_model = pickle.load(f)
# 预测分类结果
try:
classifier_pred: int | float = classifier_model.predict(features_for_pred)[0]
except Exception as e:
logger.error(f"分类模型预测出错: {e}")
continue
logger.info(f"分类模型预测结果: {classifier_pred}")
# 2. 如果分类结果为0则回归预测为0否则调用回归模型
if classifier_pred == 0:
prediction: int = 0
logger.info("分类结果为0回归预测直接设为0。")
else:
regressor_model_path: str = "model/regressor_model.pkl"
if not os.path.exists(regressor_model_path):
logger.error(f"回归模型文件不存在: {regressor_model_path}")
continue
with open(regressor_model_path, "rb") as f:
regressor_model = pickle.load(f)
try:
prediction: float = regressor_model.predict(features_for_pred)[0]
except Exception as e:
logger.error(f"回归模型预测出错: {e}")
continue
logger.info(f"回归模型预测结果: {prediction}")
# 3. 结果写入
result: pd.DataFrame = pd.DataFrame({'time': [latest_time], 'prediction': [prediction]})
file_name: str = f"result/{sid}/{latest_time:%Y%m%d%H}.csv"
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'a', encoding='utf-8') as f:
if f.tell() == 0:
header_string: str = ",".join(result.columns) + "\n"
f.write(header_string)
f.write(result.to_csv(index=False, header=False))
print(f"结果已通过文本追加方式写入: {file_name}")