def test_albert_google_weights_non_tfhub(self): albert_model_name = "albert_base_v2" albert_dir = bert.fetch_google_albert_model(albert_model_name, ".models") model_ckpt = os.path.join(albert_dir, "model.ckpt-best") albert_params = bert.albert_params(albert_dir) model, l_bert = self.build_model(albert_params) skipped_weight_value_tuples = bert.load_albert_weights(l_bert, model_ckpt) self.assertEqual(0, len(skipped_weight_value_tuples)) model.summary()
def initialize_bert_tokenizer(): """ Function to initialize the bert tokenizer Returns: Tokenizer (bert.tokenization.albert_tokenization.FullTokenizer) """ model_name = "albert_base_v2" model_dir = bert.fetch_google_albert_model(model_name, ".models") spm = os.path.join(model_dir, "30k-clean.model") vocab = os.path.join(model_dir, "30k-clean.vocab") Tokenizer = bert.albert_tokenization.FullTokenizer(vocab, spm_model_file=spm) return Tokenizer
def fetch_bert_layer(): """ Function to return ALBERT layer and weights Returns: l_bert (bert.model.BertModelLayer): BERT layer model_ckpt (str): path to best model checkpoint """ model_name = "albert_base_v2" model_dir = bert.fetch_google_albert_model(model_name, ".models") model_ckpt = os.path.join(model_dir, "model.ckpt-best") model_params = bert.albert_params(model_dir) l_bert = bert.BertModelLayer.from_params(model_params, name="albert") return l_bert, model_ckpt
def fetch_model(model_name): """Downloads BERT/ALBERT models and outputs their location. Args: model_name: String. Name of the model. See supported models at the top. Returns: pretrained_model_dir: String. Path to pretrained model. model_dir: String. Path to where the trained model is saved. model_type: String. Either "albert" or "bert". """ model_type = validate_model(model_name) model_dir_prefix = datetime.datetime.now().strftime("%Y%m%d%H%M%S") model_dir = os.path.join("/app/models/trained", "{}_{}".format(model_dir_prefix, model_name)) os.makedirs(model_dir, exist_ok=True) if model_type == "albert": try: model_file = glob.glob( os.path.join("/app/models/pretrained", model_name, "*", "model.ckpt-best.meta") )[0] pretrained_model_dir = os.path.dirname(model_file) logger.info("Located model at {}".format(pretrained_model_dir)) except (OSError, IndexError): logger.info("Downloading model:[{}]".format(model_name)) pretrained_model_dir = bert.fetch_google_albert_model( model_name, "/app/models/pretrained" ) elif model_type == "bert": try: model_file = glob.glob( os.path.join( "/app/models/pretrained", model_name, "bert_model.ckpt.data-00000-of-00001" ) )[0] pretrained_model_dir = os.path.dirname(model_file) logger.info("Located model at {}".format(pretrained_model_dir)) except (OSError, IndexError): logger.info("Downloading model:[{}]".format(model_name)) pretrained_model_dir = bert.fetch_google_bert_model( model_name, "/app/models/pretrained" ) return pretrained_model_dir, model_dir, model_type