Пример #1
0
 def test_uniform_clip_sampler_insufficient_clips(self, tmpdir):
     video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
     video_clips = VideoClips(video_list, 5, 5)
     sampler = UniformClipSampler(video_clips, 3)
     assert len(sampler) == 3 * 3
     indices = torch.tensor(list(iter(sampler)))
     assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))
Пример #2
0
    def test_distributed_sampler_and_uniform_clip_sampler(self):
        with get_list_of_videos(num_videos=3, sizes=[25, 25,
                                                     25]) as video_list:
            video_clips = VideoClips(video_list, 5, 5)
            clip_sampler = UniformClipSampler(video_clips, 3)

            distributed_sampler_rank0 = DistributedSampler(
                clip_sampler,
                num_replicas=2,
                rank=0,
                group_size=3,
            )
            indices = torch.tensor(list(iter(distributed_sampler_rank0)))
            self.assertEqual(len(distributed_sampler_rank0), 6)
            self.assertTrue(indices.equal(torch.tensor([0, 2, 4, 10, 12, 14])))

            distributed_sampler_rank1 = DistributedSampler(
                clip_sampler,
                num_replicas=2,
                rank=1,
                group_size=3,
            )
            indices = torch.tensor(list(iter(distributed_sampler_rank1)))
            self.assertEqual(len(distributed_sampler_rank1), 6)
            self.assertTrue(indices.equal(torch.tensor([5, 7, 9, 0, 2, 4])))
Пример #3
0
    def test_distributed_sampler_and_uniform_clip_sampler(self, tmpdir):
        video_list = get_list_of_videos(tmpdir,
                                        num_videos=3,
                                        sizes=[25, 25, 25])
        video_clips = VideoClips(video_list, 5, 5)
        clip_sampler = UniformClipSampler(video_clips, 3)

        distributed_sampler_rank0 = DistributedSampler(
            clip_sampler,
            num_replicas=2,
            rank=0,
            group_size=3,
        )
        indices = torch.tensor(list(iter(distributed_sampler_rank0)))
        assert len(distributed_sampler_rank0) == 6
        assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14]))

        distributed_sampler_rank1 = DistributedSampler(
            clip_sampler,
            num_replicas=2,
            rank=1,
            group_size=3,
        )
        indices = torch.tensor(list(iter(distributed_sampler_rank1)))
        assert len(distributed_sampler_rank1) == 6
        assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4]))
Пример #4
0
 def test_uniform_clip_sampler_insufficient_clips(self):
     with get_list_of_videos(num_videos=3, sizes=[10, 25,
                                                  25]) as video_list:
         video_clips = VideoClips(video_list, 5, 5)
         sampler = UniformClipSampler(video_clips, 3)
         self.assertEqual(len(sampler), 3 * 3)
         indices = torch.tensor(list(iter(sampler)))
         assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))
Пример #5
0
 def test_uniform_clip_sampler(self, tmpdir):
     video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
     video_clips = VideoClips(video_list, 5, 5)
     sampler = UniformClipSampler(video_clips, 3)
     assert len(sampler) == 3 * 3
     indices = torch.tensor(list(iter(sampler)))
     videos = torch.div(indices, 5, rounding_mode="floor")
     v_idxs, count = torch.unique(videos, return_counts=True)
     assert_equal(v_idxs, torch.tensor([0, 1, 2]))
     assert_equal(count, torch.tensor([3, 3, 3]))
     assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))
Пример #6
0
def video_uniform_3(dataset, collate_fn):
    clips_per_video = 100
    batch_size = 80
    workers = 40
    sampler = UniformClipSampler(dataset.video_clips, clips_per_video)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batch_size,
                                              sampler=sampler,
                                              num_workers=workers,
                                              pin_memory=True,
                                              collate_fn=collate_fn)
    return data_loader
Пример #7
0
 def test_uniform_clip_sampler(self):
     with get_list_of_videos(num_videos=3, sizes=[25, 25,
                                                  25]) as video_list:
         video_clips = VideoClips(video_list, 5, 5)
         sampler = UniformClipSampler(video_clips, 3)
         self.assertEqual(len(sampler), 3 * 3)
         indices = torch.tensor(list(iter(sampler)))
         videos = torch.div(indices, 5, rounding_mode='floor')
         v_idxs, count = torch.unique(videos, return_counts=True)
         assert_equal(v_idxs, torch.tensor([0, 1, 2]))
         assert_equal(count, torch.tensor([3, 3, 3]))
         assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))
Пример #8
0
def main(args):
    if args.prototype and prototype is None:
        raise ImportError(
            "The prototype module couldn't be found. Please install the latest torchvision nightly."
        )
    if not args.prototype and args.weights:
        raise ValueError(
            "The weights parameter works only in prototype mode. Please pass the --prototype argument."
        )
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)
    print("torch version: ", torch.__version__)
    print("torchvision version: ", torchvision.__version__)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
    traindir = os.path.join(args.data_path, args.train_dir)
    valdir = os.path.join(args.data_path, args.val_dir)

    print("Loading training data")
    st = time.time()
    cache_path = _get_cache_path(traindir)
    transform_train = presets.VideoClassificationPresetTrain((128, 171),
                                                             (112, 112))

    if args.cache_dataset and os.path.exists(cache_path):
        print(f"Loading dataset_train from {cache_path}")
        dataset, _ = torch.load(cache_path)
        dataset.transform = transform_train
    else:
        if args.distributed:
            print(
                "It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster"
            )
        dataset = torchvision.datasets.Kinetics400(
            traindir,
            frames_per_clip=args.clip_len,
            step_between_clips=1,
            transform=transform_train,
            frame_rate=15,
            extensions=(
                "avi",
                "mp4",
            ),
        )
        if args.cache_dataset:
            print(f"Saving dataset_train to {cache_path}")
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)

    print("Took", time.time() - st)

    print("Loading validation data")
    cache_path = _get_cache_path(valdir)

    if not args.prototype:
        transform_test = presets.VideoClassificationPresetEval(
            resize_size=(128, 171), crop_size=(112, 112))
    else:
        if args.weights:
            weights = prototype.models.get_weight(args.weights)
            transform_test = weights.transforms()
        else:
            transform_test = prototype.transforms.Kinect400Eval(
                crop_size=(112, 112), resize_size=(128, 171))

    if args.cache_dataset and os.path.exists(cache_path):
        print(f"Loading dataset_test from {cache_path}")
        dataset_test, _ = torch.load(cache_path)
        dataset_test.transform = transform_test
    else:
        if args.distributed:
            print(
                "It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster"
            )
        dataset_test = torchvision.datasets.Kinetics400(
            valdir,
            frames_per_clip=args.clip_len,
            step_between_clips=1,
            transform=transform_test,
            frame_rate=15,
            extensions=(
                "avi",
                "mp4",
            ),
        )
        if args.cache_dataset:
            print(f"Saving dataset_test to {cache_path}")
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
    train_sampler = RandomClipSampler(dataset.video_clips,
                                      args.clips_per_video)
    test_sampler = UniformClipSampler(dataset_test.video_clips,
                                      args.clips_per_video)
    if args.distributed:
        train_sampler = DistributedSampler(train_sampler)
        test_sampler = DistributedSampler(test_sampler)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=args.batch_size,
        sampler=test_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    print("Creating model")
    if not args.prototype:
        model = torchvision.models.video.__dict__[args.model](
            pretrained=args.pretrained)
    else:
        model = prototype.models.video.__dict__[args.model](
            weights=args.weights)
    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    lr = args.lr * args.world_size
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scaler = torch.cuda.amp.GradScaler() if args.amp else None

    # convert scheduler to be per iteration, not per epoch, for warmup that lasts
    # between different epochs
    iters_per_epoch = len(data_loader)
    lr_milestones = [
        iters_per_epoch * (m - args.lr_warmup_epochs)
        for m in args.lr_milestones
    ]
    main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=lr_milestones, gamma=args.lr_gamma)

    if args.lr_warmup_epochs > 0:
        warmup_iters = iters_per_epoch * args.lr_warmup_epochs
        args.lr_warmup_method = args.lr_warmup_method.lower()
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer,
                start_factor=args.lr_warmup_decay,
                total_iters=warmup_iters)
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer,
                factor=args.lr_warmup_decay,
                total_iters=warmup_iters)
        else:
            raise RuntimeError(
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
            )

        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer,
            schedulers=[warmup_lr_scheduler, main_lr_scheduler],
            milestones=[warmup_iters])
    else:
        lr_scheduler = main_lr_scheduler

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
        if args.amp:
            scaler.load_state_dict(checkpoint["scaler"])

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader,
                        device, epoch, args.print_freq, scaler)
        evaluate(model, criterion, data_loader_test, device=device)
        if args.output_dir:
            checkpoint = {
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
            if args.amp:
                checkpoint["scaler"] = scaler.state_dict()
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir,
                                         f"model_{epoch}.pth"))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str}")
Пример #9
0
def main(args):
    if args.apex and amp is None:
        raise RuntimeError(
            "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
            "to enable mixed-precision training.")

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)
    print("torch version: ", torch.__version__)
    print("torchvision version: ", torchvision.__version__)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
    traindir = os.path.join(args.data_path, args.train_dir)
    valdir = os.path.join(args.data_path, args.val_dir)
    normalize = T.Normalize(mean=[0.43216, 0.394666, 0.37645],
                            std=[0.22803, 0.22145, 0.216989])

    print("Loading training data")
    st = time.time()
    cache_path = _get_cache_path(traindir)
    transform_train = torchvision.transforms.Compose([
        T.ToFloatTensorInZeroOne(),
        T.Resize((128, 171)),
        T.RandomHorizontalFlip(), normalize,
        T.RandomCrop((112, 112))
    ])

    if args.cache_dataset and os.path.exists(cache_path):
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
        dataset.transform = transform_train
    else:
        if args.distributed:
            print("It is recommended to pre-compute the dataset cache "
                  "on a single-gpu first, as it will be faster")
        dataset = torchvision.datasets.Kinetics400(
            traindir,
            frames_per_clip=args.clip_len,
            step_between_clips=1,
            transform=transform_train,
            frame_rate=15)
        if args.cache_dataset:
            print("Saving dataset_train to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)

    print("Took", time.time() - st)

    print("Loading validation data")
    cache_path = _get_cache_path(valdir)

    transform_test = torchvision.transforms.Compose([
        T.ToFloatTensorInZeroOne(),
        T.Resize((128, 171)), normalize,
        T.CenterCrop((112, 112))
    ])

    if args.cache_dataset and os.path.exists(cache_path):
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
        dataset_test.transform = transform_test
    else:
        if args.distributed:
            print("It is recommended to pre-compute the dataset cache "
                  "on a single-gpu first, as it will be faster")
        dataset_test = torchvision.datasets.Kinetics400(
            valdir,
            frames_per_clip=args.clip_len,
            step_between_clips=1,
            transform=transform_test,
            frame_rate=15)
        if args.cache_dataset:
            print("Saving dataset_test to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
    train_sampler = RandomClipSampler(dataset.video_clips,
                                      args.clips_per_video)
    test_sampler = UniformClipSampler(dataset_test.video_clips,
                                      args.clips_per_video)
    if args.distributed:
        train_sampler = DistributedSampler(train_sampler)
        test_sampler = DistributedSampler(test_sampler)

    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True,
                                              collate_fn=collate_fn)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   collate_fn=collate_fn)

    print("Creating model")
    model = torchvision.models.video.__dict__[args.model](
        pretrained=args.pretrained)
    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    lr = args.lr * args.world_size
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.apex:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.apex_opt_level)

    # convert scheduler to be per iteration, not per epoch, for warmup that lasts
    # between different epochs
    warmup_iters = args.lr_warmup_epochs * len(data_loader)
    lr_milestones = [len(data_loader) * m for m in args.lr_milestones]
    lr_scheduler = WarmupMultiStepLR(optimizer,
                                     milestones=lr_milestones,
                                     gamma=args.lr_gamma,
                                     warmup_iters=warmup_iters,
                                     warmup_factor=1e-5)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader,
                        device, epoch, args.print_freq, args.apex)
        evaluate(model, criterion, data_loader_test, device=device)
        if args.output_dir:
            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, 'checkpoint.pth'))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Пример #10
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    # suppress printing if not master
    if args.multiprocessing_distributed and args.gpu != 0:

        def print_pass(*args):
            pass

        builtins.print = print_pass

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    print("=============> creating model '{}'".format(args.arch))
    model = models.__dict__[args.arch]()
    print(model)
    # freeze all layers but the last fc
    #     for name, param in model.named_parameters():
    #         if name not in ['fc.weight', 'fc.bias']:
    #             param.requires_grad = False
    # init the fc layer
    model.fc = nn.Linear(512, args.num_class, bias=True)
    model.fc.weight.data.normal_(mean=0.0, std=0.01)
    model.fc.bias.data.zero_()

    # load from pre-trained, before DistributedDataParallel constructor
    if args.pretrained:
        if os.path.isfile(args.pretrained):
            print("=> loading checkpoint '{}'".format(args.pretrained))
            checkpoint = torch.load(args.pretrained, map_location="cpu")

            # rename moco pre-trained keys
            state_dict = checkpoint['state_dict']
            for k in list(state_dict.keys()):
                # retain only encoder_q up to before the embedding layer
                if k.startswith('module.encoder_q'
                                ) and not k.startswith('module.encoder_q.fc'):
                    # remove prefix
                    state_dict[k[len("module.encoder_q."):]] = state_dict[k]
                # delete renamed or unused k
                del state_dict[k]

            args.start_epoch = 0
            msg = model.load_state_dict(state_dict, strict=False)
            assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}

            print("=> loaded pre-trained model '{}'".format(args.pretrained))
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrained))

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model)  #.cuda() for debug on cpu
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    # optimize only the linear classifier
    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    # assert len(parameters) == 2  # fc.weight, fc.bias
    optimizer = torch.optim.SGD(parameters,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    normalize_video = transforms_video.NormalizeVideo(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    video_augmentation_train = transforms.Compose([
        transforms_video.ToTensorVideo(),
        transforms_video.RandomResizedCropVideo(args.crop_size),
        transforms_video.RandomHorizontalFlipVideo(),
        normalize_video,
    ])
    video_augmentation_val = transforms.Compose([
        transforms_video.ToTensorVideo(),
        transforms_video.CenterCropVideo(args.crop_size),
        normalize_video,
    ])
    data_dir = os.path.join(args.data, 'data')
    anno_dir = os.path.join(args.data, 'anno')
    audio_augmentation = moco.loader.DummyAudioTransform()
    train_augmentation = {
        'video': video_augmentation_train,
        'audio': audio_augmentation
    }
    val_augmentation = {
        'video': video_augmentation_val,
        'audio': audio_augmentation
    }

    train_dataset = UCF101(data_dir,
                           anno_dir,
                           args.frame_per_clip,
                           args.step_between_clips,
                           fold=1,
                           train=True,
                           transform=train_augmentation,
                           num_workers=16)
    train_sampler = RandomClipSampler(train_dataset.video_clips, 10)
    if args.distributed:
        train_sampler = DistributedSampler(train_sampler)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               multiprocessing_context="fork")

    val_dataset = UCF101(data_dir,
                         anno_dir,
                         args.frame_per_clip,
                         args.step_between_clips,
                         fold=1,
                         train=False,
                         transform=val_augmentation,
                         num_workers=16)
    # Do not use DistributedSampler since it will destroy the testing iteration process
    val_sampler = UniformClipSampler(val_dataset.video_clips,
                                     args.clip_per_video)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.clip_per_video,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler,
                                             multiprocessing_context="fork")

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return
    if args.multiprocessing_distributed and args.gpu == 0:
        log_dir = "{}_bs={}_lr={}_cs={}_fpc={}".format(args.log_dir,
                                                       args.batch_size,
                                                       args.lr, args.crop_size,
                                                       args.frame_per_clip)
        writer = SummaryWriter(log_dir)
    else:
        writer = None
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args, writer)

        # evaluate on validation set
        val_loss, acc1, acc5 = validate(val_loader, model, criterion, args)
        if writer is not None:
            writer.add_scalar('lincls_val/loss', val_loss, epoch)
            writer.add_scalar('lincls_val/acc1', acc1, epoch)
            writer.add_scalar('lincls_val/acc5', acc5, epoch)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            ckp_dir = "{}_bs={}_lr={}_cs={}_fpc={}".format(
                args.ckp_dir, args.batch_size, args.lr, args.crop_size,
                args.frame_per_clip)
            save_checkpoint(epoch, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            },
                            ckp_dir,
                            max_save=1,
                            is_best=is_best)