Exemplo n.º 1
0
def get_model(model_checkpoint_path):
    checkpoint_dict = Trainer.load_checkpoint_from_path(model_checkpoint_path)
    model_state = checkpoint_dict["model_state_dict"]

    model = ResNet18(None)
    model.conv1 = nn.Conv2d(1,
                            64,
                            kernel_size=7,
                            stride=1,
                            padding=3,
                            bias=False)
    model.load_state_dict(model_state)

    return model
Exemplo n.º 2
0
def training_loop(config):
    # Create model.
    model = ResNet18(config)
    model.conv1 = nn.Conv2d(1,
                            64,
                            kernel_size=7,
                            stride=1,
                            padding=3,
                            bias=False)
    model = train.torch.prepare_model(model)

    # Create optimizer.
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config.get("lr", 0.1),
        momentum=config.get("momentum", 0.9),
    )

    # Load in training and validation data.
    train_dataset = load_mnist_data(True, True)
    validation_dataset = load_mnist_data(False, False)

    if config["test_mode"]:
        train_dataset = Subset(train_dataset, list(range(64)))
        validation_dataset = Subset(validation_dataset, list(range(64)))

    train_loader = DataLoader(train_dataset,
                              batch_size=config["batch_size"],
                              num_workers=2)
    validation_loader = DataLoader(validation_dataset,
                                   batch_size=config["batch_size"],
                                   num_workers=2)

    train_loader = train.torch.prepare_data_loader(train_loader)
    validation_loader = train.torch.prepare_data_loader(validation_loader)

    # Create loss.
    criterion = nn.CrossEntropyLoss()

    for epoch_idx in range(2):
        train_epoch(train_loader, model, criterion, optimizer)
        validation_loss = validate_epoch(validation_loader, model, criterion)

        train.save_checkpoint(model_state_dict=model.module.state_dict())
        train.report(**validation_loss)
def train_func(config):
    epochs = config.pop("epochs", 3)
    model = ResNet18(config)
    model = train.torch.prepare_model(model)

    # Create optimizer.
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config.get("lr", 0.1),
        momentum=config.get("momentum", 0.9),
    )

    # Load in training and validation data.
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])  # meanstd transformation

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    with FileLock(".ray.lock"):
        train_dataset = CIFAR10(root="~/data",
                                train=True,
                                download=True,
                                transform=transform_train)
        validation_dataset = CIFAR10(root="~/data",
                                     train=False,
                                     download=False,
                                     transform=transform_test)

    if config.get("test_mode"):
        train_dataset = Subset(train_dataset, list(range(64)))
        validation_dataset = Subset(validation_dataset, list(range(64)))

    worker_batch_size = config["batch_size"] // train.world_size()

    train_loader = DataLoader(train_dataset, batch_size=worker_batch_size)
    validation_loader = DataLoader(validation_dataset,
                                   batch_size=worker_batch_size)

    train_loader = train.torch.prepare_data_loader(train_loader)
    validation_loader = train.torch.prepare_data_loader(validation_loader)

    # Create loss.
    criterion = nn.CrossEntropyLoss()

    results = []

    for _ in range(epochs):
        train_epoch(train_loader, model, criterion, optimizer)
        result = validate_epoch(validation_loader, model, criterion)
        train.report(**result)
        results.append(result)

    return results
Exemplo n.º 4
0
def train_loop_per_worker(config):
    import horovod.torch as hvd

    hvd.init()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = ResNet18(None).to(device)
    optimizer = torch.optim.SGD(
        net.parameters(),
        lr=config["lr"],
    )
    epoch = 0

    checkpoint = train.load_checkpoint()
    if checkpoint:
        model_state = checkpoint["model_state"]
        optimizer_state = checkpoint["optimizer_state"]
        epoch = checkpoint["epoch"]

        net.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)

    criterion = nn.CrossEntropyLoss()
    optimizer = hvd.DistributedOptimizer(optimizer)
    np.random.seed(1 + hvd.rank())
    torch.manual_seed(1234)
    # To ensure consistent initialization across workers,
    hvd.broadcast_parameters(net.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    trainset = ray.get(config["data"])
    trainloader = DataLoader(trainset,
                             batch_size=int(config["batch_size"]),
                             shuffle=True,
                             num_workers=4)

    for epoch in range(epoch, 40):  # loop over the dataset multiple times
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(trainloader):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_steps += 1
            train.report(loss=running_loss / epoch_steps)
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print("[%d, %5d] loss: %.3f" %
                      (epoch + 1, i + 1, running_loss / epoch_steps))

        train.save_checkpoint(
            model_state=net.state_dict(),
            optimizer_state=optimizer.state_dict(),
            epoch=epoch,
        )