移动眼动瞳孔直径数据预处理1(瞳孔眨眼提取及插值修复、异常点修复)

在老鼠瞳孔测量实验中,鼠眨眼会造成瞳孔直径(PD)曲线失真,这是由多种物理和生理机制共同导致的伪影。这里分享一种处理方法:获取眨眼区间,根据眨眼区间前后的瞳孔直径,构造线性插值,替换眨眼区间的数据。

准备工作

使用代码前请安装好python(演示的版本是3.8.12),以及代码依赖的库:numpy, pandas, h5py(可以通过pip install packagename 指令进行库的下载)

import glob
import os.path
from itertools import groupby
from operator import itemgetter

import h5py   # 3.11.0
import pandas as pd  # 2.0.3
import numpy as np  # 1.24.4
from numpy.lib.stride_tricks import sliding_window_view

参数设置

代码中必须修改设置的是眼动工程文件夹路径project_path 和完成眼动追踪生成的h5文件名file_names。可以选择修改的是眨眼提取阈值blink_threshold、眨眼范围内的邻域点数neighbor_points、异常点相比于平均值的最大倍数max_mul等。

project_path = r'C:\Users\ASUS\Desktop\111'     # 批量替换分析完成的工程的路径
file_names = f'*-camera-eye-result.h5'    # 在工程路径下查找的文件名,默认是全部h5文件,如果只需要对一个样本进行处理,可以指定文件名
file_path = glob.glob(os.path.join(project_path, 'results', file_names))
blink_threshold = 0.90  # 眨眼阈值,平均置信度小于该阈值的点视为眨眼
neighbor_points = 3  # 干扰范围内的邻域点数

完整代码

对应修改完上述参数后即可运行完整代码:

# coding=utf-8
# python3
"""
@FileName:论坛分享.py
@Author: FangXiang
@Company: Guangdong BayONE Scientific CO., Ltd.
@Email: x.fang@bayonesci.com
@CreateFileTime: 2026/1/22 10:42
@Description: 对移动眼动分析完的数据进行眨眼区间插值修复、异常点修复,最后导出为 csv 格式
"""
import glob
import os.path
from itertools import groupby
from operator import itemgetter

import h5py   # 3.11.0
import pandas as pd  # 2.0.3
import numpy as np  # 1.24.4
from numpy.lib.stride_tricks import sliding_window_view


def linear_interpolation(df, column, start, end, neighbor_points=3, decimals=3):
    """
    对指定范围和 NaN 值进行多项式插值修复。
    参数:
        df: DataFrame 数据表
        column: 需要插值的列名
        start: 干扰范围起始索引
        end: 干扰范围结束索引
        neighbor_points: 干扰范围内的邻域点数
        decimals: 小数位数
    返回:
        None
    """
    # 获取插值范围两端的点
    start_margin_indices = df[(df['frame'] >= max(start - neighbor_points, df['frame'].min())) &
                              (df['frame'] < start)]['frame']
    start_margin_values = df[(df['frame'] >= max(start - neighbor_points, df['frame'].min())) &
                             (df['frame'] < start)][column]

    end_margin_indices = df[(df['frame'] > end) &
                            (df['frame'] <= min(end + neighbor_points, df['frame'].max()))]['frame']
    end_margin_values = df[(df['frame'] > end) &
                           (df['frame'] <= min(end + neighbor_points, df['frame'].max()))][column]

    # 合并两端的点
    margin_indices = pd.concat([start_margin_indices, end_margin_indices])
    margin_values = pd.concat([start_margin_values, end_margin_values])

    # 插值范围内的点
    interference_indices = df[(df['frame'] >= start) & (df['frame'] <= end)]['frame']

    # 构造线性插值
    interpolated_values = np.interp(interference_indices, margin_indices, margin_values)

    # 四舍五入到指定的小数位数
    interpolated_values = np.round(interpolated_values, decimals)

    # 替换范围内的数据
    df.loc[(df['frame'] >= start) & (df['frame'] <= end), column] = interpolated_values

    # 修复 NaN 数据
    if df[column].isna().any():
        valid_indices = df[df[column].notna()]['frame']
        valid_values = df[df[column].notna()][column]
        nan_indices = df[df[column].isna()]['frame']
        interpolated_values = np.interp(nan_indices, valid_indices, valid_values)
        interpolated_values = np.round(interpolated_values, decimals)
        df.loc[df[column].isna(), column] = interpolated_values


def fast_outlier_removal(pupil, window_size=3, k=2.0, max_mul=5):
    """
    异常点修复:基于滑动窗口判断,适应信号的局部变化,仅修正异常点,不修改正常数据点
    参数:
        pupil: DataFrame 数据表
        window_size: 窗口大小
        k: 异常点修复系数
        max_mul: 异常点相比于平均值的最大倍数
    返回:
        None
    """
    data = np.asarray(pupil, dtype=np.float32)
    mean_val = np.mean(data)
    print(f'mean_val: {mean_val}')
    data[data > mean_val * max_mul] = mean_val

    half = window_size // 2

    # 边缘填充
    padded = np.pad(data, (half, half), mode='reflect')
    windows = sliding_window_view(padded, window_size)  # shape (N, window_size)   63093

    # median 和 mad 按行计算
    median = np.median(windows, axis=1)  # (N,)
    mad = np.median(np.abs(windows - median[:, None]), axis=1) + 1e-6  # (N,)

    # mask:窗口内哪些值是异常
    is_outlier = windows > (median[:, None] + k * mad[:, None])  # 正确广播

    # 去异常后的平均
    windows_sum = np.sum(windows * (~is_outlier), axis=1)
    windows_count = np.sum(~is_outlier, axis=1)
    windows_mean = windows_sum / windows_count

    # 判断当前点是否异常
    point_is_outlier = data > (median + k * mad)

    # 替换异常点
    result = data.copy()
    result[point_is_outlier] = windows_mean[point_is_outlier]

    return result


project_path = r'E:\Project Test\eye project\demo' # 批量替换分析完成的工程的路径
file_names = f'*-camera-eye-result.h5'    # 在工程路径下查找的文件名,默认是全部h5文件,如果只需要对一个样本进行处理,可以指定文件名
file_path = glob.glob(os.path.join(project_path, 'results', file_names))
blink_threshold = 0.90  # 眨眼阈值,平均置信度小于该阈值的点视为眨眼
neighbor_points = 3  # 干扰范围内的邻域点数

lens = len(file_path)
count = 0
for result_h5 in file_path:
    count += 1
    print(f'开始处理:{count}/{lens}', result_h5)

    h5_file = h5py.File(result_h5, 'r')
    if 'key_point_coordinate' in h5_file.keys():
        dataset_key = h5_file['key_point_coordinate']['coordinate_data']
        dataset_processed = h5_file['pupil_parameter_processed']['pupil_data'][:, 0]
    elif 'PointPosition' in h5_file.keys():
        dataset_key = h5_file['PointPosition']
        dataset_processed = h5_file['PupilData'][:, 2]
    else:
        print('样本未做瞳孔追踪,无坐标数据')
        continue

    likelihood = np.mean([dataset_key[:, 3 * i + 2] for i in range(9)], axis=0)
    indexes = np.where(likelihood < blink_threshold)[0]  # 取出平均置信度低于 0.90 的点

    blinks = []
    for k, g in groupby(enumerate(indexes), lambda x: x[1] - x[0]):
        group = list(map(itemgetter(1), g))
        blinks.append((group[0], group[-1]))
    if blinks:
        # 动态合并间隔
        gap_threshold = 2
        blinks = sorted(blinks, key=lambda x: x[0])
        merged = [blinks[0]]
        for curr in blinks[1:]:
            last_start, last_end = merged[-1]
            curr_start, curr_end = curr
            if curr_start - last_end - 1 <= gap_threshold:
                merged[-1] = (last_start, max(last_end, curr_end))
            else:
                merged.append(curr)
        blinks = merged
    print('blinks:', blinks)
    h5_file.close()

    if not blinks:
        print('没有眨眼事件,跳过此样本')
        continue

    dataset_processed_numpy = dataset_processed[:]

    data = pd.DataFrame({
        'frame': range(1, len(dataset_processed_numpy[:]) + 1),
        'pupil_diameter': dataset_processed_numpy[:],
        'pupil_diameter_interp': dataset_processed_numpy[:],
    })

    for blink in blinks:
        # 定义需要插值的范围
        interference_start, interference_end = blink[0] - 1, blink[-1] + 1
        # 眨眼区间插值修复
        linear_interpolation(data, 'pupil_diameter_interp', interference_start, interference_end, neighbor_points)
    # 异常点修复
    data['pupil_diameter_interp'] = fast_outlier_removal(data['pupil_diameter_interp'])
    csv_path = result_h5.replace('-result.h5', '-interp.csv')
    data.to_csv(csv_path, index=False, float_format='%.6f')
    print(f'write csv success: {csv_path}!')