快来试试!教你从零搭建动物姿态估计模型,鼠鼠看了都说简单!

零基础搭建动物姿态估计模型(附相关Python代码)

适合人群:给各位正在使用BehaviorAtlas相关产品,想自己尝试动手编写代码的同学。
所需工具:Roboflow标注平台 + YOLOv8姿态估计模型 + Python行为分析


一、环境准备

首先让我们准备所需的Python环境

# 创建虚拟环境
conda create -n myenv 
conda activate myenv
#安装所需Python包
pip install ultralytics roboflow numpy pandas matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple

二、数据准备流程

1. 获取Kaggle数据集

推荐使用包含清晰动物全身的公开数据集,可以是各位实验中拍摄的鼠鼠的视频、图片。这里用kaggle上的猫品种分类数据集

2. Roboflow标注步骤

  1. 访问 roboflow.com 新建一个项目,命名项目,选择”Keypoint Detection“

2. 定义要标注的关键点,这里以鼻子、左耳、右耳为例(如果是标注鼠鼠,建议标注12-17个关键点,如:鼻尖、四肢关节、尾根等)。点击添加关键点,以骨架形式排列并命名好关键点。

  1. 上传图片数据集,可以选择单独图片,也可以选择文件夹。将我们已下载好的图片上传,点击保存和继续。随后选择手动标注,下一步可以选择将标注任务量自定义分配给当前账号或者加入该项目的其他人。这里我就全分配给自己。



  1. 现在可以开始标注啦,选择右边工具栏的”Skeleton Tool“,即可在图片上拖拽出我们在第二步中预设的关键点骨架,然后手动拖拽关键点,确保其对应动物的正确部位,最后确认右上角的label名称。




  大家在遇到遮挡情况时,不要着急,我们只需预标出它的位置,然后右键选择”Occluded“即可。

  1. 导出为YOLO格式(包含图片路径的data.yaml文件)。大家在标注完后,选择将图片加入到dataset,再点击”New version“,Preprocessing和Augmentation我们选择默认即可。最后点击”create“,选择下载YOLOv8格式到本机。





三、模型训练代码

能坚持到这一步的同学已经非常棒了,接下来我们只需在创建好的Python环境中,运行以下代码,即可训练我们的姿态估计模型啦。

from ultralytics import YOLO

# 加载预训练模型
model = YOLO('yolov8n-pose.pt')  # 基础版轻量模型

# 开始训练(需替换实际路径)
results = model.train(
    data='your_dataset_folder/data.yaml',
    epochs=100,
    imgsz=640,
    batch=8,
    name='cat_pose_v1'
)

参数说明

  • imgsz:输入图片尺寸(根据显存调整)
  • batch:批大小(RTX3060建议8-16)
  • device:可指定使用GPU编号

四、模型推理代码

# 加载训练好的模型
best_model = YOLO('runs/pose/animal_pose_v1/weights/best.pt')

# 单张图片推理
results = best_model.predict(
    source='test_image.jpg',
    save=True,
    conf=0.5  # 置信度阈值
)

# 可视化结果
results[0].plot()

五、行为分析示例

1. 关键点数据提取

# 提取首个检测目标的坐标
keypoints = results[0].keypoints.xy.cpu().numpy()[0]

# 示例:获取头部和尾部坐标
head = keypoints[0]  # 假设索引0为头部
tail_base = keypoints[-1]  # 假设最后一个为尾根

2. 基础行为分析示例

import matplotlib.pyplot as plt

# 计算移动距离
def calculate_distance(point1, point2):
    return ((point1[0]-point2[0])**2 + (point1[1]-point2[1])**2)**0.5

# 轨迹可视化(需连续帧数据)
trajectory_x = [pose[0] for pose in keypoints_sequence]  # 关键点x坐标序列
trajectory_y = [pose[1] for pose in keypoints_sequence]  # y坐标序列

plt.figure(figsize=(10,6))
plt.scatter(trajectory_x, trajectory_y, c=range(len(trajectory_x)), cmap='viridis')
plt.colorbar(label='Frame Index')
plt.title('Animal Movement Trajectory')
plt.show()

常见问题解答

Q:标注多少张图片合适?
A:建议至少500张含不同姿势的图片,复杂场景需1000+张

Q:训练出现过拟合怎么办?
A:尝试以下方法:

  1. 增加数据增强(旋转、模糊、明暗变化)
  2. 减小模型尺寸(改用yolov8n-pose)
  3. 添加L2正则化参数

Q:如何分析社交行为?
A:可计算多个个体间的:

  • 相对距离
  • 面向角度
  • 接触时间比例

延伸学习

  1. 使用deeplabcut进行更专业的行为分析
  2. 结合OpenCV计算运动速度
  3. 用Pandas分析时间序列特征
  4. 更多阅读:
      a.Ultralytics Pose Estimation
      b.How to Train a Custom Ultralytics YOLOv8 Pose Estimation Model
      c.How to train Deeplab on Custom Dataset

(注:实际使用请替换路径和关键点索引为您的真实数据配置)

5 个赞