def test_jester_rgbdiff(): cfg.merge_from_file('configs/tsn_r50_jester_rgbdiff_224x3_seg.yaml') transform = build_transform(cfg, is_train=True) dataset = build_dataset(cfg, transform=transform, is_train=True) image, target = dataset.__getitem__(20) print(image.shape) print(target) assert image.shape == (3, 15, 224, 224)
def test_hmdb51_rgb(): cfg.merge_from_file('configs/tsn_r50_hmdb51_rgb_224x3_seg.yaml') cfg.DATASETS.NUM_CLIPS = 8 transform = build_transform(cfg, is_train=True) dataset = build_dataset(cfg, transform=transform, is_train=True) image, target = dataset.__getitem__(20) print(image.shape) print(target) assert image.shape == (3, 8, 224, 224)
def main(): is_train = True transform = build_transform(cfg, is_train=is_train) dataset = build_dataset(cfg, transform=transform, is_train=is_train) sampler = SequentialSampler(dataset) cfg.SAMPLER.MULTIGRID.DEFAULT_S = cfg.TRANSFORM.TRAIN.TRAIN_CROP_SIZE sampler = ShortCycleBatchSampler(sampler, cfg.DATALOADER.TRAIN_BATCH_SIZE, False, cfg) print('batch_size:', cfg.DATALOADER.TRAIN_BATCH_SIZE) for i, idxs in enumerate(sampler): print(idxs) print(len(idxs)) if i > 3: break print(len(sampler))