def data_parallel(self): """Wraps the model with PyTorch's DistributedDataParallel. The intention is for rlpyt to create a separate Python process to drive each GPU (or CPU-group for CPU-only, MPI-like configuration). Agents with additional model components (beyond ``self.model``) which will have gradients computed through them should extend this method to wrap those, as well. Typically called in the runner during startup. """ if self.device.type == "cpu": self.model = DDPC(self.model) if self.dual_model: self.model_int = DDPC(self.model_int) logger.log("Initialized DistributedDataParallelCPU agent model.") else: self.model = DDP(self.model, device_ids=[self.device.index], output_device=self.device.index) if self.dual_model: self.model_int = DDP(self.model_int, device_ids=[self.device.index], output_device=self.device.index) logger.log("Initialized DistributedDataParallel agent model on " f"device {self.device}.")
def data_parallel(self): """Overwrite/extend for format other than 'self.model' for network(s) which will have gradients through them.""" if self.device.type == "cpu": self.model = DDPC(self.model) logger.log("Initialized DistributedDataParallelCPU agent model.") else: self.model = DDP(self.model, device_ids=[self.device.index], output_device=self.device.index) logger.log("Initialized DistributedDataParallel agent model on " f"device {self.device}.")
def data_parallel(self): """ Wraps the intrinsic bonus model with PyTorch's DistributedDataParallel. The intention is for rlpyt to create a separate Python process to drive each GPU (or CPU-group for CPU-only, MPI-like configuration). Typically called in the runner during startup. """ super().data_parallel() if self.device.type == "cpu": self.bonus_model = DDPC(self.bonus_model) logger.log( "Initialized DistributedDataParallelCPU intrinsic bonus model." ) else: self.bonus_model = DDP(self.bonus_model, device_ids=[self.device.index], output_device=self.device.index) logger.log( f"Initialized DistributedDataParallel intrinsic bonus model on device {self.device}." )
def data_parallel(self): super().data_parallel() if self.device.type == "cpu": self.q2_model = DDPC(self.q2_model) else: self.q2_model = DDP(self.q2_model)
def data_parallel(self): super().data_parallel() # Takes care of self.model. if self.device.type == "cpu": self.q_model = DDPC(self.q_model) else: self.q_model = DDP(self.q_model)