def forward(self, input_sequence, length): batch_size = input_sequence.size(0) sorted_lengths, sorted_idx = torch.sort(length, descending=True) input_sequence = input_sequence[sorted_idx] # ENCODER input_embedding = self.embedding(input_sequence) packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True) _, hidden = self.encoder_rnn(packed_input) if self.bidirectional or self.num_layers > 1: # flatten hidden state hidden = hidden.view(batch_size, self.hidden_size*self.hidden_factor) else: hidden = hidden.squeeze() # REPARAMETERIZATION mean = self.hidden2mean(hidden) logv = self.hidden2logv(hidden) std = torch.exp(0.5 * logv) z = to_device(torch.randn([batch_size, self.latent_size])) z = z * std + mean # DECODER hidden = self.latent2hidden(z) if self.bidirectional or self.num_layers > 1: # unflatten hidden state hidden = hidden.view(self.hidden_factor, batch_size, self.hidden_size) else: hidden = hidden.unsqueeze(0) # decoder input if self.word_dropout_rate > 0: # randomly replace decoder input with <unk> prob = torch.rand(input_sequence.size()) if torch.cuda.is_available(): prob=prob.cuda() prob[(input_sequence.data - self.sos_idx) * (input_sequence.data - self.pad_idx) == 0] = 1 decoder_input_sequence = input_sequence.clone() decoder_input_sequence[prob < self.word_dropout_rate] = self.unk_idx input_embedding = self.embedding(decoder_input_sequence) input_embedding = self.embedding_dropout(input_embedding) packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True) # decoder forward pass outputs, _ = self.decoder_rnn(packed_input, hidden) # process outputs padded_outputs = rnn_utils.pad_packed_sequence(outputs, batch_first=True)[0] padded_outputs = padded_outputs.contiguous() _,reversed_idx = torch.sort(sorted_idx) padded_outputs = padded_outputs[reversed_idx] b,s,_ = padded_outputs.size() # project outputs to vocab logp = nn.functional.log_softmax(self.outputs2vocab(padded_outputs.view(-1, padded_outputs.size(2))), dim=-1) logp = logp.view(b, s, self.embedding.num_embeddings) return logp, mean, logv, None, z, None
def train_and_eval_cvae(args): torch.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) # output folder if args.pickle is not None: pickle_path = Path(args.pickle.rstrip('.pkl')) pickle_name = pickle_path.stem run_dir = pickle_path else: output_dir = Path(args.output_folder) if not output_dir.exists(): output_dir.mkdir() current_time = datetime.now().strftime('%b%d_%H-%M-%S') run_dir = output_dir / current_time if not run_dir.exists(): run_dir.mkdir() # data handling if args.cosine_threshold is not None and args.none_intents is not None: raise ValueError("None intents cannot be specified while using a " "cosine similarity selection") data_folder = Path(args.data_folder) dataset_folder = data_folder / args.dataset_type none_folder = data_folder / args.none_type none_idx = NONE_COLUMN_MAPPING[args.none_type] dataset = create_dataset( dataset_type=args.dataset_type, dataset_folder=dataset_folder, dataset_size=args.dataset_size, restrict_intents=args.restrict_intents, none_folder=none_folder, none_size=args.none_size, none_intents=args.none_intents, none_idx=none_idx, infersent_selection=args.infersent_selection, cosine_threshold=args.cosine_threshold, input_type=args.input_type, tokenizer_type=args.tokenizer_type, preprocessing_type=args.preprocessing_type, max_sequence_length=args.max_sequence_length, embedding_type=args.embedding_type, embedding_dimension=args.embedding_dimension, max_vocab_size=args.max_vocab_size, slot_embedding=args.slot_embedding, run_dir=run_dir ) if args.load_folder: original_vocab_size = dataset.update(args.load_folder) LOGGER.info('Loaded vocab from %s' % args.load_folder) # training if args.conditioning == NO_CONDITIONING: args.conditioning = None if not args.load_folder: model = CVAE( conditional=args.conditioning, compute_bow=args.bow_loss, vocab_size=dataset.vocab_size, embedding_size=args.embedding_dimension, rnn_type=args.rnn_type, hidden_size_encoder=args.hidden_size_encoder, hidden_size_decoder=args.hidden_size_decoder, word_dropout_rate=args.word_dropout_rate, embedding_dropout_rate=args.embedding_dropout_rate, z_size=args.latent_size, n_classes=dataset.n_classes, cat_size=dataset.n_classes if args.cat_size is None else args.cat_size, sos_idx=dataset.sos_idx, eos_idx=dataset.eos_idx, pad_idx=dataset.pad_idx, unk_idx=dataset.unk_idx, max_sequence_length=args.max_sequence_length, num_layers_encoder=args.num_layers_encoder, num_layers_decoder=args.num_layers_decoder, bidirectional=args.bidirectional, temperature=args.temperature, force_cpu=args.force_cpu ) else: model = CVAE.from_folder(args.load_folder) LOGGER.info('Loaded model from %s' % args.load_folder) model.n_classes = dataset.n_classes model.update_embedding(dataset.vectors) model.update_outputs2vocab(original_vocab_size, dataset.vocab_size) model = to_device(model, args.force_cpu) parameters = filter(lambda p: p.requires_grad, model.parameters()) optimizer = getattr(torch.optim, args.optimizer_type)( model.parameters(), lr=args.learning_rate ) trainer = Trainer( dataset, model, optimizer, batch_size=args.batch_size, annealing_strategy=args.annealing_strategy, kl_anneal_rate=args.kl_anneal_rate, kl_anneal_time=args.kl_anneal_time, kl_anneal_target=args.kl_anneal_target, label_anneal_rate=args.label_anneal_rate, label_anneal_time=args.label_anneal_time, label_anneal_target=args.label_anneal_target, add_bow_loss=args.bow_loss, force_cpu=args.force_cpu, run_dir=run_dir / "tensorboard", alpha = args.alpha ) trainer.run(args.n_epochs, dev_step_every_n_epochs=1) if args.pickle is not None: model_path = run_dir / "{}_load".format(pickle_name) else: model_path = run_dir / "load" dataset.save(model_path) model.save(model_path) # evaluation run_dict = dict() # generate queries generated_sentences, logp = generate_vae_sentences( model=model, n_to_generate=args.n_generated, input_type=args.input_type, i2int=dataset.i2int, i2w=dataset.i2w, eos_idx=dataset.eos_idx, slotdic=dataset.slotdic if args.input_type == 'delexicalised' else None, verbose=True ) run_dict['generated'] = generated_sentences run_dict['metrics'] = compute_generation_metrics( dataset, generated_sentences['utterances'], generated_sentences['intents'], logp ) for k, v in run_dict['metrics'].items(): LOGGER.info((k, v)) if args.input_type == "delexicalised": run_dict['delexicalised_metrics'] = compute_generation_metrics( dataset, generated_sentences['delexicalised'], generated_sentences['intents'], logp, input_type='delexicalised' ) for k, v in run_dict['delexicalised_metrics'].items(): LOGGER.info((k, v)) save_augmented_dataset(generated_sentences, args.n_generated, dataset.train_path, run_dir) run_dict['args'] = vars(args) run_dict['logs'] = trainer.run_logs run_dict['latent_rep'] = trainer.latent_rep run_dict['i2w'] = dataset.i2w run_dict['w2i'] = dataset.w2i run_dict['i2int'] = dataset.i2int run_dict['int2i'] = dataset.int2i run_dict['vectors'] = { 'before': dataset.vocab.vectors, 'after': model.embedding.weight.data } if args.pickle is not None: run_dict_path = run_dir.parents[0] / "{}.pkl".format(pickle_name) else: run_dict_path = run_dir / "run.pkl" torch.save(run_dict, str(run_dict_path))
def inference(self, n=4, z=None): if z is None: batch_size = n z = to_device(torch.randn([batch_size, self.latent_size])) else: batch_size = z.size(0) hidden = self.latent2hidden(z) if self.bidirectional or self.num_layers > 1: # unflatten hidden state hidden = hidden.view(self.hidden_factor, batch_size, self.hidden_size) hidden = hidden.unsqueeze(0) # required for dynamic stopping of sentence generation sequence_idx = torch.arange(0, batch_size, out=self.tensor()).long() # all idx of batch sequence_running = torch.arange(0, batch_size, out=self.tensor()).long() # all idx of batch which are still generating sequence_mask = torch.ones(batch_size, out=self.tensor()).byte() running_seqs = torch.arange(0, batch_size, out=self.tensor()).long() # idx of still generating sequences with respect to current loop generations = self.tensor(batch_size, self.max_sequence_length).fill_(self.pad_idx).long() t=0 while(t<self.max_sequence_length and len(running_seqs)>0): if t == 0: input_sequence = to_device(torch.Tensor(batch_size).fill_(self.sos_idx).long()) input_sequence = input_sequence.unsqueeze(1) input_embedding = self.embedding(input_sequence) output, hidden = self.decoder_rnn(input_embedding, hidden) logits = self.outputs2vocab(output) input_sequence = self._sample(logits) # save next input generations = self._save_sample(generations, input_sequence, sequence_running, t) # update gloabl running sequence sequence_mask[sequence_running] = (input_sequence != self.eos_idx).data sequence_running = sequence_idx.masked_select(sequence_mask) # update local running sequences running_mask = (input_sequence != self.eos_idx).data running_seqs = running_seqs.masked_select(running_mask) # prune input and hidden state according to local update if len(running_seqs) > 0: input_sequence = input_sequence[running_seqs] hidden = hidden[:, running_seqs] running_seqs = torch.arange(0, len(running_seqs), out=self.tensor()).long() t += 1 return generations, z, None
def train(model, datasets, args): train_iter, val_iter = datasets.get_iterators(batch_size=args.batch_size) opt = getattr(torch.optim, args.optimizer)(model.parameters(), lr=args.learning_rate) # opt = torch.optim.Adam([ # {"params": model.encoder_rnn.parameters(), "lr": args.learning_rate}, # {"params": model.hidden2mean.parameters(), "lr": args.learning_rate}, # {"params": model.hidden2logv.parameters(), "lr": args.learning_rate}, # {"params": model.hidden2cat.parameters(), "lr": args.learning_rate}, # {"params": model.latent2hidden.parameters(), "lr": args.learning_rate}, # {"params": model.latent2bow.parameters(), "lr": args.learning_rate}, # {"params": model.outputs2vocab.parameters(), "lr": args.learning_rate}]) step = 0 NLL_hist = [] KL_hist = [] BOW_hist = [] NMI_hist = [] acc_hist = [] latent_rep = {i: [] for i in range(model.n_classes)} for epoch in range(1, args.epochs + 1): tr_loss = 0.0 NLL_tr_loss = 0.0 KL_tr_loss = 0.0 BOW_tr_loss = 0.0 NMI_tr = 0.0 n_correct_tr = 0.0 acc_tr = 0.0 model.train() # turn on training mode for iteration, batch in enumerate(tqdm(train_iter)): step += 1 opt.zero_grad() # model.word_dropout_rate = anneal_fn(args.anneal_function, step, args.k3, args.x3, args.m3) x, lengths = getattr(batch, args.input_type) input = x[:, :-1] # remove <eos> target = x[:, 1:] # remove <sos> lengths -= 1 # account for the removal input, target = to_device(input), to_device(target) if args.conditional != 'none': y = batch.intent.squeeze() y = to_device(y) sorted_lengths, sorted_idx = torch.sort(lengths, descending=True) y = y[sorted_idx] logp, mean, logv, logc, z, bow = model(input, lengths) if epoch == args.epochs and args.conditional != 'none': for i, intent in enumerate(y): latent_rep[int(intent)].append(z[i].cpu().detach().numpy()) # loss calculation NLL_loss, KL_losses, KL_weight, BOW_loss = loss_fn( logp, bow, target, lengths, mean, logv, args.anneal_function, step, args.k1, args.x1, args.m1) KL_loss = torch.sum(KL_losses) NLL_hist.append(NLL_loss.detach().cpu().numpy() / args.batch_size) KL_hist.append(KL_losses.detach().cpu().numpy() / args.batch_size) BOW_hist.append(BOW_loss.detach().cpu().numpy() / args.batch_size) label_loss, label_weight = loss_labels(logc, y, args.anneal_function, step, args.k2, args.x2, args.m2) loss = (NLL_loss + KL_weight * KL_loss + label_weight * label_loss ) #/args.batch_size if args.bow_loss: loss += BOW_loss if args.conditional == 'none': pred_labels = 0 n_correct = 0 NMI = 0 else: if args.conditional == 'supervised': label_loss, label_weight = loss_labels( logc, y, args.anneal_function, step, args.k2, args.x2, args.m2) loss += label_weight * label_loss elif args.conditional == 'unsupervised': entropy = torch.sum( torch.exp(logc) * torch.log(model.n_classes * torch.exp(logc))) loss += entropy pred_labels = logc.data.max(1)[1].long() n_correct = pred_labels.eq(y.data).cpu().sum().float().item() acc_hist.append(n_correct / args.batch_size) NMI = normalized_mutual_info_score( y.cpu().detach().numpy(), torch.exp(logc).cpu().max(1)[1].numpy()) NMI_hist.append(NMI) loss.backward() # CLIPPING # for p in model.parameters(): # p.register_hook(lambda grad: torch.clamp(grad, -1, 1)) # torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1) # torch.nn.utils.clip_grad_norm(model.parameters(), clipping_value=1) opt.step() tr_loss += loss.item() NLL_tr_loss += NLL_loss.item() KL_tr_loss += KL_loss.item() BOW_tr_loss += BOW_loss.item() NMI_tr += NMI n_correct_tr += n_correct # if iteration % 100 == 0: # print("Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f" # %(loss.data, NLL_loss.item()/args.batch_size, KL_loss.item()/args.batch_size, KL_weight)) # x_sentences = input[:3].cpu().numpy() # print('\nInput sentences :') # print(*idx2word(x_sentences, i2w=i2w, eos_idx=eos_idx), sep='\n') # _, y_sentences = torch.topk(logp, 1, dim=-1) # y_sentences = y_sentences[:3].squeeze().cpu().numpy() # print('\nOutput sentences : ') # print(*idx2word(y_sentences, i2w=i2w, eos_idx=eos_idx), sep='\n') # print('\n') tr_loss = tr_loss / len(datasets.train) NLL_tr_loss = NLL_tr_loss / len(datasets.train) KL_tr_loss = KL_tr_loss / len(datasets.train) BOW_tr_loss = BOW_tr_loss / len(datasets.train) NMI_tr = NMI_tr / len(datasets.train) acc_tr = n_correct_tr / len(datasets.train) # calculate the validation loss for this epoch val_loss = 0.0 NLL_val_loss = 0.0 KL_val_loss = 0.0 BOW_val_loss = 0.0 NMI_val = 0.0 n_correct_val = 0.0 acc_val = 0.0 model.eval() # turn on evaluation mode for batch in tqdm(val_iter): x, lengths = getattr(batch, args.input_type) target = x[:, 1:] # remove <sos> input = x[:, :-1] # remove <eos> lengths -= 1 # account for the removal input, target = to_device(input), to_device(target) if args.conditional != 'none': y = batch.intent.squeeze() y = to_device(y) sorted_lengths, sorted_idx = torch.sort(lengths, descending=True) y = y[sorted_idx] logp, mean, logv, logc, z, bow = model(input, lengths) # loss calculation NLL_loss, KL_losses, KL_weight, BOW_loss = loss_fn( logp, bow, target, lengths, mean, logv, args.anneal_function, step, args.k1, args.x1, args.m1) KL_loss = torch.sum(KL_losses) loss = (NLL_loss + KL_weight * KL_loss) #/args.batch_size if args.bow_loss: loss += BOW_loss if args.conditional == 'none': pred_labels = 0 n_correct = 0 NMI = 0 else: if args.conditional == 'supervised': label_loss, label_weight = loss_labels( logc, y, args.anneal_function, step, args.k2, args.x2, args.m2) loss += label_weight * label_loss elif args.conditional == 'unsupervised': entropy = torch.sum( torch.exp(logc) * torch.log(model.n_classes * torch.exp(logc))) loss += entropy pred_labels = logc.data.max(1)[1].long() n_correct = pred_labels.eq(y.data).cpu().sum().float().item() NMI = normalized_mutual_info_score( y.cpu().detach().numpy(), torch.exp(logc).cpu().max(1)[1].numpy()) val_loss += loss.item() NLL_val_loss += NLL_loss.item() KL_val_loss += KL_loss.item() BOW_val_loss += BOW_loss.item() NMI_val += NMI n_correct_val += n_correct val_loss = val_loss / len(datasets.valid) NLL_val_loss = NLL_val_loss / len(datasets.valid) KL_val_loss = KL_val_loss / len(datasets.valid) BOW_val_loss = BOW_val_loss / len(datasets.valid) NMI_val = NMI_val / len(datasets.valid) acc_val = n_correct_val / len(datasets.valid) print('Epoch {} : train {:.6f} valid {:.6f}'.format( epoch, tr_loss, val_loss)) print( 'Training : NLL loss : {:.6f}, KL loss : {:.6f}, BOW : {:.6f}, acc : {:.6f}' .format(NLL_tr_loss, KL_tr_loss, BOW_tr_loss, acc_tr)) print( 'Validation : NLL loss : {:.6f}, KL loss : {:.6f}, BOW : {:.6f}, acc : {:.6f}' .format(NLL_val_loss, KL_val_loss, BOW_val_loss, acc_val)) run['NLL_hist'] = NLL_hist run['KL_hist'] = KL_hist run['NLL_val'] = NLL_val_loss run['KL_val'] = KL_val_loss run['NMI_hist'] = NMI_hist run['acc_hist'] = acc_hist run['latent'] = latent_rep return
def do_one_sweep(self, iter, is_last_epoch, train_or_dev): if train_or_dev not in ['train', 'dev']: raise TypeError("train_or_dev should be either train or dev") if train_or_dev == "train": self.model.train() else: self.model.eval() sweep_loss = 0 sweep_recon_loss = 0 sweep_kl_loss = 0 sweep_accuracy = 0 n_batches = 0 for iteration, batch in enumerate(tqdm(iter)): # if len(batch) < self.batch_size and : # continue if train_or_dev == "train": self.step += 1 self.optimizer.zero_grad() # forward pass x, lengths = getattr(batch, self.dataset.input_type) input = x[:, :-1] # remove <eos> target = x[:, 1:] # remove <sos> lengths -= 1 # account for the removal input, target = to_device(input, self.force_cpu), to_device( target, self.force_cpu) y = None if self.model.conditional is not None: y = batch.intent.squeeze() y = to_device(y, self.force_cpu) sorted_lengths, sorted_idx = torch.sort(lengths, descending=True) y = y[sorted_idx] logp, mean, logv, logc, z, bow = self.model(input, lengths) if is_last_epoch: _, reversed_idx = torch.sort(sorted_idx) y = y[reversed_idx] logc = logc[reversed_idx] real_labels = [self.i2int[label] for label in y] pred_labels = [ self.i2int[label] if label < len(self.i2int) else 'None' for label in logc.max(1)[1] ] for real_label, pred_label in zip(real_labels, pred_labels): self.run_logs[train_or_dev]['classifications'][real_label][ pred_label] += 1 for real_label in real_labels: self.run_logs[train_or_dev]['transfer'][ real_label] += logc.sum(dim=0).cpu().detach() # save latent representation if train_or_dev == "train" and self.model.conditional: for i, intent in enumerate(y): self.latent_rep[self.i2int[intent]].append( z[i].cpu().detach().numpy()) # loss calculation loss, recon_loss, kl_loss, accuracy = self.compute_loss( logp, bow, target, lengths, mean, logv, logc, y, train_or_dev) sweep_loss += loss sweep_recon_loss += recon_loss sweep_kl_loss += kl_loss sweep_accuracy += accuracy n_batches += 1 if train_or_dev == "train": loss.backward() self.optimizer.step() if is_last_epoch: for intent1 in self.i2int: n_sentences = sum(self.run_logs[train_or_dev] ['classifications'][intent1].values()) self.run_logs[train_or_dev]['transfer'][intent1] /= n_sentences for intent2 in self.i2int: self.run_logs[train_or_dev]['classifications'][intent1][ intent2] /= n_sentences return sweep_loss / n_batches, sweep_recon_loss / n_batches, \ sweep_kl_loss / n_batches, sweep_accuracy / n_batches
def inference(self, n=10, z=None, y_onehot=None, temperature=0): if z is None: batch_size = n z = torch.randn(batch_size, self.z_size) else: batch_size = z.size(0) if self.conditional is not None: if y_onehot is None: y = torch.LongTensor(batch_size, 1).random_() % self.n_classes y_onehot = torch.FloatTensor(batch_size, self.cat_size) y_onehot.fill_(0) y_onehot.scatter_(dim=1, index=y, value=1) latent = to_device(torch.cat((z, y_onehot), dim=1), self.force_cpu) else: y_onehot = None latent = to_device(z, self.force_cpu) hidden = self.latent2hidden(latent) if self.bidirectional or self.num_layers_decoder > 1: # unflatten hidden state hidden = hidden.view(self.num_layers_decoder, batch_size, self.hidden_size) else: hidden = hidden.unsqueeze(0) # required for dynamic stopping of sentence generation sequence_idx = torch.arange( 0, batch_size, out=self.tensor()).long() # all idx of batch sequence_running = torch.arange( 0, batch_size, out=self.tensor()).long() # all idx of batch # which are still generating sequence_mask = torch.ones(batch_size, out=self.tensor()).byte() running_seqs = torch.arange(0, batch_size, out=self.tensor()).long() # idx of still # generating sequences with respect to current loop generations = self.tensor(batch_size, self.max_sequence_length).fill_( self.pad_idx).long() t = 0 while t < self.max_sequence_length and len(running_seqs) > 0: if t == 0: input_sequence = torch.Tensor(batch_size).fill_( self.sos_idx).long() # input_sequence = torch.randint(0, self.vocab_size, # (batch_size,)) input_sequence = to_device(input_sequence.unsqueeze(1), self.force_cpu) input_embedding = self.embedding(input_sequence) output, hidden = self.decoder_rnn(input_embedding, hidden) logits = self.outputs2vocab(output) logp = nn.functional.log_softmax(logits / self.temperature, dim=-1) input_sequence = self._sample(logits) # save next input generations = self._save_sample(generations, input_sequence, sequence_running, t) # update gloabl running sequence sequence_mask[sequence_running] = (input_sequence != self.eos_idx).data sequence_running = sequence_idx.masked_select(sequence_mask) # update local running sequences running_mask = (input_sequence != self.eos_idx).data running_seqs = running_seqs.masked_select(running_mask) # prune input and hidden state according to local update if len(running_seqs) > 0: try: input_sequence = input_sequence[running_seqs] except: break hidden = hidden[:, running_seqs] running_seqs = torch.arange(0, len(running_seqs), out=self.tensor()).long() t += 1 return generations, z, y_onehot, logp
def forward(self, input_sequence, lengths): batch_size = input_sequence.size(0) sorted_lengths, sorted_idx = torch.sort(lengths, descending=True) input_sequence = input_sequence[sorted_idx] # ENCODER input_embedding = self.embedding(input_sequence) packed_input = rnn_utils.pack_padded_sequence( input_embedding, sorted_lengths.data.tolist(), batch_first=True) _, hidden = self.encoder_rnn(packed_input) if self.bidirectional or self.num_layers_encoder > 1: # flatten hidden state hidden = hidden.view( batch_size, self.hidden_size_encoder * self.hidden_factor_encoder) else: hidden = hidden.squeeze() # REPARAMETERIZATION mean = self.hidden2mean(hidden) logv = self.hidden2logv(hidden) std = torch.exp(0.5 * logv) z = to_device(torch.randn(batch_size, self.z_size), self.force_cpu) z = z * std + mean if self.conditional is not None: logc = nn.functional.log_softmax(self.hidden2cat(hidden), dim=-1) y_onehot = nn.functional.gumbel_softmax(logc) latent = torch.cat((z, y_onehot), dim=-1) else: logc = None latent = z # DECODER hidden = self.latent2hidden(latent) if self.bidirectional or self.num_layers_decoder > 1: # unflatten hidden state hidden = hidden.view(self.num_layers_decoder, batch_size, self.hidden_size_decoder) else: hidden = hidden.unsqueeze(0) # decoder input if self.word_dropout_rate > 0: # randomly replace decoder input with <unk> prob = torch.rand(input_sequence.size()) prob = to_device(prob) prob[(input_sequence.data - self.sos_idx) * (input_sequence.data - self.pad_idx) == 0] = 1 decoder_input_sequence = input_sequence.clone() decoder_input_sequence[ prob < self.word_dropout_rate] = self.unk_idx input_embedding = self.embedding(decoder_input_sequence) input_embedding = self.embedding_dropout(input_embedding) packed_input = rnn_utils.pack_padded_sequence( input_embedding, sorted_lengths.data.tolist(), batch_first=True) outputs, _ = self.decoder_rnn(packed_input, hidden) # process outputs padded_outputs = \ rnn_utils.pad_packed_sequence(outputs, batch_first=True)[0] padded_outputs = padded_outputs.contiguous() _, reversed_idx = torch.sort(sorted_idx) padded_outputs = padded_outputs[reversed_idx] bs, seqlen, hs = padded_outputs.size() logits = self.outputs2vocab(padded_outputs.view(-1, hs)) logp = nn.functional.log_softmax(logits / self.temperature, dim=-1) logp = logp.view(bs, seqlen, self.embedding.num_embeddings) if self.bow: bow = nn.functional.log_softmax(self.z2bow(z), dim=0) bow = bow[reversed_idx] else: bow = None return logp, mean, logv, logc, z, bow