Files
very_short_lightning/3_train.py
2025-07-28 11:08:04 +08:00

181 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score, classification_report, balanced_accuracy_score
import numpy as np
import os
import pickle
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline # 引入 imblearn Pipeline 以防止数据泄露
# ---------------------------------------------------------------------------
# --- 主流程 ---
# ---------------------------------------------------------------------------
df_flash = pd.read_parquet("data/preprocessed_flash/212/flash_212.parquet")
df_ef = pd.read_parquet("data/preprocessed_ef/212/ef_212.parquet")
df_ef['time'] = pd.to_datetime(df_ef['time'])
df_flash['time'] = pd.to_datetime(df_flash['time'])
merged_df = pd.merge(df_flash, df_ef, on='time', how='inner')
merged_df = merged_df.drop('time', axis=1)
print("数据框形状:", merged_df.shape)
print(merged_df.head())
# 分离特征和目标变量
# 确保列名是字符串,以避免 LightGBM 的问题
X = merged_df.drop('flash_count', axis=1)
X.columns = ["".join (c if c.isalnum() else "_" for c in str(x)) for x in X.columns]
y = merged_df['flash_count']
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=(y > 0) # 使用stratify保证训练集和测试集中0/1比例相似
)
# ---------------------------------------------------------------------------
# --- 步骤 1: 训练不平衡分类器 (使用SMOTE) ---
# ---------------------------------------------------------------------------
print("\n--- 步骤 1: 正在训练分类模型 (使用 SMOTE) ---")
# 创建二元目标变量 (0 vs >0)
y_train_binary = (y_train > 0).astype(int)
y_test_binary = (y_test > 0).astype(int)
# 定义分类器
classifier = lgb.LGBMClassifier(n_estimators=100, random_state=42)
# 定义SMOTE
# n_jobs=-1 表示使用所有可用的CPU核心
smote = SMOTE(random_state=42,)
# 使用 imblearn 的 Pipeline
# 这是使用SMOTE的标准做法可以防止SMOTE在交叉验证时对验证集进行过采样避免数据泄露
smote_pipeline = Pipeline([
('smote', smote),
('classifier', classifier)
])
# 训练分类器
print("正在对训练数据进行SMOTE过采样并训练分类器...")
smote_pipeline.fit(X_train, y_train_binary)
print("分类器训练完成。")
# ---------------------------------------------------------------------------
# --- 步骤 2: 训练回归器 (仅在非零数据上) ---
# ---------------------------------------------------------------------------
print("\n--- 步骤 2: 正在训练回归模型 (仅在非零数据上) ---")
# 定义回归器
regressor = lgb.LGBMRegressor(n_estimators=150, learning_rate=0.05, random_state=42)
# 筛选出训练集中的非零数据
mask_pos_train = (y_train > 0)
X_train_pos = X_train[mask_pos_train]
y_train_pos = y_train[mask_pos_train]
# 检查是否有非零数据可供训练
if X_train_pos.shape[0] > 0:
# 对目标变量进行 log1p 转换,以处理右偏分布
y_train_pos_log = np.log1p(y_train_pos)
print(f"正在使用 {X_train_pos.shape[0]} 个非零样本训练回归器...")
regressor.fit(X_train_pos, y_train_pos_log)
print("回归器训练完成。")
else:
print("警告:训练集中没有非零数据,回归器未被训练。")
# ---------------------------------------------------------------------------
# --- 模型保存 ---
# ---------------------------------------------------------------------------
os.makedirs("model", exist_ok=True)
# 分别保存两个模型
with open("model/212_classifier_model.pkl", "wb") as f:
pickle.dump(smote_pipeline, f)
print("\n分类器模型已保存到 model/212_classifier_model.pkl")
with open("model/212_regressor_model.pkl", "wb") as f:
pickle.dump(regressor, f)
print("回归器模型已保存到 model/212_regressor_model.pkl")
# ---------------------------------------------------------------------------
# --- 预测与评估 ---
# ---------------------------------------------------------------------------
print("\n--- 正在加载模型并进行预测 ---")
# 加载模型 (用于演示,实际应用中可以在新脚本中加载)
with open("model/212_classifier_model.pkl", "rb") as f:
loaded_classifier_pipeline = pickle.load(f)
with open("model/212_regressor_model.pkl", "rb") as f:
loaded_regressor = pickle.load(f)
# --- 组合预测 ---
# 1. 分类器预测为正类的概率
prob_positive = loaded_classifier_pipeline.predict_proba(X_test)[:, 1]
# 2. 回归器预测数值 (对数尺度)
# 如果回归器被训练过,则进行预测
if hasattr(loaded_regressor, 'n_features_in_'):
predictions_log = loaded_regressor.predict(X_test)
# 转换回原始尺度
predictions_pos = np.expm1(predictions_log)
else:
# 如果回归器未被训练则预测为0
predictions_pos = np.zeros(X_test.shape[0])
# 3. 最终预测 = 概率 * 预测值
final_predictions = prob_positive * predictions_pos
# --- 分类器评估 ---
print("\n--- 内部二元分类器评估 (使用SMOTE) ---")
y_pred_binary = loaded_classifier_pipeline.predict(X_test)
print(classification_report(y_test_binary, y_pred_binary, target_names=['是 零 (class 0)', '非 零 (class 1)']))
print(f"平衡准确率 (Balanced Accuracy): {balanced_accuracy_score(y_test_binary, y_pred_binary):.4f}")
# --- 最终回归任务评估 ---
print("\n--- 最终回归任务评估 ---")
mse = mean_squared_error(y_test, final_predictions)
rmse = np.sqrt(mse)
r2 = r2_score(y_test, final_predictions)
print(f"均方误差 (MSE): {mse:.4f}")
print(f"均方根误差 (RMSE): {rmse:.4f}")
print(f"决定系数 (R²): {r2:.4f}")
# ---------------------------------------------------------------------------
# --- 获取并打印特征重要性 ---
# ---------------------------------------------------------------------------
print("\n--- 特征重要性排序 ---")
# 从 Pipeline 中提取分类器
final_classifier = smote_pipeline.named_steps['classifier']
# 分类器的特征重要性
clf_imp = pd.DataFrame({
'feature': X_train.columns,
'importance_classifier': final_classifier.feature_importances_
})
# 回归器的特征重要性
if hasattr(regressor, 'n_features_in_'):
reg_imp = pd.DataFrame({
'feature': X_train_pos.columns, # 使用训练回归器时的列名
'importance_regressor': regressor.feature_importances_
})
# 合并两个重要性DataFrame
importances = pd.merge(clf_imp, reg_imp, on='feature', how='outer').fillna(0)
# 将重要性转换为整数类型以便查看
importances['importance_classifier'] = importances['importance_classifier'].astype(int)
importances['importance_regressor'] = importances['importance_regressor'].astype(int)
else:
importances = clf_imp
importances['importance_regressor'] = 0
print("\n分类器重要性 (用于预测'''非零'):")
print(importances.sort_values('importance_classifier', ascending=False).head(10))
if hasattr(regressor, 'n_features_in_'):
print("\n回归器重要性 (用于预测'非零值'的大小):")
print(importances.sort_values('importance_regressor', ascending=False).head(10))