def prepare_model( self, model: torch.nn.Module, move_to_device: bool = True, wrap_ddp: bool = True, ddp_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.nn.Module: """Prepares the model for distributed execution. This allows you to use the same exact code regardless of number of workers or the device type being used (CPU, GPU). Args: model (torch.nn.Module): A torch model to prepare. move_to_device (bool): Whether to move the model to the correct device. If set to False, the model needs to manually be moved to the correct device. wrap_ddp (bool): Whether to wrap models in ``DistributedDataParallel``. ddp_kwargs (Dict[str, Any]): Args to pass into ``DistributedDataParallel`` initialization if ``wrap_ddp`` is set to True. """ ddp_kwargs = ddp_kwargs or {} rank = train.local_rank() device = self.get_device() if torch.cuda.is_available(): torch.cuda.set_device(device) if move_to_device: logger.info(f"Moving model to device: {device}") model = model.to(device) def wrap_forward(forward): @functools.wraps(forward) def wrapper(*args, **kwargs): with autocast(): outputs = forward(*args, **kwargs) assert isinstance(outputs, torch.Tensor) return outputs.float() return wrapper def model_get_state(self): # `__getstate__` is an special method that informs pickle which attributes # to serialize. This custom implementation ensures that the wrapped forward # method and custom `__getstate__` method aren't serialized. state = self.__dict__.copy() state["forward"] = state["_unwrapped_forward"] del state["_unwrapped_forward"] del state["__getstate__"] return state if self.amp_is_enabled: # Pickle cannot serialize the wrapped forward method. As a workaround, # define a custom `__getstate__` method that unwraps the forward method. model._unwrapped_forward = model.forward model.forward = wrap_forward(model.forward) # `__getstate__` must be a bound method rather than an callable attribute. # See https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance. # noqa: E501 assert not hasattr(model, "__getstate__") model.__getstate__ = types.MethodType(model_get_state, model) if wrap_ddp and train.world_size() > 1: logger.info("Wrapping provided model in DDP.") if torch.cuda.is_available(): model = DistributedDataParallel(model, device_ids=[rank], output_device=rank, **ddp_kwargs) else: model = DistributedDataParallel(model, **ddp_kwargs) return model
def prepare_model( self, model: torch.nn.Module, move_to_device: bool = True, wrap_ddp: bool = True, ddp_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.nn.Module: """Prepares the model for distributed execution. This allows you to use the same exact code regardless of number of workers or the device type being used (CPU, GPU). Args: model (torch.nn.Module): A torch model to prepare. move_to_device: Whether to move the model to the correct device. If set to False, the model needs to manually be moved to the correct device. wrap_ddp: Whether to wrap models in ``DistributedDataParallel``. ddp_kwargs (Dict[str, Any]): Args to pass into ``DistributedDataParallel`` initialization if ``wrap_ddp`` is set to True. """ ddp_kwargs = ddp_kwargs or {} rank = train.local_rank() device = self.get_device() if torch.cuda.is_available(): torch.cuda.set_device(device) if move_to_device: logger.info(f"Moving model to device: {device}") model = model.to(device) def model_get_state(self): # `__getstate__` is an special method that informs pickle which attributes # to serialize. This custom implementation ensures that the wrapped forward # method and custom `__getstate__` method aren't serialized. if hasattr(self, "_original_get_state"): state = self._original_get_state() state["__getstate__"] = state["_original_get_state"] del state["_original_get_state"] else: # If model does not have a `__getstate__` already defined, use default # implementation. state = self.__dict__.copy() del state["__getstate__"] state["forward"] = state["_unwrapped_forward"] del state["_unwrapped_forward"] return state if self.amp_is_enabled: # Pickle cannot serialize the wrapped forward method. As a workaround, # define a custom `__getstate__` method that unwraps the forward method. model._unwrapped_forward = model.forward model.forward = autocast()(model.forward) # TODO(amogkam): Replace below logic with a generic "unpack model" method. # Replacing the `model.forward` method makes the model no longer # serializable. When serializing the model, we have to override the # `__getstate__` method to set back the original forward method. if hasattr(model, "__getstate__"): model._original_get_state = model.__getstate__ # `__getstate__` must be a bound method rather than an callable attribute. # See https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance. # noqa: E501 model.__getstate__ = types.MethodType(model_get_state, model) if wrap_ddp and train.world_size() > 1: logger.info("Wrapping provided model in DDP.") if torch.cuda.is_available(): model = DistributedDataParallel(model, device_ids=[rank], output_device=rank, **ddp_kwargs) else: model = DistributedDataParallel(model, **ddp_kwargs) return model