def run_test(test_dataloader, jiant_model: JiantModel, task, device, local_rank, verbose=True): if not local_rank == -1: return jiant_model.eval() evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task) eval_accumulator = evaluation_scheme.get_accumulator() for step, (batch, batch_metadata) in enumerate( maybe_tqdm(test_dataloader, desc=f"Eval ({task.name}, Test)", verbose=verbose) ): batch = batch.to(device) with torch.no_grad(): model_output = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=False, ) batch_logits = model_output.logits.detach().cpu().numpy() eval_accumulator.update( batch_logits=batch_logits, batch_loss=0, batch=batch, batch_metadata=batch_metadata, ) return { "preds": evaluation_scheme.get_preds_from_accumulator( task=task, accumulator=eval_accumulator, ), "accumulator": eval_accumulator, }
def run_val( val_dataloader, val_labels, jiant_model: JiantModel, task, device, local_rank, return_preds=False, verbose=True, ): # Reminder: # val_dataloader contains mostly PyTorch-relevant info # val_labels might contain more details information needed for full evaluation if not local_rank == -1: return jiant_model.eval() total_eval_loss = 0 nb_eval_steps, nb_eval_examples = 0, 0 evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task) eval_accumulator = evaluation_scheme.get_accumulator() for step, (batch, batch_metadata) in enumerate( maybe_tqdm(val_dataloader, desc=f"Eval ({task.name}, Val)", verbose=verbose) ): batch = batch.to(device) with torch.no_grad(): model_output = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=True, ) batch_logits = model_output.logits.detach().cpu().numpy() batch_loss = model_output.loss.mean().item() total_eval_loss += batch_loss eval_accumulator.update( batch_logits=batch_logits, batch_loss=batch_loss, batch=batch, batch_metadata=batch_metadata, ) nb_eval_examples += len(batch) nb_eval_steps += 1 eval_loss = total_eval_loss / nb_eval_steps tokenizer = ( jiant_model.tokenizer if not torch_utils.is_data_parallel(jiant_model) else jiant_model.module.tokenizer ) output = { "accumulator": eval_accumulator, "loss": eval_loss, "metrics": evaluation_scheme.compute_metrics_from_accumulator( task=task, accumulator=eval_accumulator, labels=val_labels, tokenizer=tokenizer, ), } if return_preds: output["preds"] = evaluation_scheme.get_preds_from_accumulator( task=task, accumulator=eval_accumulator, ) return output
def run_test( test_dataloader, jiant_model: JiantModel, task, device, local_rank, verbose=True, return_preds=True, return_logits=True, return_encoder_output: bool = False, ): if not local_rank == -1: return jiant_model.eval() evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task) eval_accumulator = evaluation_scheme.get_accumulator() encoder_outputs = [] for step, (batch, batch_metadata) in enumerate( maybe_tqdm(test_dataloader, desc=f"Eval ({task.name}, Test)", verbose=verbose)): regular_log(logger, step, interval=10, tag='test') batch = batch.to(device) with torch.no_grad(): model_outputs = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=False, get_encoder_output=return_encoder_output, ) if return_encoder_output: model_output, encoder_output = model_outputs encoder_outputs.append(encoder_output) else: model_output = model_outputs batch_logits = model_output.logits.detach().cpu().numpy() eval_accumulator.update( batch_logits=batch_logits, batch_loss=0, batch=batch, batch_metadata=batch_metadata, ) output = { "accumulator": eval_accumulator, } if return_preds: output["preds"] = evaluation_scheme.get_preds_from_accumulator( task=task, accumulator=eval_accumulator, ) if isinstance(eval_accumulator, evaluate.ConcatenateLogitsAccumulator) and return_logits: output["logits"] = eval_accumulator.get_accumulated() if return_encoder_output: output["encoder_outputs_pooled"] = np.concatenate( [encoder_output.pooled for encoder_output in encoder_outputs]) output["encoder_outputs_unpooled"] = np.concatenate( [encoder_output.unpooled for encoder_output in encoder_outputs]) return output
def run_val( val_dataloader, val_labels, jiant_model: JiantModel, task, device, local_rank, tf_writer: SummaryWriter, global_step: Optional[int] = None, phase=None, return_preds=False, return_logits=True, return_encoder_output: bool = False, verbose=True, split='valid', ): # Reminder: # val_dataloader contains mostly PyTorch-relevant info # val_labels might contain more details information needed for full evaluation has_labels = True # TODO: データセットにラベルが存在するかどうかを自動判定する. if not local_rank == -1: return jiant_model.eval() evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task) eval_accumulator = evaluation_scheme.get_accumulator() total_eval_loss = 0 nb_eval_steps, nb_eval_examples = 0, 0 encoder_outputs = [] for step, (batch, batch_metadata) in enumerate( maybe_tqdm(val_dataloader, desc=f"Eval ({task.name}, {str(phase)})", verbose=verbose)): regular_log(logger, step, interval=10, tag=split) batch = batch.to(device) with torch.no_grad(): model_outputs = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=has_labels, get_encoder_output=return_encoder_output, ) if return_encoder_output: model_output, encoder_output = model_outputs encoder_outputs.append(encoder_output) else: model_output = model_outputs batch_logits = model_output.logits.detach().cpu().numpy() if has_labels: batch_loss = model_output.loss.mean().item() else: batch_loss = 0 total_eval_loss += batch_loss eval_accumulator.update( batch_logits=batch_logits, batch_loss=batch_loss, batch=batch, batch_metadata=batch_metadata, ) nb_eval_examples += len(batch) nb_eval_steps += 1 eval_loss = total_eval_loss / nb_eval_steps output = { "accumulator": eval_accumulator, } if has_labels: tokenizer = (jiant_model.tokenizer if not torch_utils.is_data_parallel(jiant_model) else jiant_model.module.tokenizer) metrics = evaluation_scheme.compute_metrics_from_accumulator( task=task, accumulator=eval_accumulator, labels=val_labels, tokenizer=tokenizer, ) output.update({ "loss": eval_loss, "metrics": metrics, }) if global_step is not None: for metric_name, metric_value in metrics.minor.items(): tf_writer.add_scalar(f'{split}/{metric_name}', metric_value, global_step=global_step) if return_preds: output["preds"] = evaluation_scheme.get_preds_from_accumulator( task=task, accumulator=eval_accumulator, ) if isinstance(eval_accumulator, evaluate.ConcatenateLogitsAccumulator) and return_logits: output["logits"] = eval_accumulator.get_accumulated() if return_encoder_output: output["encoder_outputs_pooled"] = np.concatenate( [encoder_output.pooled for encoder_output in encoder_outputs]) output["encoder_outputs_unpooled"] = np.concatenate( [encoder_output.unpooled for encoder_output in encoder_outputs]) if global_step is not None: tf_writer.add_scalar(f'{split}/loss', eval_loss, global_step=global_step) tf_writer.flush() return output