def learn_pca(device='cpu', stuff=None): if stuff is None: tokenizer, model = load_model(device=device) corpus = loader.Corpus(DATA_DIR, CACHE_DIR, less_mem=True) np.stuff = (tokenizer, model, corpus) else: tokenizer, model, corpus = stuff lines_dataset = LinesDataset(corpus, CONTEXT_MIN_SIZE, tokenizer.eos_token) lines_loader = torch.utils.data.DataLoader(lines_dataset, BATCH_SIZE, shuffle=True) model.transformer.output_hidden_states = True mean = None var = None count = 0 for i in range(PCA_BATCHES): ldata = next(iter(lines_loader)) ltids_list, lmasks_list = prepare_batch(tokenizer, ldata, device) for id_bit, m_bit in zip(ltids_list, lmasks_list): with torch.no_grad(): outputs = model(id_bit, attention_mask=m_bit) embed_dim = outputs[2][0].shape[-1] pieces = [ (outputs[2][i] * m_bit[:, :, None]).view(-1, embed_dim) for i in range(1, len(outputs[2])) ] vec = torch.cat(pieces, dim=1) vec = torch.sign(vec) * torch.log(1 + torch.abs(vec)) del (outputs) del (pieces) if mean is None: mean = vec.new_zeros(vec.shape[1]) var = vec.new_zeros(vec.shape[1]) mean += vec.sum(dim=0) torch.addmm(var, vec.T, vec, out=var) piece_count = m_bit.sum() count += piece_count print('%d ' % (i * BATCH_SIZE), end='', flush=True) mean /= count var /= count torch.addmm(var, -mean[:, None], mean[None, :], out=var) var_np = var.cpu().data.numpy() learner = TruncatedSVD(n_components=1024, n_iter=8) learner.fit(var_np) np.save('../data/gpt2/mean.npy', mean.cpu().data.numpy()) np.save('../data/gpt2/e_vecs.npy', learner.components_) np.save('../data/gpt2/e_vals.npy', learner.singular_values_) explained_var = learner.singular_values_.sum() total_var = torch.diag(var).sum().item() print(1 - (explained_var / total_var))
def evaluate(checkpt, count, sc_changer=None, stuff=None, device='cpu', seed=1337): if stuff is None: corpus = loader.Corpus(DATA_DIR, CACHE_DIR, less_mem=True) tokenizer, gpt2_model = load_model(device=device) np.stuff = (corpus, tokenizer, gpt2_model) else: corpus, tokenizer, gpt2_model = stuff lines_dir = '../data/eval_lines' files = sorted([os.path.join(lines_dir, p) for p in os.listdir(lines_dir)]) lines = [] for file in files: with open(file, 'r') as f: for line in f: line = loader.fully_clean_line(line) if line != '': lines.append(line) parser = loader.construct_parser('../data/parser.tar.gz') lines_dataset = LinesListDataset(corpus, parser, lines, CONTEXT_MIN_SIZE, tokenizer.eos_token, split=True, sc_changer=sc_changer) mean = torch.tensor(np.load('../data/gpt2/mean.npy')).to(device) e_vecs = torch.tensor(np.load('../data/gpt2/e_vecs.npy')).to(device) encoder_config = sc_gpt2_model.GPT2Config(n_layer=2, n_head=16, n_embd=1024) decoder_config = sc_gpt2_model.GPT2Config(n_layer=4, n_head=16, n_embd=1024) config = sc_gpt2_model.SCGPT2Config(corpus.bitvec_size(), e_vecs.shape[0], encoder_config, decoder_config) model = sc_gpt2_model.SCGPT2HeadModel(config) model.load_state_dict(torch.load(checkpt, map_location='cpu')) model = model.to(device) model.attach_gpt2_model(gpt2_model, mean, e_vecs) model.eval() np.random.seed(seed) sample_ids = np.random.choice(len(lines_dataset), size=count) eos_len = len(tokenizer.eos_token) avg_sentence_bleu = 0.0 avg_sc_bleu = 0.0 avg_loss = 0.0 num_losses = 0 eval_logfile = open('../data/gpt2/eval_logs.txt', 'w') def log_print(*args, **kwargs): print(*args, **kwargs) print(*args, **kwargs, file=eval_logfile, flush=True) torch.random.manual_seed(seed) for i in sample_ids: raw_ldata = lines_dataset[i] last_line_ind = raw_ldata[0][:-eos_len].rindex(tokenizer.eos_token) past = raw_ldata[0][:last_line_ind] + tokenizer.eos_token context = raw_ldata[0][:last_line_ind] + (tokenizer.eos_token * 2) target = raw_ldata[0][last_line_ind:] parse = loader.get_parse(parser, [target[eos_len:-eos_len]])[0] ref_constr = loader.clean_up_bits(loader.process(parse)) ref_constr_tagged = loader.replace_with_tags(corpus, *ref_constr, train=False) ldata = [(context, raw_ldata[1])] all_sentences = [] if i == sample_ids[0]: print(ldata[0][1][-1]) print(ref_constr_tagged) with torch.no_grad(): for j in range(5): split_pieces = prepare_sc_batch(corpus, tokenizer, ldata, None, device, peturb=False, single=True, remove_last_eos=True) del (split_pieces[0]['labels']) model.prepare_forward(**split_pieces[0], num_copies=5) output = model.generate(max_length=100, num_return_sequences=5, do_sample=True, pad_token_id=tokenizer.eos_token_id, temperature=0.5) for sentence in output: eos_places = (sentence == tokenizer.eos_token_id).nonzero() if eos_places.shape[0] > 1: end = eos_places[1, 0] else: end = sentence.shape[0] sentence = tokenizer.decode(sentence[1:end]) all_sentences.append(sentence) best_similarity = None best_sentence = None for sentence in all_sentences: parse = loader.get_parse(parser, [sentence])[0] hypo_constr = loader.clean_up_bits(loader.process(parse)) hypo_constr_tagged = loader.replace_with_tags(corpus, *hypo_constr, train=False) similarity = loader.sc_similarity(ref_constr_tagged, hypo_constr_tagged) if best_similarity is None or similarity > best_similarity: best_similarity = similarity best_sentence = sentence sentence_bleu = loader.bleu(target[eos_len:-eos_len], best_sentence) avg_sentence_bleu += sentence_bleu / count sc_bleu = best_similarity avg_sc_bleu += sc_bleu / count ldata = [(past, raw_ldata[1][:-1])] split_pieces = prepare_sc_batch(corpus, tokenizer, ldata, None, device, peturb=False, single=True) model_outputs = model.forward_with_gpt2(**split_pieces[0]) gpt2_past = model_outputs[2] past_eam = split_pieces[0]['encoder_attention_mask'] past_dam = split_pieces[0]['decoder_attention_mask'] ldata = [(target, raw_ldata[1][-1:])] split_pieces = prepare_sc_batch(corpus, tokenizer, ldata, None, device, peturb=False, single=True) split_pieces[0]['encoder_attention_mask'] = torch.cat( [past_eam, split_pieces[0]['encoder_attention_mask']], dim=1) split_pieces[0]['decoder_attention_mask'] = torch.cat( [past_dam, split_pieces[0]['decoder_attention_mask']], dim=1) loss = model.forward_with_gpt2(**split_pieces[0], past=gpt2_past)[0] loss = loss.item() loss_count = split_pieces[0]['decoder_attention_mask'].sum().item() avg_loss += loss * loss_count num_losses += loss_count orig_sentences = raw_ldata[0].replace(tokenizer.eos_token, '|') log_print( "\"%s\" ==> \"%s\", (%0.4f, %0.4f, %0.4f)" % (orig_sentences, best_sentence, sentence_bleu, sc_bleu, loss)) avg_loss /= num_losses log_print("Averages: (%0.4f, %0.4f, %0.4f)" % (avg_sentence_bleu, avg_sc_bleu, avg_loss)) eval_logfile.close()
def train_model(device='cpu', stuff=None): if stuff is None: corpus = loader.Corpus(DATA_DIR, CACHE_DIR, less_mem=True) np.stuff = corpus else: corpus = stuff sqrt_e_vals = torch.sqrt( torch.tensor(np.load('../data/gpt2/e_vals.npy')).to(device)) lines_perm = np.load('../data/gpt2/memo/lines_perm.npy') pos_arr = np.load('../data/gpt2/memo/pos.npy') tokenizer, gpt2_model = load_model(device=device) lines_dataset = LinesDataset(corpus, CONTEXT_MIN_SIZE, tokenizer.eos_token, split=True) encoder_config = sc_gpt2_model.GPT2Config(n_layer=2, n_head=16, n_embd=1024) decoder_config = sc_gpt2_model.GPT2Config(n_layer=4, n_head=16, n_embd=1024) config = sc_gpt2_model.SCGPT2Config(corpus.bitvec_size(), sqrt_e_vals.shape[0], encoder_config, decoder_config) model = sc_gpt2_model.SCGPT2HeadModel(config).to(device) model.embedding.sabotage_gpt2(gpt2_model) # model.load_state_dict(torch.load( # '../data/gpt2/checkpt/nouns_model_0_50000.pth')) # !!! model.train() del (gpt2_model) novel_named_parameters = model.novel_named_parameters() no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in novel_named_parameters if not any(nd in n for nd in no_decay) ], "weight_decay": 1e-2, }, { "params": [ p for n, p in novel_named_parameters if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] num_epochs = 4 starting_epoch = 0 starting_lr = 2e-4 # !!! optim = torch.optim.AdamW(optimizer_grouped_parameters, lr=starting_lr) # log = open('../data/gpt2/checkpt/log.txt', 'w') def log_print(*args, **kwargs): print(*args, **kwargs) # print(*args, **kwargs, file=log) # ipdb.set_trace() # with torch.autograd.profiler.profile(use_cuda=True) as prof: for epoch in range(starting_epoch, num_epochs + 1): log_print('Now on epoch %d' % epoch) if epoch == num_epochs: break prev_arena = -1 arena_data = None mean_loss = None kill_run = False for batch_num in range(len(pos_arr)): if os.path.exists('../data/gpt2/checkpt/decay_%d_%d.txt' % (epoch, batch_num)): for group in optim.param_groups: lr = group['lr'] / 2 break for group in optim.param_groups: group['lr'] = lr log_print("\nDecayed learning rate to %0.3e" % lr) if os.path.exists('../data/gpt2/checkpt/save_%d_%d.txt' % (epoch, batch_num)): torch.save( model.state_dict(), '../data/gpt2/checkpt/model_%d_%d.pth' % (epoch, batch_num)) torch.save( optim.state_dict(), '../data/gpt2/checkpt/optim_%d_%d.pth' % (epoch, batch_num)) pass if os.path.exists('../data/gpt2/checkpt/kill_%d_%d.txt' % (epoch, batch_num)): kill_run = True break arena, pointer = pos_arr[batch_num] if arena != prev_arena: arena_data = np.load('../data/gpt2/memo/%d.npy' % arena, mmap_mode='r') prev_arena = arena if batch_num == (len(pos_arr) - 1): transf_batch = arena_data[pointer:-1] else: next_arena, next_pointer = pos_arr[batch_num + 1] if next_arena != arena: transf_batch = arena_data[pointer:-1] else: transf_batch = arena_data[pointer:next_pointer - 1] transf_batch = torch.tensor(decompress(transf_batch)).to(device) transf_batch = transf_batch * sqrt_e_vals # if batch_num == 50: # kill_run = True # break entries = lines_perm[batch_num * BATCH_SIZE:(batch_num + 1) * BATCH_SIZE] ldata = [lines_dataset[i] for i in entries] if batch_num == 0: print(ldata[0]) split_pieces = prepare_sc_batch(corpus, tokenizer, ldata, transf_batch, device) counts = [p['decoder_attention_mask'].sum() for p in split_pieces] total_count = sum(counts) if mean_loss is not None: log_print('%0.3f ' % mean_loss.item(), end='', flush=True) optim.zero_grad() mean_loss = torch.zeros(1, device=device) for i in range(len(split_pieces)): outputs = model.forward_full(**split_pieces[i]) loss = outputs[0] scaled_loss = loss * counts[i] / total_count scaled_loss.backward() mean_loss += scaled_loss.data optim.step() log_print('') if kill_run: break
def memoize_gpt2_data(device='cpu', stuff=None): if stuff is None: tokenizer, model = load_model(device=device) corpus = loader.Corpus(DATA_DIR, CACHE_DIR, less_mem=True) np.stuff = (tokenizer, model, corpus) else: tokenizer, model, corpus = stuff lines_dataset = LinesDataset(corpus, CONTEXT_MIN_SIZE, tokenizer.eos_token) total_size = len(lines_dataset) model.transformer.output_hidden_states = True model.eval() mean = torch.tensor(np.load('../data/gpt2/mean.npy')).to(device) e_vecs = torch.tensor(np.load('../data/gpt2/e_vecs.npy')).to(device) sqrt_e_vals = torch.sqrt( torch.tensor(np.load('../data/gpt2/e_vals.npy')).to(device)) lines_perm = np.random.permutation(total_size) start = 0 arena_num = 0 pos_arr = [] lines_perm_dir = '../data/gpt2/memo/lines_perm.npy' start_dir = '../data/gpt2/memo/start.npy' arena_num_dir = '../data/gpt2/memo/arena_num.npy' pos_arr_dir = '../data/gpt2/memo/pos.npy' if os.path.exists(lines_perm_dir): lines_perm = np.load(lines_perm_dir) start = np.load(start_dir) arena_num = np.load(arena_num_dir) pos_arr = np.load(pos_arr_dir).tolist() else: # np.save(lines_perm_dir, lines_perm) # np.save(start_dir, start) # np.save(arena_num_dir, arena_num) # np.save(pos_arr_dir, pos_arr) pass arena = [] pointer = 0 num_added = 0 print("Starting from instance %d for arena %d" % (start, arena_num)) for batch_start in range(start, total_size, BATCH_SIZE): elems = lines_perm[batch_start:batch_start + BATCH_SIZE] ldata = [lines_dataset[j] for j in elems] ltids_list, lmasks_list = prepare_batch(tokenizer, ldata, device) new_pointer = pointer for id_bit, m_bit in zip(ltids_list, lmasks_list): with torch.no_grad(): outputs = model(id_bit, attention_mask=m_bit) embed_dim = outputs[2][0].shape[-1] pieces = [ (outputs[2][i] * m_bit[:, :, None]).view(-1, embed_dim) for i in range(1, len(outputs[2])) ] vec = torch.cat(pieces, dim=1) vec = torch.sign(vec) * torch.log(1 + torch.abs(vec)) del (outputs) del (pieces) vec = vec[m_bit.view(-1) > 0.5] transf = (vec - mean) @ e_vecs.T / sqrt_e_vals packed = transf.cpu().data.numpy() packed = compress(packed) arena.append(packed) new_pointer += packed.shape[0] pos_arr.append((arena_num, pointer)) pointer = new_pointer arena.append(np.ones((1, e_vecs.shape[0]), dtype=np.int8) * (-128)) pointer += 1 num_added += 1 print('%d' % ((batch_start % 10000) // 1000), end='', flush=True) if num_added >= 3000 or batch_start + BATCH_SIZE >= total_size: joined = np.concatenate(arena, axis=0) # np.save('../data/gpt2/memo/%d.npy' % arena_num, joined) # np.save(pos_arr_dir, np.array(pos_arr)) # np.save(arena_num_dir, arena_num + 1) # np.save(start_dir, batch_start + BATCH_SIZE) print("\nSaved up to instance %d for arena %d" % (batch_start + BATCH_SIZE, arena_num)) arena = [] arena_num += 1 pointer = 0 num_added = 0
if not cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) else: torch.cuda.manual_seed(seed) ############################################################################### # Load data ############################################################################### text_field = data.Field(lower=True) label_field = data.Field(lower=True) corpus = loader.Corpus(data_set, text_field, label_field, batch_size, max_size=vocab_size) #text_field.vocab.load_vectors(embed, wv_type='glove.6B2', wv_dim=200) #embed() #help(text_field.vocab.load_vectors) #text_field.vocab.load_vectors(embed_path, wv_dim=200) #text_field.vocab.load_vectors(embed_path) text_field.vocab.load_vectors(embed) #text_field.vocab.load_vectors('glove.6B.200d') ''' def batchify(data, bsz): nbatch = data.size(0) // bsz data = data.narrow(0, 0, nbatch * bsz) data = data.view(bsz, -1).t().contiguous() if cuda: