init
This commit is contained in:
51
3_train_v2.py
Normal file
51
3_train_v2.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import pandas as pd
|
||||
import lightgbm as lgb
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import mean_squared_error, r2_score, classification_report, balanced_accuracy_score
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
from imblearn.over_sampling import SMOTE
|
||||
from imblearn.pipeline import Pipeline # 引入 imblearn Pipeline 以防止数据泄露
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# --- 主流程 ---
|
||||
# ---------------------------------------------------------------------------
|
||||
df_flash = pd.read_parquet("data/preprocessed_flash/212/flash_212.parquet")
|
||||
df_ef = pd.read_parquet("data/preprocessed_ef/212/ef_212.parquet")
|
||||
df_ef['time'] = pd.to_datetime(df_ef['time'])
|
||||
df_flash['time'] = pd.to_datetime(df_flash['time'])
|
||||
merged_df = pd.merge(df_flash, df_ef, on='time', how='inner')
|
||||
merged_df = merged_df.drop('time', axis=1)
|
||||
|
||||
print("数据框形状:", merged_df.shape)
|
||||
print(merged_df.head())
|
||||
|
||||
# 分离特征和目标变量
|
||||
# 确保列名是字符串,以避免 LightGBM 的问题
|
||||
X = merged_df.drop('flash_count', axis=1)
|
||||
X.columns = ["".join (c if c.isalnum() else "_" for c in str(x)) for x in X.columns]
|
||||
|
||||
y = merged_df['flash_count']
|
||||
|
||||
# 划分训练集和测试集
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.2, random_state=42, stratify=(y > 0) # 使用stratify保证训练集和测试集中0/1比例相似
|
||||
)
|
||||
|
||||
|
||||
# 定义回归器
|
||||
regressor = lgb.LGBMRegressor(n_estimators=150, learning_rate=0.05, random_state=42)
|
||||
|
||||
regressor.fit(X_train, y_train)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# --- 模型保存 ---
|
||||
# ---------------------------------------------------------------------------
|
||||
os.makedirs("model", exist_ok=True)
|
||||
|
||||
with open("model/212_regressor_model.pkl", "wb") as f:
|
||||
pickle.dump(regressor, f)
|
||||
print("回归器模型已保存到 model/212_regressor_model.pkl")
|
||||
|
||||
Reference in New Issue
Block a user