init
This commit is contained in:
233
1_get_ef_data.py
Normal file
233
1_get_ef_data.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import pandas as pd
|
||||
import os, glob
|
||||
import numpy as np
|
||||
from scipy.stats import skew, kurtosis
|
||||
import pywt
|
||||
from scipy.signal import butter, filtfilt
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from loguru import logger
|
||||
import warnings
|
||||
import traceback
|
||||
|
||||
logger.add("1_get_ef_data.log", rotation="10 MB", retention="10 days", level="DEBUG")
|
||||
|
||||
# 抑制特定的警告
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning,
|
||||
message=".*Precision loss occurred in moment calculation.*")
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning,
|
||||
message=".*invalid value encountered in divide.*")
|
||||
|
||||
def safe_skew(arr, axis=1, bias=False, batch_size=50000):
|
||||
try:
|
||||
logger.debug(f"safe_skew: input array shape: {arr.shape}")
|
||||
var = np.var(arr, axis=axis, ddof=1 if not bias else 0)
|
||||
result = np.full(arr.shape[0] if axis == 1 else arr.shape[1], np.nan)
|
||||
valid_mask = var > 1e-10
|
||||
logger.debug(f"safe_skew: valid_mask count: {np.sum(valid_mask)}/{len(valid_mask)}")
|
||||
|
||||
valid_indices = np.where(valid_mask)[0]
|
||||
for start in range(0, len(valid_indices), batch_size):
|
||||
batch_idx = valid_indices[start:start + batch_size]
|
||||
if axis == 1:
|
||||
sub_arr = arr[batch_idx]
|
||||
else:
|
||||
sub_arr = arr[:, batch_idx]
|
||||
|
||||
try:
|
||||
sub_skew = skew(sub_arr, axis=axis, bias=bias, nan_policy='omit')
|
||||
result[batch_idx] = sub_skew
|
||||
except Exception as e:
|
||||
logger.warning(f"safe_skew batch error: {e}")
|
||||
continue
|
||||
|
||||
logger.debug(f"safe_skew: result shape: {result.shape}, nan count: {np.sum(np.isnan(result))}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"safe_skew error: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return np.full(arr.shape[0] if axis == 1 else arr.shape[1], np.nan)
|
||||
|
||||
|
||||
def safe_kurtosis(arr, axis=1, bias=False, batch_size=50000):
|
||||
try:
|
||||
logger.debug(f"safe_kurtosis: input array shape: {arr.shape}")
|
||||
var = np.var(arr, axis=axis, ddof=1 if not bias else 0)
|
||||
result = np.full(arr.shape[0] if axis == 1 else arr.shape[1], np.nan)
|
||||
valid_mask = var > 1e-10
|
||||
logger.debug(f"safe_kurtosis: valid_mask count: {np.sum(valid_mask)}/{len(valid_mask)}")
|
||||
|
||||
valid_indices = np.where(valid_mask)[0]
|
||||
for start in range(0, len(valid_indices), batch_size):
|
||||
batch_idx = valid_indices[start:start + batch_size]
|
||||
if axis == 1:
|
||||
sub_arr = arr[batch_idx]
|
||||
else:
|
||||
sub_arr = arr[:, batch_idx]
|
||||
|
||||
try:
|
||||
sub_kurt = kurtosis(sub_arr, axis=axis, bias=bias, nan_policy='omit')
|
||||
result[batch_idx] = sub_kurt
|
||||
except Exception as e:
|
||||
logger.warning(f"safe_kurtosis batch error: {e}")
|
||||
continue
|
||||
|
||||
logger.debug(f"safe_kurtosis: result shape: {result.shape}, nan count: {np.sum(np.isnan(result))}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"safe_kurtosis error: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return np.full(arr.shape[0] if axis == 1 else arr.shape[1], np.nan)
|
||||
|
||||
|
||||
# --- 1. 数据加载与预处理 ---
|
||||
def single_station(station_id="212"):
|
||||
try:
|
||||
logger.info(f"开始处理站点 {station_id}")
|
||||
|
||||
# 检查输入目录是否存在
|
||||
input_dir = f"./data/ef/{station_id}"
|
||||
if not os.path.exists(input_dir):
|
||||
logger.error(f"输入目录不存在: {input_dir}")
|
||||
return False
|
||||
|
||||
# 加载并合并Parquet文件
|
||||
df_list = []
|
||||
file_pattern = f"./data/ef/{station_id}/{station_id}_*.parquet"
|
||||
files = glob.glob(file_pattern)
|
||||
|
||||
if not files:
|
||||
logger.error(f"没有找到匹配的文件: {file_pattern}")
|
||||
return False
|
||||
|
||||
logger.info(f"找到 {len(files)} 个文件: {files}")
|
||||
|
||||
for file_path in files:
|
||||
try:
|
||||
df_temp = pd.read_parquet(file_path)
|
||||
df_list.append(df_temp)
|
||||
logger.info(f"成功加载文件: {file_path}, 数据量: {len(df_temp)}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载文件失败 {file_path}: {e}")
|
||||
continue
|
||||
|
||||
if not df_list:
|
||||
logger.error(f"没有成功加载任何数据文件")
|
||||
return False
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
logger.info(f"合并后数据量: {len(df)}")
|
||||
|
||||
df = df.rename(columns={"HappenTime": "time", "ElectricField": "ef"})
|
||||
df["time"] = pd.to_datetime(df["time"]) # 确保'time'列为datetime类型
|
||||
df = df.sort_values('time').drop_duplicates(subset='time', keep='first')
|
||||
df = df.drop_duplicates(subset='time', keep='first')
|
||||
|
||||
if len(df) == 0:
|
||||
logger.error(f"处理后数据为空")
|
||||
return False
|
||||
|
||||
# 创建一个从头到尾每秒连续的时间索引
|
||||
full_time_index = pd.date_range(start=df["time"].min(), end=df["time"].max(), freq="s")
|
||||
df = df.set_index("time").reindex(full_time_index).rename_axis("time")
|
||||
|
||||
# --- 修正 1:处理 NaN (空值) ---
|
||||
# 在线性插值前处理reindex引入的NaN值
|
||||
df['ef'] = df['ef'].interpolate(method='linear', limit_direction='both')
|
||||
df = df.reset_index()
|
||||
|
||||
# --- 2. 统计特征提取 ---
|
||||
|
||||
window_sizes = [2, 5, 10, 30, 60, 300, 600, 1200]
|
||||
|
||||
def rolling_window(a, window):
|
||||
shape = (a.size - window + 1, window)
|
||||
strides = (a.strides[0], a.strides[0])
|
||||
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
|
||||
|
||||
stat_features_all = pd.DataFrame({'time': df['time']})
|
||||
ef = df['ef'].values
|
||||
|
||||
def pad_result(res, win):
|
||||
# 前面补 nan,使长度与原始数据一致
|
||||
return np.concatenate([np.full(win-1, np.nan), res])
|
||||
|
||||
for win in window_sizes:
|
||||
logger.info(f"正在处理{win}")
|
||||
if len(ef) < win:
|
||||
# 数据太短,全部填 nan
|
||||
n = len(ef)
|
||||
stat_features_all[f'mean_{win}s'] = np.full(n, np.nan)
|
||||
stat_features_all[f'std_{win}s'] = np.full(n, np.nan)
|
||||
stat_features_all[f'var_{win}s'] = np.full(n, np.nan)
|
||||
stat_features_all[f'min_{win}s'] = np.full(n, np.nan)
|
||||
stat_features_all[f'max_{win}s'] = np.full(n, np.nan)
|
||||
stat_features_all[f'ptp_{win}s'] = np.full(n, np.nan)
|
||||
stat_features_all[f'skew_{win}s'] = np.full(n, np.nan)
|
||||
stat_features_all[f'kurt_{win}s'] = np.full(n, np.nan)
|
||||
stat_features_all[f'diff_mean_{win}s'] = np.full(n, np.nan)
|
||||
stat_features_all[f'diff_std_{win}s'] = np.full(n, np.nan)
|
||||
continue
|
||||
|
||||
win_arr = rolling_window(ef, win)
|
||||
stat_features_all[f'mean_{win}s'] = pad_result(win_arr.mean(axis=1), win)
|
||||
stat_features_all[f'std_{win}s'] = pad_result(win_arr.std(axis=1, ddof=1), win)
|
||||
stat_features_all[f'var_{win}s'] = pad_result(win_arr.var(axis=1, ddof=1), win)
|
||||
stat_features_all[f'min_{win}s'] = pad_result(win_arr.min(axis=1), win)
|
||||
stat_features_all[f'max_{win}s'] = pad_result(win_arr.max(axis=1), win)
|
||||
stat_features_all[f'ptp_{win}s'] = pad_result(np.ptp(win_arr, axis=1), win)
|
||||
stat_features_all[f'skew_{win}s'] = pad_result(safe_skew(win_arr, axis=1, bias=False), win)
|
||||
stat_features_all[f'kurt_{win}s'] = pad_result(safe_kurtosis(win_arr, axis=1, bias=False), win)
|
||||
|
||||
# 差分特征
|
||||
diff = np.diff(ef, prepend=ef[0])
|
||||
diff_win_arr = rolling_window(diff, win)
|
||||
stat_features_all[f'diff_mean_{win}s'] = pad_result(diff_win_arr.mean(axis=1), win)
|
||||
stat_features_all[f'diff_std_{win}s'] = pad_result(diff_win_arr.std(axis=1, ddof=1), win)
|
||||
|
||||
save_path = f"./data/preprocessed_ef/ef_{station_id}.parquet"
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
logger.info(f"开始保存{save_path}")
|
||||
stat_features_all.to_parquet(save_path, index=False)
|
||||
logger.info(f"成功保存到 {save_path}, 数据量: {len(stat_features_all)}")
|
||||
logger.info(f"数据预览:\n{stat_features_all.head()}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理站点 {station_id} 时发生错误: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
stations = ["213","249","251","252","253","254","261","262","263","266","267","268","269","270","271","272","276","281","282","283","285","286","212"]
|
||||
|
||||
success_count = 0
|
||||
failed_stations = []
|
||||
|
||||
for station_id in stations:
|
||||
logger.info(f"\n{'='*50}")
|
||||
logger.info(f"开始处理站点: {station_id}")
|
||||
|
||||
if single_station(station_id):
|
||||
success_count += 1
|
||||
logger.info(f"站点 {station_id} 处理成功")
|
||||
else:
|
||||
failed_stations.append(station_id)
|
||||
logger.error(f"站点 {station_id} 处理失败")
|
||||
|
||||
logger.info(f"\n{'='*50}")
|
||||
logger.info(f"处理完成!")
|
||||
logger.info(f"成功处理: {success_count}/{len(stations)} 个站点")
|
||||
|
||||
if failed_stations:
|
||||
logger.error(f"失败站点: {failed_stations}")
|
||||
|
||||
# 检查输出目录中的文件
|
||||
output_dir = "./data/preprocessed_ef"
|
||||
if os.path.exists(output_dir):
|
||||
output_files = [f for f in os.listdir(output_dir) if f.endswith('.parquet')]
|
||||
logger.info(f"输出目录中共有 {len(output_files)} 个parquet文件")
|
||||
for file in output_files:
|
||||
logger.info(f" - {file}")
|
||||
else:
|
||||
logger.error(f"输出目录不存在: {output_dir}")
|
||||
Reference in New Issue
Block a user