コード例 #1
0
    def _setup_amp_backend(self, amp_type: str):
        if self.trainer.precision != 16:
            # no AMP requested, so we can leave now
            return

        amp_type = amp_type.lower()
        assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}'
        if amp_type == 'native':
            if not NATIVE_AMP_AVAILABLE:
                rank_zero_warn('You have asked for native AMP but your PyTorch version does not support it.'
                               ' Consider upgrading with `pip install torch>=1.6`.'
                               ' We will attempt to use NVIDIA Apex for this session.')
                amp_type = 'apex'
            else:
                self.trainer.amp_backend = AMPType.NATIVE
                log.info('Using native 16bit precision.')
                self.backend = NativeAMPPlugin(self.trainer)

        if amp_type == 'apex':
            if not APEX_AVAILABLE:
                rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.'
                               ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux')
            else:
                log.info('Using APEX 16bit precision.')
                self.trainer.amp_backend = AMPType.APEX
                self.backend = ApexPlugin(self.trainer)
                log.warn("LightningOptimizer doesn't support Apex")

        if not self.trainer.amp_backend:
            raise ModuleNotFoundError(
                f'You have asked for AMP support {amp_type}, but there is no support on your side yet.'
                f' Consider installing torch >= 1.6 or NVIDIA Apex.'
            )
コード例 #2
0
    def __init_nvidia_apex(self, model):
        # check for this bug (amp + dp + !01 doesn't work)
        # https://github.com/NVIDIA/apex/issues/227
        if self.trainer.amp_level == 'O2':
            raise MisconfigurationException(
                f'Amp level {self.trainer.amp_level} with DataParallel is not supported.'
                f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.'
                f' We recommend you switch to ddp if you want to use amp')
        else:
            self.precision_backend = ApexPlugin(self.trainer)
            model, optimizers = self.precision_backend._init(model)

        return model
コード例 #3
0
 def __init__(self, trainer):
     self.trainer = trainer
     self.plugins = []
     self.ddp_plugin = DDPPlugin()
     self.cloud_environment = None
     self.amp_plugin = NativeAMPPlugin(trainer)
     self.apex_plugin = ApexPlugin(trainer)
コード例 #4
0
    def _setup_amp_backend(self, amp_type: str, plugins: Optional[list]):
        if self.trainer.precision != 16:
            # no AMP requested, so we can leave now
            return

        using_sharded_plugin = self._check_using_sharded_plugin(plugins)
        amp_type = amp_type.lower()
        assert amp_type in ('native',
                            'apex'), f'Unsupported amp type {amp_type}'
        if amp_type == 'native':
            if not NATIVE_AMP_AVAILABLE:
                rank_zero_warn(
                    'You have asked for native AMP but your PyTorch version does not support it.'
                    ' Consider upgrading with `pip install torch>=1.6`.'
                    ' We will attempt to use NVIDIA Apex for this session.')
                amp_type = 'apex'
            else:
                self.trainer.amp_backend = AMPType.NATIVE
                if using_sharded_plugin:
                    log.info('Using sharded 16bit precision.')
                    self.backend = ShardedNativeAMPPlugin(self.trainer)
                else:
                    log.info('Using native 16bit precision.')
                    self.backend = NativeAMPPlugin(self.trainer)

        if amp_type == 'apex':
            if not APEX_AVAILABLE:
                rank_zero_warn(
                    'You have asked for Apex AMP but you have not installed it yet.'
                    ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux'
                )
            elif using_sharded_plugin:
                raise MisconfigurationException(
                    'Sharded Plugin is not supported with Apex AMP, please using native AMP for 16-bit precision.'
                )
            else:
                log.info('Using APEX 16bit precision.')
                self.trainer.amp_backend = AMPType.APEX
                self.backend = ApexPlugin(self.trainer)

        if not self.trainer.amp_backend:
            raise ModuleNotFoundError(
                f'You have asked for AMP support {amp_type}, but there is no support on your side yet.'
                f' Consider installing torch >= 1.6 or NVIDIA Apex.')
コード例 #5
0
    def setup(self, model):

        # call setup
        self.trainer.call_setup_hook(model)

        torch.cuda.set_device(self.trainer.root_gpu)
        model.cuda(self.trainer.root_gpu)

        # CHOOSE OPTIMIZER
        # allow for lr schedulers as well
        optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(
            model)
        self.trainer.optimizers = optimizers
        self.trainer.lr_schedulers = lr_schedulers
        self.trainer.optimizer_frequencies = optimizer_frequencies

        if self.trainer.amp_backend == AMPType.APEX:
            self.precision_backend = ApexPlugin(self.trainer)
            model, optimizers = self.precision_backend._init(model)

        self.trainer.model = model
コード例 #6
0
class PrecisionConnector:

    def __init__(self, trainer):
        self.trainer = trainer
        self.backend = None

    def on_trainer_init(self, precision: int, amp_level: str, amp_backend: str):
        # AMP init
        # These are the only lines needed after v0.8.0
        # we wrap the user's forward with autocast and give it back at the end of fit
        self.trainer.autocast_original_forward = None
        self.trainer.precision = precision
        self.trainer.scaler = None

        self.trainer.amp_level = amp_level
        self.init_amp(amp_backend)

    def init_amp(self, amp_type: str):
        assert self.trainer.precision in (16, 32), 'only 32 or 16 bit precision supported'
        self.trainer.amp_backend = None
        self._setup_amp_backend(amp_type)

    def _setup_amp_backend(self, amp_type: str):
        if self.trainer.precision != 16:
            # no AMP requested, so we can leave now
            return

        amp_type = amp_type.lower()
        assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}'
        if amp_type == 'native':
            if not NATIVE_AMP_AVAILABLE:
                rank_zero_warn('You have asked for native AMP but your PyTorch version does not support it.'
                               ' Consider upgrading with `pip install torch>=1.6`.'
                               ' We will attempt to use NVIDIA Apex for this session.')
                amp_type = 'apex'
            else:
                self.trainer.amp_backend = AMPType.NATIVE
                log.info('Using native 16bit precision.')
                self.backend = NativeAMPPlugin(self.trainer)

        if amp_type == 'apex':
            if not APEX_AVAILABLE:
                rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.'
                               ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux')
            else:
                log.info('Using APEX 16bit precision.')
                self.trainer.amp_backend = AMPType.APEX
                self.backend = ApexPlugin(self.trainer)
                log.warn("LightningOptimizer doesn't support Apex")

        if not self.trainer.amp_backend:
            raise ModuleNotFoundError(
                f'You have asked for AMP support {amp_type}, but there is no support on your side yet.'
                f' Consider installing torch >= 1.6 or NVIDIA Apex.'
            )

    def connect(self, model):
        if self.backend:
            model, optimizers = self.backend.connect(model, self.trainer.optimizers)
            self.trainer.optimizers = optimizers

        return model
コード例 #7
0
    def ddp_train_tmp(self,
                      process_idx,
                      mp_queue,
                      model,
                      is_master=False,
                      proc_offset=0):
        """
        Entry point for ddp

        Args:
            process_idx:
            mp_queue: multiprocessing queue
            model:

        Returns:

        """
        # offset the process id if requested
        process_idx = process_idx + proc_offset

        # show progressbar only on progress_rank 0
        if (self.trainer.node_rank != 0 or process_idx != 0
            ) and self.trainer.progress_bar_callback is not None:
            self.trainer.progress_bar_callback.disable()

        # determine which process we are and world size
        self.set_world_ranks(process_idx)

        # set warning rank
        rank_zero_only.rank = self.trainer.global_rank

        # set up server using proc 0's ip address
        # try to init for 20 times at max in case ports are taken
        # where to store ip_table
        model.trainer = self.trainer
        model.init_ddp_connection(self.trainer.global_rank,
                                  self.trainer.world_size,
                                  self.trainer.is_slurm_managing_tasks)

        # call setup after the ddp process has connected
        self.trainer.call_setup_hook(model)

        # on world_size=0 let everyone know training is starting
        if self.trainer.is_global_zero:
            log.info('-' * 100)
            log.info(f'distributed_backend={self.trainer.distributed_backend}')
            log.info(
                f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes'
            )
            log.info('-' * 100)

        # call sync_bn before .cuda(), configure_apex and configure_ddp
        if self.trainer.sync_batchnorm:
            model = model.configure_sync_batchnorm(model)

        # move the model to the correct device
        self.model_to_device(model, process_idx, is_master)

        # CHOOSE OPTIMIZER
        # allow for lr schedulers as well
        optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(
            model)
        self.trainer.optimizers = optimizers
        self.trainer.lr_schedulers = lr_schedulers
        self.trainer.optimizer_frequencies = optimizer_frequencies

        # set model properties before going into wrapper
        self.trainer.model_connector.copy_trainer_model_properties(model)

        # AMP -
        # run through amp wrapper before going to distributed DP
        if self.trainer.amp_backend == AMPType.APEX:
            self.precision_backend = ApexPlugin(self.trainer)
            model, optimizers = self.precision_backend._init(model)

        # device ids change depending on the DDP setup
        device_ids = self.get_device_ids()

        # allow user to configure ddp
        model = model.configure_ddp(model, device_ids)

        # set up training routine
        self.trainer.train_loop.setup_training(model)

        # train or test
        results = self.train_or_test()

        # get original model
        model = self.trainer.get_model()

        # persist info in ddp_spawn
        self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

        # clean up memory
        torch.cuda.empty_cache()

        if self.trainer.global_rank == 0:
            return results
コード例 #8
0
class DDPBase(Accelerator):
    def __init__(self, trainer):
        super().__init__(trainer)
        self.precision_backend = None

    def training_step(self, args):
        if self.trainer.amp_backend == AMPType.NATIVE:
            with torch.cuda.amp.autocast():
                output = self.trainer.model(*args)
        else:
            output = self.trainer.model(*args)
        return output

    def validation_step(self, args):
        output = self.training_step(args)
        return output

    def test_step(self, args):
        output = self.training_step(args)
        return output

    def barrier(self, name: str = None):
        torch_distrib.barrier()

    def early_stopping_should_stop(self, pl_module):
        stop = torch.tensor(int(self.trainer.should_stop),
                            device=pl_module.device)
        dist.all_reduce(stop, op=dist.reduce_op.SUM)
        dist.barrier()
        should_stop = stop == self.trainer.world_size
        return should_stop

    def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue,
                                                results):
        if self.trainer.distributed_backend.lower() not in [
                'ddp_spawn', 'ddp_cpu', 'tpu'
        ]:
            return

        # track the best model path
        best_model_path = None
        if self.trainer.checkpoint_callback is not None:
            best_model_path = self.trainer.checkpoint_callback.best_model_path

        if self.trainer.global_rank == 0 and mp_queue is not None:
            rank_zero_warn('cleaning up ddp environment...')
            # todo, pass complete checkpoint as state dictionary
            mp_queue.put(best_model_path)
            mp_queue.put(results)

            # save the last weights
            last_path = None
            if not self.trainer.testing and best_model_path is not None and len(
                    best_model_path) > 0:
                last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
                atomic_save(model.state_dict(), last_path)
            mp_queue.put(last_path)

    def ddp_train_tmp(self,
                      process_idx,
                      mp_queue,
                      model,
                      is_master=False,
                      proc_offset=0):
        """
        Entry point for ddp

        Args:
            process_idx:
            mp_queue: multiprocessing queue
            model:

        Returns:

        """
        # offset the process id if requested
        process_idx = process_idx + proc_offset

        # show progressbar only on progress_rank 0
        if (self.trainer.node_rank != 0 or process_idx != 0
            ) and self.trainer.progress_bar_callback is not None:
            self.trainer.progress_bar_callback.disable()

        # determine which process we are and world size
        self.set_world_ranks(process_idx)

        # set warning rank
        rank_zero_only.rank = self.trainer.global_rank

        # set up server using proc 0's ip address
        # try to init for 20 times at max in case ports are taken
        # where to store ip_table
        model.trainer = self.trainer
        model.init_ddp_connection(self.trainer.global_rank,
                                  self.trainer.world_size,
                                  self.trainer.is_slurm_managing_tasks)

        # call setup after the ddp process has connected
        self.trainer.call_setup_hook(model)

        # on world_size=0 let everyone know training is starting
        if self.trainer.is_global_zero:
            log.info('-' * 100)
            log.info(f'distributed_backend={self.trainer.distributed_backend}')
            log.info(
                f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes'
            )
            log.info('-' * 100)

        # call sync_bn before .cuda(), configure_apex and configure_ddp
        if self.trainer.sync_batchnorm:
            model = model.configure_sync_batchnorm(model)

        # move the model to the correct device
        self.model_to_device(model, process_idx, is_master)

        # CHOOSE OPTIMIZER
        # allow for lr schedulers as well
        optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(
            model)
        self.trainer.optimizers = optimizers
        self.trainer.lr_schedulers = lr_schedulers
        self.trainer.optimizer_frequencies = optimizer_frequencies

        # set model properties before going into wrapper
        self.trainer.model_connector.copy_trainer_model_properties(model)

        # AMP -
        # run through amp wrapper before going to distributed DP
        if self.trainer.amp_backend == AMPType.APEX:
            self.precision_backend = ApexPlugin(self.trainer)
            model, optimizers = self.precision_backend._init(model)

        # device ids change depending on the DDP setup
        device_ids = self.get_device_ids()

        # allow user to configure ddp
        model = model.configure_ddp(model, device_ids)

        # set up training routine
        self.trainer.train_loop.setup_training(model)

        # train or test
        results = self.train_or_test()

        # get original model
        model = self.trainer.get_model()

        # persist info in ddp_spawn
        self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

        # clean up memory
        torch.cuda.empty_cache()

        if self.trainer.global_rank == 0:
            return results

    def set_world_ranks(self, process_idx):
        raise NotImplementedError(
            'to create a ddp backend, please implement set_world_ranks')

    def model_to_device(self, model, process_idx, is_master):
        raise NotImplementedError(
            'to create a ddp backend, please implement model_to_device')

    def get_device_ids(self):
        raise NotImplementedError(
            'to create a ddp backend, please implement get_device_ids')
コード例 #9
0
class DataParallelBackend(Accelerator):

    def __init__(self, trainer):
        super().__init__(trainer)
        self.model_autocast_original_forward = None
        self.precision_backend = None

    def setup(self, model):
        # call setup after the ddp process has connected
        self.trainer.call_setup_hook(model)

        # put model on correct device
        model.cuda(self.trainer.root_gpu)

        # CHOOSE OPTIMIZER
        # allow for lr schedulers as well
        optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
        self.trainer.optimizers = optimizers
        self.trainer.lr_schedulers = lr_schedulers
        self.trainer.optimizer_frequencies = optimizer_frequencies

        # init torch data parallel
        model = self.__init_torch_data_parallel(model)

        # hack forward to do autocast for the user
        self.model_autocast_original_forward = model.forward

        # init half precision
        if self.trainer.amp_backend:
            model = self.__init_half_precision(model)

        self.trainer.model = model

    def __init_torch_data_parallel(self, model):
        # create list of device ids
        device_ids = self.trainer.data_parallel_device_ids
        if isinstance(device_ids, int):
            device_ids = list(range(device_ids))

        # set dp device
        torch.cuda.set_device(self.trainer.root_gpu)
        model = LightningDataParallel(model, device_ids=device_ids)
        return model

    def __init_half_precision(self, model):
        if self.trainer.amp_backend == AMPType.NATIVE:
            self.__init_native_amp(model)
        else:
            model = self.__init_nvidia_apex(model)
        return model

    def __init_native_amp(self, model):
        model.forward = torch.cuda.amp.autocast()(model.forward)

    def __init_nvidia_apex(self, model):
        # check for this bug (amp + dp + !01 doesn't work)
        # https://github.com/NVIDIA/apex/issues/227
        if self.trainer.amp_level == 'O2':
            raise MisconfigurationException(
                f'Amp level {self.trainer.amp_level} with DataParallel is not supported.'
                f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.'
                f' We recommend you switch to ddp if you want to use amp')
        else:
            self.precision_backend = ApexPlugin(self.trainer)
            model, optimizers = self.precision_backend._init(model)

        return model

    def train(self):
        model = self.trainer.model
        # set up training routine
        self.trainer.train_loop.setup_training(model)

        # train or test
        results = self.train_or_test()

        return results

    def teardown(self):
        # replace the original fwd function
        self.trainer.model.forward = self.model_autocast_original_forward

    def training_step(self, args):
        if self.trainer.amp_backend == AMPType.NATIVE:
            with torch.cuda.amp.autocast():
                output = self.trainer.model(*args)
        else:
            output = self.trainer.model(*args)
        return output

    def validation_step(self, args):
        output = self.training_step(args)
        return output

    def test_step(self, args):
        output = self.training_step(args)
        return output

    def training_step_end(self, output):
        if isinstance(output, Result):
            output.dp_reduce()
        return output

    def validation_step_end(self, output):
        if isinstance(output, Result):
            output.dp_reduce()
        return output

    def test_step_end(self, output):
        if isinstance(output, Result):
            output.dp_reduce()
        return output

    def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
        """
        Reinitialize optimizer.step properties added by schedulers
        """
        for scheduler in schedulers:
            scheduler = scheduler['scheduler']

            for optimizer in optimizers:
                # check that we dont mix users optimizers and schedulers
                if scheduler.optimizer == optimizer:
                    # Find the mro belonging to the base lr scheduler class
                    for i, mro in enumerate(scheduler.__class__.__mro__):
                        is_regular_scheduler = optim.lr_scheduler._LRScheduler
                        is_lr_reduce_on_plateau = optim.lr_scheduler.ReduceLROnPlateau
                        if is_regular_scheduler or is_lr_reduce_on_plateau:
                            idx = i
                            state = scheduler.state_dict()
                        else:
                            state = None

                scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
                if state is not None:
                    scheduler.load_state_dict(state)
コード例 #10
0
class GPUBackend(Accelerator):
    amp_backend: AMPType

    def __init__(self, trainer):
        super().__init__(trainer)
        self.precision_backend = None

    def setup(self, model):

        # call setup
        self.trainer.call_setup_hook(model)

        torch.cuda.set_device(self.trainer.root_gpu)
        model.cuda(self.trainer.root_gpu)

        # CHOOSE OPTIMIZER
        # allow for lr schedulers as well
        optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(
            model)
        self.trainer.optimizers = optimizers
        self.trainer.lr_schedulers = lr_schedulers
        self.trainer.optimizer_frequencies = optimizer_frequencies

        if self.trainer.amp_backend == AMPType.APEX:
            self.precision_backend = ApexPlugin(self.trainer)
            model, optimizers = self.precision_backend._init(model)

        self.trainer.model = model

    def train(self):
        model = self.trainer.model

        # set up training routine
        self.trainer.train_loop.setup_training(model)

        # train or test
        results = self.train_or_test()

        return results

    def training_step(self, args):
        if self.trainer.amp_backend == AMPType.NATIVE:
            with torch.cuda.amp.autocast():
                output = self.__training_step(args)
        else:
            output = self.__training_step(args)

        return output

    def __training_step(self, args):
        batch = args[0]
        batch = self.to_device(batch)
        args[0] = batch
        output = self.trainer.model.training_step(*args)
        return output

    def validation_step(self, args):
        if self.trainer.amp_backend == AMPType.NATIVE:
            with torch.cuda.amp.autocast():
                output = self.__validation_step(args)
        else:
            output = self.__validation_step(args)

        return output

    def __validation_step(self, args):
        batch = args[0]
        batch = self.to_device(batch)
        args[0] = batch
        output = self.trainer.model.validation_step(*args)
        return output

    def test_step(self, args):
        if self.trainer.amp_backend == AMPType.NATIVE:
            with torch.cuda.amp.autocast():
                output = self.__test_step(args)
        else:
            output = self.__test_step(args)

        return output

    def __test_step(self, args):
        batch = args[0]
        batch = self.to_device(batch)
        args[0] = batch
        output = self.trainer.model.test_step(*args)
        return output

    def to_device(self, batch):
        gpu_id = 0
        if isinstance(self.trainer.data_parallel_device_ids, list):
            gpu_id = self.trainer.data_parallel_device_ids[0]

        # Don't copy the batch since there is a single gpu that the batch could
        # be referenced from and if there are multiple optimizers the batch will
        # wind up copying it to the same device repeatedly.
        return self.batch_to_device(batch, gpu_id)
コード例 #11
0
    def setup(self, model):
        # call setup after the ddp process has connected
        self.trainer.call_setup_hook(model)

        if torch.cuda.is_available() and self.trainer.on_gpu:
            # Horovod: pin GPU to local rank
            assert self.trainer.root_gpu == hvd.local_rank()
            torch.cuda.set_device(self.trainer.root_gpu)
            model.cuda(self.trainer.root_gpu)

        # avoid duplicating progress bar
        if hvd.rank() != 0 and self.trainer.progress_bar_callback is not None:
            self.trainer.progress_bar_callback.disable()

        # CHOOSE OPTIMIZER
        # allow for lr schedulers as well
        optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(
            model)
        self.trainer.optimizers = optimizers
        self.trainer.lr_schedulers = lr_schedulers
        self.trainer.optimizer_frequencies = optimizer_frequencies

        # Horovod: scale the learning rate by the number of workers to account for
        # increased total batch size
        for optimizer in self.trainer.optimizers:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= hvd.size()

        # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR
        for scheduler in self.trainer.lr_schedulers:
            scheduler = scheduler['scheduler']
            if isinstance(scheduler, _LRScheduler):
                scheduler.base_lrs = [
                    lr * hvd.size() for lr in scheduler.base_lrs
                ]

        # Horovod: broadcast parameters & optimizer state to ensure consistent initialization
        hvd.broadcast_parameters(model.state_dict(), root_rank=0)
        for optimizer in self.trainer.optimizers:
            hvd.broadcast_optimizer_state(optimizer, root_rank=0)

        def filter_named_parameters(model, optimizer):
            opt_params = set([
                p for group in optimizer.param_groups
                for p in group.get('params', [])
            ])
            return [(name, p) for name, p in model.named_parameters()
                    if p in opt_params]

        # Horovod: wrap optimizers to perform gradient aggregation via allreduce
        self.trainer.optimizers = [
            hvd.DistributedOptimizer(optimizer,
                                     named_parameters=filter_named_parameters(
                                         model, optimizer))
            for optimizer in self.trainer.optimizers
        ]

        if self.trainer.amp_backend == AMPType.APEX:
            self.precision_backend = ApexPlugin(self.trainer)
            model, optimizers = self.precision_backend._init(model)

        # Update logger rank info from Horovod to avoid race conditions from  different ranks
        # creating directories / writing files in the same locations.
        self.trainer.global_rank = hvd.rank()
        rank_zero_only.rank = self.trainer.global_rank

        self.trainer.model = model
コード例 #12
0
class HorovodBackend(Accelerator):
    amp_backend: AMPType

    def __init__(self, trainer):
        super().__init__(trainer)
        self.precision_backend = None

    def setup(self, model):
        # call setup after the ddp process has connected
        self.trainer.call_setup_hook(model)

        if torch.cuda.is_available() and self.trainer.on_gpu:
            # Horovod: pin GPU to local rank
            assert self.trainer.root_gpu == hvd.local_rank()
            torch.cuda.set_device(self.trainer.root_gpu)
            model.cuda(self.trainer.root_gpu)

        # avoid duplicating progress bar
        if hvd.rank() != 0 and self.trainer.progress_bar_callback is not None:
            self.trainer.progress_bar_callback.disable()

        # CHOOSE OPTIMIZER
        # allow for lr schedulers as well
        optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(
            model)
        self.trainer.optimizers = optimizers
        self.trainer.lr_schedulers = lr_schedulers
        self.trainer.optimizer_frequencies = optimizer_frequencies

        # Horovod: scale the learning rate by the number of workers to account for
        # increased total batch size
        for optimizer in self.trainer.optimizers:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= hvd.size()

        # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR
        for scheduler in self.trainer.lr_schedulers:
            scheduler = scheduler['scheduler']
            if isinstance(scheduler, _LRScheduler):
                scheduler.base_lrs = [
                    lr * hvd.size() for lr in scheduler.base_lrs
                ]

        # Horovod: broadcast parameters & optimizer state to ensure consistent initialization
        hvd.broadcast_parameters(model.state_dict(), root_rank=0)
        for optimizer in self.trainer.optimizers:
            hvd.broadcast_optimizer_state(optimizer, root_rank=0)

        def filter_named_parameters(model, optimizer):
            opt_params = set([
                p for group in optimizer.param_groups
                for p in group.get('params', [])
            ])
            return [(name, p) for name, p in model.named_parameters()
                    if p in opt_params]

        # Horovod: wrap optimizers to perform gradient aggregation via allreduce
        self.trainer.optimizers = [
            hvd.DistributedOptimizer(optimizer,
                                     named_parameters=filter_named_parameters(
                                         model, optimizer))
            for optimizer in self.trainer.optimizers
        ]

        if self.trainer.amp_backend == AMPType.APEX:
            self.precision_backend = ApexPlugin(self.trainer)
            model, optimizers = self.precision_backend._init(model)

        # Update logger rank info from Horovod to avoid race conditions from  different ranks
        # creating directories / writing files in the same locations.
        self.trainer.global_rank = hvd.rank()
        rank_zero_only.rank = self.trainer.global_rank

        self.trainer.model = model

    def train(self):
        with ExitStack() as stack:
            for optimizer in self.trainer.optimizers:
                # Synchronization will be performed explicitly following backward()
                stack.enter_context(optimizer.skip_synchronize())

            # set up training routine
            self.trainer.train_loop.setup_training(self.trainer.model)

            # train or test
            results = self.train_or_test()

        # Make sure all workers have finished training before returning to the user
        hvd.join()
        return results

    def teardown(self):
        pass

    def training_step(self, args):
        if self.trainer.on_gpu:
            batch = args[0]
            batch = self.batch_to_device(batch, hvd.local_rank())
            args[0] = batch

        if self.trainer.amp_backend == AMPType.NATIVE:
            with torch.cuda.amp.autocast():
                output = self.trainer.model.training_step(*args)
        else:
            output = self.trainer.model.training_step(*args)

        return output

    def validation_step(self, args):
        if self.trainer.on_gpu:
            batch = args[0]
            batch = self.batch_to_device(batch, hvd.local_rank())
            args[0] = batch

        if self.trainer.amp_backend == AMPType.NATIVE:
            with torch.cuda.amp.autocast():
                output = self.trainer.model.validation_step(*args)
        else:
            output = self.trainer.model.validation_step(*args)

        return output

    def test_step(self, args):
        if self.trainer.on_gpu:
            batch = args[0]
            batch = self.batch_to_device(batch, hvd.local_rank())
            args[0] = batch

        if self.trainer.amp_backend == AMPType.NATIVE:
            with torch.cuda.amp.autocast():
                output = self.trainer.model.test_step(*args)
        else:
            output = self.trainer.model.test_step(*args)
        return output

    def backward(self, closure_loss, optimizer, opt_idx):
        super().backward(closure_loss, optimizer, opt_idx)
        optimizer.synchronize()

    def on_train_epoch_end(self):
        hvd.join(hvd.local_rank() if self.trainer.on_gpu else -1)

    def barrier(self, name: str = None):
        hvd.join()