Esempio n. 1
0
def learn_pca(device='cpu', stuff=None):
    if stuff is None:
        tokenizer, model = load_model(device=device)
        corpus = loader.Corpus(DATA_DIR, CACHE_DIR, less_mem=True)
        np.stuff = (tokenizer, model, corpus)
    else:
        tokenizer, model, corpus = stuff

    lines_dataset = LinesDataset(corpus, CONTEXT_MIN_SIZE, tokenizer.eos_token)
    lines_loader = torch.utils.data.DataLoader(lines_dataset,
                                               BATCH_SIZE,
                                               shuffle=True)
    model.transformer.output_hidden_states = True

    mean = None
    var = None
    count = 0

    for i in range(PCA_BATCHES):
        ldata = next(iter(lines_loader))
        ltids_list, lmasks_list = prepare_batch(tokenizer, ldata, device)

        for id_bit, m_bit in zip(ltids_list, lmasks_list):
            with torch.no_grad():
                outputs = model(id_bit, attention_mask=m_bit)
                embed_dim = outputs[2][0].shape[-1]
                pieces = [
                    (outputs[2][i] * m_bit[:, :, None]).view(-1, embed_dim)
                    for i in range(1, len(outputs[2]))
                ]
                vec = torch.cat(pieces, dim=1)
                vec = torch.sign(vec) * torch.log(1 + torch.abs(vec))
                del (outputs)
                del (pieces)

                if mean is None:
                    mean = vec.new_zeros(vec.shape[1])
                    var = vec.new_zeros(vec.shape[1])
                mean += vec.sum(dim=0)
                torch.addmm(var, vec.T, vec, out=var)
                piece_count = m_bit.sum()
                count += piece_count
        print('%d ' % (i * BATCH_SIZE), end='', flush=True)

    mean /= count
    var /= count
    torch.addmm(var, -mean[:, None], mean[None, :], out=var)

    var_np = var.cpu().data.numpy()
    learner = TruncatedSVD(n_components=1024, n_iter=8)
    learner.fit(var_np)
    np.save('../data/gpt2/mean.npy', mean.cpu().data.numpy())
    np.save('../data/gpt2/e_vecs.npy', learner.components_)
    np.save('../data/gpt2/e_vals.npy', learner.singular_values_)

    explained_var = learner.singular_values_.sum()
    total_var = torch.diag(var).sum().item()
    print(1 - (explained_var / total_var))
Esempio n. 2
0
def evaluate(checkpt,
             count,
             sc_changer=None,
             stuff=None,
             device='cpu',
             seed=1337):
    if stuff is None:
        corpus = loader.Corpus(DATA_DIR, CACHE_DIR, less_mem=True)
        tokenizer, gpt2_model = load_model(device=device)
        np.stuff = (corpus, tokenizer, gpt2_model)
    else:
        corpus, tokenizer, gpt2_model = stuff

    lines_dir = '../data/eval_lines'
    files = sorted([os.path.join(lines_dir, p) for p in os.listdir(lines_dir)])
    lines = []
    for file in files:
        with open(file, 'r') as f:
            for line in f:
                line = loader.fully_clean_line(line)
                if line != '':
                    lines.append(line)

    parser = loader.construct_parser('../data/parser.tar.gz')
    lines_dataset = LinesListDataset(corpus,
                                     parser,
                                     lines,
                                     CONTEXT_MIN_SIZE,
                                     tokenizer.eos_token,
                                     split=True,
                                     sc_changer=sc_changer)

    mean = torch.tensor(np.load('../data/gpt2/mean.npy')).to(device)
    e_vecs = torch.tensor(np.load('../data/gpt2/e_vecs.npy')).to(device)

    encoder_config = sc_gpt2_model.GPT2Config(n_layer=2,
                                              n_head=16,
                                              n_embd=1024)
    decoder_config = sc_gpt2_model.GPT2Config(n_layer=4,
                                              n_head=16,
                                              n_embd=1024)
    config = sc_gpt2_model.SCGPT2Config(corpus.bitvec_size(), e_vecs.shape[0],
                                        encoder_config, decoder_config)
    model = sc_gpt2_model.SCGPT2HeadModel(config)
    model.load_state_dict(torch.load(checkpt, map_location='cpu'))
    model = model.to(device)
    model.attach_gpt2_model(gpt2_model, mean, e_vecs)
    model.eval()

    np.random.seed(seed)
    sample_ids = np.random.choice(len(lines_dataset), size=count)
    eos_len = len(tokenizer.eos_token)

    avg_sentence_bleu = 0.0
    avg_sc_bleu = 0.0
    avg_loss = 0.0
    num_losses = 0

    eval_logfile = open('../data/gpt2/eval_logs.txt', 'w')

    def log_print(*args, **kwargs):
        print(*args, **kwargs)
        print(*args, **kwargs, file=eval_logfile, flush=True)

    torch.random.manual_seed(seed)

    for i in sample_ids:
        raw_ldata = lines_dataset[i]
        last_line_ind = raw_ldata[0][:-eos_len].rindex(tokenizer.eos_token)
        past = raw_ldata[0][:last_line_ind] + tokenizer.eos_token
        context = raw_ldata[0][:last_line_ind] + (tokenizer.eos_token * 2)
        target = raw_ldata[0][last_line_ind:]
        parse = loader.get_parse(parser, [target[eos_len:-eos_len]])[0]
        ref_constr = loader.clean_up_bits(loader.process(parse))
        ref_constr_tagged = loader.replace_with_tags(corpus,
                                                     *ref_constr,
                                                     train=False)
        ldata = [(context, raw_ldata[1])]
        all_sentences = []

        if i == sample_ids[0]:
            print(ldata[0][1][-1])
            print(ref_constr_tagged)

        with torch.no_grad():
            for j in range(5):
                split_pieces = prepare_sc_batch(corpus,
                                                tokenizer,
                                                ldata,
                                                None,
                                                device,
                                                peturb=False,
                                                single=True,
                                                remove_last_eos=True)
                del (split_pieces[0]['labels'])
                model.prepare_forward(**split_pieces[0], num_copies=5)
                output = model.generate(max_length=100,
                                        num_return_sequences=5,
                                        do_sample=True,
                                        pad_token_id=tokenizer.eos_token_id,
                                        temperature=0.5)
                for sentence in output:
                    eos_places = (sentence == tokenizer.eos_token_id).nonzero()
                    if eos_places.shape[0] > 1:
                        end = eos_places[1, 0]
                    else:
                        end = sentence.shape[0]
                    sentence = tokenizer.decode(sentence[1:end])
                    all_sentences.append(sentence)

        best_similarity = None
        best_sentence = None
        for sentence in all_sentences:
            parse = loader.get_parse(parser, [sentence])[0]
            hypo_constr = loader.clean_up_bits(loader.process(parse))
            hypo_constr_tagged = loader.replace_with_tags(corpus,
                                                          *hypo_constr,
                                                          train=False)
            similarity = loader.sc_similarity(ref_constr_tagged,
                                              hypo_constr_tagged)
            if best_similarity is None or similarity > best_similarity:
                best_similarity = similarity
                best_sentence = sentence

        sentence_bleu = loader.bleu(target[eos_len:-eos_len], best_sentence)
        avg_sentence_bleu += sentence_bleu / count
        sc_bleu = best_similarity
        avg_sc_bleu += sc_bleu / count

        ldata = [(past, raw_ldata[1][:-1])]
        split_pieces = prepare_sc_batch(corpus,
                                        tokenizer,
                                        ldata,
                                        None,
                                        device,
                                        peturb=False,
                                        single=True)
        model_outputs = model.forward_with_gpt2(**split_pieces[0])
        gpt2_past = model_outputs[2]
        past_eam = split_pieces[0]['encoder_attention_mask']
        past_dam = split_pieces[0]['decoder_attention_mask']

        ldata = [(target, raw_ldata[1][-1:])]
        split_pieces = prepare_sc_batch(corpus,
                                        tokenizer,
                                        ldata,
                                        None,
                                        device,
                                        peturb=False,
                                        single=True)
        split_pieces[0]['encoder_attention_mask'] = torch.cat(
            [past_eam, split_pieces[0]['encoder_attention_mask']], dim=1)
        split_pieces[0]['decoder_attention_mask'] = torch.cat(
            [past_dam, split_pieces[0]['decoder_attention_mask']], dim=1)
        loss = model.forward_with_gpt2(**split_pieces[0], past=gpt2_past)[0]
        loss = loss.item()
        loss_count = split_pieces[0]['decoder_attention_mask'].sum().item()
        avg_loss += loss * loss_count
        num_losses += loss_count

        orig_sentences = raw_ldata[0].replace(tokenizer.eos_token, '|')
        log_print(
            "\"%s\" ==> \"%s\", (%0.4f, %0.4f, %0.4f)" %
            (orig_sentences, best_sentence, sentence_bleu, sc_bleu, loss))

    avg_loss /= num_losses
    log_print("Averages: (%0.4f, %0.4f, %0.4f)" %
              (avg_sentence_bleu, avg_sc_bleu, avg_loss))
    eval_logfile.close()
Esempio n. 3
0
def train_model(device='cpu', stuff=None):
    if stuff is None:
        corpus = loader.Corpus(DATA_DIR, CACHE_DIR, less_mem=True)
        np.stuff = corpus
    else:
        corpus = stuff

    sqrt_e_vals = torch.sqrt(
        torch.tensor(np.load('../data/gpt2/e_vals.npy')).to(device))
    lines_perm = np.load('../data/gpt2/memo/lines_perm.npy')
    pos_arr = np.load('../data/gpt2/memo/pos.npy')

    tokenizer, gpt2_model = load_model(device=device)
    lines_dataset = LinesDataset(corpus,
                                 CONTEXT_MIN_SIZE,
                                 tokenizer.eos_token,
                                 split=True)

    encoder_config = sc_gpt2_model.GPT2Config(n_layer=2,
                                              n_head=16,
                                              n_embd=1024)
    decoder_config = sc_gpt2_model.GPT2Config(n_layer=4,
                                              n_head=16,
                                              n_embd=1024)
    config = sc_gpt2_model.SCGPT2Config(corpus.bitvec_size(),
                                        sqrt_e_vals.shape[0], encoder_config,
                                        decoder_config)
    model = sc_gpt2_model.SCGPT2HeadModel(config).to(device)
    model.embedding.sabotage_gpt2(gpt2_model)
    # model.load_state_dict(torch.load(
    #     '../data/gpt2/checkpt/nouns_model_0_50000.pth'))  # !!!
    model.train()
    del (gpt2_model)

    novel_named_parameters = model.novel_named_parameters()
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in novel_named_parameters
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            1e-2,
        },
        {
            "params": [
                p for n, p in novel_named_parameters
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    num_epochs = 4
    starting_epoch = 0
    starting_lr = 2e-4  # !!!
    optim = torch.optim.AdamW(optimizer_grouped_parameters, lr=starting_lr)

    # log = open('../data/gpt2/checkpt/log.txt', 'w')

    def log_print(*args, **kwargs):
        print(*args, **kwargs)
        # print(*args, **kwargs, file=log)

    # ipdb.set_trace()

    # with torch.autograd.profiler.profile(use_cuda=True) as prof:
    for epoch in range(starting_epoch, num_epochs + 1):
        log_print('Now on epoch %d' % epoch)
        if epoch == num_epochs:
            break

        prev_arena = -1
        arena_data = None
        mean_loss = None
        kill_run = False
        for batch_num in range(len(pos_arr)):
            if os.path.exists('../data/gpt2/checkpt/decay_%d_%d.txt' %
                              (epoch, batch_num)):
                for group in optim.param_groups:
                    lr = group['lr'] / 2
                    break
                for group in optim.param_groups:
                    group['lr'] = lr
                log_print("\nDecayed learning rate to %0.3e" % lr)

            if os.path.exists('../data/gpt2/checkpt/save_%d_%d.txt' %
                              (epoch, batch_num)):
                torch.save(
                    model.state_dict(),
                    '../data/gpt2/checkpt/model_%d_%d.pth' %
                    (epoch, batch_num))
                torch.save(
                    optim.state_dict(),
                    '../data/gpt2/checkpt/optim_%d_%d.pth' %
                    (epoch, batch_num))
                pass

            if os.path.exists('../data/gpt2/checkpt/kill_%d_%d.txt' %
                              (epoch, batch_num)):
                kill_run = True
                break

            arena, pointer = pos_arr[batch_num]
            if arena != prev_arena:
                arena_data = np.load('../data/gpt2/memo/%d.npy' % arena,
                                     mmap_mode='r')
                prev_arena = arena
            if batch_num == (len(pos_arr) - 1):
                transf_batch = arena_data[pointer:-1]
            else:
                next_arena, next_pointer = pos_arr[batch_num + 1]
                if next_arena != arena:
                    transf_batch = arena_data[pointer:-1]
                else:
                    transf_batch = arena_data[pointer:next_pointer - 1]
            transf_batch = torch.tensor(decompress(transf_batch)).to(device)
            transf_batch = transf_batch * sqrt_e_vals

            # if batch_num == 50:
            #     kill_run = True
            #     break

            entries = lines_perm[batch_num * BATCH_SIZE:(batch_num + 1) *
                                 BATCH_SIZE]
            ldata = [lines_dataset[i] for i in entries]
            if batch_num == 0:
                print(ldata[0])

            split_pieces = prepare_sc_batch(corpus, tokenizer, ldata,
                                            transf_batch, device)
            counts = [p['decoder_attention_mask'].sum() for p in split_pieces]
            total_count = sum(counts)

            if mean_loss is not None:
                log_print('%0.3f ' % mean_loss.item(), end='', flush=True)

            optim.zero_grad()
            mean_loss = torch.zeros(1, device=device)
            for i in range(len(split_pieces)):
                outputs = model.forward_full(**split_pieces[i])
                loss = outputs[0]
                scaled_loss = loss * counts[i] / total_count
                scaled_loss.backward()
                mean_loss += scaled_loss.data
            optim.step()
        log_print('')

        if kill_run:
            break
Esempio n. 4
0
def memoize_gpt2_data(device='cpu', stuff=None):
    if stuff is None:
        tokenizer, model = load_model(device=device)
        corpus = loader.Corpus(DATA_DIR, CACHE_DIR, less_mem=True)
        np.stuff = (tokenizer, model, corpus)
    else:
        tokenizer, model, corpus = stuff

    lines_dataset = LinesDataset(corpus, CONTEXT_MIN_SIZE, tokenizer.eos_token)
    total_size = len(lines_dataset)
    model.transformer.output_hidden_states = True
    model.eval()

    mean = torch.tensor(np.load('../data/gpt2/mean.npy')).to(device)
    e_vecs = torch.tensor(np.load('../data/gpt2/e_vecs.npy')).to(device)
    sqrt_e_vals = torch.sqrt(
        torch.tensor(np.load('../data/gpt2/e_vals.npy')).to(device))

    lines_perm = np.random.permutation(total_size)
    start = 0
    arena_num = 0
    pos_arr = []
    lines_perm_dir = '../data/gpt2/memo/lines_perm.npy'
    start_dir = '../data/gpt2/memo/start.npy'
    arena_num_dir = '../data/gpt2/memo/arena_num.npy'
    pos_arr_dir = '../data/gpt2/memo/pos.npy'
    if os.path.exists(lines_perm_dir):
        lines_perm = np.load(lines_perm_dir)
        start = np.load(start_dir)
        arena_num = np.load(arena_num_dir)
        pos_arr = np.load(pos_arr_dir).tolist()
    else:
        # np.save(lines_perm_dir, lines_perm)
        # np.save(start_dir, start)
        # np.save(arena_num_dir, arena_num)
        # np.save(pos_arr_dir, pos_arr)
        pass

    arena = []
    pointer = 0
    num_added = 0
    print("Starting from instance %d for arena %d" % (start, arena_num))

    for batch_start in range(start, total_size, BATCH_SIZE):
        elems = lines_perm[batch_start:batch_start + BATCH_SIZE]
        ldata = [lines_dataset[j] for j in elems]
        ltids_list, lmasks_list = prepare_batch(tokenizer, ldata, device)
        new_pointer = pointer

        for id_bit, m_bit in zip(ltids_list, lmasks_list):
            with torch.no_grad():
                outputs = model(id_bit, attention_mask=m_bit)
                embed_dim = outputs[2][0].shape[-1]
                pieces = [
                    (outputs[2][i] * m_bit[:, :, None]).view(-1, embed_dim)
                    for i in range(1, len(outputs[2]))
                ]
                vec = torch.cat(pieces, dim=1)
                vec = torch.sign(vec) * torch.log(1 + torch.abs(vec))
                del (outputs)
                del (pieces)
                vec = vec[m_bit.view(-1) > 0.5]

                transf = (vec - mean) @ e_vecs.T / sqrt_e_vals
                packed = transf.cpu().data.numpy()
                packed = compress(packed)

                arena.append(packed)
                new_pointer += packed.shape[0]

        pos_arr.append((arena_num, pointer))
        pointer = new_pointer
        arena.append(np.ones((1, e_vecs.shape[0]), dtype=np.int8) * (-128))
        pointer += 1
        num_added += 1
        print('%d' % ((batch_start % 10000) // 1000), end='', flush=True)

        if num_added >= 3000 or batch_start + BATCH_SIZE >= total_size:
            joined = np.concatenate(arena, axis=0)
            # np.save('../data/gpt2/memo/%d.npy' % arena_num, joined)
            # np.save(pos_arr_dir, np.array(pos_arr))
            # np.save(arena_num_dir, arena_num + 1)
            # np.save(start_dir, batch_start + BATCH_SIZE)
            print("\nSaved up to instance %d for arena %d" %
                  (batch_start + BATCH_SIZE, arena_num))
            arena = []
            arena_num += 1
            pointer = 0
            num_added = 0
Esempio n. 5
0
        if not cuda:
            print(
                "WARNING: You have a CUDA device, so you should probably run with --cuda"
            )
        else:
            torch.cuda.manual_seed(seed)

    ###############################################################################
    # Load data
    ###############################################################################

    text_field = data.Field(lower=True)
    label_field = data.Field(lower=True)
    corpus = loader.Corpus(data_set,
                           text_field,
                           label_field,
                           batch_size,
                           max_size=vocab_size)
    #text_field.vocab.load_vectors(embed, wv_type='glove.6B2', wv_dim=200)
    #embed()
    #help(text_field.vocab.load_vectors)
    #text_field.vocab.load_vectors(embed_path, wv_dim=200)
    #text_field.vocab.load_vectors(embed_path)
    text_field.vocab.load_vectors(embed)
    #text_field.vocab.load_vectors('glove.6B.200d')
    '''
    def batchify(data, bsz):
        nbatch = data.size(0) // bsz
        data = data.narrow(0, 0, nbatch * bsz)
        data = data.view(bsz, -1).t().contiguous()
        if cuda: