Files
very_short_lightning/3_train_v4.py
2025-07-28 11:08:04 +08:00

192 lines
7.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import pickle
import lightgbm as lgb
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.metrics import (accuracy_score, balanced_accuracy_score,
classification_report, mean_absolute_error,
mean_squared_error, r2_score)
from sklearn.model_selection import train_test_split
def load_station_ids(station_csv_path: str) -> list:
"""
读取所有站点的EFID。
Args:
station_csv_path (str): 站点信息csv文件路径
Returns:
list: EFID列表
"""
# 读取站点信息
df_station: pd.DataFrame = pd.read_csv(station_csv_path)
# 获取EFID列并转为列表
efid_list: list = df_station["EFID"].tolist()
return efid_list
def load_and_merge_data_for_station(efid: int) -> pd.DataFrame | None:
"""
加载指定站点的flash和ef数据并按时间合并若有一个不存在则返回None。
Args:
efid (int): 站点EFID
Returns:
pd.DataFrame | None: 合并后的数据框若数据缺失则为None
"""
# 构建文件路径
flash_path: str = f"data/preprocessed_flash/{efid}/flash_{efid}.parquet"
ef_path: str = f"data/preprocessed_ef/ef_{efid}.parquet"
# 检查文件是否存在
if not (os.path.exists(flash_path) and os.path.exists(ef_path)):
# 如果有一个文件不存在返回None
print(f"站点{efid}的flash或ef数据不存在跳过。")
return None
# 读取数据
df_flash: pd.DataFrame = pd.read_parquet(flash_path)
df_ef: pd.DataFrame = pd.read_parquet(ef_path)
# 转换时间列为datetime
df_flash['time'] = pd.to_datetime(df_flash['time'])
df_ef['time'] = pd.to_datetime(df_ef['time'])
# 合并数据
merged_df: pd.DataFrame = pd.merge(df_flash, df_ef, on='time', how='inner')
# 删除time列
merged_df = merged_df.drop('time', axis=1)
# 增加站点标识列
merged_df['EFID'] = efid
return merged_df
def prepare_full_dataset(station_csv_path: str) -> pd.DataFrame:
"""
合并所有站点的flash和ef数据跳过缺失的站点。
Args:
station_csv_path (str): 站点信息csv文件路径
Returns:
pd.DataFrame: 合并后的大数据框
"""
# 获取所有站点EFID
efid_list: list = load_station_ids(station_csv_path)
# 存储所有合并后的数据
merged_list: list = []
# 遍历所有站点
for efid in efid_list:
merged_df: pd.DataFrame | None = load_and_merge_data_for_station(efid)
if merged_df is not None:
merged_list.append(merged_df)
# 合并所有站点数据
if not merged_list:
raise ValueError("没有可用的站点数据。")
full_df: pd.DataFrame = pd.concat(merged_list, ignore_index=True)
return full_df
# ---------------------------------------------------------------------------
# --- 主流程 ---
# ---------------------------------------------------------------------------
if __name__ == "__main__":
# 读取并合并所有站点数据
# station_id.csv为站点信息文件
full_merged_df: pd.DataFrame = prepare_full_dataset("station_id.csv")
print("合并后数据框形状:", full_merged_df.shape)
print(full_merged_df.head())
# 分离特征和目标变量
# 确保列名是字符串,以避免 LightGBM 的问题
X: pd.DataFrame = full_merged_df.drop('flash_count', axis=1)
if "EFID" in X.columns:
X = X.drop("EFID", axis=1)
X.columns = ["".join(c if c.isalnum() else "_" for c in str(x)) for x in X.columns]
y: pd.Series = full_merged_df['flash_count']
# 1. 先训练分类模型判断flash_count是否为0
y_cls: pd.Series = (y > 0).astype(int) # 0为无闪电1为有闪电
X_train_cls, X_test_cls, y_train_cls, y_test_cls = train_test_split(
X, y_cls, test_size=0.2, random_state=42, stratify=y_cls
)
# 1. 训练分类模型判断flash_count是否为0
# 由于数据量较大适当增加树的数量调整max_depth防止过拟合
classifier: xgb.XGBClassifier = xgb.XGBClassifier(
n_estimators=300, # 增加树的数量以提升表现
learning_rate=0.05, # 学习率
max_depth=6, # 控制树的最大深度,防止过拟合
subsample=0.8, # 随机采样部分样本,提升泛化能力
colsample_bytree=0.8, # 随机采样部分特征,提升泛化能力
random_state=42, # 随机种子
use_label_encoder=False, # 关闭label encoder警告
eval_metric='logloss', # 评估指标
n_jobs=-1 # 使用所有CPU核心加速训练
)
# 拟合分类器
classifier.fit(X_train_cls, y_train_cls)
# 2. 只用flash_count>0的数据训练回归模型
# 选取flash_count大于0的样本作为回归数据
# 注意X[y > 0] 可能返回DataFrame或Series需确保类型正确
X_reg = X.loc[y > 0] # 保证返回DataFrame
y_reg = y.loc[y > 0] # 保证返回Series
# 划分回归训练集和测试集
X_train_reg, X_test_reg, y_train_reg, y_test_reg = train_test_split(
X_reg, y_reg, test_size=0.2, random_state=42
)
# 创建XGBoost回归器设置参数适应大数据量
regressor: xgb.XGBRegressor = xgb.XGBRegressor(
n_estimators=400, # 增加树的数量以提升表现
learning_rate=0.05, # 学习率
max_depth=6, # 控制树的最大深度
subsample=0.8, # 随机采样部分样本
colsample_bytree=0.8, # 随机采样部分特征
random_state=42, # 随机种子
n_jobs=-1 # 使用所有CPU核心加速训练
)
# 拟合回归器
regressor.fit(X_train_reg, y_train_reg)
# ---------------------------------------------------------------------------
# --- 模型保存 ---
# ---------------------------------------------------------------------------
os.makedirs("model", exist_ok=True)
with open("model/classifier_model.pkl", "wb") as f:
pickle.dump(classifier, f)
print("分类器模型已保存到 model/classifier_model.pkl")
with open("model/regressor_model.pkl", "wb") as f:
pickle.dump(regressor, f)
print("回归器模型已保存到 model/regressor_model.pkl")
# =====
# 性能评估
# ===========================================
# 分类器预测
y_pred_cls: list[int] = classifier.predict(X_test_cls)
# 计算准确率
cls_accuracy: float = accuracy_score(y_test_cls, y_pred_cls)
print(f"分类器在测试集上的准确率: {cls_accuracy:.4f}")
# 输出详细分类报告
print("分类器详细分类报告:")
print(classification_report(y_test_cls, y_pred_cls, digits=4))
# 评估回归器在测试集上的性能
# 回归器只在flash_count>0的样本上评估
y_pred_reg: list[float] = regressor.predict(X_test_reg)
# 计算均方误差
reg_mse: float = mean_squared_error(y_test_reg, y_pred_reg)
# 计算平均绝对误差
reg_mae: float = mean_absolute_error(y_test_reg, y_pred_reg)
# 计算R2分数
reg_r2: float = r2_score(y_test_reg, y_pred_reg)
print(f"回归器在测试集上的均方误差(MSE): {reg_mse:.4f}")
print(f"回归器在测试集上的平均绝对误差(MAE): {reg_mae:.4f}")
print(f"回归器在测试集上的R2分数: {reg_r2:.4f}")