def _finetuned_to_mtm(run_params, fishers): return ModelToMerge( task=run_params.task, train_run_uuid=run_params.finetuned_run_uuid, fisher_run_uuid=run_params.run_uuid, model_checkpoint_uuid=run_params.finetuned_ckpt_uuid, fisher_matrix_uuid=fishers[run_params.run_uuid], )
def _pretrained_to_mtm(run_params, fishers): return ModelToMerge( task=run_params.task, train_run_uuid=None, fisher_run_uuid=run_params.run_uuid, model_checkpoint_uuid=run_params.pretrained_model, fisher_matrix_uuid=fishers[run_params.run_uuid], )
def _finetuned_to_mtm(run_params, fishers, additional_model_bindings=()): return ModelToMerge( task=run_params.task, train_run_uuid=run_params.finetuned_run_uuid, fisher_run_uuid=run_params.run_uuid, model_checkpoint_uuid=run_params.finetuned_ckpt_uuid, fisher_matrix_uuid=fishers[run_params.run_uuid], additional_model_bindings=additional_model_bindings, )
def _to_mtm(run_params, fishers): # Supports both fine-tuned and downloaded models. train_run_uuid = getattr(run_params, "finetuned_run_uuid", None) model_checkpoint_uuid = getattr(run_params, "finetuned_ckpt_uuid", run_params.pretrained_model) return ModelToMerge( task=run_params.task, train_run_uuid=train_run_uuid, fisher_run_uuid=run_params.run_uuid, model_checkpoint_uuid=model_checkpoint_uuid, fisher_matrix_uuid=fishers[run_params.run_uuid], )