Beispiel #1
0
def train_model(opt, model, train_iter, valid_iter, fields, optim,
                lr_scheduler, start_epoch_at):
    train_loss = nmt.NMTLossCompute(model.generator, fields['tgt'].vocab)
    valid_loss = nmt.NMTLossCompute(model.generator, fields['tgt'].vocab)

    if use_cuda:
        train_loss = train_loss.cuda()
        valid_loss = valid_loss.cuda()

    shard_size = opt.train_shard_size
    trainer = nmt.Trainer(opt, model, train_iter, valid_iter, train_loss,
                          valid_loss, optim, shard_size)

    num_train_epochs = opt.num_train_epochs
    print('start training...')
    for step_epoch in range(start_epoch_at + 1, num_train_epochs):

        if step_epoch >= opt.start_decay_at:
            lr_scheduler.step()
        # 1. Train for one epoch on the training set.
        train_stats = trainer.train(step_epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())

        #2. Validate on the validation set.
        valid_stats = trainer.validate()
        print('Validation perplexity: %g' % valid_stats.ppl())

        trainer.epoch_step(step_epoch, out_dir=opt.out_dir)

        model.train()
Beispiel #2
0
def train_model(model, train_data, valid_data, fields, optim, lr_scheduler,
                start_epoch_at):

    train_iter = make_train_data_iter(train_data, opt)
    valid_iter = make_valid_data_iter(valid_data, opt)

    train_loss = nmt.NMTLossCompute(model.generator, fields['tgt'].vocab)
    valid_loss = nmt.NMTLossCompute(model.generator, fields['tgt'].vocab)

    if use_cuda:
        train_loss = train_loss.cuda()
        valid_loss = valid_loss.cuda()

    shard_size = opt.train_shard_size
    trainer = nmt.Trainer(opt, model, train_iter, valid_iter, train_loss,
                          valid_loss, optim, lr_scheduler, shard_size)

    num_train_epochs = opt.num_train_epochs
    print('start training...')
    for step_epoch in range(start_epoch_at + 1, num_train_epochs):

        if step_epoch >= opt.start_decay_at:
            trainer.lr_scheduler.step()
        # 1. Train for one epoch on the training set.
        train_stats = trainer.train(step_epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())

        # 2. Validate on the validation set.
        valid_stats = trainer.validate()
        print('Validation perplexity: %g' % valid_stats.ppl())
        trainer.epoch_step(step_epoch, out_dir=opt.out_dir)
        if opt.test_bleu:
            model.eval()
            valid_bleu = test_bleu(model, fields, step_epoch)
            model.train()

        train_stats.log("train",
                        summery_writer,
                        step_epoch,
                        ppl=train_stats.ppl(),
                        learning_rate=optim.lr,
                        accuracy=train_stats.accuracy())
        valid_stats.log("valid",
                        summery_writer,
                        step_epoch,
                        ppl=valid_stats.ppl(),
                        learning_rate=optim.lr,
                        bleu=valid_bleu if opt.test_bleu else 0.0,
                        accuracy=valid_stats.accuracy())
Beispiel #3
0
def train_model(opt, model, train_iter, valid_iter, fields, optimG,
                lr_schedulerG, optimD, lr_schedulerD, start_epoch_at):
    num_train_epochs = opt.num_train_epochs
    num_updates = 0
    print('start training...')
    valid_loss = nmt.NMTLossCompute(model.generator.generator,
                                    fields['tgt'].vocab)
    if use_cuda:
        valid_loss = valid_loss.cuda()
    shard_size = opt.train_shard_size
    trainer = nmt.Trainer(opt,
                          model.generator,
                          train_iter,
                          valid_iter,
                          valid_loss,
                          valid_loss,
                          optimG,
                          lr_schedulerG,
                          shard_size,
                          train_loss_b=None)

    for step_epoch in range(start_epoch_at + 1, num_train_epochs):
        for batch in train_iter:
            if num_updates % (opt.D_turns + 1) == -1 % (opt.D_turns + 1):
                G_turn(model, batch, optimG, opt)
            else:
                D_turn(model, batch, optimD, opt)
            if num_updates % (opt.show_sample_every) == -1 % (
                    opt.show_sample_every):
                D_turn(model, batch, optimD, opt, show_sample=True)
            num_updates += 1
            sys.stdout.flush()
        valid_stats = trainer.validate()
        print('Validation perplexity: %g' % valid_stats.ppl())
        sys.stdout.flush()
        if step_epoch >= opt.start_decay_at:
            lr_schedulerD.step()
            lr_schedulerG.step()
        save_per_epoch(model, step_epoch, opt)
        model.train()
Beispiel #4
0
def train_model(opt, model, critic, train_iter, valid_iter, fields, optimR,
                lr_schedulerR, optimT, lr_schedulerT, optimC, lr_schedulerC,
                start_epoch_at):
    train_loss = nmt.NMTLossCompute(model.generator, fields['tgt'].vocab)
    valid_loss = nmt.NMTLossCompute(model.generator, fields['tgt'].vocab)

    if use_cuda:
        train_loss = train_loss.cuda()
        valid_loss = valid_loss.cuda()

    shard_size = opt.train_shard_size
    trainer = nmt.Trainer(opt, model, train_iter, valid_iter, train_loss,
                          valid_loss, optimR, shard_size)

    scorer = nmt.Scorer(model, fields['tgt'].vocab, fields['src'].vocab,
                        train_loss, opt)
    num_train_epochs = opt.num_train_epochs
    print('start training...')
    global_step = 0
    for step_epoch in range(start_epoch_at + 1, num_train_epochs):
        if step_epoch >= opt.start_decay_at:
            lr_schedulerR.step()
            if lr_schedulerT is not None:
                lr_schedulerT.step()
            if lr_schedulerC is not None:
                lr_schedulerC.step()

        total_stats = Statistics()
        report_stats = Statistics()
        for step_batch, batch in enumerate(train_iter):
            global_step += 1
            if global_step % 6 == -1 % global_step:
                T_turn = False
                C_turn = False
                R_turn = True
            else:
                T_turn = False
                C_turn = False
                R_turn = True

            if C_turn:
                model.template_generator.eval()
                model.response_generator.eval()
                critic.train()
                optimC.optimizer.zero_grad()
                src_inputs, src_lengths = batch.src
                tgt_inputs, tgt_lengths = batch.tgt
                ref_src_inputs, ref_src_lengths = batch.ref_src
                ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt
                I_word, I_word_length = batch.I
                D_word, D_word_length = batch.D
                preds, ev = model.template_generator(I_word,
                                                     I_word_length,
                                                     D_word,
                                                     D_word_length,
                                                     ref_tgt_inputs,
                                                     ref_tgt_lengths,
                                                     return_ev=True)
                preds = preds.squeeze(2)
                template, template_lengths = model.template_generator.do_mask_and_clean(
                    preds, ref_tgt_inputs, ref_tgt_lengths)

                #x = template.t().data.tolist()
                #vocab = fields['tgt'].vocab
                #for t in x:
                #    print ("---", ' '.join([vocab.itos[tt] for tt in t]))
                (response,
                 response_length), logp = sample(model.response_generator,
                                                 src_inputs,
                                                 None,
                                                 template,
                                                 src_lengths,
                                                 None,
                                                 template_lengths,
                                                 max_len=20)

                enc_embedding = model.response_generator.enc_embedding
                dec_embedding = model.response_generator.dec_embedding
                inds = np.arange(len(tgt_lengths))
                np.random.shuffle(inds)
                inds_tensor = Variable(torch.LongTensor(inds).cuda())
                random_tgt = tgt_inputs.index_select(1, inds_tensor)
                random_tgt_len = [tgt_lengths[i] for i in inds]

                #vocab = fields['tgt'].vocab
                #vocab_src = fields['src'].vocab
                #w = src_inputs.t().data.tolist()
                #x = tgt_inputs.t().data.tolist()
                #y = response.t().data.tolist()
                #z = random_tgt.t().data.tolist()
                #for tw, tx, ty, tz in zip(w, x, y, z):
                #    print (' '.join([vocab_src.itos[tt] for tt in tw]), '|||||', ' '.join([vocab.itos[tt] for tt in tx]), '|||||', ' '.join([vocab.itos[tt] for tt in ty]), '|||||',' '.join([vocab.itos[tt] for tt in tz]))

                x, y, z = critic(enc_embedding(src_inputs), src_lengths,
                                 dec_embedding(tgt_inputs), tgt_lengths,
                                 dec_embedding(response), response_length,
                                 dec_embedding(random_tgt), random_tgt_len)
                loss = torch.mean(-x)
                #print (loss.data[0])
                loss.backward()
                optimC.step()
                stats = Statistics()
            elif T_turn:
                model.template_generator.train()
                model.response_generator.eval()
                critic.eval()
                stats = scorer.update(batch, optimT, 'T', sample, critic)
            elif R_turn:
                #I_word, I_word_length = batch.I
                #D_word, D_word_length = batch.D
                #print("R_TURN : I_word : {}, D_word: {}".format(I_word, D_word))
                if not (model.__class__.__name__
                        == "jointTemplateResponseGenerator"):
                    model.template_generator.eval()
                    model.response_generator.train()
                    critic.eval()
                    if global_step % 2 == 0:
                        stats = trainer.update(batch)
                    else:
                        stats = scorer.update(batch, optimR, 'R', sample,
                                              critic)
                else:
                    stats = trainer.update(batch)
            report_stats.update(stats)
            total_stats.update(stats)
            report_func(opt, global_step, step_epoch, step_batch,
                        len(train_iter), total_stats.start_time, optimR.lr,
                        report_stats)

        if critic is not None:
            critic.save_checkpoint(
                step_epoch, opt,
                os.path.join(opt.out_dir,
                             "checkpoint_epoch_critic%d.pkl" % step_epoch))
        print("\nEpoch : {} ______________________________".format(step_epoch))
        print('Train perplexity: %g' % total_stats.ppl())

        #2. Validate on the validation set.
        valid_stats = trainer.validate()
        print('Validation perplexity: %g' % valid_stats.ppl())

        trainer.epoch_step(step_epoch, out_dir=opt.out_dir)

        model.train()