コード例 #1
0
    def test_albert_google_weights_non_tfhub(self):
        albert_model_name = "albert_base_v2"
        albert_dir = bert.fetch_google_albert_model(albert_model_name, ".models")
        model_ckpt = os.path.join(albert_dir, "model.ckpt-best")

        albert_params = bert.albert_params(albert_dir)
        model, l_bert = self.build_model(albert_params)

        skipped_weight_value_tuples = bert.load_albert_weights(l_bert, model_ckpt)
        self.assertEqual(0, len(skipped_weight_value_tuples))
        model.summary()
def initialize_bert_tokenizer():
    """
    Function to initialize the bert tokenizer

    Returns:
        Tokenizer (bert.tokenization.albert_tokenization.FullTokenizer)
    """
    model_name = "albert_base_v2"
    model_dir = bert.fetch_google_albert_model(model_name, ".models")
    spm = os.path.join(model_dir, "30k-clean.model")
    vocab = os.path.join(model_dir, "30k-clean.vocab")
    Tokenizer = bert.albert_tokenization.FullTokenizer(vocab,
                                                       spm_model_file=spm)
    return Tokenizer
コード例 #3
0
def fetch_bert_layer():
    """
    Function to return ALBERT layer and weights

    Returns:
        l_bert (bert.model.BertModelLayer): BERT layer
        model_ckpt (str): path to best model checkpoint
    """
    model_name = "albert_base_v2"
    model_dir = bert.fetch_google_albert_model(model_name, ".models")
    model_ckpt = os.path.join(model_dir, "model.ckpt-best")
    model_params = bert.albert_params(model_dir)
    l_bert = bert.BertModelLayer.from_params(model_params, name="albert")
    return l_bert, model_ckpt
コード例 #4
0
ファイル: loader.py プロジェクト: xlzwhboy/bert_in_a_flask
def fetch_model(model_name):
    """Downloads BERT/ALBERT models and outputs their location.

    Args:
        model_name: String. Name of the model. See supported models at the top.

    Returns:
        pretrained_model_dir: String. Path to pretrained model.
        model_dir: String. Path to where the trained model is saved.
        model_type: String. Either "albert" or "bert".
    """
    model_type = validate_model(model_name)
    model_dir_prefix = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    model_dir = os.path.join("/app/models/trained", "{}_{}".format(model_dir_prefix, model_name))
    os.makedirs(model_dir, exist_ok=True)
    if model_type == "albert":
        try:
            model_file = glob.glob(
                os.path.join("/app/models/pretrained", model_name, "*", "model.ckpt-best.meta")
            )[0]
            pretrained_model_dir = os.path.dirname(model_file)
            logger.info("Located model at {}".format(pretrained_model_dir))
        except (OSError, IndexError):
            logger.info("Downloading model:[{}]".format(model_name))
            pretrained_model_dir = bert.fetch_google_albert_model(
                model_name, "/app/models/pretrained"
            )
    elif model_type == "bert":
        try:
            model_file = glob.glob(
                os.path.join(
                    "/app/models/pretrained", model_name, "bert_model.ckpt.data-00000-of-00001"
                )
            )[0]
            pretrained_model_dir = os.path.dirname(model_file)
            logger.info("Located model at {}".format(pretrained_model_dir))
        except (OSError, IndexError):
            logger.info("Downloading model:[{}]".format(model_name))
            pretrained_model_dir = bert.fetch_google_bert_model(
                model_name, "/app/models/pretrained"
            )

    return pretrained_model_dir, model_dir, model_type