コード例 #1
0
ファイル: build_file_list.py プロジェクト: gyq716/mmaction2
def main():
    args = parse_args()

    if args.seed is not None:
        print(f'Set random seed to {args.seed}')
        set_random_seed(args.seed)

    if args.format == 'rawframes':
        frame_info = parse_directory(args.src_folder,
                                     rgb_prefix=args.rgb_prefix,
                                     flow_x_prefix=args.flow_x_prefix,
                                     flow_y_prefix=args.flow_y_prefix,
                                     level=args.level)
    elif args.format == 'videos':
        if args.level == 1:
            # search for one-level directory
            video_list = glob.glob(osp.join(args.src_folder, '*'))
        elif args.level == 2:
            # search for two-level directory
            video_list = glob.glob(osp.join(args.src_folder, '*', '*'))
        else:
            raise ValueError(f'level must be 1 or 2, but got {args.level}')
        frame_info = {}
        for video in video_list:
            video_path = osp.relpath(video, args.src_folder)
            # video_id: (video_relative_path, -1, -1)
            frame_info[osp.splitext(video_path)[0]] = (video_path, -1, -1)
    else:
        raise NotImplementedError('only rawframes and videos are supported')

    if args.dataset == 'ucf101':
        splits = parse_ucf101_splits(args.level)
    elif args.dataset == 'sthv1':
        splits = parse_sthv1_splits(args.level)
    elif args.dataset == 'sthv2':
        splits = parse_sthv2_splits(args.level)
    elif args.dataset == 'mit':
        splits = parse_mit_splits()
    elif args.dataset == 'mmit':
        splits = parse_mmit_splits()
    elif args.dataset in ['kinetics400', 'kinetics600', 'kinetics700']:
        splits = parse_kinetics_splits(args.level, args.dataset)
    elif args.dataset == 'hmdb51':
        splits = parse_hmdb51_split(args.level)
    elif args.dataset == 'jester':
        splits = parse_jester_splits(args.level)
    elif args.dataset == 'diving48':
        splits = parse_diving48_splits()
    else:
        raise ValueError(
            f"Supported datasets are 'ucf101, sthv1, sthv2', 'jester', "
            f"'mmit', 'mit', 'kinetics400', 'kinetics600', 'kinetics700', but "
            f'got {args.dataset}')

    assert len(splits) == args.num_split

    out_path = args.out_root_path + args.dataset

    if len(splits) > 1:
        for i, split in enumerate(splits):
            file_lists = build_file_list(split,
                                         frame_info,
                                         shuffle=args.shuffle)
            train_name = f'{args.dataset}_train_split_{i+1}_{args.format}.txt'
            val_name = f'{args.dataset}_val_split_{i+1}_{args.format}.txt'
            if args.output_format == 'txt':
                with open(osp.join(out_path, train_name), 'w') as f:
                    f.writelines(file_lists[0][0])
                with open(osp.join(out_path, val_name), 'w') as f:
                    f.writelines(file_lists[0][1])
            elif args.output_format == 'json':
                train_list = lines2dictlist(file_lists[0][0], args.format)
                val_list = lines2dictlist(file_lists[0][1], args.format)
                train_name = train_name.replace('.txt', '.json')
                val_name = val_name.replace('.txt', '.json')
                with open(osp.join(out_path, train_name), 'w') as f:
                    json.dump(train_list, f)
                with open(osp.join(out_path, val_name), 'w') as f:
                    json.dump(val_list, f)
    else:
        lists = build_file_list(splits[0], frame_info, shuffle=args.shuffle)

        if args.subset == 'train':
            ind = 0
        elif args.subset == 'val':
            ind = 1
        elif args.subset == 'test':
            ind = 2
        else:
            raise ValueError(f"subset must be in ['train', 'val', 'test'], "
                             f'but got {args.subset}.')

        filename = f'{args.dataset}_{args.subset}_list_{args.format}.txt'
        if args.output_format == 'txt':
            with open(osp.join(out_path, filename), 'w') as f:
                f.writelines(lists[0][ind])
        elif args.output_format == 'json':
            data_list = lines2dictlist(lists[0][ind], args.format)
            filename = filename.replace('.txt', '.json')
            with open(osp.join(out_path, filename), 'w') as f:
                json.dump(data_list, f)
コード例 #2
0
def main():
    args = parse_args()

    if args.level == 2:
        # search for two-level directory
        def key_func(x):
            return '/'.join(x.split('/')[-2:])
    else:
        # Only search for one-level directory
        def key_func(x):
            return x.split('/')[-1]

    if args.format == 'rawframes':
        frame_info = parse_directory(args.src_folder,
                                     key_func=key_func,
                                     rgb_prefix=args.rgb_prefix,
                                     flow_x_prefix=args.flow_x_prefix,
                                     flow_y_prefix=args.flow_y_prefix,
                                     level=args.level)
    elif args.format == 'videos':
        if args.level == 1:
            # search for one-level directory
            video_list = glob.glob(osp.join(args.src_folder, '*'))
        elif args.level == 2:
            # search for two-level directory
            video_list = glob.glob(osp.join(args.src_folder, '*', '*'))
        else:
            raise ValueError(f'level must be 1 or 2, but got {args.level}')
        frame_info = {}
        for video in video_list:
            video_path = osp.relpath(video, args.src_folder)
            # video_id: (video_relative_path, -1, -1)
            frame_info['.'.join(video_path.split('.')[:-1])] = (video_path, -1,
                                                                -1)
    else:
        raise NotImplementedError('only rawframes and videos are supported')

    if args.dataset == 'ucf101':
        splits = parse_ucf101_splits(args.level)
    elif args.dataset == 'sthv1':
        splits = parse_sthv1_splits(args.level)
    elif args.dataset == 'sthv2':
        splits = parse_sthv2_splits(args.level)
    elif args.dataset == 'mit':
        splits = parse_mit_splits(args.level)
    elif args.dataset == 'mmit':
        splits = parse_mmit_splits(args.level)
    elif args.dataset == 'kinetics400':
        splits = parse_kinetics_splits(args.level)
    else:
        raise ValueError(
            f"Supported datasets are 'ucf101, sthv1, sthv2',"
            f"'mmit', 'mit', 'kinetics400' but got {args.dataset}")

    assert len(splits) == args.num_split

    out_path = args.out_root_path + args.dataset

    if len(splits) > 1:
        for i, split in enumerate(splits):
            file_lists = build_file_list(split,
                                         frame_info,
                                         shuffle=args.shuffle)

            filename = f'{args.dataset}_train_split_{i+1}_{args.format}.txt'
            with open(osp.join(out_path, filename), 'w') as f:
                f.writelines(file_lists[0][0])

            filename = f'{args.dataset}_val_split_{i+1}_{args.format}.txt'
            with open(osp.join(out_path, filename), 'w') as f:
                f.writelines(file_lists[0][1])
    else:
        lists = build_file_list(splits[0], frame_info, shuffle=args.shuffle)
        filename = f'{args.dataset}_{args.subset}_list_{args.format}.txt'

        if args.subset == 'train':
            ind = 0
        elif args.subset == 'val':
            ind = 1
        elif args.subset == 'test':
            ind = 2
        else:
            raise ValueError(f"subset must be in ['train', 'val', 'test'], "
                             f'but got {args.subset}.')

        with open(osp.join(out_path, filename), 'w') as f:
            f.writelines(lists[0][ind])