Beispiel #1
0
 def forward(self, y, large_z, context):  # train time
     y, lengths = append(truncate(y, 'eos'), 'sos')
     if self.word_drop > 0.:
         y = word_drop(y, self.word_drop)
     embedded = self.embedding(y)  # (B, l, 300)
     embedded = torch.cat(
         [embedded, context.repeat(1, embedded.size(1), 1)], dim=-1)
     packed = pack_padded_sequence(embedded, lengths, batch_first=True)
     init_hidden = self._transform_hidden(large_z)
     packed_output, _ = self.lstm(packed, init_hidden)
     total_length = embedded.size(1)
     output, _ = pad_packed_sequence(packed_output,
                                     batch_first=True,
                                     total_length=total_length)
     recon_logits = self.out(output)
     return recon_logits  # (B, L, vocab_size)
Beispiel #2
0
 def forward(self, orig, para, z):  # train time
     orig, orig_lengths = orig  # (B, l), (B,)
     orig = self.embedding(orig)  # (B, l, 300)
     orig_packed = pack_padded_sequence(orig,
                                        orig_lengths,
                                        batch_first=True)
     _, orig_hidden = self.lstm_orig(orig_packed)
     para, _ = append(truncate(para, 'eos'), 'sos')
     if self.word_drop > 0.:
         para = word_drop(para, self.word_drop)  # from Bowman's paper
     para = self.embedding(para)
     L = para.size(1)
     para_z = torch.cat([para, z.repeat(1, L, 1)],
                        dim=-1)  # (B, L, 1100+300)
     para_output, _ = self.lstm_para(para_z, orig_hidden)  # no packing
     logits = self.linear(para_output)
     return logits  # (B, L, vocab_size)
Beispiel #3
0
def f_step(config,
           vocab,
           model_F,
           model_D,
           optimizer_F,
           batch,
           temperature,
           drop_decay,
           cyc_rec_enable=True):
    model_D.eval()

    pad_idx = vocab.stoi['<pad>']
    eos_idx = vocab.stoi['<eos>']
    unk_idx = vocab.stoi['<unk>']
    vocab_size = len(vocab)
    loss_fn = nn.NLLLoss(reduction='none')

    inp_tokens, inp_lengths, raw_styles = batch_preprocess(
        batch, pad_idx, eos_idx)
    rev_styles = 1 - raw_styles
    batch_size = inp_tokens.size(0)
    token_mask = (inp_tokens != pad_idx).float()

    optimizer_F.zero_grad()

    # self reconstruction loss

    noise_inp_tokens = word_drop(inp_tokens, inp_lengths,
                                 config.inp_drop_prob * drop_decay, vocab)
    noise_inp_lengths = get_lengths(noise_inp_tokens, eos_idx)

    slf_log_probs = model_F(
        noise_inp_tokens,
        inp_tokens,
        noise_inp_lengths,
        raw_styles,
        generate=False,
        differentiable_decode=False,
        temperature=temperature,
    )

    slf_rec_loss = loss_fn(slf_log_probs.transpose(1, 2),
                           inp_tokens) * token_mask
    slf_rec_loss = slf_rec_loss.sum() / batch_size
    slf_rec_loss *= config.slf_factor

    slf_rec_loss.backward()

    # cycle consistency loss

    if not cyc_rec_enable:
        optimizer_F.step()
        model_D.train()
        return slf_rec_loss.item(), 0, 0

    gen_log_probs = model_F(
        inp_tokens,
        None,
        inp_lengths,
        rev_styles,
        generate=True,
        differentiable_decode=True,
        temperature=temperature,
    )

    gen_soft_tokens = gen_log_probs.exp()
    gen_lengths = get_lengths(gen_soft_tokens.argmax(-1), eos_idx)

    cyc_log_probs = model_F(
        gen_soft_tokens,
        inp_tokens,
        gen_lengths,
        raw_styles,
        generate=False,
        differentiable_decode=False,
        temperature=temperature,
    )

    cyc_rec_loss = loss_fn(cyc_log_probs.transpose(1, 2),
                           inp_tokens) * token_mask
    cyc_rec_loss = cyc_rec_loss.sum() / batch_size
    cyc_rec_loss *= config.cyc_factor

    # style consistency loss

    adv_log_porbs = model_D(gen_soft_tokens, gen_lengths, rev_styles)
    if config.discriminator_method == 'Multi':
        adv_labels = rev_styles + 1
    else:
        adv_labels = torch.ones_like(rev_styles)
    adv_loss = loss_fn(adv_log_porbs, adv_labels)
    adv_loss = adv_loss.sum() / batch_size
    adv_loss *= config.adv_factor

    (cyc_rec_loss + adv_loss).backward()

    # update parameters

    clip_grad_norm_(model_F.parameters(), 5)
    optimizer_F.step()

    model_D.train()

    return slf_rec_loss.item(), cyc_rec_loss.item(), adv_loss.item()
Beispiel #4
0
def f_step(config,
           model_F,
           model_D,
           optimizer_F,
           batch,
           temperature,
           drop_decay,
           cyc_rec_enable=True):
    model_D.eval()

    pad_idx = model_F.tokenizer.pad_token_id
    eos_idx = model_F.tokenizer.eos_token_id
    unk_idx = model_F.tokenizer.unk_token_id
    vocab_size = model_F.tokenizer.vocab_size
    loss_fn = nn.NLLLoss(reduction='none')

    # How to get batch
    inp_tokens = batch["source_ids"].to(config.device)
    inp_lengths = get_lengths(inp_tokens, eos_idx).to(config.device)
    raw_styles = batch["source_style"].to(config.device)
    raw_styles = raw_styles.unsqueeze(1)

    rev_styles = 1 - raw_styles
    batch_size = inp_tokens.size(0)
    token_mask = batch['source_mask'].to(config.device)
    target_ids = batch['target_ids'].to(config.device)

    y_ids = target_ids[:, :-1].contiguous()
    lm_labels = target_ids[:, 1:].clone()
    lm_labels[target_ids[:, 1:] == pad_idx] = -100

    optimizer_F.zero_grad()

    # self reconstruction loss

    noise_inp_tokens, noise_token_mask = word_drop(
        inp_tokens, token_mask, inp_lengths, config.inp_drop_prob * drop_decay)
    noise_inp_lengths = get_lengths(noise_inp_tokens, eos_idx)

    noise_inp_tokens = torch.cat((raw_styles, noise_inp_tokens),
                                 1).to(config.device)
    noise_token_mask = torch.cat(
        (torch.ones_like(raw_styles), noise_token_mask), 1).to(config.device)

    outputs = model_F(input_ids=noise_inp_tokens,
                      attention_mask=noise_token_mask,
                      decoder_input_ids=y_ids,
                      lm_labels=lm_labels)
    slf_log_probs = outputs[
        1]  # 0: LM loss, 1: logits 2: hidden 3:  attention?

    # slf_rec_loss = loss_fn(slf_log_probs.transpose(1, 2), inp_tokens) * token_mask
    # slf_rec_loss = slf_rec_loss.sum() / batch_size

    slf_rec_loss = outputs[0]
    slf_rec_loss *= config.slf_factor

    slf_rec_loss.backward()

    # cycle consistency loss
    # return if
    if not cyc_rec_enable:
        optimizer_F.step()
        model_D.train()
        return slf_rec_loss.item(), 0, 0

    inp_tokens_rev = torch.cat((rev_styles, inp_tokens), 1).to(config.device)
    token_mask_rev = torch.cat((torch.ones_like(rev_styles), token_mask),
                               1).to(config.device)

    # outputs = model_F(
    #     input_ids=inp_tokens_rev, attention_mask=token_mask_rev, decoder_input_ids=None, lm_labels=None
    # )

    # add y_ids to get loss value....
    outputs = model_F(input_ids=inp_tokens_rev,
                      attention_mask=token_mask_rev,
                      decoder_input_ids=y_ids,
                      lm_labels=lm_labels)

    gen_log_probs = outputs[1]
    # gen_soft_tokens = gen_log_probs.exp()
    # gen_lengths = get_lengths(gen_soft_tokens.argmax(-1), eos_idx)

    gen_soft_tokens = gen_log_probs.argmax(-1)
    gen_lengths = get_lengths(gen_soft_tokens, eos_idx)

    # pos_idx = torch.arange(max_seq_len).unsqueeze(0).expand((batch_size, -1)).to(config.device)
    # gen_token_mask = pos_idx >= gen_lengths.unsqueeze(-1)
    # pdb.set_trace()
    gen_token_mask = torch.zeros_like(gen_soft_tokens)
    for i, length in enumerate(gen_lengths):
        # logger.info(length, gen_soft_tokens.size())
        gen_token_mask[i] = torch.LongTensor(
            [1] * length + [0] * (gen_soft_tokens.size(1) - length)).to(
                config.device)

    raw_styles = raw_styles.type(torch.LongTensor).to(config.device)
    gen_soft_tokens = torch.cat((raw_styles, gen_soft_tokens), 1)
    gen_token_mask = torch.cat((torch.ones_like(raw_styles), gen_token_mask),
                               1)

    rev_y_ids = inp_tokens[:, :-1].contiguous()
    rev_lm_labels = inp_tokens[:, 1:].clone()
    rev_lm_labels[inp_tokens[:, 1:] == pad_idx] = -100

    outputs = model_F(input_ids=gen_soft_tokens,
                      attention_mask=gen_token_mask,
                      decoder_input_ids=rev_y_ids,
                      lm_labels=rev_lm_labels)

    cyc_rec_loss = outputs[0]
    cyc_rec_loss *= config.cyc_factor

    # style consistency loss

    adv_log_probs = model_D(gen_log_probs.exp(), gen_lengths)
    if config.discriminator_method == 'Multi':
        adv_labels = rev_styles + 1
    else:
        adv_labels = torch.ones_like(rev_styles)
    # pdb.set_trace()

    adv_labels = adv_labels.squeeze(1)
    adv_loss = loss_fn(adv_log_probs, adv_labels)
    adv_loss = adv_loss.sum()
    adv_loss *= config.adv_factor

    (cyc_rec_loss + adv_loss).backward()

    # update parameters

    clip_grad_norm_(model_F.parameters(), 5)

    optimizer_F.step()
    model_F.lr_scheduler.step()

    model_D.train()

    return slf_rec_loss.item(), cyc_rec_loss.item(), adv_loss.item()