Ejemplo n.º 1
0
def fcnn_train(train_input, train_output):
    clf = FullyConnectedNN()
    # print(clf)
    fcnn_learner = Learner(data=data,
                           model=clf,
                           loss_func=nn.CrossEntropyLoss(),
                           metrics=accuracy)
    fcnn_learner.fit_one_cycle(5, 1e-2)
Ejemplo n.º 2
0
def fcnn_train(train_input, train_output, test_input, test_output):

    tensor_train_input = torch.from_numpy(train_input)
    tensor_train_output = torch.from_numpy(train_output[:, 1].astype(int))
    tensor_test_input = torch.from_numpy(test_input)
    tensor_test_output = torch.from_numpy(test_output[:, 1].astype(int))
    train_ds = ArrayDataset(tensor_train_input, tensor_train_output)
    test_ds = ArrayDataset(tensor_test_input, tensor_test_output)
    bs = 10
    databunch = DataBunch.create(train_ds, test_ds, bs=bs)

    clf = FullyConnectedNN()
    # print(clf)

    fcnn_learner = Learner(data=databunch,
                           model=clf,
                           loss_func=nn.CrossEntropyLoss(),
                           metrics=accuracy)
    fcnn_learner.fit_one_cycle(50, 1e-2)
    # breakpoint()
    return fcnn_learner