init
This commit is contained in:
191
3_train_v4.py
Normal file
191
3_train_v4.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import xgboost as xgb
|
||||
from sklearn.metrics import (accuracy_score, balanced_accuracy_score,
|
||||
classification_report, mean_absolute_error,
|
||||
mean_squared_error, r2_score)
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
def load_station_ids(station_csv_path: str) -> list:
|
||||
"""
|
||||
读取所有站点的EFID。
|
||||
|
||||
Args:
|
||||
station_csv_path (str): 站点信息csv文件路径
|
||||
|
||||
Returns:
|
||||
list: EFID列表
|
||||
"""
|
||||
# 读取站点信息
|
||||
df_station: pd.DataFrame = pd.read_csv(station_csv_path)
|
||||
# 获取EFID列并转为列表
|
||||
efid_list: list = df_station["EFID"].tolist()
|
||||
return efid_list
|
||||
|
||||
def load_and_merge_data_for_station(efid: int) -> pd.DataFrame | None:
|
||||
"""
|
||||
加载指定站点的flash和ef数据,并按时间合并,若有一个不存在则返回None。
|
||||
|
||||
Args:
|
||||
efid (int): 站点EFID
|
||||
|
||||
Returns:
|
||||
pd.DataFrame | None: 合并后的数据框,若数据缺失则为None
|
||||
"""
|
||||
# 构建文件路径
|
||||
flash_path: str = f"data/preprocessed_flash/{efid}/flash_{efid}.parquet"
|
||||
ef_path: str = f"data/preprocessed_ef/ef_{efid}.parquet"
|
||||
# 检查文件是否存在
|
||||
if not (os.path.exists(flash_path) and os.path.exists(ef_path)):
|
||||
# 如果有一个文件不存在,返回None
|
||||
print(f"站点{efid}的flash或ef数据不存在,跳过。")
|
||||
return None
|
||||
# 读取数据
|
||||
df_flash: pd.DataFrame = pd.read_parquet(flash_path)
|
||||
df_ef: pd.DataFrame = pd.read_parquet(ef_path)
|
||||
# 转换时间列为datetime
|
||||
df_flash['time'] = pd.to_datetime(df_flash['time'])
|
||||
df_ef['time'] = pd.to_datetime(df_ef['time'])
|
||||
# 合并数据
|
||||
merged_df: pd.DataFrame = pd.merge(df_flash, df_ef, on='time', how='inner')
|
||||
# 删除time列
|
||||
merged_df = merged_df.drop('time', axis=1)
|
||||
# 增加站点标识列
|
||||
merged_df['EFID'] = efid
|
||||
return merged_df
|
||||
|
||||
def prepare_full_dataset(station_csv_path: str) -> pd.DataFrame:
|
||||
"""
|
||||
合并所有站点的flash和ef数据,跳过缺失的站点。
|
||||
|
||||
Args:
|
||||
station_csv_path (str): 站点信息csv文件路径
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: 合并后的大数据框
|
||||
"""
|
||||
# 获取所有站点EFID
|
||||
efid_list: list = load_station_ids(station_csv_path)
|
||||
# 存储所有合并后的数据
|
||||
merged_list: list = []
|
||||
# 遍历所有站点
|
||||
for efid in efid_list:
|
||||
merged_df: pd.DataFrame | None = load_and_merge_data_for_station(efid)
|
||||
if merged_df is not None:
|
||||
merged_list.append(merged_df)
|
||||
# 合并所有站点数据
|
||||
if not merged_list:
|
||||
raise ValueError("没有可用的站点数据。")
|
||||
full_df: pd.DataFrame = pd.concat(merged_list, ignore_index=True)
|
||||
return full_df
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# --- 主流程 ---
|
||||
# ---------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
# 读取并合并所有站点数据
|
||||
# station_id.csv为站点信息文件
|
||||
full_merged_df: pd.DataFrame = prepare_full_dataset("station_id.csv")
|
||||
|
||||
print("合并后数据框形状:", full_merged_df.shape)
|
||||
print(full_merged_df.head())
|
||||
|
||||
# 分离特征和目标变量
|
||||
# 确保列名是字符串,以避免 LightGBM 的问题
|
||||
X: pd.DataFrame = full_merged_df.drop('flash_count', axis=1)
|
||||
if "EFID" in X.columns:
|
||||
X = X.drop("EFID", axis=1)
|
||||
X.columns = ["".join(c if c.isalnum() else "_" for c in str(x)) for x in X.columns]
|
||||
|
||||
y: pd.Series = full_merged_df['flash_count']
|
||||
|
||||
# 1. 先训练分类模型,判断flash_count是否为0
|
||||
y_cls: pd.Series = (y > 0).astype(int) # 0为无闪电,1为有闪电
|
||||
|
||||
X_train_cls, X_test_cls, y_train_cls, y_test_cls = train_test_split(
|
||||
X, y_cls, test_size=0.2, random_state=42, stratify=y_cls
|
||||
)
|
||||
|
||||
|
||||
# 1. 训练分类模型,判断flash_count是否为0
|
||||
# 由于数据量较大,适当增加树的数量,调整max_depth防止过拟合
|
||||
classifier: xgb.XGBClassifier = xgb.XGBClassifier(
|
||||
n_estimators=300, # 增加树的数量以提升表现
|
||||
learning_rate=0.05, # 学习率
|
||||
max_depth=6, # 控制树的最大深度,防止过拟合
|
||||
subsample=0.8, # 随机采样部分样本,提升泛化能力
|
||||
colsample_bytree=0.8, # 随机采样部分特征,提升泛化能力
|
||||
random_state=42, # 随机种子
|
||||
use_label_encoder=False, # 关闭label encoder警告
|
||||
eval_metric='logloss', # 评估指标
|
||||
n_jobs=-1 # 使用所有CPU核心加速训练
|
||||
)
|
||||
# 拟合分类器
|
||||
classifier.fit(X_train_cls, y_train_cls)
|
||||
|
||||
# 2. 只用flash_count>0的数据训练回归模型
|
||||
# 选取flash_count大于0的样本作为回归数据
|
||||
# 注意:X[y > 0] 可能返回DataFrame或Series,需确保类型正确
|
||||
X_reg = X.loc[y > 0] # 保证返回DataFrame
|
||||
y_reg = y.loc[y > 0] # 保证返回Series
|
||||
|
||||
# 划分回归训练集和测试集
|
||||
X_train_reg, X_test_reg, y_train_reg, y_test_reg = train_test_split(
|
||||
X_reg, y_reg, test_size=0.2, random_state=42
|
||||
)
|
||||
|
||||
# 创建XGBoost回归器,设置参数适应大数据量
|
||||
regressor: xgb.XGBRegressor = xgb.XGBRegressor(
|
||||
n_estimators=400, # 增加树的数量以提升表现
|
||||
learning_rate=0.05, # 学习率
|
||||
max_depth=6, # 控制树的最大深度
|
||||
subsample=0.8, # 随机采样部分样本
|
||||
colsample_bytree=0.8, # 随机采样部分特征
|
||||
random_state=42, # 随机种子
|
||||
n_jobs=-1 # 使用所有CPU核心加速训练
|
||||
)
|
||||
# 拟合回归器
|
||||
regressor.fit(X_train_reg, y_train_reg)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# --- 模型保存 ---
|
||||
# ---------------------------------------------------------------------------
|
||||
os.makedirs("model", exist_ok=True)
|
||||
|
||||
with open("model/classifier_model.pkl", "wb") as f:
|
||||
pickle.dump(classifier, f)
|
||||
print("分类器模型已保存到 model/classifier_model.pkl")
|
||||
|
||||
with open("model/regressor_model.pkl", "wb") as f:
|
||||
pickle.dump(regressor, f)
|
||||
print("回归器模型已保存到 model/regressor_model.pkl")
|
||||
|
||||
# =====
|
||||
# 性能评估
|
||||
# ===========================================
|
||||
# 分类器预测
|
||||
y_pred_cls: list[int] = classifier.predict(X_test_cls)
|
||||
# 计算准确率
|
||||
cls_accuracy: float = accuracy_score(y_test_cls, y_pred_cls)
|
||||
print(f"分类器在测试集上的准确率: {cls_accuracy:.4f}")
|
||||
# 输出详细分类报告
|
||||
print("分类器详细分类报告:")
|
||||
print(classification_report(y_test_cls, y_pred_cls, digits=4))
|
||||
|
||||
# 评估回归器在测试集上的性能
|
||||
# 回归器只在flash_count>0的样本上评估
|
||||
y_pred_reg: list[float] = regressor.predict(X_test_reg)
|
||||
# 计算均方误差
|
||||
reg_mse: float = mean_squared_error(y_test_reg, y_pred_reg)
|
||||
# 计算平均绝对误差
|
||||
reg_mae: float = mean_absolute_error(y_test_reg, y_pred_reg)
|
||||
# 计算R2分数
|
||||
reg_r2: float = r2_score(y_test_reg, y_pred_reg)
|
||||
print(f"回归器在测试集上的均方误差(MSE): {reg_mse:.4f}")
|
||||
print(f"回归器在测试集上的平均绝对误差(MAE): {reg_mae:.4f}")
|
||||
print(f"回归器在测试集上的R2分数: {reg_r2:.4f}")
|
||||
Reference in New Issue
Block a user