Пример #1
0
    def to_air_checkpoint(self) -> Optional[Checkpoint]:
        from ray.tune.trainable.util import TrainableUtil

        checkpoint_data = self.dir_or_data

        if not checkpoint_data:
            return None

        if isinstance(checkpoint_data, ray.ObjectRef):
            checkpoint_data = ray.get(checkpoint_data)

        if isinstance(checkpoint_data, str):
            checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data)
            checkpoint = Checkpoint.from_directory(checkpoint_dir)
        elif isinstance(checkpoint_data, bytes):
            with tempfile.TemporaryDirectory() as tmpdir:
                TrainableUtil.create_from_pickle(checkpoint_data, tmpdir)
                # Double wrap in checkpoint so we hold the data in memory and
                # can remove the temp directory
                checkpoint = Checkpoint.from_dict(
                    Checkpoint.from_directory(tmpdir).to_dict())
        elif isinstance(checkpoint_data, dict):
            checkpoint = Checkpoint.from_dict(checkpoint_data)
        else:
            raise RuntimeError(
                f"Unknown checkpoint data type: {type(checkpoint_data)}")

        return checkpoint
Пример #2
0
    def train_func(config):
        itr = 0
        ckpt = session.get_checkpoint()
        if ckpt is not None:
            ckpt = ckpt.to_dict()
            itr = ckpt["iter"] + 1

        for i in range(itr, config["max_iter"]):
            session.report(
                dict(test=i, training_iteration=i),
                checkpoint=Checkpoint.from_dict(dict(iter=i)),
            )
Пример #3
0
    def train_func():
        ckpt = session.get_checkpoint()
        restored = bool(ckpt)  # Does a previous checkpoint exist?
        itr = 0
        if ckpt:
            ckpt = ckpt.to_dict()
            itr = ckpt["iter"] + 1

        for i in range(itr, 4):
            if i == 2 and not restored:
                raise Exception("try to fail me")
            session.report(
                dict(test=i, training_iteration=i),
                checkpoint=Checkpoint.from_dict(dict(iter=i)),
            )
Пример #4
0
def test_resume_from_checkpoint():
    with pytest.raises(ValueError):
        DummyTrainer(resume_from_checkpoint="invalid")

    with pytest.raises(ValueError):
        DummyTrainer(resume_from_checkpoint=False)

    with pytest.raises(ValueError):
        DummyTrainer(resume_from_checkpoint=True)

    with pytest.raises(ValueError):
        DummyTrainer(resume_from_checkpoint={})

    # Succeed
    DummyTrainer(resume_from_checkpoint=None)

    # Succeed
    DummyTrainer(resume_from_checkpoint=Checkpoint.from_dict({"empty": ""}))
Пример #5
0
    def commit(self, path: Optional[Path] = None) -> None:
        """Commit checkpoint to disk, if needed.

        Args:
            path: Path to commit checkpoint to.
        """
        if self.storage_mode == CheckpointStorage.MEMORY:
            # Do not persist memory checkpoints
            return

        if not path:
            # If no path is given, skip
            return

        if not isinstance(self.dir_or_data, dict):
            # Only persist dictionaries
            return

        checkpoint = Checkpoint.from_dict(self.dir_or_data)
        self.dir_or_data = checkpoint.to_directory(str(path))
Пример #6
0
    def to_air_checkpoint(self) -> Optional[Checkpoint]:
        from ray.tune.trainable.util import TrainableUtil

        checkpoint_data = self.dir_or_data

        if not checkpoint_data:
            return None

        if isinstance(checkpoint_data, ray.ObjectRef):
            checkpoint_data = ray.get(checkpoint_data)

        if isinstance(checkpoint_data, str):
            try:
                checkpoint_dir = TrainableUtil.find_checkpoint_dir(
                    checkpoint_data)
            except FileNotFoundError:
                if log_once("checkpoint_not_available"):
                    logger.error(
                        f"The requested checkpoint is not available on this node, "
                        f"most likely because you are using Ray client or disabled "
                        f"checkpoint synchronization. To avoid this, enable checkpoint "
                        f"synchronization to cloud storage by specifying a "
                        f"`SyncConfig`. The checkpoint may be available on a different "
                        f"node - please check this location on worker nodes: "
                        f"{checkpoint_data}")
                return None
            checkpoint = Checkpoint.from_directory(checkpoint_dir)
        elif isinstance(checkpoint_data, bytes):
            checkpoint = Checkpoint.from_bytes(checkpoint_data)
        elif isinstance(checkpoint_data, dict):
            checkpoint = Checkpoint.from_dict(checkpoint_data)
        else:
            raise RuntimeError(
                f"Unknown checkpoint data type: {type(checkpoint_data)}")

        return checkpoint
Пример #7
0
def function_trainable_dict(config):
    session.report({"metric": 2},
                   checkpoint=Checkpoint.from_dict({"checkpoint_data": 3}))
Пример #8
0
 def train_func():
     for i in range(9):
         session.report(dict(test=i))
     session.report(
         dict(test=i + 1), checkpoint=Checkpoint.from_dict(dict(hello="world"))
     )
Пример #9
0
def train_func(config):
    use_gpu = config["use_gpu"]
    num_epochs = config["num_epochs"]
    batch_size = config["batch_size"]
    num_layers = config["num_layers"]
    num_hidden = config["num_hidden"]
    dropout_every = config["dropout_every"]
    dropout_prob = config["dropout_prob"]
    num_features = config["num_features"]

    print("Defining model, loss, and optimizer...")

    # Setup device.
    device = torch.device(f"cuda:{session.get_local_rank()}"
                          if use_gpu and torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    # Setup data.
    train_dataset_pipeline = session.get_dataset_shard("train")
    train_dataset_epoch_iterator = train_dataset_pipeline.iter_epochs()
    test_dataset = session.get_dataset_shard("test")
    test_torch_dataset = test_dataset.to_torch(label_column="label",
                                               batch_size=batch_size,
                                               drop_last=True)

    net = Net(
        n_layers=num_layers,
        n_features=num_features,
        num_hidden=num_hidden,
        dropout_every=dropout_every,
        drop_prob=dropout_prob,
    ).to(device)
    print(net.parameters)

    net = train.torch.prepare_model(net)

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(net.parameters(), weight_decay=0.0001)

    print("Starting training...")
    for epoch in range(num_epochs):
        train_dataset = next(train_dataset_epoch_iterator)

        train_torch_dataset = train_dataset.to_torch(label_column="label",
                                                     batch_size=batch_size)

        train_running_loss, train_num_correct, train_num_total = train_epoch(
            train_torch_dataset, net, device, criterion, optimizer)
        train_acc = train_num_correct / train_num_total
        print(f"epoch [{epoch + 1}]: training accuracy: "
              f"{train_num_correct} / {train_num_total} = {train_acc:.4f}")

        test_running_loss, test_num_correct, test_num_total = test_epoch(
            test_torch_dataset, net, device, criterion)
        test_acc = test_num_correct / test_num_total
        print(f"epoch [{epoch + 1}]: testing accuracy: "
              f"{test_num_correct} / {test_num_total} = {test_acc:.4f}")

        # Checkpoint model.
        module = net.module if isinstance(net,
                                          DistributedDataParallel) else net
        checkpoint = Checkpoint.from_dict(dict(model=module.state_dict()))

        # Record and log stats.
        print(f"session report on {session.get_world_rank()}")
        session.report(
            dict(
                train_acc=train_acc,
                train_loss=train_running_loss,
                test_acc=test_acc,
                test_loss=test_running_loss,
            ),
            checkpoint=checkpoint,
        )