Exemplo n.º 1
0
    def inference(data_iter, raw_style):
        gold_text = []
        raw_output = []
        rev_output = []
        for batch in data_iter:
            inp_tokens = batch.text
            inp_lengths = get_lengths(inp_tokens, eos_idx)
            raw_styles = torch.full_like(inp_tokens[:, 0], raw_style)
            rev_styles = 1 - raw_styles

            with torch.no_grad():
                raw_log_probs = model_F(
                    inp_tokens,
                    None,
                    inp_lengths,
                    raw_styles,
                    generate=True,
                    differentiable_decode=False,
                    temperature=temperature,
                )

            with torch.no_grad():
                rev_log_probs = model_F(
                    inp_tokens,
                    None,
                    inp_lengths,
                    rev_styles,
                    generate=True,
                    differentiable_decode=False,
                    temperature=temperature,
                )

            #gold_text += tensor2text(vocab, inp_tokens.cpu())
            #raw_output += tensor2text(vocab, raw_log_probs.argmax(-1).cpu())
            #rev_output += tensor2text(vocab, rev_log_probs.argmax(-1).cpu())
            gold_text += tensor2text(vocab, inp_tokens.cuda())
            raw_output += tensor2text(vocab, raw_log_probs.argmax(-1).cuda())
            rev_output += tensor2text(vocab, rev_log_probs.argmax(-1).cuda())

        return gold_text, raw_output, rev_output
Exemplo n.º 2
0
    def inference(data_iter, raw_style):
        gold_text = []
        raw_output = []
        rev_outputs = [[] for _ in ratios]
        for batch in data_iter:
            inp_tokens = batch.text
            inp_lengths = get_lengths(inp_tokens, eos_idx)
            raw_styles = torch.full_like(inp_tokens[:, 0], raw_style)
            rev_styles = 1 - raw_styles

            with torch.no_grad():
                raw_log_probs = model_F(
                    inp_tokens,
                    None,
                    inp_lengths,
                    raw_styles,
                    generate=True,
                    differentiable_decode=False,
                    temperature=temperature,
                )
            gold_text += tensor2text(vocab, inp_tokens.cpu())
            raw_output += tensor2text(vocab, raw_log_probs.argmax(-1).cpu())

            for i, r in enumerate(ratios):
                with torch.no_grad():
                    rev_log_probs = model_F(
                        inp_tokens,
                        None,
                        inp_lengths,
                        raw_styles + (rev_styles-raw_styles)*r,  # 0.2, 0.4, 0.6, 0.8, 1.0
                        generate=True,
                        differentiable_decode=False,
                        temperature=temperature,
                    )

                rev_outputs[i] += tensor2text(vocab, rev_log_probs.argmax(-1).cpu())

        return gold_text, raw_output, rev_outputs
Exemplo n.º 3
0
        dataset=dataset,
        batch_size=config.batch_size,
        shuffle=train,
        repeat=train,
        sort_key=lambda x: len(x.text),
        sort_within_batch=False,
        device=config.device)

    train_pos_iter, train_neg_iter = map(lambda x: dataiter_fn(x, True),
                                         [train_pos_set, train_neg_set])
    dev_pos_iter, dev_neg_iter = map(lambda x: dataiter_fn(x, False),
                                     [dev_pos_set, dev_neg_set])
    test_pos_iter, test_neg_iter = map(lambda x: dataiter_fn(x, False),
                                       [test_pos_set, test_neg_set])

    train_iters = DatasetIterator(train_pos_iter, train_neg_iter)
    dev_iters = DatasetIterator(dev_pos_iter, dev_neg_iter)
    test_iters = DatasetIterator(test_pos_iter, test_neg_iter)

    return train_iters, dev_iters, test_iters, vocab


if __name__ == '__main__':
    train_iter, _, _, vocab = load_dataset('../data/swbd/')
    print(len(vocab))
    for batch in train_iter:
        text = tensor2text(vocab, batch.text)
        print('\n'.join(text))
        print(batch.label)
        break
Exemplo n.º 4
0
def part2(args):

    ## load model
    #model_prefix = './save/Feb15203331/ckpts/1300'
    model_prefix = os.path.join(args.part2_model_dir, str(args.part2_step))

    args.preload_F = f'{model_prefix}_F.pth'
    args.preload_D = f'{model_prefix}_D.pth'

    ## load data
    train_iters, dev_iters, test_iters, vocab = load_dataset(args)

    ## output dir
    output_dir = 'part2_output'
    os.makedirs(output_dir, exist_ok=True)

    log_f = open(os.path.join(output_dir, 'log.txt'), 'w')

    model_F = StyleTransformer(args, vocab).to(args.device)
    model_D = Discriminator(args, vocab).to(args.device)

    assert os.path.isfile(args.preload_F)
    model_F.load_state_dict(torch.load(args.preload_F))
    assert os.path.isfile(args.preload_D)
    model_D.load_state_dict(torch.load(args.preload_D))

    model_F.eval()
    model_D.eval()

    dataset = test_iters
    pos_iter = dataset.pos_iter
    neg_iter = dataset.neg_iter

    pad_idx = vocab.stoi['<pad>']  # 1
    eos_idx = vocab.stoi['<eos>']  # 2
    unk_idx = vocab.stoi['<unk>']  # 0

    ## 2-1 attention
    log(log_f, "***** 2-1: Attention *****")

    gold_text = []
    gold_token = []
    rev_output = []
    rev_token = []
    attn_weight = None

    raw_style = 1  ## neg: 0, pos: 1

    for batch in pos_iter:

        inp_tokens = batch.text
        inp_lengths = get_lengths(inp_tokens, eos_idx)
        raw_styles = torch.full_like(inp_tokens[:, 0], raw_style)
        rev_styles = 1 - raw_styles

        with torch.no_grad():
            rev_log_probs = model_F(inp_tokens,
                                    None,
                                    inp_lengths,
                                    rev_styles,
                                    generate=True,
                                    differentiable_decode=False,
                                    temperature=1)

        rev_attn = model_F.get_decode_src_attn_weight()
        if attn_weight == None:
            attn_weight = rev_attn
        else:
            for layer in range(len(rev_attn)):
                attn_weight[layer] = torch.cat(
                    [attn_weight[layer], rev_attn[layer]])

        gold_text += tensor2text(vocab, inp_tokens.cpu())
        rev_idx = rev_log_probs.argmax(-1).cpu()
        rev_output += tensor2text(vocab, rev_idx)

        gold_token.extend([[vocab.itos[j] for j in i] for i in inp_tokens])
        rev_token.extend([[vocab.itos[j] for j in i] for i in rev_idx])

        break  ## select first batch to speed up

    # attn_weight[layer] = (Batch, Head, Source, Style+Target)

    idx = np.random.randint(len(rev_output))
    log(log_f, '*' * 20, 'pos sample', '*' * 20)
    log(log_f, '[gold]', gold_text[idx])
    log(log_f, '[rev ]', rev_output[idx])
    for l in range(len(attn_weight)):
        output_name = os.path.join(output_dir, f'problem1_attn_layer{l}.png')
        show_attn(gold_token[idx], rev_token[idx], attn_weight[l][idx],
                  'attention map', output_name)
        log(log_f, f'save attention figure at {output_name}')

    log(log_f, '***** 2-1 end *****')
    log(log_f)

    ## 2-2. tsne
    log(log_f, "***** 2-2: T-sne *****")
    features = []
    labels = []

    for batch in pos_iter:

        inp_tokens = batch.text
        inp_lengths = get_lengths(inp_tokens, eos_idx)

        _, pos_features = model_D(inp_tokens,
                                  inp_lengths,
                                  return_features=True)
        features.extend(pos_features.detach().cpu().numpy())
        labels.extend([0 for i in range(pos_features.shape[0])])

        raw_style = 1
        raw_styles = torch.full_like(inp_tokens[:, 0], raw_style)
        rev_styles = 1 - raw_styles

        with torch.no_grad():
            rev_log_probs = model_F(inp_tokens,
                                    None,
                                    inp_lengths,
                                    rev_styles,
                                    generate=True,
                                    differentiable_decode=False,
                                    temperature=1)

        rev_tokens = rev_log_probs.argmax(-1)
        rev_lengths = get_lengths(rev_tokens, eos_idx)
        _, rev_features = model_D(rev_tokens,
                                  inp_lengths,
                                  return_features=True)
        features.extend(rev_features.detach().cpu().numpy())
        labels.extend([1 for i in range(rev_features.shape[0])])

    for batch in neg_iter:

        inp_tokens = batch.text
        inp_lengths = get_lengths(inp_tokens, eos_idx)

        _, neg_features = model_D(inp_tokens,
                                  inp_lengths,
                                  return_features=True)
        features.extend(neg_features.detach().cpu().numpy())
        labels.extend([2 for i in range(neg_features.shape[0])])

        raw_style = 0
        raw_styles = torch.full_like(inp_tokens[:, 0], raw_style)
        rev_styles = 1 - raw_styles

        with torch.no_grad():
            rev_log_probs = model_F(inp_tokens,
                                    None,
                                    inp_lengths,
                                    rev_styles,
                                    generate=True,
                                    differentiable_decode=False,
                                    temperature=1)

        rev_tokens = rev_log_probs.argmax(-1)
        rev_lengths = get_lengths(rev_tokens, eos_idx)
        _, rev_features = model_D(rev_tokens,
                                  inp_lengths,
                                  return_features=True)
        features.extend(rev_features.detach().cpu().numpy())
        labels.extend([3 for i in range(rev_features.shape[0])])

    labels = np.array(labels)
    colors = ['red', 'blue', 'orange', 'green']
    classes = ['POS', 'POS -> NEG', 'NEG', 'NEG -> POS']
    X_emb = TSNE(n_components=2).fit_transform(features)

    fig, ax = plt.subplots()
    for i in range(4):
        idxs = labels == i
        ax.scatter(X_emb[idxs, 0],
                   X_emb[idxs, 1],
                   color=colors[i],
                   label=classes[i],
                   alpha=0.8,
                   edgecolors='none')
    ax.legend()
    ax.set_title('t-sne of four distributions')
    output_name = os.path.join(output_dir, 'problem2_tsne.png')
    plt.savefig(output_name)
    log(log_f, f'save T-sne figure at {output_name}')
    log(log_f, "***** 2-2 end *****")
    log(log_f)

    # 2-3. mask input tokens
    log(log_f, '***** 2-3: mask input *****')
    raw_style = 1

    for batch in pos_iter:
        inp_tokens = batch.text
        inp_lengths = get_lengths(inp_tokens, eos_idx)
        break  ## only select first batch

    sample_idx = np.random.randint(inp_tokens.shape[0])
    inp_token = inp_tokens[sample_idx]
    inp_length = inp_lengths[sample_idx]

    inp_tokens = inp_token.repeat(
        inp_length - 2 + 1,
        1)  ## mask until '. <eos>' but contain the origin sentence
    for i in range(inp_tokens.shape[0] - 1):
        inp_tokens[i + 1][i] = unk_idx

    inp_lengths = torch.full_like(inp_tokens[:, 0], inp_length)
    raw_styles = torch.full_like(inp_tokens[:, 0], raw_style)
    rev_styles = 1 - raw_styles

    with torch.no_grad():
        rev_log_probs = model_F(inp_tokens,
                                None,
                                inp_lengths,
                                rev_styles,
                                generate=True,
                                differentiable_decode=False,
                                temperature=1)

    gold_text = tensor2text(vocab, inp_tokens.cpu(), remain_unk=True)
    rev_idx = rev_log_probs.argmax(-1).cpu()
    rev_output = tensor2text(vocab, rev_idx, remain_unk=True)

    for i in range(len(gold_text)):
        log(log_f, '-')
        log(log_f, '[ORG]', gold_text[i])
        log(log_f, '[REV]', rev_output[i])

    log(log_f, '***** 2-3 end *****')
    log_f.close()
Exemplo n.º 5
0
    def forward(self,
                inp_tokens,
                gold_tokens,
                inp_lengths,
                style,
                generate=False,
                differentiable_decode=False,
                temperature=1.0):
        batch_size = inp_tokens.size(0)
        max_enc_len = inp_tokens.size(1)

        try:
            assert max_enc_len <= self.max_length
        except AssertionError:
            print(max_enc_len, self.max_length,
                  tensor2text(self.vocab, inp_tokens))

        pos_idx = torch.arange(self.max_length).unsqueeze(0).expand(
            (batch_size, -1))
        pos_idx = pos_idx.to(inp_lengths.device)

        src_mask = pos_idx[:, :max_enc_len] >= inp_lengths.unsqueeze(-1)
        src_mask = torch.cat((torch.zeros_like(src_mask[:, :1]), src_mask), 1)
        src_mask = src_mask.view(batch_size, 1, 1, max_enc_len + 1)

        tgt_mask = torch.ones(
            (self.max_length, self.max_length)).to(src_mask.device)
        tgt_mask = (tgt_mask.tril() == 0).view(1, 1, self.max_length,
                                               self.max_length)

        style_emb = self.style_embed(style).unsqueeze(1)

        enc_input = torch.cat(
            (style_emb, self.embed(inp_tokens, pos_idx[:, :max_enc_len])), 1)
        memory = self.encoder(enc_input, src_mask)

        sos_token = self.sos_token.view(1, 1, -1).expand(batch_size, -1, -1)

        if not generate:
            dec_input = gold_tokens[:, :-1]
            max_dec_len = gold_tokens.size(1)
            dec_input_emb = torch.cat(
                (sos_token, self.embed(dec_input,
                                       pos_idx[:, :max_dec_len - 1])), 1)
            log_probs = self.decoder(
                dec_input_emb, memory, src_mask,
                tgt_mask[:, :, :max_dec_len, :max_dec_len], temperature)
        else:

            log_probs = []
            next_token = sos_token
            prev_states = None

            for k in range(self.max_length):
                log_prob, prev_states = self.decoder.incremental_forward(
                    next_token, memory, src_mask,
                    tgt_mask[:, :, k:k + 1, :k + 1], temperature, prev_states)

                log_probs.append(log_prob)

                if differentiable_decode:
                    next_token = self.embed(log_prob.exp(), pos_idx[:,
                                                                    k:k + 1])
                else:
                    next_token = self.embed(log_prob.argmax(-1),
                                            pos_idx[:, k:k + 1])

                #if (pred_tokens == self.eos_idx).max(-1)[0].min(-1)[0].item() == 1:
                #    break

            log_probs = torch.cat(log_probs, 1)

        return log_probs