Ejemplo n.º 1
0
def load_base_model_if_needed(learner: Learner,
                              lm_training_config: LMTrainingConfig,
                              model_file='best') -> None:
    if lm_training_config.base_model:
        model = os.path.join(lm_training_config.base_model, model_file)
        logger.info(f"Using pretrained model: {model}.pth")
        # not setting purge to True raises a pickle serialization error
        learner.load(model, purge=False)
    else:
        logger.info("Training form scratch")
            bert_train.iloc[val_idx, :],
            bert_test,
            tokenizer=fastai_tokenizer,
            vocab=fastai_bert_vocab,
            include_bos=False,
            include_eos=False,
            text_cols='comment_text',
            label_cols=label_cols,
            bs=BATCH_SIZE,
            collate_fn=partial(pad_collate, pad_first=False, pad_idx=0),
        )

        learner = Learner(databunch, bert_model, loss_func=bert_custom_loss)
        if CUR_STEP != 1:
            learner.load('/kaggle/input/freeze-bert-1-s-uc-260ml-3e-8f-s-' +
                         str(CUR_STEP - 1) + '-f-' + str(MAKE_FOLD) +
                         '/models/' + FILE_NAME)

        learner.fit_one_cycle(N_EPOCH, max_lr=MAX_LR)

        oof[val_idx] = get_preds_as_nparray(DatasetType.Valid).astype(
            np.float32)
        predictions += get_preds_as_nparray(DatasetType.Test).astype(
            np.float32) / NFOLDS

        validate_df(train.iloc[val_idx], oof[val_idx, 0], verbose=True)

        learner.save(FILE_NAME)

print('CV BIASED AUC:')
validate_df(train, oof[:, 0], verbose=True)