def get_device() -> torch.device: """Gets the correct torch device to use for training.""" if torch.cuda.is_available(): rank = train.local_rank() device = torch.device(f"cuda:{rank}") else: device = torch.device("cpu") return device
def train_func(config): batch_size = config.get("batch_size", 32) hidden_size = config.get("hidden_size", 1) lr = config.get("lr", 1e-2) epochs = config.get("epochs", 3) train_dataset_pipeline_shard = train.get_dataset_shard("train") validation_dataset_pipeline_shard = train.get_dataset_shard("validation") device = torch.device( f"cuda:{train.local_rank()}" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): torch.cuda.set_device(device) model = nn.Linear(1, hidden_size) model = model.to(device) model = DistributedDataParallel( model, device_ids=[train.local_rank()] if torch.cuda.is_available() else None) loss_fn = nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=lr) results = [] train_dataset_iterator = train_dataset_pipeline_shard.iter_datasets() validation_dataset_iterator = \ validation_dataset_pipeline_shard.iter_datasets() for _ in range(epochs): train_dataset = next(train_dataset_iterator) validation_dataset = next(validation_dataset_iterator) train_torch_dataset = train_dataset.to_torch( label_column="y", feature_columns=["x"], label_column_dtype=torch.float, feature_column_dtypes=[torch.float], batch_size=batch_size, ) validation_torch_dataset = validation_dataset.to_torch( label_column="y", feature_columns=["x"], label_column_dtype=torch.float, feature_column_dtypes=[torch.float], batch_size=batch_size) train_epoch(train_torch_dataset, model, loss_fn, optimizer, device) result = validate_epoch(validation_torch_dataset, model, loss_fn, device) train.report(**result) results.append(result) return results
def prepare_model( self, model: torch.nn.Module, move_to_device: bool = True, wrap_ddp: bool = True, ddp_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.nn.Module: """Prepares the model for distributed execution. This allows you to use the same exact code regardless of number of workers or the device type being used (CPU, GPU). Args: model (torch.nn.Module): A torch model to prepare. move_to_device (bool): Whether to move the model to the correct device. If set to False, the model needs to manually be moved to the correct device. wrap_ddp (bool): Whether to wrap models in ``DistributedDataParallel``. ddp_kwargs (Dict[str, Any]): Args to pass into ``DistributedDataParallel`` initialization if ``wrap_ddp`` is set to True. """ ddp_kwargs = ddp_kwargs or {} rank = train.local_rank() device = self.get_device() if torch.cuda.is_available(): torch.cuda.set_device(device) if move_to_device: logger.info(f"Moving model to device: {device}") model = model.to(device) if wrap_ddp and train.world_size() > 1: logger.info("Wrapping provided model in DDP.") if torch.cuda.is_available(): model = DistributedDataParallel(model, device_ids=[rank], output_device=rank, **ddp_kwargs) else: model = DistributedDataParallel(model, **ddp_kwargs) return model
def _huggingface_train_loop_per_worker(config): """Per-worker training loop for HuggingFace Transformers.""" trainer_init_per_worker = config.pop("_trainer_init_per_worker") # Env vars necessary for HF to setup DDP os.environ["RANK"] = str(train.world_rank()) os.environ["WORLD_SIZE"] = str(train.world_size()) os.environ["LOCAL_RANK"] = str(train.local_rank()) train_dataset = train.get_dataset_shard(TRAIN_DATASET_KEY) eval_dataset = train.get_dataset_shard(EVALUATION_DATASET_KEY) train_torch_dataset, eval_torch_dataset = process_datasets( train_dataset, eval_dataset, ) trainer: transformers.trainer.Trainer = trainer_init_per_worker( train_torch_dataset, eval_torch_dataset, **config) if trainer.args.push_to_hub and not trainer.args.hub_token: warnings.warn( "You have set `push_to_hub=True` but didn't specify `hub_token`. " "Pushing to hub will most likely fail, as the credentials will not " "be automatically propagated from the local enviroment to the Ray Actors. " "If that happens, specify `hub_token` in `TrainingArguments`.") if (trainer.args.evaluation_strategy == "steps" or trainer.args.save_strategy == "steps" or trainer.args.logging_strategy == "steps"): raise ValueError( "'steps' value for `evaluation_strategy`, `logging_strategy` " "or `save_strategy` is not yet supported.") trainer = wrap_transformers_trainer(trainer) # ensure no HF logging callbacks are added # aside from doubling functionality with our callbacks, # the Wandb callbacks causes training to freeze integration_callbacks = transformers.trainer.get_reporting_integration_callbacks( trainer.args.report_to) for callback in integration_callbacks: trainer.pop_callback(callback) trainer.add_callback(TrainReportCallback) checkpoint = session.get_checkpoint() checkpoint_path = None remove_checkpoint_path = False if checkpoint: assert isinstance(checkpoint, Checkpoint) checkpoint_dict = checkpoint.to_dict() source_ip = checkpoint_dict[NODE_IP_KEY] source_path = checkpoint_dict[CHECKPOINT_PATH_ON_NODE_KEY] target_ip = get_node_ip_address() if source_ip == target_ip: checkpoint_path = source_path else: checkpoint_path = tempfile.mkdtemp( suffix=Path(trainer.args.output_dir).name) remove_checkpoint_path = True sync_dir_between_nodes( source_ip=source_ip, source_path=source_path, target_ip=target_ip, target_path=checkpoint_path, return_futures=False, max_size_bytes=None, ) trainer.train(resume_from_checkpoint=checkpoint_path) if remove_checkpoint_path: shutil.rmtree(checkpoint_path, ignore_errors=True)
def train_actor_failure(): import sys sys.exit(0) return train.local_rank()
def train_func(): return train.local_rank()
def prepare_model( self, model: torch.nn.Module, move_to_device: bool = True, wrap_ddp: bool = True, ddp_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.nn.Module: """Prepares the model for distributed execution. This allows you to use the same exact code regardless of number of workers or the device type being used (CPU, GPU). Args: model (torch.nn.Module): A torch model to prepare. move_to_device (bool): Whether to move the model to the correct device. If set to False, the model needs to manually be moved to the correct device. wrap_ddp (bool): Whether to wrap models in ``DistributedDataParallel``. ddp_kwargs (Dict[str, Any]): Args to pass into ``DistributedDataParallel`` initialization if ``wrap_ddp`` is set to True. """ ddp_kwargs = ddp_kwargs or {} rank = train.local_rank() device = self.get_device() if torch.cuda.is_available(): torch.cuda.set_device(device) if move_to_device: logger.info(f"Moving model to device: {device}") model = model.to(device) def wrap_forward(forward): @functools.wraps(forward) def wrapper(*args, **kwargs): with autocast(): outputs = forward(*args, **kwargs) assert isinstance(outputs, torch.Tensor) return outputs.float() return wrapper def model_get_state(self): # `__getstate__` is an special method that informs pickle which attributes # to serialize. This custom implementation ensures that the wrapped forward # method and custom `__getstate__` method aren't serialized. state = self.__dict__.copy() state["forward"] = state["_unwrapped_forward"] del state["_unwrapped_forward"] del state["__getstate__"] return state if self.amp_is_enabled: # Pickle cannot serialize the wrapped forward method. As a workaround, # define a custom `__getstate__` method that unwraps the forward method. model._unwrapped_forward = model.forward model.forward = wrap_forward(model.forward) # `__getstate__` must be a bound method rather than an callable attribute. # See https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance. # noqa: E501 assert not hasattr(model, "__getstate__") model.__getstate__ = types.MethodType(model_get_state, model) if wrap_ddp and train.world_size() > 1: logger.info("Wrapping provided model in DDP.") if torch.cuda.is_available(): model = DistributedDataParallel(model, device_ids=[rank], output_device=rank, **ddp_kwargs) else: model = DistributedDataParallel(model, **ddp_kwargs) return model
def prepare_model( self, model: torch.nn.Module, move_to_device: bool = True, wrap_ddp: bool = True, ddp_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.nn.Module: """Prepares the model for distributed execution. This allows you to use the same exact code regardless of number of workers or the device type being used (CPU, GPU). Args: model (torch.nn.Module): A torch model to prepare. move_to_device: Whether to move the model to the correct device. If set to False, the model needs to manually be moved to the correct device. wrap_ddp: Whether to wrap models in ``DistributedDataParallel``. ddp_kwargs (Dict[str, Any]): Args to pass into ``DistributedDataParallel`` initialization if ``wrap_ddp`` is set to True. """ ddp_kwargs = ddp_kwargs or {} rank = train.local_rank() device = self.get_device() if torch.cuda.is_available(): torch.cuda.set_device(device) if move_to_device: logger.info(f"Moving model to device: {device}") model = model.to(device) def model_get_state(self): # `__getstate__` is an special method that informs pickle which attributes # to serialize. This custom implementation ensures that the wrapped forward # method and custom `__getstate__` method aren't serialized. if hasattr(self, "_original_get_state"): state = self._original_get_state() state["__getstate__"] = state["_original_get_state"] del state["_original_get_state"] else: # If model does not have a `__getstate__` already defined, use default # implementation. state = self.__dict__.copy() del state["__getstate__"] state["forward"] = state["_unwrapped_forward"] del state["_unwrapped_forward"] return state if self.amp_is_enabled: # Pickle cannot serialize the wrapped forward method. As a workaround, # define a custom `__getstate__` method that unwraps the forward method. model._unwrapped_forward = model.forward model.forward = autocast()(model.forward) # TODO(amogkam): Replace below logic with a generic "unpack model" method. # Replacing the `model.forward` method makes the model no longer # serializable. When serializing the model, we have to override the # `__getstate__` method to set back the original forward method. if hasattr(model, "__getstate__"): model._original_get_state = model.__getstate__ # `__getstate__` must be a bound method rather than an callable attribute. # See https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance. # noqa: E501 model.__getstate__ = types.MethodType(model_get_state, model) if wrap_ddp and train.world_size() > 1: logger.info("Wrapping provided model in DDP.") if torch.cuda.is_available(): model = DistributedDataParallel(model, device_ids=[rank], output_device=rank, **ddp_kwargs) else: model = DistributedDataParallel(model, **ddp_kwargs) return model