コード例 #1
0
def test_save_load_model():
    """Test saving/loading a fitted model to disk"""

    X_train, y_train, X_dev, y_dev = sst2_test_data()

    model = BertClassifier()
    model.max_seq_length = 64
    model.train_batch_size = 8
    model.epochs = 1

    model.fit(X_train, y_train)

    accy1 = model.score(X_dev, y_dev)

    savefile = './test_model_save.bin'
    print("\nSaving model to ", savefile)

    model.save(savefile)

    # load model from disk
    new_model = load_model(savefile)

    # predict with new model
    accy2 = new_model.score(X_dev, y_dev)

    # clean up
    print("Cleaning up model file: test_model_save.bin ")
    os.remove(savefile)

    assert accy1 == accy2
コード例 #2
0
def train_model(train,
                model_file_to_save,
                epochs=3,
                val_frac=0.1,
                class_weight=None):
    X_train = train['sentence']
    y_train = train['label']

    max_seq_length, train_batch_size, lr = 128, 32, 2e-5

    model = BertClassifier(bert_model=BERT_MODEL, random_state=RANDOM_STATE, \
                            class_weight=class_weight, max_seq_length=max_seq_length, \
                            train_batch_size=train_batch_size, learning_rate=lr, \
                            epochs=epochs, validation_fraction=val_frac)
    print(model)
    model.fit(X_train, y_train)
    model.save(model_file_to_save)
    print(f'\n- model saved to: {model_file_to_save}\n')
    return model