Example #1
0
def get_prediction(args, test_list_name, image_tmpl, filename_seperator,
                   filter_video, model, augmentor):
    data_list = test_list_name
    val_dataset = VideoDataSet(args.datadir,
                               data_list,
                               args.groups,
                               args.frames_per_group,
                               num_clips=args.num_clips,
                               modality=args.modality,
                               image_tmpl=image_tmpl,
                               dense_sampling=args.dense_sampling,
                               fixed_offset=not args.random_sampling,
                               transform=augmentor,
                               is_train=False,
                               test_mode=not args.evaluate,
                               seperator=filename_seperator,
                               filter_video=filter_video)

    data_loader = build_dataflow(val_dataset,
                                 is_train=False,
                                 batch_size=args.batch_size,
                                 workers=args.workers)

    # switch to evaluate mode
    model.eval()

    for i, (video, label) in enumerate(data_loader):
        output = eval_a_batch(video, model)

        output = output.data.cpu().numpy().copy()
        # print('output', output)
        predictions = np.argsort(output, axis=1)
        for ii in range(len(predictions)):
            temp = predictions[ii][::-1][:5]
            preds = [str(pred) for pred in temp]

            label_encoder = pickle.load(
                open("dataset_dir/label_encoder.pkl", 'rb'))
            actual_label = label_encoder.classes_[int(preds[0])]

    return actual_label
def main():
    global args
    parser = arg_parser()
    args = parser.parse_args()
    cudnn.benchmark = True

    num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, image_tmpl, filter_video, label_file = get_dataset_config(args.dataset, args.use_lmdb)

    data_list_name = val_list_name if args.evaluate else test_list_name

    args.num_classes = num_classes
    if args.dataset == 'st2stv1':
        id_to_label, label_to_id = load_categories(os.path.join(args.datadir, label_file))

    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if args.modality == 'rgb':
        args.input_channels = 3
    elif args.modality == 'flow':
        args.input_channels = 2 * 5

    model, arch_name = build_model(args, test_mode=True)
    mean = model.mean(args.modality)
    std = model.std(args.modality)

    # overwrite mean and std if they are presented in command
    if args.mean is not None:
        if args.modality == 'rgb':
            if len(args.mean) != 3:
                raise ValueError("When training with rgb, dim of mean must be three.")
        elif args.modality == 'flow':
            if len(args.mean) != 1:
                raise ValueError("When training with flow, dim of mean must be three.")
        mean = args.mean

    if args.std is not None:
        if args.modality == 'rgb':
            if len(args.std) != 3:
                raise ValueError("When training with rgb, dim of std must be three.")
        elif args.modality == 'flow':
            if len(args.std) != 1:
                raise ValueError("When training with flow, dim of std must be three.")
        std = args.std

    model = model.cuda()
    model.eval()

    if args.threed_data:
        dummy_data = (args.input_channels, args.groups, args.input_size, args.input_size)
    else:
        dummy_data = (args.input_channels * args.groups, args.input_size, args.input_size)

    model_summary = torchsummary.summary(model, input_size=dummy_data)

    flops, params = extract_total_flops_params(model_summary)
    flops = int(flops.replace(',', '')) * (args.num_clips * args.num_crops)
    model = torch.nn.DataParallel(model).cuda()
    if args.pretrained is not None:
        print("=> using pre-trained model '{}'".format(arch_name))
        checkpoint = torch.load(args.pretrained)
        model.load_state_dict(checkpoint['state_dict'])
    else:
        print("=> creating model '{}'".format(arch_name))

    # augmentor
    if args.disable_scaleup:
        scale_size = args.input_size
    else:
        scale_size = int(args.input_size / 0.875 + 0.5)

    augments = []
    if args.num_crops == 1:
        augments += [
            GroupScale(scale_size),
            GroupCenterCrop(args.input_size)
        ]
    else:
        flip = True if args.num_crops == 10 else False
        augments += [
            GroupOverSample(args.input_size, scale_size, num_crops=args.num_crops, flip=flip),
        ]
    augments += [
        Stack(threed_data=args.threed_data),
        ToTorchFormatTensor(num_clips_crops=args.num_clips * args.num_crops),
        GroupNormalize(mean=mean, std=std, threed_data=args.threed_data)
    ]

    augmentor = transforms.Compose(augments)

    # Data loading code
    data_list = os.path.join(args.datadir, data_list_name)
    sample_offsets = list(range(-args.num_clips // 2 + 1, args.num_clips // 2 + 1))
    print("Image is scaled to {} and crop {}".format(scale_size, args.input_size))
    print("Number of crops: {}".format(args.num_crops))
    print("Number of clips: {}, offset from center with {}".format(args.num_clips, sample_offsets))

    video_data_cls = VideoDataSetLMDB if args.use_lmdb else VideoDataSet
    val_dataset = video_data_cls(args.datadir, data_list, args.groups, args.frames_per_group,
                                 num_clips=args.num_clips, modality=args.modality,
                                 image_tmpl=image_tmpl, dense_sampling=args.dense_sampling,
                                 fixed_offset=not args.random_sampling,
                                 transform=augmentor, is_train=False, test_mode=not args.evaluate,
                                 seperator=filename_seperator, filter_video=filter_video)

    data_loader = build_dataflow(val_dataset, is_train=False, batch_size=args.batch_size,
                                 workers=args.workers)

    log_folder = os.path.join(args.logdir, arch_name)
    if not os.path.exists(log_folder):
        os.makedirs(log_folder)

    batch_time = AverageMeter()
    if args.evaluate:
        logfile = open(os.path.join(log_folder, 'evaluate_log.log'), 'a')
        top1 = AverageMeter()
        top5 = AverageMeter()
    else:
        logfile = open(os.path.join(log_folder,
                                    'test_{}crops_{}clips_{}.csv'.format(args.num_crops,
                                                                         args.num_clips,
                                                                         args.input_size))
                       , 'w')

    total_outputs = 0
    outputs = np.zeros((len(data_loader) * args.batch_size, num_classes))
    # switch to evaluate mode
    model.eval()
    total_batches = len(data_loader)
    with torch.no_grad(), tqdm(total=total_batches) as t_bar:
        end = time.time()
        for i, (video, label) in enumerate(data_loader):
            output = eval_a_batch(video, model, args.input_channels, num_clips=args.num_clips,
                                  num_crops=args.num_crops,
                                  modality=args.modality, softmax=True, threed_data=args.threed_data)
            if args.evaluate:
                label = label.cuda(non_blocking=True)
                # measure accuracy
                prec1, prec5 = accuracy(output, label, topk=(1, 5))
                top1.update(prec1[0], video.size(0))
                top5.update(prec5[0], video.size(0))
                output = output.data.cpu().numpy().copy()
                batch_size = output.shape[0]
                outputs[total_outputs:total_outputs + batch_size, :] = output
            else:
                # testing, store output to prepare csv file
                # measure elapsed time
                output = output.data.cpu().numpy().copy()
                batch_size = output.shape[0]
                outputs[total_outputs:total_outputs + batch_size, :] = output
                predictions = np.argsort(output, axis=1)
                for ii in range(len(predictions)):
                    # preds = [id_to_label[str(pred)] for pred in predictions[ii][::-1][:5]]
                    temp = predictions[ii][::-1][:5]
                    preds = [str(pred) for pred in temp]
                    if args.dataset == 'st2stv1':
                        print("{};{}".format(label[ii], id_to_label[int(preds[0])]), file=logfile)
                    else:
                        print("{};{}".format(label[ii], ";".join(preds)), file=logfile)
            total_outputs += video.shape[0]
            batch_time.update(time.time() - end)
            end = time.time()
            t_bar.update(1)

        # if not args.evaluate:
        outputs = outputs[:total_outputs]
        print("Predict {} videos.".format(total_outputs), flush=True)
        np.save(os.path.join(log_folder, '{}_{}crops_{}clips_{}_details.npy'.format("val" if args.evaluate else "test", args.num_crops, args.num_clips, args.input_size)), outputs)

    if args.evaluate:
        print('Val@{}({}) (# crops = {}, # clips = {}): \tTop@1: {:.4f}\tTop@5: {:.4f}\tFLOPs: {:,}\tParams:{} '.format(
            args.input_size, scale_size, args.num_crops, args.num_clips, top1.avg, top5.avg, flops, params), flush=True)
        print('Val@{}({}) (# crops = {}, # clips = {}): \tTop@1: {:.4f}\tTop@5: {:.4f}\tFLOPs: {:,}\tParams:{} '.format(
            args.input_size, scale_size, args.num_crops, args.num_clips, top1.avg, top5.avg, flops, params), flush=True, file=logfile)

    logfile.close()
Example #3
0
def main():
    global args
    parser = arg_parser()
    args = parser.parse_args()
    cudnn.benchmark = True

    num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, image_tmpl, filter_video, label_file = get_dataset_config(
        args.dataset, args.use_lmdb)

    args.num_classes = num_classes

    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if args.modality == 'rgb':
        args.input_channels = 3
    elif args.modality == 'flow':
        args.input_channels = 2 * 5

    model, arch_name = build_model(args)
    mean = model.mean(args.modality)
    std = model.std(args.modality)

    # overwrite mean and std if they are presented in command
    if args.mean is not None:
        if args.modality == 'rgb':
            if len(args.mean) != 3:
                raise ValueError(
                    "When training with rgb, dim of mean must be three.")
        elif args.modality == 'flow':
            if len(args.mean) != 1:
                raise ValueError(
                    "When training with flow, dim of mean must be three.")
        mean = args.mean

    if args.std is not None:
        if args.modality == 'rgb':
            if len(args.std) != 3:
                raise ValueError(
                    "When training with rgb, dim of std must be three.")
        elif args.modality == 'flow':
            if len(args.std) != 1:
                raise ValueError(
                    "When training with flow, dim of std must be three.")
        std = args.std

    model = model.cuda()
    model.eval()

    if args.threed_data:
        dummy_data = (3, args.groups, args.input_size, args.input_size)
    else:
        dummy_data = (3 * args.groups, args.input_size, args.input_size)

    model_summary = torchsummary.summary(model, input_size=dummy_data)
    torch.cuda.empty_cache()

    if args.show_model:
        print(model)
        print(model_summary)
        return 0

    model = torch.nn.DataParallel(model).cuda()

    if args.pretrained is not None:
        print("=> using pre-trained model '{}'".format(arch_name))
        checkpoint = torch.load(args.pretrained, map_location='cpu')
        if args.transfer:
            new_dict = {}
            for k, v in checkpoint['state_dict'].items():
                # TODO: a better approach:
                if k.replace("module.", "").startswith("fc"):
                    continue
                new_dict[k] = v
        else:
            new_dict = checkpoint['state_dict']
        model.load_state_dict(new_dict, strict=False)
    else:
        print("=> creating model '{}'".format(arch_name))

    # define loss function (criterion) and optimizer
    train_criterion = nn.CrossEntropyLoss().cuda()
    val_criterion = nn.CrossEntropyLoss().cuda()

    # Data loading code
    video_data_cls = VideoDataSetLMDB if args.use_lmdb else VideoDataSet
    val_list = os.path.join(args.datadir, val_list_name)
    val_augmentor = get_augmentor(False,
                                  args.input_size,
                                  mean,
                                  std,
                                  args.disable_scaleup,
                                  threed_data=args.threed_data,
                                  version=args.augmentor_ver,
                                  scale_range=args.scale_range)
    val_dataset = video_data_cls(args.datadir,
                                 val_list,
                                 args.groups,
                                 args.frames_per_group,
                                 num_clips=args.num_clips,
                                 modality=args.modality,
                                 image_tmpl=image_tmpl,
                                 dense_sampling=args.dense_sampling,
                                 transform=val_augmentor,
                                 is_train=False,
                                 test_mode=False,
                                 seperator=filename_seperator,
                                 filter_video=filter_video)

    val_loader = build_dataflow(val_dataset,
                                is_train=False,
                                batch_size=args.batch_size,
                                workers=args.workers)

    log_folder = os.path.join(args.logdir, arch_name)
    if not os.path.exists(log_folder):
        os.makedirs(log_folder)

    if args.evaluate:
        logfile = open(os.path.join(log_folder, 'evaluate_log.log'), 'a')
        flops, params = extract_total_flops_params(model_summary)
        print(model_summary)
        val_top1, val_top5, val_losses, val_speed = validate(
            val_loader, model, val_criterion)
        print(
            'Val@{}: \tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch\tFlops: {}\tParams: {}'
            .format(args.input_size, val_losses, val_top1, val_top5,
                    val_speed * 1000.0, flops, params),
            flush=True)
        print(
            'Val@{}: \tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch\tFlops: {}\tParams: {}'
            .format(args.input_size, val_losses, val_top1, val_top5,
                    val_speed * 1000.0, flops, params),
            flush=True,
            file=logfile)
        return

    train_list = os.path.join(args.datadir, train_list_name)

    train_augmentor = get_augmentor(True,
                                    args.input_size,
                                    mean,
                                    std,
                                    threed_data=args.threed_data,
                                    version=args.augmentor_ver,
                                    scale_range=args.scale_range)
    train_dataset = video_data_cls(args.datadir,
                                   train_list,
                                   args.groups,
                                   args.frames_per_group,
                                   num_clips=args.num_clips,
                                   modality=args.modality,
                                   image_tmpl=image_tmpl,
                                   dense_sampling=args.dense_sampling,
                                   transform=train_augmentor,
                                   is_train=True,
                                   test_mode=False,
                                   seperator=filename_seperator,
                                   filter_video=filter_video)

    train_loader = build_dataflow(train_dataset,
                                  is_train=True,
                                  batch_size=args.batch_size,
                                  workers=args.workers)

    sgd_polices = model.parameters()
    optimizer = torch.optim.SGD(sgd_polices,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=args.nesterov)

    if args.lr_scheduler == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, args.lr_steps[0], gamma=0.1)
    elif args.lr_scheduler == 'multisteps':
        scheduler = lr_scheduler.MultiStepLR(optimizer,
                                             args.lr_steps,
                                             gamma=0.1)
    elif args.lr_scheduler == 'cosine':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   args.epochs,
                                                   eta_min=0)
    elif args.lr_scheduler == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   'min',
                                                   verbose=True)

    best_top1 = 0.0
    tensorboard_logger.configure(os.path.join(log_folder))
    # optionally resume from a checkpoint
    if args.resume:
        logfile = open(os.path.join(log_folder, 'log.log'), 'a')
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_top1 = checkpoint['best_top1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            try:
                scheduler.load_state_dict(checkpoint['scheduler'])
            except:
                pass
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            raise ValueError("Checkpoint is not found: {}".format(args.resume))
    else:
        if os.path.exists(os.path.join(log_folder, 'log.log')):
            shutil.copyfile(
                os.path.join(log_folder, 'log.log'),
                os.path.join(log_folder,
                             'log.log.{}'.format(int(time.time()))))
        logfile = open(os.path.join(log_folder, 'log.log'), 'w')

    command = " ".join(sys.argv)
    print(command, flush=True)
    print(args, flush=True)
    print(model, flush=True)
    print(model_summary, flush=True)

    print(command, file=logfile, flush=True)
    print(args, file=logfile, flush=True)

    if args.resume == '':
        print(model, file=logfile, flush=True)
        print(model_summary, flush=True, file=logfile)

    for epoch in range(args.start_epoch, args.epochs):
        if args.lr_scheduler == 'plateau':
            scheduler.step(val_losses, epoch)
        else:
            scheduler.step(epoch)
        try:
            # get_lr get all lrs for every layer of current epoch, assume the lr for all layers are identical
            lr = scheduler.optimizer.param_groups[0]['lr']
        except:
            lr = None
        # set current learning rate
        # train for one epoch
        train_top1, train_top5, train_losses, train_speed, speed_data_loader, train_steps = \
            train(train_loader, model, train_criterion, optimizer, epoch + 1,
                  display=args.print_freq,
                  label_smoothing=args.label_smoothing, clip_gradient=args.clip_gradient)
        print(
            'Train: [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch\tData loading: {:.2f} ms/batch'
            .format(epoch + 1, args.epochs, train_losses, train_top1,
                    train_top5, train_speed * 1000.0,
                    speed_data_loader * 1000.0),
            file=logfile,
            flush=True)
        print(
            'Train: [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch\tData loading: {:.2f} ms/batch'
            .format(epoch + 1, args.epochs, train_losses, train_top1,
                    train_top5, train_speed * 1000.0,
                    speed_data_loader * 1000.0),
            flush=True)

        # evaluate on validation set
        val_top1, val_top5, val_losses, val_speed = validate(
            val_loader, model, val_criterion)
        print(
            'Val  : [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch'
            .format(epoch + 1, args.epochs, val_losses, val_top1, val_top5,
                    val_speed * 1000.0),
            file=logfile,
            flush=True)
        print(
            'Val  : [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch'
            .format(epoch + 1, args.epochs, val_losses, val_top1, val_top5,
                    val_speed * 1000.0),
            flush=True)
        # remember best prec@1 and save checkpoint
        is_best = val_top1 > best_top1
        best_top1 = max(val_top1, best_top1)

        save_dict = {
            'epoch': epoch + 1,
            'arch': arch_name,
            'state_dict': model.state_dict(),
            'best_top1': best_top1,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict()
        }

        save_checkpoint(save_dict, is_best, filepath=log_folder)

        if lr is not None:
            tensorboard_logger.log_value('learning-rate', lr, epoch + 1)
        tensorboard_logger.log_value('val-top1', val_top1, epoch + 1)
        tensorboard_logger.log_value('val-loss', val_losses, epoch + 1)
        tensorboard_logger.log_value('train-top1', train_top1, epoch + 1)
        tensorboard_logger.log_value('train-loss', train_losses, epoch + 1)
        tensorboard_logger.log_value('best-val-top1', best_top1, epoch + 1)

    logfile.close()
def main_worker(gpu, ngpus_per_node, args):
    cudnn.benchmark = args.cudnn_benchmark
    args.gpu = gpu

    num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, image_tmpl, filter_video, label_file = get_dataset_config(
        args.dataset, args.use_lmdb)
    args.num_classes = num_classes

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    if args.modality == 'rgb':
        args.input_channels = 3
    elif args.modality == 'flow':
        args.input_channels = 2 * 5

    model, arch_name = build_model(args)
    mean = model.mean(args.modality)
    std = model.std(args.modality)

    # overwrite mean and std if they are presented in command
    if args.mean is not None:
        if args.modality == 'rgb':
            if len(args.mean) != 3:
                raise ValueError(
                    "When training with rgb, dim of mean must be three.")
        elif args.modality == 'flow':
            if len(args.mean) != 1:
                raise ValueError(
                    "When training with flow, dim of mean must be three.")
        mean = args.mean

    if args.std is not None:
        if args.modality == 'rgb':
            if len(args.std) != 3:
                raise ValueError(
                    "When training with rgb, dim of std must be three.")
        elif args.modality == 'flow':
            if len(args.std) != 1:
                raise ValueError(
                    "When training with flow, dim of std must be three.")
        std = args.std

    model = model.cuda(args.gpu)
    model.eval()

    if args.threed_data:
        dummy_data = (args.input_channels, args.groups, args.input_size,
                      args.input_size)
    else:
        dummy_data = (args.input_channels * args.groups, args.input_size,
                      args.input_size)

    if args.rank == 0:
        model_summary = torchsummary.summary(model, input_size=dummy_data)
        torch.cuda.empty_cache()

    if args.show_model and args.rank == 0:
        print(model)
        print(model_summary)
        return 0

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            # the batch size should be divided by number of nodes as well
            args.batch_size = int(args.batch_size / args.world_size)
            args.workers = int(args.workers / ngpus_per_node)

            if args.sync_bn:
                process_group = torch.distributed.new_group(
                    list(range(args.world_size)))
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    model, process_group)

            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        # assign rank to 0
        model = torch.nn.DataParallel(model).cuda()
        args.rank = 0

    if args.pretrained is not None:
        if args.rank == 0:
            print("=> using pre-trained model '{}'".format(arch_name))
        if args.gpu is None:
            checkpoint = torch.load(args.pretrained, map_location='cpu')
        else:
            checkpoint = torch.load(args.pretrained,
                                    map_location='cuda:{}'.format(args.gpu))
        if args.transfer:
            new_dict = {}
            for k, v in checkpoint['state_dict'].items():
                # TODO: a better approach:
                if k.replace("module.", "").startswith("fc"):
                    continue
                new_dict[k] = v
        else:
            new_dict = checkpoint['state_dict']
        model.load_state_dict(new_dict, strict=False)
        del checkpoint  # dereference seems crucial
        torch.cuda.empty_cache()
    else:
        if args.rank == 0:
            print("=> creating model '{}'".format(arch_name))

    # define loss function (criterion) and optimizer
    train_criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    val_criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    # Data loading code
    val_list = os.path.join(args.datadir, val_list_name)

    val_augmentor = get_augmentor(
        False,
        args.input_size,
        scale_range=args.scale_range,
        mean=mean,
        std=std,
        disable_scaleup=args.disable_scaleup,
        threed_data=args.threed_data,
        is_flow=True if args.modality == 'flow' else False,
        version=args.augmentor_ver)

    video_data_cls = VideoDataSetLMDB if args.use_lmdb else VideoDataSet
    val_dataset = video_data_cls(args.datadir,
                                 val_list,
                                 args.groups,
                                 args.frames_per_group,
                                 num_clips=args.num_clips,
                                 modality=args.modality,
                                 image_tmpl=image_tmpl,
                                 dense_sampling=args.dense_sampling,
                                 transform=val_augmentor,
                                 is_train=False,
                                 test_mode=False,
                                 seperator=filename_seperator,
                                 filter_video=filter_video)

    val_loader = build_dataflow(val_dataset,
                                is_train=False,
                                batch_size=args.batch_size,
                                workers=args.workers,
                                is_distributed=args.distributed)

    log_folder = os.path.join(args.logdir, arch_name)
    if args.rank == 0:
        if not os.path.exists(log_folder):
            os.makedirs(log_folder)

    if args.evaluate:
        val_top1, val_top5, val_losses, val_speed = validate(val_loader,
                                                             model,
                                                             val_criterion,
                                                             gpu_id=args.gpu)
        if args.rank == 0:
            logfile = open(os.path.join(log_folder, 'evaluate_log.log'), 'a')
            flops, params = extract_total_flops_params(model_summary)
            print(
                'Val@{}: \tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch\tFlops: {}\tParams: {}'
                .format(args.input_size, val_losses, val_top1, val_top5,
                        val_speed * 1000.0, flops, params),
                flush=True)
            print(
                'Val@{}: \tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch\tFlops: {}\tParams: {}'
                .format(args.input_size, val_losses, val_top1, val_top5,
                        val_speed * 1000.0, flops, params),
                flush=True,
                file=logfile)
        return

    train_list = os.path.join(args.datadir, train_list_name)

    train_augmentor = get_augmentor(
        True,
        args.input_size,
        scale_range=args.scale_range,
        mean=mean,
        std=std,
        disable_scaleup=args.disable_scaleup,
        threed_data=args.threed_data,
        is_flow=True if args.modality == 'flow' else False,
        version=args.augmentor_ver)

    train_dataset = video_data_cls(args.datadir,
                                   train_list,
                                   args.groups,
                                   args.frames_per_group,
                                   num_clips=args.num_clips,
                                   modality=args.modality,
                                   image_tmpl=image_tmpl,
                                   dense_sampling=args.dense_sampling,
                                   transform=train_augmentor,
                                   is_train=True,
                                   test_mode=False,
                                   seperator=filename_seperator,
                                   filter_video=filter_video)

    train_loader = build_dataflow(train_dataset,
                                  is_train=True,
                                  batch_size=args.batch_size,
                                  workers=args.workers,
                                  is_distributed=args.distributed)

    sgd_polices = model.parameters()
    optimizer = torch.optim.SGD(sgd_polices,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=args.nesterov)

    if args.lr_scheduler == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, args.lr_steps[0], gamma=0.1)
    elif args.lr_scheduler == 'multisteps':
        scheduler = lr_scheduler.MultiStepLR(optimizer,
                                             args.lr_steps,
                                             gamma=0.1)
    elif args.lr_scheduler == 'cosine':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   args.epochs,
                                                   eta_min=0)
    elif args.lr_scheduler == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   'min',
                                                   verbose=True)

    best_top1 = 0.0
    # optionally resume from a checkpoint
    if args.resume:
        if args.rank == 0:
            logfile = open(os.path.join(log_folder, 'log.log'), 'a')
        if os.path.isfile(args.resume):
            if args.rank == 0:
                print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume, map_location='cpu')
            else:
                checkpoint = torch.load(args.resume,
                                        map_location='cuda:{}'.format(
                                            args.gpu))
            args.start_epoch = checkpoint['epoch']
            # TODO: handle distributed version
            best_top1 = checkpoint['best_top1']
            if args.gpu is not None:
                if not isinstance(best_top1, float):
                    best_top1 = best_top1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            try:
                scheduler.load_state_dict(checkpoint['scheduler'])
            except:
                pass
            if args.rank == 0:
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            del checkpoint  # dereference seems crucial
            torch.cuda.empty_cache()
        else:
            raise ValueError("Checkpoint is not found: {}".format(args.resume))
    else:
        if os.path.exists(os.path.join(log_folder,
                                       'log.log')) and args.rank == 0:
            shutil.copyfile(
                os.path.join(log_folder, 'log.log'),
                os.path.join(log_folder,
                             'log.log.{}'.format(int(time.time()))))
        if args.rank == 0:
            logfile = open(os.path.join(log_folder, 'log.log'), 'w')

    if args.rank == 0:
        command = " ".join(sys.argv)
        tensorboard_logger.configure(os.path.join(log_folder))
        print(command, flush=True)
        print(args, flush=True)
        print(model, flush=True)
        print(command, file=logfile, flush=True)
        print(model_summary, flush=True)
        print(args, file=logfile, flush=True)

    if args.resume == '' and args.rank == 0:
        print(model, file=logfile, flush=True)
        print(model_summary, flush=True, file=logfile)

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train_top1, train_top5, train_losses, train_speed, speed_data_loader, train_steps = \
            train(train_loader, model, train_criterion, optimizer, epoch + 1,
                  display=args.print_freq, label_smoothing=args.label_smoothing,
                  clip_gradient=args.clip_gradient, gpu_id=args.gpu, rank=args.rank)
        if args.distributed:
            dist.barrier()

        # evaluate on validation set
        val_top1, val_top5, val_losses, val_speed = validate(val_loader,
                                                             model,
                                                             val_criterion,
                                                             gpu_id=args.gpu)

        # update current learning rate
        if args.lr_scheduler == 'plateau':
            scheduler.step(val_losses)
        else:
            scheduler.step(epoch + 1)

        if args.distributed:
            dist.barrier()

        # only logging at rank 0
        if args.rank == 0:
            print(
                'Train: [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch\tData loading: {:.2f} ms/batch'
                .format(epoch + 1, args.epochs, train_losses, train_top1,
                        train_top5, train_speed * 1000.0,
                        speed_data_loader * 1000.0),
                file=logfile,
                flush=True)
            print(
                'Train: [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch\tData loading: {:.2f} ms/batch'
                .format(epoch + 1, args.epochs, train_losses, train_top1,
                        train_top5, train_speed * 1000.0,
                        speed_data_loader * 1000.0),
                flush=True)
            print(
                'Val  : [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch'
                .format(epoch + 1, args.epochs, val_losses, val_top1, val_top5,
                        val_speed * 1000.0),
                file=logfile,
                flush=True)
            print(
                'Val  : [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch'
                .format(epoch + 1, args.epochs, val_losses, val_top1, val_top5,
                        val_speed * 1000.0),
                flush=True)

            # remember best prec@1 and save checkpoint
            is_best = val_top1 > best_top1
            best_top1 = max(val_top1, best_top1)

            save_dict = {
                'epoch': epoch + 1,
                'arch': arch_name,
                'state_dict': model.state_dict(),
                'best_top1': best_top1,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            }

            save_checkpoint(save_dict, is_best, filepath=log_folder)
            try:
                # get_lr get all lrs for every layer of current epoch, assume the lr for all layers are identical
                lr = scheduler.optimizer.param_groups[0]['lr']
            except Exception as e:
                lr = None
            if lr is not None:
                tensorboard_logger.log_value('learning-rate', lr, epoch + 1)
            tensorboard_logger.log_value('val-top1', val_top1, epoch + 1)
            tensorboard_logger.log_value('val-loss', val_losses, epoch + 1)
            tensorboard_logger.log_value('train-top1', train_top1, epoch + 1)
            tensorboard_logger.log_value('train-loss', train_losses, epoch + 1)
            tensorboard_logger.log_value('best-val-top1', best_top1, epoch + 1)

        if args.distributed:
            dist.barrier()

    if args.rank == 0:
        logfile.close()