def initialize_model_parallel_for_nemo( world_size, global_rank, local_rank, tensor_model_parallel_size=1, seed=1234, ): # updating NeMo globals app_state = AppState() app_state.global_rank = global_rank app_state.world_size = world_size app_state.model_parallel_size = tensor_model_parallel_size app_state.model_parallel_rank = compute_model_parallel_rank( local_rank, tensor_model_parallel_size) # update apex.mpu globals set_tensor_model_parallel_world_size(tensor_model_parallel_size) set_tensor_model_parallel_rank(app_state.model_parallel_rank) # pipeline model parallelism not implemented in NeMo yet set_pipeline_model_parallel_rank(0) set_pipeline_model_parallel_world_size(1) _set_random_seed(seed) app_state._is_megatron_initialized = True
def initialize_model_parallel_for_nemo( world_size, global_rank, local_rank, tensor_model_parallel_size=1, pipeline_model_parallel_size=1, micro_batch_size=None, global_batch_size=None, seed=1234, apex_transformer_log_level=30, ): # updating NeMo globals app_state = AppState() app_state.global_rank = global_rank app_state.world_size = world_size app_state.local_rank = local_rank app_state.tensor_model_parallel_size = tensor_model_parallel_size app_state.pipeline_model_parallel_size = pipeline_model_parallel_size ( app_state.tensor_model_parallel_rank, app_state.pipeline_model_parallel_rank, app_state.model_parallel_size, app_state.data_parallel_size, ) = fake_initialize_model_parallel( world_size=world_size, rank=global_rank, tensor_model_parallel_size_=tensor_model_parallel_size, pipeline_model_parallel_size_=pipeline_model_parallel_size, ) # update apex.transformer globals set_tensor_model_parallel_world_size(app_state.tensor_model_parallel_size) set_tensor_model_parallel_rank(app_state.tensor_model_parallel_rank) # pipeline model parallelism not implemented in NeMo yet set_pipeline_model_parallel_rank(app_state.pipeline_model_parallel_rank) set_pipeline_model_parallel_world_size( app_state.pipeline_model_parallel_size) _set_random_seed(seed) if global_batch_size and micro_batch_size is not None: # TODO: add rampup_batch_size here when we have it implemented setup_microbatch_calculator( rank=global_rank, global_batch_size=global_batch_size, micro_batch_size=micro_batch_size, data_parallel_size=app_state.data_parallel_size, rampup_batch_size=None, ) app_state._is_megatron_initialized = True set_logging_level(apex_transformer_log_level)