Ejemplo n.º 1
0
def get_device() -> torch.device:
    """Gets the correct torch device to use for training."""
    if torch.cuda.is_available():
        rank = train.local_rank()
        device = torch.device(f"cuda:{rank}")
    else:
        device = torch.device("cpu")

    return device
Ejemplo n.º 2
0
def train_func(config):
    batch_size = config.get("batch_size", 32)
    hidden_size = config.get("hidden_size", 1)
    lr = config.get("lr", 1e-2)
    epochs = config.get("epochs", 3)

    train_dataset_pipeline_shard = train.get_dataset_shard("train")
    validation_dataset_pipeline_shard = train.get_dataset_shard("validation")

    device = torch.device(
        f"cuda:{train.local_rank()}" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        torch.cuda.set_device(device)

    model = nn.Linear(1, hidden_size)
    model = model.to(device)
    model = DistributedDataParallel(
        model,
        device_ids=[train.local_rank()] if torch.cuda.is_available() else None)

    loss_fn = nn.MSELoss()

    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    results = []

    train_dataset_iterator = train_dataset_pipeline_shard.iter_datasets()
    validation_dataset_iterator = \
        validation_dataset_pipeline_shard.iter_datasets()

    for _ in range(epochs):
        train_dataset = next(train_dataset_iterator)
        validation_dataset = next(validation_dataset_iterator)

        train_torch_dataset = train_dataset.to_torch(
            label_column="y",
            feature_columns=["x"],
            label_column_dtype=torch.float,
            feature_column_dtypes=[torch.float],
            batch_size=batch_size,
        )
        validation_torch_dataset = validation_dataset.to_torch(
            label_column="y",
            feature_columns=["x"],
            label_column_dtype=torch.float,
            feature_column_dtypes=[torch.float],
            batch_size=batch_size)

        train_epoch(train_torch_dataset, model, loss_fn, optimizer, device)
        result = validate_epoch(validation_torch_dataset, model, loss_fn,
                                device)
        train.report(**result)
        results.append(result)

    return results
Ejemplo n.º 3
0
    def prepare_model(
        self,
        model: torch.nn.Module,
        move_to_device: bool = True,
        wrap_ddp: bool = True,
        ddp_kwargs: Optional[Dict[str, Any]] = None,
    ) -> torch.nn.Module:
        """Prepares the model for distributed execution.

        This allows you to use the same exact code regardless of number of
        workers or the device type being used (CPU, GPU).

        Args:
            model (torch.nn.Module): A torch model to prepare.
            move_to_device (bool): Whether to move the model to the correct
                device. If set to False, the model needs to manually be moved
                to the correct device.
            wrap_ddp (bool): Whether to wrap models in
                ``DistributedDataParallel``.
            ddp_kwargs (Dict[str, Any]): Args to pass into
                ``DistributedDataParallel`` initialization if ``wrap_ddp`` is
                set to True.
        """
        ddp_kwargs = ddp_kwargs or {}

        rank = train.local_rank()

        device = self.get_device()

        if torch.cuda.is_available():
            torch.cuda.set_device(device)

        if move_to_device:
            logger.info(f"Moving model to device: {device}")
            model = model.to(device)
        if wrap_ddp and train.world_size() > 1:
            logger.info("Wrapping provided model in DDP.")
            if torch.cuda.is_available():
                model = DistributedDataParallel(model,
                                                device_ids=[rank],
                                                output_device=rank,
                                                **ddp_kwargs)
            else:
                model = DistributedDataParallel(model, **ddp_kwargs)

        return model
Ejemplo n.º 4
0
def _huggingface_train_loop_per_worker(config):
    """Per-worker training loop for HuggingFace Transformers."""
    trainer_init_per_worker = config.pop("_trainer_init_per_worker")

    # Env vars necessary for HF to setup DDP
    os.environ["RANK"] = str(train.world_rank())
    os.environ["WORLD_SIZE"] = str(train.world_size())
    os.environ["LOCAL_RANK"] = str(train.local_rank())

    train_dataset = train.get_dataset_shard(TRAIN_DATASET_KEY)
    eval_dataset = train.get_dataset_shard(EVALUATION_DATASET_KEY)

    train_torch_dataset, eval_torch_dataset = process_datasets(
        train_dataset,
        eval_dataset,
    )

    trainer: transformers.trainer.Trainer = trainer_init_per_worker(
        train_torch_dataset, eval_torch_dataset, **config)

    if trainer.args.push_to_hub and not trainer.args.hub_token:
        warnings.warn(
            "You have set `push_to_hub=True` but didn't specify `hub_token`. "
            "Pushing to hub will most likely fail, as the credentials will not "
            "be automatically propagated from the local enviroment to the Ray Actors. "
            "If that happens, specify `hub_token` in `TrainingArguments`.")

    if (trainer.args.evaluation_strategy == "steps"
            or trainer.args.save_strategy == "steps"
            or trainer.args.logging_strategy == "steps"):
        raise ValueError(
            "'steps' value for `evaluation_strategy`, `logging_strategy` "
            "or `save_strategy` is not yet supported.")

    trainer = wrap_transformers_trainer(trainer)

    # ensure no HF logging callbacks are added
    # aside from doubling functionality with our callbacks,
    # the Wandb callbacks causes training to freeze
    integration_callbacks = transformers.trainer.get_reporting_integration_callbacks(
        trainer.args.report_to)
    for callback in integration_callbacks:
        trainer.pop_callback(callback)

    trainer.add_callback(TrainReportCallback)

    checkpoint = session.get_checkpoint()
    checkpoint_path = None
    remove_checkpoint_path = False
    if checkpoint:
        assert isinstance(checkpoint, Checkpoint)
        checkpoint_dict = checkpoint.to_dict()
        source_ip = checkpoint_dict[NODE_IP_KEY]
        source_path = checkpoint_dict[CHECKPOINT_PATH_ON_NODE_KEY]
        target_ip = get_node_ip_address()
        if source_ip == target_ip:
            checkpoint_path = source_path
        else:
            checkpoint_path = tempfile.mkdtemp(
                suffix=Path(trainer.args.output_dir).name)
            remove_checkpoint_path = True
            sync_dir_between_nodes(
                source_ip=source_ip,
                source_path=source_path,
                target_ip=target_ip,
                target_path=checkpoint_path,
                return_futures=False,
                max_size_bytes=None,
            )
    trainer.train(resume_from_checkpoint=checkpoint_path)
    if remove_checkpoint_path:
        shutil.rmtree(checkpoint_path, ignore_errors=True)
Ejemplo n.º 5
0
 def train_actor_failure():
     import sys
     sys.exit(0)
     return train.local_rank()
Ejemplo n.º 6
0
 def train_func():
     return train.local_rank()
Ejemplo n.º 7
0
    def prepare_model(
        self,
        model: torch.nn.Module,
        move_to_device: bool = True,
        wrap_ddp: bool = True,
        ddp_kwargs: Optional[Dict[str, Any]] = None,
    ) -> torch.nn.Module:
        """Prepares the model for distributed execution.

        This allows you to use the same exact code regardless of number of
        workers or the device type being used (CPU, GPU).

        Args:
            model (torch.nn.Module): A torch model to prepare.
            move_to_device (bool): Whether to move the model to the correct
                device. If set to False, the model needs to manually be moved
                to the correct device.
            wrap_ddp (bool): Whether to wrap models in
                ``DistributedDataParallel``.
            ddp_kwargs (Dict[str, Any]): Args to pass into
                ``DistributedDataParallel`` initialization if ``wrap_ddp`` is
                set to True.
        """
        ddp_kwargs = ddp_kwargs or {}

        rank = train.local_rank()

        device = self.get_device()

        if torch.cuda.is_available():
            torch.cuda.set_device(device)

        if move_to_device:
            logger.info(f"Moving model to device: {device}")
            model = model.to(device)

        def wrap_forward(forward):
            @functools.wraps(forward)
            def wrapper(*args, **kwargs):
                with autocast():
                    outputs = forward(*args, **kwargs)
                assert isinstance(outputs, torch.Tensor)
                return outputs.float()

            return wrapper

        def model_get_state(self):
            # `__getstate__` is an special method that informs pickle which attributes
            # to serialize. This custom implementation ensures that the wrapped forward
            # method and custom `__getstate__` method aren't serialized.
            state = self.__dict__.copy()
            state["forward"] = state["_unwrapped_forward"]
            del state["_unwrapped_forward"]
            del state["__getstate__"]
            return state

        if self.amp_is_enabled:
            # Pickle cannot serialize the wrapped forward method. As a workaround,
            # define a custom `__getstate__` method that unwraps the forward method.
            model._unwrapped_forward = model.forward
            model.forward = wrap_forward(model.forward)
            # `__getstate__` must be a bound method rather than an callable attribute.
            # See https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance.  # noqa: E501
            assert not hasattr(model, "__getstate__")
            model.__getstate__ = types.MethodType(model_get_state, model)

        if wrap_ddp and train.world_size() > 1:
            logger.info("Wrapping provided model in DDP.")
            if torch.cuda.is_available():
                model = DistributedDataParallel(model,
                                                device_ids=[rank],
                                                output_device=rank,
                                                **ddp_kwargs)
            else:
                model = DistributedDataParallel(model, **ddp_kwargs)

        return model
Ejemplo n.º 8
0
    def prepare_model(
        self,
        model: torch.nn.Module,
        move_to_device: bool = True,
        wrap_ddp: bool = True,
        ddp_kwargs: Optional[Dict[str, Any]] = None,
    ) -> torch.nn.Module:
        """Prepares the model for distributed execution.

        This allows you to use the same exact code regardless of number of
        workers or the device type being used (CPU, GPU).

        Args:
            model (torch.nn.Module): A torch model to prepare.
            move_to_device: Whether to move the model to the correct
                device. If set to False, the model needs to manually be moved
                to the correct device.
            wrap_ddp: Whether to wrap models in
                ``DistributedDataParallel``.
            ddp_kwargs (Dict[str, Any]): Args to pass into
                ``DistributedDataParallel`` initialization if ``wrap_ddp`` is
                set to True.
        """
        ddp_kwargs = ddp_kwargs or {}

        rank = train.local_rank()

        device = self.get_device()

        if torch.cuda.is_available():
            torch.cuda.set_device(device)

        if move_to_device:
            logger.info(f"Moving model to device: {device}")
            model = model.to(device)

        def model_get_state(self):
            # `__getstate__` is an special method that informs pickle which attributes
            # to serialize. This custom implementation ensures that the wrapped forward
            # method and custom `__getstate__` method aren't serialized.
            if hasattr(self, "_original_get_state"):
                state = self._original_get_state()
                state["__getstate__"] = state["_original_get_state"]
                del state["_original_get_state"]
            else:
                # If model does not have a `__getstate__` already defined, use default
                # implementation.
                state = self.__dict__.copy()
                del state["__getstate__"]
            state["forward"] = state["_unwrapped_forward"]
            del state["_unwrapped_forward"]

            return state

        if self.amp_is_enabled:
            # Pickle cannot serialize the wrapped forward method. As a workaround,
            # define a custom `__getstate__` method that unwraps the forward method.
            model._unwrapped_forward = model.forward
            model.forward = autocast()(model.forward)

            # TODO(amogkam): Replace below logic with a generic "unpack model" method.
            # Replacing the `model.forward` method makes the model no longer
            # serializable. When serializing the model, we have to override the
            # `__getstate__` method to set back the original forward method.
            if hasattr(model, "__getstate__"):
                model._original_get_state = model.__getstate__
            # `__getstate__` must be a bound method rather than an callable attribute.
            # See https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance.  # noqa: E501
            model.__getstate__ = types.MethodType(model_get_state, model)

        if wrap_ddp and train.world_size() > 1:
            logger.info("Wrapping provided model in DDP.")
            if torch.cuda.is_available():
                model = DistributedDataParallel(model,
                                                device_ids=[rank],
                                                output_device=rank,
                                                **ddp_kwargs)
            else:
                model = DistributedDataParallel(model, **ddp_kwargs)

        return model