def main(opt):
    dataset = VideoDataset(opt, 'train')
    dataset_test = VideoDataset(opt, 'test')
    dataloader = DataLoader(dataset,
                            batch_size=opt["batch_size"],
                            shuffle=True)
    dataloader_test = DataLoader(dataset_test,
                                 batch_size=opt["batch_size"],
                                 shuffle=False)
    opt["obj_vocab_size"] = dataset.get_obj_vocab_size()
    opt["rel_vocab_size"] = dataset.get_rel_vocab_size()
    if opt["model"] == 'S2VTModel':
        model = S2VTModel(opt["vocab_size"],
                          opt["max_len"],
                          opt["dim_hidden"],
                          opt["dim_word"],
                          opt['dim_vid'],
                          rnn_cell=opt['rnn_type'],
                          n_layers=opt['num_layers'],
                          rnn_dropout_p=opt["rnn_dropout_p"])
    elif opt["model"] == "S2VTAttModel":
        encoder = EncoderRNN(opt["dim_vid"],
                             opt["dim_hidden"],
                             bidirectional=opt["bidirectional"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"])
        decoder = DecoderRNN(opt["obj_vocab_size"],
                             opt["rel_vocab_size"],
                             opt["max_len"],
                             opt["dim_hidden"],
                             opt["dim_word"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"],
                             bidirectional=opt["bidirectional"])
        model = S2VTAttModel(encoder, decoder)
    model = model.cuda()
    crit = utils.ObjRelCriterion()
    #rl_crit = utils.RewardCriterion()
    optimizer = optim.Adam(model.parameters(),
                           lr=opt["learning_rate"],
                           weight_decay=opt["weight_decay"])
    exp_lr_scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=opt["learning_rate_decay_every"],
        gamma=opt["learning_rate_decay_rate"])

    train(dataloader, model, crit, optimizer, exp_lr_scheduler, opt,
          dataloader_test)
Esempio n. 2
0
def main(opt):
    dataset_test = VideoDataset(opt, 'test')
    dataloader_test = DataLoader(dataset_test,
                                 batch_size=opt["batch_size"],
                                 shuffle=False)
    opt["obj_vocab_size"] = dataset_test.get_obj_vocab_size()
    opt["rel_vocab_size"] = dataset_test.get_rel_vocab_size()
    if opt["model"] == 'S2VTModel':
        model = S2VTModel(opt["vocab_size"],
                          opt["max_len"],
                          opt["dim_hidden"],
                          opt["dim_word"],
                          opt['dim_vid'],
                          rnn_cell=opt['rnn_type'],
                          n_layers=opt['num_layers'],
                          rnn_dropout_p=opt["rnn_dropout_p"])
    elif opt["model"] == "S2VTAttModel":
        encoder = EncoderRNN(opt["dim_vid"],
                             opt["dim_hidden"],
                             bidirectional=opt["bidirectional"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"])
        decoder = DecoderRNN(opt["obj_vocab_size"],
                             opt["rel_vocab_size"],
                             opt["max_len"],
                             opt["dim_hidden"],
                             opt["dim_word"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"],
                             bidirectional=opt["bidirectional"])
        model = S2VTAttModel(encoder, decoder)
    model = model.cuda()
    model.load_state_dict(torch.load(opt['ckpt_path']))
    crit = utils.ObjRelCriterion()
    test(model, crit, opt, dataloader_test)