Esempio n. 1
0
def test_fail_state(ray_start_2_cpus):  # noqa: F811
    """Tests if state of training with failure is same as training without."""
    if not dist.is_available():
        return

    torch.manual_seed(0)

    def single_loader(config):
        dataset = LinearDataset(2, 5, size=1000000)
        return DataLoader(dataset, batch_size=config.get("batch_size", 32))

    TestOperator = TrainingOperator.from_creators(
        model_creator,
        optimizer_creator,
        single_loader,
        loss_creator=lambda config: nn.MSELoss(),
    )

    def init_hook():
        torch.manual_seed(0)

    trainer1 = TorchTrainer(
        training_operator_cls=TestOperator,
        config={"batch_size": 100000},
        timeout_s=5,
        initialization_hook=init_hook,
        num_workers=2,
    )
    initial_state = trainer1.state_dict()
    trainer1.train()
    trainer1_state = trainer1.state_dict()
    assert trainer1_state != initial_state
    trainer1.shutdown()

    trainer2 = TorchTrainer(
        training_operator_cls=TestOperator,
        config={"batch_size": 100000},
        timeout_s=5,
        initialization_hook=init_hook,
        num_workers=2,
    )
    trainer2.load_state_dict(initial_state)
    trainer2.train()
    assert trainer2.state_dict() == trainer1_state
    trainer2.shutdown()

    start_with_fail = gen_start_with_fail(1)
    with patch.object(TorchTrainer, "_start_workers", start_with_fail):
        trainer3 = TorchTrainer(
            training_operator_cls=TestOperator,
            config={"batch_size": 100000},
            timeout_s=5,
            initialization_hook=init_hook,
            num_workers=2,
        )
        trainer3.load_state_dict(initial_state)
        trainer3.train()
        assert trainer3.state_dict() == trainer1_state
        trainer3.shutdown()
Esempio n. 2
0
def test_resize(ray_8_node_2_cpu):
    """Tests if placement group is removed when trainer is resized."""
    assert ray.available_resources()["CPU"] == 16
    placement_group_table = ray.state.state.placement_group_table()
    assert len(placement_group_table) == 0

    trainer = TorchTrainer(
        training_operator_cls=Operator,
        num_workers=7,
        use_gpu=False,
    )

    assert ray.available_resources()["CPU"] == 9
    placement_group_table = ray.state.state.placement_group_table()
    assert len(placement_group_table) == 1
    placement_group_id = list(placement_group_table)[0]
    placement_group = placement_group_table[placement_group_id]
    assert placement_group["state"] == "CREATED"

    trainer._resize_worker_group(trainer.state_dict())

    assert ray.available_resources()["CPU"] == 9
    placement_group_table = ray.state.state.placement_group_table()
    assert len(placement_group_table) == 2
    placement_group = placement_group_table[placement_group_id]
    assert placement_group["state"] == "REMOVED"
    placement_group_table_keys = list(placement_group_table)
    placement_group_table_keys.remove(placement_group_id)
    second_placement_group_id = placement_group_table_keys[0]
    second_placement_group = placement_group_table[second_placement_group_id]
    assert second_placement_group["state"] == "CREATED"

    trainer.shutdown()

    assert ray.available_resources()["CPU"] == 16
    placement_group_table = ray.state.state.placement_group_table()
    assert len(placement_group_table) == 2
    placement_group = placement_group_table[placement_group_id]
    assert placement_group["state"] == "REMOVED"
    second_placement_group = placement_group_table[second_placement_group_id]
    assert second_placement_group["state"] == "REMOVED"
Esempio n. 3
0
def test_multi_model(ray_start_2_cpus, num_workers, use_local):
    def train(*, model=None, criterion=None, optimizer=None, iterator=None):
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        for batch_idx, (inputs, targets) in enumerate(iterator):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        return {
            "accuracy": correct / total,
            "train_loss": train_loss / (batch_idx + 1)
        }

    def train_epoch(self, iterator, info):
        result = {}
        data = list(iterator)
        for i, (model, optimizer) in enumerate(
                zip(self.models, self.optimizers)):
            result[f"model_{i}"] = train(
                model=model,
                criterion=self.criterion,
                optimizer=optimizer,
                iterator=iter(data))
        return result

    class MultiModelOperator(TrainingOperator):
        def setup(self, config):
            models = nn.Linear(1, 1), nn.Linear(1, 1)
            opts = [
                torch.optim.SGD(model.parameters(), lr=0.0001)
                for model in models
            ]
            loss = nn.MSELoss()
            train_dataloader, val_dataloader = data_creator(config)
            self.models, self.optimizers, self.criterion = self.register(
                models=models, optimizers=opts, criterion=loss)
            self.register_data(
                train_loader=train_dataloader,
                validation_loader=val_dataloader)

    TestOperator = get_test_operator(MultiModelOperator)

    trainer1 = TorchTrainer(
        config={"custom_func": train_epoch},
        training_operator_cls=TestOperator,
        num_workers=num_workers,
        use_local=use_local,
        use_gpu=False,
    )
    trainer1.train()
    state = trainer1.state_dict()

    models1 = trainer1.get_model()

    trainer1.shutdown()

    trainer2 = TorchTrainer(
        config={"custom_func": train_epoch},
        training_operator_cls=TestOperator,
        num_workers=num_workers,
        use_local=use_local,
        use_gpu=False,
    )
    trainer2.load_state_dict(state)

    models2 = trainer2.get_model()

    for model_1, model_2 in zip(models1, models2):

        model1_state_dict = model_1.state_dict()
        model2_state_dict = model_2.state_dict()

        assert set(model1_state_dict.keys()) == set(model2_state_dict.keys())

        for k in model1_state_dict:
            assert torch.equal(model1_state_dict[k], model2_state_dict[k])

    trainer2.shutdown()
Esempio n. 4
0
def test_multi_model(ray_start_2_cpus, num_workers):
    def train(*, model=None, criterion=None, optimizer=None, dataloader=None):
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        return {
            "accuracy": correct / total,
            "train_loss": train_loss / (batch_idx + 1)
        }

    def train_epoch(self, iterator, info):
        result = {}
        for i, (model,
                optimizer) in enumerate(zip(self.models, self.optimizers)):
            result["model_{}".format(i)] = train(model=model,
                                                 criterion=self.criterion,
                                                 optimizer=optimizer,
                                                 dataloader=iterator)
        return result

    def multi_model_creator(config):
        return nn.Linear(1, 1), nn.Linear(1, 1)

    def multi_optimizer_creator(models, config):
        opts = [
            torch.optim.SGD(model.parameters(), lr=0.0001) for model in models
        ]
        return opts[0], opts[1]

    trainer1 = TorchTrainer(model_creator=multi_model_creator,
                            data_creator=data_creator,
                            optimizer_creator=multi_optimizer_creator,
                            loss_creator=lambda config: nn.MSELoss(),
                            config={"custom_func": train_epoch},
                            training_operator_cls=_TestingOperator,
                            num_workers=num_workers)
    trainer1.train()
    state = trainer1.state_dict()

    models1 = trainer1.get_model()

    trainer1.shutdown()

    trainer2 = TorchTrainer(model_creator=multi_model_creator,
                            data_creator=data_creator,
                            optimizer_creator=multi_optimizer_creator,
                            loss_creator=lambda config: nn.MSELoss(),
                            config={"custom_func": train_epoch},
                            training_operator_cls=_TestingOperator,
                            num_workers=num_workers)
    trainer2.load_state_dict(state)

    models2 = trainer2.get_model()

    for model_1, model_2 in zip(models1, models2):

        model1_state_dict = model_1.state_dict()
        model2_state_dict = model_2.state_dict()

        assert set(model1_state_dict.keys()) == set(model2_state_dict.keys())

        for k in model1_state_dict:
            assert torch.equal(model1_state_dict[k], model2_state_dict[k])

    trainer2.shutdown()
                                   batch_size=config["batch"])
    return train_loader, validation_loader


def optimizer_creator(model, config):
    """Returns an optimizer (or multiple)"""
    return torch.optim.SGD(model.parameters(), lr=config["lr"])


ray.init(address="auto")

trainer = TorchTrainer(
    model_creator=ResNet18,  # A function that returns a nn.Module
    data_creator=cifar_creator,  # A function that returns dataloaders
    optimizer_creator=optimizer_creator,  # A function that returns an optimizer
    loss_creator=torch.nn.CrossEntropyLoss,  # A loss function
    config={
        "lr": 0.01,
        "batch": 64
    },  # parameters
    num_workers=4,  # amount of parallelism
    use_gpu=torch.cuda.is_available(),
    use_tqdm=True)

stats = trainer.train()
print(trainer.validate())

torch.save(trainer.state_dict(), "checkpoint.pt")
trainer.shutdown()
print("success!")