Ejemplo n.º 1
0
def load_pretrained_model(model_name,
                          config,
                          cache_dir,
                          custom_model_class,
                          if_tf_model=False):
    model_class_name = modelclass_dispatcher(model_name, custom_model_class)

    if model_class_name == "GPT2ModelNoPastState":
        return GPT2ModelNoPastState.from_pretrained(model_name,
                                                    config=config,
                                                    cache_dir=cache_dir)

    if model_class_name == "GPT2ModelNoPastState":
        if is_tf_model:
            raise NotImplementedError(
                "TFGPT2ModelNoPastState is currently not supported.")
        else:
            return GPT2ModelNoPastState.from_pretrained(model_name,
                                                        config=config,
                                                        cache_dir=cache_dir)

    if if_tf_model:
        model_class_name = 'TF' + model_class_name

    transformers_module = __import__("transformers",
                                     fromlist=[model_class_name])
    model_class = getattr(transformers_module, model_class_name)

    return model_class.from_pretrained(model_name,
                                       config=config,
                                       cache_dir=cache_dir)
Ejemplo n.º 2
0
def load_pretrained_model(model_name,
                          config,
                          cache_dir,
                          custom_model_class,
                          is_tf_model=False):
    model_class_name = modelclass_dispatcher(model_name, custom_model_class)

    if model_class_name == "GPT2ModelNoPastState":
        if is_tf_model:
            return TFGPT2ModelNoPastState.from_pretrained(model_name,
                                                          config=config,
                                                          cache_dir=cache_dir)
        else:
            return GPT2ModelNoPastState.from_pretrained(model_name,
                                                        config=config,
                                                        cache_dir=cache_dir)

    if is_tf_model:
        model_class_name = 'TF' + model_class_name

    transformers_module = __import__("transformers",
                                     fromlist=[model_class_name])
    logger.info(f"Model class name: {model_class_name}")
    model_class = getattr(transformers_module, model_class_name)

    return model_class.from_pretrained(model_name,
                                       config=config,
                                       cache_dir=cache_dir)
Ejemplo n.º 3
0
def load_pretrained_model(model_name, config, cache_dir):
    if model_name in ["gpt2", "distilgpt2", "gpt2-large"]:
        return GPT2ModelNoPastState.from_pretrained(model_name,
                                                    config=config,
                                                    cache_dir=cache_dir)
    return AutoModel.from_pretrained(model_name,
                                     config=config,
                                     cache_dir=cache_dir)
Ejemplo n.º 4
0
def load_pretrained_model(model_name, config, cache_dir):
    if model_name in PRETRAINED_GPT2_MODELS:
        return GPT2ModelNoPastState.from_pretrained(model_name,
                                                    config=config,
                                                    cache_dir=cache_dir)
    return AutoModel.from_pretrained(model_name,
                                     config=config,
                                     cache_dir=cache_dir)