Example #1
0
    def encode(self,
               src_inputs,
               template_inputs,
               src_lengths,
               template_lengths,
               ev=None):
        emb_src = self.enc_embedding(src_inputs)
        src_contexts, enc_hidden = self.encoder_src(emb_src, src_lengths, None)
        if ev is not None and self.bridge is not None:
            dist = self.bridge(ev)
        else:
            dist = None

        ref_contexts, ref_mask = [], []
        for template_input, template_length in zip(template_inputs,
                                                   template_lengths):
            emb_ref = self.dec_embedding(template_input)
            ref_context, _ = self.encoder_ref(emb_ref, template_length)
            ref_mask_ = sequence_mask(template_length)
            ref_contexts.append(ref_context)
            ref_mask.append(ref_mask_)
        ref_contexts = torch.cat(ref_contexts, 0)
        ref_mask = torch.cat(ref_mask, 1)
        src_mask = sequence_mask(src_lengths)
        return ref_contexts, enc_hidden, ref_mask, dist, src_contexts, src_mask
Example #2
0
    def encode(self, I_word, I_word_length, D_word, D_word_length,
               ref_tgt_inputs, ref_tgt_lengths, src_inputs, src_lengths):
        ev, enc_outputs = self.ev_generator(I_word, I_word_length, D_word,
                                            D_word_length, ref_tgt_inputs,
                                            ref_tgt_lengths)
        ev = self.masker_dropout(ev)
        ev_for_return = ev
        enc_outputs = self.masker_dropout(enc_outputs)
        _, _dim = ev.size()
        _len, _batch, _ = enc_outputs.size()

        if self.bridge is not None:
            dist = self.bridge(ev)
        else:
            dist = None

        ev = ev.unsqueeze(0)
        ev = ev.expand(_len, _batch, _dim)
        preds = self.masker(torch.cat([ev, enc_outputs], 2))
        preds = preds.squeeze(2)

        emb_src = self.enc_embedding(src_inputs)
        src_contexts, enc_hidden = self.encoder_src(emb_src, src_lengths, None)

        ref_mask = sequence_mask(ref_tgt_lengths)
        src_mask = sequence_mask(src_lengths)
        return enc_outputs, enc_hidden, ref_mask, dist, src_contexts, src_mask, preds
Example #3
0
 def encode(self, src_inputs, src_lengths, I_word, I_word_length, D_word,
            D_word_length, ref_tgt_inputs, ref_tgt_lengths):
     emb_src = self.enc_embedding(src_inputs)
     _, enc_hidden = self.encoder_src(emb_src, src_lengths, None)
     ev, ref_contexts = self.ev_generator(I_word, I_word_length, D_word,
                                          D_word_length, ref_tgt_inputs,
                                          ref_tgt_lengths)
     dist = self.bridge(ev)
     ref_mask = sequence_mask(ref_tgt_lengths)
     return ref_contexts, enc_hidden, ref_mask, dist
Example #4
0
    def encode(self,
               src_inputs,
               ref_src_inputs,
               ref_tgt_inputs,
               src_lengths,
               ref_src_lengths,
               ref_tgt_lengths,
               hidden=None):
        emb_src = self.enc_embedding(src_inputs)
        embs_ref_src = [
            self.enc_embedding(ref_src_input)
            for ref_src_input in ref_src_inputs
        ]
        embs_ref_tgt = [
            self.dec_embedding(ref_tgt_input)
            for ref_tgt_input in ref_tgt_inputs
        ]

        ref_values, ref_keys, ref_mask = [], [], []
        for emb_ref_src, emb_ref_tgt, ref_src_length, ref_tgt_length in zip(
                embs_ref_src, embs_ref_tgt, ref_src_lengths, ref_tgt_lengths):
            ref_src_context, enc_ref_hidden = self.encoder_src(
                emb_ref_src, ref_src_length, None)
            ref_src_mask = sequence_mask(ref_src_length)
            ref_key, _, _ = self.decoder_ref(emb_ref_tgt, ref_src_context,
                                             enc_ref_hidden, ref_src_mask)
            ref_value, _ = self.encoder_ref(emb_ref_tgt, ref_tgt_length, None)
            ref_msk = sequence_mask([x - 1 for x in ref_tgt_length])
            ref_values.append(ref_value[1:])
            ref_keys.append(ref_key[:-1])
            ref_mask.append(ref_msk)
        ref_values = torch.cat(ref_values, 0)
        ref_keys = torch.cat(ref_keys, 0)
        ref_mask = torch.cat(ref_mask, 1)

        src_context, enc_hidden = self.encoder_src(emb_src, src_lengths, None)
        src_mask = sequence_mask(src_lengths)

        return ref_values, enc_hidden, ref_keys, ref_mask, src_context, src_mask
Example #5
0
    def forward(self, I_word, I_word_length, D_word, D_word_length,
                ref_tgt_inputs, ref_tgt_lengths):

        enc_outputs, enc_hidden = self.encoder_ref(
            self.dec_embedding(ref_tgt_inputs), ref_tgt_lengths, None)
        I_context = self.enc_embedding(I_word)
        D_context = self.enc_embedding(D_word)
        enc_hidden = enc_hidden.squeeze(0)

        I_context = self.dropout(I_context)
        D_context = self.dropout(D_context)
        enc_hidden = self.dropout(enc_hidden)

        I_context = I_context.transpose(0, 1).contiguous()
        D_context = D_context.transpose(0, 1).contiguous()

        I, _ = self.attention_src(enc_hidden,
                                  I_context,
                                  mask=sequence_mask(I_word_length))
        D, _ = self.attention_ref(enc_hidden,
                                  D_context,
                                  mask=sequence_mask(D_word_length))

        return torch.cat([I, D], 1), enc_outputs
Example #6
0
 def do_mask_and_clean(self, preds, ref_tgt_inputs, ref_tgt_lengths):
     mask = sequence_mask(ref_tgt_lengths).transpose(0, 1).float()
     ans = torch.ge(preds, 0.5)
     ref_tgt_inputs.data.masked_fill_(1 - ans.data, 0)
     y = ref_tgt_inputs.transpose(0, 1).data.tolist()
     data = [z[:l] for z, l in zip(y, ref_tgt_lengths)]
     new_data = []
     for z in data:
         new_z = []
         iszero = False
         for w in z:
             if iszero and w == 0:
                 continue
             else:
                 new_z.append(w)
             iszero = (w == 0)
         new_data.append([1] + new_z + [2])
     return ListsToTensor(new_data)
Example #7
0
    def update(self, batch):
        self.model.zero_grad()
        src_inputs, src_lengths = batch.src
        tgt_inputs = batch.tgt[0][:-1]

        ref_src_inputs, ref_src_lengths = batch.ref_src
        ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt

        model_type = self.model.__class__.__name__
        if model_type == "vanillaNMTModel":
            outputs, attn = self.model(src_inputs, tgt_inputs, src_lengths)
        if model_type == "bivanillaNMTModel":
            outputs, attn = self.model(src_inputs, tgt_inputs, ref_tgt_inputs,
                                       src_lengths, ref_tgt_lengths)
        if model_type == "refNMTModel":
            outputs, attn, outputs_f = self.model(src_inputs, tgt_inputs,
                                                  ref_src_inputs,
                                                  ref_tgt_inputs, src_lengths,
                                                  ref_src_lengths,
                                                  ref_tgt_lengths)
        if model_type == "evNMTModel":
            I_word, I_word_length = batch.I
            D_word, D_word_length = batch.D
            outputs, attn = self.model(src_inputs, tgt_inputs, src_lengths,
                                       I_word, I_word_length, D_word,
                                       D_word_length, ref_tgt_inputs,
                                       ref_tgt_lengths)
        if model_type == "responseGenerator":
            outputs, attn = self.model(src_inputs, tgt_inputs, ref_tgt_inputs,
                                       src_lengths, ref_tgt_lengths)
        if model_type == "tem_resNMTModel":
            I_word, I_word_length = batch.I
            D_word, D_word_length = batch.D
            outputs, attn = self.model(I_word, I_word_length, D_word,
                                       D_word_length, ref_tgt_inputs,
                                       ref_tgt_lengths, src_inputs, tgt_inputs,
                                       src_lengths)
        if model_type == "jointTemplateResponseGenerator":
            I_word, I_word_length = batch.I
            D_word, D_word_length = batch.D
            target, _ = batch.mask

            outputs, attn, preds = self.model(I_word, I_word_length, D_word,
                                              D_word_length, ref_tgt_inputs,
                                              ref_tgt_lengths, src_inputs,
                                              tgt_inputs, src_lengths)
            mask = sequence_mask(ref_tgt_lengths).transpose(0, 1)
            tot = mask.float().sum()

            reserved = target.float().sum()
            w1 = (0.5 * tot / reserved).data[0]
            w2 = (0.5 * tot / (tot - reserved)).data[0]
            #w1, w2 = 1., 1.
            weight = torch.FloatTensor(mask.size()).zero_().cuda()
            weight.masked_fill_(mask, w2)
            weight.masked_fill_(torch.eq(target, 1).data, w1)

            loss = F.binary_cross_entropy(preds, target.float(), weight)
            loss.backward(retain_graph=True)
        if batch.score is not None:
            score = Variable(torch.FloatTensor(batch.score)).cuda()
        else:
            score = None

        stats = self.train_loss.sharded_compute_loss(batch,
                                                     outputs,
                                                     self.shard_size,
                                                     weight=score)

        self.optim.step()
        return stats
Example #8
0
    def update(self, batch, optim, update_what, sample_func=None, critic=None):
        optim.optimizer.zero_grad()
        src_inputs, src_lengths = batch.src
        tgt_inputs, tgt_lengths = batch.tgt
        ref_src_inputs, ref_src_lengths = batch.ref_src
        ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt
        I_word, I_word_length = batch.I
        D_word, D_word_length = batch.D
        preds, ev = self.model.template_generator(I_word,
                                                  I_word_length,
                                                  D_word,
                                                  D_word_length,
                                                  ref_tgt_inputs,
                                                  ref_tgt_lengths,
                                                  return_ev=True)
        preds = preds.squeeze(2)
        template, template_lengths = self.model.template_generator.do_mask_and_clean(
            preds, ref_tgt_inputs, ref_tgt_lengths)

        if sample_func is None:
            outputs, scores = self.score_batch(src_inputs,
                                               tgt_inputs,
                                               None,
                                               template,
                                               src_lengths,
                                               tgt_lengths,
                                               None,
                                               template_lengths,
                                               normalization=True)
            avg = sum(scores) / len(scores)
            scores = [t - avg for t in scores]
        else:
            (response, response_length), logp = sample_func(
                self.model.response_generator,
                src_inputs,
                None,
                template,
                src_lengths,
                None,
                template_lengths,
                max_len=20,
                show_sample=False)
            enc_embedding = self.model.response_generator.enc_embedding
            dec_embedding = self.model.response_generator.dec_embedding
            inds = np.arange(len(tgt_lengths))
            np.random.shuffle(inds)
            inds_tensor = Variable(torch.LongTensor(inds).cuda())
            random_tgt = tgt_inputs.index_select(1, inds_tensor)
            random_tgt_len = [tgt_lengths[i] for i in inds]

            vocab = self.tgt_vocab
            vocab_src = self.src_vocab
            w = src_inputs.t().data.tolist()
            x = tgt_inputs.t().data.tolist()
            y = response.t().data.tolist()
            z = random_tgt.t().data.tolist()
            for tw, tx, ty, tz, ww, xx, yy, zz in zip(w, x, y, z, src_lengths,
                                                      tgt_lengths,
                                                      response_length,
                                                      random_tgt_len):
                print(' '.join([vocab_src.itos[tt]
                                for tt in tw[:ww]]), '|||||',
                      ' '.join([vocab.itos[tt]
                                for tt in tx[1:xx - 1]]), '|||||',
                      ' '.join([vocab.itos[tt]
                                for tt in ty[1:yy - 1]]), '|||||',
                      ' '.join([vocab.itos[tt] for tt in tz[1:zz - 1]]))

            x, y, z = critic(enc_embedding(src_inputs), src_lengths,
                             dec_embedding(tgt_inputs), tgt_lengths,
                             dec_embedding(response), response_length,
                             dec_embedding(random_tgt), random_tgt_len)
            scores = y.data.tolist()

        if update_what == "R":
            logp = logp.sum(0)
            scores = torch.FloatTensor(scores)
            scores = torch.exp(Variable(scores.cuda()))
            #print (logp, scores)
            loss = -(logp * scores).mean()
            print(loss.data[0])
            loss.backward()
            optim.step()
            stats = Statistics()
            return stats

        ans = torch.ge(preds, 0.5)
        mask = sequence_mask(ref_tgt_lengths).transpose(0, 1)
        weight = torch.FloatTensor(mask.size()).zero_().cuda()
        weight.masked_fill_(mask, 1.)

        for i, x in enumerate(scores):
            weight[:, i] *= x

        loss = F.binary_cross_entropy(preds, Variable(ans.float().data),
                                      weight)

        stats = Statistics(
        )  #self.train_loss.monolithic_compute_loss(batch, outputs)
        loss.backward()
        optim.step()
        return stats
Example #9
0
 def encode(self, input, lengths=None, hidden=None):
     emb = self.enc_embedding(input)
     enc_outputs, enc_hidden = self.encoder(emb, lengths, None)
     enc_mask = sequence_mask(lengths)
     return enc_outputs, enc_hidden, enc_mask
def train_model(opt, model, train_iter, valid_iter, fields, optim,
                lr_scheduler, start_epoch_at):
    sys.stdout.flush()
    for step_epoch in range(start_epoch_at + 1, opt.num_train_epochs):
        for batch in train_iter:
            model.zero_grad()
            I_word, I_word_length = batch.I
            D_word, D_word_length = batch.D
            target, _ = batch.mask
            ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt
            preds = model(I_word, I_word_length, D_word, D_word_length,
                          ref_tgt_inputs, ref_tgt_lengths)
            preds = preds.squeeze(2)
            mask = sequence_mask(ref_tgt_lengths).transpose(0, 1)
            tot = mask.float().sum()

            reserved = target.float().sum()
            w1 = (0.5 * tot / reserved).data[0]
            w2 = (0.5 * tot / (tot - reserved)).data[0]
            #w1, w2 = 1., 1.
            weight = torch.FloatTensor(mask.size()).zero_().cuda()
            weight.masked_fill_(mask, w2)
            weight.masked_fill_(torch.eq(target, 1).data, w1)

            loss = F.binary_cross_entropy(preds, target.float(), weight)
            loss.backward()
            optim.step()

        loss = 0.
        acc = 0.
        ntokens = 0.
        reserved, targeted, received = 0., 0., 0.
        model.eval()
        for batch in valid_iter:
            I_word, I_word_length = batch.I
            D_word, D_word_length = batch.D
            target, _ = batch.mask
            target = target.float()
            ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt
            preds = model(I_word, I_word_length, D_word, D_word_length,
                          ref_tgt_inputs, ref_tgt_lengths)
            preds = preds.squeeze(2)
            mask = sequence_mask(ref_tgt_lengths).transpose(0, 1).float()
            loss += F.binary_cross_entropy(preds,
                                           target,
                                           mask,
                                           size_average=False).data[0]
            ans = torch.ge(preds, 0.5).float()
            acc += (torch.eq(ans, target).float().data * mask).sum()
            received += (ans.data * target.data * mask).sum()
            reserved += (ans.data * mask).sum()
            targeted += (target.data * mask).sum()
            ntokens += mask.sum()
        print("epoch: ", step_epoch, "valid_loss: ", loss / ntokens,
              "valid_acc: ", acc / ntokens, "precision: ", received / reserved,
              "recall: ", received / targeted)

        if step_epoch >= opt.start_decay_at:
            lr_scheduler.step()
        model.train()
        save_per_epoch(model, step_epoch, opt)
        sys.stdout.flush()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-config", type=str)
    parser.add_argument("-nmt_dir", type=str)
    parser.add_argument('-gpuid', default=[0], nargs='+', type=int)
    parser.add_argument("-valid_file", type=str)
    parser.add_argument("-train_file", type=str)
    parser.add_argument("-test_file", type=str)
    parser.add_argument("-model", type=str)
    parser.add_argument("-src_vocab", type=str)
    parser.add_argument("-tgt_vocab", type=str)
    parser.add_argument("-mode", type=str)
    parser.add_argument("-out_file", type=str)
    parser.add_argument("-stop_words", type=str, default=None)
    parser.add_argument("-for_train", type=bool, default=True)
    args = parser.parse_args()
    opt = utils.load_hparams(args.config)

    if opt.random_seed > 0:
        random.seed(opt.random_seed)
        torch.manual_seed(opt.random_seed)

    fields = dict()
    vocab_src = Vocab(args.src_vocab, noST=True)
    vocab_tgt = Vocab(args.tgt_vocab)
    fields['src'] = vocab_wrapper(vocab_src)
    fields['tgt'] = vocab_wrapper(vocab_tgt)

    if args.mode == "test":
        model = nmt.model_helper.create_template_generator(opt, fields)
        if use_cuda:
            model = model.cuda()
        model.load_checkpoint(args.model)
        model.eval()
        test = Data_Loader(args.test_file,
                           opt.train_batch_size,
                           train=False,
                           mask_end=True,
                           stop_words=args.stop_words)
        fo = open(args.out_file, 'w')
        loss, acc, ntokens = 0., 0., 0.
        reserved, targeted, received = 0., 0., 0.
        for batch in test:
            I_word, I_word_length = batch.I
            D_word, D_word_length = batch.D
            target, _ = batch.mask
            target = target.float()
            ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt
            preds = model(I_word, I_word_length, D_word, D_word_length,
                          ref_tgt_inputs, ref_tgt_lengths)
            preds = preds.squeeze(2)
            mask = sequence_mask(ref_tgt_lengths).transpose(0, 1).float()
            loss += F.binary_cross_entropy(preds,
                                           target,
                                           mask,
                                           size_average=False).data[0]
            ans = torch.ge(preds, 0.5).float()
            output_results(ans, batch, fo, vocab_tgt, args.for_train)
            acc += (torch.eq(ans, target).float().data * mask).sum()
            received += (ans.data * target.data * mask).sum()
            reserved += (ans.data * mask).sum()
            targeted += (target.data * mask).sum()
            ntokens += mask.sum()
        print("test_loss: ", loss / ntokens, "test_acc: ", acc / ntokens,
              "precision:", received / reserved, "recall: ",
              received / targeted, "leave percentage", targeted / ntokens)
        fo.close()
        #x = 1
        #while True:
        #    x = (x+1)%5
        return

    train = Data_Loader(args.train_file,
                        opt.train_batch_size,
                        mask_end=True,
                        stop_words=args.stop_words)
    valid = Data_Loader(args.valid_file,
                        opt.train_batch_size,
                        mask_end=True,
                        stop_words=args.stop_words)

    # Build model.

    model, start_epoch_at = build_or_load_model(args, opt, fields)
    check_save_model_path(args, opt)

    # Build optimizer.
    optim = build_optim(model, opt)
    lr_scheduler = build_lr_scheduler(optim.optimizer, opt)

    if use_cuda:
        model = model.cuda()

    # Do training.

    train_model(opt, model, train, valid, fields, optim, lr_scheduler,
                start_epoch_at)
    print("DONE")