Files
very_short_lightning/3_train_v3.py
2025-07-28 11:08:04 +08:00

65 lines
2.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score, classification_report, balanced_accuracy_score
import numpy as np
import os
import pickle
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline # 引入 imblearn Pipeline 以防止数据泄露
# ---------------------------------------------------------------------------
# --- 主流程 ---
# ---------------------------------------------------------------------------
df_flash = pd.read_parquet("data/preprocessed_flash/212/flash_212.parquet")
df_ef = pd.read_parquet("data/preprocessed_ef/212/ef_212.parquet")
df_ef['time'] = pd.to_datetime(df_ef['time'])
df_flash['time'] = pd.to_datetime(df_flash['time'])
merged_df = pd.merge(df_flash, df_ef, on='time', how='inner')
merged_df = merged_df.drop('time', axis=1)
print("数据框形状:", merged_df.shape)
print(merged_df.head())
# 分离特征和目标变量
# 确保列名是字符串,以避免 LightGBM 的问题
X = merged_df.drop('flash_count', axis=1)
X.columns = ["".join (c if c.isalnum() else "_" for c in str(x)) for x in X.columns]
y = merged_df['flash_count']
# 1. 先训练分类模型判断flash_count是否为0
y_cls = (y > 0).astype(int) # 0为无闪电1为有闪电
X_train_cls, X_test_cls, y_train_cls, y_test_cls = train_test_split(
X, y_cls, test_size=0.2, random_state=42, stratify=y_cls
)
classifier = lgb.LGBMClassifier(n_estimators=100, learning_rate=0.05, random_state=42)
classifier.fit(X_train_cls, y_train_cls)
# 2. 只用flash_count>0的数据训练回归模型
X_reg = X[y > 0]
y_reg = y[y > 0]
X_train_reg, X_test_reg, y_train_reg, y_test_reg = train_test_split(
X_reg, y_reg, test_size=0.2, random_state=42
)
regressor = lgb.LGBMRegressor(n_estimators=150, learning_rate=0.05, random_state=42)
regressor.fit(X_train_reg, y_train_reg)
# ---------------------------------------------------------------------------
# --- 模型保存 ---
# ---------------------------------------------------------------------------
os.makedirs("model", exist_ok=True)
with open("model/212_classifier_model.pkl", "wb") as f:
pickle.dump(classifier, f)
print("分类器模型已保存到 model/212/212_classifier_model.pkl")
with open("model/212_regressor_model.pkl", "wb") as f:
pickle.dump(regressor, f)
print("回归器模型已保存到 model/212/212_regressor_model.pkl")