Exemple #1
def train_func(config):
    batch_size = config["batch_size"]
    lr = config["lr"]
    epochs = config["epochs"]

    worker_batch_size = batch_size // train.world_size()

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=worker_batch_size)
    test_dataloader = DataLoader(test_data, batch_size=worker_batch_size)

    train_dataloader = train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = train.torch.prepare_data_loader(test_dataloader)

    # Create model.
    model = NeuralNetwork()
    model = train.torch.prepare_model(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    for _ in range(epochs):
        train_epoch(train_dataloader, model, loss_fn, optimizer)
        loss = validate_epoch(test_dataloader, model, loss_fn)
def validate_epoch(dataloader, model, loss_fn):
    size = len(dataloader.dataset) // train.world_size()
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n "
          f"Accuracy: {(100 * correct):>0.1f}%, "
          f"Avg loss: {test_loss:>8f} \n")
    return {"loss": test_loss}
def train_epoch(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset) // train.world_size()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
Exemple #4
    def prepare_model(
        model: torch.nn.Module,
        move_to_device: bool = True,
        wrap_ddp: bool = True,
        ddp_kwargs: Optional[Dict[str, Any]] = None,
    ) -> torch.nn.Module:
        """Prepares the model for distributed execution.

        This allows you to use the same exact code regardless of number of
        workers or the device type being used (CPU, GPU).

            model (torch.nn.Module): A torch model to prepare.
            move_to_device (bool): Whether to move the model to the correct
                device. If set to False, the model needs to manually be moved
                to the correct device.
            wrap_ddp (bool): Whether to wrap models in
            ddp_kwargs (Dict[str, Any]): Args to pass into
                ``DistributedDataParallel`` initialization if ``wrap_ddp`` is
                set to True.
        ddp_kwargs = ddp_kwargs or {}

        rank = train.local_rank()

        device = self.get_device()

        if torch.cuda.is_available():

        if move_to_device:
            logger.info(f"Moving model to device: {device}")
            model = model.to(device)
        if wrap_ddp and train.world_size() > 1:
            logger.info("Wrapping provided model in DDP.")
            if torch.cuda.is_available():
                model = DistributedDataParallel(model,
                model = DistributedDataParallel(model, **ddp_kwargs)

        return model
def _huggingface_train_loop_per_worker(config):
    """Per-worker training loop for HuggingFace Transformers."""
    trainer_init_per_worker = config.pop("_trainer_init_per_worker")

    # Env vars necessary for HF to setup DDP
    os.environ["RANK"] = str(train.world_rank())
    os.environ["WORLD_SIZE"] = str(train.world_size())
    os.environ["LOCAL_RANK"] = str(train.local_rank())

    train_dataset = train.get_dataset_shard(TRAIN_DATASET_KEY)
    eval_dataset = train.get_dataset_shard(EVALUATION_DATASET_KEY)

    train_torch_dataset, eval_torch_dataset = process_datasets(

    trainer: transformers.trainer.Trainer = trainer_init_per_worker(
        train_torch_dataset, eval_torch_dataset, **config)

    if trainer.args.push_to_hub and not trainer.args.hub_token:
            "You have set `push_to_hub=True` but didn't specify `hub_token`. "
            "Pushing to hub will most likely fail, as the credentials will not "
            "be automatically propagated from the local enviroment to the Ray Actors. "
            "If that happens, specify `hub_token` in `TrainingArguments`.")

    if (trainer.args.evaluation_strategy == "steps"
            or trainer.args.save_strategy == "steps"
            or trainer.args.logging_strategy == "steps"):
        raise ValueError(
            "'steps' value for `evaluation_strategy`, `logging_strategy` "
            "or `save_strategy` is not yet supported.")

    trainer = wrap_transformers_trainer(trainer)

    # ensure no HF logging callbacks are added
    # aside from doubling functionality with our callbacks,
    # the Wandb callbacks causes training to freeze
    integration_callbacks = transformers.trainer.get_reporting_integration_callbacks(
    for callback in integration_callbacks:


    checkpoint = session.get_checkpoint()
    checkpoint_path = None
    remove_checkpoint_path = False
    if checkpoint:
        assert isinstance(checkpoint, Checkpoint)
        checkpoint_dict = checkpoint.to_dict()
        source_ip = checkpoint_dict[NODE_IP_KEY]
        source_path = checkpoint_dict[CHECKPOINT_PATH_ON_NODE_KEY]
        target_ip = get_node_ip_address()
        if source_ip == target_ip:
            checkpoint_path = source_path
            checkpoint_path = tempfile.mkdtemp(
            remove_checkpoint_path = True
    if remove_checkpoint_path:
        shutil.rmtree(checkpoint_path, ignore_errors=True)
def train_func(config):
    epochs = config.pop("epochs", 3)
    model = ResNet18(config)
    model = train.torch.prepare_model(model)

    # Create optimizer.
    optimizer = torch.optim.SGD(
        lr=config.get("lr", 0.1),
        momentum=config.get("momentum", 0.9),

    # Load in training and validation data.
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])  # meanstd transformation

    transform_test = transforms.Compose([
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),

    with FileLock(".ray.lock"):
        train_dataset = CIFAR10(root="~/data",
        validation_dataset = CIFAR10(root="~/data",

    if config.get("test_mode"):
        train_dataset = Subset(train_dataset, list(range(64)))
        validation_dataset = Subset(validation_dataset, list(range(64)))

    worker_batch_size = config["batch_size"] // train.world_size()

    train_loader = DataLoader(train_dataset, batch_size=worker_batch_size)
    validation_loader = DataLoader(validation_dataset,

    train_loader = train.torch.prepare_data_loader(train_loader)
    validation_loader = train.torch.prepare_data_loader(validation_loader)

    # Create loss.
    criterion = nn.CrossEntropyLoss()

    results = []

    for _ in range(epochs):
        train_epoch(train_loader, model, criterion, optimizer)
        result = validate_epoch(validation_loader, model, criterion)

    return results
def train_loop_per_worker(train_loop_config):
    dataset = train_loop_config["dataset_fn"]()
    batch_size = train_loop_config["batch_size"]
    num_epochs = train_loop_config["num_epochs"]

    data = dataset[0]
    train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
    train_idx = train_idx.split(train_idx.size(0) // train.world_size())[

    train_loader = NeighborSampler(
        sizes=[25, 10],

    # Disable distributed sampler since the train_loader has already been split above.
    train_loader = train.torch.prepare_data_loader(train_loader, add_dist_sampler=False)

    # Do validation on rank 0 worker only.
    if train.world_rank() == 0:
        subgraph_loader = NeighborSampler(
            data.edge_index, node_idx=None, sizes=[-1], batch_size=2048, shuffle=False
        subgraph_loader = train.torch.prepare_data_loader(
            subgraph_loader, add_dist_sampler=False

    model = SAGE(dataset.num_features, 256, dataset.num_classes)
    model = train.torch.prepare_model(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    x, y = data.x.to(train.torch.get_device()), data.y.to(train.torch.get_device())

    for epoch in range(num_epochs):

        # ``batch_size`` is the number of samples in the current batch.
        # ``n_id`` are the ids of all the nodes used in the computation. This is
        # needed to pull in the necessary features just for the current batch that is
        # being trained on.
        # ``adjs`` is a list of 3 element tuple consisting of ``(edge_index, e_id,
        # size)`` for each sample in the batch, where ``edge_index``represent the
        # edges of the sampled subgraph, ``e_id`` are the ids of the edges in the
        # sample, and ``size`` holds the shape of the subgraph.
        # See ``torch_geometric.loader.neighbor_sampler.NeighborSampler`` for more info.
        for batch_size, n_id, adjs in train_loader:
            out = model(x[n_id], adjs)
            loss = F.nll_loss(out, y[n_id[:batch_size]])

        if train.world_rank() == 0:
            print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}")

        train_accuracy = validation_accuracy = test_accuracy = None

        # Do validation on rank 0 worker only.
        if train.world_rank() == 0:
            with torch.no_grad():
                out = model.module.test(x, subgraph_loader)
            res = out.argmax(dim=-1) == data.y
            train_accuracy = int(res[data.train_mask].sum()) / int(
            validation_accuracy = int(res[data.val_mask].sum()) / int(
            test_accuracy = int(res[data.test_mask].sum()) / int(data.test_mask.sum())

Exemple #8
def prepare_data_loader(
    data_loader: torch.utils.data.DataLoader,
    add_dist_sampler: bool = True,
    move_to_device: bool = True,
) -> torch.utils.data.DataLoader:
    Prepares DataLoader for distributed execution.

    This allows you to use the same exact code regardless of number of
    workers or the device type being used (CPU, GPU).

        data_loader (torch.utils.data.DataLoader): The DataLoader to
        add_dist_sampler (bool): Whether to add a DistributedSampler to
            the provided DataLoader.
        move_to_device (bool): If set, automatically move the data
            returned by the data loader to the correct device.

    # Only add Distributed Sampler if the following conditions hold:
    # 1. More than one training worker is being used.
    # 2. A DistributedSampler has not already been added by the user.
    # 3. The dataset is not an IterableDataset. Samplers do not worker with
    # IterableDatasets.
    if (train.world_size() > 1
            and not isinstance(data_loader.sampler, DistributedSampler)
            and not (hasattr(data_loader, "dataset")
                     and isinstance(data_loader.dataset, IterableDataset))
            and add_dist_sampler):

        def with_sampler(loader):
            # Automatically set the DistributedSampler

            # If using a sampler, the shuffle attribute in the
            # DataLoader must be set to False.
            # Instead the shuffling is determined by the shuffle attribute
            # in the DistributedSampler.
            # We identify if shuffling is enabled in the passed in
            # DataLoader by seeing if the sampler for the DataLoader is a
            # SequentialSampler.
            shuffle = not isinstance(loader.sampler, SequentialSampler)

            data_loader_args = {
                "dataset": loader.dataset,
                "batch_size": loader.batch_size,
                "shuffle": False,
                "num_workers": loader.num_workers,
                "collate_fn": loader.collate_fn,
                "pin_memory": loader.pin_memory,
                "drop_last": loader.drop_last,
                "timeout": loader.timeout,
                "worker_init_fn": loader.worker_init_fn,
                "sampler": DistributedSampler(loader.dataset, shuffle=shuffle),
            return DataLoader(**data_loader_args)

        data_loader = with_sampler(data_loader)

    if move_to_device:
        device = get_device()
        data_loader = _WrappedDataLoader(data_loader, device)

    return data_loader
Exemple #9
    def prepare_model(
        model: torch.nn.Module,
        move_to_device: bool = True,
        wrap_ddp: bool = True,
        ddp_kwargs: Optional[Dict[str, Any]] = None,
    ) -> torch.nn.Module:
        """Prepares the model for distributed execution.

        This allows you to use the same exact code regardless of number of
        workers or the device type being used (CPU, GPU).

            model (torch.nn.Module): A torch model to prepare.
            move_to_device (bool): Whether to move the model to the correct
                device. If set to False, the model needs to manually be moved
                to the correct device.
            wrap_ddp (bool): Whether to wrap models in
            ddp_kwargs (Dict[str, Any]): Args to pass into
                ``DistributedDataParallel`` initialization if ``wrap_ddp`` is
                set to True.
        ddp_kwargs = ddp_kwargs or {}

        rank = train.local_rank()

        device = self.get_device()

        if torch.cuda.is_available():

        if move_to_device:
            logger.info(f"Moving model to device: {device}")
            model = model.to(device)

        def wrap_forward(forward):
            def wrapper(*args, **kwargs):
                with autocast():
                    outputs = forward(*args, **kwargs)
                assert isinstance(outputs, torch.Tensor)
                return outputs.float()

            return wrapper

        def model_get_state(self):
            # `__getstate__` is an special method that informs pickle which attributes
            # to serialize. This custom implementation ensures that the wrapped forward
            # method and custom `__getstate__` method aren't serialized.
            state = self.__dict__.copy()
            state["forward"] = state["_unwrapped_forward"]
            del state["_unwrapped_forward"]
            del state["__getstate__"]
            return state

        if self.amp_is_enabled:
            # Pickle cannot serialize the wrapped forward method. As a workaround,
            # define a custom `__getstate__` method that unwraps the forward method.
            model._unwrapped_forward = model.forward
            model.forward = wrap_forward(model.forward)
            # `__getstate__` must be a bound method rather than an callable attribute.
            # See https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance.  # noqa: E501
            assert not hasattr(model, "__getstate__")
            model.__getstate__ = types.MethodType(model_get_state, model)

        if wrap_ddp and train.world_size() > 1:
            logger.info("Wrapping provided model in DDP.")
            if torch.cuda.is_available():
                model = DistributedDataParallel(model,
                model = DistributedDataParallel(model, **ddp_kwargs)

        return model
Exemple #10
    def prepare_data_loader(
        data_loader: torch.utils.data.DataLoader,
        add_dist_sampler: bool = True,
        move_to_device: bool = True,
        auto_transfer: bool = True,
    ) -> torch.utils.data.DataLoader:
        """Prepares DataLoader for distributed execution.

        This allows you to use the same exact code regardless of number of
        workers or the device type being used (CPU, GPU).

            data_loader (torch.utils.data.DataLoader): The DataLoader to
            add_dist_sampler (bool): Whether to add a DistributedSampler to
                the provided DataLoader.
            move_to_device (bool): If set, automatically move the data
                returned by the data loader to the correct device.
            auto_transfer (bool): If set and device is GPU, another CUDA stream
                is created to automatically copy data from host (CPU) memory
                to device (GPU) memory (the default CUDA stream still runs the
                training procedure). If device is CPU, it will be disabled
                regardless of the setting. This configuration will be ignored
                if ``move_to_device`` is False.

        # Only add Distributed Sampler if the following conditions hold:
        # 1. More than one training worker is being used.
        # 2. A DistributedSampler has not already been added by the user.
        # 3. The dataset is not an IterableDataset. Samplers do not worker with
        # IterableDatasets.
        if (train.world_size() > 1
                and not isinstance(data_loader.sampler, DistributedSampler)
                and not (hasattr(data_loader, "dataset")
                         and isinstance(data_loader.dataset, IterableDataset))
                and add_dist_sampler):

            def with_sampler(loader):
                # Automatically set the DistributedSampler

                # If you're using a sampler, the DataLoader shuffle flag must be set to
                # False. Shuffling is instead determined by the shuffle argument passed
                # to the DistributedSampler constructor.

                # If no sampler is passed to the DataLoader constructor, Torch
                # constructs a default sampler. The default sampler is a RandomSampler
                # if shuffling is enabled and a SequentialSampler otherwise. DataLoader
                # does not have a shuffle attribute, so we instead identify whether
                # shuffling is enabled by checking the default sampler type.
                shuffle = not isinstance(loader.sampler, SequentialSampler)

                def seeded_worker_init_fn(worker_init_fn):
                    def wrapper(worker_id):
                        worker_seed = torch.initial_seed() % 2**32

                    return wrapper

                worker_init_fn = loader.worker_init_fn
                generator = loader.generator
                if self._seed is not None:
                    worker_init_fn = seeded_worker_init_fn(
                    generator = torch.Generator()

                using_default_sampler = isinstance(
                    loader.sampler, (SequentialSampler, RandomSampler))
                if not using_default_sampler and train.world_rank() == 0:
                        f"The {loader.sampler.__class__.__name__} will be overwritten "
                        "with a DistributedSampler. You can disable this by setting "
                        "`with_sampler` to False in `prepare_data_loader`.")

                data_loader_args = {
                    "dataset": loader.dataset,
                    "batch_size": loader.batch_size,
                    "shuffle": False,
                    "num_workers": loader.num_workers,
                    "collate_fn": loader.collate_fn,
                    "pin_memory": loader.pin_memory,
                    "drop_last": loader.drop_last,
                    "timeout": loader.timeout,
                    "worker_init_fn": worker_init_fn,
                    "generator": generator,
                    "sampler": DistributedSampler(loader.dataset,
                return DataLoader(**data_loader_args)

            data_loader = with_sampler(data_loader)

        if move_to_device:
            device = self.get_device()
            data_loader = _WrappedDataLoader(data_loader, device,

        return data_loader
Exemple #11
    def prepare_model(
        model: torch.nn.Module,
        move_to_device: bool = True,
        wrap_ddp: bool = True,
        ddp_kwargs: Optional[Dict[str, Any]] = None,
    ) -> torch.nn.Module:
        """Prepares the model for distributed execution.

        This allows you to use the same exact code regardless of number of
        workers or the device type being used (CPU, GPU).

            model (torch.nn.Module): A torch model to prepare.
            move_to_device: Whether to move the model to the correct
                device. If set to False, the model needs to manually be moved
                to the correct device.
            wrap_ddp: Whether to wrap models in
            ddp_kwargs (Dict[str, Any]): Args to pass into
                ``DistributedDataParallel`` initialization if ``wrap_ddp`` is
                set to True.
        ddp_kwargs = ddp_kwargs or {}

        rank = train.local_rank()

        device = self.get_device()

        if torch.cuda.is_available():

        if move_to_device:
            logger.info(f"Moving model to device: {device}")
            model = model.to(device)

        def model_get_state(self):
            # `__getstate__` is an special method that informs pickle which attributes
            # to serialize. This custom implementation ensures that the wrapped forward
            # method and custom `__getstate__` method aren't serialized.
            if hasattr(self, "_original_get_state"):
                state = self._original_get_state()
                state["__getstate__"] = state["_original_get_state"]
                del state["_original_get_state"]
                # If model does not have a `__getstate__` already defined, use default
                # implementation.
                state = self.__dict__.copy()
                del state["__getstate__"]
            state["forward"] = state["_unwrapped_forward"]
            del state["_unwrapped_forward"]

            return state

        if self.amp_is_enabled:
            # Pickle cannot serialize the wrapped forward method. As a workaround,
            # define a custom `__getstate__` method that unwraps the forward method.
            model._unwrapped_forward = model.forward
            model.forward = autocast()(model.forward)

            # TODO(amogkam): Replace below logic with a generic "unpack model" method.
            # Replacing the `model.forward` method makes the model no longer
            # serializable. When serializing the model, we have to override the
            # `__getstate__` method to set back the original forward method.
            if hasattr(model, "__getstate__"):
                model._original_get_state = model.__getstate__
            # `__getstate__` must be a bound method rather than an callable attribute.
            # See https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance.  # noqa: E501
            model.__getstate__ = types.MethodType(model_get_state, model)

        if wrap_ddp and train.world_size() > 1:
            logger.info("Wrapping provided model in DDP.")
            if torch.cuda.is_available():
                model = DistributedDataParallel(model,
                model = DistributedDataParallel(model, **ddp_kwargs)

        return model