コード例 #1
0
def main(args):
    # logging
    if args.use_wandb:
        wandb.init(project="HW5-TextStyleTransfer", config=args)
        #wandb.config.update(vars(args))
        args = wandb.config
        print(args)
    
    train_iters, dev_iters, test_iters, vocab = load_dataset(args)
    print('Vocab size:', len(vocab))
    model_F = StyleTransformer(args, vocab).to(args.device)
    model_D = Discriminator(args, vocab).to(args.device)
    print(args.discriminator_method)

    if os.path.isfile(args.preload_F):
        temp = torch.load(args.preload_F)
        model_F.load_state_dict(temp)
    if os.path.isfile(args.preload_D):
        temp = torch.load(args.preload_D)
        model_D.load_state_dict(temp)
    
    if args.do_train:
        train(args, vocab, model_F, model_D, train_iters, dev_iters, test_iters)
    if args.do_test:
        dev_eval(args, vocab, model_F, test_iters, 0.5)
コード例 #2
0
def test():
    config = Config()
    train_iters, dev_iters, test_iters, vocab = load_dataset(config)

    step = 125
    model_F = StyleTransformer(config, vocab).to(config.device)
    model_F.load_state_dict(
        torch.load(f'./save/Jun15042756/ckpts/{step}_F.pth'))

    auto_eval(config, vocab, model_F, test_iters, 1, step)
コード例 #3
0
def main():
    config = Config()
    train_iters, dev_iters, test_iters, vocab = load_dataset(config)
    print('Vocab size:', len(vocab))
    model_F = StyleTransformer(config, vocab).to(config.device)
    model_F.load_state_dict(torch.load('./save/Jun11115103/ckpts/9925_F.pth'))
    model_D = Discriminator(config, vocab).to(config.device)
    print(config.discriminator_method)

    train(config, vocab, model_F, model_D, train_iters, dev_iters, test_iters)
コード例 #4
0
def main():
    config = Config()
    train_iters, dev_iters, test_iters, vocab = load_dataset(config,
                                                            train_pos='cvet2.train', train_neg='push2.train',
                                                            dev_pos='cvet2.dev', dev_neg='push2.dev',
                                                            test_pos='cvet2.test', test_neg='push2.test')
    print('Vocab size:', len(vocab))
    model_F = StyleTransformer(config, vocab).to(config.device)
    model_D = Discriminator(config, vocab).to(config.device)
    print(config.discriminator_method)
    
    train(config, vocab, model_F, model_D, train_iters, dev_iters, test_iters)
    torch.save(model_F.state_dict(), 'modelF_trained')
    torch.save(model_D.state_dict(), 'modelD_trained')
コード例 #5
0
def main():
    config = Config()
    train_iters, dev_iters, test_iters, vocab = load_dataset(config)
    print('Vocab size:', len(vocab))
    model_F = StyleTransformer(config, vocab).to(config.device)
    model_D = Discriminator(config, vocab).to(config.device)
    print(config.discriminator_method)

    train(config, vocab, model_F, model_D, train_iters, dev_iters, test_iters)
コード例 #6
0
ファイル: main.py プロジェクト: daxborde/style-transformer
def main():
    config = Config()
    train_iters, test_iters, vocab = load_enron(config)
    print('Vocab size:', len(vocab))
    model_F = StyleTransformer(config, vocab).to(config.device)
    model_D = Discriminator(config, vocab).to(config.device)
    print(config.discriminator_method)

    # last_checkpoint = most_recent_path(most_recent_path(config.save_path), return_two=True)
    # if last_checkpoint:
    #     print(last_checkpoint)
    #     model_D.load_state_dict(torch.load(last_checkpoint[1]))
    #     model_F.load_state_dict(torch.load(last_checkpoint[0]))

    train(config, vocab, model_F, model_D, train_iters, test_iters)
コード例 #7
0
def part2(args):

    ## load model
    #model_prefix = './save/Feb15203331/ckpts/1300'
    model_prefix = os.path.join(args.part2_model_dir, str(args.part2_step))

    args.preload_F = f'{model_prefix}_F.pth'
    args.preload_D = f'{model_prefix}_D.pth'

    ## load data
    train_iters, dev_iters, test_iters, vocab = load_dataset(args)

    ## output dir
    output_dir = 'part2_output'
    os.makedirs(output_dir, exist_ok=True)

    log_f = open(os.path.join(output_dir, 'log.txt'), 'w')

    model_F = StyleTransformer(args, vocab).to(args.device)
    model_D = Discriminator(args, vocab).to(args.device)

    assert os.path.isfile(args.preload_F)
    model_F.load_state_dict(torch.load(args.preload_F))
    assert os.path.isfile(args.preload_D)
    model_D.load_state_dict(torch.load(args.preload_D))

    model_F.eval()
    model_D.eval()

    dataset = test_iters
    pos_iter = dataset.pos_iter
    neg_iter = dataset.neg_iter

    pad_idx = vocab.stoi['<pad>']  # 1
    eos_idx = vocab.stoi['<eos>']  # 2
    unk_idx = vocab.stoi['<unk>']  # 0

    ## 2-1 attention
    log(log_f, "***** 2-1: Attention *****")

    gold_text = []
    gold_token = []
    rev_output = []
    rev_token = []
    attn_weight = None

    raw_style = 1  ## neg: 0, pos: 1

    for batch in pos_iter:

        inp_tokens = batch.text
        inp_lengths = get_lengths(inp_tokens, eos_idx)
        raw_styles = torch.full_like(inp_tokens[:, 0], raw_style)
        rev_styles = 1 - raw_styles

        with torch.no_grad():
            rev_log_probs = model_F(inp_tokens,
                                    None,
                                    inp_lengths,
                                    rev_styles,
                                    generate=True,
                                    differentiable_decode=False,
                                    temperature=1)

        rev_attn = model_F.get_decode_src_attn_weight()
        if attn_weight == None:
            attn_weight = rev_attn
        else:
            for layer in range(len(rev_attn)):
                attn_weight[layer] = torch.cat(
                    [attn_weight[layer], rev_attn[layer]])

        gold_text += tensor2text(vocab, inp_tokens.cpu())
        rev_idx = rev_log_probs.argmax(-1).cpu()
        rev_output += tensor2text(vocab, rev_idx)

        gold_token.extend([[vocab.itos[j] for j in i] for i in inp_tokens])
        rev_token.extend([[vocab.itos[j] for j in i] for i in rev_idx])

        break  ## select first batch to speed up

    # attn_weight[layer] = (Batch, Head, Source, Style+Target)

    idx = np.random.randint(len(rev_output))
    log(log_f, '*' * 20, 'pos sample', '*' * 20)
    log(log_f, '[gold]', gold_text[idx])
    log(log_f, '[rev ]', rev_output[idx])
    for l in range(len(attn_weight)):
        output_name = os.path.join(output_dir, f'problem1_attn_layer{l}.png')
        show_attn(gold_token[idx], rev_token[idx], attn_weight[l][idx],
                  'attention map', output_name)
        log(log_f, f'save attention figure at {output_name}')

    log(log_f, '***** 2-1 end *****')
    log(log_f)

    ## 2-2. tsne
    log(log_f, "***** 2-2: T-sne *****")
    features = []
    labels = []

    for batch in pos_iter:

        inp_tokens = batch.text
        inp_lengths = get_lengths(inp_tokens, eos_idx)

        _, pos_features = model_D(inp_tokens,
                                  inp_lengths,
                                  return_features=True)
        features.extend(pos_features.detach().cpu().numpy())
        labels.extend([0 for i in range(pos_features.shape[0])])

        raw_style = 1
        raw_styles = torch.full_like(inp_tokens[:, 0], raw_style)
        rev_styles = 1 - raw_styles

        with torch.no_grad():
            rev_log_probs = model_F(inp_tokens,
                                    None,
                                    inp_lengths,
                                    rev_styles,
                                    generate=True,
                                    differentiable_decode=False,
                                    temperature=1)

        rev_tokens = rev_log_probs.argmax(-1)
        rev_lengths = get_lengths(rev_tokens, eos_idx)
        _, rev_features = model_D(rev_tokens,
                                  inp_lengths,
                                  return_features=True)
        features.extend(rev_features.detach().cpu().numpy())
        labels.extend([1 for i in range(rev_features.shape[0])])

    for batch in neg_iter:

        inp_tokens = batch.text
        inp_lengths = get_lengths(inp_tokens, eos_idx)

        _, neg_features = model_D(inp_tokens,
                                  inp_lengths,
                                  return_features=True)
        features.extend(neg_features.detach().cpu().numpy())
        labels.extend([2 for i in range(neg_features.shape[0])])

        raw_style = 0
        raw_styles = torch.full_like(inp_tokens[:, 0], raw_style)
        rev_styles = 1 - raw_styles

        with torch.no_grad():
            rev_log_probs = model_F(inp_tokens,
                                    None,
                                    inp_lengths,
                                    rev_styles,
                                    generate=True,
                                    differentiable_decode=False,
                                    temperature=1)

        rev_tokens = rev_log_probs.argmax(-1)
        rev_lengths = get_lengths(rev_tokens, eos_idx)
        _, rev_features = model_D(rev_tokens,
                                  inp_lengths,
                                  return_features=True)
        features.extend(rev_features.detach().cpu().numpy())
        labels.extend([3 for i in range(rev_features.shape[0])])

    labels = np.array(labels)
    colors = ['red', 'blue', 'orange', 'green']
    classes = ['POS', 'POS -> NEG', 'NEG', 'NEG -> POS']
    X_emb = TSNE(n_components=2).fit_transform(features)

    fig, ax = plt.subplots()
    for i in range(4):
        idxs = labels == i
        ax.scatter(X_emb[idxs, 0],
                   X_emb[idxs, 1],
                   color=colors[i],
                   label=classes[i],
                   alpha=0.8,
                   edgecolors='none')
    ax.legend()
    ax.set_title('t-sne of four distributions')
    output_name = os.path.join(output_dir, 'problem2_tsne.png')
    plt.savefig(output_name)
    log(log_f, f'save T-sne figure at {output_name}')
    log(log_f, "***** 2-2 end *****")
    log(log_f)

    # 2-3. mask input tokens
    log(log_f, '***** 2-3: mask input *****')
    raw_style = 1

    for batch in pos_iter:
        inp_tokens = batch.text
        inp_lengths = get_lengths(inp_tokens, eos_idx)
        break  ## only select first batch

    sample_idx = np.random.randint(inp_tokens.shape[0])
    inp_token = inp_tokens[sample_idx]
    inp_length = inp_lengths[sample_idx]

    inp_tokens = inp_token.repeat(
        inp_length - 2 + 1,
        1)  ## mask until '. <eos>' but contain the origin sentence
    for i in range(inp_tokens.shape[0] - 1):
        inp_tokens[i + 1][i] = unk_idx

    inp_lengths = torch.full_like(inp_tokens[:, 0], inp_length)
    raw_styles = torch.full_like(inp_tokens[:, 0], raw_style)
    rev_styles = 1 - raw_styles

    with torch.no_grad():
        rev_log_probs = model_F(inp_tokens,
                                None,
                                inp_lengths,
                                rev_styles,
                                generate=True,
                                differentiable_decode=False,
                                temperature=1)

    gold_text = tensor2text(vocab, inp_tokens.cpu(), remain_unk=True)
    rev_idx = rev_log_probs.argmax(-1).cpu()
    rev_output = tensor2text(vocab, rev_idx, remain_unk=True)

    for i in range(len(gold_text)):
        log(log_f, '-')
        log(log_f, '[ORG]', gold_text[i])
        log(log_f, '[REV]', rev_output[i])

    log(log_f, '***** 2-3 end *****')
    log_f.close()
コード例 #8
0
        print('*' * 20, '********', '*' * 20, file=fw)

    model_F.train()



if __name__ == '__main__':
    config = Config()
    config.save_folder = config.save_path + '/' + str(time.strftime('%b%d%H%M%S', time.localtime()))
    os.makedirs(config.save_folder)
    os.makedirs(config.save_folder + '/ckpts')
    print('Save Path:', config.save_folder)

    train_iters, dev_iters, test_iters, vocab = load_dataset(config)

    # print(len(vocab))
    # for batch in test_iters:
    #     text = tensor2text(vocab, batch[0])
    #     print('\n'.join(text))
    #     print(batch.label)
    #     break

    model_F = StyleTransformer(config, vocab).to(config.device)
    global_step = 1200
    save_path = f"save/Mar27144631/{global_step}_F.pth"
    state_dict = torch.load(save_path)
    model_F.load_state_dict(state_dict)
    temperature = calc_temperature(config.temperature_config, global_step)

    my_eval_interp(config, vocab, model_F, test_iters, global_step, temperature)