def get_model(model_checkpoint_path): checkpoint_dict = Trainer.load_checkpoint_from_path(model_checkpoint_path) model_state = checkpoint_dict["model_state_dict"] model = ResNet18(None) model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3, bias=False) model.load_state_dict(model_state) return model
def training_loop(config): # Create model. model = ResNet18(config) model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3, bias=False) model = train.torch.prepare_model(model) # Create optimizer. optimizer = torch.optim.SGD( model.parameters(), lr=config.get("lr", 0.1), momentum=config.get("momentum", 0.9), ) # Load in training and validation data. train_dataset = load_mnist_data(True, True) validation_dataset = load_mnist_data(False, False) if config["test_mode"]: train_dataset = Subset(train_dataset, list(range(64))) validation_dataset = Subset(validation_dataset, list(range(64))) train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], num_workers=2) validation_loader = DataLoader(validation_dataset, batch_size=config["batch_size"], num_workers=2) train_loader = train.torch.prepare_data_loader(train_loader) validation_loader = train.torch.prepare_data_loader(validation_loader) # Create loss. criterion = nn.CrossEntropyLoss() for epoch_idx in range(2): train_epoch(train_loader, model, criterion, optimizer) validation_loss = validate_epoch(validation_loader, model, criterion) train.save_checkpoint(model_state_dict=model.module.state_dict()) train.report(**validation_loss)
def train_func(config): epochs = config.pop("epochs", 3) model = ResNet18(config) model = train.torch.prepare_model(model) # Create optimizer. optimizer = torch.optim.SGD( model.parameters(), lr=config.get("lr", 0.1), momentum=config.get("momentum", 0.9), ) # Load in training and validation data. transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) # meanstd transformation transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) with FileLock(".ray.lock"): train_dataset = CIFAR10(root="~/data", train=True, download=True, transform=transform_train) validation_dataset = CIFAR10(root="~/data", train=False, download=False, transform=transform_test) if config.get("test_mode"): train_dataset = Subset(train_dataset, list(range(64))) validation_dataset = Subset(validation_dataset, list(range(64))) worker_batch_size = config["batch_size"] // train.world_size() train_loader = DataLoader(train_dataset, batch_size=worker_batch_size) validation_loader = DataLoader(validation_dataset, batch_size=worker_batch_size) train_loader = train.torch.prepare_data_loader(train_loader) validation_loader = train.torch.prepare_data_loader(validation_loader) # Create loss. criterion = nn.CrossEntropyLoss() results = [] for _ in range(epochs): train_epoch(train_loader, model, criterion, optimizer) result = validate_epoch(validation_loader, model, criterion) train.report(**result) results.append(result) return results
def train_loop_per_worker(config): import horovod.torch as hvd hvd.init() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = ResNet18(None).to(device) optimizer = torch.optim.SGD( net.parameters(), lr=config["lr"], ) epoch = 0 checkpoint = train.load_checkpoint() if checkpoint: model_state = checkpoint["model_state"] optimizer_state = checkpoint["optimizer_state"] epoch = checkpoint["epoch"] net.load_state_dict(model_state) optimizer.load_state_dict(optimizer_state) criterion = nn.CrossEntropyLoss() optimizer = hvd.DistributedOptimizer(optimizer) np.random.seed(1 + hvd.rank()) torch.manual_seed(1234) # To ensure consistent initialization across workers, hvd.broadcast_parameters(net.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) trainset = ray.get(config["data"]) trainloader = DataLoader(trainset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=4) for epoch in range(epoch, 40): # loop over the dataset multiple times running_loss = 0.0 epoch_steps = 0 for i, data in enumerate(trainloader): # get the inputs; data is a list of [inputs, labels] inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() epoch_steps += 1 train.report(loss=running_loss / epoch_steps) if i % 2000 == 1999: # print every 2000 mini-batches print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / epoch_steps)) train.save_checkpoint( model_state=net.state_dict(), optimizer_state=optimizer.state_dict(), epoch=epoch, )