Пример #1
0
def main(args):
    # create run dirs
    run_name = get_run_name(args)
    run_dir = os.path.join(DATA_DIR, 'predict_data', run_name)
    preprocessed_folder = os.path.join(run_dir, 'preprocessed')
    tfrecord_folder = os.path.join(run_dir, 'tfrecord')
    for _dir in [preprocessed_folder, tfrecord_folder]:
        if not os.path.isdir(_dir):
            os.makedirs(_dir)
    # set up parallel
    if args.run_in_parallel:
        num_cpus = max(multiprocessing.cpu_count() - 1, 1)
    else:
        num_cpus = 1
    logger.info(f'Running with {num_cpus} CPUs...')
    t_s = time.time()
    parallel = joblib.Parallel(n_jobs=num_cpus)
    process_fn_delayed = joblib.delayed(process)
    res = parallel((process_fn_delayed(f_name, tfrecord_folder,
                                       preprocessed_folder, args)
                    for f_name in tqdm(args.input_txt_files,
                                       desc='Processing input files')))
    # save config
    f_config = os.path.join(run_dir, 'create_predict_config.json')
    logger.info(f'Saving config to {f_config}')
    data = dict(vars(args))
    save_to_json(data, f_config)
def main(args):
    input_files = get_input_files(args.input_data)
    logger.info(f'Found {len(input_files):,} input text files')

    # preprocess fn
    preprocess_fn = preprocess_bert
    do_lower_case = PRETRAINED_MODELS[args.model_class]['lower_case']

    # create run dirs
    run_name = get_run_name(args)
    output_folder = os.path.join(DATA_FOLDER, 'pretrain', run_name,
                                 'preprocessed')
    if not os.path.isdir(output_folder):
        os.makedirs(output_folder)

    # set up parallel processing
    if args.run_in_parallel:
        num_cpus = max(multiprocessing.cpu_count() - 1, 1)
    else:
        num_cpus = 1
    parallel = joblib.Parallel(n_jobs=num_cpus)
    preprocess_fn_delayed = joblib.delayed(preprocess_file)

    # run
    t_s = time.time()
    res = parallel((preprocess_fn_delayed(input_file, preprocess_fn,
                                          output_folder, do_lower_case, args)
                    for input_file in tqdm(input_files)))
    t_e = time.time()
    time_taken_min = (t_e - t_s) / 60
    logger.info(f'Finished after {time_taken_min:.1f} min')
    num_sentences = sum(r[0] for r in res)
    num_tokens = sum(r[1] for r in res)
    num_tweets = sum(r[2] for r in res)
    num_examples = sum(r[3] for r in res)
    num_examples_single_sentence = sum(r[4] for r in res)
    logger.info(
        f'Collected a total of {num_sentences:,} sentences, {num_tokens:,} tokens from {num_tweets:,} tweets'
    )
    logger.info(
        f'Collected a total of {num_examples:,} examples, {num_examples_single_sentence:,} examples only contain a single sentence.'
    )
    logger.info(f'All output files can be found under {output_folder}')
    # save config
    f_config = os.path.join(output_folder, 'prepare_pretrain_config.json')
    logger.info(f'Saving config to {f_config}')
    data = {
        'num_sentences': num_sentences,
        'num_tokens': num_tokens,
        'num_tweets': num_tweets,
        'num_examples': num_examples,
        'num_examples_single_sentence': num_examples_single_sentence,
        'time_taken_min': time_taken_min,
        **vars(args)
    }
    save_to_json(data, f_config)
Пример #3
0
def main(args):
    # create run dirs
    run_name = get_run_name(args)
    run_dir = os.path.join(DATA_DIR, 'finetune', run_name)
    if not os.path.isdir(run_dir):
        os.makedirs(run_dir)
    # find input data
    originals_dir = os.path.join(DATA_DIR, 'finetune', 'originals')
    if args.finetune_datasets is None or len(args.finetune_datasets) == 0:
        finetune_datasets = os.listdir(originals_dir)
    else:
        finetune_datasets = args.finetune_datasets
    do_lower_case = PRETRAINED_MODELS[args.model_class]['lower_case']
    for dataset in finetune_datasets:
        logger.info(f'Processing dataset {dataset}...')
        preprocessed_folder = os.path.join(run_dir, dataset, 'preprocessed')
        if not os.path.isdir(preprocessed_folder):
            os.makedirs(preprocessed_folder)
        labels = set()
        for _type in ['train', 'dev']:
            f_name = f'{_type}.tsv'
            logger.info(f'Reading data for for type {_type}...')
            f_path = os.path.join(originals_dir, dataset, f_name)
            if not os.path.isfile(f_path):
                logger.info(f'Could not find file {f_path}. Skipping.')
                continue
            df = pd.read_csv(f_path, usecols=REQUIRED_COLUMNS, sep='\t')
            logger.info('Creating preprocessed files...')
            df.loc[:, 'text'] = df.text.apply(preprocess_bert,
                                              args=(args, do_lower_case))
            df.to_csv(os.path.join(preprocessed_folder, f_name),
                      columns=REQUIRED_COLUMNS,
                      header=False,
                      index=False,
                      sep='\t')
            # add labels
            labels.update(df.label.unique().tolist())
        logger.info('Creating tfrecords files...')
        # we sort the labels alphabetically in order to maintain consistent label ids
        labels = sorted(list(labels))
        dataset_dir = os.path.join(run_dir, dataset)
        generate_tfrecords(args, dataset_dir, labels)
        # saving config
    f_path_config = os.path.join(run_dir, 'create_finetune_config.json')
    logger.info(f'Saving config to {f_path_config}')
    save_to_json(vars(args), f_path_config)
def main(args):
    rng = random.Random(args.random_seed)
    run_folder = os.path.join(DATA_DIR, 'pretrain', args.run_name)
    input_files = get_input_files(run_folder)

    logger.info('Processing the following {len(input_files):,} input files:')
    for input_file in input_files:
        logger.info(f'{input_file}')

    logger.info(f'Setting up tokenizer for model class {args.model_class}')
    tokenizer = get_tokenizer(args.model_class)

    if args.run_in_parallel:
        num_cpus = max(min(multiprocessing.cpu_count() - 1, args.max_num_cpus),
                       1)
    else:
        num_cpus = 1
    logger.info(f'Running with {num_cpus} CPUs...')
    t_s = time.time()
    parallel = joblib.Parallel(n_jobs=num_cpus)
    process_fn_delayed = joblib.delayed(process)
    res = parallel(
        (process_fn_delayed(input_file, tokenizer, rng, args)
         for input_file in tqdm(input_files, desc='Processing input files')))
    t_e = time.time()
    time_taken_min = (t_e - t_s) / 60
    logger.info(f'Finished after {time_taken_min:.1f} min')
    counts = {}
    for _r in res:
        _type = _r[2]
        if _type not in counts:
            counts[_type] = defaultdict(int)
        counts[_type]['num_documents'] += _r[0]
        counts[_type]['num_instances'] += _r[1]
    for _type, c in counts.items():
        num_instances = c['num_instances']
        num_documents = c['num_documents']
        logger.info(
            f'Type {_type}: Generated a total of {num_instances:,} training examples from {num_documents:,} documents'
        )
    f_config = os.path.join(run_folder, 'create_pretrain_config.json')
    logger.info(f'Saving config to {f_config}')
    data = {'counts': counts, 'time_taken_min': time_taken_min, **vars(args)}
    save_to_json(data, f_config)
Пример #5
0
def run(args, strategy):
    """Pretrains model using TF2. Adapted from the tensorflow/models Github"""
    # CONFIG
    # Use timestamp to generate a unique run name
    run_name = get_run_name(args)
    logger.info(f'*** Starting run {run_name} ***')
    output_dir = f'gs://{args.bucket_name}/{args.project_name}/pretrain/runs/{run_name}'

    # pretrained model path
    try:
        pretrained_model_path = PRETRAINED_MODELS[args.model_class]['bucket_location']
    except KeyError:
        raise ValueError(f'Could not find a pretrained model matching the model class {args.model_class}')
    pretrained_model_config_path = f'gs://{args.bucket_name}/{pretrained_model_path}/bert_config.json'
    if args.init_checkpoint is None:
        pretrained_model_checkpoint_path = f'gs://{args.bucket_name}/{pretrained_model_path}/bert_model.ckpt'
    else:
        pretrained_model_checkpoint_path = f'gs://{args.bucket_name}/{args.project_name}/pretrain/runs/{args.init_checkpoint}'

    # some logging
    logger.info(f'Running pretraining of model {args.model_class} on pretrain data {args.pretrain_data}')
    logger.info(f'Initializing model from checkpoint {pretrained_model_checkpoint_path}')

    # load model config based on model_class
    model_config = get_model_config(pretrained_model_config_path)

    # input data function
    train_input_fn = get_dataset_fn(args, _type='train')
    eval_input_fn = None
    eval_metric_fn = None
    if args.do_eval:
        logger.info(f'Setting up evaluation dataset')
        eval_metric_fn = get_eval_metric_fn
        eval_input_fn = get_dataset_fn(args, _type='dev')

    # model_fn
    def _get_pretrained_model(end_lr=0.0):
        """Gets a pretraining model."""
        pretrain_model, core_model = bert_models.pretrain_model(model_config, args.max_seq_length, args.max_predictions_per_seq)
        if args.warmup_proportion is None:
            warmup_steps = args.warmup_steps
            warmup_proportion_perc = 100 * args.warmup_steps/(args.num_epochs * args.num_steps_per_epoch)
        else:
            warmup_steps = int(args.num_epochs * args.num_steps_per_epoch * args.warmup_proportion)
            warmup_proportion_perc = args.warmup_proportion * 100
        logger.info(f'Running {warmup_steps:,} warmup steps ({warmup_proportion_perc:.2f}% warmup)')
        optimizer = utils.optimizer.create_optimizer(
                args.learning_rate,
                args.num_steps_per_epoch * args.num_epochs,
                warmup_steps,
                args.end_lr,
                args.optimizer_type)
        pretrain_model.optimizer = configure_optimizer(optimizer, use_float16=args.dtype == 'fp16', use_graph_rewrite=False)
        return pretrain_model, core_model

    # custom callbacks
    summary_dir = os.path.join(output_dir, 'summaries')
    time_history_callback = keras_utils.TimeHistory(
        batch_size=args.train_batch_size,
        log_steps=args.time_history_log_steps,
        logdir=summary_dir)
    custom_callbacks = [time_history_callback]

    # Save an initial version of the log file
    data = {
            'created_at': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
            'run_name': run_name,
            'num_train_steps': args.num_steps_per_epoch * args.num_epochs,
            'eval_steps': args.eval_steps,
            'model_dir': output_dir,
            'output_dir': output_dir,
            **vars(args),
            }
    # write initial training log
    f_path_training_log = os.path.join(output_dir, 'run_logs.json')
    logger.info(f'Writing training preliminary log to {f_path_training_log}...')
    save_to_json(data, f_path_training_log)

    # run training loop
    logger.info(f'Run training for {args.num_epochs:,} epochs, {args.num_steps_per_epoch:,} steps each, processing {args.num_epochs*args.num_steps_per_epoch*args.train_batch_size:,} training examples in total...')
    time_start = time.time()
    model_training_utils.run_customized_training_loop(
        strategy=strategy,
        model_fn=_get_pretrained_model,
        loss_fn=get_loss_fn(),
        scale_loss=True,
        model_dir=output_dir,
        train_input_fn=train_input_fn,
        steps_per_epoch=args.num_steps_per_epoch,
        steps_per_loop=args.steps_per_loop,
        epochs=args.num_epochs,
        eval_input_fn=eval_input_fn,
        eval_steps=args.eval_steps,
        metric_fn=eval_metric_fn,
        init_checkpoint=pretrained_model_checkpoint_path,
        load_mlm_nsp_weights = args.load_mlm_nsp_weights,
        set_trainstep = args.set_trainstep,
        custom_callbacks=custom_callbacks,
        run_eagerly=False,
        sub_model_export_name='pretrained/bert_model',
        explicit_allreduce=False,
        pre_allreduce_callbacks=None,
        post_allreduce_callbacks=None)
    time_end = time.time()
    training_time_min = (time_end-time_start)/60
    data['training_time_min'] = training_time_min
    logger.info(f'Finished training after {training_time_min:.1f} min')
    # Write to run directory
    logger.info(f'Writing final training log to {f_path_training_log}...')
    save_to_json(data, f_path_training_log)
    # Write bert config
    f_path_bert_config = os.path.join(output_dir, 'bert_config.json')
    logger.info(f'Writing BERT config to {f_path_bert_config}...')
    save_to_json(model_config.to_dict(), f_path_bert_config)
Пример #6
0
def run(args):
    # start time
    s_time = time.time()
    # paths
    run_dir = f'gs://{args.bucket_name}/{args.project_name}/finetune/runs/{args.run_name}'
    ts = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S_%f')
    output_folder = os.path.join('data', 'predictions', f'predictions_{ts}')
    predictions_output_folder = os.path.join('data', 'predictions',
                                             f'predictions_{ts}',
                                             'predictions')
    if not os.path.isdir(predictions_output_folder):
        os.makedirs(predictions_output_folder)
    # read configs
    logger.info(f'Reading run configs...')
    run_log = read_run_log(run_dir)
    pretrained_model_config_path = get_model_config_path(args)
    model_config = get_model_config(pretrained_model_config_path)
    max_seq_length = run_log['max_seq_length']
    label_mapping = run_log['label_mapping']
    num_labels = len(label_mapping)
    # load tokenizer
    logger.info(f'Loading tokenizer...')
    tokenizer = get_tokenizer(args.model_class)
    # load model
    logger.info(f'Loading model...')
    model = get_model(args, model_config, num_labels, max_seq_length)
    # restore fine-tuned run
    checkpoint_path = os.path.join(run_dir, 'checkpoint')
    logger.info(f'Restore run checkpoint {checkpoint_path}...')
    # load weights (expect partial state because we don't want need the optimizer state)
    try:
        model.load_weights(checkpoint_path).expect_partial()
    except:
        logger.error(
            f'Restoring from checkpoint unsuccessful. Use the flag --use_tf_hub if the TFHub was used to initialize the model.'
        )
        return
    else:
        logger.info(f'... successfully restored checkpoint')
    # predict
    num_predictions = 0
    predictions = []
    if args.input_text:
        example = generate_single_example(args.input_text, tokenizer,
                                          max_seq_length)
        preds = model.predict(example)
        preds = format_prediction(preds, label_mapping, args.label_name)
        print(json.dumps(preds, indent=4))
        return
    elif args.interactive_mode:
        while True:
            text = input('Type text to predict. Quit by typing "q".\n>>> ')
            if text.lower() == 'q':
                break
            example = generate_single_example(text, tokenizer, max_seq_length)
            preds = model.predict(example)
            preds = format_prediction(preds, label_mapping, args.label_name)
            print(json.dumps(preds, indent=4))
        return
    elif args.input_txt_files:
        s_time_predict = time.time()
        for input_file in args.input_txt_files:
            num_lines = sum(1 for _ in open(input_file, 'r'))
            num_batches = int(num_lines / args.eval_batch_size) + 1
            f_out_name = os.path.basename(input_file).split('.')[-2]
            f_out = os.path.join(predictions_output_folder,
                                 f'{f_out_name}.jsonl')
            logger.info(f'Predicting file {input_file}...')
            for batch in tqdm(generate_examples_from_txt_file(
                    input_file, tokenizer, max_seq_length,
                    args.eval_batch_size),
                              total=num_batches,
                              unit='batch'):
                preds = model.predict(batch)
                preds = format_prediction(preds, label_mapping,
                                          args.label_name)
                num_predictions += len(preds)
                with open(f_out, 'a') as f:
                    for pred in preds:
                        f.write(json.dumps(pred) + '\n')
        e_time_predict = time.time()
        prediction_time_min = (e_time_predict - s_time_predict) / 60
        logger.info(
            f'Wrote {num_predictions:,} predictions in {prediction_time_min:.1f} min ({num_predictions/prediction_time_min:.1f} predictions per min)'
        )
    elif args.input_tfrecord_files:
        s_time_predict = time.time()
        for input_file_pattern in args.input_tfrecord_files:
            for input_file in tf.io.gfile.glob(input_file_pattern):
                logger.info(f'Processing file {input_file}')
                dataset = get_tfrecord_dataset(input_file,
                                               args.eval_batch_size,
                                               max_seq_length)()
                num_batches = sum(1 for _ in tf.data.TFRecordDataset(
                    input_file).batch(args.eval_batch_size))
                f_out_name = os.path.basename(input_file).split('.')[-2]
                f_out = os.path.join(predictions_output_folder,
                                     f'{f_out_name}.jsonl')
                for batch in tqdm(dataset, total=num_batches, unit='batch'):
                    preds = model.predict(batch)
                    preds = format_prediction(preds, label_mapping,
                                              args.label_name)
                    num_predictions += len(preds)
                    with open(f_out, 'a') as f:
                        for pred in preds:
                            f.write(json.dumps(pred) + '\n')
        e_time_predict = time.time()
        prediction_time_min = (e_time_predict - s_time_predict) / 60
        logger.info(
            f'Wrote {num_predictions:,} predictions in {prediction_time_min:.1f} min ({num_predictions/prediction_time_min:.1f} predictions per min)'
        )
    e_time = time.time()
    total_time_min = (e_time - s_time) / 60
    f_config = os.path.join(output_folder, 'predict_config.json')
    logger.info(f'Saving config to {f_config}')
    data = {
        'prediction_time_min': prediction_time_min,
        'total_time_min': total_time_min,
        'num_predictions': num_predictions,
        **vars(args)
    }
    save_to_json(data, f_config)
Пример #7
0
def run(args):
    """Train using the Keras/TF 2.0. Adapted from the tensorflow/models Github"""
    # CONFIG
    run_name = get_run_name(args)
    logger.info(f'*** Starting run {run_name} ***')
    data_dir = f'gs://{args.bucket_name}/{args.project_name}/finetune/finetune_data/{args.finetune_data}'
    output_dir = f'gs://{args.bucket_name}/{args.project_name}/finetune/runs/{run_name}'

    # Get configs
    pretrained_model_config_path = get_model_config_path(args)
    model_config = get_model_config(pretrained_model_config_path)

    # Meta data/label mapping
    input_meta_data = get_input_meta_data(data_dir)
    label_mapping = get_label_mapping(data_dir)
    logger.info(f'Loaded training data meta.json file: {input_meta_data}')

    # Calculate steps, warmup steps and eval steps
    train_data_size = input_meta_data['train_data_size']
    num_labels = input_meta_data['num_labels']
    max_seq_length = input_meta_data['max_seq_length']
    if args.limit_train_steps is None:
        steps_per_epoch = int(train_data_size / args.train_batch_size)
    else:
        steps_per_epoch = args.limit_train_steps
    warmup_steps = int(args.num_epochs * train_data_size *
                       args.warmup_proportion / args.train_batch_size)
    if args.limit_eval_steps is None:
        eval_steps = int(
            math.ceil(input_meta_data['eval_data_size'] /
                      args.eval_batch_size))
    else:
        eval_steps = args.limit_eval_steps

    # some logging
    if args.init_checkpoint is None:
        logger.info(
            f'Finetuning on datset {args.finetune_data} using default pretrained model {args.model_class}'
        )
    else:
        logger.info(
            f'Finetuning on datset {args.finetune_data} using pretrained model in {args.init_checkpoint} of type {args.model_class}'
        )
    logger.info(
        f'Running {args.num_epochs} epochs with {steps_per_epoch:,} steps per epoch'
    )
    logger.info(
        f'Using warmup proportion of {args.warmup_proportion}, resulting in {warmup_steps:,} warmup steps'
    )
    logger.info(
        f'Using learning rate: {args.learning_rate}, training batch size: {args.train_batch_size}, num_epochs: {args.num_epochs}'
    )

    # Get model
    classifier_model, core_model = get_model(args, model_config,
                                             steps_per_epoch, warmup_steps,
                                             num_labels, max_seq_length)
    optimizer = classifier_model.optimizer
    loss_fn = get_loss_fn(num_labels)

    # Restore checkpoint
    if args.init_checkpoint:
        checkpoint_path = f'gs://{args.bucket_name}/{args.project_name}/pretrain/runs/{args.init_checkpoint}'
        checkpoint = tf.train.Checkpoint(model=core_model)
        checkpoint.restore(checkpoint_path).assert_existing_objects_matched()
        logger.info(f'Successfully restored checkpoint from {checkpoint_path}')

    # Run keras compile
    logger.info(f'Compiling keras model...')
    classifier_model.compile(optimizer=optimizer,
                             loss=loss_fn,
                             metrics=get_metrics())
    logger.info(f'... done')

    # Create all custom callbacks
    summary_dir = os.path.join(output_dir, 'summaries')
    summary_callback = tf.keras.callbacks.TensorBoard(summary_dir,
                                                      profile_batch=0)
    checkpoint_path = os.path.join(output_dir, 'checkpoint')
    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        checkpoint_path, save_weights_only=True)
    time_history_callback = keras_utils.TimeHistory(
        batch_size=args.train_batch_size,
        log_steps=args.time_history_log_steps,
        logdir=summary_dir)
    custom_callbacks = [summary_callback, time_history_callback]
    if args.save_model:
        custom_callbacks.append(checkpoint_callback)
    if args.early_stopping_epochs > 0:
        logger.info(
            f'Using early stopping of after {args.early_stopping_epochs} epochs of val_loss not decreasing'
        )
        early_stopping_callback = tf.keras.callbacks.EarlyStopping(
            patience=args.early_stopping_epochs, monitor='val_loss')
        custom_callbacks.append(early_stopping_callback)

    # Generate dataset_fn
    train_input_fn = get_dataset_fn(os.path.join(data_dir, 'tfrecords',
                                                 'train.tfrecords'),
                                    max_seq_length,
                                    args.train_batch_size,
                                    is_training=True)
    eval_input_fn = get_dataset_fn(os.path.join(data_dir, 'tfrecords',
                                                'dev.tfrecords'),
                                   max_seq_length,
                                   args.eval_batch_size,
                                   is_training=False)

    # Add mertrics callback to calculate performance metrics at the end of epoch
    performance_metrics_callback = Metrics(
        eval_input_fn, label_mapping, os.path.join(summary_dir, 'metrics'),
        eval_steps, args.eval_batch_size, args.validation_freq)
    custom_callbacks.append(performance_metrics_callback)

    # Run keras fit
    time_start = time.time()
    logger.info('Run training...')
    history = classifier_model.fit(x=train_input_fn(),
                                   validation_data=eval_input_fn(),
                                   steps_per_epoch=steps_per_epoch,
                                   epochs=args.num_epochs,
                                   validation_steps=eval_steps,
                                   validation_freq=args.validation_freq,
                                   callbacks=custom_callbacks,
                                   verbose=1)
    time_end = time.time()
    training_time_min = (time_end - time_start) / 60
    logger.info(f'Finished training after {training_time_min:.1f} min')

    # Write training log
    all_scores = performance_metrics_callback.scores
    all_predictions = performance_metrics_callback.predictions
    if len(all_scores) > 0:
        final_scores = all_scores[-1]
        logger.info(f'Final eval scores: {final_scores}')
    else:
        final_scores = {}
    full_history = history.history
    if len(full_history) > 0:
        final_val_loss = full_history['val_loss'][-1]
        final_loss = full_history['loss'][-1]
        logger.info(
            f'Final training loss: {final_loss:.2f}, Final validation loss: {final_val_loss:.2f}'
        )
    else:
        final_val_loss = None
        final_loss = None
    data = {
        'created_at': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        'run_name': run_name,
        'final_loss': final_loss,
        'final_val_loss': final_val_loss,
        'max_seq_length': max_seq_length,
        'num_train_steps': steps_per_epoch * args.num_epochs,
        'eval_steps': eval_steps,
        'steps_per_epoch': steps_per_epoch,
        'training_time_min': training_time_min,
        'data_dir': data_dir,
        'output_dir': output_dir,
        'all_scores': all_scores,
        'all_predictions': all_predictions,
        'num_labels': num_labels,
        'label_mapping': label_mapping,
        **full_history,
        **final_scores,
        **vars(args),
    }
    # Write run_log
    f_path_training_log = os.path.join(output_dir, 'run_logs.json')
    logger.info(f'Writing training log to {f_path_training_log}...')
    save_to_json(data, f_path_training_log)
    # Write bert config
    model_config.id2label = label_mapping
    model_config.label2id = {v: k for k, v in label_mapping.items()}
    model_config.max_seq_length = max_seq_length
    model_config.num_labels = num_labels
    f_path_bert_config = os.path.join(output_dir, 'bert_config.json')
    logger.info(f'Writing BERT config to {f_path_bert_config}...')
    save_to_json(model_config.to_dict(), f_path_bert_config)