コード例 #1
0
def diagonal_fisher_computer(
    _initializer,
    _builder,
    _loader,
    num_examples,
    finetuned_ckpt_uuid=None,
    fisher_class_chunk_size=4096,
    y_samples=None,
    max_grad_value=None,
):
    if finetuned_ckpt_uuid is None:
        ctx = contextlib.suppress()
    else:
        ctx = scopes.binding_by_name_scope("checkpoint", finetuned_ckpt_uuid)
    with ctx:
        ft_model = _initializer()
        with scopes.binding_by_name_scope("model", ft_model):
            ft_model = _builder(ft_model)
            ft_model = _loader(ft_model)

        computer = diagonal.DiagonalFisherComputer(
            ft_model,
            total_examples=num_examples,
            y_samples=y_samples,
            class_chunk_size=fisher_class_chunk_size,
            max_grad_value=max_grad_value,
        )
        # NOTE: I don't really think the next binding will ever be used, but I'm
        # putting here out of paranoia.
        with scopes.binding_by_name_scope("model", computer):
            computer.compile()

    return computer
コード例 #2
0
def variational_diag_fisher_computer(
    _initializer,
    _builder,
    _loader,
    optimizer,
    finetuned_ckpt_uuid=None,
    variational_fisher_beta=1e-8,
):
    if finetuned_ckpt_uuid is None:
        ctx = contextlib.suppress()
    else:
        ctx = scopes.binding_by_name_scope("checkpoint", finetuned_ckpt_uuid)
    with ctx:
        ft_model = _initializer()
        with scopes.binding_by_name_scope("model", ft_model):
            ft_model = _builder(ft_model)
            ft_model = _loader(ft_model)

        computer = vardiag.VariationalDiagFisherComputer(
            ft_model,
            beta=variational_fisher_beta,
        )
        # NOTE: I don't really think the next binding will ever be used, but I'm
        # putting here out of paranoia.
        with scopes.binding_by_name_scope("model", computer):
            computer.compile(optimizer=optimizer)

    return computer
コード例 #3
0
    def call(
        self,
        _dataset,
        compiled_model,
        _fit_kwargs,
        _train_history_saver,
        with_validation=True,
    ):
        with self.dataset_scope("train"):
            train_ds = _dataset()

        if with_validation:
            with self.dataset_scope("validation"):
                validation_ds = _dataset()
        else:
            validation_ds = None

        with scopes.binding_by_name_scope("model", compiled_model):
            fit_kwargs = _fit_kwargs()

            # # NOTE: This can be uncommented to assist with debugging.
            # compiled_model.run_eagerly = True

            history = compiled_model.fit(train_ds,
                                         validation_data=validation_ds,
                                         **fit_kwargs)

            if _train_history_saver:
                return _train_history_saver(history)
            else:
                return None
コード例 #4
0
def diagonal_model_merger(
    mergeable_models,
    weightings,
    _initializer,
    _builder,
    _metrics=None,
    min_fisher=1e-6,
    normalize_fishers=False,
    multitask_merge=False,
):
    to_be_merged = _initializer()
    with scopes.binding_by_name_scope("model", to_be_merged):
        to_be_merged = _builder(to_be_merged)

        merged_models = diagonal.merge_models_with_weightings(
            to_be_merged,
            mergeable_models,
            weightings,
            single_task=not multitask_merge,
            min_fisher=min_fisher,
            normalize_fishers=normalize_fishers,
        )

        for merged in merged_models:
            compile_kwargs = {}
            if _metrics:
                compile_kwargs["metrics"] = _metrics(merged)

            merged.compile(**compile_kwargs)

            yield merged
コード例 #5
0
def diagonal_model_merge_weighting_search(
    mergeable_models,
    merge_weighting_search_steps,
    merge_weighting_num_inits,
    _initializer,
    _builder,
    _model_scorer,
    min_fisher=1e-6,
    multitask_merge=False,
    merge_on_cpu=False,
):
    to_be_merged = _initializer()
    with scopes.binding_by_name_scope("model", to_be_merged):
        to_be_merged = _builder(to_be_merged)

        (
            merged_model,
            weighting,
            trial_weightings,
            trial_scores,
        ) = diagonal.merge_search_best_weighting(
            to_be_merged,
            mergeable_models=mergeable_models,
            score_fn=_model_scorer,
            max_evals=merge_weighting_search_steps,
            num_inits=merge_weighting_num_inits,
            min_fisher=min_fisher,
            single_task=not multitask_merge,
            merge_on_cpu=merge_on_cpu,
        )

    merged_model.compile()

    return merged_model, weighting, trial_weightings, trial_scores
コード例 #6
0
def bert_finetuning_model(
    _initializer,
    _builder,
    _loader,
    _regularizer=None,
    _optimizer=None,
    _loss=None,
    _metrics=None,
):
    model = _initializer()

    with scopes.binding_by_name_scope("model", model):
        model = _builder(model)
        model = _loader(model)
        if _regularizer:
            model = _regularizer(model)

        kwargs = {}
        if _optimizer:
            kwargs["optimizer"] = _optimizer()
        if _loss:
            kwargs["loss"] = _loss()
        if _metrics:
            kwargs["metrics"] = _metrics()

        model.compile(**kwargs)

    return model
コード例 #7
0
def dummy_fisher_model_merger(
    mergeable_models,
    weightings,
    _initializer,
    _builder,
    _metrics=None,
    multitask_merge=False,
):
    to_be_merged = _initializer()
    with scopes.binding_by_name_scope("model", to_be_merged):
        to_be_merged = _builder(to_be_merged)

        for weighting in weightings:
            merged = dummy_merge_models(
                to_be_merged,
                mergeable_models,
                weighting,
                single_task=not multitask_merge,
            )
            compile_kwargs = {}
            if _metrics:
                compile_kwargs["metrics"] = _metrics(merged)

            merged.compile(**compile_kwargs)

            logging.info("DUMMY MERGING!!!")

            yield merged
コード例 #8
0
ファイル: entrypoint.py プロジェクト: mmatena/del8
def worker_run(
    *,
    global_binding_specs,
    storage_params,
    group_cls,
    experiment_cls,
    executable_cls,
    init_kwargs=None,
    call_kwargs=None,
    # run_uuid=None,
    preload_blob_uuids=None,
    # The run_params are used purely for storage at the start of the experiment
    # and do not affect any execution.
    run_params=None,
):
    # NOTE: Should only be called on the worker. Users probably won't call
    # this method directly.
    if not init_kwargs:
        init_kwargs = {}
    if not call_kwargs:
        call_kwargs = {}

    # if not run_uuid:
    #     run_uuid = storage_params.get_storage_cls().new_uuid()
    run_uuid = storage_params.get_storage_cls().new_uuid()

    # Due to the class decorators returning an instance, we should not
    # call these.
    group = group_cls
    experiment = experiment_cls

    extra_global_binding_specs = [
        scopes.ArgNameBindingSpec("group", group),
        scopes.ArgNameBindingSpec("experiment", experiment),
        scopes.ArgNameBindingSpec("run_uuid", run_uuid),
    ]

    total_global_binding_specs = list(
        global_binding_specs) + extra_global_binding_specs

    with scopes.binding_scope(total_global_binding_specs):
        with storage_params.instantiate_storage() as storage:

            if preload_blob_uuids and storage.can_preload_blobs():
                storage.preload_blobs(preload_blob_uuids)

            # NOTE: I might want to avoid injecting storage directly and instead mediate
            # interactions with storage via injected instances of ExperimentGroup, Experiment,
            # and Procedure.
            #
            # I'd need to think how re-usable executables that interact with storage such as
            # the checkpoint saver would work in that framework, though.
            with scopes.binding_by_name_scope("storage", storage):
                set_run_state()(RunState.STARTED)
                if run_params:
                    save_params_at_run_start()(run_params)
                executable_cls(**init_kwargs)(**call_kwargs)
                set_run_state()(RunState.FINISHED)
コード例 #9
0
ファイル: merging_execs.py プロジェクト: mmatena/m251
def merge_weighting_search_scorer(
    merged_model,
    _evaluate_model,
    _single_score_from_results,
):
    merged_model.compile()
    # The evaluate_model returns the result of this function, so make it
    # the identity function.
    with scopes.binding_by_name_scope("evaluation_results_saver",
                                      pass_evaluation_results):
        results = _evaluate_model(merged_model)
    return _single_score_from_results(results)
コード例 #10
0
def diagonal_mergeable_model_from_checkpoint(
    checkpoint,
    checkpoint_to_fisher_matrix_uuid,
    _initializer,
    _builder,
    _loader,
    storage,
):
    with tf.device("/cpu"):
        with scopes.binding_by_name_scope("checkpoint", checkpoint):
            ft_model = _initializer()
            with scopes.binding_by_name_scope("model", ft_model):
                ft_model = _builder(ft_model)
                ft_model = _loader(ft_model)

        fisher_matrix_uuid = checkpoint_to_fisher_matrix_uuid[checkpoint]
        logging.info(f"Retrieving saved fisher matrix: {fisher_matrix_uuid}")
        with storage.retrieve_blob_as_tempfile(fisher_matrix_uuid) as f:
            logging.info(f"Loading retrieved fisher matrix: {fisher_matrix_uuid}")
            fisher_matrix = diagonal.DiagonalFisherMatrix.load(f.name)

        return MergableModel(model=ft_model, fisher_matrix=fisher_matrix)
コード例 #11
0
ファイル: eval_execs.py プロジェクト: mmatena/del8
def robust_evaluate_model(
    compiled_model,
    # Dict from task name to eval dataset. Also has "{}_labels" key with a tensor of the labels.
    robust_evaluate_dataset,
    # Dict from task name to metric or list of metrics.
    metrics_for_tasks,
    _process_task_logits,
    _evaluation_results_saver,
):
    results = {}
    items = robust_evaluate_dataset.items()
    for task, dataset in items:
        if task.endswith("_labels"):
            continue

        labels = robust_evaluate_dataset[f"{task}_labels"]

        og_task = task
        task = _handle_mnli(task)

        start_time = time.time()
        task_logits = _get_task_logits(
            compiled_model.compute_task_logits,
            dataset,
            task,
            num_classes=compiled_model.get_num_classes_for_task(task),
        )
        elapsed_seconds = time.time() - start_time
        elapsed_nice = str(datetime.timedelta(seconds=elapsed_seconds))
        logging.info(f"Evaluation took {elapsed_nice}")

        with scopes.binding_by_name_scope("task", og_task):
            prediction_outputs = _process_task_logits(task_logits)

        metrics = metrics_for_tasks[og_task]
        if not isinstance(metrics, (list, tuple)):
            metrics = [metrics]

        task_results = {}
        for metric in metrics:
            task_results.update(
                metric(labels, prediction_outputs, return_dict=True))

        results[og_task] = task_results

    logging.info(f"Evaluation results: {results}")

    return _evaluation_results_saver(results)
コード例 #12
0
ファイル: eval_execs.py プロジェクト: mmatena/del8
def evaluate_from_checkpoints_summary(checkpoints_summary,
                                      _compiled_model,
                                      _evaluate_model,
                                      should_clear_session=True):
    retvals = []
    for i, checkpoint_blob_uuid in enumerate(
            checkpoints_summary.checkpoint_uuids):
        bindings = [("checkpoint", checkpoint_blob_uuid),
                    ("checkpoint_index", i)]
        with scopes.binding_by_name_scopes(bindings):
            compiled_model = _compiled_model()
            with scopes.binding_by_name_scope("compiled_model",
                                              compiled_model):
                retval = _evaluate_model(compiled_model)
                retvals.append(retval)
            if should_clear_session:
                tf.keras.backend.clear_session()
    return retvals
コード例 #13
0
ファイル: bert_mlm_execs.py プロジェクト: mmatena/m251
def bert_mlm_model(
    _initializer,
    _builder,
    _loader,
    _metrics=None,
):
    model = _initializer()

    with scopes.binding_by_name_scope("model", model):
        model = _builder(model)
        model = _loader(model)

        kwargs = {}
        if _metrics:
            kwargs["metrics"] = _metrics()

        model.compile(**kwargs)

    return model
コード例 #14
0
def diagonal_mergeable_model_from_checkpoint_or_pretrained(
    checkpoint,
    checkpoint_to_fisher_matrix_uuid,
    pretrained_model,
    _initializer,
    _builder,
    _loader,
    storage,
    pretrained_full_model=True,
    mergeable_model_pretrained_model=None,
):
    with tf.device("/cpu"):
        if checkpoint is None or _is_uuid(checkpoint):
            bindings = [("checkpoint", checkpoint)]
            if mergeable_model_pretrained_model:
                bindings.append(("pretrained_model", mergeable_model_pretrained_model))
        else:
            bindings = [
                ("pretrained_model", checkpoint),
                ("checkpoint", None),
                ("pretrained_body_only", not pretrained_full_model),
            ]

        with scopes.binding_by_name_scopes(bindings):
            ft_model = _initializer()
            with scopes.binding_by_name_scope("model", ft_model):
                ft_model = _builder(ft_model)
                ft_model = _loader(ft_model)

        if checkpoint is not None:
            fisher_matrix_uuid = checkpoint_to_fisher_matrix_uuid[checkpoint]
        else:
            fisher_matrix_uuid = checkpoint_to_fisher_matrix_uuid[pretrained_model]

        logging.info(f"Retrieving saved fisher matrix: {fisher_matrix_uuid}")
        with storage.retrieve_blob_as_tempfile(fisher_matrix_uuid) as f:
            logging.info(f"Loading retrieved fisher matrix: {fisher_matrix_uuid}")
            fisher_matrix = diagonal.DiagonalFisherMatrix.load(f.name)

        return MergableModel(model=ft_model, fisher_matrix=fisher_matrix)
コード例 #15
0
ファイル: merging_execs.py プロジェクト: mmatena/m251
def merge_and_evaluate_from_checkpoints(
    checkpoints,
    tasks,
    # TODO: Describe weightings, length is independent of checkpoints and tasks.
    weightings,
    _mergeable_model,
    _model_merger,
    _evaluate_model,
    multitask_merge=False,
    additional_model_bindings=None,
):
    assert len(checkpoints) == len(tasks)
    additional_model_bindings = additional_model_bindings or len(tasks) * [[]]

    mergeable_models = []
    for checkpoint, task, extra_bindings in zip(checkpoints, tasks,
                                                additional_model_bindings):
        bindings = [
            ("checkpoint", checkpoint),
            ("tasks", [task]),
            ("task", task),
        ]
        bindings.extend(extra_bindings)
        with scopes.binding_by_name_scopes(bindings):
            mergeable_model = _mergeable_model()
            mergeable_models.append(mergeable_model)

    # NOTE: Single task merge means we only care about the performance of
    # the first task.
    bindings = [
        ("tasks", tasks if multitask_merge else tasks[:1]),
        ("task", tasks[0]),
    ]
    with scopes.binding_by_name_scopes(bindings):
        merged_models = _model_merger(mergeable_models=mergeable_models,
                                      weightings=weightings)
        for merged_model, weighting in zip(merged_models, weightings):
            logging.info(f"Evaluating task weighting {weighting}")
            with scopes.binding_by_name_scope("weighting", weighting):
                _evaluate_model(merged_model)