def get_sparse_bslongformer_config(sparsity): block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT) different_layout_per_head = get_scalar_param( sparsity, SPARSE_DIFFERENT_LAYOUT_PER_HEAD, SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT) num_sliding_window_blocks = get_scalar_param( sparsity, SPARSE_NUM_SLIDING_WINDOW_BLOCKS, SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT) global_block_indices = get_scalar_param(sparsity, SPARSE_GLOBAL_BLOCK_INDICES, SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT) global_block_end_indices = get_scalar_param(sparsity, SPARSE_GLOBAL_BLOCK_END_INDICES, SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT) return { SPARSE_MODE: SPARSE_BSLONGFORMER_MODE, SPARSE_BLOCK: block, SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head, SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks, SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices, SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices }
def get_dynamic_loss_scale_args(param_dict): loss_scale_args = None if get_fp16_enabled(param_dict): fp16_dict = param_dict[FP16] dynamic_loss_args = [ FP16_INITIAL_SCALE_POWER, FP16_LOSS_SCALE_WINDOW, FP16_MIN_LOSS_SCALE, FP16_HYSTERESIS ] if any(arg in list(fp16_dict.keys()) for arg in dynamic_loss_args): init_scale = get_scalar_param(fp16_dict, FP16_INITIAL_SCALE_POWER, FP16_INITIAL_SCALE_POWER_DEFAULT) scale_window = get_scalar_param(fp16_dict, FP16_LOSS_SCALE_WINDOW, FP16_LOSS_SCALE_WINDOW_DEFAULT) delayed_shift = get_scalar_param(fp16_dict, FP16_HYSTERESIS, FP16_HYSTERESIS_DEFAULT) min_loss_scale = get_scalar_param(fp16_dict, FP16_MIN_LOSS_SCALE, FP16_MIN_LOSS_SCALE_DEFAULT) loss_scale_args = { INITIAL_LOSS_SCALE: 2**init_scale, SCALE_WINDOW: scale_window, DELAYED_SHIFT: delayed_shift, MIN_LOSS_SCALE: min_loss_scale } return loss_scale_args
def get_sparse_bigbird_config(sparsity): block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT) different_layout_per_head = get_scalar_param( sparsity, SPARSE_DIFFERENT_LAYOUT_PER_HEAD, SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT) num_random_blocks = get_scalar_param(sparsity, SPARSE_NUM_RANDOM_BLOCKS, SPARSE_NUM_RANDOM_BLOCKS_DEFAULT) num_sliding_window_blocks = get_scalar_param( sparsity, SPARSE_NUM_SLIDING_WINDOW_BLOCKS, SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT) num_global_blocks = get_scalar_param(sparsity, SPARSE_NUM_GLOBAL_BLOCKS, SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT) return { SPARSE_MODE: SPARSE_BIGBIRD_MODE, SPARSE_BLOCK: block, SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head, SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks, SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks, SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks }
def get_allgather_size(param_dict): return get_scalar_param(param_dict, ALLGATHER_SIZE, ALLGATHER_SIZE_DEFAULT) if get_scalar_param( param_dict, ALLGATHER_SIZE, ALLGATHER_SIZE_DEFAULT) > 0 else ALLGATHER_SIZE_DEFAULT
def get_aio_config(param_dict): if AIO in param_dict.keys() and param_dict[AIO] is not None: aio_dict = param_dict[AIO] return { AIO_BLOCK_SIZE: get_scalar_param(aio_dict, AIO_BLOCK_SIZE, AIO_BLOCK_SIZE_DEFAULT), AIO_QUEUE_DEPTH: get_scalar_param(aio_dict, AIO_QUEUE_DEPTH, AIO_QUEUE_DEPTH_DEFAULT), AIO_THREAD_COUNT: get_scalar_param(aio_dict, AIO_THREAD_COUNT, AIO_THREAD_COUNT_DEFAULT), AIO_SINGLE_SUBMIT: get_scalar_param(aio_dict, AIO_SINGLE_SUBMIT, AIO_SINGLE_SUBMIT_DEFAULT), AIO_OVERLAP_EVENTS: get_scalar_param(aio_dict, AIO_OVERLAP_EVENTS, AIO_OVERLAP_EVENTS_DEFAULT) } return AIO_DEFAULT_DICT
def get_sparse_variable_config(sparsity): block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT) different_layout_per_head = get_scalar_param( sparsity, SPARSE_DIFFERENT_LAYOUT_PER_HEAD, SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT) num_random_blocks = get_scalar_param(sparsity, SPARSE_NUM_RANDOM_BLOCKS, SPARSE_NUM_RANDOM_BLOCKS_DEFAULT) local_window_blocks = get_scalar_param(sparsity, SPARSE_LOCAL_WINDOW_BLOCKS, SPARSE_LOCAL_WINDOW_BLOCKS_DEFAULT) global_block_indices = get_scalar_param( sparsity, SPARSE_GLOBAL_BLOCK_INDICES, SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT) global_block_end_indices = get_scalar_param( sparsity, SPARSE_GLOBAL_BLOCK_END_INDICES, SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT) attention = get_scalar_param(sparsity, SPARSE_ATTENTION_TYPE, SPARSE_ATTENTION_TYPE_DEFAULT) horizontal_global_attention = get_scalar_param( sparsity, SPARSE_HORIZONTAL_GLOBAL_ATTENTION, SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT) return { SPARSE_MODE: SPARSE_VARIABLE_MODE, SPARSE_BLOCK: block, SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head, SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks, SPARSE_LOCAL_WINDOW_BLOCKS: local_window_blocks, SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices, SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices, SPARSE_ATTENTION_TYPE: attention, SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention }
def get_tensorboard_output_path(param_dict): if get_tensorboard_enabled(param_dict): return get_scalar_param(param_dict[TENSORBOARD], TENSORBOARD_OUTPUT_PATH, TENSORBOARD_OUTPUT_PATH_DEFAULT) else: return TENSORBOARD_OUTPUT_PATH_DEFAULT
def get_tensorboard_job_name(param_dict): if get_tensorboard_enabled(param_dict): return get_scalar_param(param_dict[TENSORBOARD], TENSORBOARD_JOB_NAME, TENSORBOARD_JOB_NAME_DEFAULT) else: return TENSORBOARD_JOB_NAME_DEFAULT
def get_loss_scale(param_dict): if get_fp16_enabled(param_dict): return get_scalar_param(param_dict[FP16], FP16_LOSS_SCALE, FP16_LOSS_SCALE_DEFAULT) else: return FP16_LOSS_SCALE_DEFAULT
def get_pld_enabled(param_dict): if PROGRESSIVE_LAYER_DROP in param_dict.keys(): return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP], PLD_ENABLED, PLD_ENABLED_DEFAULT) else: return False
def get_tensorboard_enabled(param_dict): if TENSORBOARD in param_dict.keys(): return get_scalar_param(param_dict[TENSORBOARD], TENSORBOARD_ENABLED, TENSORBOARD_ENABLED_DEFAULT) else: return False
def get_sparse_fixed_config(sparsity): block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT) different_layout_per_head = get_scalar_param( sparsity, SPARSE_DIFFERENT_LAYOUT_PER_HEAD, SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT) num_local_blocks = get_scalar_param(sparsity, SPARSE_NUM_LOCAL_BLOCKS, SPARSE_NUM_LOCAL_BLOCKS_DEFAULT) num_global_blocks = get_scalar_param(sparsity, SPARSE_NUM_GLOBAL_BLOCKS, SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT) attention = get_scalar_param(sparsity, SPARSE_ATTENTION_TYPE, SPARSE_ATTENTION_TYPE_DEFAULT) horizontal_global_attention = get_scalar_param( sparsity, SPARSE_HORIZONTAL_GLOBAL_ATTENTION, SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT) num_different_global_patterns = get_scalar_param( sparsity, SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS, SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS_DEFAULT) return { SPARSE_MODE: SPARSE_FIXED_MODE, SPARSE_BLOCK: block, SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head, SPARSE_NUM_LOCAL_BLOCKS: num_local_blocks, SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks, SPARSE_ATTENTION_TYPE: attention, SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention, SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS: num_different_global_patterns }
def get_initial_dynamic_scale(param_dict): if get_fp16_enabled(param_dict): initial_scale_power = get_scalar_param(param_dict[FP16], FP16_INITIAL_SCALE_POWER, FP16_INITIAL_SCALE_POWER_DEFAULT) else: initial_scale_power = FP16_INITIAL_SCALE_POWER_DEFAULT return 2**initial_scale_power
def get_model_info_config(param_dict): if MODEL_INFO in param_dict and param_dict[MODEL_INFO] is not None: model_info_config = {} for key, default_value in MODEL_INFO_KEY_DEFAULT_DICT.items(): model_info_config[key] = get_scalar_param(param_dict[MODEL_INFO], key, default_value) return model_info_config return None
def read_zero_config_deprecated(param_dict): zero_config_dict = {} zero_config_dict["stage"] = 1 if param_dict[ZERO_OPTIMIZATION] else 0 if zero_config_dict["stage"] > 0: zero_config_dict["allgather_bucket_size"] = get_scalar_param( param_dict, "allgather_size", 5e8) logger.warning( "DeepSpeedConfig: this format of ZeRO optimization setup is deprecated. Please use the following format: {}" .format(ZERO_FORMAT)) return zero_config_dict
def _initialize(self, flops_profiler_dict): self.enabled = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_ENABLED, FLOPS_PROFILER_ENABLED_DEFAULT) self.profile_step = get_scalar_param( flops_profiler_dict, FLOPS_PROFILER_PROFILE_STEP, FLOPS_PROFILER_PROFILE_STEP_DEFAULT) self.module_depth = get_scalar_param( flops_profiler_dict, FLOPS_PROFILER_MODULE_DEPTH, FLOPS_PROFILER_MODULE_DEPTH_DEFAULT) self.top_modules = get_scalar_param( flops_profiler_dict, FLOPS_PROFILER_TOP_MODULES, FLOPS_PROFILER_TOP_MODULES_DEFAULT) self.detailed = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_DETAILED, FLOPS_PROFILER_DETAILED_DEFAULT)
def _initialize(self, nebula_dict): self.enabled = get_scalar_param(nebula_dict, NEBULA_ENABLED, NEBULA_ENABLED_DEFAULT) self.load_path = get_scalar_param(nebula_dict, NEBULA_LOAD_PATH, NEBULA_LOAD_PATH_DEFAULT) self.enable_nebula_load = get_scalar_param(nebula_dict, NEBULA_ENABLE_NEBULA_LOAD, NEBULA_ENABLE_NEBULA_LOAD_DEFAULT) self.persistent_storage_path = get_scalar_param( nebula_dict, NEBULA_PERSISTENT_STORAGE_PATH, NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT) self.persistent_time_interval = get_scalar_param( nebula_dict, NEBULA_PERSISTENT_TIME_INTERVAL, NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT) self.num_of_version_in_retention = get_scalar_param( nebula_dict, NEBULA_NUM_OF_VERSION_IN_RETENTION, NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT)
def _initialize(self, zero_config_dict): self.stage = get_scalar_param(zero_config_dict, ZERO_OPTIMIZATION_STAGE, ZERO_OPTIMIZATION_STAGE_DEFAULT) self.contiguous_gradients = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS, ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT) self.reduce_bucket_size = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE, ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT) self.reduce_scatter = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_REDUCE_SCATTER, ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT) self.overlap_comm = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_OVERLAP_COMM, ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT) self.allgather_partitions = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS, ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT) self.allgather_bucket_size = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT) self.load_from_fp32_weights = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT) self.cpu_offload = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_CPU_OFFLOAD, ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT) self.elastic_checkpoint = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT, ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT)
def _initialize(self, flops_profiler_dict): """ docstring """ self.enabled = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_ENABLED, FLOPS_PROFILER_ENABLED_DEFAULT) self.start_step = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_START_STEP, FLOPS_PROFILER_START_STEP_DEFAULT) self.end_step = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_END_STEP, FLOPS_PROFILER_END_STEP_DEFAULT) self.module_depth = get_scalar_param( flops_profiler_dict, FLOPS_PROFILER_MODULE_DEPTH, FLOPS_PROFILER_MODULE_DEPTH_DEFAULT) self.top_modules = get_scalar_param( flops_profiler_dict, FLOPS_PROFILER_TOP_MODULES, FLOPS_PROFILER_TOP_MODULES_DEFAULT)
def read_zero_config_deprecated(self, param_dict): zero_config_dict = {} zero_config_dict[ ZERO_OPTIMIZATION_STAGE] = 1 if param_dict[ZERO_OPTIMIZATION] else 0 if zero_config_dict[ZERO_OPTIMIZATION_STAGE] > 0: zero_config_dict[ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE] = get_scalar_param( param_dict, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT) logger.warning( 'DeepSpeedConfig: this format of ZeRO optimization setup is deprecated. Please use the following format: {}' .format(ZERO_FORMAT)) return zero_config_dict
def _initialize(self, act_chkpt_config_dict): self.partition_activations = get_scalar_param( act_chkpt_config_dict, ACT_CHKPT_PARTITION_ACTIVATIONS, ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT) self.contiguous_memory_optimization = get_scalar_param( act_chkpt_config_dict, ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION, ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT) self.cpu_checkpointing = get_scalar_param( act_chkpt_config_dict, ACT_CHKPT_CPU_CHECKPOINTING, ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT) self.number_checkpoints = get_scalar_param( act_chkpt_config_dict, ACT_CHKPT_NUMBER_CHECKPOINTS, ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT) self.profile = get_scalar_param(act_chkpt_config_dict, ACT_CHKPT_PROFILE, ACT_CHKPT_PROFILE_DEFAULT) self.synchronize_checkpoint_boundary = get_scalar_param( act_chkpt_config_dict, ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY, ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT)
def _initialize(self, zero_config_dict): self._sanity_check(zero_config_dict) self.stage = get_scalar_param(zero_config_dict, ZERO_OPTIMIZATION_STAGE, ZERO_OPTIMIZATION_STAGE_DEFAULT) self.contiguous_gradients = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS, ZERO3_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT if self.stage == ZERO_OPTIMIZATION_WEIGHTS else ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT) self.reduce_bucket_size = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE, ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT) self.reduce_scatter = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_REDUCE_SCATTER, ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT) self.overlap_comm = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_OVERLAP_COMM, ZERO3_OPTIMIZATION_OVERLAP_COMM_DEFAULT if self.stage == ZERO_OPTIMIZATION_WEIGHTS else ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT) self.allgather_partitions = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS, ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT) self.allgather_bucket_size = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT) self.load_from_fp32_weights = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT) self.elastic_checkpoint = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT, ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT) if ZERO_OPTIMIZATION_CPU_OFFLOAD in zero_config_dict: cpu_offload_optimizer = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_CPU_OFFLOAD, ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT) if cpu_offload_optimizer: self.offload_optimizer = get_default_offload_optimizer_config() else: self.offload_optimizer = get_offload_optimizer_config( zero_config_dict) if ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS in zero_config_dict: cpu_offload_params = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS, ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS_DEFAULT) if cpu_offload_params: self.offload_param = get_default_offload_param_config() else: self.offload_param = get_offload_param_config(zero_config_dict) self.sub_group_size = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_SUB_GROUP_SIZE, ZERO_OPTIMIZATION_SUB_GROUP_SIZE_DEFAULT) self.max_live_parameters = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS, ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS_DEFAULT) self.max_reuse_distance = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE, ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE_DEFAULT) self.prefetch_bucket_size = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE, ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT) self.param_persistence_threshold = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD, ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT) self.gather_fp16_weights_on_model_save = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE, ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT) self.ignore_unused_parameters = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS, ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS_DEFAULT) self.legacy_stage1 = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_LEGACY_STAGE1, ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT) self.round_robin_gradients = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS, ZERO3_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT)
def _initialize(self, zero_config_dict): self.stage = get_scalar_param(zero_config_dict, ZERO_OPTIMIZATION_STAGE, ZERO_OPTIMIZATION_STAGE_DEFAULT) self.contiguous_gradients = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS, ZERO3_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT if self.stage == ZERO_OPTIMIZATION_WEIGHTS else ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT) self.reduce_bucket_size = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE, ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT) self.reduce_scatter = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_REDUCE_SCATTER, ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT) self.overlap_comm = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_OVERLAP_COMM, ZERO3_OPTIMIZATION_OVERLAP_COMM_DEFAULT if self.stage == ZERO_OPTIMIZATION_WEIGHTS else ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT) self.allgather_partitions = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS, ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT) self.allgather_bucket_size = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT) self.load_from_fp32_weights = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT) self.cpu_offload = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_CPU_OFFLOAD, ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT) self.elastic_checkpoint = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT, ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT) self.cpu_offload_params = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS, ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS_DEFAULT) self.cpu_offload_use_pin_memory = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_CPU_OFFLOAD_USE_PIN_MEMORY, ZERO_OPTIMIZATION_CPU_OFFLOAD_USE_PIN_MEMORY_DEFAULT) self.sub_group_size = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_SUB_GROUP_SIZE, ZERO_OPTIMIZATION_SUB_GROUP_SIZE_DEFAULT) self.max_live_parameters = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS, ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS_DEFAULT) self.max_reuse_distance = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE, ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE_DEFAULT) self.prefetch_bucket_size = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE, ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT) self.param_persistence_threshold = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD, ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT)
def _get_offload_config(param_dict, key_default_dict): offload_config = {} for key, default_value in key_default_dict.items(): offload_config[key] = get_scalar_param(param_dict, key, default_value) return offload_config
def get_train_batch_size(param_dict): return get_scalar_param(param_dict, TRAIN_BATCH_SIZE, TRAIN_BATCH_SIZE_DEFAULT)
def get_fp16_enabled(param_dict): if FP16 in param_dict.keys(): return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT) else: return False
def get_amp_enabled(param_dict): if AMP in param_dict.keys(): return get_scalar_param(param_dict[AMP], AMP_ENABLED, AMP_ENABLED_DEFAULT) else: return False
def get_train_micro_batch_size_per_gpu(param_dict): return get_scalar_param(param_dict, TRAIN_MICRO_BATCH_SIZE_PER_GPU, TRAIN_MICRO_BATCH_SIZE_PER_GPU_DEFAULT)
def get_wall_clock_breakdown(param_dict): return get_scalar_param(param_dict, WALL_CLOCK_BREAKDOWN, WALL_CLOCK_BREAKDOWN_DEFAULT)
def get_memory_breakdown(param_dict): return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT)