def pytorch_mnist_task(hp: Hyperparameters) -> TrainingOutputs:
    print("Start MNIST training:")

    world_size = torch.cuda.device_count()
    print(f"Device count: {world_size}")
    download_mnist(DATA_DIR)
    mp.spawn(
        train_mnist,
        args=(world_size, hp),
        nprocs=world_size,
        join=True,
    )
    print("Training Complete")
    with open(ACCURACIES_FILE) as fp:
        accuracies = json.load(fp)
    return TrainingOutputs(epoch_accuracies=accuracies,
                           model_state=PythonPickledFile(MODEL_FILE))
Exemple #2
0
def pytorch_mnist_task(hp: Hyperparameters) -> TrainingOutputs:
    wandb_setup()

    # store the hyperparameters' config in ``wandb``
    wandb.config.update(json.loads(hp.to_json()))

    # set random seed
    torch.manual_seed(hp.seed)

    # ideally, if GPU training is required, and if cuda is not available, we can raise an exception
    # however, as we want this algorithm to work locally as well (and most users don't have a GPU locally), we will fallback to using a CPU
    use_cuda = torch.cuda.is_available()
    print(f"Use cuda {use_cuda}")
    device = torch.device("cuda" if use_cuda else "cpu")

    # load data
    kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
    training_data_loader = mnist_dataloader(hp.batch_size,
                                            train=True,
                                            **kwargs)
    test_data_loader = mnist_dataloader(hp.batch_size, train=False, **kwargs)

    # train the model
    model = Net().to(device)

    optimizer = optim.SGD(model.parameters(),
                          lr=hp.learning_rate,
                          momentum=hp.sgd_momentum)

    # run multiple epochs and capture the accuracies for each epoch
    # train the model: run multiple epochs and capture the accuracies for each epoch
    accuracies = []
    for epoch in range(1, hp.epochs + 1):
        train(model, device, training_data_loader, optimizer, epoch,
              hp.log_interval)
        accuracies.append(test(model, device, test_data_loader))

    # after training the model, we can simply save it to disk and return it from the Flyte task as a :py:class:`flytekit.types.file.FlyteFile`
    # type, which is the ``PythonPickledFile``. ``PythonPickledFile`` is simply a decorator on the ``FlyteFile`` that records the format
    # of the serialized model as ``pickled``
    model_file = "mnist_cnn.pt"
    torch.save(model.state_dict(), model_file)

    return TrainingOutputs(epoch_accuracies=accuracies,
                           model_state=PythonPickledFile(model_file))
Exemple #3
0
def mnist_pytorch_job(hp: Hyperparameters) -> TrainingOutputs:
    log_dir = "logs"
    writer = SummaryWriter(log_dir)

    torch.manual_seed(hp.seed)

    use_cuda = torch.cuda.is_available()
    print(f"Use cuda {use_cuda}")
    device = torch.device("cuda" if use_cuda else "cpu")

    print("Using device: {}, world size: {}".format(device, WORLD_SIZE))

    if should_distribute():
        print("Using distributed PyTorch with {} backend".format(hp.backend))
        dist.init_process_group(backend=hp.backend)

    # LOAD Data
    kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "../data",
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]),
        ),
        batch_size=hp.batch_size,
        shuffle=True,
        **kwargs,
    )
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "../data",
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]),
        ),
        batch_size=hp.test_batch_size,
        shuffle=False,
        **kwargs,
    )

    # Train the model
    model = Net().to(device)

    if is_distributed():
        Distributor = (nn.parallel.DistributedDataParallel
                       if use_cuda else nn.parallel.DistributedDataParallelCPU)
        model = Distributor(model)

    optimizer = optim.SGD(model.parameters(),
                          lr=hp.learning_rate,
                          momentum=hp.sgd_momentum)

    accuracies = [
        epoch_step(
            model,
            device,
            train_loader,
            test_loader,
            optimizer,
            epoch,
            writer,
            hp.log_interval,
        ) for epoch in range(1, hp.epochs + 1)
    ]

    # Save the model
    model_file = "mnist_cnn.pt"
    torch.save(model.state_dict(), model_file)

    return TrainingOutputs(
        epoch_accuracies=accuracies,
        model_state=PythonPickledFile(model_file),
        logs=TensorboardLogs(log_dir),
    )