Esempio n. 1
0
 def __init__(self, config: BertConfig, bert_name: str):
     super(SpatialContinuousBert, self).__init__()
     self.cliparts_embeddings = nn.Embedding(
         num_embeddings=config.vocab_size,
         embedding_dim=config.hidden_size,
         padding_idx=0,
     )
     self.x_embeddings = nn.Embedding(
         num_embeddings=X_PAD + 1,
         embedding_dim=config.hidden_size,
         padding_idx=X_PAD,
     )
     self.y_embeddings = nn.Embedding(
         num_embeddings=Y_PAD + 1,
         embedding_dim=config.hidden_size,
         padding_idx=Y_PAD,
     )
     self.o_embeddings = nn.Embedding(
         num_embeddings=O_PAD + 1,
         embedding_dim=config.hidden_size,
         padding_idx=O_PAD,
     )
     self.pos_layer_norm = torch.nn.LayerNorm(
         config.hidden_size, eps=config.layer_norm_eps
     )
     self.pos_dropout = nn.Dropout(config.hidden_dropout_prob)
     self.bert = BertModel.from_pretrained(bert_name)
     # Change config for the positions
     config.vocab_size = 2
     self.xy_head = BertOnlyMLMHead(config)
     config.vocab_size = O_PAD + 1
     self.o_head = BertOnlyMLMHead(config)
     self.log_softmax = nn.LogSoftmax(dim=-1)
Esempio n. 2
0
def get_model(vocab_size=30000):
    config_encoder = BertConfig()
    config_decoder = BertConfig()

    config_encoder.vocab_size = vocab_size
    config_decoder.vocab_size = vocab_size

    config_decoder.is_decoder = True
    config_decoder.add_cross_attention = True

    config = EncoderDecoderConfig.from_encoder_decoder_configs(
        config_encoder, config_decoder)
    model = EncoderDecoderModel(config=config)

    return model
Esempio n. 3
0
def set_model_config(args, tokenizer):
    sentence_config = BertConfig()
    sentence_config.vocab_size = tokenizer.get_vocab_size()
    sentence_config.num_hidden_layers = args.num_layers1
    sentence_config.hidden_size = args.hidden_size1
    sentence_config.num_attention_heads = args.attention_heads1
    sentence_config.max_position_embeddings = args.block_length

    document_config = BertConfig()
    document_config.vocab_size = tokenizer.get_vocab_size()
    document_config.num_hidden_layers = args.num_layers2
    document_config.hidden_size = args.hidden_size2
    document_config.num_attention_heads = args.attention_heads2
    document_config.num_masked_blocks = args.max_blocks
    document_config.max_position_embeddings = args.max_blocks

    return sentence_config, document_config
Esempio n. 4
0
    def __init__(self, config, num=0):
        super(Bert, self).__init__()
        model_config = BertConfig()
        model_config.vocab_size = config.vocab_size
        # 计算loss的方法
        self.loss_method = config.loss_method
        self.multi_drop = config.multi_drop

        self.bert = BertModel(model_config)
        if config.requires_grad:
            for param in self.bert.parameters():
                param.requires_grad = True
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.hidden_size = config.hidden_size[num]
        if self.loss_method in ['binary', 'focal_loss', 'ghmc']:
            self.classifier = nn.Linear(self.hidden_size, 1)
        else:
            self.classifier = nn.Linear(self.hidden_size, self.num_labels)

        self.classifier.apply(self._init_weights)
        self.bert.apply(self._init_weights)
Esempio n. 5
0
tokenizer = BertWordPieceTokenizer(
    r'C:\Users\David\Documents\Machine_learning\NLP\CardioExplorer\vocab.txt',
    lowercase=True)

pretrain = True
sentence_block_length = 32
max_sentence_blocks = 48
hidden_size = 256
batch_size = 4
shuffle = True
drop_last = True

sentence_block_vector = torch.normal(mean=0.0, std=1.0, size=[hidden_size])

sentence_config = BertConfig()
sentence_config.vocab_size = tokenizer.get_vocab_size()
sentence_config.num_hidden_layers = 6
sentence_config.hidden_size = 256
sentence_config.num_attention_heads = 4
sentence_config.max_position_embeddings = sentence_block_length  # sentence_block_length

document_config = BertConfig()
document_config.vocab_size = tokenizer.get_vocab_size()
document_config.num_hidden_layers = 3
document_config.hidden_size = 256
document_config.num_attention_heads = 4
document_config.max_position_embeddings = max_sentence_blocks  # sentence_block_length

dataset = Dataset(file_path,
                  tokenizer,
                  sentence_block_length,
Esempio n. 6
0
def get_model(use_pretrained_weight):
    config = BertConfig()
    config.vocab_size = 37 + 5
    config.hidden_size = 120
    return ActorCritic(config)
Esempio n. 7
0
from transformers import BertConfig, TrainingArguments, BertModel, Trainer, AutoModelForMaskedLM, BertTokenizer
import torch.nn as nn
import torch
import torch.nn.functional as F
import heapq
import numpy
from pypinyin import pinyin, Style

config = BertConfig()
config.vocab_size = 41460  # 句子词典
model = AutoModelForMaskedLM.from_config(config)
model.bert.embeddings.word_embeddings = nn.Embedding(1839, 768, padding_idx=0)
state_dict = torch.load('./results/checkpoint-00000/pytorch_model.bin',
                        map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
pinyin_list = [
    i
    for tmp in pinyin('手机没电了', style=Style.TONE3, neutral_tone_with_five=True)
    for i in tmp
]
con_tokenizer = BertTokenizer.from_pretrained('y2d1')
lab_tokenizer = BertTokenizer.from_pretrained('z2d')
con = torch.tensor(
    con_tokenizer.convert_tokens_to_ids(pinyin_list)).unsqueeze(0)
out_top5 = torch.topk(F.softmax(model(con)[0].squeeze(0), dim=-1), k=10)
values = out_top5[0].detach().numpy().tolist()
indices = out_top5[1].detach().numpy().tolist()
for i, item in enumerate(indices):
    print(lab_tokenizer.convert_ids_to_tokens(item))
    print(values[i])
def get_model(enable_model_name, is_pretraining, pretrained_path):
    # tile(37), menzen(2), reach_state(2), n_reach(3),
    # reach_ippatsu(2), dans(21), rates(19), oya(4),
    # scores(13), n_honba(3), n_round(12), sanma_or_yonma(2),
    # han_or_ton(2), aka_ari(2), kui_ari(2), special_token(4)
    # vocab_size = 37 + 2 + 2 + 3 + 2 + 21 + 19 + 4 + 13 + 3 + 12 + 2 + 2 + 2 + 2 + 4 + 2 + 4 + 6 + 8 # 130 + shanten_diff(2) + who(4) + sum_discards(6) + shanten(8)
    vocab_size = 37 + 2 + 2 + 3 + 2 + 21 + 19 + 4 + 13 + 3 + 12 + 2 + 2 + 2 + 2 + 4 + 4 + 6 + 8  # 130 + who(4) + sum_discards(6) + shanten(8)
    # hidden_size = 1024
    # num_attention_heads = 16
    hidden_size = 768
    num_attention_heads = 12
    max_position_embeddings = 239  # base + who(1) + sum_discards(1) + shanten(1)
    # intermediate_size = 64
    # intermediate_size = 3072
    # max_position_embeddings = 239 # base + pad(1) + who(1) + pad(1) + sum_discards(1) + pad(1) + shanten(1)
    # max_position_embeddings = 281 # 260 + pad(1) + shanten_diff(14) + pad(1) + who(1) + pad(1) + sum_discards(1) + pad(1) + shanten(1)

    if is_pretraining:
        config = BertConfig()
        config.vocab_size = vocab_size
        config.hidden_size = hidden_size
        config.num_attention_heads = num_attention_heads
        config.max_position_embeddings = max_position_embeddings
        config.num_hidden_layers = 12
        return MahjongPretrainingModel(config)

    model = None
    if enable_model_name == 'discard':
        discard_config = BertConfig()
        discard_config.vocab_size = vocab_size
        discard_config.hidden_size = hidden_size
        discard_config.num_attention_heads = num_attention_heads
        discard_config.max_position_embeddings = max_position_embeddings
        discard_config.num_hidden_layers = 12
        # discard_config.intermediate_size = intermediate_size
        # discard_config.num_hidden_layers = 24
        # discard_config.num_hidden_layers = 12
        model = MahjongDiscardModel(discard_config)
    elif enable_model_name == 'reach':
        reach_config = BertConfig()
        reach_config.vocab_size = vocab_size
        reach_config.hidden_size = hidden_size
        reach_config.num_attention_heads = num_attention_heads
        reach_config.max_position_embeddings = max_position_embeddings
        reach_config.num_hidden_layers = 24
        model = MahjongReachChowPongKongModel(reach_config)
    elif enable_model_name == 'chow':
        chow_config = BertConfig()
        chow_config.vocab_size = vocab_size
        chow_config.hidden_size = hidden_size
        chow_config.num_attention_heads = num_attention_heads
        chow_config.max_position_embeddings = max_position_embeddings
        chow_config.num_hidden_layers = 24
        model = MahjongReachChowPongKongModel(chow_config)
    elif enable_model_name == 'pong':
        pong_config = BertConfig()
        pong_config.vocab_size = vocab_size
        pong_config.hidden_size = hidden_size
        pong_config.num_attention_heads = num_attention_heads
        pong_config.max_position_embeddings = max_position_embeddings
        pong_config.num_hidden_layers = 24
        model = MahjongReachChowPongKongModel(pong_config)
    elif enable_model_name == 'kong':
        kong_config = BertConfig()
        kong_config.vocab_size = vocab_size
        kong_config.hidden_size = hidden_size
        kong_config.num_attention_heads = num_attention_heads
        kong_config.max_position_embeddings = max_position_embeddings
        kong_config.num_hidden_layers = 24
        model = MahjongReachChowPongKongModel(kong_config)

    if pretrained_path != '':
        checkpoint = torch.load(pretrained_path,
                                map_location=catalyst.utils.get_device())
        # print(checkpoint['model_state_dict'].keys())
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)

    return model
    def __init__(self, config, language_pretrained_model_path=None):
        super(VisualLinguisticBertDecoder, self).__init__(config)

        self.config = config

        # embeddings
        self.word_embeddings = nn.Embedding(config.vocab_size,
                                            config.hidden_size)
        self.end_embedding = nn.Embedding(1, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
                                                  config.hidden_size)
        self.embedding_LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)

        # for compatibility of roberta
        self.position_padding_idx = config.position_padding_idx

        # visual transform
        self.visual_1x1_text = None
        self.visual_1x1_object = None
        if config.visual_size != config.hidden_size:
            self.visual_1x1_text = nn.Linear(config.visual_size,
                                             config.hidden_size)
            self.visual_1x1_object = nn.Linear(config.visual_size,
                                               config.hidden_size)
        if config.visual_ln:
            self.visual_ln_text = BertLayerNorm(config.hidden_size, eps=1e-12)
            self.visual_ln_object = BertLayerNorm(config.hidden_size,
                                                  eps=1e-12)
        else:
            visual_scale_text = nn.Parameter(torch.as_tensor(
                self.config.visual_scale_text_init, dtype=torch.float),
                                             requires_grad=True)
            self.register_parameter('visual_scale_text', visual_scale_text)
            visual_scale_object = nn.Parameter(torch.as_tensor(
                self.config.visual_scale_object_init, dtype=torch.float),
                                               requires_grad=True)
            self.register_parameter('visual_scale_object', visual_scale_object)

        # *********************************************
        # FM addition - Set-up decoder layer for MT
        #  Initializing a BERT bert-base-uncased style configuration
        configuration = BertConfig()
        configuration.vocab_size = config.vocab_size
        # FM edit: reduce size - 12 layers doesn't fit in single 12GB GPU
        configuration.num_hidden_layers = 6
        configuration.is_decoder = True
        # Initializing a model from the bert-base-uncased style configuration
        self.decoder = BertModel(configuration)
        # *********************************************

        if self.config.with_pooler:
            self.pooler = BertPooler(config)

        # init weights
        self.apply(self.init_weights)
        if config.visual_ln:
            self.visual_ln_text.weight.data.fill_(
                self.config.visual_scale_text_init)
            self.visual_ln_object.weight.data.fill_(
                self.config.visual_scale_object_init)

        # load language pretrained model
        if language_pretrained_model_path is not None:
            self.load_language_pretrained_model(language_pretrained_model_path)

        if config.word_embedding_frozen:
            for p in self.word_embeddings.parameters():
                p.requires_grad = False
            self.special_word_embeddings = nn.Embedding(
                NUM_SPECIAL_WORDS, config.hidden_size)
            self.special_word_embeddings.weight.data.copy_(
                self.word_embeddings.weight.data[:NUM_SPECIAL_WORDS])
Esempio n. 10
0
def train_process(config, train_load, valid_load, test_load, k, train_sampler):

    # load source bert weights
    # model_config = BertConfig.from_pretrained(pretrained_model_name_or_path="../user_data/bert_source/{}/config.json".format(config.model_name))
    model_config = BertConfig()
    model_config.vocab_size = len(
        pd.read_csv('../user_data/vocab', names=["score"]))
    model = BertForSequenceClassification(config=model_config)

    if os.path.isfile('save_model/{}_best_model_v1111.pth.tar'.format(
            config.model_name)):
        checkpoint = torch.load('save_model/{}_best_model_v1.pth.tar'.format(
            config.model_name),
                                map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['status'], strict=False)
        best_dev_auc = 0
        print('***********load best model weight*************')
    else:
        checkpoint = torch.load(
            '../user_data/save_bert/{}_checkpoint.pth.tar'.format(
                config.model_name),
            map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['status'], strict=False)
        best_dev_auc = 0
        print('***********load pretrained mlm model weight*************')

    for param in model.parameters():
        param.requires_grad = True

    # 4) 封装之前要把模型移到对应的gpu
    model = model.to(config.device)

    no_decay = ["bias", "LayerNorm.weight"]

    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            config.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate)

    #     t_total = len(train_load) * config.num_train_epochs
    #     scheduler = get_linear_schedule_with_warmup(
    #         optimizer, num_warmup_steps=t_total * config.warmup_proportion, num_training_steps=t_total
    #     )

    cudnn.benchmark = True

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # 5)封装
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[config.local_rank])

    model.train()
    if config.fgm:
        fgm = FGM(model)

    for epoch in range(config.num_train_epochs):
        train_sampler.set_epoch(epoch)
        is_best = False
        torch.cuda.empty_cache()

        for batch, (input_ids, token_type_ids, attention_mask,
                    label) in enumerate(train_load):
            input_ids = input_ids.cuda(config.local_rank, non_blocking=True)
            attention_mask = attention_mask.cuda(config.local_rank,
                                                 non_blocking=True)
            token_type_ids = token_type_ids.cuda(config.local_rank,
                                                 non_blocking=True)
            label = label.cuda(config.local_rank, non_blocking=True)

            outputs = model(input_ids=input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            labels=label)

            loss = outputs.loss
            model.zero_grad()
            loss.backward()
            #             torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)

            if config.fgm:
                fgm.attack()  # 在embedding上添加对抗扰动
                loss_adv = model(input_ids=input_ids,
                                 attention_mask=attention_mask,
                                 token_type_ids=token_type_ids,
                                 labels=label).loss
                loss_adv.backward()  # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
                fgm.restore()  # 恢复embedding参数

            optimizer.step()
        #             scheduler.step()

        dev_auc = model_evaluate(config, model, valid_load)

        # 同步各个进程的速度,计算分布式loss
        torch.distributed.barrier()
        reduce_dev_auc = reduce_auc(dev_auc, config.nprocs).item()

        if reduce_dev_auc > best_dev_auc:
            best_dev_auc = reduce_dev_auc
            is_best = True

        now = strftime("%Y-%m-%d %H:%M:%S", localtime())
        msg = 'number {} fold,time:{},epoch:{}/{},reduce_dev_auc:{},best_dev_auc:{}'

        if config.local_rank in [0, -1]:
            print(
                msg.format(k, now, epoch + 1, config.num_train_epochs,
                           reduce_dev_auc, best_dev_auc))
            checkpoint = {
                "status": model.state_dict(),
                "epoch": epoch + 1,
                'reduce_dev_auc': reduce_dev_auc
            }
            if is_best:
                torch.save(
                    checkpoint, '../user_data/save_model' + os.sep +
                    '{}_best_model.pth.tar'.format(config.model_name))
            torch.save(
                checkpoint, '../user_data/save_model' + os.sep +
                '{}_checkpoint.pth.tar'.format(config.model_name))
            del checkpoint

    torch.distributed.barrier()
Esempio n. 11
0
def make_bert_multitask(pretrained_model_dir,
                        tasks,
                        num_labels_per_task,
                        init_args_dir,
                        mask_id,
                        encoder=None,
                        args=None):
    assert num_labels_per_task is not None and isinstance(num_labels_per_task, dict), \
        "ERROR : num_labels_per_task {} should be a dictionary".format(num_labels_per_task)
    assert isinstance(tasks, list) and len(
        tasks) >= 1, "ERROR tasks {} should be a list of len >=1".format(tasks)

    if init_args_dir is None:
        if pretrained_model_dir is None:

            pretrained_model_dir = args.bert_model
        # assert args.output_attentions is None or not args.output_attentions, "ERROR not supported "

        multitask_wrapper = BertMultiTask

        def get_state_dict_mapping(model):
            if model.startswith("xlm") or model.startswith(
                    "rob") or model.startswith("camembert"):
                return {
                    "roberta":
                    "encoder",  # "lm_head":,
                    "lm_head.decoder":
                    "head.mlm.predictions.decoder",
                    "lm_head.dense":
                    "head.mlm.predictions.transform.dense",
                    "lm_head.bias":
                    "head.mlm.predictions.bias",
                    "lm_head.layer_norm":
                    "head.mlm.predictions.transform.LayerNorm"
                }
            elif model.startswith("bert") or model.startswith(
                    "cahya") or model.startswith("KB"):
                return {"bert": "encoder", "cls": "head.mlm"}
            elif model.startswith("asafaya"):
                return {"bert": "encoder", "cls": "head.mlm"}
            else:
                raise (Exception(
                    f"not supported by {multitask_wrapper} needs to define a ")
                       )

        state_dict_mapping = get_state_dict_mapping(args.bert_model)

        model = multitask_wrapper.from_pretrained(
            pretrained_model_dir,
            tasks=tasks,
            mask_id=mask_id,
            output_attentions=args.output_attentions,
            output_hidden_states=args.output_all_encoded_layers,
            output_hidden_states_per_head=args.output_hidden_states_per_head,
            hard_skip_attention_layers=args.hard_skip_attention_layers,
            hard_skip_all_layers=args.hard_skip_all_layers,
            hard_skip_dense_layers=args.hard_skip_dense_layers,
            num_labels_per_task=num_labels_per_task,
            mapping_keys_state_dic=
            state_dict_mapping,  #DIR_2_STAT_MAPPING[pretrained_model_dir],
            encoder=eval(encoder) if encoder is not None else BertModel,
            dropout_classifier=args.dropout_classifier,
            hidden_dropout_prob=args.hidden_dropout_prob,
            random_init=args.random_init,
            load_params_only_ls=None,
            not_load_params_ls=args.not_load_params_ls)

    elif init_args_dir is not None:
        assert pretrained_model_dir is not None, "ERROR model_dir is needed here for reloading"
        init_args_dir = get_init_args_dir(init_args_dir)
        args_checkpoint = json.load(open(init_args_dir, "r"))
        assert "checkpoint_dir" in args_checkpoint, "ERROR checkpoint_dir not in {} ".format(
            args_checkpoint)

        checkpoint_dir = args_checkpoint["checkpoint_dir"]
        assert os.path.isfile(
            checkpoint_dir), "ERROR checkpoint {} not found ".format(
                checkpoint_dir)

        # redefining model and reloading
        def get_config_bert(bert_model, config_file_name="bert_config.json"):
            model_dir = BERT_MODEL_DIC[bert_model]["model"]
            #tempdir = tempfile.mkdtemp()
            #print("extracting archive file {} to temp dir {}".format(model_dir, tempdir))
            #with tarfile.open(model_dir, 'r:gz') as archive:
            #    archive.extractall(tempdir)
            #serialization_dir = tempdir
            serialization_dir = None
            config_file = os.path.join(model_dir, config_file_name)
            try:
                assert os.path.isfile(
                    config_file
                ), "ERROR {} not a file , extracted from {} : dir includes {} ".format(
                    config_file, model_dir,
                    [x[0] for x in os.walk(serialization_dir)])
            except Exception as e:
                config_file = os.path.join(model_dir, "config.json")
                assert os.path.join(config_file)
            return config_file

        config_file = get_config_bert(
            args_checkpoint["hyperparameters"]["bert_model"])
        encoder = eval(BERT_MODEL_DIC[args_checkpoint["hyperparameters"]
                                      ["bert_model"]]["encoder"])
        config = BertConfig(
            config_file,
            output_attentions=args.output_attentions,
            output_hidden_states=args.output_all_encoded_layers,
            output_hidden_states_per_head=args.output_hidden_states_per_head)
        #
        config.vocab_size = 119547

        model = BertMultiTask(
            config=config,
            tasks=[
                task for tasks in args_checkpoint["hyperparameters"]["tasks"]
                for task in tasks
            ],
            num_labels_per_task=args_checkpoint["info_checkpoint"]
            ["num_labels_per_task"],
            encoder=encoder,
            mask_id=mask_id)
        printing("MODEL : loading model from checkpoint {}",
                 var=[checkpoint_dir],
                 verbose=1,
                 verbose_level=1)
        model.load_state_dict(
            torch.load(checkpoint_dir,
                       map_location=lambda storage, loc: storage))
        model.append_extra_heads_model(downstream_tasks=tasks,
                                       num_labels_dic_new=num_labels_per_task)
    else:
        raise (Exception(
            "only one of pretrained_model_dir checkpoint_dir can be defined "))

    return model
Esempio n. 12
0
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    logger.info("device: {}".format(device))

    df_train = pd.read_csv(TRAIN_PATH)
    df_train.columns = ["doc_id", "sents"]
    sents_list = df_train["sents"].values.tolist()
    logger.info("len(sents_list): {}".format(len(sents_list)))

    config = BertConfig()
    config.num_hidden_layers = 3
    config.num_attention_heads = 12
    config.hidden_size = 768
    config.intermediate_size = 3072
    config.max_position_embeddings = 512
    config.vocab_size = 32000

    logger.info("USE_NSP: {}".format(USE_NSP))
    if USE_NSP:
        model = BertForPreTraining(config)
    else:
        model = BertForPreTrainingWithoutNSP(config)
    model.to(device)

    logger.info(config)
    logger.info(model)

    optimizer = AdamW(model.parameters(), lr=2e-5)
    model.train()
    train_losses = []
    for i in range(1, MAX_STEPS + 1):
Esempio n. 13
0
    model = GPT2LMHeadModel(config)
    model.resize_token_embeddings(len(tokenizer))
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    if train:
        train_dataset = datasets.Dataset.load_from_disk(os.path.join(data_dir, "lm_train"))


elif model_type == "bert":
    dataset_properties = json.load(open(os.path.join(data_dir, "dataset_properties.json")))
    special_tokens = dataset_properties["special_tokens"]
    tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
    tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})

    config = BertConfig()
    config.vocab_size = len(tokenizer)

    model = AutoModelForMaskedLM.from_config(config)
    model.resize_token_embeddings(len(tokenizer))
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)

    # the NL inputs for the train dataset are the same for BERT and GPT-2 models, but they are tokenized
    # differently (using the corresponding BERT and GPT-2 tokenizers, respectively). The standard training
    # set is already tokenized with the BERT tokenizer, so we can reuse that set here.
    if train:
        train_dataset = datasets.Dataset.load_from_disk(os.path.join(data_dir, "arsenal_train"))


else:
    raise("unknown model type")