Example #1
0
def setup_multitask_ptt_model(model_type, config_path, tokenizer_path,
                              task_dict):
    model_arch = ModelArchitectures.from_model_type(model_type)
    assert ModelArchitectures.is_ptt_model_arch(model_arch)

    # 1. Retrieve class specs
    model_class_spec_dict = {}
    for task_name, task in task_dict.items():
        model_class_spec_dict[
            task_name] = model_resolution.resolve_model_setup_classes(
                model_type=model_type,
                task_type=task.TASK_TYPE,
            )

    # 2. Get tokenizer
    tokenizer_class_list = [
        model_class_spec.tokenizer_class
        for model_class_spec in model_class_spec_dict.values()
    ]
    tokenizer_class_list = list(set(tokenizer_class_list))
    tokenizer = model_setup.get_tokenizer(
        model_type=model_type,
        tokenizer_class=pyutils.take_one(tokenizer_class_list),
        tokenizer_path=tokenizer_path,
    )

    # 3. Get model
    shared_ptt_encoder = None
    model_dict = {}
    for task_name, task in task_dict.items():
        task_model = model_setup.get_model(
            model_class_spec=model_class_spec_dict[task_name],
            config_path=config_path,
            task=task,
        )
        encoder = get_ptt_encoder(task_model)
        if shared_ptt_encoder is None:
            shared_ptt_encoder = encoder
        else:
            set_ptt_encoder(task_model, shared_ptt_encoder)
        model_dict[task_name] = task_model

    multitask_model = multitask_modeling.MultiTaskModel(
        model_dict=model_dict,
        shared_ptt_encoder=shared_ptt_encoder,
    )

    return model_setup.ModelWrapper(
        model=multitask_model,
        tokenizer=tokenizer,
    )
Example #2
0
def simple_load_model(model, state_dict, model_load_mode, verbose=True):
    if model_load_mode == "strict":
        model.load_state_dict(state_dict)
    elif model_load_mode == "safe":
        if ModelArchitectures.from_ptt_model(
                model) == ModelArchitectures.ALBERT:
            # TODO: add safer check for ALBERT models
            safe_load_model(
                model=model,
                state_dict=state_dict,
                verbose=verbose,
                max_miss_fraction=0.66,
            )
        else:
            safe_load_model(
                model=model,
                state_dict=state_dict,
                verbose=verbose,
            )
    elif model_load_mode == "base_weights":
        model.load_state_dict(
            load_model_base_weights(
                model=model,
                state_dict=state_dict,
            ))
    elif model_load_mode == "no_load":
        pass
    else:
        raise KeyError(model_load_mode)
Example #3
0
def get_ptt_model_embedding_dim(ptt_model):
    model_arch = ModelArchitectures.from_ptt_model(ptt_model)
    if model_arch in (ModelArchitectures.BERT, ModelArchitectures.XLNET,
                      ModelArchitectures.XLM, ModelArchitectures.ROBERTA):
        return ptt_model.config.hidden_size
    elif model_arch == ModelArchitectures.GLOVE_LSTM:
        return ptt_model.model.hidden_dim
    else:
        raise KeyError(model_arch)
Example #4
0
def load_model_base_weights(model, state_dict):
    arch = ModelArchitectures.from_ptt_model(model)
    if arch == ModelArchitectures.BERT:
        new_state_dict = {
            k: v
            for k, v in state_dict.items() if not k.startswith("classifier.")
        }
    else:
        raise NotImplementedError()
    return new_state_dict
Example #5
0
def simple_model_setup(model_type, model_class_spec, config_path,
                       tokenizer_path, task):
    model_arch = ModelArchitectures.from_model_type(model_type)
    if ModelArchitectures.is_ptt_model_arch(model_arch):
        return simple_ptt_model_setup(
            model_type=model_type,
            model_class_spec=model_class_spec,
            config_path=config_path,
            tokenizer_path=tokenizer_path,
            task=task,
        )
    elif model_arch in [ModelArchitectures.GLOVE_LSTM]:
        return glove_lstm_setup(
            config_path=config_path,
            tokenizer_path=tokenizer_path,
            task=task,
        )
    else:
        raise KeyError(model_arch)
Example #6
0
    def __init__(self, ptt_model, embedding_dim, dropout_p=0.5):
        super().__init__()
        self.ptt_model = ptt_model
        self.embedding_dim = embedding_dim
        self.dropout_p = dropout_p

        self.model_arch = ModelArchitectures.from_ptt_model(ptt_model)
        self.embedding_layer = nn.Linear(
            get_ptt_model_embedding_dim(ptt_model),
            embedding_dim,
        )
        self.dropout = nn.Dropout(p=dropout_p)
Example #7
0
def _get_ptt_encoder_attr(ptt_model):
    model_arch = ModelArchitectures.from_ptt_model(ptt_model)
    # Will probably need to refactor this out later
    if model_arch == ModelArchitectures.BERT:
        return "bert"
    elif model_arch == ModelArchitectures.XLNET:
        return "transformer"
    elif model_arch == ModelArchitectures.XLM:
        return "transformer"
    elif model_arch == ModelArchitectures.ROBERTA:
        return "roberta"
    else:
        raise KeyError(model_arch)
Example #8
0
def get_tokenizer(model_type, tokenizer_class, tokenizer_path):
    model_arch = ModelArchitectures.from_model_type(model_type)
    if model_arch in [ModelArchitectures.BERT]:
        if "-cased" in model_type:
            do_lower_case = False
        elif "-uncased" in model_type:
            do_lower_case = True
        else:
            raise RuntimeError(model_type)
    elif model_arch in [
            ModelArchitectures.XLNET, ModelArchitectures.XLM,
            ModelArchitectures.ROBERTA, ModelArchitectures.ALBERT
    ]:
        do_lower_case = False
    else:
        raise RuntimeError(model_type)
    tokenizer = tokenizer_class.from_pretrained(
        tokenizer_path,
        do_lower_case=do_lower_case,
    )
    return tokenizer