Ejemplo n.º 1
0
def train():
    # This loads the dataset and partitions it into batches:
    trainset, testset = dp.load_cifar10()
    trainloader, testloader = dp.batch_data(trainset, testset)
    # Loads the model and the training/testing functions:
    net = SoftmaxClassifier()
    net, _ = set_device(net)
    criterion, optimizer, epochs = set_optimization(net)

    # Print the train and test accuracy after every epoch:
    for epoch in range(epochs):
        train_model(net, trainloader, criterion, optimizer, epoch)
        test_model(net, testloader, epoch)

    print('Finished Training')
    # Save the model:
    save_model(net)
def train(dataset=dataset):
    '''
    Applies the train_model and test_model functions at each epoch
    '''
    # This loads the dataset and partitions it into batches:
    if dataset == "cifar10":
        trainset, testset = dp.load_cifar10()
        trainloader, testloader = dp.generate_batches(trainset, testset)
    if dataset == "mnist":
        trainset, testset = dp.load_mnist()
        trainloader, testloader = dp.generate_batches(trainset, testset)
    if dataset in [
            "miniImageNet", "tieredImageNet", "CIFARFS", "FC100", "Omniglot"
    ]:
        meta_train = data_generator.generate_batch(test=False)
        meta_test = data_generator.generate_batch(test=True)

    # Loads the model and the training/testing functions:
    net = ResNet18(dataset)
    net, device = set_device(net)
    svi, epochs = inference()

    # Print the train and test accuracy after every epoch:
    if dataset == "mnist" or dataset == "cifar10":
        for epoch in range(epochs):
            train_model(trainloader, svi, epoch, device)
            test_model(testloader, epoch, device)

    if dataset in [
            "miniImageNet", "tieredImageNet", "CIFARFS", "FC100", "Omniglot"
    ]:
        for epoch in range(epochs):
            train_model(meta_train, svi, epoch, device)
            test_model(meta_test, epoch, device)

    print('Finished Training')
    # Save the model:
    save_model(dataset, net)