def _setup_amp_backend(self, amp_type: str): if self.trainer.precision != 16: # no AMP requested, so we can leave now return amp_type = amp_type.lower() assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}' if amp_type == 'native': if not NATIVE_AMP_AVAILABLE: rank_zero_warn('You have asked for native AMP but your PyTorch version does not support it.' ' Consider upgrading with `pip install torch>=1.6`.' ' We will attempt to use NVIDIA Apex for this session.') amp_type = 'apex' else: self.trainer.amp_backend = AMPType.NATIVE log.info('Using native 16bit precision.') self.backend = NativeAMPPlugin(self.trainer) if amp_type == 'apex': if not APEX_AVAILABLE: rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.' ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux') else: log.info('Using APEX 16bit precision.') self.trainer.amp_backend = AMPType.APEX self.backend = ApexPlugin(self.trainer) log.warn("LightningOptimizer doesn't support Apex") if not self.trainer.amp_backend: raise ModuleNotFoundError( f'You have asked for AMP support {amp_type}, but there is no support on your side yet.' f' Consider installing torch >= 1.6 or NVIDIA Apex.' )
def __init_nvidia_apex(self, model): # check for this bug (amp + dp + !01 doesn't work) # https://github.com/NVIDIA/apex/issues/227 if self.trainer.amp_level == 'O2': raise MisconfigurationException( f'Amp level {self.trainer.amp_level} with DataParallel is not supported.' f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.' f' We recommend you switch to ddp if you want to use amp') else: self.precision_backend = ApexPlugin(self.trainer) model, optimizers = self.precision_backend._init(model) return model
def __init__(self, trainer): self.trainer = trainer self.plugins = [] self.ddp_plugin = DDPPlugin() self.cloud_environment = None self.amp_plugin = NativeAMPPlugin(trainer) self.apex_plugin = ApexPlugin(trainer)
def _setup_amp_backend(self, amp_type: str, plugins: Optional[list]): if self.trainer.precision != 16: # no AMP requested, so we can leave now return using_sharded_plugin = self._check_using_sharded_plugin(plugins) amp_type = amp_type.lower() assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}' if amp_type == 'native': if not NATIVE_AMP_AVAILABLE: rank_zero_warn( 'You have asked for native AMP but your PyTorch version does not support it.' ' Consider upgrading with `pip install torch>=1.6`.' ' We will attempt to use NVIDIA Apex for this session.') amp_type = 'apex' else: self.trainer.amp_backend = AMPType.NATIVE if using_sharded_plugin: log.info('Using sharded 16bit precision.') self.backend = ShardedNativeAMPPlugin(self.trainer) else: log.info('Using native 16bit precision.') self.backend = NativeAMPPlugin(self.trainer) if amp_type == 'apex': if not APEX_AVAILABLE: rank_zero_warn( 'You have asked for Apex AMP but you have not installed it yet.' ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux' ) elif using_sharded_plugin: raise MisconfigurationException( 'Sharded Plugin is not supported with Apex AMP, please using native AMP for 16-bit precision.' ) else: log.info('Using APEX 16bit precision.') self.trainer.amp_backend = AMPType.APEX self.backend = ApexPlugin(self.trainer) if not self.trainer.amp_backend: raise ModuleNotFoundError( f'You have asked for AMP support {amp_type}, but there is no support on your side yet.' f' Consider installing torch >= 1.6 or NVIDIA Apex.')
def setup(self, model): # call setup self.trainer.call_setup_hook(model) torch.cuda.set_device(self.trainer.root_gpu) model.cuda(self.trainer.root_gpu) # CHOOSE OPTIMIZER # allow for lr schedulers as well optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers( model) self.trainer.optimizers = optimizers self.trainer.lr_schedulers = lr_schedulers self.trainer.optimizer_frequencies = optimizer_frequencies if self.trainer.amp_backend == AMPType.APEX: self.precision_backend = ApexPlugin(self.trainer) model, optimizers = self.precision_backend._init(model) self.trainer.model = model
class PrecisionConnector: def __init__(self, trainer): self.trainer = trainer self.backend = None def on_trainer_init(self, precision: int, amp_level: str, amp_backend: str): # AMP init # These are the only lines needed after v0.8.0 # we wrap the user's forward with autocast and give it back at the end of fit self.trainer.autocast_original_forward = None self.trainer.precision = precision self.trainer.scaler = None self.trainer.amp_level = amp_level self.init_amp(amp_backend) def init_amp(self, amp_type: str): assert self.trainer.precision in (16, 32), 'only 32 or 16 bit precision supported' self.trainer.amp_backend = None self._setup_amp_backend(amp_type) def _setup_amp_backend(self, amp_type: str): if self.trainer.precision != 16: # no AMP requested, so we can leave now return amp_type = amp_type.lower() assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}' if amp_type == 'native': if not NATIVE_AMP_AVAILABLE: rank_zero_warn('You have asked for native AMP but your PyTorch version does not support it.' ' Consider upgrading with `pip install torch>=1.6`.' ' We will attempt to use NVIDIA Apex for this session.') amp_type = 'apex' else: self.trainer.amp_backend = AMPType.NATIVE log.info('Using native 16bit precision.') self.backend = NativeAMPPlugin(self.trainer) if amp_type == 'apex': if not APEX_AVAILABLE: rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.' ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux') else: log.info('Using APEX 16bit precision.') self.trainer.amp_backend = AMPType.APEX self.backend = ApexPlugin(self.trainer) log.warn("LightningOptimizer doesn't support Apex") if not self.trainer.amp_backend: raise ModuleNotFoundError( f'You have asked for AMP support {amp_type}, but there is no support on your side yet.' f' Consider installing torch >= 1.6 or NVIDIA Apex.' ) def connect(self, model): if self.backend: model, optimizers = self.backend.connect(model, self.trainer.optimizers) self.trainer.optimizers = optimizers return model
def ddp_train_tmp(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): """ Entry point for ddp Args: process_idx: mp_queue: multiprocessing queue model: Returns: """ # offset the process id if requested process_idx = process_idx + proc_offset # show progressbar only on progress_rank 0 if (self.trainer.node_rank != 0 or process_idx != 0 ) and self.trainer.progress_bar_callback is not None: self.trainer.progress_bar_callback.disable() # determine which process we are and world size self.set_world_ranks(process_idx) # set warning rank rank_zero_only.rank = self.trainer.global_rank # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer model.init_ddp_connection(self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks) # call setup after the ddp process has connected self.trainer.call_setup_hook(model) # on world_size=0 let everyone know training is starting if self.trainer.is_global_zero: log.info('-' * 100) log.info(f'distributed_backend={self.trainer.distributed_backend}') log.info( f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes' ) log.info('-' * 100) # call sync_bn before .cuda(), configure_apex and configure_ddp if self.trainer.sync_batchnorm: model = model.configure_sync_batchnorm(model) # move the model to the correct device self.model_to_device(model, process_idx, is_master) # CHOOSE OPTIMIZER # allow for lr schedulers as well optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers( model) self.trainer.optimizers = optimizers self.trainer.lr_schedulers = lr_schedulers self.trainer.optimizer_frequencies = optimizer_frequencies # set model properties before going into wrapper self.trainer.model_connector.copy_trainer_model_properties(model) # AMP - # run through amp wrapper before going to distributed DP if self.trainer.amp_backend == AMPType.APEX: self.precision_backend = ApexPlugin(self.trainer) model, optimizers = self.precision_backend._init(model) # device ids change depending on the DDP setup device_ids = self.get_device_ids() # allow user to configure ddp model = model.configure_ddp(model, device_ids) # set up training routine self.trainer.train_loop.setup_training(model) # train or test results = self.train_or_test() # get original model model = self.trainer.get_model() # persist info in ddp_spawn self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) # clean up memory torch.cuda.empty_cache() if self.trainer.global_rank == 0: return results
class DDPBase(Accelerator): def __init__(self, trainer): super().__init__(trainer) self.precision_backend = None def training_step(self, args): if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) else: output = self.trainer.model(*args) return output def validation_step(self, args): output = self.training_step(args) return output def test_step(self, args): output = self.training_step(args) return output def barrier(self, name: str = None): torch_distrib.barrier() def early_stopping_should_stop(self, pl_module): stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) dist.all_reduce(stop, op=dist.reduce_op.SUM) dist.barrier() should_stop = stop == self.trainer.world_size return should_stop def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): if self.trainer.distributed_backend.lower() not in [ 'ddp_spawn', 'ddp_cpu', 'tpu' ]: return # track the best model path best_model_path = None if self.trainer.checkpoint_callback is not None: best_model_path = self.trainer.checkpoint_callback.best_model_path if self.trainer.global_rank == 0 and mp_queue is not None: rank_zero_warn('cleaning up ddp environment...') # todo, pass complete checkpoint as state dictionary mp_queue.put(best_model_path) mp_queue.put(results) # save the last weights last_path = None if not self.trainer.testing and best_model_path is not None and len( best_model_path) > 0: last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) atomic_save(model.state_dict(), last_path) mp_queue.put(last_path) def ddp_train_tmp(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): """ Entry point for ddp Args: process_idx: mp_queue: multiprocessing queue model: Returns: """ # offset the process id if requested process_idx = process_idx + proc_offset # show progressbar only on progress_rank 0 if (self.trainer.node_rank != 0 or process_idx != 0 ) and self.trainer.progress_bar_callback is not None: self.trainer.progress_bar_callback.disable() # determine which process we are and world size self.set_world_ranks(process_idx) # set warning rank rank_zero_only.rank = self.trainer.global_rank # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer model.init_ddp_connection(self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks) # call setup after the ddp process has connected self.trainer.call_setup_hook(model) # on world_size=0 let everyone know training is starting if self.trainer.is_global_zero: log.info('-' * 100) log.info(f'distributed_backend={self.trainer.distributed_backend}') log.info( f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes' ) log.info('-' * 100) # call sync_bn before .cuda(), configure_apex and configure_ddp if self.trainer.sync_batchnorm: model = model.configure_sync_batchnorm(model) # move the model to the correct device self.model_to_device(model, process_idx, is_master) # CHOOSE OPTIMIZER # allow for lr schedulers as well optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers( model) self.trainer.optimizers = optimizers self.trainer.lr_schedulers = lr_schedulers self.trainer.optimizer_frequencies = optimizer_frequencies # set model properties before going into wrapper self.trainer.model_connector.copy_trainer_model_properties(model) # AMP - # run through amp wrapper before going to distributed DP if self.trainer.amp_backend == AMPType.APEX: self.precision_backend = ApexPlugin(self.trainer) model, optimizers = self.precision_backend._init(model) # device ids change depending on the DDP setup device_ids = self.get_device_ids() # allow user to configure ddp model = model.configure_ddp(model, device_ids) # set up training routine self.trainer.train_loop.setup_training(model) # train or test results = self.train_or_test() # get original model model = self.trainer.get_model() # persist info in ddp_spawn self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) # clean up memory torch.cuda.empty_cache() if self.trainer.global_rank == 0: return results def set_world_ranks(self, process_idx): raise NotImplementedError( 'to create a ddp backend, please implement set_world_ranks') def model_to_device(self, model, process_idx, is_master): raise NotImplementedError( 'to create a ddp backend, please implement model_to_device') def get_device_ids(self): raise NotImplementedError( 'to create a ddp backend, please implement get_device_ids')
class DataParallelBackend(Accelerator): def __init__(self, trainer): super().__init__(trainer) self.model_autocast_original_forward = None self.precision_backend = None def setup(self, model): # call setup after the ddp process has connected self.trainer.call_setup_hook(model) # put model on correct device model.cuda(self.trainer.root_gpu) # CHOOSE OPTIMIZER # allow for lr schedulers as well optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model) self.trainer.optimizers = optimizers self.trainer.lr_schedulers = lr_schedulers self.trainer.optimizer_frequencies = optimizer_frequencies # init torch data parallel model = self.__init_torch_data_parallel(model) # hack forward to do autocast for the user self.model_autocast_original_forward = model.forward # init half precision if self.trainer.amp_backend: model = self.__init_half_precision(model) self.trainer.model = model def __init_torch_data_parallel(self, model): # create list of device ids device_ids = self.trainer.data_parallel_device_ids if isinstance(device_ids, int): device_ids = list(range(device_ids)) # set dp device torch.cuda.set_device(self.trainer.root_gpu) model = LightningDataParallel(model, device_ids=device_ids) return model def __init_half_precision(self, model): if self.trainer.amp_backend == AMPType.NATIVE: self.__init_native_amp(model) else: model = self.__init_nvidia_apex(model) return model def __init_native_amp(self, model): model.forward = torch.cuda.amp.autocast()(model.forward) def __init_nvidia_apex(self, model): # check for this bug (amp + dp + !01 doesn't work) # https://github.com/NVIDIA/apex/issues/227 if self.trainer.amp_level == 'O2': raise MisconfigurationException( f'Amp level {self.trainer.amp_level} with DataParallel is not supported.' f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.' f' We recommend you switch to ddp if you want to use amp') else: self.precision_backend = ApexPlugin(self.trainer) model, optimizers = self.precision_backend._init(model) return model def train(self): model = self.trainer.model # set up training routine self.trainer.train_loop.setup_training(model) # train or test results = self.train_or_test() return results def teardown(self): # replace the original fwd function self.trainer.model.forward = self.model_autocast_original_forward def training_step(self, args): if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) else: output = self.trainer.model(*args) return output def validation_step(self, args): output = self.training_step(args) return output def test_step(self, args): output = self.training_step(args) return output def training_step_end(self, output): if isinstance(output, Result): output.dp_reduce() return output def validation_step_end(self, output): if isinstance(output, Result): output.dp_reduce() return output def test_step_end(self, output): if isinstance(output, Result): output.dp_reduce() return output def reinit_scheduler_properties(self, optimizers: list, schedulers: list): """ Reinitialize optimizer.step properties added by schedulers """ for scheduler in schedulers: scheduler = scheduler['scheduler'] for optimizer in optimizers: # check that we dont mix users optimizers and schedulers if scheduler.optimizer == optimizer: # Find the mro belonging to the base lr scheduler class for i, mro in enumerate(scheduler.__class__.__mro__): is_regular_scheduler = optim.lr_scheduler._LRScheduler is_lr_reduce_on_plateau = optim.lr_scheduler.ReduceLROnPlateau if is_regular_scheduler or is_lr_reduce_on_plateau: idx = i state = scheduler.state_dict() else: state = None scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) if state is not None: scheduler.load_state_dict(state)
class GPUBackend(Accelerator): amp_backend: AMPType def __init__(self, trainer): super().__init__(trainer) self.precision_backend = None def setup(self, model): # call setup self.trainer.call_setup_hook(model) torch.cuda.set_device(self.trainer.root_gpu) model.cuda(self.trainer.root_gpu) # CHOOSE OPTIMIZER # allow for lr schedulers as well optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers( model) self.trainer.optimizers = optimizers self.trainer.lr_schedulers = lr_schedulers self.trainer.optimizer_frequencies = optimizer_frequencies if self.trainer.amp_backend == AMPType.APEX: self.precision_backend = ApexPlugin(self.trainer) model, optimizers = self.precision_backend._init(model) self.trainer.model = model def train(self): model = self.trainer.model # set up training routine self.trainer.train_loop.setup_training(model) # train or test results = self.train_or_test() return results def training_step(self, args): if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.__training_step(args) else: output = self.__training_step(args) return output def __training_step(self, args): batch = args[0] batch = self.to_device(batch) args[0] = batch output = self.trainer.model.training_step(*args) return output def validation_step(self, args): if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.__validation_step(args) else: output = self.__validation_step(args) return output def __validation_step(self, args): batch = args[0] batch = self.to_device(batch) args[0] = batch output = self.trainer.model.validation_step(*args) return output def test_step(self, args): if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.__test_step(args) else: output = self.__test_step(args) return output def __test_step(self, args): batch = args[0] batch = self.to_device(batch) args[0] = batch output = self.trainer.model.test_step(*args) return output def to_device(self, batch): gpu_id = 0 if isinstance(self.trainer.data_parallel_device_ids, list): gpu_id = self.trainer.data_parallel_device_ids[0] # Don't copy the batch since there is a single gpu that the batch could # be referenced from and if there are multiple optimizers the batch will # wind up copying it to the same device repeatedly. return self.batch_to_device(batch, gpu_id)
def setup(self, model): # call setup after the ddp process has connected self.trainer.call_setup_hook(model) if torch.cuda.is_available() and self.trainer.on_gpu: # Horovod: pin GPU to local rank assert self.trainer.root_gpu == hvd.local_rank() torch.cuda.set_device(self.trainer.root_gpu) model.cuda(self.trainer.root_gpu) # avoid duplicating progress bar if hvd.rank() != 0 and self.trainer.progress_bar_callback is not None: self.trainer.progress_bar_callback.disable() # CHOOSE OPTIMIZER # allow for lr schedulers as well optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers( model) self.trainer.optimizers = optimizers self.trainer.lr_schedulers = lr_schedulers self.trainer.optimizer_frequencies = optimizer_frequencies # Horovod: scale the learning rate by the number of workers to account for # increased total batch size for optimizer in self.trainer.optimizers: for param_group in optimizer.param_groups: param_group['lr'] *= hvd.size() # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR for scheduler in self.trainer.lr_schedulers: scheduler = scheduler['scheduler'] if isinstance(scheduler, _LRScheduler): scheduler.base_lrs = [ lr * hvd.size() for lr in scheduler.base_lrs ] # Horovod: broadcast parameters & optimizer state to ensure consistent initialization hvd.broadcast_parameters(model.state_dict(), root_rank=0) for optimizer in self.trainer.optimizers: hvd.broadcast_optimizer_state(optimizer, root_rank=0) def filter_named_parameters(model, optimizer): opt_params = set([ p for group in optimizer.param_groups for p in group.get('params', []) ]) return [(name, p) for name, p in model.named_parameters() if p in opt_params] # Horovod: wrap optimizers to perform gradient aggregation via allreduce self.trainer.optimizers = [ hvd.DistributedOptimizer(optimizer, named_parameters=filter_named_parameters( model, optimizer)) for optimizer in self.trainer.optimizers ] if self.trainer.amp_backend == AMPType.APEX: self.precision_backend = ApexPlugin(self.trainer) model, optimizers = self.precision_backend._init(model) # Update logger rank info from Horovod to avoid race conditions from different ranks # creating directories / writing files in the same locations. self.trainer.global_rank = hvd.rank() rank_zero_only.rank = self.trainer.global_rank self.trainer.model = model
class HorovodBackend(Accelerator): amp_backend: AMPType def __init__(self, trainer): super().__init__(trainer) self.precision_backend = None def setup(self, model): # call setup after the ddp process has connected self.trainer.call_setup_hook(model) if torch.cuda.is_available() and self.trainer.on_gpu: # Horovod: pin GPU to local rank assert self.trainer.root_gpu == hvd.local_rank() torch.cuda.set_device(self.trainer.root_gpu) model.cuda(self.trainer.root_gpu) # avoid duplicating progress bar if hvd.rank() != 0 and self.trainer.progress_bar_callback is not None: self.trainer.progress_bar_callback.disable() # CHOOSE OPTIMIZER # allow for lr schedulers as well optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers( model) self.trainer.optimizers = optimizers self.trainer.lr_schedulers = lr_schedulers self.trainer.optimizer_frequencies = optimizer_frequencies # Horovod: scale the learning rate by the number of workers to account for # increased total batch size for optimizer in self.trainer.optimizers: for param_group in optimizer.param_groups: param_group['lr'] *= hvd.size() # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR for scheduler in self.trainer.lr_schedulers: scheduler = scheduler['scheduler'] if isinstance(scheduler, _LRScheduler): scheduler.base_lrs = [ lr * hvd.size() for lr in scheduler.base_lrs ] # Horovod: broadcast parameters & optimizer state to ensure consistent initialization hvd.broadcast_parameters(model.state_dict(), root_rank=0) for optimizer in self.trainer.optimizers: hvd.broadcast_optimizer_state(optimizer, root_rank=0) def filter_named_parameters(model, optimizer): opt_params = set([ p for group in optimizer.param_groups for p in group.get('params', []) ]) return [(name, p) for name, p in model.named_parameters() if p in opt_params] # Horovod: wrap optimizers to perform gradient aggregation via allreduce self.trainer.optimizers = [ hvd.DistributedOptimizer(optimizer, named_parameters=filter_named_parameters( model, optimizer)) for optimizer in self.trainer.optimizers ] if self.trainer.amp_backend == AMPType.APEX: self.precision_backend = ApexPlugin(self.trainer) model, optimizers = self.precision_backend._init(model) # Update logger rank info from Horovod to avoid race conditions from different ranks # creating directories / writing files in the same locations. self.trainer.global_rank = hvd.rank() rank_zero_only.rank = self.trainer.global_rank self.trainer.model = model def train(self): with ExitStack() as stack: for optimizer in self.trainer.optimizers: # Synchronization will be performed explicitly following backward() stack.enter_context(optimizer.skip_synchronize()) # set up training routine self.trainer.train_loop.setup_training(self.trainer.model) # train or test results = self.train_or_test() # Make sure all workers have finished training before returning to the user hvd.join() return results def teardown(self): pass def training_step(self, args): if self.trainer.on_gpu: batch = args[0] batch = self.batch_to_device(batch, hvd.local_rank()) args[0] = batch if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model.training_step(*args) else: output = self.trainer.model.training_step(*args) return output def validation_step(self, args): if self.trainer.on_gpu: batch = args[0] batch = self.batch_to_device(batch, hvd.local_rank()) args[0] = batch if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model.validation_step(*args) else: output = self.trainer.model.validation_step(*args) return output def test_step(self, args): if self.trainer.on_gpu: batch = args[0] batch = self.batch_to_device(batch, hvd.local_rank()) args[0] = batch if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model.test_step(*args) else: output = self.trainer.model.test_step(*args) return output def backward(self, closure_loss, optimizer, opt_idx): super().backward(closure_loss, optimizer, opt_idx) optimizer.synchronize() def on_train_epoch_end(self): hvd.join(hvd.local_rank() if self.trainer.on_gpu else -1) def barrier(self, name: str = None): hvd.join()