def main(config_file, update, _run, _log, _config): working_dir = _run.observers[0].dir # load the config file config = yaml_load(config_file) recursive_update(config, update) yaml_dump(config, path.join(working_dir, 'config.yaml')) _config = config print(_config) print(working_dir) dataset = _config['dataset'] dataset['device'] = update['device'] # load the dataset and the vocab train_loader, dev_loader, test_loader, vocab = load_data(**dataset) vocab_size = len(vocab.itos) # model _config['hidden']['features'][0] = vocab_size # trainer batch test_sample = _config['trainer_batch']['test_sample'] _config['trainer_batch']['test_sample'] = 1 config = extend_config_reference(_config) trainer = config['trainer'] trainer['evaluate_interval'] = len( train_loader) * trainer['evaluate_interval'] trainer['save_checkpoint_interval'] = trainer['evaluate_interval'] trainer['base_dir'] = working_dir yaml_dump(trainer, path.join(working_dir, 'trainer.yaml')) trainer['train_iterator'] = train_loader trainer['dev_iterator'] = dev_loader trainer['test_iterator'] = None callback = EvaluationCallback(working_dir, vocab, corpus_dir=path.join(dataset['data_dir'], 'corpus'), **config['callback']) trainer['callbacks'] = callback trainer['logger'] = _log print(config) trainer = Trainer.from_config(trainer) _log.info("model architecture") print(trainer.trainer_batch.model) # train the model trainer.train() # testing and save results trainer.dev_iterator = test_loader trainer.trainer_batch.test_sample = test_sample # test using many samples, but not in development dataset trainer.restore_from_basedir(best=True) stat = trainer._evaluate_epoch().get_dict() callback.evaluate_topic_coherence() # topic coherence of best checkpoint stat.update(callback.get_dict()) yaml_dump(stat, path.join(working_dir, 'result.yaml')) _log.info('test result of best evaluation {}'.format(stat))
def run_training_batch(self, batch, batch_idx): """ :param batch: dict; contains three keys: input_ids, attention_mask, decoder_input_ids Example for 'batch': batch: {'input_ids': tensor([[ 0, 36, 230, ..., 8, 41, 2]]), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1]]), 'decoder_input_ids': tensor([[ 0, 287, 10, 2107, 111, 10468, 226, 47385, 11579, 1012, 2156, 5, 5302, 47385, 281, 47385, 10003, 255, 47385, 347, 111, 2107, 47385, 574, 47385, 1000, 47385, 398, 47385, 245, 16, 10, 205, 1374, 12576, 479, 646, 1000, 1215, 3388, 510, 742, 85, 128, 579, 65, 9, 5, 357, 3092, 23, 63, 1836, 11, 5, 3555, 111, 672, 2156, 26180, 47385, 642, 111, 3547, 4120, 479, 646, 1000, 1215, 3388, 510, 742, 7192, 8806, 10262, 3444, 7951, 2170, 1318, 2]])} :param batch_idx: number of batch :return: """ # load tokenizer tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') # load config for GSM config = yaml_load(f"{self.default_root_dir}/data/config/gsm.yaml") # load dict dictionary = Dictionary.load(datapath('dict-www-cnndm-unigram')) # remove [SEP] sep_list = [ '[SEP_0]', '[SEP_1]', '[SEP_2]', '[SEP_3]', '[SEP_4]', '[SEP_5]', '[SEP_6]', '[SEP_7]', '[SEP_8]', '[SEP_9]', '<S_SEP>' ] # vocab size for topic modeling vocab_size = len(dictionary) # model config['hidden']['features'][0] = vocab_size # trainer batch config['trainer_batch']['test_sample'] = 1 config = extend_config_reference(config) gsm_trainer = config['GSMtrainer'] gsm_trainer[ 'base_dir'] = f"{self.default_root_dir}/log/bart-large-cnn-finetune" gsm_trainer = GSMTrainer.from_config(gsm_trainer) # number of topics K = config['gsmtopic']['k'] # yaml_dump(gsm_trainer, # os.path.join(f"{self.default_root_dir}/log/bart-large-cnn-finetune", "gsm_trainer.yaml")) # ----------------------------------------- # Topic Modeling - GSM # ----------------------------------------- batch_size = batch['input_ids'].size()[0] docs = [] for batch_num in range(batch_size): # extract the batch_sentence batch_sentence = tokenizer.decode( batch['input_ids'][batch_num].tolist(), skip_special_tokens=True) # change to lowercase and split to list batch_sentence_list = batch_sentence.split(" ") # remove [SEP] batch_sentence_list_nosep = [ item for item in batch_sentence_list if item not in sep_list ] text = ' '.join([x for x in batch_sentence_list_nosep]) fine_text = text.replace(' ##', '').lower() batch_sentence = re.sub(r'[^\w\s]', '', fine_text) # batch_sentence: change to the cleaned news for topic modeling # change to training data format in topic modeling gsm_data_bow = dictionary.doc2bow(batch_sentence.split(" ")) docs.append(gsm_data_bow) # gsm_data: data for topic modeling gsm_data = DataLoader(DocDataset(docs, len(dictionary), device='cuda'), batch_size=config['dataset']['batch_size'], drop_last=False, num_workers=0) gsm_trainer.__dict__['train_iterator'] = gsm_data gsm_loss, gsm_p = gsm_trainer.co_train(vocab_size, training=True) del gsm_data # track grad norms grad_norm_dic = {} # track all metrics for callbacks batch_callback_metrics = [] # track metrics to log batch_log_metrics = [] if batch is None: return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) # Batch start events with self.profiler.profile('on_batch_start'): # callbacks self.on_batch_start() # hooks if self.is_function_implemented('on_batch_start'): response = self.get_model().on_batch_start(batch) if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) splits = [batch] if self.truncated_bptt_steps is not None: model_ref = self.get_model() with self.profiler.profile('tbptt_split_batch'): splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps) self.hiddens = None for split_idx, split_batch in enumerate(splits): self.split_idx = split_idx for opt_idx, optimizer in self._get_optimizers_iterable(): # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if len(self.optimizers) > 1: for param in self.get_model().parameters(): param.requires_grad = False for group in optimizer.param_groups: for param in group['params']: param.requires_grad = True # ------------------- # calculate loss # ------------------- beta = 0.01 opt_closure_result = self.optimizer_closure( split_batch, batch_idx, opt_idx, optimizer, self.hiddens, gsm_p, # topic distribution gsm_loss, # loss for topic modeling K, # number of topics beta, ) # ------------------------------ # POST forward bookkeeping # ------------------------------ batch_callback_metrics.append( opt_closure_result.training_step_output.callback_metrics) batch_log_metrics.append( opt_closure_result.training_step_output.log_metrics) self.add_progress_bar_metrics( opt_closure_result.training_step_output.pbar_on_batch_end) # track hiddens self.hiddens = opt_closure_result.hiddens # check if loss or model weights are nan if self.terminate_on_nan: self.detect_nan_tensors(opt_closure_result.loss) # track total loss for logging (avoid mem leaks) self.batch_loss_value.append(opt_closure_result.loss) # ------------------------------ # BACKWARD PASS # ------------------------------ # gradient update with accumulated gradients if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: # backward grad_norm_dic = self.run_batch_backward_pass( split_batch, batch_idx, opt_idx, optimizer) # calculate running loss for display self.running_loss.append(self.batch_loss_value.mean()) # reset for next set of accumulated grads self.batch_loss_value.reset() # Batch end events with self.profiler.profile('on_batch_end'): # callbacks self.on_batch_end() # model hooks if self.is_function_implemented('on_batch_end'): self.get_model().on_batch_end() # collapse all metrics into one dict batch_log_metrics = { k: v for d in batch_log_metrics for k, v in d.items() } # track all metrics for callbacks self.callback_metrics.update( {k: v for d in batch_callback_metrics for k, v in d.items()}) result = AttributeDict( signal=0, grad_norm_dic=grad_norm_dic, batch_log_metrics=batch_log_metrics, training_step_output_for_epoch_end=opt_closure_result. training_step_output_for_epoch_end) return result
def _evaluate(self, model: LightningModule, dataloaders: List[DataLoader], max_batches: Union[int, List[int]], test_mode: bool = False): """Run evaluation code. Args: model: The model to evaluate. dataloaders: A list of PyTorch dataloaders. max_batches: An integer or list of integers with length of the number of dataloaders. Each entry is the number of batches to process in the corresponding dataloader. test_mode: """ # enable eval mode model.zero_grad() model.eval() # copy properties for forward overrides self.copy_trainer_model_properties(model) # bookkeeping outputs = [] # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) # run validation for dataloader_idx, dataloader in enumerate(dataloaders): dl_outputs = [] # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device(self.tpu_id) dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) # each dataloader has a max num batches dl_max_batches = max_batches[dataloader_idx] # load tokenizer tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') # load dict dictionary = Dictionary.load(datapath('dict-www-cnndm-unigram')) # vocab size for topic modeling vocab_size = len(dictionary) for batch_idx, batch in enumerate(dataloader): if batch is None: continue # stop short when on fast_dev_run (sets max_batch=1) if batch_idx >= dl_max_batches: break # ----------------- # Topic Modeling # ----------------- # load config for GSM config = yaml_load( f"{self.default_root_dir}/data/config/gsm.yaml") # remove [SEP] sep_list = [ '[SEP_0]', '[SEP_1]', '[SEP_2]', '[SEP_3]', '[SEP_4]', '[SEP_5]', '[SEP_6]', '[SEP_7]', '[SEP_8]', '[SEP_9]' ] # model config['hidden']['features'][0] = vocab_size # trainer batch config['trainer_batch']['test_sample'] = 1 config = extend_config_reference(config) gsm_trainer = config['GSMtrainer'] gsm_trainer[ 'base_dir'] = f"{self.default_root_dir}/log/bart-large-cnn-finetune" gsm_trainer = GSMTrainer.from_config(gsm_trainer) # ----------------------------------------- # Topic Modeling - GSM # ----------------------------------------- batch_size = batch['input_ids'].size()[0] docs = [] for batch_num in range(batch_size): # extract the batch_sentence batch_sentence = tokenizer.decode( batch['input_ids'][batch_num].tolist(), skip_special_tokens=True) # change to lowercase and split to list batch_sentence_list = batch_sentence.split(" ") # remove [SEP] batch_sentence_list_nosep = [ item for item in batch_sentence_list if item not in sep_list ] text = ' '.join([x for x in batch_sentence_list_nosep]) fine_text = text.replace(' ##', '').lower() batch_sentence = re.sub(r'[^\w\s]', '', fine_text) # batch_sentence: change to the cleaned news for topic modeling # change to training data format in topic modeling gsm_data_bow = dictionary.doc2bow( batch_sentence.split(" ")) docs.append(gsm_data_bow) # gsm_data: data for topic modeling gsm_data = DataLoader( DocDataset(docs, len(dictionary), device='cuda'), batch_size=config['dataset']['batch_size'], drop_last=False, num_workers=0) gsm_trainer.__dict__['train_iterator'] = gsm_data gsm_loss, gsm_p = gsm_trainer.co_train(vocab_size, training=False) del gsm_data topic_p = Variable(gsm_p.data, requires_grad=False) # tm_loss = Variable(gsm_loss.data, requires_grad=False) # disable gradients to save memory torch.set_grad_enabled(False) # callbacks if test_mode: self.on_test_batch_start() else: self.on_validation_batch_start() # ----------------- # RUN EVALUATION STEP # ----------------- beta = 0.01 if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: with torch.cuda.amp.autocast(): output = self.evaluation_forward( model, batch, batch_idx, dataloader_idx, topic_p, gsm_loss, beta, test_mode) else: output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, topic_p, gsm_loss, beta, test_mode) # on dp / ddp2 might still want to do something with the batch parts if test_mode: if self.is_overridden('test_step_end'): model_ref = self.get_model() with self.profiler.profile('test_step_end'): output = model_ref.test_step_end(output) self.on_test_batch_end() else: if self.is_overridden('validation_step_end'): model_ref = self.get_model() with self.profiler.profile('validation_step_end'): output = model_ref.validation_step_end(output) self.on_validation_batch_end() # track outputs for collation dl_outputs.append(output) # enable gradients to save memory torch.set_grad_enabled(True) outputs.append(dl_outputs) eval_results = {} # with a single dataloader don't pass an array if len(dataloaders) == 1: outputs = outputs[0] # give model a chance to do something with the outputs (and method defined) if isinstance( model, (LightningDistributedDataParallel, LightningDataParallel)): model = model.module if test_mode: if self.is_overridden('test_end', model=model): eval_results = model.test_end(outputs) rank_zero_warn( 'Method `test_end` was deprecated in v0.7 and will be removed in v1.0.' ' Use `test_epoch_end` instead.', DeprecationWarning) elif self.is_overridden('test_epoch_end', model=model): eval_results = model.test_epoch_end(outputs) else: if self.is_overridden('validation_end', model=model): eval_results = model.validation_end(outputs) rank_zero_warn( 'Method `validation_end` was deprecated in v0.7 and will be removed in v1.0.' ' Use `validation_epoch_end` instead.', DeprecationWarning) elif self.is_overridden('validation_epoch_end', model=model): eval_results = model.validation_epoch_end(outputs) # enable train mode again model.train() return eval_results
def generate_summaries( examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=True, task="summarization", decoder_start_token_id=None, finetune_flag: int = 0, checkpoint_path: str = "", **gen_kwargs, ) -> None: fout = Path(out_file).open("w", encoding="utf-8") # initialize tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) # if our goal is to evaluate the original checkpoint if finetune_flag < 1: # initialize the model checkpoints model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) # if our goal is to evaluate our fine-tuned checkpoint else: # load the finetuned checkpoints model = AutoModelForSeq2SeqLM.from_pretrained( f"{checkpoint_path}/best_tfmr").to(device) if fp16: model = model.half() if decoder_start_token_id is None: decoder_start_token_id = gen_kwargs.pop("decoder_start_token_id", None) # update config with summarization specific params use_task_specific_params(model, task) for batch in tqdm(list(chunks(examples, batch_size))): batch = tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(device) input_ids, attention_mask = trim_batch( **batch, pad_token_id=tokenizer.pad_token_id) # ----------------------------------------- # Topic Modeling - GSM # ----------------------------------------- docs = [] # load dict dictionary = Dictionary.load(datapath('dict-www-cnndm-unigram')) # remove [SEP] sep_list = [ '[SEP_0]', '[SEP_1]', '[SEP_2]', '[SEP_3]', '[SEP_4]', '[SEP_5]', '[SEP_6]', '[SEP_7]', '[SEP_8]', '[SEP_9]' ] # vocab size for topic modeling vocab_size = len(dictionary) # load config for GSM config = yaml_load(f"data/config/gsm.yaml") # model config['hidden']['features'][0] = vocab_size # trainer batch config['trainer_batch']['test_sample'] = 1 config = extend_config_reference(config) gsm_trainer = config['GSMtrainer'] gsm_trainer['base_dir'] = f"log/bart-large-cnn-finetune" gsm_trainer = GSMTrainer.from_config(gsm_trainer) total_sample = len(batch['input_ids']) for batch_num in range(total_sample): # extract the batch_sentence batch_sentence = tokenizer.decode( batch['input_ids'][batch_num].tolist(), skip_special_tokens=True) # change to lowercase and split to list batch_sentence_list = batch_sentence.split(" ") # remove [SEP] batch_sentence_list_nosep = [ item for item in batch_sentence_list if item not in sep_list ] text = ' '.join([x for x in batch_sentence_list_nosep]) fine_text = text.replace(' ##', '').lower() batch_sentence = re.sub(r'[^\w\s]', '', fine_text) # batch_sentence: change to the cleaned news for topic modeling # change to training data format in topic modeling gsm_data_bow = dictionary.doc2bow(batch_sentence.split(" ")) docs.append(gsm_data_bow) # gsm_data: data for topic modeling gsm_data = DataLoader(DocDataset(docs, len(dictionary), device='cuda'), batch_size=config['dataset']['batch_size'], drop_last=False, num_workers=0) gsm_trainer.__dict__['train_iterator'] = gsm_data gsm_loss, gsm_p = gsm_trainer.co_train(vocab_size=vocab_size, training=False) del gsm_data topic_p = gsm_p.cuda() summaries = model.generate( input_ids=input_ids, attention_mask=attention_mask, decoder_start_token_id=decoder_start_token_id, topic_p=topic_p, **gen_kwargs, ) dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) for hypothesis in dec: fout.write(hypothesis + "\n") fout.flush()