def evaluate(self, dev_data, prefix_metric=None, save_dir=None, save_result=False, file_name=None): ''' Evaluate the model on a validation set ''' all_batch_list = [] eval_dataloader = self.build_eval_dataloader(dev_data) self.build_record_tracker() self.reset_metrics() pbar = ProgressBar(n_total=len(eval_dataloader), desc='Evaluating') for step, batch in enumerate(eval_dataloader): batch = self.predict_forward(batch) if 'loss' in batch and batch['loss'] is not None: self.records['loss_meter'].update(batch['loss'], n=1) all_batch_list.append(batch) pbar.step(step) self.records['result']['eval_loss'] = self.records['loss_meter'].avg self.update_metrics(all_batch_list, prefix_metric) self.print_evaluate_result() if save_result: if file_name is None: file_name = f"dev_eval_results.pkl" self.save_predict_result(data=all_batch_list, file_name=file_name, save_dir=save_dir) if torch.cuda.is_available(): torch.cuda.empty_cache()
def predict(self, test_data, save_result=True, file_name=None, save_dir=None): ''' test数据集预测 ''' all_batch_list = [] test_dataloader = self.build_test_dataloader(test_data) pbar = ProgressBar(n_total=len(test_dataloader), desc='Predicting') for step, batch in enumerate(test_dataloader): batch = self.predict_forward(batch) all_batch_list.append(batch) pbar.step(step) if save_result: if file_name is None: file_name = f"test_predict_results.{os.path.splitext(os.path.basename(test_data.data_name))[0]}.pkl" self.save_predict_result(data=all_batch_list, file_name=file_name, save_dir=save_dir) return all_batch_list
def _predict_forward(self, model, data_loader, do_eval=True, **kwargs): self.build_record_object() pbar = ProgressBar(n_total=len(data_loader), desc='Evaluating' if do_eval else 'Predicting') for step, batch in enumerate(data_loader): model.eval() inputs = self.build_inputs(batch) with torch.no_grad(): outputs = model.module(**inputs) if isinstance( model, nn.DataParallel) else model(**inputs) if do_eval: loss, logits = outputs[:2] self.records['loss_meter'].update(loss.item(), n=1) self.records['target'].extend(tensor_to_list(inputs['labels'])) else: if outputs[0].dim() == 1 and outputs[0].size(0) == 1: logits = outputs[1] else: logits = outputs[0] if self.args.use_crf: crf_model = model.module.crf if isinstance( model, nn.DataParallel) else model.crf tags = crf_model.decode(logits, inputs['attention_mask']) self.records['preds'].extend(tensor_to_list(tags.squeeze(0))) else: self.records['preds'].extend( tensor_to_list(torch.argmax(logits, dim=2))) self.records['input_lens'].extend( tensor_to_list(torch.sum(inputs['attention_mask'], 1))) pbar(step)
def _predict_forward(self, model, data_loader, do_eval=True, **kwargs): self.build_record_object() pbar = ProgressBar(n_total=len(data_loader), desc='Evaluating' if do_eval else 'Predicting') for step, batch in enumerate(data_loader): model.eval() inputs = self.build_inputs(batch) with torch.no_grad(): outputs = model(**inputs) if do_eval: loss, start_logits, end_logits = outputs[:3] start_positions = tensor_to_list(inputs['start_positions']) end_positions = tensor_to_list(inputs['end_positions']) self.records['loss_meter'].update(loss.mean().item(), n=1) self.records['target'].extend( zip(start_positions, end_positions)) else: if outputs[0].dim() == 1 and outputs[0].size(0) == 1: _, start_logits, end_logits = outputs[:3] else: start_logits, end_logits = outputs[:2] start_logits = tensor_to_list(torch.argmax(start_logits, -1)) end_logits = tensor_to_list(torch.argmax(end_logits, -1)) self.records['preds'].extend(zip(start_logits, end_logits)) self.records['input_lens'].extend( tensor_to_list(torch.sum(inputs['attention_mask'], 1))) pbar(step)
def predict_step(self, model, data_loader, do_eval, **kwargs): self.build_record_object() pbar = ProgressBar(n_total=len(data_loader), desc='Evaluating' if do_eval else 'Predicting') for step, batch in enumerate(data_loader): model.eval() inputs = self.build_inputs(batch) with torch.no_grad(): outputs = model(**inputs) if do_eval: loss, logits = outputs[:2] loss = loss.mean() self.records['target'].append(tensor_to_cpu(inputs['labels'])) self.records['loss_meter'].update(loss.item(), n=1) else: if outputs[0].dim() == 1 and outputs[0].size(0) == 1: logits = outputs[1] else: logits = outputs[0] anchor, positive, negative = logits distance_metric = DISTANCE2METRIC[self.args.distance_metric] distance_positive = distance_metric(anchor, positive) distance_negative = distance_metric(anchor, negative) diff_dist = 1 - (distance_positive > distance_negative).int() self.records['preds'].append(tensor_to_cpu(diff_dist)) pbar(step) self.records['preds'] = torch.cat(self.records['preds'], dim=0) if do_eval: self.records['target'] = torch.cat(self.records['target'], dim=0)
def _predict_forward(self, model, data_loader, do_eval, **kwargs): self.build_record_object() pbar = ProgressBar(n_total=len(data_loader), desc='Evaluating' if do_eval else 'Predicting') for step, batch in enumerate(data_loader): model.eval() inputs = self.build_inputs(batch) with torch.no_grad(): outputs = model(**inputs) if do_eval: loss, logits = outputs[:2] loss = loss.mean() labels = inputs['labels'] self.records['target'].append(tensor_to_cpu(labels)) self.records['loss_meter'].update(loss.item(), n=1) else: if outputs[0].dim() == 1 and outputs[0].size(0) == 1: logits = outputs[1] else: logits = outputs[0] self.records['preds'].append(tensor_to_cpu(logits)) pbar(step) self.records['preds'] = torch.cat(self.records['preds'], dim=0) if do_eval: self.records['target'] = torch.cat(self.records['target'], dim=0)
def train(self, train_data, dev_data=None, resume_path=None, start_epoch=1, state_to_save=dict()): train_dataloader = self.build_train_dataloader(train_data) num_training_steps = len( train_dataloader ) // self.gradient_accumulation_steps * self.num_train_epochs self.steps_in_epoch = len(train_dataloader) if self.scheduler is None: self.scheduler = self.build_lr_scheduler(num_training_steps) self.resume_from_checkpoint(resume_path=resume_path) self.build_model_warp() self.print_summary(len(train_data), num_training_steps) self.optimizer.zero_grad() seed_everything( self.opts.seed, verbose=False ) # Added here for reproductibility (even between python 2 and 3) if self.opts.logging_steps < 0: self.opts.logging_steps = len( train_dataloader) // self.gradient_accumulation_steps self.opts.logging_steps = max(1, self.opts.logging_steps) if self.opts.save_steps < 0: self.opts.save_steps = len( train_dataloader) // self.gradient_accumulation_steps self.opts.save_steps = max(1, self.opts.save_steps) self.build_record_tracker() self.reset_metrics() pbar = ProgressBar(n_total=len(train_dataloader), desc='Training', num_epochs=self.num_train_epochs) for epoch in range(start_epoch, int(self.num_train_epochs) + 1): pbar.epoch(current_epoch=epoch) for step, batch in enumerate(train_dataloader): outputs, should_logging, should_save = self.train_step( step, batch) if outputs is not None: if self.opts.ema_enable: self.model_ema.update(self.model) pbar.step(step, {'loss': outputs['loss'].item()}) if (self.opts.logging_steps > 0 and self.global_step > 0) and \ should_logging and self.opts.evaluate_during_training: self.evaluate(dev_data) if self.opts.ema_enable and self.model_ema is not None: self.evaluate(dev_data, prefix_metric='ema') if hasattr(self.writer, 'save'): self.writer.save() if (self.opts.save_steps > 0 and self.global_step > 0) and should_save: # model checkpoint if self.model_checkpoint: state = self.build_state_object(**state_to_save) if self.opts.evaluate_during_training: if self.model_checkpoint.monitor not in self.records[ 'result']: msg = ( "There were expected keys in the eval result: " f"{', '.join(list(self.records['result'].keys()))}, " f"but get {self.model_checkpoint.monitor}." ) raise TypeError(msg) self.model_checkpoint.step( state=state, current=self.records['result'][ self.model_checkpoint.monitor]) else: self.model_checkpoint.step(state=state, current=None) # early_stopping if self.early_stopping: if self.early_stopping.monitor not in self.records['result']: msg = ( "There were expected keys in the eval result: " f"{', '.join(list(self.records['result'].keys()))}, " f"but get {self.early_stopping.monitor}.") raise TypeError(msg) self.early_stopping.step(current=self.records['result'][ self.early_stopping.monitor]) if self.early_stopping.stop_training: break if torch.cuda.is_available(): torch.cuda.empty_cache() if self.writer: self.writer.close()
def train(self, model, train_dataset, eval_dataset): """ Main training entry point. """ train_dataloader = self.build_train_dataloader(train_dataset) t_total = len( train_dataloader ) // self.args.gradient_accumulation_steps * self.args.num_train_epochs optimizer = self.build_optimizer(model) scheduler = self.build_scheduler(optimizer, t_total) optimizer, scheduler = self.restore_optimizer(optimizer, scheduler) model, optimizer = self.build_apex_and_distribute(model, optimizer) # Train! self.print_training_parameters(model, len(train_dataset), t_total) model.zero_grad() # ema if self.args.do_ema: ema = EMA(model, decay=self.args.ema_decay) seed_everything( self.args.seed ) # Added here for reproductibility (even between python 2 and 3) print('Start training.') for epoch in range(0, int(self.args.num_train_epochs)): self.build_record_object() pbar = ProgressBar(n_total=len(train_dataloader), desc='Training') for step, batch in enumerate(train_dataloader): loss = self._train_step(model, batch, optimizer) if (step + 1) % self.args.gradient_accumulation_steps == 0: self._train_update(model, optimizer, loss, scheduler) if self.args.do_ema: ema.update(model) pbar(step, {'loss': loss}) if (self.args.local_rank in [-1, 0] and self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0): if self.args.do_ema: ema.apply_shadow(model) self.tb_writer.add_scalar( 'Loss/train_epoch_loss', self.records['loss_meter'].avg, int(self.global_step / self.args.logging_steps)) self.evaluate(model, eval_dataset) if self.args.do_ema: ema.restore(model) if hasattr(self.tb_writer, 'save'): self.tb_writer.save() if (self.args.local_rank in [-1, 0] and self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0): # model checkpoint if self.model_checkpoint: state = self.build_state_object( model, optimizer, scheduler, self.global_step) self.model_checkpoint.step( state=state, current=self.records['result'][ self.model_checkpoint.monitor]) if not self.scheduler_on_batch: # epoch scheduler scheduler.step() # early_stopping if self.early_stopping: self.early_stopping.step(current=self.records['result'][ self.early_stopping.monitor]) if self.early_stopping.stop_training: break if torch.cuda.is_available(): torch.cuda.empty_cache() if self.tb_writer: self.tb_writer.close()