def encode(self, src_inputs, template_inputs, src_lengths, template_lengths, ev=None): emb_src = self.enc_embedding(src_inputs) src_contexts, enc_hidden = self.encoder_src(emb_src, src_lengths, None) if ev is not None and self.bridge is not None: dist = self.bridge(ev) else: dist = None ref_contexts, ref_mask = [], [] for template_input, template_length in zip(template_inputs, template_lengths): emb_ref = self.dec_embedding(template_input) ref_context, _ = self.encoder_ref(emb_ref, template_length) ref_mask_ = sequence_mask(template_length) ref_contexts.append(ref_context) ref_mask.append(ref_mask_) ref_contexts = torch.cat(ref_contexts, 0) ref_mask = torch.cat(ref_mask, 1) src_mask = sequence_mask(src_lengths) return ref_contexts, enc_hidden, ref_mask, dist, src_contexts, src_mask
def encode(self, I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, src_inputs, src_lengths): ev, enc_outputs = self.ev_generator(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths) ev = self.masker_dropout(ev) ev_for_return = ev enc_outputs = self.masker_dropout(enc_outputs) _, _dim = ev.size() _len, _batch, _ = enc_outputs.size() if self.bridge is not None: dist = self.bridge(ev) else: dist = None ev = ev.unsqueeze(0) ev = ev.expand(_len, _batch, _dim) preds = self.masker(torch.cat([ev, enc_outputs], 2)) preds = preds.squeeze(2) emb_src = self.enc_embedding(src_inputs) src_contexts, enc_hidden = self.encoder_src(emb_src, src_lengths, None) ref_mask = sequence_mask(ref_tgt_lengths) src_mask = sequence_mask(src_lengths) return enc_outputs, enc_hidden, ref_mask, dist, src_contexts, src_mask, preds
def encode(self, src_inputs, src_lengths, I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths): emb_src = self.enc_embedding(src_inputs) _, enc_hidden = self.encoder_src(emb_src, src_lengths, None) ev, ref_contexts = self.ev_generator(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths) dist = self.bridge(ev) ref_mask = sequence_mask(ref_tgt_lengths) return ref_contexts, enc_hidden, ref_mask, dist
def encode(self, src_inputs, ref_src_inputs, ref_tgt_inputs, src_lengths, ref_src_lengths, ref_tgt_lengths, hidden=None): emb_src = self.enc_embedding(src_inputs) embs_ref_src = [ self.enc_embedding(ref_src_input) for ref_src_input in ref_src_inputs ] embs_ref_tgt = [ self.dec_embedding(ref_tgt_input) for ref_tgt_input in ref_tgt_inputs ] ref_values, ref_keys, ref_mask = [], [], [] for emb_ref_src, emb_ref_tgt, ref_src_length, ref_tgt_length in zip( embs_ref_src, embs_ref_tgt, ref_src_lengths, ref_tgt_lengths): ref_src_context, enc_ref_hidden = self.encoder_src( emb_ref_src, ref_src_length, None) ref_src_mask = sequence_mask(ref_src_length) ref_key, _, _ = self.decoder_ref(emb_ref_tgt, ref_src_context, enc_ref_hidden, ref_src_mask) ref_value, _ = self.encoder_ref(emb_ref_tgt, ref_tgt_length, None) ref_msk = sequence_mask([x - 1 for x in ref_tgt_length]) ref_values.append(ref_value[1:]) ref_keys.append(ref_key[:-1]) ref_mask.append(ref_msk) ref_values = torch.cat(ref_values, 0) ref_keys = torch.cat(ref_keys, 0) ref_mask = torch.cat(ref_mask, 1) src_context, enc_hidden = self.encoder_src(emb_src, src_lengths, None) src_mask = sequence_mask(src_lengths) return ref_values, enc_hidden, ref_keys, ref_mask, src_context, src_mask
def forward(self, I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths): enc_outputs, enc_hidden = self.encoder_ref( self.dec_embedding(ref_tgt_inputs), ref_tgt_lengths, None) I_context = self.enc_embedding(I_word) D_context = self.enc_embedding(D_word) enc_hidden = enc_hidden.squeeze(0) I_context = self.dropout(I_context) D_context = self.dropout(D_context) enc_hidden = self.dropout(enc_hidden) I_context = I_context.transpose(0, 1).contiguous() D_context = D_context.transpose(0, 1).contiguous() I, _ = self.attention_src(enc_hidden, I_context, mask=sequence_mask(I_word_length)) D, _ = self.attention_ref(enc_hidden, D_context, mask=sequence_mask(D_word_length)) return torch.cat([I, D], 1), enc_outputs
def do_mask_and_clean(self, preds, ref_tgt_inputs, ref_tgt_lengths): mask = sequence_mask(ref_tgt_lengths).transpose(0, 1).float() ans = torch.ge(preds, 0.5) ref_tgt_inputs.data.masked_fill_(1 - ans.data, 0) y = ref_tgt_inputs.transpose(0, 1).data.tolist() data = [z[:l] for z, l in zip(y, ref_tgt_lengths)] new_data = [] for z in data: new_z = [] iszero = False for w in z: if iszero and w == 0: continue else: new_z.append(w) iszero = (w == 0) new_data.append([1] + new_z + [2]) return ListsToTensor(new_data)
def update(self, batch): self.model.zero_grad() src_inputs, src_lengths = batch.src tgt_inputs = batch.tgt[0][:-1] ref_src_inputs, ref_src_lengths = batch.ref_src ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt model_type = self.model.__class__.__name__ if model_type == "vanillaNMTModel": outputs, attn = self.model(src_inputs, tgt_inputs, src_lengths) if model_type == "bivanillaNMTModel": outputs, attn = self.model(src_inputs, tgt_inputs, ref_tgt_inputs, src_lengths, ref_tgt_lengths) if model_type == "refNMTModel": outputs, attn, outputs_f = self.model(src_inputs, tgt_inputs, ref_src_inputs, ref_tgt_inputs, src_lengths, ref_src_lengths, ref_tgt_lengths) if model_type == "evNMTModel": I_word, I_word_length = batch.I D_word, D_word_length = batch.D outputs, attn = self.model(src_inputs, tgt_inputs, src_lengths, I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths) if model_type == "responseGenerator": outputs, attn = self.model(src_inputs, tgt_inputs, ref_tgt_inputs, src_lengths, ref_tgt_lengths) if model_type == "tem_resNMTModel": I_word, I_word_length = batch.I D_word, D_word_length = batch.D outputs, attn = self.model(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, src_inputs, tgt_inputs, src_lengths) if model_type == "jointTemplateResponseGenerator": I_word, I_word_length = batch.I D_word, D_word_length = batch.D target, _ = batch.mask outputs, attn, preds = self.model(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, src_inputs, tgt_inputs, src_lengths) mask = sequence_mask(ref_tgt_lengths).transpose(0, 1) tot = mask.float().sum() reserved = target.float().sum() w1 = (0.5 * tot / reserved).data[0] w2 = (0.5 * tot / (tot - reserved)).data[0] #w1, w2 = 1., 1. weight = torch.FloatTensor(mask.size()).zero_().cuda() weight.masked_fill_(mask, w2) weight.masked_fill_(torch.eq(target, 1).data, w1) loss = F.binary_cross_entropy(preds, target.float(), weight) loss.backward(retain_graph=True) if batch.score is not None: score = Variable(torch.FloatTensor(batch.score)).cuda() else: score = None stats = self.train_loss.sharded_compute_loss(batch, outputs, self.shard_size, weight=score) self.optim.step() return stats
def update(self, batch, optim, update_what, sample_func=None, critic=None): optim.optimizer.zero_grad() src_inputs, src_lengths = batch.src tgt_inputs, tgt_lengths = batch.tgt ref_src_inputs, ref_src_lengths = batch.ref_src ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt I_word, I_word_length = batch.I D_word, D_word_length = batch.D preds, ev = self.model.template_generator(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, return_ev=True) preds = preds.squeeze(2) template, template_lengths = self.model.template_generator.do_mask_and_clean( preds, ref_tgt_inputs, ref_tgt_lengths) if sample_func is None: outputs, scores = self.score_batch(src_inputs, tgt_inputs, None, template, src_lengths, tgt_lengths, None, template_lengths, normalization=True) avg = sum(scores) / len(scores) scores = [t - avg for t in scores] else: (response, response_length), logp = sample_func( self.model.response_generator, src_inputs, None, template, src_lengths, None, template_lengths, max_len=20, show_sample=False) enc_embedding = self.model.response_generator.enc_embedding dec_embedding = self.model.response_generator.dec_embedding inds = np.arange(len(tgt_lengths)) np.random.shuffle(inds) inds_tensor = Variable(torch.LongTensor(inds).cuda()) random_tgt = tgt_inputs.index_select(1, inds_tensor) random_tgt_len = [tgt_lengths[i] for i in inds] vocab = self.tgt_vocab vocab_src = self.src_vocab w = src_inputs.t().data.tolist() x = tgt_inputs.t().data.tolist() y = response.t().data.tolist() z = random_tgt.t().data.tolist() for tw, tx, ty, tz, ww, xx, yy, zz in zip(w, x, y, z, src_lengths, tgt_lengths, response_length, random_tgt_len): print(' '.join([vocab_src.itos[tt] for tt in tw[:ww]]), '|||||', ' '.join([vocab.itos[tt] for tt in tx[1:xx - 1]]), '|||||', ' '.join([vocab.itos[tt] for tt in ty[1:yy - 1]]), '|||||', ' '.join([vocab.itos[tt] for tt in tz[1:zz - 1]])) x, y, z = critic(enc_embedding(src_inputs), src_lengths, dec_embedding(tgt_inputs), tgt_lengths, dec_embedding(response), response_length, dec_embedding(random_tgt), random_tgt_len) scores = y.data.tolist() if update_what == "R": logp = logp.sum(0) scores = torch.FloatTensor(scores) scores = torch.exp(Variable(scores.cuda())) #print (logp, scores) loss = -(logp * scores).mean() print(loss.data[0]) loss.backward() optim.step() stats = Statistics() return stats ans = torch.ge(preds, 0.5) mask = sequence_mask(ref_tgt_lengths).transpose(0, 1) weight = torch.FloatTensor(mask.size()).zero_().cuda() weight.masked_fill_(mask, 1.) for i, x in enumerate(scores): weight[:, i] *= x loss = F.binary_cross_entropy(preds, Variable(ans.float().data), weight) stats = Statistics( ) #self.train_loss.monolithic_compute_loss(batch, outputs) loss.backward() optim.step() return stats
def encode(self, input, lengths=None, hidden=None): emb = self.enc_embedding(input) enc_outputs, enc_hidden = self.encoder(emb, lengths, None) enc_mask = sequence_mask(lengths) return enc_outputs, enc_hidden, enc_mask
def train_model(opt, model, train_iter, valid_iter, fields, optim, lr_scheduler, start_epoch_at): sys.stdout.flush() for step_epoch in range(start_epoch_at + 1, opt.num_train_epochs): for batch in train_iter: model.zero_grad() I_word, I_word_length = batch.I D_word, D_word_length = batch.D target, _ = batch.mask ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt preds = model(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths) preds = preds.squeeze(2) mask = sequence_mask(ref_tgt_lengths).transpose(0, 1) tot = mask.float().sum() reserved = target.float().sum() w1 = (0.5 * tot / reserved).data[0] w2 = (0.5 * tot / (tot - reserved)).data[0] #w1, w2 = 1., 1. weight = torch.FloatTensor(mask.size()).zero_().cuda() weight.masked_fill_(mask, w2) weight.masked_fill_(torch.eq(target, 1).data, w1) loss = F.binary_cross_entropy(preds, target.float(), weight) loss.backward() optim.step() loss = 0. acc = 0. ntokens = 0. reserved, targeted, received = 0., 0., 0. model.eval() for batch in valid_iter: I_word, I_word_length = batch.I D_word, D_word_length = batch.D target, _ = batch.mask target = target.float() ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt preds = model(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths) preds = preds.squeeze(2) mask = sequence_mask(ref_tgt_lengths).transpose(0, 1).float() loss += F.binary_cross_entropy(preds, target, mask, size_average=False).data[0] ans = torch.ge(preds, 0.5).float() acc += (torch.eq(ans, target).float().data * mask).sum() received += (ans.data * target.data * mask).sum() reserved += (ans.data * mask).sum() targeted += (target.data * mask).sum() ntokens += mask.sum() print("epoch: ", step_epoch, "valid_loss: ", loss / ntokens, "valid_acc: ", acc / ntokens, "precision: ", received / reserved, "recall: ", received / targeted) if step_epoch >= opt.start_decay_at: lr_scheduler.step() model.train() save_per_epoch(model, step_epoch, opt) sys.stdout.flush()
def main(): parser = argparse.ArgumentParser() parser.add_argument("-config", type=str) parser.add_argument("-nmt_dir", type=str) parser.add_argument('-gpuid', default=[0], nargs='+', type=int) parser.add_argument("-valid_file", type=str) parser.add_argument("-train_file", type=str) parser.add_argument("-test_file", type=str) parser.add_argument("-model", type=str) parser.add_argument("-src_vocab", type=str) parser.add_argument("-tgt_vocab", type=str) parser.add_argument("-mode", type=str) parser.add_argument("-out_file", type=str) parser.add_argument("-stop_words", type=str, default=None) parser.add_argument("-for_train", type=bool, default=True) args = parser.parse_args() opt = utils.load_hparams(args.config) if opt.random_seed > 0: random.seed(opt.random_seed) torch.manual_seed(opt.random_seed) fields = dict() vocab_src = Vocab(args.src_vocab, noST=True) vocab_tgt = Vocab(args.tgt_vocab) fields['src'] = vocab_wrapper(vocab_src) fields['tgt'] = vocab_wrapper(vocab_tgt) if args.mode == "test": model = nmt.model_helper.create_template_generator(opt, fields) if use_cuda: model = model.cuda() model.load_checkpoint(args.model) model.eval() test = Data_Loader(args.test_file, opt.train_batch_size, train=False, mask_end=True, stop_words=args.stop_words) fo = open(args.out_file, 'w') loss, acc, ntokens = 0., 0., 0. reserved, targeted, received = 0., 0., 0. for batch in test: I_word, I_word_length = batch.I D_word, D_word_length = batch.D target, _ = batch.mask target = target.float() ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt preds = model(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths) preds = preds.squeeze(2) mask = sequence_mask(ref_tgt_lengths).transpose(0, 1).float() loss += F.binary_cross_entropy(preds, target, mask, size_average=False).data[0] ans = torch.ge(preds, 0.5).float() output_results(ans, batch, fo, vocab_tgt, args.for_train) acc += (torch.eq(ans, target).float().data * mask).sum() received += (ans.data * target.data * mask).sum() reserved += (ans.data * mask).sum() targeted += (target.data * mask).sum() ntokens += mask.sum() print("test_loss: ", loss / ntokens, "test_acc: ", acc / ntokens, "precision:", received / reserved, "recall: ", received / targeted, "leave percentage", targeted / ntokens) fo.close() #x = 1 #while True: # x = (x+1)%5 return train = Data_Loader(args.train_file, opt.train_batch_size, mask_end=True, stop_words=args.stop_words) valid = Data_Loader(args.valid_file, opt.train_batch_size, mask_end=True, stop_words=args.stop_words) # Build model. model, start_epoch_at = build_or_load_model(args, opt, fields) check_save_model_path(args, opt) # Build optimizer. optim = build_optim(model, opt) lr_scheduler = build_lr_scheduler(optim.optimizer, opt) if use_cuda: model = model.cuda() # Do training. train_model(opt, model, train, valid, fields, optim, lr_scheduler, start_epoch_at) print("DONE")