def diagonal_fisher_computer( _initializer, _builder, _loader, num_examples, finetuned_ckpt_uuid=None, fisher_class_chunk_size=4096, y_samples=None, max_grad_value=None, ): if finetuned_ckpt_uuid is None: ctx = contextlib.suppress() else: ctx = scopes.binding_by_name_scope("checkpoint", finetuned_ckpt_uuid) with ctx: ft_model = _initializer() with scopes.binding_by_name_scope("model", ft_model): ft_model = _builder(ft_model) ft_model = _loader(ft_model) computer = diagonal.DiagonalFisherComputer( ft_model, total_examples=num_examples, y_samples=y_samples, class_chunk_size=fisher_class_chunk_size, max_grad_value=max_grad_value, ) # NOTE: I don't really think the next binding will ever be used, but I'm # putting here out of paranoia. with scopes.binding_by_name_scope("model", computer): computer.compile() return computer
def variational_diag_fisher_computer( _initializer, _builder, _loader, optimizer, finetuned_ckpt_uuid=None, variational_fisher_beta=1e-8, ): if finetuned_ckpt_uuid is None: ctx = contextlib.suppress() else: ctx = scopes.binding_by_name_scope("checkpoint", finetuned_ckpt_uuid) with ctx: ft_model = _initializer() with scopes.binding_by_name_scope("model", ft_model): ft_model = _builder(ft_model) ft_model = _loader(ft_model) computer = vardiag.VariationalDiagFisherComputer( ft_model, beta=variational_fisher_beta, ) # NOTE: I don't really think the next binding will ever be used, but I'm # putting here out of paranoia. with scopes.binding_by_name_scope("model", computer): computer.compile(optimizer=optimizer) return computer
def call( self, _dataset, compiled_model, _fit_kwargs, _train_history_saver, with_validation=True, ): with self.dataset_scope("train"): train_ds = _dataset() if with_validation: with self.dataset_scope("validation"): validation_ds = _dataset() else: validation_ds = None with scopes.binding_by_name_scope("model", compiled_model): fit_kwargs = _fit_kwargs() # # NOTE: This can be uncommented to assist with debugging. # compiled_model.run_eagerly = True history = compiled_model.fit(train_ds, validation_data=validation_ds, **fit_kwargs) if _train_history_saver: return _train_history_saver(history) else: return None
def diagonal_model_merger( mergeable_models, weightings, _initializer, _builder, _metrics=None, min_fisher=1e-6, normalize_fishers=False, multitask_merge=False, ): to_be_merged = _initializer() with scopes.binding_by_name_scope("model", to_be_merged): to_be_merged = _builder(to_be_merged) merged_models = diagonal.merge_models_with_weightings( to_be_merged, mergeable_models, weightings, single_task=not multitask_merge, min_fisher=min_fisher, normalize_fishers=normalize_fishers, ) for merged in merged_models: compile_kwargs = {} if _metrics: compile_kwargs["metrics"] = _metrics(merged) merged.compile(**compile_kwargs) yield merged
def diagonal_model_merge_weighting_search( mergeable_models, merge_weighting_search_steps, merge_weighting_num_inits, _initializer, _builder, _model_scorer, min_fisher=1e-6, multitask_merge=False, merge_on_cpu=False, ): to_be_merged = _initializer() with scopes.binding_by_name_scope("model", to_be_merged): to_be_merged = _builder(to_be_merged) ( merged_model, weighting, trial_weightings, trial_scores, ) = diagonal.merge_search_best_weighting( to_be_merged, mergeable_models=mergeable_models, score_fn=_model_scorer, max_evals=merge_weighting_search_steps, num_inits=merge_weighting_num_inits, min_fisher=min_fisher, single_task=not multitask_merge, merge_on_cpu=merge_on_cpu, ) merged_model.compile() return merged_model, weighting, trial_weightings, trial_scores
def bert_finetuning_model( _initializer, _builder, _loader, _regularizer=None, _optimizer=None, _loss=None, _metrics=None, ): model = _initializer() with scopes.binding_by_name_scope("model", model): model = _builder(model) model = _loader(model) if _regularizer: model = _regularizer(model) kwargs = {} if _optimizer: kwargs["optimizer"] = _optimizer() if _loss: kwargs["loss"] = _loss() if _metrics: kwargs["metrics"] = _metrics() model.compile(**kwargs) return model
def dummy_fisher_model_merger( mergeable_models, weightings, _initializer, _builder, _metrics=None, multitask_merge=False, ): to_be_merged = _initializer() with scopes.binding_by_name_scope("model", to_be_merged): to_be_merged = _builder(to_be_merged) for weighting in weightings: merged = dummy_merge_models( to_be_merged, mergeable_models, weighting, single_task=not multitask_merge, ) compile_kwargs = {} if _metrics: compile_kwargs["metrics"] = _metrics(merged) merged.compile(**compile_kwargs) logging.info("DUMMY MERGING!!!") yield merged
def worker_run( *, global_binding_specs, storage_params, group_cls, experiment_cls, executable_cls, init_kwargs=None, call_kwargs=None, # run_uuid=None, preload_blob_uuids=None, # The run_params are used purely for storage at the start of the experiment # and do not affect any execution. run_params=None, ): # NOTE: Should only be called on the worker. Users probably won't call # this method directly. if not init_kwargs: init_kwargs = {} if not call_kwargs: call_kwargs = {} # if not run_uuid: # run_uuid = storage_params.get_storage_cls().new_uuid() run_uuid = storage_params.get_storage_cls().new_uuid() # Due to the class decorators returning an instance, we should not # call these. group = group_cls experiment = experiment_cls extra_global_binding_specs = [ scopes.ArgNameBindingSpec("group", group), scopes.ArgNameBindingSpec("experiment", experiment), scopes.ArgNameBindingSpec("run_uuid", run_uuid), ] total_global_binding_specs = list( global_binding_specs) + extra_global_binding_specs with scopes.binding_scope(total_global_binding_specs): with storage_params.instantiate_storage() as storage: if preload_blob_uuids and storage.can_preload_blobs(): storage.preload_blobs(preload_blob_uuids) # NOTE: I might want to avoid injecting storage directly and instead mediate # interactions with storage via injected instances of ExperimentGroup, Experiment, # and Procedure. # # I'd need to think how re-usable executables that interact with storage such as # the checkpoint saver would work in that framework, though. with scopes.binding_by_name_scope("storage", storage): set_run_state()(RunState.STARTED) if run_params: save_params_at_run_start()(run_params) executable_cls(**init_kwargs)(**call_kwargs) set_run_state()(RunState.FINISHED)
def merge_weighting_search_scorer( merged_model, _evaluate_model, _single_score_from_results, ): merged_model.compile() # The evaluate_model returns the result of this function, so make it # the identity function. with scopes.binding_by_name_scope("evaluation_results_saver", pass_evaluation_results): results = _evaluate_model(merged_model) return _single_score_from_results(results)
def diagonal_mergeable_model_from_checkpoint( checkpoint, checkpoint_to_fisher_matrix_uuid, _initializer, _builder, _loader, storage, ): with tf.device("/cpu"): with scopes.binding_by_name_scope("checkpoint", checkpoint): ft_model = _initializer() with scopes.binding_by_name_scope("model", ft_model): ft_model = _builder(ft_model) ft_model = _loader(ft_model) fisher_matrix_uuid = checkpoint_to_fisher_matrix_uuid[checkpoint] logging.info(f"Retrieving saved fisher matrix: {fisher_matrix_uuid}") with storage.retrieve_blob_as_tempfile(fisher_matrix_uuid) as f: logging.info(f"Loading retrieved fisher matrix: {fisher_matrix_uuid}") fisher_matrix = diagonal.DiagonalFisherMatrix.load(f.name) return MergableModel(model=ft_model, fisher_matrix=fisher_matrix)
def robust_evaluate_model( compiled_model, # Dict from task name to eval dataset. Also has "{}_labels" key with a tensor of the labels. robust_evaluate_dataset, # Dict from task name to metric or list of metrics. metrics_for_tasks, _process_task_logits, _evaluation_results_saver, ): results = {} items = robust_evaluate_dataset.items() for task, dataset in items: if task.endswith("_labels"): continue labels = robust_evaluate_dataset[f"{task}_labels"] og_task = task task = _handle_mnli(task) start_time = time.time() task_logits = _get_task_logits( compiled_model.compute_task_logits, dataset, task, num_classes=compiled_model.get_num_classes_for_task(task), ) elapsed_seconds = time.time() - start_time elapsed_nice = str(datetime.timedelta(seconds=elapsed_seconds)) logging.info(f"Evaluation took {elapsed_nice}") with scopes.binding_by_name_scope("task", og_task): prediction_outputs = _process_task_logits(task_logits) metrics = metrics_for_tasks[og_task] if not isinstance(metrics, (list, tuple)): metrics = [metrics] task_results = {} for metric in metrics: task_results.update( metric(labels, prediction_outputs, return_dict=True)) results[og_task] = task_results logging.info(f"Evaluation results: {results}") return _evaluation_results_saver(results)
def evaluate_from_checkpoints_summary(checkpoints_summary, _compiled_model, _evaluate_model, should_clear_session=True): retvals = [] for i, checkpoint_blob_uuid in enumerate( checkpoints_summary.checkpoint_uuids): bindings = [("checkpoint", checkpoint_blob_uuid), ("checkpoint_index", i)] with scopes.binding_by_name_scopes(bindings): compiled_model = _compiled_model() with scopes.binding_by_name_scope("compiled_model", compiled_model): retval = _evaluate_model(compiled_model) retvals.append(retval) if should_clear_session: tf.keras.backend.clear_session() return retvals
def bert_mlm_model( _initializer, _builder, _loader, _metrics=None, ): model = _initializer() with scopes.binding_by_name_scope("model", model): model = _builder(model) model = _loader(model) kwargs = {} if _metrics: kwargs["metrics"] = _metrics() model.compile(**kwargs) return model
def diagonal_mergeable_model_from_checkpoint_or_pretrained( checkpoint, checkpoint_to_fisher_matrix_uuid, pretrained_model, _initializer, _builder, _loader, storage, pretrained_full_model=True, mergeable_model_pretrained_model=None, ): with tf.device("/cpu"): if checkpoint is None or _is_uuid(checkpoint): bindings = [("checkpoint", checkpoint)] if mergeable_model_pretrained_model: bindings.append(("pretrained_model", mergeable_model_pretrained_model)) else: bindings = [ ("pretrained_model", checkpoint), ("checkpoint", None), ("pretrained_body_only", not pretrained_full_model), ] with scopes.binding_by_name_scopes(bindings): ft_model = _initializer() with scopes.binding_by_name_scope("model", ft_model): ft_model = _builder(ft_model) ft_model = _loader(ft_model) if checkpoint is not None: fisher_matrix_uuid = checkpoint_to_fisher_matrix_uuid[checkpoint] else: fisher_matrix_uuid = checkpoint_to_fisher_matrix_uuid[pretrained_model] logging.info(f"Retrieving saved fisher matrix: {fisher_matrix_uuid}") with storage.retrieve_blob_as_tempfile(fisher_matrix_uuid) as f: logging.info(f"Loading retrieved fisher matrix: {fisher_matrix_uuid}") fisher_matrix = diagonal.DiagonalFisherMatrix.load(f.name) return MergableModel(model=ft_model, fisher_matrix=fisher_matrix)
def merge_and_evaluate_from_checkpoints( checkpoints, tasks, # TODO: Describe weightings, length is independent of checkpoints and tasks. weightings, _mergeable_model, _model_merger, _evaluate_model, multitask_merge=False, additional_model_bindings=None, ): assert len(checkpoints) == len(tasks) additional_model_bindings = additional_model_bindings or len(tasks) * [[]] mergeable_models = [] for checkpoint, task, extra_bindings in zip(checkpoints, tasks, additional_model_bindings): bindings = [ ("checkpoint", checkpoint), ("tasks", [task]), ("task", task), ] bindings.extend(extra_bindings) with scopes.binding_by_name_scopes(bindings): mergeable_model = _mergeable_model() mergeable_models.append(mergeable_model) # NOTE: Single task merge means we only care about the performance of # the first task. bindings = [ ("tasks", tasks if multitask_merge else tasks[:1]), ("task", tasks[0]), ] with scopes.binding_by_name_scopes(bindings): merged_models = _model_merger(mergeable_models=mergeable_models, weightings=weightings) for merged_model, weighting in zip(merged_models, weightings): logging.info(f"Evaluating task weighting {weighting}") with scopes.binding_by_name_scope("weighting", weighting): _evaluate_model(merged_model)