Exemple #1
0
def main(params):
    params['anchor_list'] = ANCHOR_LIST
    logger = logging.getLogger(params['alias'])
    set_device(logger, params['gpu_id'])
    saver = ModelSaver(params, os.path.abspath('./third_party/densevid_eval'))
    model = construct_model(params, saver, logger)

    model = model.cuda()

    training_set = ANetDataSample(params['train_data'], params['feature_path'],
                                  params['translator_path'], params['video_sample_rate'], logger)
    val_set = ANetDataFull(params['val_data'], params['feature_path'],
                           params['translator_path'], params['video_sample_rate'], logger)
    train_loader = DataLoader(training_set, batch_size=params['batch_size'], shuffle=True,
                              num_workers=params['num_workers'], collate_fn=collate_fn, drop_last=True)
    val_loader = DataLoader(val_set, batch_size=params['batch_size']/3, shuffle=True,
                            num_workers=params['num_workers'], collate_fn=collate_fn, drop_last=True)

    optimizer = torch.optim.SGD(model.get_parameter_group(params),
                                lr=params['lr'], weight_decay=params['weight_decay'], momentum=params['momentum'])

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=params['lr_step'], gamma=params["lr_decay_rate"])
    eval(model, val_loader, params, logger, 0, saver)
    for step in range(params['training_epoch']):
        lr_scheduler.step()
        train(model, train_loader, params, logger, step, optimizer)

        # validation and saving
        if step % params['test_interval'] == 0:
            eval(model, val_loader, params, logger, step, saver)
        if step % params['save_model_interval'] == 0 and step != 0:
            saver.save_model(model,step, {'step': step})
Exemple #2
0
def main(params):
    logger = logging.getLogger(params['alias'])
    gpu_id = set_device(logger, params['gpu_id'])
    logger = logging.getLogger(params['alias'] + '(%d)' % gpu_id)
    set_device(logger, params['gpu_id'])
    saver = ModelSaver(params, os.path.abspath('./third_party/densevid_eval'))
    model_sl, model_cg = construct_model(params, saver, logger)

    model_sl, model_cg = model_sl.cuda(), model_cg.cuda()

    training_set = ANetDataSample(params['train_data'], params['feature_path'],
                                  params['translator_path'],
                                  params['video_sample_rate'], logger)
    val_set = ANetDataFull(params['val_data'], params['feature_path'],
                           params['translator_path'],
                           params['video_sample_rate'], logger)
    train_loader_cg = DataLoader(training_set,
                                 batch_size=params['batch_size'],
                                 shuffle=True,
                                 num_workers=params['num_workers'],
                                 collate_fn=collate_fn,
                                 drop_last=True)
    train_loader_sl = DataLoader(training_set,
                                 batch_size=6,
                                 shuffle=True,
                                 num_workers=params['num_workers'],
                                 collate_fn=collate_fn,
                                 drop_last=True)
    val_loader = DataLoader(val_set,
                            batch_size=8,
                            shuffle=True,
                            num_workers=params['num_workers'],
                            collate_fn=collate_fn,
                            drop_last=True)

    optimizer_sl = torch.optim.SGD(model_sl.get_parameter_group(params),
                                   lr=params['lr'],
                                   weight_decay=params['weight_decay'],
                                   momentum=params['momentum'])

    optimizer_cg = torch.optim.SGD(list(
        chain(model_sl.get_parameter_group_c(params),
              model_cg.get_parameter_group(params))),
                                   lr=params['lr'],
                                   weight_decay=params['weight_decay'],
                                   momentum=params['momentum'])

    optimizer_cg_n = torch.optim.SGD(model_cg.get_parameter_group(params),
                                     lr=params['lr'],
                                     weight_decay=params['weight_decay'],
                                     momentum=params['momentum'])

    lr_scheduler_sl = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_sl,
        milestones=params['lr_step'],
        gamma=params["lr_decay_rate"])

    lr_scheduler_cg = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_cg,
        milestones=params['lr_step'],
        gamma=params["lr_decay_rate"])

    evaluator = CaptionEvaluator(training_set)
    saver.save_model(
        model_sl, 0, {
            'step': 0,
            'model_sl': model_sl.state_dict(),
            'model_cg': model_cg.state_dict()
        })
    # eval(model_sl, model_cg, val_loader, logger, saver, params, -1)
    for step in range(params['training_epoch']):
        lr_scheduler_cg.step()
        lr_scheduler_sl.step()

        if step < params['pretrain_epoch']:
            pretrain_cg(model_cg, train_loader_cg, params, logger, step,
                        optimizer_cg_n)
        elif step % params['alter_step'] != 0:
            train_cg(model_cg, model_sl, train_loader_cg, params, logger, step,
                     optimizer_cg)
        else:
            train_sl(model_cg, model_sl, train_loader_sl, evaluator, params,
                     logger, step, optimizer_sl)

        # validation and saving
        # if step % params['test_interval'] == 0:
        #     eval(model_sl, model_cg, val_loader, logger, saver, params, step)
        if step % params['save_model_interval'] == 0 and step != 0:
            saver.save_model(
                model_sl, step, {
                    'step': step,
                    'model_sl': model_sl.state_dict(),
                    'model_cg': model_cg.state_dict()
                })