Beispiel #1
0
def save_bert(model, optimizer, args, save_path, save_mode="all", verbose=True):
    assert save_mode in [
        "all", "tunable", "model_all", "model_tunable",
    ]

    save_dict = dict()

    # Save args
    save_dict["args"] = vars(args)

    # Save model
    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model itself
    if save_mode in ["all", "model_all"]:
        model_state_dict = model_to_save.state_dict()
    elif save_mode in ["tunable", "model_tunable"]:
        model_state_dict = get_tunable_state_dict(model_to_save)
    else:
        raise KeyError(save_mode)
    if verbose:
        print("Saving {} model elems:".format(len(model_state_dict)))
    save_dict["model"] = utils.to_cpu(model_state_dict)

    # Save optimizer
    if save_mode in ["all", "tunable"]:
        optimizer_state_dict = utils.to_cpu(optimizer.state_dict()) if optimizer is not None else None
        if verbose:
            print("Saving {} optimizer elems:".format(len(optimizer_state_dict)))

    torch.save(save_dict, save_path)
Beispiel #2
0
def save_bert(glue_lm_model,
              optimizer,
              args,
              save_path,
              save_mode="all",
              verbose=True):
    glue_model = glue_lm_model.glue_model
    lm_model = glue_lm_model.lm_model
    assert save_mode in [
        "all",
        "tunable",
        "model_all",
        "model_tunable",
    ]

    save_dict = dict()

    # Save args
    save_dict["args"] = vars(args)

    # Save model
    glue_model_to_save = glue_model.module \
        if hasattr(glue_model, 'module') \
        else glue_model  # Only save the model itself
    if save_mode in ["all", "model_all"]:
        glue_model_state_dict = glue_model_to_save.state_dict()
        lm_state_dict = get_lm_cls_state_dict(lm_model)
    elif save_mode in ["tunable", "model_tunable"]:
        glue_model_state_dict = get_tunable_state_dict(glue_model_to_save)
        lm_state_dict = get_tunable_state_dict(get_lm_cls_state_dict(lm_model))
    else:
        raise KeyError(save_mode)
    if verbose:
        print("Saving {} glue model elems:".format(len(glue_model_state_dict)))
        print("Saving {} lm model elems:".format(len(lm_state_dict)))
    save_dict["model"] = utils.to_cpu(glue_model_state_dict)
    save_dict["lm_model"] = utils.to_cpu(lm_state_dict)

    # Save optimizer
    if save_mode in ["all", "tunable"]:
        optimizer_state_dict = utils.to_cpu(
            optimizer.state_dict()) if optimizer is not None else None
        if verbose:
            print("Saving {} optimizer elems:".format(
                len(optimizer_state_dict)))

    torch.save(save_dict, save_path)
Beispiel #3
0
    def save_best_model(self, save_mode="model_all", verbose=True):
        classification_model = self.classification_lm_model.classification_model
        lm_model = self.classification_lm_model.lm_model
        assert save_mode in [
            "all",
            "tunable",
            "model_all",
            "model_tunable",
        ]
        save_dict = dict()
        # Save model
        classification_model_to_save = classification_model.module \
            if hasattr(classification_model, 'module') \
            else classification_model  # Only save the model itself
        if save_mode in ["all", "model_all"]:
            classification_model_state_dict = classification_model_to_save.state_dict(
            )
            lm_state_dict = lm_model.state_dict()
        elif save_mode in ["tunable", "model_tunable"]:
            raise NotImplementedError
        else:
            raise KeyError(save_mode)
        if verbose:
            print("Saving {} classification model elems:".format(
                len(classification_model_state_dict)))
            print("Saving {} lm model elems:".format(len(lm_state_dict)))
        save_dict["model"] = utils.to_cpu(classification_model_state_dict)
        save_dict["lm_model"] = utils.to_cpu(lm_state_dict)

        # Save optimizer
        if save_mode in ["all", "tunable"]:
            optimizer_state_dict = utils.to_cpu(self.optimizer.state_dict(
            )) if self.optimizer is not None else None
            if verbose:
                print("Saving {} optimizer elems:".format(
                    len(optimizer_state_dict)))

        torch.save(save_dict, os.path.join(self.output_path, "all_state.p"))