def broadcast_mp_parameters(model, hcg): model_parallel_group = hcg.get_model_parallel_group() src_rank = hcg.get_model_parallel_group_src_rank() sync_params_buffers(model, model_parallel_group, src_rank, is_model_parallel=True)
def broadcast_dp_parameters(model, hcg): data_parallel_group = hcg.get_data_parallel_group() src_rank = hcg.get_data_parallel_group_src_rank() sync_params_buffers(model, data_parallel_group, src_rank, is_model_parallel=False)
def broadcast_sharding_parameters(model, hcg): # TODO TO save memory, use un-fused broadcast to avoid potentional OOM logger.debug("sharding start init parameters sync") sharding_parallel_group = hcg.get_sharding_parallel_group() src_rank = hcg.get_sharding_parallel_group_src_rank() sync_params_buffers( model, sharding_parallel_group, src_rank, is_model_parallel=False)