def inference(model, dev_data, save_predictions=False, verbose=False): predictions = [] bos_token_id = dev_data.tokenizer.bos_token_id for idx, batch in enumerate(dev_data.dataloader.inference_dataloader()): if torch.cuda.is_available(): batch = [b.to(torch.device("cuda")) for b in batch] pad_token_id = dev_data.tokenizer.pad_token_id batch[0], batch[1] = trim_batch(batch[0].unsqueeze(0), pad_token_id, batch[1].unsqueeze(0)) batch[2], batch[3] = trim_batch(batch[2], pad_token_id, batch[3]) with torch.no_grad(): model.set_relation(batch[0], batch[1]) outputs = model.model.generate( input_ids=batch[2], attention_mask=batch[3], num_beams=dev_data.args.num_beams, max_length=dev_data.args.max_output_length, decoder_start_token_id=model.config.bos_token_id, early_stopping=dev_data.gen_early_stop, ) for input_, output in zip(batch[2], outputs): pred = dev_data.decode(output) predictions.append(pred) if save_predictions: dev_data.save_predictions(predictions) return dev_data.evaluate(predictions, verbose=verbose)
def __call__(self, batch) -> Dict[str, torch.Tensor]: if hasattr(self.tokenizer, "prepare_seq2seq_batch"): batch = self._encode(batch) input_ids, attention_mask, labels = ( batch["input_ids"], batch["attention_mask"], batch["labels"], ) else: input_ids = torch.stack([x["input_ids"] for x in batch]) attention_mask = torch.stack([x["attention_mask"] for x in batch]) labels = torch.stack([x["labels"] for x in batch]) labels = trim_batch(labels, self.pad_token_id) input_ids, attention_mask = trim_batch( input_ids, self.pad_token_id, attention_mask=attention_mask) if isinstance(self.tokenizer, T5Tokenizer): decoder_input_ids = self._shift_right_t5(labels) labels = labels else: decoder_input_ids = shift_tokens_right(labels, self.pad_token_id) labels = labels batch = { "input_ids": input_ids, "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, "labels": labels, } return batch
def generate(self, queries, decode_method="beam", num_generate=5): with torch.no_grad(): examples = queries decs = [] for batch in list(chunks(examples, self.batch_size)): batch = self.tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(self.device) input_ids, attention_mask = trim_batch( **batch, pad_token_id=self.tokenizer.pad_token_id) summaries = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, decoder_start_token_id=self.decoder_start_token_id, num_beams=num_generate, num_return_sequences=num_generate, ) dec = self.tokenizer.batch_decode( summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) decs.append(dec) return decs
def generate_summaries_or_translations( examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False, task="summarization", **gen_kwargs, ) -> None: fout = Path(out_file).open("w", encoding="utf-8") model_name = str(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) if fp16: model = model.half() tokenizer = AutoTokenizer.from_pretrained(model_name) # update config with summarization specific params use_task_specific_params(model, task) for batch in tqdm(list(chunks(examples, batch_size))): if "t5" in model_name: batch = [model.config.prefix + text for text in batch] batch = tokenizer(batch, max_length=1024, return_tensors="pt", truncation=True, padding="max_length").to( device ) input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id) summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs) dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) for hypothesis in dec: fout.write(hypothesis + "\n") fout.flush()
def update(engine, batch): # remove extra pad from batches batch = trim_batch(batch, pad) qgen.train() loss = 0.0 ################################### # MLE training with teacher forcing ################################### if 'sl' in args.learning: input_ids, lm_labels, token_type_ids, _, _, _ = tuple( input_tensor.to(args.device) for input_tensor in batch) loss_ce = qgen(input_ids=input_ids, labels=lm_labels, token_type_ids=token_type_ids)[0] loss = apply_loss(engine.state.iteration, qgen_optimizer, loss_ce, args) return loss.item()
def train(args, logger, model, train_data, dev_data, optimizer, scheduler): model.train() global_batch = 0 global_step = 0 train_losses = [] dev_losses = [] best_accuracy = -1.0 stop_training = False logger.info("Starting training!") for epoch in range(int(args.num_train_epochs)): for batch in tqdm(train_data.dataloader, desc="Epoch {}".format(epoch)): global_batch += 1 if torch.cuda.is_available(): batch = [b.to(torch.device("cuda")) for b in batch[0]] pad_token_id = train_data.tokenizer.pad_token_id # train batch batch[0], batch[1] = trim_batch(batch[0], pad_token_id, batch[1]) batch[2], batch[3] = trim_batch(batch[2], pad_token_id, batch[3]) # dev batch batch[4], batch[5] = trim_batch(batch[4], pad_token_id, batch[5]) batch[6], batch[7] = trim_batch(batch[6], pad_token_id, batch[7]) inner_opt = torch.optim.SGD(model.parameters(), lr=args.inner_lr) with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=False) as (fnet, diffopt): # print("train batch") train_loss = fnet(input_ids=batch[0], attention_mask=batch[1], decoder_input_ids=batch[2], decoder_attention_mask=batch[3], is_training=True) if torch.isnan(train_loss).data: logger.info("Stop training because loss=%s" % (train_loss.data)) stop_training = True break # does this ever happen? train_losses.append(train_loss.detach().cpu()) diffopt.step(train_loss) # print("dev batch") dev_loss = fnet(input_ids=batch[4], attention_mask=batch[5], decoder_input_ids=batch[6], decoder_attention_mask=batch[7], is_training=True) dev_losses.append(dev_loss.detach().cpu()) dev_loss.backward() if global_batch % args.gradient_accumulation_steps == 0: global_step += 1 torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() # We have accumulated enough gradients scheduler.step() model.zero_grad() if global_step % args.eval_period == 0: # model.eval() # curr_em = inference(model if args.n_gpu==1 else model.module, dev_data) # logger.info("Step %d Train loss %.2f %s %s on epoch=%d" % ( # global_step, # np.mean(train_losses), # dev_data.metric, # curr_em, # epoch)) logger.info("train loss: {}; dev loss: {}".format( np.mean(train_losses), np.mean(dev_losses))) train_losses = [] dev_losses = [] # if best_accuracy < curr_em: # model_state_dict = {k:v.cpu() for (k, v) in model.state_dict().items()} # torch.save(model_state_dict, os.path.join(args.output_dir, "best-model.pt")) # logger.info("Saving model with best %s: %s -> %s on epoch=%d, global_step=%d" % \ # (dev_data.metric, best_accuracy, curr_em, epoch, global_step)) # best_accuracy = curr_em # wait_step = 0 # stop_training = False # else: # wait_step += 1 # if wait_step >= args.wait_step: # stop_training = True # break # model.train() if global_step >= args.total_steps: stop_training = True break if stop_training: break model_state_dict = {k: v.cpu() for (k, v) in model.state_dict().items()} torch.save(model_state_dict, os.path.join(args.output_dir, "last-model.pt"))
def train(args, logger, model, train_data, dev_data, optimizer, scheduler): model.train() global_step = 0 train_losses = [] best_accuracy = -1.0 stop_training = False logger.info("Starting training!") for epoch in range(int(args.num_train_epochs)): for batch in tqdm(train_data.dataloader, desc="Epoch {}".format(epoch)): global_step += 1 if torch.cuda.is_available(): batch = [b.to(torch.device("cuda")) for b in batch] pad_token_id = train_data.tokenizer.pad_token_id batch[0], batch[1] = trim_batch(batch[0], pad_token_id, batch[1]) batch[2], batch[3] = trim_batch(batch[2], pad_token_id, batch[3]) loss = model(input_ids=batch[0], attention_mask=batch[1], decoder_input_ids=batch[2], decoder_attention_mask=batch[3], is_training=True) if args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if torch.isnan(loss).data: logger.info("Stop training because loss=%s" % (loss.data)) stop_training = True break train_losses.append(loss.detach().cpu()) loss.backward() if global_step % args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() # We have accumulated enough gradients scheduler.step() model.zero_grad() if global_step % args.eval_period == 0: model.eval() curr_em = inference(model if args.n_gpu == 1 else model.module, dev_data) logger.info("Step %d Train loss %.2f %s %s on epoch=%d" % (global_step, np.mean(train_losses), dev_data.metric, curr_em, epoch)) train_losses = [] if best_accuracy < curr_em: model_state_dict = { k: v.cpu() for (k, v) in model.state_dict().items() } torch.save(model_state_dict, os.path.join(args.output_dir, "best-model.pt")) logger.info("Saving model with best %s: %s -> %s on epoch=%d, global_step=%d" % \ (dev_data.metric, best_accuracy, curr_em, epoch, global_step)) best_accuracy = curr_em wait_step = 0 stop_training = False else: wait_step += 1 if wait_step >= args.wait_step: stop_training = True break model.train() if stop_training: break model_state_dict = {k: v.cpu() for (k, v) in model.state_dict().items()} torch.save(model_state_dict, os.path.join(args.output_dir, "last-model.pt"))
def main(args): debug = False # Load a pre-defined tokenizer (GPT-2), create config and model logger.info("Prepare tokenizer, pretrained model and optimizer - add \ special tokens for fine-tuning") tokenizer = GPT2Tokenizer.from_pretrained(args.model_path, cache_dir=args.dataset_cache) tokenizer.add_tokens(SPECIAL_TOKENS) tokenizer.sep_token = '<sep>' if 'amr' in args.dataset_type: qgen = GPT2LMHeadModel.from_pretrained(args.model_path, cache_dir=args.dataset_cache) else: qgen = GPT2ConditionalLMHeadModel.\ from_pretrained(args.model_path, cache_dir=args.dataset_cache) qgen.resize_token_embeddings(len(tokenizer)) qgen.to(args.device) qgen.eval() logsoftmax = nn.LogSoftmax(dim=0) bos, eos, ctx, ans, que, pad, gen = \ tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS) if args.n_gpu > 1: logger.info("Setting model to DataParallel.") qgen = torch.nn.DataParallel(qgen) logger.info("Prepare datasets") if "amr" in args.dataset_type: logger.info("Decoding with AMR dev set") dataloader = get_data_loaders(args, tokenizer, qgen, dataset_name=args.output_data, shuffle=False) else: dataloader = get_data_loaders(args, tokenizer, qgen, shuffle=False) if 'amr' in args.dataset_type: if args.output_data.lower() == "test": ref = os.path.join(args.dataset_path, "test.tok.text") else: ref = os.path.join(args.dataset_path, "dev.tok.text") ref = open(ref).readlines() logger.info("Decode: " + args.decoder) # Output file name f = open(os.path.join(args.checkpoint, 'output.txt'), 'w') text_outputs = list() # beam search variables beam_size = 1 if args.beam_size is None else args.beam_size output_size = 1 if args.output_size is None else args.output_size beam_candidates = args.beam_candidates # General variables max_length = args.max_input_length instance = 0 for batch in tqdm(dataloader): batch = trim_batch(batch, pad) _, _, _, _, input_ids, _, token_type_ids, attention_mask = \ tuple(input_tensor.to(args.device) for input_tensor in batch) past = None o = 0 all_probs = torch.zeros(beam_size, 1).to(args.device) original_input_len = input_ids.shape[1] start = True # general variables questions = [] for idx in range(max_length): ################### # Greedy decoding ################### if args.decoder == "greedy": with torch.no_grad(): logits, past = qgen(input_ids=input_ids, token_type_ids=token_type_ids, past=past) outputs = torch.argmax(logits[0, -1, :]) outputs = outputs.unsqueeze(0).unsqueeze(0) ################### # Nucleous Sampling ################### elif args.decoder == "sampling": with torch.no_grad(): logits, past = qgen(input_ids=input_ids, token_type_ids=token_type_ids, past=past) # bs x seq_len x V logits = logits[:, -1, :] / args.temperature logits = top_k_top_p_filtering(logits, top_k=args.top_k, top_p=args.top_p) # bs x V probs = F.softmax(logits, dim=-1) # bs x 1 outputs = torch.multinomial(probs, num_samples=1) outputs = torch.where(input_ids[:, -1:] == eos, input_ids[:, -1:], outputs) ################### # BEAM Search ################### elif args.decoder == "beam": # Beam search with torch.no_grad(): logits = qgen(input_ids)[0] out_paths = None probs = None for k in range(logits.shape[0]): log_p = logsoftmax(logits[k, -1, :]) p = log_p + all_probs[k] if start: predicted_top_k = torch.topk(p, beam_size) start = False else: predicted_top_k = torch.topk(p, beam_candidates) p_top_k_tokens = predicted_top_k.indices[:, None] p_top_k_probs = predicted_top_k.values[:, None] # Store paths if out_paths is None: out_paths = torch.cat((input_ids[k].expand( p_top_k_tokens.shape[0], input_ids.shape[1]), p_top_k_tokens), 1) else: out_paths = torch.cat( (out_paths, torch.cat((input_ids[k].expand( p_top_k_tokens.shape[0], input_ids.shape[1]), p_top_k_tokens), 1)), 0) if probs is None: probs = p_top_k_probs else: probs = torch.cat((probs, p_top_k_probs), 0) global_top_k = torch.topk(probs, k=beam_size, dim=0) input_ids = out_paths[global_top_k.indices[:, 0], :] all_probs = global_top_k.values o += 1 else: raise Exception('Not valid decoder ' + args.decoder) ####################### # Termination condition ####################### if not args.decoder == 'beam': # correctly shape inputs for next round input_ids = outputs token_type_ids = token_type_ids[:, -1:] # if all the outputs are special tokens questions.append(outputs) if (outputs == eos).all(): break else: outputs = input_ids[:, original_input_len:] if ((outputs == eos).sum(dim=1) > 0).all(): break ################ # Output to file ################ if args.decoder != 'beam': # append an extra <eos> in case max length is reached questions.append(torch.zeros_like(outputs).fill_(eos)) questions = torch.cat(questions, dim=1) else: questions.append(outputs) questions = questions[0] for i, question in enumerate(questions): question = question.tolist() if eos in question: idx = question.index(eos) else: idx = -1 question = tokenizer.decode(question[:idx]) if '<generate>' in question: question = question.split('<generate>')[1] # Print outputs to file and save in text_outputs print(question.replace('\n', ' '), file=f) f.flush() text_outputs.append(question.replace('\n', ' ').lower()) # Limit number of outputs to output_size if i >= output_size - 1: break if 'amr' in args.dataset_type and debug: print("GOLD: ", ref[instance]) print("final: ", text_outputs[instance]) instance += 1
def train(args, logger, model, train_data, dev_data, optimizer, scheduler): model.train() global_step = 0 train_losses = [] best_accuracy = (-1.0, -1.0, -1.0) if args.dataset == "zest_grouped" else -1.0 stop_training = False # curr_em = inference(model if args.n_gpu==1 else model.module, dev_data) # logger.info("[Before Training] %s %s" % ( # dev_data.metric, # curr_em)) logger.info("Starting training!") model.model.backup_layer_norm_parameters() for epoch in range(int(args.num_train_epochs)): for batch in tqdm(train_data.dataloader, desc="Epoch {}".format(epoch)): global_step += 1 if torch.cuda.is_available(): batch = [b.to(torch.device("cuda")) for b in batch[0]] rel_ids, rel_masks = batch[0].unsqueeze(0), batch[1].unsqueeze(0) input_ids, input_masks = batch[2], batch[3] output_ids, output_masks = batch[4], batch[5] pad_token_id = train_data.tokenizer.pad_token_id rel_ids, rel_masks = trim_batch(rel_ids, pad_token_id, rel_masks) input_ids, input_masks = trim_batch(input_ids, pad_token_id, input_masks) output_ids, output_masks = trim_batch(output_ids, pad_token_id, output_masks) loss = model.forward(rel_ids=rel_ids, rel_masks=rel_masks, input_ids=input_ids, input_masks=input_masks, output_ids=output_ids, output_masks=output_masks, is_training=True) train_losses.append(loss.detach().cpu()) loss.backward() model.model.restore_layer_norm_parameters() if global_step % args.gradient_accumulation_steps == 0: # for p in model.meta_model.decoders.parameters(): # print(p) # print(p.grad) # break torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) # for p in model.meta_model.decoders.parameters(): # print() # print(p) # print(p.grad) # break optimizer.step() # We have accumulated enough gradients scheduler.step() model.zero_grad() # print(model.meta_model.decoders[-1].linear1.weight) if global_step % args.eval_period == 0: model.eval() # curr_em = 0.0 curr_em = inference(model if args.n_gpu == 1 else model.module, dev_data, save_predictions=True) logger.info("Step %d Train loss %.2f %s %s on epoch=%d" % (global_step, np.mean(train_losses), dev_data.metric, curr_em, epoch)) train_losses = [] if best_accuracy < curr_em: model_state_dict = { k: v.cpu() for (k, v) in model.state_dict().items() } torch.save(model_state_dict, os.path.join(args.output_dir, "best-model.pt")) logger.info("Saving model with best %s: %s -> %s on epoch=%d, global_step=%d" % \ (dev_data.metric, best_accuracy, curr_em, epoch, global_step)) best_accuracy = curr_em wait_step = 0 stop_training = False else: wait_step += 1 if wait_step >= args.wait_step: stop_training = True break model.train() if stop_training: break model_state_dict = {k: v.cpu() for (k, v) in model.state_dict().items()} torch.save(model_state_dict, os.path.join(args.output_dir, "last-model.pt"))
def generate_summaries( examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=True, task="summarization", decoder_start_token_id=None, finetune_flag: int = 0, checkpoint_path: str = "", **gen_kwargs, ) -> None: fout = Path(out_file).open("w", encoding="utf-8") # initialize tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) # if our goal is to evaluate the original checkpoint if finetune_flag < 1: # initialize the model checkpoints model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) # if our goal is to evaluate our fine-tuned checkpoint else: # load the finetuned checkpoints model = AutoModelForSeq2SeqLM.from_pretrained( f"{checkpoint_path}/best_tfmr").to(device) if fp16: model = model.half() if decoder_start_token_id is None: decoder_start_token_id = gen_kwargs.pop("decoder_start_token_id", None) # update config with summarization specific params use_task_specific_params(model, task) for batch in tqdm(list(chunks(examples, batch_size))): batch = tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(device) input_ids, attention_mask = trim_batch( **batch, pad_token_id=tokenizer.pad_token_id) # ----------------------------------------- # Topic Modeling - GSM # ----------------------------------------- docs = [] # load dict dictionary = Dictionary.load(datapath('dict-www-cnndm-unigram')) # remove [SEP] sep_list = [ '[SEP_0]', '[SEP_1]', '[SEP_2]', '[SEP_3]', '[SEP_4]', '[SEP_5]', '[SEP_6]', '[SEP_7]', '[SEP_8]', '[SEP_9]' ] # vocab size for topic modeling vocab_size = len(dictionary) # load config for GSM config = yaml_load(f"data/config/gsm.yaml") # model config['hidden']['features'][0] = vocab_size # trainer batch config['trainer_batch']['test_sample'] = 1 config = extend_config_reference(config) gsm_trainer = config['GSMtrainer'] gsm_trainer['base_dir'] = f"log/bart-large-cnn-finetune" gsm_trainer = GSMTrainer.from_config(gsm_trainer) total_sample = len(batch['input_ids']) for batch_num in range(total_sample): # extract the batch_sentence batch_sentence = tokenizer.decode( batch['input_ids'][batch_num].tolist(), skip_special_tokens=True) # change to lowercase and split to list batch_sentence_list = batch_sentence.split(" ") # remove [SEP] batch_sentence_list_nosep = [ item for item in batch_sentence_list if item not in sep_list ] text = ' '.join([x for x in batch_sentence_list_nosep]) fine_text = text.replace(' ##', '').lower() batch_sentence = re.sub(r'[^\w\s]', '', fine_text) # batch_sentence: change to the cleaned news for topic modeling # change to training data format in topic modeling gsm_data_bow = dictionary.doc2bow(batch_sentence.split(" ")) docs.append(gsm_data_bow) # gsm_data: data for topic modeling gsm_data = DataLoader(DocDataset(docs, len(dictionary), device='cuda'), batch_size=config['dataset']['batch_size'], drop_last=False, num_workers=0) gsm_trainer.__dict__['train_iterator'] = gsm_data gsm_loss, gsm_p = gsm_trainer.co_train(vocab_size=vocab_size, training=False) del gsm_data topic_p = gsm_p.cuda() summaries = model.generate( input_ids=input_ids, attention_mask=attention_mask, decoder_start_token_id=decoder_start_token_id, topic_p=topic_p, **gen_kwargs, ) dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) for hypothesis in dec: fout.write(hypothesis + "\n") fout.flush()