在老鼠瞳孔测量实验中,鼠眨眼会造成瞳孔直径(PD)曲线失真,这是由多种物理和生理机制共同导致的伪影。这里分享一种处理方法:获取眨眼区间,根据眨眼区间前后的瞳孔直径,构造线性插值,替换眨眼区间的数据。
准备工作
使用代码前请安装好python(演示的版本是3.8.12),以及代码依赖的库:numpy, pandas, h5py(可以通过pip install packagename 指令进行库的下载)
import glob
import os.path
from itertools import groupby
from operator import itemgetter
import h5py # 3.11.0
import pandas as pd # 2.0.3
import numpy as np # 1.24.4
from numpy.lib.stride_tricks import sliding_window_view
参数设置
代码中必须修改设置的是眼动工程文件夹路径project_path 和完成眼动追踪生成的h5文件名file_names。可以选择修改的是眨眼提取阈值blink_threshold、眨眼范围内的邻域点数neighbor_points、异常点相比于平均值的最大倍数max_mul等。
project_path = r'C:\Users\ASUS\Desktop\111' # 批量替换分析完成的工程的路径
file_names = f'*-camera-eye-result.h5' # 在工程路径下查找的文件名,默认是全部h5文件,如果只需要对一个样本进行处理,可以指定文件名
file_path = glob.glob(os.path.join(project_path, 'results', file_names))
blink_threshold = 0.90 # 眨眼阈值,平均置信度小于该阈值的点视为眨眼
neighbor_points = 3 # 干扰范围内的邻域点数
完整代码
对应修改完上述参数后即可运行完整代码:
# coding=utf-8
# python3
"""
@FileName:论坛分享.py
@Author: FangXiang
@Company: Guangdong BayONE Scientific CO., Ltd.
@Email: x.fang@bayonesci.com
@CreateFileTime: 2026/1/22 10:42
@Description: 对移动眼动分析完的数据进行眨眼区间插值修复、异常点修复,最后导出为 csv 格式
"""
import glob
import os.path
from itertools import groupby
from operator import itemgetter
import h5py # 3.11.0
import pandas as pd # 2.0.3
import numpy as np # 1.24.4
from numpy.lib.stride_tricks import sliding_window_view
def linear_interpolation(df, column, start, end, neighbor_points=3, decimals=3):
"""
对指定范围和 NaN 值进行多项式插值修复。
参数:
df: DataFrame 数据表
column: 需要插值的列名
start: 干扰范围起始索引
end: 干扰范围结束索引
neighbor_points: 干扰范围内的邻域点数
decimals: 小数位数
返回:
None
"""
# 获取插值范围两端的点
start_margin_indices = df[(df['frame'] >= max(start - neighbor_points, df['frame'].min())) &
(df['frame'] < start)]['frame']
start_margin_values = df[(df['frame'] >= max(start - neighbor_points, df['frame'].min())) &
(df['frame'] < start)][column]
end_margin_indices = df[(df['frame'] > end) &
(df['frame'] <= min(end + neighbor_points, df['frame'].max()))]['frame']
end_margin_values = df[(df['frame'] > end) &
(df['frame'] <= min(end + neighbor_points, df['frame'].max()))][column]
# 合并两端的点
margin_indices = pd.concat([start_margin_indices, end_margin_indices])
margin_values = pd.concat([start_margin_values, end_margin_values])
# 插值范围内的点
interference_indices = df[(df['frame'] >= start) & (df['frame'] <= end)]['frame']
# 构造线性插值
interpolated_values = np.interp(interference_indices, margin_indices, margin_values)
# 四舍五入到指定的小数位数
interpolated_values = np.round(interpolated_values, decimals)
# 替换范围内的数据
df.loc[(df['frame'] >= start) & (df['frame'] <= end), column] = interpolated_values
# 修复 NaN 数据
if df[column].isna().any():
valid_indices = df[df[column].notna()]['frame']
valid_values = df[df[column].notna()][column]
nan_indices = df[df[column].isna()]['frame']
interpolated_values = np.interp(nan_indices, valid_indices, valid_values)
interpolated_values = np.round(interpolated_values, decimals)
df.loc[df[column].isna(), column] = interpolated_values
def fast_outlier_removal(pupil, window_size=3, k=2.0, max_mul=5):
"""
异常点修复:基于滑动窗口判断,适应信号的局部变化,仅修正异常点,不修改正常数据点
参数:
pupil: DataFrame 数据表
window_size: 窗口大小
k: 异常点修复系数
max_mul: 异常点相比于平均值的最大倍数
返回:
None
"""
data = np.asarray(pupil, dtype=np.float32)
mean_val = np.mean(data)
print(f'mean_val: {mean_val}')
data[data > mean_val * max_mul] = mean_val
half = window_size // 2
# 边缘填充
padded = np.pad(data, (half, half), mode='reflect')
windows = sliding_window_view(padded, window_size) # shape (N, window_size) 63093
# median 和 mad 按行计算
median = np.median(windows, axis=1) # (N,)
mad = np.median(np.abs(windows - median[:, None]), axis=1) + 1e-6 # (N,)
# mask:窗口内哪些值是异常
is_outlier = windows > (median[:, None] + k * mad[:, None]) # 正确广播
# 去异常后的平均
windows_sum = np.sum(windows * (~is_outlier), axis=1)
windows_count = np.sum(~is_outlier, axis=1)
windows_mean = windows_sum / windows_count
# 判断当前点是否异常
point_is_outlier = data > (median + k * mad)
# 替换异常点
result = data.copy()
result[point_is_outlier] = windows_mean[point_is_outlier]
return result
project_path = r'E:\Project Test\eye project\demo' # 批量替换分析完成的工程的路径
file_names = f'*-camera-eye-result.h5' # 在工程路径下查找的文件名,默认是全部h5文件,如果只需要对一个样本进行处理,可以指定文件名
file_path = glob.glob(os.path.join(project_path, 'results', file_names))
blink_threshold = 0.90 # 眨眼阈值,平均置信度小于该阈值的点视为眨眼
neighbor_points = 3 # 干扰范围内的邻域点数
lens = len(file_path)
count = 0
for result_h5 in file_path:
count += 1
print(f'开始处理:{count}/{lens}', result_h5)
h5_file = h5py.File(result_h5, 'r')
if 'key_point_coordinate' in h5_file.keys():
dataset_key = h5_file['key_point_coordinate']['coordinate_data']
dataset_processed = h5_file['pupil_parameter_processed']['pupil_data'][:, 0]
elif 'PointPosition' in h5_file.keys():
dataset_key = h5_file['PointPosition']
dataset_processed = h5_file['PupilData'][:, 2]
else:
print('样本未做瞳孔追踪,无坐标数据')
continue
likelihood = np.mean([dataset_key[:, 3 * i + 2] for i in range(9)], axis=0)
indexes = np.where(likelihood < blink_threshold)[0] # 取出平均置信度低于 0.90 的点
blinks = []
for k, g in groupby(enumerate(indexes), lambda x: x[1] - x[0]):
group = list(map(itemgetter(1), g))
blinks.append((group[0], group[-1]))
if blinks:
# 动态合并间隔
gap_threshold = 2
blinks = sorted(blinks, key=lambda x: x[0])
merged = [blinks[0]]
for curr in blinks[1:]:
last_start, last_end = merged[-1]
curr_start, curr_end = curr
if curr_start - last_end - 1 <= gap_threshold:
merged[-1] = (last_start, max(last_end, curr_end))
else:
merged.append(curr)
blinks = merged
print('blinks:', blinks)
h5_file.close()
if not blinks:
print('没有眨眼事件,跳过此样本')
continue
dataset_processed_numpy = dataset_processed[:]
data = pd.DataFrame({
'frame': range(1, len(dataset_processed_numpy[:]) + 1),
'pupil_diameter': dataset_processed_numpy[:],
'pupil_diameter_interp': dataset_processed_numpy[:],
})
for blink in blinks:
# 定义需要插值的范围
interference_start, interference_end = blink[0] - 1, blink[-1] + 1
# 眨眼区间插值修复
linear_interpolation(data, 'pupil_diameter_interp', interference_start, interference_end, neighbor_points)
# 异常点修复
data['pupil_diameter_interp'] = fast_outlier_removal(data['pupil_diameter_interp'])
csv_path = result_h5.replace('-result.h5', '-interp.csv')
data.to_csv(csv_path, index=False, float_format='%.6f')
print(f'write csv success: {csv_path}!')