def run_test(test_dataloader, jiant_model: JiantModel, task, device, local_rank, verbose=True): if not local_rank == -1: return jiant_model.eval() evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task) eval_accumulator = evaluation_scheme.get_accumulator() for step, (batch, batch_metadata) in enumerate( maybe_tqdm(test_dataloader, desc=f"Eval ({task.name}, Test)", verbose=verbose) ): batch = batch.to(device) with torch.no_grad(): model_output = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=False, ) batch_logits = model_output.logits.detach().cpu().numpy() eval_accumulator.update( batch_logits=batch_logits, batch_loss=0, batch=batch, batch_metadata=batch_metadata, ) return { "preds": evaluation_scheme.get_preds_from_accumulator( task=task, accumulator=eval_accumulator, ), "accumulator": eval_accumulator, }
def run_val( val_dataloader, val_labels, jiant_model: JiantModel, task, device, local_rank, return_preds=False, verbose=True, ): # Reminder: # val_dataloader contains mostly PyTorch-relevant info # val_labels might contain more details information needed for full evaluation if not local_rank == -1: return jiant_model.eval() total_eval_loss = 0 nb_eval_steps, nb_eval_examples = 0, 0 evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task) eval_accumulator = evaluation_scheme.get_accumulator() for step, (batch, batch_metadata) in enumerate( maybe_tqdm(val_dataloader, desc=f"Eval ({task.name}, Val)", verbose=verbose) ): batch = batch.to(device) with torch.no_grad(): model_output = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=True, ) batch_logits = model_output.logits.detach().cpu().numpy() batch_loss = model_output.loss.mean().item() total_eval_loss += batch_loss eval_accumulator.update( batch_logits=batch_logits, batch_loss=batch_loss, batch=batch, batch_metadata=batch_metadata, ) nb_eval_examples += len(batch) nb_eval_steps += 1 eval_loss = total_eval_loss / nb_eval_steps tokenizer = ( jiant_model.tokenizer if not torch_utils.is_data_parallel(jiant_model) else jiant_model.module.tokenizer ) output = { "accumulator": eval_accumulator, "loss": eval_loss, "metrics": evaluation_scheme.compute_metrics_from_accumulator( task=task, accumulator=eval_accumulator, labels=val_labels, tokenizer=tokenizer, ), } if return_preds: output["preds"] = evaluation_scheme.get_preds_from_accumulator( task=task, accumulator=eval_accumulator, ) return output
def load_state_dict_for_jiant_model_with_adapters(jiant_model: JiantModel, state_dict: Dict): """Load state_dict into jiant model, allowing for missing_keys (untrained encoder) The checks aren't very rigorous """ mismatched = jiant_model.load_state_dict(state_dict, strict=False) assert mismatched.missing_keys assert not mismatched.unexpected_keys
def get_optimized_state_dict_for_jiant_model_with_adapters( jiant_model: JiantModel): """Get the state_dict for relevant weights for a JiantModel with adapters Basically, the tensors for the adapters and taskmodel heads """ dropped = [] kept = {} for name, tensor in jiant_model.state_dict().items(): if name.startswith("encoder.") and "adapter" not in name: dropped.append(name) elif ".encoder." in name and "adapter" not in name: # Do not keep the taskmodel encoder weights UNLESS they are adapter modules dropped.append(name) else: # This should include adapters and taskmodel heads kept[name] = tensor return kept, dropped
def get_optimized_named_parameters_for_jiant_model_with_adapters( jiant_model: JiantModel): """Does a couple things: 1. Finds the adapter parameters and taskmodel heads (the only params to be optimized) 2. Sets the other parameters to not require gradients """ set_to_no_grad_list = [] optimized_named_parameters = [] for name, param in jiant_model.named_parameters(): if name.startswith("encoder.") and "adapter" not in name: # Do not optimize the shared encoder torch_utils.set_requires_grad_single(param, requires_grad=False) set_to_no_grad_list.append(name) elif ".encoder." in name and "adapter" not in name: # Do not optimize the taskmodel encoder weights UNLESS they are adapter modules # I believe this strictly speaking isn't necessary because .named_parameters() # doesn't return duplicates, but better to be safe. torch_utils.set_requires_grad_single(param, requires_grad=False) set_to_no_grad_list.append(name) else: # This should include adapters and taskmodel heads optimized_named_parameters.append((name, param)) return optimized_named_parameters, set_to_no_grad_list
def run_test( test_dataloader, jiant_model: JiantModel, task, device, local_rank, verbose=True, return_preds=True, return_logits=True, return_encoder_output: bool = False, ): if not local_rank == -1: return jiant_model.eval() evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task) eval_accumulator = evaluation_scheme.get_accumulator() encoder_outputs = [] for step, (batch, batch_metadata) in enumerate( maybe_tqdm(test_dataloader, desc=f"Eval ({task.name}, Test)", verbose=verbose)): regular_log(logger, step, interval=10, tag='test') batch = batch.to(device) with torch.no_grad(): model_outputs = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=False, get_encoder_output=return_encoder_output, ) if return_encoder_output: model_output, encoder_output = model_outputs encoder_outputs.append(encoder_output) else: model_output = model_outputs batch_logits = model_output.logits.detach().cpu().numpy() eval_accumulator.update( batch_logits=batch_logits, batch_loss=0, batch=batch, batch_metadata=batch_metadata, ) output = { "accumulator": eval_accumulator, } if return_preds: output["preds"] = evaluation_scheme.get_preds_from_accumulator( task=task, accumulator=eval_accumulator, ) if isinstance(eval_accumulator, evaluate.ConcatenateLogitsAccumulator) and return_logits: output["logits"] = eval_accumulator.get_accumulated() if return_encoder_output: output["encoder_outputs_pooled"] = np.concatenate( [encoder_output.pooled for encoder_output in encoder_outputs]) output["encoder_outputs_unpooled"] = np.concatenate( [encoder_output.unpooled for encoder_output in encoder_outputs]) return output
def run_val( val_dataloader, val_labels, jiant_model: JiantModel, task, device, local_rank, tf_writer: SummaryWriter, global_step: Optional[int] = None, phase=None, return_preds=False, return_logits=True, return_encoder_output: bool = False, verbose=True, split='valid', ): # Reminder: # val_dataloader contains mostly PyTorch-relevant info # val_labels might contain more details information needed for full evaluation has_labels = True # TODO: データセットにラベルが存在するかどうかを自動判定する. if not local_rank == -1: return jiant_model.eval() evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task) eval_accumulator = evaluation_scheme.get_accumulator() total_eval_loss = 0 nb_eval_steps, nb_eval_examples = 0, 0 encoder_outputs = [] for step, (batch, batch_metadata) in enumerate( maybe_tqdm(val_dataloader, desc=f"Eval ({task.name}, {str(phase)})", verbose=verbose)): regular_log(logger, step, interval=10, tag=split) batch = batch.to(device) with torch.no_grad(): model_outputs = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=has_labels, get_encoder_output=return_encoder_output, ) if return_encoder_output: model_output, encoder_output = model_outputs encoder_outputs.append(encoder_output) else: model_output = model_outputs batch_logits = model_output.logits.detach().cpu().numpy() if has_labels: batch_loss = model_output.loss.mean().item() else: batch_loss = 0 total_eval_loss += batch_loss eval_accumulator.update( batch_logits=batch_logits, batch_loss=batch_loss, batch=batch, batch_metadata=batch_metadata, ) nb_eval_examples += len(batch) nb_eval_steps += 1 eval_loss = total_eval_loss / nb_eval_steps output = { "accumulator": eval_accumulator, } if has_labels: tokenizer = (jiant_model.tokenizer if not torch_utils.is_data_parallel(jiant_model) else jiant_model.module.tokenizer) metrics = evaluation_scheme.compute_metrics_from_accumulator( task=task, accumulator=eval_accumulator, labels=val_labels, tokenizer=tokenizer, ) output.update({ "loss": eval_loss, "metrics": metrics, }) if global_step is not None: for metric_name, metric_value in metrics.minor.items(): tf_writer.add_scalar(f'{split}/{metric_name}', metric_value, global_step=global_step) if return_preds: output["preds"] = evaluation_scheme.get_preds_from_accumulator( task=task, accumulator=eval_accumulator, ) if isinstance(eval_accumulator, evaluate.ConcatenateLogitsAccumulator) and return_logits: output["logits"] = eval_accumulator.get_accumulated() if return_encoder_output: output["encoder_outputs_pooled"] = np.concatenate( [encoder_output.pooled for encoder_output in encoder_outputs]) output["encoder_outputs_unpooled"] = np.concatenate( [encoder_output.unpooled for encoder_output in encoder_outputs]) if global_step is not None: tf_writer.add_scalar(f'{split}/loss', eval_loss, global_step=global_step) tf_writer.flush() return output
def main(args: RunConfiguration): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # === Shared model components setup === # model_type = "roberta-base" model_arch = ModelArchitectures.from_model_type(model_type=model_type) transformers_class_spec = model_setup.TRANSFORMERS_CLASS_SPEC_DICT[ model_arch] ancestor_model = model_setup.get_ancestor_model( transformers_class_spec=transformers_class_spec, model_config_path=args.model_config_path, ) encoder = model_setup.get_encoder( model_arch=model_arch, ancestor_model=ancestor_model, ) tokenizer = shared_model_setup.get_tokenizer( model_type=model_type, tokenizer_path=args.model_tokenizer_path, ) # === Taskmodels setup === # task_dict = { "mnli": tasks.create_task_from_config_path( os.path.join( args.task_config_base_path, "mnli.json", )), "qnli": tasks.create_task_from_config_path( os.path.join( args.task_config_base_path, "qnli.json", )), "rte": tasks.create_task_from_config_path( os.path.join( args.task_config_base_path, "qnli.json", )) } taskmodels_dict = { "nli": taskmodels.ClassificationModel( encoder=encoder, classification_head=heads.ClassificationHead( hidden_size=encoder.config.hidden_size, hidden_dropout_prob=encoder.config.hidden_dropout_prob, num_labels=len(task_dict["mnli"].LABELS), ), ), "rte": taskmodels.ClassificationModel( encoder=encoder, classification_head=heads.ClassificationHead( hidden_size=encoder.config.hidden_size, hidden_dropout_prob=encoder.config.hidden_dropout_prob, num_labels=len(task_dict["rte"].LABELS), ), ), } task_to_taskmodel_map = { "mnli": "nli", "qnli": "nli", "rte": "rte", } # === Final === # jiant_model = JiantModel( task_dict=task_dict, encoder=encoder, taskmodels_dict=taskmodels_dict, task_to_taskmodel_map=task_to_taskmodel_map, tokenizer=tokenizer, ) jiant_model = jiant_model.to(device) # === Run === # task_dataloader_dict = {} for task_name, task in task_dict.items(): train_cache = caching.ChunkedFilesDataCache( cache_fol_path=os.path.join(args.task_cache_base_path, task_name, "train"), ) train_dataset = train_cache.get_iterable_dataset(buffer_size=10000, shuffle=True) train_dataloader = torch_utils.DataLoaderWithLength( dataset=train_dataset, batch_size=4, collate_fn=task.collate_fn, ) task_dataloader_dict[task_name] = train_dataloader for task_name, task in task_dict.items(): batch, batch_metadata = next(iter(task_dataloader_dict[task_name])) batch = batch.to(device) with torch.no_grad(): model_output = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=True, ) print(task_name) print(model_output) print()