def _configure_zero_optimizer(self, optimizer):
        logging.info('Creating fp16 zero optimizer')
        optimizer = FP16_DeepSpeedZeroOptimizer(
            optimizer,
            static_loss_scale=self.loss_scale(),
            dynamic_loss_scale=self.dynamic_loss_scale(),
            dynamic_loss_args=self.dynamic_loss_scale_args(),
            dp_process_group=self.data_parallel_group,
            clip_grad=self.gradient_clipping(),
            all_gather_partitions=not self.disable_allgather(),
            allgather_size=self.allgather_size(),
            mpu=self.mpu)

        return optimizer
Beispiel #2
0
    def _configure_zero_optimizer(self, optimizer):
        zero_stage = self.zero_optimization_stage()
        logger.info('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage))

        if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
            assert self.zero_reduce_scatter(
            ), 'Stage 1 only supports reduce scatter mode'
            optimizer = FP16_DeepSpeedZeroOptimizer_Stage1(
                optimizer,
                static_loss_scale=self.loss_scale(),
                dynamic_loss_scale=self.dynamic_loss_scale(),
                dynamic_loss_args=self.dynamic_loss_scale_args(),
                clip_grad=self.gradient_clipping(),
                all_gather_partitions=self.zero_allgather_partitions(),
                allgather_size=self.zero_allgather_bucket_size(),
                max_elements_per_comm=self.zero_reduce_bucket_size(),
                dp_process_group=self.data_parallel_group,
                mpu=self.mpu)
        elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
            assert self.gradient_accumulation_steps(
            ) == 1, "ZeRO stage 2 does not support gradient accumulation, if you need gradient accumulation please use stage 1"
            optimizer = FP16_DeepSpeedZeroOptimizer(
                optimizer,
                timers=self.timers,
                static_loss_scale=self.loss_scale(),
                dynamic_loss_scale=self.dynamic_loss_scale(),
                dynamic_loss_args=self.dynamic_loss_scale_args(),
                clip_grad=self.gradient_clipping(),
                contiguous_gradients=self.zero_contiguous_gradients(),
                reduce_bucket_size=self.zero_reduce_bucket_size(),
                allgather_bucket_size=self.zero_allgather_bucket_size(),
                dp_process_group=self.data_parallel_group,
                reduce_scatter=self.zero_reduce_scatter(),
                overlap_comm=self.zero_overlap_comm(),
                mpu=self.mpu,
                postscale_gradients=self.postscale_gradients(),
                gradient_predivide_factor=self.gradient_predivide_factor())
        else:
            raise NotImplementedError(
                "ZeRO stage {} not implemented".format(zero_stage))

        return optimizer