def forward(self, y, large_z, context): # train time y, lengths = append(truncate(y, 'eos'), 'sos') if self.word_drop > 0.: y = word_drop(y, self.word_drop) embedded = self.embedding(y) # (B, l, 300) embedded = torch.cat( [embedded, context.repeat(1, embedded.size(1), 1)], dim=-1) packed = pack_padded_sequence(embedded, lengths, batch_first=True) init_hidden = self._transform_hidden(large_z) packed_output, _ = self.lstm(packed, init_hidden) total_length = embedded.size(1) output, _ = pad_packed_sequence(packed_output, batch_first=True, total_length=total_length) recon_logits = self.out(output) return recon_logits # (B, L, vocab_size)
def forward(self, orig, para, z): # train time orig, orig_lengths = orig # (B, l), (B,) orig = self.embedding(orig) # (B, l, 300) orig_packed = pack_padded_sequence(orig, orig_lengths, batch_first=True) _, orig_hidden = self.lstm_orig(orig_packed) para, _ = append(truncate(para, 'eos'), 'sos') if self.word_drop > 0.: para = word_drop(para, self.word_drop) # from Bowman's paper para = self.embedding(para) L = para.size(1) para_z = torch.cat([para, z.repeat(1, L, 1)], dim=-1) # (B, L, 1100+300) para_output, _ = self.lstm_para(para_z, orig_hidden) # no packing logits = self.linear(para_output) return logits # (B, L, vocab_size)
def f_step(config, vocab, model_F, model_D, optimizer_F, batch, temperature, drop_decay, cyc_rec_enable=True): model_D.eval() pad_idx = vocab.stoi['<pad>'] eos_idx = vocab.stoi['<eos>'] unk_idx = vocab.stoi['<unk>'] vocab_size = len(vocab) loss_fn = nn.NLLLoss(reduction='none') inp_tokens, inp_lengths, raw_styles = batch_preprocess( batch, pad_idx, eos_idx) rev_styles = 1 - raw_styles batch_size = inp_tokens.size(0) token_mask = (inp_tokens != pad_idx).float() optimizer_F.zero_grad() # self reconstruction loss noise_inp_tokens = word_drop(inp_tokens, inp_lengths, config.inp_drop_prob * drop_decay, vocab) noise_inp_lengths = get_lengths(noise_inp_tokens, eos_idx) slf_log_probs = model_F( noise_inp_tokens, inp_tokens, noise_inp_lengths, raw_styles, generate=False, differentiable_decode=False, temperature=temperature, ) slf_rec_loss = loss_fn(slf_log_probs.transpose(1, 2), inp_tokens) * token_mask slf_rec_loss = slf_rec_loss.sum() / batch_size slf_rec_loss *= config.slf_factor slf_rec_loss.backward() # cycle consistency loss if not cyc_rec_enable: optimizer_F.step() model_D.train() return slf_rec_loss.item(), 0, 0 gen_log_probs = model_F( inp_tokens, None, inp_lengths, rev_styles, generate=True, differentiable_decode=True, temperature=temperature, ) gen_soft_tokens = gen_log_probs.exp() gen_lengths = get_lengths(gen_soft_tokens.argmax(-1), eos_idx) cyc_log_probs = model_F( gen_soft_tokens, inp_tokens, gen_lengths, raw_styles, generate=False, differentiable_decode=False, temperature=temperature, ) cyc_rec_loss = loss_fn(cyc_log_probs.transpose(1, 2), inp_tokens) * token_mask cyc_rec_loss = cyc_rec_loss.sum() / batch_size cyc_rec_loss *= config.cyc_factor # style consistency loss adv_log_porbs = model_D(gen_soft_tokens, gen_lengths, rev_styles) if config.discriminator_method == 'Multi': adv_labels = rev_styles + 1 else: adv_labels = torch.ones_like(rev_styles) adv_loss = loss_fn(adv_log_porbs, adv_labels) adv_loss = adv_loss.sum() / batch_size adv_loss *= config.adv_factor (cyc_rec_loss + adv_loss).backward() # update parameters clip_grad_norm_(model_F.parameters(), 5) optimizer_F.step() model_D.train() return slf_rec_loss.item(), cyc_rec_loss.item(), adv_loss.item()
def f_step(config, model_F, model_D, optimizer_F, batch, temperature, drop_decay, cyc_rec_enable=True): model_D.eval() pad_idx = model_F.tokenizer.pad_token_id eos_idx = model_F.tokenizer.eos_token_id unk_idx = model_F.tokenizer.unk_token_id vocab_size = model_F.tokenizer.vocab_size loss_fn = nn.NLLLoss(reduction='none') # How to get batch inp_tokens = batch["source_ids"].to(config.device) inp_lengths = get_lengths(inp_tokens, eos_idx).to(config.device) raw_styles = batch["source_style"].to(config.device) raw_styles = raw_styles.unsqueeze(1) rev_styles = 1 - raw_styles batch_size = inp_tokens.size(0) token_mask = batch['source_mask'].to(config.device) target_ids = batch['target_ids'].to(config.device) y_ids = target_ids[:, :-1].contiguous() lm_labels = target_ids[:, 1:].clone() lm_labels[target_ids[:, 1:] == pad_idx] = -100 optimizer_F.zero_grad() # self reconstruction loss noise_inp_tokens, noise_token_mask = word_drop( inp_tokens, token_mask, inp_lengths, config.inp_drop_prob * drop_decay) noise_inp_lengths = get_lengths(noise_inp_tokens, eos_idx) noise_inp_tokens = torch.cat((raw_styles, noise_inp_tokens), 1).to(config.device) noise_token_mask = torch.cat( (torch.ones_like(raw_styles), noise_token_mask), 1).to(config.device) outputs = model_F(input_ids=noise_inp_tokens, attention_mask=noise_token_mask, decoder_input_ids=y_ids, lm_labels=lm_labels) slf_log_probs = outputs[ 1] # 0: LM loss, 1: logits 2: hidden 3: attention? # slf_rec_loss = loss_fn(slf_log_probs.transpose(1, 2), inp_tokens) * token_mask # slf_rec_loss = slf_rec_loss.sum() / batch_size slf_rec_loss = outputs[0] slf_rec_loss *= config.slf_factor slf_rec_loss.backward() # cycle consistency loss # return if if not cyc_rec_enable: optimizer_F.step() model_D.train() return slf_rec_loss.item(), 0, 0 inp_tokens_rev = torch.cat((rev_styles, inp_tokens), 1).to(config.device) token_mask_rev = torch.cat((torch.ones_like(rev_styles), token_mask), 1).to(config.device) # outputs = model_F( # input_ids=inp_tokens_rev, attention_mask=token_mask_rev, decoder_input_ids=None, lm_labels=None # ) # add y_ids to get loss value.... outputs = model_F(input_ids=inp_tokens_rev, attention_mask=token_mask_rev, decoder_input_ids=y_ids, lm_labels=lm_labels) gen_log_probs = outputs[1] # gen_soft_tokens = gen_log_probs.exp() # gen_lengths = get_lengths(gen_soft_tokens.argmax(-1), eos_idx) gen_soft_tokens = gen_log_probs.argmax(-1) gen_lengths = get_lengths(gen_soft_tokens, eos_idx) # pos_idx = torch.arange(max_seq_len).unsqueeze(0).expand((batch_size, -1)).to(config.device) # gen_token_mask = pos_idx >= gen_lengths.unsqueeze(-1) # pdb.set_trace() gen_token_mask = torch.zeros_like(gen_soft_tokens) for i, length in enumerate(gen_lengths): # logger.info(length, gen_soft_tokens.size()) gen_token_mask[i] = torch.LongTensor( [1] * length + [0] * (gen_soft_tokens.size(1) - length)).to( config.device) raw_styles = raw_styles.type(torch.LongTensor).to(config.device) gen_soft_tokens = torch.cat((raw_styles, gen_soft_tokens), 1) gen_token_mask = torch.cat((torch.ones_like(raw_styles), gen_token_mask), 1) rev_y_ids = inp_tokens[:, :-1].contiguous() rev_lm_labels = inp_tokens[:, 1:].clone() rev_lm_labels[inp_tokens[:, 1:] == pad_idx] = -100 outputs = model_F(input_ids=gen_soft_tokens, attention_mask=gen_token_mask, decoder_input_ids=rev_y_ids, lm_labels=rev_lm_labels) cyc_rec_loss = outputs[0] cyc_rec_loss *= config.cyc_factor # style consistency loss adv_log_probs = model_D(gen_log_probs.exp(), gen_lengths) if config.discriminator_method == 'Multi': adv_labels = rev_styles + 1 else: adv_labels = torch.ones_like(rev_styles) # pdb.set_trace() adv_labels = adv_labels.squeeze(1) adv_loss = loss_fn(adv_log_probs, adv_labels) adv_loss = adv_loss.sum() adv_loss *= config.adv_factor (cyc_rec_loss + adv_loss).backward() # update parameters clip_grad_norm_(model_F.parameters(), 5) optimizer_F.step() model_F.lr_scheduler.step() model_D.train() return slf_rec_loss.item(), cyc_rec_loss.item(), adv_loss.item()