Files
very_short_lightning/1_get_ef_data.py
2025-07-28 11:08:04 +08:00

234 lines
9.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import os, glob
import numpy as np
from scipy.stats import skew, kurtosis
import pywt
from scipy.signal import butter, filtfilt
from concurrent.futures import ProcessPoolExecutor, as_completed
from loguru import logger
import warnings
import traceback
logger.add("1_get_ef_data.log", rotation="10 MB", retention="10 days", level="DEBUG")
# 抑制特定的警告
warnings.filterwarnings("ignore", category=RuntimeWarning,
message=".*Precision loss occurred in moment calculation.*")
warnings.filterwarnings("ignore", category=RuntimeWarning,
message=".*invalid value encountered in divide.*")
def safe_skew(arr, axis=1, bias=False, batch_size=50000):
try:
logger.debug(f"safe_skew: input array shape: {arr.shape}")
var = np.var(arr, axis=axis, ddof=1 if not bias else 0)
result = np.full(arr.shape[0] if axis == 1 else arr.shape[1], np.nan)
valid_mask = var > 1e-10
logger.debug(f"safe_skew: valid_mask count: {np.sum(valid_mask)}/{len(valid_mask)}")
valid_indices = np.where(valid_mask)[0]
for start in range(0, len(valid_indices), batch_size):
batch_idx = valid_indices[start:start + batch_size]
if axis == 1:
sub_arr = arr[batch_idx]
else:
sub_arr = arr[:, batch_idx]
try:
sub_skew = skew(sub_arr, axis=axis, bias=bias, nan_policy='omit')
result[batch_idx] = sub_skew
except Exception as e:
logger.warning(f"safe_skew batch error: {e}")
continue
logger.debug(f"safe_skew: result shape: {result.shape}, nan count: {np.sum(np.isnan(result))}")
return result
except Exception as e:
logger.error(f"safe_skew error: {e}")
logger.error(traceback.format_exc())
return np.full(arr.shape[0] if axis == 1 else arr.shape[1], np.nan)
def safe_kurtosis(arr, axis=1, bias=False, batch_size=50000):
try:
logger.debug(f"safe_kurtosis: input array shape: {arr.shape}")
var = np.var(arr, axis=axis, ddof=1 if not bias else 0)
result = np.full(arr.shape[0] if axis == 1 else arr.shape[1], np.nan)
valid_mask = var > 1e-10
logger.debug(f"safe_kurtosis: valid_mask count: {np.sum(valid_mask)}/{len(valid_mask)}")
valid_indices = np.where(valid_mask)[0]
for start in range(0, len(valid_indices), batch_size):
batch_idx = valid_indices[start:start + batch_size]
if axis == 1:
sub_arr = arr[batch_idx]
else:
sub_arr = arr[:, batch_idx]
try:
sub_kurt = kurtosis(sub_arr, axis=axis, bias=bias, nan_policy='omit')
result[batch_idx] = sub_kurt
except Exception as e:
logger.warning(f"safe_kurtosis batch error: {e}")
continue
logger.debug(f"safe_kurtosis: result shape: {result.shape}, nan count: {np.sum(np.isnan(result))}")
return result
except Exception as e:
logger.error(f"safe_kurtosis error: {e}")
logger.error(traceback.format_exc())
return np.full(arr.shape[0] if axis == 1 else arr.shape[1], np.nan)
# --- 1. 数据加载与预处理 ---
def single_station(station_id="212"):
try:
logger.info(f"开始处理站点 {station_id}")
# 检查输入目录是否存在
input_dir = f"./data/ef/{station_id}"
if not os.path.exists(input_dir):
logger.error(f"输入目录不存在: {input_dir}")
return False
# 加载并合并Parquet文件
df_list = []
file_pattern = f"./data/ef/{station_id}/{station_id}_*.parquet"
files = glob.glob(file_pattern)
if not files:
logger.error(f"没有找到匹配的文件: {file_pattern}")
return False
logger.info(f"找到 {len(files)} 个文件: {files}")
for file_path in files:
try:
df_temp = pd.read_parquet(file_path)
df_list.append(df_temp)
logger.info(f"成功加载文件: {file_path}, 数据量: {len(df_temp)}")
except Exception as e:
logger.error(f"加载文件失败 {file_path}: {e}")
continue
if not df_list:
logger.error(f"没有成功加载任何数据文件")
return False
df = pd.concat(df_list, ignore_index=True)
logger.info(f"合并后数据量: {len(df)}")
df = df.rename(columns={"HappenTime": "time", "ElectricField": "ef"})
df["time"] = pd.to_datetime(df["time"]) # 确保'time'列为datetime类型
df = df.sort_values('time').drop_duplicates(subset='time', keep='first')
df = df.drop_duplicates(subset='time', keep='first')
if len(df) == 0:
logger.error(f"处理后数据为空")
return False
# 创建一个从头到尾每秒连续的时间索引
full_time_index = pd.date_range(start=df["time"].min(), end=df["time"].max(), freq="s")
df = df.set_index("time").reindex(full_time_index).rename_axis("time")
# --- 修正 1处理 NaN (空值) ---
# 在线性插值前处理reindex引入的NaN值
df['ef'] = df['ef'].interpolate(method='linear', limit_direction='both')
df = df.reset_index()
# --- 2. 统计特征提取 ---
window_sizes = [2, 5, 10, 30, 60, 300, 600, 1200]
def rolling_window(a, window):
shape = (a.size - window + 1, window)
strides = (a.strides[0], a.strides[0])
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
stat_features_all = pd.DataFrame({'time': df['time']})
ef = df['ef'].values
def pad_result(res, win):
# 前面补 nan使长度与原始数据一致
return np.concatenate([np.full(win-1, np.nan), res])
for win in window_sizes:
logger.info(f"正在处理{win}")
if len(ef) < win:
# 数据太短,全部填 nan
n = len(ef)
stat_features_all[f'mean_{win}s'] = np.full(n, np.nan)
stat_features_all[f'std_{win}s'] = np.full(n, np.nan)
stat_features_all[f'var_{win}s'] = np.full(n, np.nan)
stat_features_all[f'min_{win}s'] = np.full(n, np.nan)
stat_features_all[f'max_{win}s'] = np.full(n, np.nan)
stat_features_all[f'ptp_{win}s'] = np.full(n, np.nan)
stat_features_all[f'skew_{win}s'] = np.full(n, np.nan)
stat_features_all[f'kurt_{win}s'] = np.full(n, np.nan)
stat_features_all[f'diff_mean_{win}s'] = np.full(n, np.nan)
stat_features_all[f'diff_std_{win}s'] = np.full(n, np.nan)
continue
win_arr = rolling_window(ef, win)
stat_features_all[f'mean_{win}s'] = pad_result(win_arr.mean(axis=1), win)
stat_features_all[f'std_{win}s'] = pad_result(win_arr.std(axis=1, ddof=1), win)
stat_features_all[f'var_{win}s'] = pad_result(win_arr.var(axis=1, ddof=1), win)
stat_features_all[f'min_{win}s'] = pad_result(win_arr.min(axis=1), win)
stat_features_all[f'max_{win}s'] = pad_result(win_arr.max(axis=1), win)
stat_features_all[f'ptp_{win}s'] = pad_result(np.ptp(win_arr, axis=1), win)
stat_features_all[f'skew_{win}s'] = pad_result(safe_skew(win_arr, axis=1, bias=False), win)
stat_features_all[f'kurt_{win}s'] = pad_result(safe_kurtosis(win_arr, axis=1, bias=False), win)
# 差分特征
diff = np.diff(ef, prepend=ef[0])
diff_win_arr = rolling_window(diff, win)
stat_features_all[f'diff_mean_{win}s'] = pad_result(diff_win_arr.mean(axis=1), win)
stat_features_all[f'diff_std_{win}s'] = pad_result(diff_win_arr.std(axis=1, ddof=1), win)
save_path = f"./data/preprocessed_ef/ef_{station_id}.parquet"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
logger.info(f"开始保存{save_path}")
stat_features_all.to_parquet(save_path, index=False)
logger.info(f"成功保存到 {save_path}, 数据量: {len(stat_features_all)}")
logger.info(f"数据预览:\n{stat_features_all.head()}")
return True
except Exception as e:
logger.error(f"处理站点 {station_id} 时发生错误: {e}")
logger.error(traceback.format_exc())
return False
if __name__ == "__main__":
stations = ["213","249","251","252","253","254","261","262","263","266","267","268","269","270","271","272","276","281","282","283","285","286","212"]
success_count = 0
failed_stations = []
for station_id in stations:
logger.info(f"\n{'='*50}")
logger.info(f"开始处理站点: {station_id}")
if single_station(station_id):
success_count += 1
logger.info(f"站点 {station_id} 处理成功")
else:
failed_stations.append(station_id)
logger.error(f"站点 {station_id} 处理失败")
logger.info(f"\n{'='*50}")
logger.info(f"处理完成!")
logger.info(f"成功处理: {success_count}/{len(stations)} 个站点")
if failed_stations:
logger.error(f"失败站点: {failed_stations}")
# 检查输出目录中的文件
output_dir = "./data/preprocessed_ef"
if os.path.exists(output_dir):
output_files = [f for f in os.listdir(output_dir) if f.endswith('.parquet')]
logger.info(f"输出目录中共有 {len(output_files)} 个parquet文件")
for file in output_files:
logger.info(f" - {file}")
else:
logger.error(f"输出目录不存在: {output_dir}")