示例#1
0
def main():
    global args
    global best_prec1
    args = parser.parse_args()

    print('Training arguments:')
    for k, v in vars(args).items():
        print('\t{}: {}'.format(k, v))

    if args.data_name == 'ucf101':
        num_class = 101
    elif args.data_name == 'hmdb51':
        num_class = 51
    elif args.data_name == 'mine':
        num_class = 2
    else:
        raise ValueError('Unknown dataset ' + args.data_name)

    model = Model(num_class,
                  args.num_segments,
                  args.representation,
                  base_model=args.arch)
    print(model)

    if 'resnet3D' in args.arch:
        train_crop_min_ratio = 0.75
        train_crop_min_scale = 0.25
        mean = [0.4345, 0.4051, 0.3775]
        std = [0.2768, 0.2713, 0.2737]
        value_scale = 1

        train_transform = Compose([
            RandomResizedCrop(
                model.crop_size, (train_crop_min_scale, 1.0),
                (train_crop_min_ratio, 1.0 / train_crop_min_ratio)),
            RandomHorizontalFlip(),
            ToTensor(),
            ScaleValue(value_scale),
            Normalize(mean, std)
        ])
        test_trainsform = Compose([
            Resize(model.crop_size),
            CenterCrop(model.crop_size),
            ToTensor(),  # range [0, 255] -> [0.0,1.0]
            ScaleValue(1),
            Normalize(mean, std)
        ])

    train_loader = torch.utils.data.DataLoader(
        CoviarDataSet(
            args.data_root,
            args.data_name,
            video_list=args.train_list,
            num_segments=args.num_segments,
            representation=args.representation,
            transform=model.get_augmentation(),  #train_transform, 
            is_train=True,
            accumulate=(not args.no_accumulation),
            model_name=args.arch),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        worker_init_fn=worker_init_fn)

    val_loader = torch.utils.data.DataLoader(
        CoviarDataSet(
            args.data_root,
            args.data_name,
            video_list=args.test_list,
            num_segments=args.num_segments,
            representation=args.representation,
            transform=torchvision.transforms.Compose([
                GroupScale(int(model.scale_size)),
                GroupCenterCrop(model.crop_size)
            ]),  #test_trainsform, 
            is_train=True,
            accumulate=(not args.no_accumulation),
            model_name=args.arch),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        worker_init_fn=worker_init_fn)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
    cudnn.benchmark = True

    params_dict = dict(model.named_parameters())
    params = []
    for key, value in params_dict.items():
        decay_mult = 0.0 if 'bias' in key else 1.0

        if ('module.base_model.conv1' in key or 'module.base_model.bn1' in key
                or 'data_bn'
                in key) and args.representation in ['mv', 'residual']:
            lr_mult = 0.1
        elif '.fc.' in key:
            lr_mult = 1.0
        else:
            lr_mult = 0.01

        params += [{
            'params': value,
            'lr': args.lr,
            'lr_mult': lr_mult,
            'decay_mult': decay_mult
        }]

    #optimizer = torch.optim.SGD(params, weight_decay=0.001, momentum=0.9, nesterov=False)
    #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10)
    optimizer = torch.optim.Adam(params,
                                 weight_decay=args.weight_decay,
                                 eps=0.001)
    criterion = torch.nn.CrossEntropyLoss().cuda()

    for epoch in range(args.epochs):
        cur_lr = adjust_learning_rate(optimizer, epoch, args.lr_steps,
                                      args.lr_decay)
        #cur_lr = get_lr(optimizer)

        train(train_loader, model, criterion, optimizer, epoch, cur_lr)
        #prec1, prev_val_loss = validate(val_loader, model, criterion)
        #scheduler.step(prev_val_loss)

        if epoch % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1, _ = validate(val_loader, model, criterion)

            # 紀錄訓練歷程
            np.savez("train_history/train_history.npz",
                     loss=np.array(train_loss),
                     top1=np.array(train_prec),
                     lr=np.array(train_lr))
            np.savez("train_history/valid_history.npz",
                     loss=np.array(valid_loss),
                     top1=np.array(valid_prec))

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            if is_best or epoch % SAVE_FREQ == 0:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'best_prec1': best_prec1,
                    },
                    is_best,
                    filename='checkpoint.pth.tar')
示例#2
0
文件: main.py 项目: shuxiao0312/STRG
def get_train_utils(opt, model_parameters):
    assert opt.train_crop in ['random', 'corner', 'center']
    spatial_transform = []
    if opt.train_crop == 'random':
        spatial_transform.append(
            RandomResizedCrop(
                opt.sample_size, (opt.train_crop_min_scale, 1.0),
                (opt.train_crop_min_ratio, 1.0 / opt.train_crop_min_ratio)))
    elif opt.train_crop == 'corner':
        scales = [1.0]
        scale_step = 1 / (2**(1 / 4))
        for _ in range(1, 5):
            scales.append(scales[-1] * scale_step)
        spatial_transform.append(MultiScaleCornerCrop(opt.sample_size, scales))
    elif opt.train_crop == 'center':
        spatial_transform.append(Resize(opt.sample_size))
        spatial_transform.append(CenterCrop(opt.sample_size))
    normalize = get_normalize_method(opt.mean, opt.std, opt.no_mean_norm,
                                     opt.no_std_norm)
    if not opt.no_hflip:
        spatial_transform.append(RandomHorizontalFlip())
    if opt.colorjitter:
        spatial_transform.append(ColorJitter())
    spatial_transform.append(ToTensor())
    if opt.input_type == 'flow':
        spatial_transform.append(PickFirstChannels(n=2))
    spatial_transform.append(ScaleValue(opt.value_scale))
    spatial_transform.append(normalize)
    spatial_transform = Compose(spatial_transform)

    assert opt.train_t_crop in ['random', 'center']
    temporal_transform = []
    if opt.sample_t_stride > 1:
        temporal_transform.append(TemporalSubsampling(opt.sample_t_stride))
    if opt.train_t_crop == 'random':
        temporal_transform.append(TemporalRandomCrop(opt.sample_duration))
    elif opt.train_t_crop == 'center':
        temporal_transform.append(TemporalCenterCrop(opt.sample_duration))
    temporal_transform = TemporalCompose(temporal_transform)

    train_data = get_training_data(opt.video_path, opt.annotation_path,
                                   opt.dataset, opt.input_type, opt.file_type,
                                   spatial_transform, temporal_transform)
    if opt.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=opt.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=opt.n_threads,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               worker_init_fn=worker_init_fn)

    if opt.is_master_node:
        train_logger = Logger(opt.result_path / 'train.log',
                              ['epoch', 'loss', 'acc', 'lr'])
        train_batch_logger = Logger(
            opt.result_path / 'train_batch.log',
            ['epoch', 'batch', 'iter', 'loss', 'acc', 'lr'])
    else:
        train_logger = None
        train_batch_logger = None

    if opt.nesterov:
        dampening = 0
    else:
        dampening = opt.dampening
    optimizer = SGD(model_parameters,
                    lr=opt.learning_rate,
                    momentum=opt.momentum,
                    dampening=dampening,
                    weight_decay=opt.weight_decay,
                    nesterov=opt.nesterov)

    assert opt.lr_scheduler in ['plateau', 'multistep']
    assert not (opt.lr_scheduler == 'plateau' and opt.no_val)
    if opt.lr_scheduler == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(
            optimizer, 'min', patience=opt.plateau_patience)
    else:
        scheduler = lr_scheduler.MultiStepLR(optimizer,
                                             opt.multistep_milestones)

    return (train_loader, train_sampler, train_logger, train_batch_logger,
            optimizer, scheduler)
示例#3
0
def get_train_utils(opt, model_parameters):
    assert opt.train_crop in ['random', 'corner', 'center']
    spatial_transform = []
    if opt.train_crop == 'random':
        spatial_transform.append(
            RandomResizedCrop(
                opt.sample_size, (opt.train_crop_min_scale, 1.0),
                (opt.train_crop_min_ratio, 1.0 / opt.train_crop_min_ratio)))
    elif opt.train_crop == 'corner':
        scales = [1.0]
        scale_step = 1 / (2**(1 / 4))
        for _ in range(1, 5):
            scales.append(scales[-1] * scale_step)
        spatial_transform.append(MultiScaleCornerCrop(opt.sample_size, scales))
    elif opt.train_crop == 'center':
        spatial_transform.append(Resize(opt.sample_size))
        spatial_transform.append(CenterCrop(opt.sample_size))
    normalize = get_normalize_method(opt.mean, opt.std, opt.no_mean_norm,
                                     opt.no_std_norm)
    if not opt.no_hflip:
        spatial_transform.append(RandomHorizontalFlip())
    spatial_transform.append(ToArray())
    if opt.colorjitter:
        spatial_transform.append(ColorJitter())
    if opt.input_type == 'flow':
        spatial_transform.append(PickFirstChannels(n=2))
    spatial_transform.append(ScaleValue(opt.value_scale))
    spatial_transform.append(normalize)
    spatial_transform = Compose(spatial_transform)

    assert opt.train_t_crop in ['random', 'center']
    temporal_transform = []
    if opt.sample_t_stride > 1:
        temporal_transform.append(TemporalSubsampling(opt.sample_t_stride))
    if opt.train_t_crop == 'random':
        temporal_transform.append(TemporalRandomCrop(opt.sample_duration))
    elif opt.train_t_crop == 'center':
        temporal_transform.append(TemporalCenterCrop(opt.sample_duration))
    temporal_transform = TemporalCompose(temporal_transform)

    train_data = get_training_data(opt.video_path, opt.annotation_path,
                                   opt.dataset, opt.input_type, opt.file_type,
                                   spatial_transform, temporal_transform)
    train_loader = paddle.batch(train_data.reader, batch_size=opt.batch_size)

    train_logger = Logger(opt.result_path / 'train.log',
                          ['epoch', 'loss', 'acc', 'lr'])
    train_batch_logger = Logger(
        opt.result_path / 'train_batch.log',
        ['epoch', 'batch', 'iter', 'loss', 'acc', 'lr'])

    assert opt.lr_scheduler in ['plateau', 'multistep']
    assert not (opt.lr_scheduler == 'plateau' and opt.no_val)
    if opt.lr_scheduler == 'plateau':
        scheduler = ReduceLROnPlateau(learning_rate=opt.learning_rate,
                                      mode='min',
                                      patience=opt.plateau_patience)
    else:
        scheduler = MultiStepDecay(learning_rate=opt.learning_rate,
                                   milestones=opt.multistep_milestones)

    optimizer = fluid.optimizer.MomentumOptimizer(
        learning_rate=scheduler,
        momentum=opt.momentum,
        parameter_list=model_parameters,
        use_nesterov=opt.nesterov,
        regularization=fluid.regularizer.L2Decay(
            regularization_coeff=opt.weight_decay))

    return (train_loader, train_logger, train_batch_logger, optimizer,
            scheduler)