admin管理员组

文章数量:823360

SlowFast训练自己的数据集

1. 数据集视频准备

本次训练以实验为目的,了解该框架的训练步骤,选取了1段30秒以上的关于打电话的视频。

2. 视频抽帧

目的:

(1)1秒抽1帧图片,目的是用来标注,ava数据集就是1秒1帧

(2)1秒抽30帧图片,目的是用来训练,据说因为slowfast在slow通道里1秒会采集到15帧,在fast通道里1秒会采集到2帧。

以下是运行代码:

video2img.py

import os
import shutil
from tqdm import tqdmstart = 0 
seconds = 30  video_path = './ava/videos'
labelframes_path = './ava/labelframes'
rawframes_path = './ava/rawframes'
cut_videos_sh_path = './cut_videos.sh'if os.path.exists(labelframes_path):#递归删除文件夹下的所有子文件夹和子文件shutil.rmtree(labelframes_path)
if os.path.exists(rawframes_path):shutil.rmtree(rawframes_path)fps = 30
raw_frames = seconds * fpswith open(cut_videos_sh_path, 'r') as f:sh = f.read()
sh = sh.replace(sh[sh.find('    ffmpeg'):],f'    ffmpeg -ss {start} -t {seconds} -i "${{video}}" -r 30 -strict experimental "${{out_name}}"\n  fi\ndone\n')
with open(cut_videos_sh_path, 'w') as f:f.write(sh)
# 902打到1798
os.system('bash cut_videos.sh')  #调用 bash cut_videos.sh该命令
os.system('bash extract_rgb_frames_ffmpeg.sh')
os.makedirs(labelframes_path, exist_ok=True)
video_ids = [video_id[:-4] for video_id in os.listdir(video_path)]
for video_id in tqdm(video_ids):for img_id in range(2 * fps + 1, (seconds - 2) * 30, fps):shutil.copyfile(os.path.join(rawframes_path, video_id, '08093_' + format(img_id, '05d') + '.jpg'),os.path.join(labelframes_path, video_id + '_' + format(start + img_id // 30, '05d') + '.jpg'))#shutil.rmtree(): 递归删除文件夹下的所有子文件夹和子文件
#os.path.join(): 连接两个或更多的路径名组件
#shutil.copyfile(file1,file2): 将文件file1复制到file2

extract_rgb_frames_ffmpeg.sh (抽帧)

IN_DATA_DIR="./ava/videos_cut"
OUT_DATA_DIR="./ava/rawframes"
if [[ ! -d "${OUT_DATA_DIR}" ]]; thenecho "${OUT_DATA_DIR} doesn't exist. Creating it.";mkdir -p ${OUT_DATA_DIR}
fi
for video in $(ls -A1 -U ${IN_DATA_DIR}/*)
dovideo_name=${video##*/}if [[ $video_name = *".webm" ]]; thenvideo_name=${video_name::-5}elsevideo_name=${video_name::-4}fiout_video_dir=${OUT_DATA_DIR}/${video_name}mkdir -p "${out_video_dir}"out_name="${out_video_dir}/${out_video_dir}_%05d.jpg"ffmpeg -i "${video}" -r 30 -q:v 1 "${out_name}"
done

cut_videos.sh(裁剪视频)


IN_DATA_DIR="./ava/videos"
OUT_DATA_DIR="./ava/videos_cut"
if [[ ! -d "${OUT_DATA_DIR}" ]]; thenecho "${OUT_DATA_DIR} doesn't exist. Creating it.";mkdir -p ${OUT_DATA_DIR}
fi
for video in $(ls -A1 -U ${IN_DATA_DIR}/*)
doout_name="${OUT_DATA_DIR}/${video##*/}"if [ ! -f "${out_name}" ]; thenffmpeg -ss 0 -t 30 -i "${video}" -r 30 -strict experimental "${out_name}"fi
done

注意.sh脚本是在linux中运行的,在windows下打开后,需转成unix格式,不然在linux下运行会报错。

以上3个脚本放在同一目录下,并在目录下创建ava/videos文件夹,将准备的1个视频放在videos文件夹下,由于视频的时长都在30秒以上,所以修改video2img.py中的seconds为30(这里要注意,seconds为视频结束时间,所以准备的视频文件时长都必须超过30秒)。


然后执行:python video2img.py

执行完成后,会在ava文件夹下生成三个文件夹,labelframes里存放的是需要标注的图片(1秒抽1帧的图片),rawframes里放的是每个视频文件每秒30帧的图片(用于slowfast训练),videos_cut文件夹里放的时裁剪后的视频文件(视频时长是1-30秒),videos里放的就是原视频文件。实际在以后的训练过程中,videos_cut和videos里的文件就已经没啥用处了,直接删掉就行。

3. 图片标注说明

实际上图片标注分为两种方式,1是自动标注,2是手动标注。

自动标注:使用faster rcnn自动把图片中的人框出来,然后我们再标注人的行为,如果待标注的图片数据量比较大,这种方式无疑是很好的,手动画框框是很累人的。

手动标注:也就是我们手动画框框,然后再标注人的行为,这种方式比较适合图片数据量比较小的情况。

本次训练的数据集较少,采用的是手动标注。

4. 图片标注

slowfast需要ava格式的数据集,先使用via工具标注图片中的行为,然后再使用脚本将导出的csv文件转为slowfast需要的ava格式即可。我使用的via版本为via-3.0.11。

via标注工具下载地址

下载完成后,双击via_image_annotator.html打开。

 (1)点击加号图标,将labelframes文件夹下全部图片导入

 (2)点击如下图所示图标,创建一个attribute

 

 

 (3)anchor选择第二项,input type选择checkbox,在options中定义人的四个行为:stand,sit,talk to,listen,用英文状态下的逗号分割开,然后preview中勾选四个行为。

 (4)开始标注图片,框选图片中的人,然后点击矩形框,勾选你认为人出现的行为,如下图所示:

(5)全部标注完成后,点击如下图所示图标:

保持默认选项,点击“Export”导出csv文件,注意,该csv文件最好不要用Excel打开进行编辑!!!

 

 此时会得到一个csv文件

 5. via数据集转为slowfast格式

slowfast数据集要求ava格式,同时需要提供pkl文件,使用以下python脚本可一键生成全部所需配置文件!

via2ava.py

"""
Theme:ava format data transformer
author:Hongbo Jiang
time:2022/3/14/1:51:51
description:这是一个数据格式转换器,根据mmaction2的ava数据格式转换规则将来自网站:.html的、标注好的、视频理解类型的csv文件转换为mmaction2指定的数据格式。转换规则:# AVA Annotation ExplainedIn this section, we explain the annotation format of AVA in details:```mmaction2├── data│   ├── ava│   │   ├── annotations│   │   |   ├── ava_dense_proposals_train.FAIR.recall_93.9.pkl│   │   |   ├── ava_dense_proposals_val.FAIR.recall_93.9.pkl│   │   |   ├── ava_dense_proposals_test.FAIR.recall_93.9.pkl│   │   |   ├── ava_train_v2.1.csv│   │   |   ├── ava_val_v2.1.csv│   │   |   ├── ava_train_excluded_timestamps_v2.1.csv│   │   |   ├── ava_val_excluded_timestamps_v2.1.csv│   │   |   ├── ava_action_list_v2.1.pbtxt```## The proposals generated by human detectorsIn the annotation folder, `ava_dense_proposals_[train/val/test].FAIR.recall_93.9.pkl` are human proposals generated by a human detector. They are used in training, validation and testing respectively. Take `ava_dense_proposals_train.FAIR.recall_93.9.pkl` as an example. It is a dictionary of size 203626. The key consists of the `videoID` and the `timestamp`. For example, the key `-5KQ66BBWC4,0902` means the values are the detection results for the frame at the $$902_{nd}$$ second in the video `-5KQ66BBWC4`. The values in the dictionary are numpy arrays with shape $$N \times 5$$ , $$N$$ is the number of detected human bounding boxes in the corresponding frame. The format of bounding box is $$[x_1, y_1, x_2, y_2, score], 0 \le x_1, y_1, x_2, w_2, score \le 1$$. $$(x_1, y_1)$$ indicates the top-left corner of the bounding box, $$(x_2, y_2)$$ indicates the bottom-right corner of the bounding box; $$(0, 0)$$ indicates the top-left corner of the image, while $$(1, 1)$$ indicates the bottom-right corner of the image.## The ground-truth labels for spatio-temporal action detectionIn the annotation folder, `ava_[train/val]_v[2.1/2.2].csv` are ground-truth labels for spatio-temporal action detection, which are used during training & validation. Take `ava_train_v2.1.csv` as an example, it is a csv file with 837318 lines, each line is the annotation for a human instance in one frame. For example, the first line in `ava_train_v2.1.csv` is `'-5KQ66BBWC4,0902,0.077,0.151,0.283,0.811,80,1'`: the first two items `-5KQ66BBWC4` and `0902` indicate that it corresponds to the $$902_{nd}$$ second in the video `-5KQ66BBWC4`. The next four items ($$[0.077(x_1), 0.151(y_1), 0.283(x_2), 0.811(y_2)]$$) indicates the location of the bounding box, the bbox format is the same as human proposals. The next item `80` is the action label. The last item `1` is the ID of this bounding box.## Excluded timestamps`ava_[train/val]_excludes_timestamps_v[2.1/2.2].csv` contains excluded timestamps which are not used during training or validation. The format is `video_id, second_idx` .## Label map`ava_action_list_v[2.1/2.2]_for_activitynet_[2018/2019].pbtxt` contains the label map of the AVA dataset, which maps the action name to the label index.
"""import csv
import os
from distutils.log import info
import pickle
from matplotlib.pyplot import contour, show
import numpy as np
import cv2
from sklearn.utils import shuffledef transformer(origin_csv_path, frame_image_dir,train_output_pkl_path, train_output_csv_path,valid_output_pkl_path, valid_output_csv_path,exclude_train_output_csv_path, exclude_valid_output_csv_path,out_action_list, out_labelmap_path, dataset_percent=0.9):"""输入:origin_csv_path:从网站导出的csv文件路径。frame_image_dir:以"视频名_第n秒.jpg"格式命名的图片,这些图片是通过逐秒读取的。output_pkl_path:输出pkl文件路径output_csv_path:输出csv文件路径out_labelmap_path:输出labelmap.txt文件路径dataset_percent:训练集和测试集分割输出:无"""# -----------------------------------------------------------------------------------------------get_label_map(origin_csv_path, out_action_list, out_labelmap_path)# -----------------------------------------------------------------------------------------------information_array = [[], [], []]# 读取输入csv文件的位置信息段落with open(origin_csv_path, 'r') as csvfile:count = 0content = csv.reader(csvfile)for line in content:# print(line)if count >= 10:frame_image_name = eval(line[1])[0]  # str# print(line[-2])location_info = eval(line[4])[1:]  # listaction_list = list(eval(line[5]).values())[0].split(',')action_list = [int(x) for x in action_list]  # listinformation_array[0].append(frame_image_name)information_array[1].append(location_info)information_array[2].append(action_list)count += 1# 将:对应帧图片名字、物体位置信息、动作种类信息汇总为一个信息数组information_array = np.array(information_array, dtype=object).transpose()# information_array = np.array(information_array)# -----------------------------------------------------------------------------------------------num_train = int(dataset_percent * len(information_array))train_info_array = information_array[:num_train]valid_info_array = information_array[num_train:]get_pkl_csv(train_info_array, train_output_pkl_path, train_output_csv_path, exclude_train_output_csv_path, frame_image_dir)get_pkl_csv(valid_info_array, valid_output_pkl_path, valid_output_csv_path, exclude_valid_output_csv_path, frame_image_dir)def get_label_map(origin_csv_path, out_action_list, out_labelmap_path):classes_list = 0classes_content = ""labelmap_strings = ""# 提取出csv中的第9行的行为下标with open(origin_csv_path, 'r') as csvfile:count = 0content = csv.reader(csvfile)for line in content:if count == 8:classes_list = linebreakcount += 1# 截取种类字典段落st = 0ed = 0for i in range(len(classes_list)):if classes_list[i].startswith('options'):st = iif classes_list[i].startswith('default_option_id'):ed = ifor i in range(st, ed):if i == st:classes_content = classes_content + classes_list[i][len('options:'):] + ','else:classes_content = classes_content + classes_list[i] + ','classes_dict = eval(classes_content)[0]# 写入labelmap.txt文件with open(out_action_list, 'w') as f:  # 写入action_list文件for v, k in classes_dict.items():labelmap_strings = labelmap_strings + "label {{\n  name: \"{}\"\n  label_id: {}\n  label_type: PERSON_MOVEMENT\n}}\n".format(k, int(v)+1)f.write(labelmap_strings)labelmap_strings = ""with open(out_labelmap_path, 'w') as f:  # 写入label_map文件for v, k in classes_dict.items():labelmap_strings = labelmap_strings + "{}: {}\n".format(int(v)+1, k)f.write(labelmap_strings)def get_pkl_csv(information_array, output_pkl_path, output_csv_path, exclude_output_csv_path, frame_image_dir):# 在遍历之前需要对我们的字典进行初始化pkl_data = dict()  # 存储pkl键值对信的字典(其值为普通list)csv_data = []  # 存储导出csv文件的2d数组read_data = {}  # 存储pkl键值对的字典(方便字典的值化为numpy数组)for i in range(len(information_array)):img_name = information_array[i][0]# -------------------------------------------------------------------------------------------video_name, frame_name = '_'.join(img_name.split('_')[:-1]), format(int(img_name.split('_')[-1][:-4]), '04d')  # 我的格式是"视频名称_帧名称",格式不同可自行更改# -------------------------------------------------------------------------------------------pkl_key = video_name + ',' + frame_namepkl_data[pkl_key] = []# 遍历所有的图片进行信息读取并写入pkl数据for i in range(len(information_array)):img_name = information_array[i][0]# -------------------------------------------------------------------------------------------video_name, frame_name = '_'.join(img_name.split('_')[:-1]), str(int(img_name.split('_')[-1][:-4]))  # 我的格式是"视频名称_帧名称",格式不同可自行更改# -------------------------------------------------------------------------------------------imgpath = frame_image_dir + '/' + img_namelocation_list = information_array[i][1]action_info = information_array[i][2]image_array = cv2.imread(imgpath)h, w = image_array.shape[:2]# 进行归一化location_list[0] /= wlocation_list[1] /= hlocation_list[2] /= wlocation_list[3] /= hlocation_list[2] = location_list[2]+location_list[0]location_list[3] = location_list[3]+location_list[1]# 置信度置为1# 组装pkl数据for kind_idx in action_info:csv_info = [video_name, frame_name, *location_list, kind_idx+1, 1]csv_data.append(csv_info)location_list = location_list + [1]pkl_key = video_name + ',' + format(int(frame_name), '04d')pkl_value = location_listpkl_data[pkl_key].append(pkl_value)for k, v in pkl_data.items():read_data[k] = np.array(v)with open(output_pkl_path, 'wb') as f:  # 写入pkl文件pickle.dump(read_data, f)with open(output_csv_path, 'w', newline='') as f:  # 写入csv文件, 设定参数newline=''可以不换行。f_csv = csv.writer(f)f_csv.writerows(csv_data)with open(exclude_output_csv_path, 'w', newline='') as f:  # 写入csv文件, 设定参数newline=''可以不换行。f_csv = csv.writer(f)f_csv.writerows([])def showpkl(pkl_path):with open(pkl_path, 'rb') as f:content = pickle.load(f)return contentdef showcsv(csv_path):output = []with open(csv_path, 'r') as f:content = csv.reader(f)for line in content:output.append(line)return outputdef showlabelmap(labelmap_path):classes_dict = dict()with open(labelmap_path, 'r') as f:content = (f.read().split('\n'))[:-1]for item in content:mid_idx = -1for i in range(len(item)):if item[i] == ":":mid_idx = iclasses_dict[item[:mid_idx]] = item[mid_idx + 1:]return classes_dictos.makedirs('./ava/annotations', exist_ok=True)
transformer("./Unnamed-VIA Project13Jul2022_16h01m30s_export.csv", './ava/labelframes','./ava/annotations/ava_dense_proposals_train.FAIR.recall_93.9.pkl', './ava/annotations/ava_train_v2.1.csv','./ava/annotations/ava_dense_proposals_val.FAIR.recall_93.9.pkl', './ava/annotations/ava_val_v2.1.csv','./ava/annotations/ava_train_excluded_timestamps_v2.1.csv', './ava/annotations/ava_val_excluded_timestamps_v2.1.csv','./ava/annotations/ava_action_list_v2.1.pbtxt', './ava/annotations/labelmap.txt', 0.9)
print(showpkl('./ava/annotations/ava_dense_proposals_train.FAIR.recall_93.9.pkl'))
print(showcsv('././ava/annotations/ava_train_v2.1.csv'))
print(showlabelmap('././ava/annotations/labelmap.txt'))

将via2ava.py和你的csv文件放在与ava同级目录下,如下图所示:

重点将代码中的“Unnamed-VIA Project13Jul2022_16h01m30s_export.csv”替换为你的csv文件名,然后执行python via2ava.py,此时会在ava/annotations目录下生成slowfast训练时所需的全部文件。

 6. slowfast环境部署

MMAction2是一个视频理解工具箱,里面集成了各种动作识别算法,其中就有slowfast。自己实现各种算法不管是环境搭建还是数据集整理都太麻烦,所以mmaction2做了二次封装,统一了环境,简化了整理数据集难度。

MMAction2源码地址

conda create -n open-mmlab python=3.8 
conda activate open-mmlab

conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge
pip3 install mmcv-full -f .8.0/index.html
git clone .git
cd mmaction2
pip3 install -e .

环境部署成功后,在mmaction2目录下创建data文件夹,然后将与via2ava.py脚本同目录下的ava文件夹放在data下。

7. 调整配置文件

进入mmaction2/configs/detection/ava目录,复制slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py文件改名为slowfast_kinetics_pretrained_demo_r50_4x16x1_20e_ava_rgb.py,配置文件内容如下:

# model setting
model = dict(type='FastRCNN',backbone=dict(type='ResNet3dSlowFast',pretrained=None,resample_rate=8,speed_ratio=8,channel_ratio=8,slow_pathway=dict(type='resnet3d',depth=50,pretrained=None,lateral=True,conv1_kernel=(1, 7, 7),dilations=(1, 1, 1, 1),conv1_stride_t=1,pool1_stride_t=1,inflate=(0, 0, 1, 1),spatial_strides=(1, 2, 2, 1)),fast_pathway=dict(type='resnet3d',depth=50,pretrained=None,lateral=False,base_channels=8,conv1_kernel=(5, 7, 7),conv1_stride_t=1,pool1_stride_t=1,spatial_strides=(1, 2, 2, 1))),roi_head=dict(type='AVARoIHead',bbox_roi_extractor=dict(type='SingleRoIExtractor3D',roi_layer_type='RoIAlign',output_size=8,with_temporal_pool=True),bbox_head=dict(type='BBoxHeadAVA',in_channels=2304,num_classes=8,topk=(1, 7),multilabel=True,dropout_ratio=0.5)),train_cfg=dict(rcnn=dict(assigner=dict(type='MaxIoUAssignerAVA',pos_iou_thr=0.9,neg_iou_thr=0.9,min_pos_iou=0.9),sampler=dict(type='RandomSampler',num=32,pos_fraction=1,neg_pos_ub=-1,add_gt_as_proposals=True),pos_weight=1.0,debug=False)),test_cfg=dict(rcnn=dict(action_thr=0.002)))dataset_type = 'AVADataset'
data_root = '/home/wzhou/way/llwang/mmaction2-master/input/ava/rawframes'
anno_root = '/home/wzhou/way/llwang/mmaction2-master/input/ava/annotations'ann_file_train = f'{anno_root}/ava_train_v2.1.csv'
ann_file_val = f'{anno_root}/ava_val_v2.1.csv'exclude_file_train = f'{anno_root}/ava_train_excluded_timestamps_v2.1.csv'
exclude_file_val = f'{anno_root}/ava_val_excluded_timestamps_v2.1.csv'label_file = f'{anno_root}/ava_action_list_v2.1.pbtxt'proposal_file_train = (f'{anno_root}/ava_dense_proposals_train.FAIR.''recall_93.9.pkl')
proposal_file_val = f'{anno_root}/ava_dense_proposals_val.FAIR.recall_93.9.pkl'img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)train_pipeline = [dict(type='SampleAVAFrames', clip_len=32, frame_interval=2),dict(type='RawFrameDecode'),dict(type='RandomRescale', scale_range=(256, 320)),dict(type='RandomCrop', size=256),dict(type='Flip', flip_ratio=0.5),dict(type='Normalize', **img_norm_cfg),dict(type='FormatShape', input_format='NCTHW', collapse=True),# Rename is needed to use mmdet detectorsdict(type='Rename', mapping=dict(imgs='img')),dict(type='ToTensor', keys=['img', 'proposals', 'gt_bboxes', 'gt_labels']),dict(type='ToDataContainer',fields=[dict(key=['proposals', 'gt_bboxes', 'gt_labels'], stack=False)]),dict(type='Collect',keys=['img', 'proposals', 'gt_bboxes', 'gt_labels'],meta_keys=['scores', 'entity_ids'])
]
# The testing is w/o. any cropping / flipping
val_pipeline = [dict(type='SampleAVAFrames', clip_len=32, frame_interval=2, test_mode=True),dict(type='RawFrameDecode'),dict(type='Resize', scale=(-1, 256)),dict(type='Normalize', **img_norm_cfg),dict(type='FormatShape', input_format='NCTHW', collapse=True),# Rename is needed to use mmdet detectorsdict(type='Rename', mapping=dict(imgs='img')),dict(type='ToTensor', keys=['img', 'proposals']),dict(type='ToDataContainer', fields=[dict(key='proposals', stack=False)]),dict(type='Collect',keys=['img', 'proposals'],meta_keys=['scores', 'img_shape'],nested=True)
]data = dict(videos_per_gpu=5,workers_per_gpu=2,val_dataloader=dict(videos_per_gpu=1),test_dataloader=dict(videos_per_gpu=1),train=dict(type=dataset_type,ann_file=ann_file_train,exclude_file=exclude_file_train,pipeline=train_pipeline,label_file=label_file,proposal_file=proposal_file_train,person_det_score_thr=0.9,num_classes=8,start_index=1,data_prefix=data_root),val=dict(type=dataset_type,ann_file=ann_file_val,exclude_file=exclude_file_val,pipeline=val_pipeline,label_file=label_file,proposal_file=proposal_file_val,person_det_score_thr=0.9,num_classes=8,start_index=1,data_prefix=data_root))
data['test'] = data['val']optimizer = dict(type='SGD', lr=0.1125, momentum=0.9, weight_decay=0.00001)
# this lr is used for 8 gpusoptimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policylr_config = dict(policy='step',step=[10, 15],warmup='linear',warmup_by_epoch=True,warmup_iters=5,warmup_ratio=0.1)
total_epochs = 200
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, save_best='mAP@0.5IOU')
log_config = dict(interval=20, hooks=[dict(type='TextLoggerHook'),])
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = ('./work_dirs/ava/''slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb')
load_from = ('/''slowfast_r50_4x16x1_256e_kinetics400_rgb/''slowfast_r50_4x16x1_256e_kinetics400_rgb_20200704-bcde7ed7.pth')
resume_from = None
find_unused_parameters = False

注意:

1、替换全部num_classes,我定义了7种行为,所以num_classes=8,要考虑__background__;

2、第42行topk=(1,7),1保持默认,7为行为的数量;

3、62-64行注意训练数据集的路径;

4、若训练过程中显存不够,修改第122行videos_per_gpu的数量;

5、第135、146行要加上start_index=1;

6、163行修改训练次数;

7、第175行load_from可使用预训练模型。

8. 开始训练

训练脚本在tools目录下,如果只有1个gpu,那么看一看train.py需要哪些参数,配置好以后python tools/train.py即可。

由于我有4张GPU训练,就使用了tools目录下的dist_train.sh脚本,进入mmaction2目录:

bash tools/dist_train.sh configs/detection/ava/slowfast_kinetics_pretrained_dog_r50_4x16x1_20e_ava_rgb.py 4

9. 训练效果

由于slowfast行为识别的前提,是先使用目标识别算法将物体框出来,所以想看训练结果,还需下载mmdetection进行目标识别。

进入mmaction2/demo目录,编辑webcam_demo_spatiotemporal_det.py,查看需要传入哪些参数。

# Copyright (c) OpenMMLab. All rights reserved.
"""Webcam Spatio-Temporal Action Detection Demo.Some codes are based on 
"""import argparse
import atexit
import copy
import logging
import queue
import threading
import time
from abc import ABCMeta, abstractmethodimport cv2
import mmcv
import numpy as np
import torch
from mmcv import Config, DictAction
from mmcv.runner import load_checkpointfrom mmaction.models import build_detectortry:from mmdet.apis import inference_detector, init_detector
except (ImportError, ModuleNotFoundError):raise ImportError('Failed to import `inference_detector` and ''`init_detector` form `mmdet.apis`. These apis are ''required in this demo! ')logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)def parse_args():parser = argparse.ArgumentParser(description='MMAction2 webcam spatio-temporal detection demo')parser.add_argument('--config',default=('/home/wzhou/way/llwang/mmaction2-master/configs/detection/ava/''slowfast_kinetics_pretrained_demo_r50_4x16x1_20e_ava_rgb.py'),help='spatio temporal detection config file path')parser.add_argument('--checkpoint',default=('/home/wzhou/way/llwang/mmaction2-master/work_dirs/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/''latest.pth'),help='spatio temporal detection checkpoint file/url')parser.add_argument('--action-score-thr',type=float,default=0.4,help='the threshold of human action score')parser.add_argument('--det-config',default='/home/wzhou/way/llwang/mmaction2-master/demo/faster_rcnn_r50_fpn_2x_coco.py',help='human detection config file path (from mmdet)')parser.add_argument('--det-checkpoint',default=('/home/wzhou/way/llwang/mmaction2-master/weights/''faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'),help='human detection checkpoint file/url')parser.add_argument('--det-score-thr',type=float,default=0.9,help='the threshold of human detection score')parser.add_argument('--input-video',default='/home/wzhou/way/llwang/mmaction2-master/input/08093.mp4',type=str,help='webcam id or input video file/url')parser.add_argument('--label-map',default='/home/wzhou/way/llwang/mmaction2-master/tools/data/ava/label_map_demo.txt',help='label map file')parser.add_argument('--device', type=str, default='cuda:0', help='CPU/CUDA device option')parser.add_argument('--output-fps',default=15,type=int,help='the fps of demo video output')parser.add_argument('--out-filename',default='/home/wzhou/way/llwang/mmaction2-master/output/08093.mp4',type=str,help='the filename of output video')parser.add_argument('--show',action='store_true',help='Whether to show results with cv2.imshow')parser.add_argument('--display-height',type=int,default=0,help='Image height for human detector and draw frames.')parser.add_argument('--display-width',type=int,default=0,help='Image width for human detector and draw frames.')parser.add_argument('--predict-stepsize',default=8,type=int,help='give out a prediction per n frames')parser.add_argument('--clip-vis-length',default=8,type=int,help='Number of draw frames per clip.')parser.add_argument('--cfg-options',nargs='+',action=DictAction,default={},help='override some settings in the used config, the key-value pair ''in xxx=yyy format will be merged into config file. For example, '"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")args = parser.parse_args()return argsclass TaskInfo:"""Wapper for a clip.Transmit input around three threads.1) Read Thread: Create task and put task into read queue. Init `frames`,`processed_frames`, `img_shape`, `ratio`, `clip_vis_length`.2) Main Thread: Get input from read queue, predict human bboxes and stdetaction labels, draw predictions and put task into display queue. Init`display_bboxes`, `stdet_bboxes` and `action_preds`, update `frames`.3) Display Thread: Get input from display queue, show/write frames anddelete task."""def __init__(self):self.id = -1# raw frames, used as human detector input, draw predictions input# and output, display inputself.frames = None# stdet paramsself.processed_frames = None  # model inputsself.frames_inds = None  # select frames from processed framesself.img_shape = None  # model inputs, processed frame shape# `action_preds` is `list[list[tuple]]`. The outer brackets indicate# different bboxes and the intter brackets indicate different action# results for the same bbox. tuple contains `class_name` and `score`.self.action_preds = None  # stdet results# human bboxes with the format (xmin, ymin, xmax, ymax)self.display_bboxes = None  # bboxes coords for self.framesself.stdet_bboxes = None  # bboxes coords for self.processed_framesself.ratio = None  # processed_frames.shape[1::-1]/frames.shape[1::-1]# for each clip, draw predictions on clip_vis_length framesself.clip_vis_length = -1def add_frames(self, idx, frames, processed_frames):"""Add the clip and corresponding id.Args:idx (int): the current index of the clip.frames (list[ndarray]): list of images in "BGR" format.processed_frames (list[ndarray]): list of resize and normed imagesin "BGR" format."""self.frames = framesself.processed_frames = processed_framesself.id = idxself.img_shape = processed_frames[0].shape[:2]def add_bboxes(self, display_bboxes):"""Add correspondding bounding boxes."""self.display_bboxes = display_bboxesself.stdet_bboxes = display_bboxes.clone()self.stdet_bboxes[:, ::2] = self.stdet_bboxes[:, ::2] * self.ratio[0]self.stdet_bboxes[:, 1::2] = self.stdet_bboxes[:, 1::2] * self.ratio[1]def add_action_preds(self, preds):"""Add the corresponding action predictions."""self.action_preds = predsdef get_model_inputs(self, device):"""Convert preprocessed images to MMAction2 STDet model inputs."""cur_frames = [self.processed_frames[idx] for idx in self.frames_inds]input_array = np.stack(cur_frames).transpose((3, 0, 1, 2))[np.newaxis]input_tensor = torch.from_numpy(input_array).to(device)return dict(return_loss=False,img=[input_tensor],proposals=[[self.stdet_bboxes]],img_metas=[[dict(img_shape=self.img_shape)]])class BaseHumanDetector(metaclass=ABCMeta):"""Base class for Human Dector.Args:device (str): CPU/CUDA device option."""def __init__(self, device):self.device = torch.device(device)@abstractmethoddef _do_detect(self, image):"""Get human bboxes with shape [n, 4].The format of bboxes is (xmin, ymin, xmax, ymax) in pixels."""def predict(self, task):"""Add keyframe bboxes to task."""# keyframe idx == (clip_len * frame_interval) // 2keyframe = task.frames[len(task.frames) // 2]# call detectorbboxes = self._do_detect(keyframe)# convert bboxes to torch.Tensor and move to target deviceif isinstance(bboxes, np.ndarray):bboxes = torch.from_numpy(bboxes).to(self.device)elif isinstance(bboxes, torch.Tensor) and bboxes.device != self.device:bboxes = bboxes.to(self.device)# update tasktask.add_bboxes(bboxes)return taskclass MmdetHumanDetector(BaseHumanDetector):"""Wrapper for mmdetection human detector.Args:config (str): Path to mmdetection config.ckpt (str): Path to mmdetection checkpoint.device (str): CPU/CUDA device option.score_thr (float): The threshold of human detection score.person_classid (int): Choose class from detection results.Default: 0. Suitable for COCO pretrained models."""def __init__(self, config, ckpt, device, score_thr, person_classid=0):super().__init__(device)self.model = init_detector(config, ckpt, device)self.person_classid = person_classidself.score_thr = score_thrdef _do_detect(self, image):"""Get bboxes in shape [n, 4] and values in pixels."""result = inference_detector(self.model, image)[self.person_classid]result = result[result[:, 4] >= self.score_thr][:, :4]return resultclass StdetPredictor:"""Wrapper for MMAction2 spatio-temporal action models.Args:config (str): Path to stdet config.ckpt (str): Path to stdet checkpoint.device (str): CPU/CUDA device option.score_thr (float): The threshold of human action score.label_map_path (str): Path to label map file. The format for each lineis `{class_id}: {class_name}`."""def __init__(self, config, checkpoint, device, score_thr, label_map_path):self.score_thr = score_thr# load modelconfig.model.backbone.pretrained = Nonemodel = build_detector(config.model, test_cfg=config.get('test_cfg'))load_checkpoint(model, checkpoint, map_location='cpu')model.to(device)model.eval()self.model = modelself.device = device# init label map, aka class_id to class_name dictwith open(label_map_path) as f:lines = f.readlines()lines = [x.strip().split(': ') for x in lines]self.label_map = {int(x[0]): x[1] for x in lines}try:if config['input']['train']['custom_classes'] is not None:self.label_map = {id + 1: self.label_map[cls]for id, cls in enumerate(config['input']['train']['custom_classes'])}except KeyError:passdef predict(self, task):"""Spatio-temporval Action Detection model inference."""# No need to do inference if no one in keyframeif len(task.stdet_bboxes) == 0:return taskwith torch.no_grad():result = self.model(**task.get_model_inputs(self.device))[0]# pack results of human detector and stdetpreds = []for _ in range(task.stdet_bboxes.shape[0]):preds.append([])for class_id in range(len(result)):if class_id + 1 not in self.label_map:continuefor bbox_id in range(task.stdet_bboxes.shape[0]):if result[class_id][bbox_id, 4] > self.score_thr:preds[bbox_id].append((self.label_map[class_id + 1],result[class_id][bbox_id, 4]))# update task# `preds` is `list[list[tuple]]`. The outer brackets indicate# different bboxes and the intter brackets indicate different action# results for the same bbox. tuple contains `class_name` and `score`.task.add_action_preds(preds)return taskclass ClipHelper:"""Multithrading utils to manage the lifecycle of task."""def __init__(self,config,display_height=0,display_width=0,input_video=0,predict_stepsize=40,output_fps=25,clip_vis_length=8,out_filename=None,show=True,stdet_input_shortside=256):# stdet sampling strategyval_pipeline = config.data.val.pipelinesampler = [x for x in val_pipelineif x['type'] == 'SampleAVAFrames'][0]clip_len, frame_interval = sampler['clip_len'], sampler['frame_interval']self.window_size = clip_len * frame_interval# assertsassert (out_filename or show), \'out_filename and show cannot both be None'assert clip_len % 2 == 0, 'We would like to have an even clip_len'assert clip_vis_length <= predict_stepsizeassert 0 < predict_stepsize <= self.window_size# source paramstry:self.cap = cv2.VideoCapture(int(input_video))self.webcam = Trueexcept ValueError:self.cap = cv2.VideoCapture(input_video)self.webcam = Falseassert self.cap.isOpened()# stdet input preprocessing paramsh = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))w = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))self.stdet_input_size = mmcv.rescale_size((w, h), (stdet_input_shortside, np.Inf))img_norm_cfg = config['img_norm_cfg']if 'to_rgb' not in img_norm_cfg and 'to_bgr' in img_norm_cfg:to_bgr = img_norm_cfg.pop('to_bgr')img_norm_cfg['to_rgb'] = to_bgrimg_norm_cfg['mean'] = np.array(img_norm_cfg['mean'])img_norm_cfg['std'] = np.array(img_norm_cfg['std'])self.img_norm_cfg = img_norm_cfg# task init paramsself.clip_vis_length = clip_vis_lengthself.predict_stepsize = predict_stepsizeself.buffer_size = self.window_size - self.predict_stepsizeframe_start = self.window_size // 2 - (clip_len // 2) * frame_intervalself.frames_inds = [frame_start + frame_interval * i for i in range(clip_len)]self.buffer = []self.processed_buffer = []# output/display paramsif display_height > 0 and display_width > 0:self.display_size = (display_width, display_height)elif display_height > 0 or display_width > 0:self.display_size = mmcv.rescale_size((w, h), (np.Inf, max(display_height, display_width)))else:self.display_size = (w, h)self.ratio = tuple(n / o for n, o in zip(self.stdet_input_size, self.display_size))if output_fps <= 0:self.output_fps = int(self.cap.get(cv2.CAP_PROP_FPS))else:self.output_fps = output_fpsself.show = showself.video_writer = Noneif out_filename is not None:self.video_writer = self.get_output_video_writer(out_filename)display_start_idx = self.window_size // 2 - self.predict_stepsize // 2self.display_inds = [display_start_idx + i for i in range(self.predict_stepsize)]# display multi-theading paramsself.display_id = -1  # task.id for display queueself.display_queue = {}self.display_lock = threading.Lock()self.output_lock = threading.Lock()# read multi-theading paramsself.read_id = -1  # task.id for read queueself.read_id_lock = threading.Lock()self.read_queue = queue.Queue()self.read_lock = threading.Lock()self.not_end = True  # cap.read() flag# program stateself.stopped = Falseatexit.register(self.clean)def read_fn(self):"""Main function for read thread.Contains three steps:1) Read and preprocess (resize + norm) frames from source.2) Create task by frames from previous step and buffer.3) Put task into read queue."""was_read = Truestart_time = time.time()while was_read and not self.stopped:# init tasktask = TaskInfo()task.clip_vis_length = self.clip_vis_lengthtask.frames_inds = self.frames_indstask.ratio = self.ratio# read bufferframes = []processed_frames = []if len(self.buffer) != 0:frames = self.bufferif len(self.processed_buffer) != 0:processed_frames = self.processed_buffer# read and preprocess frames from source and update taskwith self.read_lock:before_read = time.time()read_frame_cnt = self.window_size - len(frames)while was_read and len(frames) < self.window_size:was_read, frame = self.cap.read()if not self.webcam:# Reading frames too fast may lead to unexpected# performance degradation. If you have enough# resource, this line could be commented.time.sleep(1 / self.output_fps)if was_read:frames.append(mmcv.imresize(frame, self.display_size))processed_frame = mmcv.imresize(frame, self.stdet_input_size).astype(np.float32)_ = mmcv.imnormalize_(processed_frame,**self.img_norm_cfg)processed_frames.append(processed_frame)task.add_frames(self.read_id + 1, frames, processed_frames)# update bufferif was_read:self.buffer = frames[-self.buffer_size:]self.processed_buffer = processed_frames[-self.buffer_size:]# update read statewith self.read_id_lock:self.read_id += 1self.not_end = was_readself.read_queue.put((was_read, copy.deepcopy(task)))cur_time = time.time()logger.debug(f'Read thread: {1000*(cur_time - start_time):.0f} ms, 'f'{read_frame_cnt / (cur_time - before_read):.0f} fps')start_time = cur_timedef display_fn(self):"""Main function for display thread.Read input from display queue and display predictions."""start_time = time.time()while not self.stopped:# get the state of the read threadwith self.read_id_lock:read_id = self.read_idnot_end = self.not_endwith self.display_lock:# If video ended and we have display all frames.if not not_end and self.display_id == read_id:break# If the next task are not available, wait.if (len(self.display_queue) == 0 orself.display_queue.get(self.display_id + 1) is None):time.sleep(0.02)continue# get display input and update stateself.display_id += 1was_read, task = self.display_queue[self.display_id]del self.display_queue[self.display_id]display_id = self.display_id# do display predictionswith self.output_lock:if was_read and task.id == 0:# the first taskcur_display_inds = range(self.display_inds[-1] + 1)elif not was_read:# the last taskcur_display_inds = range(self.display_inds[0],len(task.frames))else:cur_display_inds = self.display_indsfor frame_id in cur_display_inds:frame = task.frames[frame_id]if self.show:cv2.imshow('Demo', frame)cv2.waitKey(int(1000 / self.output_fps))if self.video_writer:self.video_writer.write(frame)cur_time = time.time()logger.debug(f'Display thread: {1000*(cur_time - start_time):.0f} ms, 'f'read id {read_id}, display id {display_id}')start_time = cur_timedef __iter__(self):return selfdef __next__(self):"""Get input from read queue.This function is part of the main thread."""if self.read_queue.qsize() == 0:time.sleep(0.02)return not self.stopped, Nonewas_read, task = self.read_queue.get()if not was_read:# If we reach the end of the video, there aren't enough frames# in the task.processed_frames, so no need to model inference# and draw predictions. Put task into display queue.with self.read_id_lock:read_id = self.read_idwith self.display_lock:self.display_queue[read_id] = was_read, copy.deepcopy(task)# main thread doesn't need to handle this task againtask = Nonereturn was_read, taskdef start(self):"""Start read thread and display thread."""self.read_thread = threading.Thread(target=self.read_fn, args=(), name='VidRead-Thread', daemon=True)self.read_thread.start()self.display_thread = threading.Thread(target=self.display_fn,args=(),name='VidDisplay-Thread',daemon=True)self.display_thread.start()return selfdef clean(self):"""Close all threads and release all resources."""self.stopped = Trueself.read_lock.acquire()self.cap.release()self.read_lock.release()self.output_lock.acquire()cv2.destroyAllWindows()if self.video_writer:self.video_writer.release()self.output_lock.release()def join(self):"""Waiting for the finalization of read and display thread."""self.read_thread.join()self.display_thread.join()def display(self, task):"""Add the visualized task to the display queue.Args:task (TaskInfo object): task object that contain the necessaryinformation for prediction visualization."""with self.display_lock:self.display_queue[task.id] = (True, task)def get_output_video_writer(self, path):"""Return a video writer object.Args:path (str): path to the output video file."""return cv2.VideoWriter(filename=path,fourcc=cv2.VideoWriter_fourcc(*'mp4v'),fps=float(self.output_fps),frameSize=self.display_size,isColor=True)class BaseVisualizer(metaclass=ABCMeta):"""Base class for visualization tools."""def __init__(self, max_labels_per_bbox):self.max_labels_per_bbox = max_labels_per_bboxdef draw_predictions(self, task):"""Visualize stdet predictions on raw frames."""# read bboxes from taskbboxes = task.display_bboxes.cpu().numpy()# draw predictions and update taskkeyframe_idx = len(task.frames) // 2draw_range = [keyframe_idx - task.clip_vis_length // 2,keyframe_idx + (task.clip_vis_length - 1) // 2]assert draw_range[0] >= 0 and draw_range[1] < len(task.frames)task.frames = self.draw_clip_range(task.frames, task.action_preds,bboxes, draw_range)return taskdef draw_clip_range(self, frames, preds, bboxes, draw_range):"""Draw a range of frames with the same bboxes and predictions."""# no predictions to be drawif bboxes is None or len(bboxes) == 0:return frames# draw frames in `draw_range`left_frames = frames[:draw_range[0]]right_frames = frames[draw_range[1] + 1:]draw_frames = frames[draw_range[0]:draw_range[1] + 1]# get labels(texts) and draw predictionsdraw_frames = [self.draw_one_image(frame, bboxes, preds) for frame in draw_frames]return list(left_frames) + draw_frames + list(right_frames)@abstractmethoddef draw_one_image(self, frame, bboxes, preds):"""Draw bboxes and corresponding texts on one frame."""@staticmethoddef abbrev(name):"""Get the abbreviation of label name:'take (an object) from (a person)' -> 'take ... from ...'"""while name.find('(') != -1:st, ed = name.find('('), name.find(')')name = name[:st] + '...' + name[ed + 1:]return nameclass DefaultVisualizer(BaseVisualizer):"""Tools to visualize predictions.Args:max_labels_per_bbox (int): Max number of labels to visualize for aperson box. Default: 5.plate (str): The color plate used for visualization. Two recommendedplates are blue plate `03045e-023e8a-0077b6-0096c7-00b4d8-48cae4`and green plate `004b23-006400-007200-008000-38b000-70e000`. Theseplates are generated by /.Default: '03045e-023e8a-0077b6-0096c7-00b4d8-48cae4'.text_fontface (int): Fontface from OpenCV for texts.Default: cv2.FONT_HERSHEY_DUPLEX.text_fontscale (float): Fontscale from OpenCV for texts.Default: 0.5.text_fontcolor (tuple): fontface from OpenCV for texts.Default: (255, 255, 255).text_thickness (int): Thickness from OpenCV for texts.Default: 1.text_linetype (int): LInetype from OpenCV for texts.Default: 1."""def __init__(self,max_labels_per_bbox=5,plate='03045e-023e8a-0077b6-0096c7-00b4d8-48cae4',text_fontface=cv2.FONT_HERSHEY_DUPLEX,text_fontscale=0.5,text_fontcolor=(255, 255, 255),  # whitetext_thickness=1,text_linetype=1):super().__init__(max_labels_per_bbox=max_labels_per_bbox)self.text_fontface = text_fontfaceself.text_fontscale = text_fontscaleself.text_fontcolor = text_fontcolorself.text_thickness = text_thicknessself.text_linetype = text_linetypedef hex2color(h):"""Convert the 6-digit hex string to tuple of 3 int value (RGB)"""return (int(h[:2], 16), int(h[2:4], 16), int(h[4:], 16))plate = plate.split('-')self.plate = [hex2color(h) for h in plate]def draw_one_image(self, frame, bboxes, preds):"""Draw predictions on one image."""for bbox, pred in zip(bboxes, preds):# draw bboxbox = bbox.astype(np.int64)st, ed = tuple(box[:2]), tuple(box[2:])cv2.rectangle(frame, st, ed, (0, 0, 255), 2)# draw textsfor k, (label, score) in enumerate(pred):if k >= self.max_labels_per_bbox:breaktext = f'{self.abbrev(label)}: {score:.4f}'location = (0 + st[0], 18 + k * 18 + st[1])textsize = cv2.getTextSize(text, self.text_fontface,self.text_fontscale,self.text_thickness)[0]textwidth = textsize[0]diag0 = (location[0] + textwidth, location[1] - 14)diag1 = (location[0], location[1] + 2)cv2.rectangle(frame, diag0, diag1, self.plate[k + 1], -1)cv2.putText(frame, text, location, self.text_fontface,self.text_fontscale, self.text_fontcolor,self.text_thickness, self.text_linetype)return framedef main(args):# init human detectorhuman_detector = MmdetHumanDetector(args.det_config, args.det_checkpoint,args.device, args.det_score_thr)# init action detectorconfig = Config.fromfile(args.config)config.merge_from_dict(args.cfg_options)try:# In our spatiotemporal detection demo, different actions should have# the same number of bboxes.config['model']['test_cfg']['rcnn']['action_thr'] = .0except KeyError:passstdet_predictor = StdetPredictor(config=config,checkpoint=args.checkpoint,device=args.device,score_thr=args.action_score_thr,label_map_path=args.label_map)# init clip helperclip_helper = ClipHelper(config=config,display_height=args.display_height,display_width=args.display_width,input_video=args.input_video,predict_stepsize=args.predict_stepsize,output_fps=args.output_fps,clip_vis_length=args.clip_vis_length,out_filename=args.out_filename,show=args.show)# init visualizervis = DefaultVisualizer()# start read and display threadclip_helper.start()try:# Main thread main function contains:# 1) get input from read queue# 2) get human bboxes and stdet predictions# 3) draw stdet predictions and update task# 4) put task into display queuefor able_to_read, task in clip_helper:# get input from read queueif not able_to_read:# read thread is dead and all tasks are processedbreakif task is None:# when no input in read queue, waittime.sleep(0.01)continueinference_start = time.time()# get human bboxeshuman_detector.predict(task)# get stdet predictionsstdet_predictor.predict(task)# draw stdet predictions in raw framesvis.draw_predictions(task)logger.info(f'Stdet Results: {task.action_preds}')# add draw frames to display queueclip_helper.display(task)logger.debug('Main thread inference time 'f'{1000*(time.time() - inference_start):.0f} ms')# wait for display threadclip_helper.join()except KeyboardInterrupt:passfinally:# close read & display thread, release all resourcesclip_helper.clean()if __name__ == '__main__':main(parse_args())

--config为slowfast训练狗的配置文件

--checkpoint为slowfast训练得到的权重

--det-config为mmdetection的配置文件

--det-checkpoint为mmdetection的权重文件

然后执行该脚本,查看识别结果。

参考链接1,参考链接2,参考链接3.

本文标签: SlowFast训练自己的数据集