示例#1
0
def build_optim(model, optim_opt):
    optim = nmt.Optim(optim_opt.optim_method, optim_opt.learning_rate,
                      optim_opt.max_grad_norm, optim_opt.learning_rate_decay,
                      optim_opt.weight_decay, optim_opt.start_decay_at)

    optim.set_parameters(model.parameters())
    return optim
示例#2
0
def build_optims_and_lr_schedulers(model, opt):
    optimG = nmt.Optim(opt.optim_method, opt.learning_rate, opt.max_grad_norm,
                       opt.learning_rate_decay, opt.weight_decay,
                       opt.start_decay_at)

    optimG.set_parameters(model.generator.parameters())

    lr_lambda = lambda epoch: opt.learning_rate_decay**epoch
    schedulerG = torch.optim.lr_scheduler.LambdaLR(optimizer=optimG.optimizer,
                                                   lr_lambda=[lr_lambda])
    optimD = nmt.Optim(opt.optim_method, opt.learning_rate_D,
                       opt.max_grad_norm, opt.learning_rate_decay,
                       opt.weight_decay, opt.start_decay_at)
    optimD.set_parameters([x for x in model.discriminator.parameters()] +
                          [y for y in model.critic.parameters()])
    schedulerD = torch.optim.lr_scheduler.LambdaLR(optimizer=optimD.optimizer,
                                                   lr_lambda=[lr_lambda])
    return optimG, schedulerG, optimD, schedulerD
示例#3
0
def build_optims_and_schedulers(model, critic, opt):
    if model.__class__.__name__ == "jointTemplateResponseGenerator":
        optimR = nmt.Optim(opt.optim_method, opt.learning_rate_R,
                           opt.max_grad_norm, opt.learning_rate_decay,
                           opt.weight_decay, opt.start_decay_at)
        optimR.set_parameters(model.parameters())
        lr_lambda = lambda epoch: opt.learning_rate_decay**epoch
        schedulerR = torch.optim.lr_scheduler.LambdaLR(
            optimizer=optimR.optimizer, lr_lambda=[lr_lambda])
        return optimR, schedulerR, None, None, None, None

    optimR = nmt.Optim(opt.optim_method, opt.learning_rate_R,
                       opt.max_grad_norm, opt.learning_rate_decay,
                       opt.weight_decay, opt.start_decay_at)

    optimR.set_parameters(model.response_generator.parameters())

    lr_lambda = lambda epoch: opt.learning_rate_decay**epoch
    schedulerR = torch.optim.lr_scheduler.LambdaLR(optimizer=optimR.optimizer,
                                                   lr_lambda=[lr_lambda])
    optimT = nmt.Optim(opt.optim_method, opt.learning_rate_T,
                       opt.max_grad_norm, opt.learning_rate_decay,
                       opt.weight_decay, opt.start_decay_at)
    optimT.set_parameters(model.template_generator.parameters())
    schedulerT = torch.optim.lr_scheduler.LambdaLR(optimizer=optimT.optimizer,
                                                   lr_lambda=[lr_lambda])

    if critic is not None:
        optimC = nmt.Optim(opt.optim_method, opt.learning_rate_C,
                           opt.max_grad_norm, opt.learning_rate_decay,
                           opt.weight_decay, opt.start_decay_at)
        optimC.set_parameters(critic.parameters())
        schedulerC = torch.optim.lr_scheduler.LambdaLR(
            optimizer=optimC.optimizer, lr_lambda=[lr_lambda])
    else:
        optimC, schedulerC = None, None
    return optimR, schedulerR, optimT, schedulerT, optimC, schedulerC
示例#4
0
def main():
    print("Loading data from '%s'" % opt.data)

    dataset = torch.load(opt.data)
    print("Done")
    dict_checkpoint = (opt.train_from
                       if opt.train_from else opt.train_from_state_dict)
    if dict_checkpoint:
        print('Loading dicts from checkpoint at %s' % dict_checkpoint)
        checkpoint = torch.load(dict_checkpoint)
        dataset['dicts'] = checkpoint['dicts']

    trainData = nmt.Dataset(dataset['train']['src'],
                            dataset['train']['tgt'],
                            opt.batch_size,
                            opt.gpus,
                            data_type=dataset.get("type", "text"),
                            balance=(opt.balance_batch == 1))
    validData = nmt.Dataset(dataset['valid']['src'],
                            dataset['valid']['tgt'],
                            opt.eval_batch_size,
                            opt.gpus,
                            volatile=True,
                            data_type=dataset.get("type", "text"))

    dicts = dataset['dicts']
    print(' * vocabulary size. source = %d; target = %d' %
          (dicts['src'].size(), dicts['tgt'].size()))
    print(' * number of training sentences. %d' % len(dataset['train']['src']))
    print(' * maximum batch size. %d' % opt.batch_size)

    print('Building model...')

    criterion = NMTCriterion(dataset['dicts']['tgt'].size())

    embeddings = nmt.utils.createEmbeddings(opt, dicts)

    model = nmt.utils.createNMT(opt, dicts, embeddings)
    print "Neural Machine Translation Model"

    print(model)

    if opt.train_from_state_dict:

        print('Loading model from checkpoint at %s' %
              opt.train_from_state_dict)
        model_state_dict = {
            k: v
            for k, v in checkpoint['model'].items() if 'criterion' not in k
        }
        model.load_state_dict(model_state_dict)
        opt.start_epoch = int(math.floor(checkpoint['epoch'] + 1))
        del checkpoint['model']

    if not opt.train_from_state_dict and not opt.train_from:
        # initialize parameters for the nmt model

        model.init_weights(opt.param_init)
        #~ for p in model.parameters():
        #~ p.data.uniform_(-opt.param_init, opt.param_init)

        model.encoder.load_pretrained_vectors(opt)
        model.decoder.load_pretrained_vectors(opt)

    if len(opt.gpus) >= 1:
        model.cuda()
    else:
        model.cpu()

    if opt.tie_weights:
        print("Share weights between decoder input and output embeddings")
        model.tie_weights()

    if opt.join_vocab:
        print("Share weights between source and target embeddings")
        model.tie_join_embeddings()

    if len(opt.gpus) > 1:
        model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)

    if opt.reset_optim or not opt.train_from_state_dict:

        optim = nmt.Optim(opt.optim,
                          opt.learning_rate,
                          opt.max_grad_norm,
                          lr_decay=opt.learning_rate_decay,
                          start_decay_at=opt.start_decay_at)

    else:
        print('Loading optimizer from checkpoint:')
        optim = checkpoint['optim']
        # Force change learning rate
        optim.lr = opt.learning_rate
        optim.start_decay_at = opt.start_decay_at
        optim.start_decay = False
        del checkpoint['optim']

    optim.set_parameters(model.parameters())
    optim.setLearningRate(opt.learning_rate)

    nParams = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % nParams)

    evaluator = Evaluator(model,
                          dataset,
                          opt.valid_src,
                          opt.valid_tgt,
                          cuda=(len(opt.gpus) >= 1))

    valid_loss = evaluator.eval_perplexity(validData, criterion)
    valid_ppl = math.exp(min(valid_loss, 100))
    print('* Initial Perplexity : %.2f' % valid_ppl)

    print('* Start training ... ')

    trainer = XETrainer(model, criterion, optim, trainData, validData,
                        evaluator, dicts, opt)

    trainer.run()