This commit is contained in:
lvhao
2025-07-28 11:08:04 +08:00
parent 07ab95ff51
commit 47a6cc00e7
11 changed files with 1470 additions and 0 deletions

64
3_train_v3.py Normal file
View File

@@ -0,0 +1,64 @@
import pandas as pd
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score, classification_report, balanced_accuracy_score
import numpy as np
import os
import pickle
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline # 引入 imblearn Pipeline 以防止数据泄露
# ---------------------------------------------------------------------------
# --- 主流程 ---
# ---------------------------------------------------------------------------
df_flash = pd.read_parquet("data/preprocessed_flash/212/flash_212.parquet")
df_ef = pd.read_parquet("data/preprocessed_ef/212/ef_212.parquet")
df_ef['time'] = pd.to_datetime(df_ef['time'])
df_flash['time'] = pd.to_datetime(df_flash['time'])
merged_df = pd.merge(df_flash, df_ef, on='time', how='inner')
merged_df = merged_df.drop('time', axis=1)
print("数据框形状:", merged_df.shape)
print(merged_df.head())
# 分离特征和目标变量
# 确保列名是字符串,以避免 LightGBM 的问题
X = merged_df.drop('flash_count', axis=1)
X.columns = ["".join (c if c.isalnum() else "_" for c in str(x)) for x in X.columns]
y = merged_df['flash_count']
# 1. 先训练分类模型判断flash_count是否为0
y_cls = (y > 0).astype(int) # 0为无闪电1为有闪电
X_train_cls, X_test_cls, y_train_cls, y_test_cls = train_test_split(
X, y_cls, test_size=0.2, random_state=42, stratify=y_cls
)
classifier = lgb.LGBMClassifier(n_estimators=100, learning_rate=0.05, random_state=42)
classifier.fit(X_train_cls, y_train_cls)
# 2. 只用flash_count>0的数据训练回归模型
X_reg = X[y > 0]
y_reg = y[y > 0]
X_train_reg, X_test_reg, y_train_reg, y_test_reg = train_test_split(
X_reg, y_reg, test_size=0.2, random_state=42
)
regressor = lgb.LGBMRegressor(n_estimators=150, learning_rate=0.05, random_state=42)
regressor.fit(X_train_reg, y_train_reg)
# ---------------------------------------------------------------------------
# --- 模型保存 ---
# ---------------------------------------------------------------------------
os.makedirs("model", exist_ok=True)
with open("model/212_classifier_model.pkl", "wb") as f:
pickle.dump(classifier, f)
print("分类器模型已保存到 model/212/212_classifier_model.pkl")
with open("model/212_regressor_model.pkl", "wb") as f:
pickle.dump(regressor, f)
print("回归器模型已保存到 model/212/212_regressor_model.pkl")