def pytorch_mnist_task(hp: Hyperparameters) -> TrainingOutputs: print("Start MNIST training:") world_size = torch.cuda.device_count() print(f"Device count: {world_size}") download_mnist(DATA_DIR) mp.spawn( train_mnist, args=(world_size, hp), nprocs=world_size, join=True, ) print("Training Complete") with open(ACCURACIES_FILE) as fp: accuracies = json.load(fp) return TrainingOutputs(epoch_accuracies=accuracies, model_state=PythonPickledFile(MODEL_FILE))
def pytorch_mnist_task(hp: Hyperparameters) -> TrainingOutputs: wandb_setup() # store the hyperparameters' config in ``wandb`` wandb.config.update(json.loads(hp.to_json())) # set random seed torch.manual_seed(hp.seed) # ideally, if GPU training is required, and if cuda is not available, we can raise an exception # however, as we want this algorithm to work locally as well (and most users don't have a GPU locally), we will fallback to using a CPU use_cuda = torch.cuda.is_available() print(f"Use cuda {use_cuda}") device = torch.device("cuda" if use_cuda else "cpu") # load data kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} training_data_loader = mnist_dataloader(hp.batch_size, train=True, **kwargs) test_data_loader = mnist_dataloader(hp.batch_size, train=False, **kwargs) # train the model model = Net().to(device) optimizer = optim.SGD(model.parameters(), lr=hp.learning_rate, momentum=hp.sgd_momentum) # run multiple epochs and capture the accuracies for each epoch # train the model: run multiple epochs and capture the accuracies for each epoch accuracies = [] for epoch in range(1, hp.epochs + 1): train(model, device, training_data_loader, optimizer, epoch, hp.log_interval) accuracies.append(test(model, device, test_data_loader)) # after training the model, we can simply save it to disk and return it from the Flyte task as a :py:class:`flytekit.types.file.FlyteFile` # type, which is the ``PythonPickledFile``. ``PythonPickledFile`` is simply a decorator on the ``FlyteFile`` that records the format # of the serialized model as ``pickled`` model_file = "mnist_cnn.pt" torch.save(model.state_dict(), model_file) return TrainingOutputs(epoch_accuracies=accuracies, model_state=PythonPickledFile(model_file))
def mnist_pytorch_job(hp: Hyperparameters) -> TrainingOutputs: log_dir = "logs" writer = SummaryWriter(log_dir) torch.manual_seed(hp.seed) use_cuda = torch.cuda.is_available() print(f"Use cuda {use_cuda}") device = torch.device("cuda" if use_cuda else "cpu") print("Using device: {}, world size: {}".format(device, WORLD_SIZE)) if should_distribute(): print("Using distributed PyTorch with {} backend".format(hp.backend)) dist.init_process_group(backend=hp.backend) # LOAD Data kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} train_loader = torch.utils.data.DataLoader( datasets.MNIST( "../data", train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]), ), batch_size=hp.batch_size, shuffle=True, **kwargs, ) test_loader = torch.utils.data.DataLoader( datasets.MNIST( "../data", train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]), ), batch_size=hp.test_batch_size, shuffle=False, **kwargs, ) # Train the model model = Net().to(device) if is_distributed(): Distributor = (nn.parallel.DistributedDataParallel if use_cuda else nn.parallel.DistributedDataParallelCPU) model = Distributor(model) optimizer = optim.SGD(model.parameters(), lr=hp.learning_rate, momentum=hp.sgd_momentum) accuracies = [ epoch_step( model, device, train_loader, test_loader, optimizer, epoch, writer, hp.log_interval, ) for epoch in range(1, hp.epochs + 1) ] # Save the model model_file = "mnist_cnn.pt" torch.save(model.state_dict(), model_file) return TrainingOutputs( epoch_accuracies=accuracies, model_state=PythonPickledFile(model_file), logs=TensorboardLogs(log_dir), )