Пример #1
0
    def __init__(self, config_dir="./bert-for-tf2", config_name="config.json"):
        with open(os.path.join(config_dir, config_name), 'r') as f:
            config = json.load(f)
            for key, value in config.items():
                self[key] = value

        with open(os.path.join(config_dir, 'class.json'), 'r') as f:
            self.classes = json.load(f)

        self.project_path = os.path.join(self.drive_path, "bert_sentiment")
        self.bert_model_path = os.path.join(self.project_path, "bert_model")
        self.data_path = os.path.join(self.project_path, "data")
        self.epoch_log_path = os.path.join(self.project_path, "epoch_logs",
                                           self.train_name)
        self.epoch_model_path = os.path.join(self.project_path, "epoch_models",
                                             self.train_name)
        self.tb_path = os.path.join(self.project_path, "logs", self.train_name)

        self.model_dir = bert.fetch_google_bert_model(self.model_name,
                                                      ".model")
        self.model_ckpt = os.path.join(self.bert_model_path, self.model_dir,
                                       "bert_model.ckpt")

        # bert ckpt path
        self.bert_ckpt_dir = os.path.join(self.bert_model_path,
                                          "multi_cased_L-12_H-768_A-12")
        self.bert_ckpt_file = os.path.join(self.bert_ckpt_dir,
                                           "bert_model.ckpt")
        self.bert_config_file = os.path.join(self.bert_ckpt_dir,
                                             "bert_config.json")
Пример #2
0
 def setUp(self) -> None:
     self.bert_name = "multilingual_L-12_H-768_A-12"
     self.bert_ckpt_dir = bert.fetch_google_bert_model(self.bert_name,
                                                       fetch_dir=".models")
     self.bert_ckpt_file = os.path.join(self.bert_ckpt_dir,
                                        "bert_model.ckpt")
     self.bert_config_file = os.path.join(self.bert_ckpt_dir,
                                          "bert_config.json")
Пример #3
0
def init_bert_tokenizer():
    model_name = "multi_cased_L-12_H-768_A-12"
    model_dir = bert.fetch_google_bert_model(model_name, ".model")
    model_ckpt = os.path.join(model_dir, "bert_model.ckpt")
    do_lower_case = not (model_name.find("cased") == 0
                         or model_name.find("multi_cased") == 0)
    bert.bert_tokenization.validate_case_matches_checkpoint(
        do_lower_case, model_ckpt)
    vocab_file = os.path.join(model_dir, "vocab.txt")
    return bert.bert_tokenization.FullTokenizer(vocab_file, do_lower_case)
    def test_bert_google_weights(self):
        bert_model_name = "uncased_L-12_H-768_A-12"
        bert_dir = bert.fetch_google_bert_model(bert_model_name, ".models")
        bert_ckpt = os.path.join(bert_dir, "bert_model.ckpt")

        bert_params = bert.params_from_pretrained_ckpt(bert_dir)
        model, l_bert = self.build_model(bert_params)

        skipped_weight_value_tuples = bert.load_bert_weights(l_bert, bert_ckpt)
        self.assertEqual(0, len(skipped_weight_value_tuples))
        model.summary()
Пример #5
0
def fetch_model(model_name):
    """Downloads BERT/ALBERT models and outputs their location.

    Args:
        model_name: String. Name of the model. See supported models at the top.

    Returns:
        pretrained_model_dir: String. Path to pretrained model.
        model_dir: String. Path to where the trained model is saved.
        model_type: String. Either "albert" or "bert".
    """
    model_type = validate_model(model_name)
    model_dir_prefix = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    model_dir = os.path.join("/app/models/trained", "{}_{}".format(model_dir_prefix, model_name))
    os.makedirs(model_dir, exist_ok=True)
    if model_type == "albert":
        try:
            model_file = glob.glob(
                os.path.join("/app/models/pretrained", model_name, "*", "model.ckpt-best.meta")
            )[0]
            pretrained_model_dir = os.path.dirname(model_file)
            logger.info("Located model at {}".format(pretrained_model_dir))
        except (OSError, IndexError):
            logger.info("Downloading model:[{}]".format(model_name))
            pretrained_model_dir = bert.fetch_google_albert_model(
                model_name, "/app/models/pretrained"
            )
    elif model_type == "bert":
        try:
            model_file = glob.glob(
                os.path.join(
                    "/app/models/pretrained", model_name, "bert_model.ckpt.data-00000-of-00001"
                )
            )[0]
            pretrained_model_dir = os.path.dirname(model_file)
            logger.info("Located model at {}".format(pretrained_model_dir))
        except (OSError, IndexError):
            logger.info("Downloading model:[{}]".format(model_name))
            pretrained_model_dir = bert.fetch_google_bert_model(
                model_name, "/app/models/pretrained"
            )

    return pretrained_model_dir, model_dir, model_type
Пример #6
0
def get_model_dir(model_name):
    return bert.fetch_google_bert_model(model_name, ".models")
Пример #7
0
 def test_coverage(self):
     try:
         bert.fetch_google_bert_model("not-existent_bert_model", ".models")
     except:
         pass
Пример #8
0
samples = './data/SMSSpam'
model_name = "uncased_L-12_H-768_A-12"
model_folder = "bert_models"

adapted_model = 'adapted' + model_name + '.h5'
output_folder = "."

batch_size = 16
max_seq_len = 100
adapter_size = 4
total_epoch_count = 2
buffer_size = 47
expected_number_spam = 747
print(f"\n fetching google pre-trained BERT model {model_name}")
model_dir = fetch_google_bert_model(model_name, model_folder)

bert_ckpt_file = join(model_dir, 'bert_model.ckpt')


def parse_raw_to_csv(raw_file='SMSSpam'):
	print(f"\n parsing {raw_file} into SMSSpam.csv ...")
	infile = open(raw_file, 'r')
	outfile = open('./data/SMSSpam.csv', 'w', newline='')
	columns = ['label', 'feature']
	spamwriter = csv.DictWriter(outfile, fieldnames=columns)
	spamwriter.writeheader()
	row = dict().fromkeys(columns)
	for line in infile:
		words = line.split()
		row['label'] = words[0]
Пример #9
0
import bert, os
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Input, Lambda
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy

MAX_LEN = 512

# Get BERT and tokenizer
bert_model_dir = bert.fetch_google_bert_model("uncased_L-12_H-768_A-12",
                                              "bert_models")
bert_model_ckpt = os.path.join(bert_model_dir, "bert_model.ckpt")
bert_params = bert.params_from_pretrained_ckpt(bert_model_dir)

vocab_file = os.path.join(bert_model_dir, "vocab.txt")
tokenizer = bert.bert_tokenization.FullTokenizer(vocab_file, True)

# Data pipeline
train, valid = tfds.load('imdb_reviews',
                         split=['train', 'test'],
                         as_supervised=True)


def tokenize(text, label):
    def _tokenize(text, label):
        tokens = tokenizer.tokenize(text.numpy())[:MAX_LEN - 2]
        tokens = ['[CLS]'] + tokens + ['[SEP]']
        token_ids = tokenizer.convert_tokens_to_ids(tokens)