def baselineNN_search(parameters):
    """Set up, run and evaluate a baseline neural network"""
    # CV with skorch
    net = NeuralNet(
        # Module
        module=BaselineNN,
        # Module settings
        module__hidden_dim=parameters["hidden_units"],
        module__p_dropout=parameters["dropout"],
        module__use_batch_norm=parameters["use_batch_norm"],
        module__weights=FTEMB,  # These are word embeddings
        module__num_classes=len(category_map),
        # Epochs & learning rate
        max_epochs=25,
        lr=parameters["learning_rate"],
        # Optimizer
        optimizer=optim.Adam
        if parameters["optimizer"] == "Adam" else optim.RMSprop,
        # Loss function
        criterion=nn.CrossEntropyLoss,
        criterion__weight=cw,
        # Shuffle training data on each epoch
        iterator_train__shuffle=True,
        # Batch size
        batch_size=128,
        train_split=CVSplit(cv=5),
        # Device
        device=device,
        # Callbacks
        callbacks=[
            skorch.callbacks.EpochScoring(f1_score,
                                          use_caching=True,
                                          name="valid_f1"),
            skorch.callbacks.EpochScoring(precision_score,
                                          use_caching=True,
                                          name="valid_precision"),
            skorch.callbacks.EpochScoring(recall_score,
                                          use_caching=True,
                                          name="valid_recall"),
            skorch.callbacks.EpochScoring(accuracy_score,
                                          use_caching=True,
                                          name="valid_accuracy")
        ])
    # Verbose to false
    net.verbose = 1
    # Fit
    net = net.fit(WD)
    # Get train / validation history
    train_loss = net.history[:, "train_loss"]
    val_loss = net.history[:, "valid_loss"]
    val_accuracy = net.history[:, "valid_accuracy"]
    val_f1 = net.history[:, "valid_f1"]
    val_precision = net.history[:, "valid_precision"]
    val_recall = net.history[:, "valid_recall"]
    # Min loss
    which_min = np.argmin(val_loss)
    # Write to file
    with open(args.out_file, 'a') as of_connection:
        writer = csv.writer(of_connection)
        writer.writerow([
            parameters, which_min,
            np.round(train_loss[which_min], 4),
            np.round(val_accuracy[which_min], 4),
            np.round(val_loss[which_min], 4),
            np.round(val_f1[which_min], 4),
            np.round(val_precision[which_min], 4),
            np.round(val_recall[which_min], 4)
        ])
    # Return cross-validation loss
    return ({
        "loss": val_loss[which_min],
        "parameters": parameters,
        "iteration": which_min,
        'status': STATUS_OK
    })
        skorch.callbacks.EpochScoring(f1_score,
                                      use_caching=True,
                                      name="valid_f1"),
        skorch.callbacks.EpochScoring(precision_score,
                                      use_caching=True,
                                      name="valid_precision"),
        skorch.callbacks.EpochScoring(recall_score,
                                      use_caching=True,
                                      name="valid_recall"),
        skorch.callbacks.EpochScoring(accuracy_score,
                                      use_caching=True,
                                      name="valid_accuracy")
    ])

# Verbose to false
net.verbose = 1

#%% Fit the model

io = net.fit(WD)

# Save model
net.save_params(f_params='models/baselineNN.pkl')

#%% Or load it from disk

net.initialize()
net.load_params(f_params="models/baselineNN.pkl")

#%% Predict on train