Пример #1
0
Файл: train.py Проект: Desny/ECO
def train(args):
    # parse config
    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    with fluid.dygraph.guard(place):
        config = parse_config(args.config)
        train_config = merge_configs(config, 'train', vars(args))
        #根据自己定义的网络,声明train_model
        train_model = TSN(args.batch_size, 32)
        train_model.train()
        opt = fluid.optimizer.Momentum(0.001,
                                       0.9,
                                       parameter_list=train_model.parameters())

        if args.pretrain:
            # 加载上一次训练的模型,继续训练
            model, _ = fluid.dygraph.load_dygraph(args.save_dir + '/tsn_model')
            train_model.load_dict(model)

        # build model
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)

        # 迭代器,每次输送一个批次的数据,每个数据的形式(imgs, label)
        train_reader = UCFReader('train', train_config).create_reader()

        epochs = args.epoch
        for i in range(epochs):
            for batch_id, data in enumerate(train_reader()):
                dy_x_data = np.array([x[0] for x in data]).astype(
                    'float32'
                )  # dy_x_data=imgs  shape:(batch_size, seg_num*seglen, 3, target_size, target_size)
                y_data = np.array([x[1] for x in data]).astype(
                    'int64')  # y_data=label  shape:(batch_size,)

                img = fluid.dygraph.to_variable(dy_x_data)
                label = fluid.dygraph.to_variable(y_data)
                label = fluid.layers.reshape(label, [args.batch_size, 1])
                label.stop_gradient = True  # 反向求导时,只有目标函数中的label部分对label求导(而不是label与img融合求导)

                out, acc = train_model(img, label)

                loss = fluid.layers.cross_entropy(out, label)
                avg_loss = fluid.layers.mean(loss)

                avg_loss.backward()

                opt.minimize(avg_loss)
                train_model.clear_gradients()

                if batch_id % 100 == 0:
                    print("Loss at epoch {} step {}: {}, acc: {}".format(
                        i, batch_id, avg_loss.numpy(), acc.numpy()))
                    fluid.dygraph.save_dygraph(train_model.state_dict(),
                                               args.save_dir + '/tsn_model')
        print("Final loss: {}".format(avg_loss.numpy()))
Пример #2
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if args.dataset == 'ucf101':
        num_class = 101
    elif args.dataset == 'hmdb51':
        num_class = 51
    elif args.dataset == 'kinetics':
        num_class = 400
    else:
        raise ValueError('Unknown dataset '+args.dataset)

    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type, dropout=args.dropout, partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet("", args.train_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg",
                   transform=torchvision.transforms.Compose([
                       train_augmentation,
                       Stack(roll=args.arch == 'BNInception'),
                       ToTorchFormatTensor(div=args.arch != 'BNInception'),
                       normalize,
                   ])),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        TSNDataSet("", args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg",
                   random_shift=False,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=args.arch == 'BNInception'),
                       ToTorchFormatTensor(div=args.arch != 'BNInception'),
                       normalize,
                   ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, (epoch + 1) * len(train_loader))

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)
Пример #3
0
    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        exit()

    log_training = open(
        os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        print("Epoch", epoch)
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        print("evaluate")
        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader), log_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
#################################
# main loop
#################################

for di in range(0, args.num_experiments):
    p['logdir'] = './%s/%s/%d/%d/' % (args.logdir, expdir, ei, di)
    if(not os.path.exists(p['logdir'])):
        os.makedirs(p['logdir'])

    model = []
    model = TSN(p, dataset_train)
    model = model.cuda(device)

    optim = get_optimizer(args, model)

    max_perf_val = 0.0    
    max_perf_aux = 0.0
    for epoch in range(0, args.num_epochs):
        stats_train = process_epoch('train', epoch, p, dataloader_train, model, optim)
        stats_val = process_epoch('val', epoch, p, dataloader_val, model)
        
        perf_val = stats_val['top1.cause'] + stats_val['top1.effect']
        perf_val_aux = stats_val['top2.cause'] + stats_val['top2.effect']
        if(perf_val >= max_perf_val):
            if(perf_val_aux >= max_perf_aux):
                max_perf_val = perf_val
                max_perf_aux = perf_val_aux
                torch.save(model.state_dict(), p['logdir'] + 'model_max.pth')        

    stats_test = process_epoch('test', epoch, p, dataloader_test, model)
    print(stats_test)
Пример #5
0
def main():

    global args, best_prec1
    args = Parse_args()

    log.l.info('Input command:\n ===========> python ' + ' '.join(sys.argv) +
               '  ===========>')

    if args.dataset == 'ucf101':
        num_class = 101
    elif args.dataset == 'hmdb51':
        num_class = 51
    elif args.dataset == 'kinetics':
        num_class = 400
    elif args.dataset == 'mm':
        num_class = 500
    elif args.dataset == 'thumos14':
        num_class = 21
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    log.l.info(
        '============= prepare the model and model\'s parameters ============='
    )

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        log.l.info(
            '============== train from checkpoint (finetune mode) ================='
        )
        if os.path.isfile(args.resume):
            log.l.info(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            log.l.info(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            log.l.info(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    log.l.info('============== Now, loading data ... ==============\n')
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(PerFrameData(
        args.frames_root,
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        data_gap=args.data_gap,
        test_mode=False,
        image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else
        args.flow_prefix + "{}_{:05d}.jpg",
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.data_workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(PerFrameData(
        args.frames_root,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        data_gap=args.data_gap,
        test_mode=True,
        image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else
        args.flow_prefix + "{}_{:05d}.jpg",
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.data_workers,
                                             pin_memory=True)

    log.l.info(
        '================= Now, define loss function and optimizer =============='
    )
    weight = torch.from_numpy(np.array([1] + [3] * (num_class - 1)))
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        log.l.info(
            ('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
                group['name'], len(group['params']), group['lr_mult'],
                group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        log.l.info('Need val the data first...')
        validate(val_loader, model, criterion, 0)

    log.l.info(
        '\n\n===================> TRAIN and VAL begins <===================\n')

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader))

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Пример #6
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if args.dataset == 'ucf101':
        num_class = 101
    elif args.dataset == 'hmdb51':
        num_class = 51
    elif args.dataset == 'kinetics':
        num_class = 400
    elif args.dataset == 'myDataset':
        num_class = 12
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        "",
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else
        args.flow_prefix + "{}_{:05d}.jpg",
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        "",
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else
        args.flow_prefix + "{}_{:05d}.jpg",
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    f, axs = plt.subplots(4, 1, figsize=(10, 5))
    if args.start_epoch == 0:
        train_acc = []
        train_loss = []
        val_acc = []
        val_loss = []
        epochs = []
        val_epochs = []
    else:
        train_acc = np.load("./%s/train_acc.npy" % args.snapshot_pref).tolist()
        train_loss = np.load("./%s/train_loss.npy" %
                             args.snapshot_pref).tolist()
        val_acc = np.load("./%s/val_acc.npy" % args.snapshot_pref).tolist()
        val_loss = np.load("./%s/val_loss.npy" % args.snapshot_pref).tolist()
        epochs = np.load("./%s/epochs.npy" % args.snapshot_pref).tolist()
        val_epochs = np.load("./%s/val_epochs.npy" %
                             args.snapshot_pref).tolist()
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        acc, loss = train(train_loader, model, criterion, optimizer, epoch)
        train_acc.append(acc)
        train_loss.append(loss)
        epochs.append(epoch)
        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1, v_loss = validate(val_loader, model, criterion,
                                     (epoch + 1) * len(train_loader))

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
            val_acc.append(prec1)
            val_loss.append(v_loss)
            val_epochs.append(epoch)
        axs[0].plot(val_epochs, val_loss, c='b', marker='.', label='val_loss')
        axs[1].plot(val_epochs, val_acc, c='r', marker='.', label='val_acc')
        axs[2].plot(epochs, train_loss, c='b', marker='.', label='train_loss')
        axs[3].plot(epochs, train_acc, c='r', marker='.', label='train_acc')
        plt.title('TSN_' + args.snapshot_pref)
        if epoch == 0:
            for i in range(4):
                axs[i].legend(loc='best')
        plt.pause(0.000001)
        if not os.path.exists(args.snapshot_pref):
            os.makedirs(args.snapshot_pref)
        plt.savefig('./%s/%s.jpg' % (args.snapshot_pref, str(epoch).zfill(5)))
        np.save("./%s/train_acc.npy" % args.snapshot_pref, train_acc)
        np.save("./%s/train_loss.npy" % args.snapshot_pref, train_loss)
        np.save("./%s/val_acc.npy" % args.snapshot_pref, val_acc)
        np.save("./%s/val_loss.npy" % args.snapshot_pref, val_loss)
        np.save("./%s/val_epochs.npy" % args.snapshot_pref, val_epochs)
        np.save("./%s/epochs.npy" % args.snapshot_pref, epochs)
Пример #7
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    print("------------------------------------")
    print("Environment Versions:")
    print("- Python: {}".format(sys.version))
    print("- PyTorch: {}".format(torch.__version__))
    print("- TorchVison: {}".format(torchvision.__version__))

    args_dict = args.__dict__
    print("------------------------------------")
    print(args.arch + " Configurations:")
    for key in args_dict.keys():
        print("- {}: {}".format(key, args_dict[key]))
    print("------------------------------------")
    print(args.mode)
    if args.dataset == 'ucf101':
        num_class = 101
        rgb_read_format = "{:05d}.jpg"
    elif args.dataset == 'hmdb51':
        num_class = 51
        rgb_read_format = "{:05d}.jpg"
    elif args.dataset == 'kinetics':
        num_class = 400
        rgb_read_format = "{:05d}.jpg"
    elif args.dataset == 'something':
        num_class = 174
        rgb_read_format = "{:05d}.jpg"
    elif args.dataset == 'somethingv2':
        num_class = 174
        rgb_read_format = "img_{:05d}.jpg"
    elif args.dataset == 'NTU_RGBD':
        num_class = 120
        rgb_read_format = "{:05d}.jpg"
    elif args.dataset == 'tinykinetics':
        num_class = 150
        rgb_read_format = "{:05d}.jpg"
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                partial_bn=not args.no_partialbn,
                non_local=args.non_local)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    # Optimizer s also support specifying per-parameter options.
    # To do this, pass in an iterable of dict s.
    # Each of them will define a separate parameter group,
    # and should contain a params key, containing a list of parameters belonging to it.
    # Other keys should match the keyword arguments accepted by the optimizers,
    # and will be used as optimization options for this group.
    policies = model.get_optim_policies(args.dataset)

    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    model_dict = model.state_dict()

    if args.arch == "resnet50":
        new_state_dict = {}  #model_dict
        div = False
        roll = True
    elif args.arch == "resnet34":
        pretrained_dict = {}
        new_state_dict = {}  #model_dict
        for k, v in model_dict.items():
            if ('fc' not in k):
                new_state_dict.update({k: v})
        div = False
        roll = True
    elif (args.arch[:3] == "TCM"):
        pretrained_dict = {}
        new_state_dict = {}  #model_dict
        for k, v in model_dict.items():
            if ('fc' not in k):
                new_state_dict.update({k: v})
        div = True
        roll = False

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 1

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            "",
            args.train_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            mode=args.mode,
            image_tmpl=args.rgb_prefix + rgb_read_format if args.modality
            in ["RGB", "RGBDiff"] else args.flow_prefix + rgb_read_format,
            img_start_idx=args.img_start_idx,
            transform=torchvision.transforms.Compose([
                GroupScale((240, 320)),
                #                        GroupScale(int(scale_size)),
                train_augmentation,
                Stack(roll=roll),
                ToTorchFormatTensor(div=div),
                normalize,
            ])),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            "",
            args.val_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            mode=args.mode,
            image_tmpl=args.rgb_prefix + rgb_read_format if args.modality
            in ["RGB", "RGBDiff"] else args.flow_prefix + rgb_read_format,
            img_start_idx=args.img_start_idx,
            random_shift=False,
            transform=torchvision.transforms.Compose([
                GroupScale((240, 320)),
                #                        GroupScale((224)),
                #                        GroupScale(int(scale_size)),
                GroupCenterCrop(crop_size),
                Stack(roll=roll),
                ToTorchFormatTensor(div=div),
                normalize,
            ])),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()

    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=args.nesterov)

    output_list = []
    if args.evaluate:
        prec1, score_tensor = validate(val_loader,
                                       model,
                                       criterion,
                                       temperature=100)
        output_list.append(score_tensor)
        save_validation_score(output_list, filename='score.pt')
        print("validation score saved in {}".format('/'.join(
            (args.val_output_folder, 'score_inf5.pt'))))
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)
        # train for one epoch
        temperature = train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1, score_tensor = validate(val_loader,
                                           model,
                                           criterion,
                                           temperature=temperature)

            output_list.append(score_tensor)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            print(output_best)

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)

    # save validation score
    save_validation_score(output_list)
    print("validation score saved in {}".format('/'.join(
        (args.val_output_folder, 'score.pt'))))
Пример #8
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if args.dataset == 'ucf101':
        num_class = 101
    elif args.dataset == 'hmdb51':
        num_class = 51
    elif args.dataset == 'kinetics':
        num_class = 400
    elif args.dataset == 'something':
        num_class = 174
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    if args.dataset == 'something':
        img_prefix = ''
    else:
        img_prefix = 'image_'
    with open(os.path.join(args.result_path, 'opts.json'), 'w') as opt_file:
        json.dump(vars(args), opt_file)
    if not (args.consensus_type == 'lstm'
            or args.consensus_type == 'conv_lstm'):
        args.lstm_out_type = None
    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                partial_bn=not args.no_partialbn,
                lstm_out_type=args.lstm_out_type,
                lstm_layers=args.lstm_layers,
                lstm_hidden_dims=args.lstm_hidden_dims,
                conv_lstm_kernel=args.conv_lstm_kernel,
                bi_add_clf=args.bi_add_clf,
                bi_out_dims=args.bi_out_dims,
                bi_rank=args.bi_rank,
                bi_att_softmax=args.bi_att_softmax,
                bi_filter_size=args.bi_filter_size,
                bi_dropout=args.bi_dropout,
                bi_conv_dropout=args.bi_conv_dropout,
                dataset=args.dataset)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()
    # print(model)
    # input('...')

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            # print(model)
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
            # input('...')
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 10
        # data_length = 5

    if args.train_reverse:
        train_temp_transform = ReverseFrames(size=data_length *
                                             args.num_segments)
    elif args.train_shuffle:
        train_temp_transform = ShuffleFrames(size=data_length *
                                             args.num_segments)
    else:
        train_temp_transform = IdentityTransform()
    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            "",
            args.train_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=img_prefix + "{:05d}.jpg" if args.modality
            in ["RGB", "RGBDiff"] else args.flow_prefix + "{}_{:05d}.jpg",
            # image_tmpl="image_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg",
            temp_transform=train_temp_transform,
            transform=torchvision.transforms.Compose([
                train_augmentation,
                Stack(roll=args.arch == 'BNInception'),
                ToTorchFormatTensor(div=args.arch != 'BNInception'),
                normalize,
            ]),
            contrastive_mode=args.contrastive_mode),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True)

    if args.val_reverse:
        val_temp_transform = ReverseFrames(size=data_length *
                                           args.num_segments)
        print('using reverse val')
    elif args.val_shuffle:
        val_temp_transform = ShuffleFrames(size=data_length *
                                           args.num_segments)
        print('using shuffle val')
    else:
        val_temp_transform = IdentityTransform()
        print('using normal val')
    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            "",
            args.val_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=img_prefix + "{:05d}.jpg" if args.modality
            in ["RGB", "RGBDiff"] else args.flow_prefix + "{}_{:05d}.jpg",
            # image_tmpl="image_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg",
            random_shift=False,
            temp_transform=val_temp_transform,
            transform=torchvision.transforms.Compose([
                GroupScale(int(scale_size)),
                GroupCenterCrop(crop_size),
                Stack(roll=args.arch == 'BNInception'),
                ToTorchFormatTensor(div=args.arch != 'BNInception'),
                normalize,
            ])),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        if args.contrastive_mode:
            criterion = ContrastiveLoss(m1=args.contras_m1,
                                        m2=args.contras_m2).cuda()
            val_criterion = torch.nn.CrossEntropyLoss().cuda()
        else:
            criterion = torch.nn.CrossEntropyLoss().cuda()
            val_criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # optimizer = torch.optim.Adagrad(policies,
    # args.lr,
    # weight_decay=args.weight_decay)

    if args.evaluate:
        # val_logger = open(os.path.join(args.result_path, 'test.log'), 'w')
        print('evaluating')
        val_logger = os.path.join(args.result_path, 'test.log')
        validate(val_loader, model, val_criterion, 0, val_logger=val_logger)
        # val_logger.close()
        return

    # train_logger = open(os.path.join(args.result_path, 'train.log'), 'w')
    # val_logger = open(os.path.join(args.result_path, 'val.log'), 'w')
    train_logger = os.path.join(args.result_path, 'train.log')
    val_logger = os.path.join(args.result_path, 'val.log')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        if args.contrastive_mode:
            train_contrastive(train_loader,
                              model,
                              criterion,
                              optimizer,
                              epoch,
                              train_logger=train_logger,
                              args=args)
        else:
            # train for one epoch
            train(train_loader,
                  model,
                  criterion,
                  optimizer,
                  epoch,
                  train_logger=train_logger)
        # train_logger.write('\n')
        with open(train_logger, 'a') as f:
            f.write('\n')

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader,
                             model,
                             val_criterion, (epoch + 1) * len(train_loader),
                             val_logger=val_logger)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Пример #9
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    args.dataset = "thumos"
    args.modality = "RGB"

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
        'thumos', args.modality)
    num_class = len(categories)

    args.store_name = '_'.join([
        'TRN', args.dataset, args.modality, args.arch, args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader), log_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)

    ###############################################################################
    # ALL LINES AFTER THIS REPRESENT NEW CODE WRITTEN TO TRAIN THE FEW-SHOT MODEL #
    ###############################################################################

    for i in range(10):
        print("TRAINING FEW-SHOT MODEL")

    num_fs_class = 14  # number of few shot classes
    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
        'thumos-fs', args.modality)  # load few-shot dataset

    # modify the fully connected layers to fit our new task with 14 classes
    fs_model = model
    fs_model.module.consensus.classifier = nn.Sequential(
        nn.ReLU(), nn.Linear(in_features=768, out_features=512, bias=True),
        nn.ReLU(),
        nn.Linear(in_features=512, out_features=num_fs_class,
                  bias=True)).cuda()

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, fs_model, criterion, 0)
        return

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    args.store_name = '_'.join([
        'fs_TRN', args.dataset, args.modality, args.arch, args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    best_prec1 = 0
    log_fs_training = open(
        os.path.join(args.root_log, '%s.csv' % "fs-logging"), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        torch.cuda.empty_cache()

        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, fs_model, criterion, optimizer, epoch,
              log_fs_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, fs_model, criterion,
                             (epoch + 1) * len(train_loader), log_fs_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': fs_model.state_dict(),
                    'best_prec1': best_prec1,
                },
                is_best,
                filename='fs_checkpoint.pth.tar')
Пример #10
0
def main():
    place = fluid.CUDAPlace(0)
    with fluid.dygraph.guard(place):
        global args
        args = parser.parse_args()

        if args.dataset == 'ucf101':
            args.num_class = 101
        elif args.dataset == 'hmdb51':
            args.num_class = 51
        elif args.dataset == 'kinetics':
            args.num_class = 400
        else:
            raise ValueError('Unknown dataset '+args.dataset)
        
        model = TSN(args.num_class, args.num_segments, 
                    args.modality, args.arch, dropout=args.dropout)
        
        args.short_size = model.scale_size
        args.target_size = model.crop_size
        args.input_mean = model.input_mean
        args.input_std = model.input_std * 3

        if args.pretrained_parts == 'finetune':
            print('***Finetune model with {}***'.format(args.pretrained_model))

            state_dict = fluid.dygraph.load_dygraph(args.pretrained_model)[0]
            model_dict = model.state_dict()
            print('extra keys: {}'.format(set(list(state_dict.keys())) - set(list(model_dict.keys()))))
            print('missing keys: {}'.format(set(list(model_dict.keys())) - set(list(state_dict.keys()))))
            for k, v in state_dict.items():
                if 'fc' not in k:
                    model_dict.update({k:v})
            model.set_dict(model_dict)
        
        optimizer = fluid.optimizer.Momentum(args.lr, args.momentum, model.parameters(), 
        regularization=fluid.regularizer.L2Decay(args.weight_decay), 
        grad_clip=fluid.clip.GradientClipByGlobalNorm(args.clip_gradient))

        train_reader = KineticsReader('train', args, args.train_list).create_reader()
        val_reader = KineticsReader('val', args, args.val_list).create_reader()

        saturate_cnt = 0
        best_prec1 = 0
        log = open(os.path.join(args.log_path, args.save_name+'_train.csv'), 'w')

        for epoch in range(args.epochs):
            if saturate_cnt == args.num_saturate:
                print('learning rate decay by 0.1.')
                log.write('learning rate decay by 0.1. \n')
                log.flush()
                adjust_learing_rate(optimizer)
                saturate_cnt = 0
            train(train_reader, model, optimizer, epoch, log)

            if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
                prec1 = validate(val_reader, model, epoch, log)

                is_best = prec1 > best_prec1
                if is_best:
                    saturate_cnt = 0
                else:
                    saturate_cnt = saturate_cnt + 1
                    output = "- Validation Prec@1 saturates for {} epochs. Best acc{}".format(saturate_cnt, best_prec1)
                    print(output)
                    log.write(output + '\n')
                    log.flush()
                best_prec1 = max(prec1, best_prec1)

                if is_best:
                    fluid.dygraph.save_dygraph(model.state_dict(), os.path.join(args.save_dir, args.save_name))
        log.close()
Пример #11
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
        args.dataset, args.modality)
    num_class = len(categories)
    print("num_class: " + str(num_class))

    args.store_name = '_'.join([
        'TRN', args.dataset, args.modality, args.arch, args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'Flow' or args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    print("Creating val_loader:")
    print("args.root_path: " + str(args.root_path))
    print("args.val_list: " + str(args.val_list))
    print("args.num_segments: " + str(args.num_segments))
    print("data_length: " + str(data_length))
    print("modality: " + str(args.modality))
    print("prefix: " + str(prefix))
    print("scale_size: " + str(int(scale_size)))
    print("crop_size: " + str(crop_size))
    print("args.arch: " + str(args.arch))
    print("args.batch_size: " + str(args.batch_size))
    print("args.workers: " + str(args.workers))
    print("")

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    git_log_output = subprocess.run(
        [
            'git', 'log', '-n1',
            '--pretty=format:commit: %h%nauthor: %an%n%s%n%b'
        ],
        stdout=subprocess.PIPE).stdout.decode('utf-8').split('\n')
    git_diff_output = subprocess.run(
        ['git', 'diff'], stdout=subprocess.PIPE).stdout.decode('utf-8')

    if args.exp_name == '':
        exp_name_match = re.match(r'experiment: *(.+)', git_log_output[2])
        if exp_name_match is None:
            print(
                'Experiment name required:\n'
                '  current commit subject does not specify an experiment, and\n'
                '  --experiment_name was not specified')
            sys.exit(0)
        args.exp_name = exp_name_match.group(1)
    print(f'experiment name: {args.exp_name}')

    time = str(datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))
    exp_dir_path = os.path.join(args.root_log, args.exp_name, time)
    log_file_path = os.path.join(exp_dir_path, f'{args.store_name}.csv')
    print("log_file_path:")
    print(log_file_path)
    os.makedirs(exp_dir_path)
    log_training = open(log_file_path, 'w')
    # store information about git status
    git_info_path = os.path.join(exp_dir_path, 'experiment_info.txt')
    with open(git_info_path, 'w') as f:
        f.write('\n'.join(git_log_output))
        f.write('\n\n' + ('=' * 80) + '\n')
        f.write(git_diff_output)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader), log_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best, time)
Пример #12
0
def main():
    global args, best_prec1, num_train_dataset, num_val_dataset, writer
    args = parser.parse_args()
    # if args.no_cudnn:
    #     torch.backends.cudnn.benchmark = False
    # print (torch.backends.cudnn.benchmark)
    # asdf
    _fill_in_None_args()
    _join_result_path()
    check_rootfolders()
    with open(os.path.join(args.result_path, 'opts.json'), 'w') as opt_file:
        json.dump(vars(args), opt_file)

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
        args.dataset, args.modality, args.root_path, args.file_type)
    # print(categories, args.train_list, args.val_list, args.root_path, prefix)
    num_class = len(categories)


    args.store_name = '_'.join([args.consensus_type, args.dataset, args.modality, args.arch, args.consensus_type, 'segment%d'% args.num_segments, \
        'key%d'%args.key_dim, 'value%d'%args.value_dim, 'query%d'%args.query_dim, 'queryUpdateby%s'%args.query_update_method,\
        'NoSoftmax%s'%args.no_softmax_on_p, 'hopMethod%s'%args.hop_method])
    print('storing name: ' + args.store_name)

    model = TSN(
        num_class,
        args.num_segments,
        args.modality,
        base_model=args.arch,
        consensus_type=args.consensus_type,
        dropout=args.dropout,
        key_dim=args.key_dim,
        value_dim=args.value_dim,
        query_dim=args.query_dim,
        query_update_method=args.query_update_method,
        partial_bn=not args.no_partialbn,
        freezeBN_Eval=args.freezeBN_Eval,
        freezeBN_Require_Grad_True=args.freezeBN_Require_Grad_True,
        num_hop=args.hop,
        hop_method=args.hop_method,
        num_CNNs=args.num_CNNs,
        no_softmax_on_p=args.no_softmax_on_p,
        freezeBackbone=args.freezeBackbone,
        CustomPolicy=args.CustomPolicy,
        sorting=args.sorting,
        MultiStageLoss=args.MultiStageLoss,
        MultiStageLoss_MLP=args.MultiStageLoss_MLP,
        how_to_get_query=args.how_to_get_query,
        only_query=args.only_query,
        CC=args.CC,
        channel=args.channel,
        memory_dim=args.memory_dim,
        image_resolution=args.image_resolution,
        how_many_objects=args.how_many_objects,
        Each_Embedding=args.Each_Embedding,
        Curriculum=args.Curriculum,
        Curriculum_dim=args.Curriculum_dim,
        lr_steps=args.lr_steps,
    )

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))
    # asdf
    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_data = TSNDataSet(
        args.root_path,
        args.train_list,
        args.file_type,
        num_segments=args.num_segments,
        MoreAug_Rotation=args.MoreAug_Rotation,
        MoreAug_ColorJitter=args.MoreAug_ColorJitter,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        phase='train',
        transform1=torchvision.transforms.Compose([
            train_augmentation,  # GroupMultiScaleCrop[1, .875, .75, .66] AND GroupRandomHorizontalFlip
        ]),
        transform2=torchvision.transforms.Compose([
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,  # GroupNormalize
        ]),
        image_resolution=args.image_resolution)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=False,
                                               drop_last=True)

    val_data = TSNDataSet(
        args.root_path,
        args.val_list,
        args.file_type,
        num_segments=args.num_segments,
        MoreAug_Rotation=args.MoreAug_Rotation,
        MoreAug_ColorJitter=args.MoreAug_ColorJitter,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        phase='test',
        transform1=torchvision.transforms.Compose(
            [GroupScale(int(scale_size)),
             GroupCenterCrop(crop_size)]),
        transform2=torchvision.transforms.Compose([
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        image_resolution=args.image_resolution)
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=False,
                                             drop_last=True)
    num_train_dataset = len(train_data)
    num_val_dataset = len(val_data)

    # print (num_train_dataset, num_val_dataset)
    # print (len(train_loader), len(val_loader))
    # asdf

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss(reduce=False).cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(policies,
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(policies,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
        # optimizer = torch.optim.SGD(policies,
        #                             args.lr,
        #                             momentum=args.momentum,
        #                             weight_decay=args.weight_decay)

    if args.evaluate:
        json_file_path = os.path.join(
            args.result_path, 'results_epoch%d.json' % args.evaluation_epoch)
        validate(val_loader,
                 model,
                 criterion,
                 0,
                 json_file=json_file_path,
                 idx2class=categories,
                 epoch=args.evaluation_epoch)
        return

    writer = SummaryWriter(args.result_path)
    log_training = open(
        os.path.join(args.root_log, '%s.csv' % args.store_name), 'a')
    # print (count_parameters(model))
    # asdf

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            json_file_path = os.path.join(args.result_path,
                                          'results_epoch%d.json' % (epoch + 1))
            # prec1 = validate(val_loader, model, criterion, (epoch + 1) * len(train_loader), log=log_training, json_file=json_file_path, idx2class=categories)
            prec1 = validate(val_loader,
                             model,
                             criterion, (epoch + 1) * num_train_dataset,
                             log=log_training,
                             json_file=json_file_path,
                             idx2class=categories,
                             epoch=epoch)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
    log_training.close()
    writer.close()
Пример #13
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if not os.path.exists(args.record_path + args.modality.lower()):
        os.mkdir(args.record_path + args.modality.lower())

    num_class = 2

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    # model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
    model = model.cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_set = TSNDataSet(
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="frame_{:06d}.jpg"
        if args.modality in ["RGB", "RGBDiff"] else "{:06d}.jpg",
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ]))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_set = TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="frame_{:06d}.jpg"
        if args.modality in ["RGB", "RGBDiff"] else "{:06d}.jpg",
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ]))
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer

    criterion = torch.nn.CrossEntropyLoss().cuda()

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1, pred_dict = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        if is_best:
            with open(
                    args.record_path + args.modality.lower() + '/' +
                    args.snapshot_pref + args.modality.lower() +
                    '_video_preds.pickle', 'wb') as f:
                pickle.dump(pred_dict, f)
                f.close()

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)
Пример #14
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(args.dataset, args.modality)
    num_class = len(categories)


    args.store_name = '_'.join(['TRN', args.dataset, args.modality, args.arch, args.consensus_type, 'segment%d'% args.num_segments])
    print('storing name: ' + args.store_name)

    model = TSN(2, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn)

    checkpoint = torch.load('pretrain/TRN_somethingv2_RGB_BNInception_TRNmultiscale_segment8_best.pth.tar', map_location='cpu')
    base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint['state_dict'].items())}
    for key in ['consensus.fc_fusion_scales.6.3.bias', 'consensus.fc_fusion_scales.5.3.bias',
                'consensus.fc_fusion_scales.4.3.bias',
                'consensus.fc_fusion_scales.3.3.bias', 'consensus.fc_fusion_scales.2.3.bias',
                'consensus.fc_fusion_scales.1.3.bias',
                'consensus.fc_fusion_scales.0.3.bias', 'consensus.fc_fusion_scales.6.3.weight',
                'consensus.fc_fusion_scales.5.3.weight',
                'consensus.fc_fusion_scales.4.3.weight', 'consensus.fc_fusion_scales.3.3.weight',
                'consensus.fc_fusion_scales.2.3.weight',
                'consensus.fc_fusion_scales.1.3.weight', 'consensus.fc_fusion_scales.0.3.weight']:
        del base_dict[key]
    # print(base_dict)
    model.load_state_dict(base_dict, strict=False)
    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   transform=torchvision.transforms.Compose([
                       train_augmentation,
                       Stack(roll=(args.arch in ['BNInception','InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception','InceptionV3'])),
                       normalize,
                   ])),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    # val_loader = torch.utils.data.DataLoader(
    #     TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
    #                new_length=data_length,
    #                modality=args.modality,
    #                image_tmpl=prefix,
    #                random_shift=False,
    #                transform=torchvision.transforms.Compose([
    #                    GroupScale(int(scale_size)),
    #                    GroupCenterCrop(crop_size),
    #                    Stack(roll=(args.arch in ['BNInception','InceptionV3'])),
    #                    ToTorchFormatTensor(div=(args.arch not in ['BNInception','InceptionV3'])),
    #                    normalize,
    #                ])),
    #     batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        weight = torch.ones([2]).cuda()
        weight[0] = 1.2
        pos_weight = torch.ones([2]).cuda()
        #pos_weight[0] = 2
        criterion = torch.nn.BCEWithLogitsLoss(weight = weight, pos_weight=pos_weight).cuda() 
        #criterion = torch.nn.CrossEntropyLoss().cuda()
        
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))
        
    optimizer = torch.optim.SGD(policies,
                                0.0001,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # if args.evaluate:
    #     validate(val_loader, model, criterion, 0)
    #     return

    log_training = open(os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)
        torch.save(model.state_dict(), 'checkpoint_bce_20_w12_{}.pth.tar'.format(epoch))
        torch.save(model.state_dict(), 'checkpoint_bce_20_w12_{}.pth'.format(epoch))
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)
Пример #15
0
def main():
    logger.auto_set_dir()

    global args, best_prec1

    import argparse
    parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks")
    parser.add_argument('--dataset', type=str,default="something", choices=['something', 'jester', 'moments'])
    parser.add_argument('--modality', type=str, default="RGB", choices=['RGB', 'Flow'])
    parser.add_argument('--train_list', type=str, default="")
    parser.add_argument('--val_list', type=str, default="")
    parser.add_argument('--root_path', type=str, default="")
    parser.add_argument('--store_name', type=str, default="")
    # ========================= Model Configs ==========================
    parser.add_argument('--arch', type=str, default="BNInception")
    parser.add_argument('--num_segments', type=int, default=3)
    parser.add_argument('--consensus_type', type=str, default='avg')
    parser.add_argument('--k', type=int, default=3)

    parser.add_argument('--dropout', '--do', default=0.8, type=float,
                        metavar='DO', help='dropout ratio (default: 0.5)')
    parser.add_argument('--loss_type', type=str, default="nll",
                        choices=['nll'])
    parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame")

    # ========================= Learning Configs ==========================
    parser.add_argument('--epochs', default=120, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-b', '--batch_size', default=128, type=int,
                        metavar='N', help='mini-batch size (default: 256)')
    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
                        metavar='LR', help='initial learning rate')
    parser.add_argument('--lr_steps', default=[50, 100], type=float, nargs="+",
                        metavar='LRSteps', help='epochs to decay learning rate by 10')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
                        metavar='W', help='weight decay (default: 5e-4)')
    parser.add_argument('--clip-gradient', '--gd', default=20, type=float,
                        metavar='W', help='gradient norm clipping (default: disabled)')
    parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true")

    # ========================= Monitor Configs ==========================
    parser.add_argument('--print-freq', '-p', default=20, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--eval-freq', '-ef', default=5, type=int,
                        metavar='N', help='evaluation frequency (default: 5)')

    # ========================= Runtime Configs ==========================
    parser.add_argument('-j', '--workers', default=30, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--snapshot_pref', type=str, default="")
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--gpu', type=str, default='4')
    parser.add_argument('--flow_prefix', default="", type=str)
    parser.add_argument('--root_log', type=str, default='log')
    parser.add_argument('--root_model', type=str, default='model')
    parser.add_argument('--root_output', type=str, default='output')

    args = parser.parse_args()

    args.consensus_type = "TRN"
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    device_ids = [int(id) for id in args.gpu.split(',')]
    assert len(device_ids) >1, "TRN must run with GPU_num > 1"

    args.root_log = logger.get_logger_dir()
    args.root_model = logger.get_logger_dir()
    args.root_output = logger.get_logger_dir()

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(args.dataset, args.modality)
    num_class = len(categories)


    args.store_name = '_'.join(['TRN', args.dataset, args.modality, args.arch, args.consensus_type, 'segment%d'% args.num_segments])
    print('storing name: ' + args.store_name)

    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)#TODO, , device_ids=[int(id) for id in args.gpu.split(',')]

    if torch.cuda.is_available():
       model.cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   transform=torchvision.transforms.Compose([
                       train_augmentation,
                       Stack(roll=(args.arch in ['BNInception','InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception','InceptionV3'])),
                       normalize,
                   ])),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   random_shift=False,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=(args.arch in ['BNInception','InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception','InceptionV3'])),
                       normalize,
                   ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        logger.info('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, (epoch + 1) * len(train_loader), log_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)
Пример #16
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if args.dataset == 'ucf101':
        num_class = 101
    elif args.dataset == 'hmdb51':
        num_class = 51
    elif args.dataset == 'kinetics':
        num_class = 400
    elif args.dataset == 'movie':
        num_class = 21
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            #best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(TSNDataSetMovie(
        "",
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="frame_{:04d}.jpg" if args.modality in ["RGB", "RGBDiff"]
        else args.flow_prefix + "{}_{:05d}.jpg",
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSetMovie(
        "",
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="frame_{:04d}.jpg" if args.modality in ["RGB", "RGBDiff"]
        else args.flow_prefix + "{}_{:05d}.jpg",
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                             batch_size=int(args.batch_size /
                                                            2),
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    #if args.loss_type == 'nll':
    #criterion = torch.nn.CrossEntropyLoss().cuda()
    #else:
    #raise ValueError("Unknown loss type")
    #class_weight = torch.tensor([1] * 21).cuda().float()
    #pos_weight = torch.tensor([1] * 21).cuda().float()
    criterion = torch.nn.BCEWithLogitsLoss().cuda()
    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    zero_time = time.time()
    best_map = 0
    print('Start training...')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        start_time = time.time()
        trainloss = train(train_loader, model, criterion, optimizer, epoch)
        print('Traing loss %4f Epoch %d' % (trainloss, epoch))
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            valloss, mAP, wAP, output_mtx = validate(val_loader, model,
                                                     criterion)
            end_time = time.time()
            epoch_time = end_time - start_time
            total_time = end_time - zero_time
            print('Total time used: %s Epoch %d time uesd: %s' %
                  (str(datetime.timedelta(seconds=int(total_time))), epoch,
                   str(datetime.timedelta(seconds=int(epoch_time)))))
            print(
                'Train loss: {0:.4f} val loss: {1:.4f} mAP: {2:.4f} wAP: {3:.4f}'
                .format(trainloss, valloss, mAP, wAP))
            # evaluate on validation set
            is_best = mAP > best_map
            if mAP > best_map:
                best_map = mAP
                # checkpoint_name = "%04d_%s" % (epoch+1, "checkpoint.pth.tar")
                checkpoint_name = "best_checkpoint.pth.tar"
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, is_best, epoch)
                npy_name = str(epoch) + args.result_path
                np.save(npy_name, output_mtx)
            with open(args.record_path, 'a') as file:
                file.write(
                    'Epoch:[{0}]'
                    'Train loss: {1:.4f} val loss: {2:.4f} map: {3:.4f}\n'.
                    format(epoch + 1, trainloss, valloss, mAP))

    print('************ Done!... ************')
Пример #17
0
def main():
    check_rootfolders()
    global best_prec1
    if args.run_for == 'train':
        categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
            args.dataset, args.modality)
    elif args.run_for == 'test':
        categories, args.test_list, args.root_path, prefix = datasets_video.return_data(
            args.dataset, args.modality)

    num_class = len(categories)

    args.store_name = '_'.join([
        'STModeling', args.dataset, args.modality, args.arch,
        args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    model = TSN(num_class, args)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    train_augmentation = model.get_augmentation()

    policies = model.get_optim_policies()
    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            # best_prec1 = 0
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    #print(model)
    cudnn.benchmark = True

    # Data loading code
    if ((args.modality != 'RGBDiff') | (args.modality != 'RGBFlow')):
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5
    elif args.modality == 'RGBFlow':
        data_length = args.num_motion

    if args.run_for == 'train':
        train_loader = torch.utils.data.DataLoader(TSNDataSet(
            "/home/machine/PROJECTS/OTHER/DATASETS/kussaster/data",
            args.train_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            dataset=args.dataset,
            transform=torchvision.transforms.Compose([
                train_augmentation,
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3']),
                      isRGBFlow=(args.modality == 'RGBFlow')),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=False)

        val_loader = torch.utils.data.DataLoader(TSNDataSet(
            "/home/machine/PROJECTS/OTHER/DATASETS/kussaster/data",
            args.val_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            dataset=args.dataset,
            random_shift=False,
            transform=torchvision.transforms.Compose([
                GroupScale(int(scale_size)),
                GroupCenterCrop(crop_size),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3']),
                      isRGBFlow=(args.modality == 'RGBFlow')),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=False)

        # define loss function (criterion) and optimizer
        if args.loss_type == 'nll':
            criterion = torch.nn.CrossEntropyLoss().cuda()
        else:
            raise ValueError("Unknown loss type")

        for group in policies:
            print(
                ('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
                    group['name'], len(group['params']), group['lr_mult'],
                    group['decay_mult'])))

        optimizer = torch.optim.SGD(policies,
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        if args.consensus_type == 'DNDF':
            params = [p for p in model.parameters() if p.requires_grad]
            optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=1e-5)

        if args.evaluate:
            validate(val_loader, model, criterion, 0)
            return

        log_training = open(
            os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
        for epoch in range(args.start_epoch, args.epochs):
            if not args.consensus_type == 'DNDF':
                adjust_learning_rate(optimizer, epoch, args.lr_steps)

            # train for one epoch
            train(train_loader, model, criterion, optimizer, epoch,
                  log_training)

            # evaluate on validation set
            if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
                prec1 = validate(val_loader, model, criterion,
                                 (epoch + 1) * len(train_loader), log_training)

                # remember best prec@1 and save checkpoint
                is_best = prec1 > best_prec1
                best_prec1 = max(prec1, best_prec1)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'best_prec1': best_prec1,
                    }, is_best)

    elif args.run_for == 'test':
        print("=> loading checkpoint '{}'".format(args.root_weights))
        checkpoint = torch.load(args.root_weights)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        model.cuda().eval()
        print("=> loaded checkpoint ")

        test_loader = torch.utils.data.DataLoader(TSNDataSet(
            "/home/machine/PROJECTS/OTHER/DATASETS/kussaster/data",
            args.test_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            dataset=args.dataset,
            random_shift=False,
            transform=torchvision.transforms.Compose([
                GroupScale(int(scale_size)),
                GroupCenterCrop(crop_size),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3']),
                      isRGBFlow=(args.modality == 'RGBFlow')),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ])),
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=args.workers,
                                                  pin_memory=False)

        # cam = cv2.VideoCapture(0)
        # cam.set(cv2.CAP_PROP_FPS, 48)

        # for i, (input, _) in enumerate(test_loader):
        #     with torch.no_grad():
        #         input_var = torch.autograd.Variable(input)
        #
        # ret, frame = cam.read()
        # frame_map = np.full((280, 640, 3), 0, np.uint8)
        # frame_map = frame
        # print(frame_map)
        # while (True):
        #     bg = np.full((480, 1200, 3), 15, np.uint8)
        #     bg[:480, :640] = frame
        #
        #     font = cv2.FONT_HERSHEY_SIMPLEX
        #     # cv2.rectangle(bg, (128, 48), (640 - 128, 480 - 48), (0, 255, 0), 3)
        #
        #     cv2.imshow('preview', bg)
        #
        #     if cv2.waitKey(1) & 0xFF == ord('q'):
        #         break

        test(test_loader, model, categories)
Пример #18
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print("args args args")
    print(args)
    check_rootfolders()

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
        args.dataset, args.modality)
    num_class = len(categories)

    args.store_name = '_'.join([
        'TRN', args.dataset, args.modality, args.arch, args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    # optimizer = torch.optim.SGD(policies,
    #                             args.lr,
    #                             momentum=args.momentum,
    #                             weight_decay=args.weight_decay)

    optimizer = torch.optim.Adam(policies,
                                 lr=args.lr,
                                 betas=(0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, '%s_adam.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader), log_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Пример #19
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    #    if args.dataset == 'ucf101':
    #        num_class = 101
    #    elif args.dataset == 'hmdb51':
    #        num_class = 51
    #    elif args.dataset == 'kinetics':
    #        num_class = 400
    #    elif args.dataset == 'kaist':
    #        num_class = 4
    #    elif args.dataset == 'ma':
    #        num_class = 5
    #    else:
    #        raise ValueError('Unknown dataset '+args.dataset)
    num_class = args.num_class

    model = TSN(num_class,
                args.num_segments * args.num_spacial_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint[
                'epoch']  ##lz: comment these lines when fine-tune (not resume)
            best_prec1 = checkpoint[
                'best_prec1']  ##lz: comment these lines when fine-tune (not resume)
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            args.root_path,
            args.train_list,
            train_val_test='train',
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl="-{:04d}.jpg" if args.modality in ["RGB", "RGBDiff"]
            else args.flow_prefix + "_{}-{:04d}.jpg",
            transform=torchvision.transforms.Compose([
                #                       GroupCrop(300,(200,50)),# tree
                #                       GroupCrop(224,(250,50)),# tree
                #                       GroupCrop(224,(1000,0)),# sky
                #                       GroupRandomCrop(256),
                #                       ImgRandomCrop(256),
                #                       GroupNRandomCrop(224, 4), # r4TSSN
                #                       GroupMbyNCrop(2, 2, 720, 1280), # 2x2TSSN
                #                       GroupMbyNCrop(4, 4, 720, 1280), # 4x4TSSN
                GroupMbyNRandomCrop(1, 2, 2, 720, 1280),  # one out of 2x2TSSN
                #                       GroupColorJitter(brightness=0.4, contrast=0, saturation=0, hue=0),
                ImgColorJitter(brightness=0.4, contrast=0, saturation=0,
                               hue=0),
                train_augmentation,
                Stack(roll=args.arch == 'BNInception'),
                ToTorchFormatTensor(div=args.arch != 'BNInception'),
                normalize,
            ])),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            args.root_path,
            args.train_list,
            train_val_test='val',
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl="-{:04d}.jpg" if args.modality in ["RGB", "RGBDiff"]
            else args.flow_prefix + "_{}-{:04d}.jpg",
            random_shift=False,
            transform=torchvision.transforms.Compose([
                #                       GroupCrop(300,(200,50)),# tree
                #                       GroupCrop(224,(250,50)),# tree
                #                       GroupCrop(224,(1000,0)),# sky
                #                       GroupRandomCrop(256),
                #                       ImgRandomCrop(256),
                #                       GroupNRandomCrop(224, 4), # r4TSSN
                GroupMbyNCrop(2, 2, 720, 1280),  # 2x2TSSN
                #                       GroupMbyNCrop(4, 4, 720, 1280), # 4x4TSSN
                #                       GroupMbyNRandomCrop(1, 2, 2, 720, 1280), # one out of 2x2TSSN
                #                       GroupColorJitter(brightness=0.4, contrast=0, saturation=0, hue=0),
                GroupScale(int(scale_size)),
                GroupCenterCrop(crop_size),
                Stack(roll=args.arch == 'BNInception'),
                ToTorchFormatTensor(div=args.arch != 'BNInception'),
                normalize,
            ])),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda(
        )  # This criterion expects a class index (0 to C-1)
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader))

            # remember best prec@1 and save checkpoint
            is_best = prec1 >= best_prec1  # need = to avoid lucky 100%
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Пример #20
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if args.dataset == 'ucf101':
        num_class = 101
    elif args.dataset == 'hmdb51':
        num_class = 51
    elif args.dataset == 'kinetics':
        num_class = 400
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5
    else:
        data_length = 5  # generate 5 displacement map, using 6 RGB images

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                new_length=data_length)
    model = model.to(device)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    train_augmentation = model.get_augmentation()
    if device.type == 'cuda':
        model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        "",
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="img_{:05d}.jpg" if args.modality
        in ["RGB", "RGBDiff", "CV"] else args.flow_prefix + "{}_{:05d}.jpg",
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        "",
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="img_{:05d}.jpg" if args.modality
        in ["RGB", "RGBDiff", "CV"] else args.flow_prefix + "{}_{:05d}.jpg",
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss().to(device)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     args.lr_steps,
                                                     gamma=0.1)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    for epoch in range(0, args.epochs):
        scheduler.step()
        if epoch < args.start_epoch:
            continue

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, epoch)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
    writer.close()
Пример #21
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    print("------------------------------------")
    print("Environment Versions:")
    print("- Python: {}".format(sys.version))
    print("- PyTorch: {}".format(torch.__version__))
    print("- TorchVison: {}".format(torchvision.__version__))

    args_dict = args.__dict__
    print("------------------------------------")
    print(args.arch + " Configurations:")
    for key in args_dict.keys():
        print("- {}: {}".format(key, args_dict[key]))
    print("------------------------------------")

    if args.dataset == 'ucf101':
        num_class = 101
        rgb_read_format = "{:05d}.jpg"
    elif args.dataset == 'hmdb51':
        num_class = 51
    elif args.dataset == 'kinetics':
        num_class = 400
        rgb_read_format = "{:05d}.jpg"
    elif args.dataset == 'something':
        num_class = 174
        rgb_read_format = "{:04d}.jpg"
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    model = TSN(num_class,
                args.num_segments,
                args.pretrained_parts,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std

    # Optimizer s also support specifying per-parameter options.
    # To do this, pass in an iterable of dict s.
    # Each of them will define a separate parameter group,
    # and should contain a params key, containing a list of parameters belonging to it.
    # Other keys should match the keyword arguments accepted by the optimizers,
    # and will be used as optimization options for this group.
    policies = model.get_optim_policies()

    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    model_dict = model.state_dict()

    print("pretrained_parts: ", args.pretrained_parts)

    if args.arch == "ECO":
        new_state_dict = init_ECO(model_dict)
    if args.arch == "ECOfull":
        new_state_dict = init_ECOfull(model_dict)
    elif args.arch == "C3DRes18":
        new_state_dict = init_C3DRes18(model_dict)

    un_init_dict_keys = [
        k for k in model_dict.keys() if k not in new_state_dict
    ]
    print("un_init_dict_keys: ", un_init_dict_keys)
    print("\n------------------------------------")

    for k in un_init_dict_keys:
        new_state_dict[k] = torch.DoubleTensor(model_dict[k].size()).zero_()
        if 'weight' in k:
            if 'bn' in k:
                print("{} init as: 1".format(k))
                constant_(new_state_dict[k], 1)
            else:
                print("{} init as: xavier".format(k))
                xavier_uniform_(new_state_dict[k])
        elif 'bias' in k:
            print("{} init as: 0".format(k))
            constant_(new_state_dict[k], 0)

    print("------------------------------------")

    model.load_state_dict(new_state_dict)

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        "",
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=args.rgb_prefix + rgb_read_format if args.modality
        in ["RGB", "RGBDiff"] else args.flow_prefix + rgb_read_format,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=True),
            ToTorchFormatTensor(div=False),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        "",
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=args.rgb_prefix + rgb_read_format if args.modality
        in ["RGB", "RGBDiff"] else args.flow_prefix + rgb_read_format,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=True),
            ToTorchFormatTensor(div=False),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=args.nesterov)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader))

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Пример #22
0
import torch
import paddle.fluid as fluid
from collections import OrderedDict
from models import TSN

place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
    model = TSN(101, 16, 'finetune', 'RGB', 'ECOfull')
    model_dict = model.state_dict()

load_path = ''
save_path = ''
torch_weight = torch.load(load_path)
torch_dict = torch_weight['state_dict']
paddle_dict = OrderedDict()

for k, v in torch_dict.items():
    k = k[7:]  # remove 'module.'
    if 'tracked' in k or 'fc' in k:
        continue
    if 'running_mean' in k:
        k = k.replace('running_mean', '_mean')
    if 'running_var' in k:
        k = k.replace('running_var', '_variance')
    paddle_dict[k] = v.detach().numpy()

keys1 = set(list(paddle_dict.keys()))
keys2 = set(list(model_dict.keys()))

print('extra keys: {}'.format(keys1 - keys2))
print('missing keys: {}'.format(keys2 - keys1))
Пример #23
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(args.dataset,
                                                                                                       args.modality)
    num_class = len(categories)

    args.store_name = '_'.join(['STModeling', args.dataset, args.modality, args.arch, args.consensus_type,
                                'segment%d' % args.num_segments])
    print('storing name: ' + args.store_name)

    model = TSN(num_class, args)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    train_augmentation = model.get_augmentation()

    policies = model.get_optim_policies()
    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            # best_prec1 = 0
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})"
                   .format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    print(model)
    cudnn.benchmark = True

    # Data loading code
    if ((args.modality != 'RGBDiff') | (args.modality != 'RGBFlow')):
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5
    elif args.modality == 'RGBFlow':
        data_length = args.num_motion

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   dataset=args.dataset,
                   transform=torchvision.transforms.Compose([
                       train_augmentation,
                       Stack(roll=(args.arch in ['BNInception', 'InceptionV3']),
                             isRGBFlow=(args.modality == 'RGBFlow')),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                       normalize,
                   ])),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=False)

    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   dataset=args.dataset,
                   random_shift=False,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=(args.arch in ['BNInception', 'InceptionV3']),
                             isRGBFlow=(args.modality == 'RGBFlow')),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                       normalize,
                   ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=False)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.consensus_type == 'DNDF':
        params = [p for p in model.parameters() if p.requires_grad]
        optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=1e-5)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    history = {
        'accuracy': [],
        'val_accuracy': [],
        'loss': [],
        'val_loss': []
    }
    model_details = {
        'backbone': args.arch,
        'transformer_arch': args.consensus_type,
        'lr': args.lr,
        'batch_size': args.batch_size
    }
    for epoch in range(args.start_epoch, args.epochs):
        if not args.consensus_type == 'DNDF':
            adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        acc, loss = train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1, val_loss = validate(val_loader, model, criterion, (epoch + 1) * len(train_loader), log_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)
        
        history['accuracy'].append(acc)
        history['loss'].append(loss)
        history['val_accuracy'].append(prec1)
        history['val_loss'].append(val_loss)
    plot_utils.plot_statistics(history,model_details)
Пример #24
0
def main():

    global args, best_prec1
    num_class = 4
    rgb_read_format = "{:d}.jpg"

    model = TSN(num_class,
                args.num_segments,
                args.pretrained_parts,
                'RGB',
                base_model='ECO',
                consensus_type='identity',
                dropout=0.3,
                partial_bn=True)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std

    # Optimizer s also support specifying per-parameter options.
    # To do this, pass in an iterable of dict s.
    # Each of them will define a separate parameter group,
    # and should contain a params key, containing a list of parameters belonging to it.
    # Other keys should match the keyword arguments accepted by the optimizers,
    # and will be used as optimization options for this group.
    policies = model.get_optim_policies()

    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=[0, 1]).cuda()

    model_dict = model.state_dict()

    print("pretrained_parts: ", args.pretrained_parts)

    model_dir = args.model_path
    new_state_dict = torch.load(model_dir)['state_dict']

    un_init_dict_keys = [
        k for k in model_dict.keys() if k not in new_state_dict
    ]
    print("un_init_dict_keys: ", un_init_dict_keys)
    print("\n------------------------------------")

    for k in un_init_dict_keys:
        new_state_dict[k] = torch.DoubleTensor(model_dict[k].size()).zero_()
        if 'weight' in k:
            if 'bn' in k:
                print("{} init as: 1".format(k))
                constant_(new_state_dict[k], 1)
            else:
                print("{} init as: xavier".format(k))
                xavier_uniform_(new_state_dict[k])
        elif 'bias' in k:
            print("{} init as: 0".format(k))
            constant_(new_state_dict[k], 0)

    print("------------------------------------")

    model.load_state_dict(new_state_dict)

    cudnn.benchmark = True

    # Data loading code
    normalize = GroupNormalize(input_mean, input_std)
    data_length = 1

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        "",
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality='RGB',
        image_tmpl=rgb_read_format,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=True),
            ToTorchFormatTensor(div=False),
            normalize,
        ])),
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=1,
                                             pin_memory=True)

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    model.eval()
    for i, (input, target) in enumerate(val_loader):
        target = target.cuda()
        input_var = input
        target_var = target
        output = model(input_var)
        _, pred = output.data.topk(1, 1, True, True)
        print(pred, target)
    print('done')
Пример #25
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    categories, args.train_list, args.val_list, args.test_list, args.root_path, prefix = datasets_video.return_dataset(
        args.dataset, args.modality)
    num_class = len(categories)

    args.store_name = '_'.join([
        'TRN', args.dataset, args.modality, args.arch, args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    # Four types of input modalities for two-stream ConvNets (one stream spatial and the other temporal): a single RGB image, stacked RGB difference,
    # stacked optical flow field, and stacked warped optical flow field;  the spatial stream ConvNet operates on a single RGB images,
    # and the temporal stream ConvNet takes a stack of consecutive optical flow fields as input.
    # A single RGB image usually encodes static appearance at a specific time point and lacks the contextual information about previous and next frames.
    # RGB difference between two consecutive frames describe the appearance change, which may correspond to the motion salient region.
    # Optical flow fields may not concentrate on the human action; the warped optical flow suppresses the background motion and makes motion concentrate
    # on the actor.

    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

# Division between train and val set

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            args.root_path,
            args.train_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            transform=torchvision.transforms.Compose([
                train_augmentation,
                Stack(
                    roll=(args.arch in ['BNInception', 'InceptionV3'])
                ),  # Batch-Normalization-Inception, InceptionV3: evolution of InceptionV2 of GoogleNet
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ])),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader), log_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Пример #26
0
def main():

    torch.set_printoptions(precision=6)

    global args, best_prec1
    args = parser.parse_args()
    #导入参数设置数据集类数量
    if args.dataset == 'ucf101':
        num_class = 101
    elif args.dataset == 'hmdb51':
        num_class = 51
    elif args.dataset == 'kinetics':
        num_class = 400
    elif args.dataset == 'cad':
        num_class = 8
    else:
        raise ValueError('Unknown dataset ' + args.dataset)
    """
    #导入模型,输入包含分类的类别数:
    # num_class;args.num_segments表示把一个video分成多少份,对应论文中的K,默认K=3;
    # 采用哪种输入:args.modality,比如RGB表示常规图像,Flow表示optical flow等;
    # 采用哪种模型:args.arch,比如resnet101,BNInception等;
    # 不同输入snippet的融合方式:args.consensus_type,比如avg等;
    # dropout参数:args.dropout。
    """
    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
    """
    接着main函数的思路,前面这几行都是在TSN类中定义的变量或者方法,model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()是设置多GPU训练模型。
    args.resume这个参数主要是用来设置是否从断点处继续训练,比如原来训练模型训到一半停止了,希望继续从保存的最新epoch开始训练,
    因此args.resume要么是默认的None,要么就是你保存的模型文件(.pth)的路径。
    其中checkpoint = torch.load(args.resume)是用来导入已训练好的模型。
    model.load_state_dict(checkpoint[‘state_dict’])是完成导入模型的参数初始化model这个网络的过程,load_state_dict是torch.nn.Module类中重要的方法之一。

    """
    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5
    """
    接下来是main函数中的第二部分:数据导入。首先是自定义的TSNDataSet类用来处理最原始的数据,返回的是torch.utils.data.Dataset类型,
    一般而言在PyTorch中自定义的数据读取类都要继承torch.utils.data.Dataset这个基类,比如此处的TSNDataSet类,然后通过重写初始化函数__init__和__getitem__方法来读取数据。
    torch.utils.data.Dataset类型的数据并不能作为模型的输入,还要通过torch.utils.data.DataLoader类进一步封装,
    这是因为数据读取类TSNDataSet返回两个值,第一个值是Tensor类型的数据,第二个值是int型的标签,
    而torch.utils.data.DataLoader类是将batch size个数据和标签分别封装成一个Tensor,从而组成一个长度为2的list。
    对于torch.utils.data.DataLoader类而言,最重要的输入就是TSNDataSet类的初始化结果,其他如batch size和shuffle参数是常用的。通过这两个类读取和封装数据,后续再转为Variable就能作为模型的输入了。

    """

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        "",
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else
        args.flow_prefix + "{}_{:05d}.jpg",
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=3,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        "",
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else
        args.flow_prefix + "{}_{:05d}.jpg",
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=3,
                                             pin_memory=True)
    """
    接下来就是main函数的第三部分:训练模型。这里包括定义损失函数、优化函数、一些超参数设置等,然后训练模型并在指定epoch验证和保存模型。
    adjust_learning_rate(optimizer, epoch, args.lr_steps)是设置学习率变化策略,args.lr_steps是一个列表,里面的值表示到达多少个epoch的时候要改变学习率,
    在adjust_learning_rate函数中,默认是修改学习率的时候修改成当前的0.1倍。
    train(train_loader, model, criterion, optimizer, epoch)就是训练模型,输入包含训练数据、模型、损失函数、优化函数和要训练多少个epoch。
    最后的if语句是当训练epoch到达指定值的时候就进行一次模型验证和模型保存,args.eval_freq这个参数就是用来控制保存的epoch值。
    prec1 = validate(val_loader, model, criterion, (epoch + 1) * len(train_loader))就是用训练好的模型验证测试数据集。
    最后的save_checkpoint函数就是保存模型参数(model)和其他一些信息,这里我对源代码做了修改,希望有助于理解,该函数中主要就是调用torch.save(mode, save_path)来保存模型。
    模型训练函数train和模型验证函数validate函数是重点,后面详细介绍。

    """
    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))
    '''
    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    '''
    # try Adam instead.
    optimizer = torch.optim.Adam(policies, args.lr)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader))

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Пример #27
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if args.dataset == 'ucf101':
        num_class = 101
    elif args.dataset == 'hmdb51':
        num_class = 51
    elif args.dataset == 'kinetics':
        num_class = 400
    else:
        raise ValueError('Unknown dataset ' + args.dataset)
    '''
    consensue_type = avg
    base_model = resnet_101
    dropout : 0.5
    
    '''
    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                partial_bn=not args.no_partialbn)

    #224
    crop_size = model.crop_size
    #256/224
    scale_size = model.scale_size
    # for each modiltiy is different
    input_mean = model.input_mean
    input_std = model.input_std

    policies = model.get_optim_policies()
    #这里拥有三个augmentation
    #GroupMultiScaleCrop,GroupRandomHorizontalFlip
    #here GropMultiScaleCrop ,is a easily method for 裁剪边用固定位置的crop并最终resize 到 224 ,采用的插值方式,为双线性插值
    #GroupRandomHorizontalFlip
    train_augmentation = model.get_augmentation()
    print(args.gpus)
    model = model.cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    #解释说这里为什么要有roll,主要还是考虑到我们所训练的是对于BGR 还是RGB
    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        "",
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="im{}.jpg",
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        "",
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="im{}.jpg",
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")
    #see the optim policy
    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))
    # general the lr here is 1e-3
    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    #如果说这里是验证过程,如果说不是验证过程
    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return
    viz = vis.Visualizer()
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, viz)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, epoch, viz=viz)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'test_crops': model.state_dict(),
                    'best_prec1': prec1,
                }, is_best)
Пример #28
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if args.dataset == 'ucf101':
        num_class = 101
    elif args.dataset == 'hmdb51':
        num_class = 51
    elif args.dataset == 'kinetics':
        num_class = 400
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        "UCF-Frames",
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="{:06d}.jpg" if args.modality in ["RGB", "RGBDiff"] else
        args.flow_prefix + "{}_{:05d}.jpg",
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        "UCF-Frames",
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="{:06d}.jpg" if args.modality in ["RGB", "RGBDiff"] else
        args.flow_prefix + "{}_{:05d}.jpg",
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader))

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Пример #29
0
def main():
    global args, best_loss
    args = parser.parse_args()
    check_rootfolders()

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
        args.dataset, args.modality)
    num_class = len(categories)

    args.store_name = '_'.join([
        'TRN', args.dataset, args.modality, args.arch, args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    model = TSN(64,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn)
    # freeze cnn if neccessary
    #     _, cnn =list(model.named_children())[0]
    #     for p in cnn.parameters():
    #         p.requires_grad = False

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    # remove if not transfer learning (this is pretrained TRN model taken from here: http://relation.csail.mit.edu/models/TRN_moments_RGB_InceptionV3_TRNmultiscale_segment8_best_v0.4.pth.tar)
    checkpoint = torch.load(
        '/home/ubuntu/dvoitekh/TRN_custom_RGB_InceptionV3_TRNmultiscale_segment8_best.pth.tar'
    )
    model.load_state_dict(checkpoint['state_dict'])
    #     for module in list(list(model._modules['module'].children())[-1].children())[-1].children():
    #         module[-1] = nn.Linear(256, 64)

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    model.cuda()
    model.train(True)
    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(SiameseDataset(
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            Augmentor(),
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(SiameseDataset(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = ContrastiveLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                8,
                                                gamma=0.1,
                                                last_epoch=-1)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            loss = validate(val_loader, model, criterion,
                            (epoch + 1) * len(train_loader), log_training,
                            epoch)

            # remember best loss and save checkpoint
            is_best = loss < best_loss
            best_loss = min(loss, best_loss)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_loss': best_loss,
                }, is_best)
    summary_writer.close()