def main(args):

    # Config
    device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available()
                          and args.cuda >= 0 else "cpu")
    # model
    model = SimpleMLP(num_classes=10)

    # CL Benchmark Creation
    perm_mnist = PermutedMNIST(n_experiences=5)
    train_stream = perm_mnist.train_stream
    test_stream = perm_mnist.test_stream

    # Prepare for training & testing
    optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
    criterion = CrossEntropyLoss()

    # Joint training strategy
    joint_train = JointTraining(model,
                                optimizer,
                                criterion,
                                train_mb_size=32,
                                train_epochs=1,
                                eval_mb_size=32,
                                device=device)

    # train and test loop
    results = []
    print("Starting training.")
    joint_train.train(train_stream)
    results.append(joint_train.eval(test_stream))
Ejemplo n.º 2
0
def main(args):
    # Device config
    device = torch.device(
        f"cuda:{args.cuda}"
        if torch.cuda.is_available() and args.cuda >= 0
        else "cpu"
    )

    # model
    model = SimpleMLP(num_classes=10)

    # Here we show all the MNIST variation we offer in the "classic" benchmarks
    if args.mnist_type == "permuted":
        scenario = PermutedMNIST(n_experiences=5, seed=1)
    elif args.mnist_type == "rotated":
        scenario = RotatedMNIST(
            n_experiences=5, rotations_list=[30, 60, 90, 120, 150], seed=1
        )
    else:
        scenario = SplitMNIST(n_experiences=5, seed=1)

    # Than we can extract the parallel train and test streams
    train_stream = scenario.train_stream
    test_stream = scenario.test_stream

    # Prepare for training & testing
    optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
    criterion = CrossEntropyLoss()

    # Continual learning strategy with default logger
    cl_strategy = Naive(
        model,
        optimizer,
        criterion,
        train_mb_size=32,
        train_epochs=100,
        eval_mb_size=32,
        device=device,
        eval_every=1,
        plugins=[EarlyStoppingPlugin(args.patience, "test_stream")],
    )

    # train and test loop
    results = []
    for train_task, test_task in zip(train_stream, test_stream):
        print("Current Classes: ", train_task.classes_in_this_experience)
        cl_strategy.train(train_task, eval_streams=[test_task])
        results.append(cl_strategy.eval(test_stream))