def get_broadcast_vars_and_param_usage(self, block): broadcast_vars = set([]) fp16_params = set([]) fp16_to_fp32 = {} param_usage = {x: 0 for x in self.param_names} for op in block.ops: if is_optimizer_op(op): continue for input_name in op.desc.input_arg_names(): if input_name in self.param_names: param_usage[input_name] += 1 for op in block.ops: if not _is_param_fp16_cast_op(block, op, self.param_names): continue input_name = op.input_arg_names[0] output_name = op.output_arg_names[0] broadcast_vars.add(output_name) fp16_params.add(output_name) fp16_to_fp32[output_name] = input_name param_usage[input_name] -= 1 self.param_to_rank[output_name] = self.param_to_rank[input_name] for param, usage in param_usage.items(): if usage > 0: broadcast_vars.add(param) return broadcast_vars, param_usage
def find_broadcast_params(self, block): broadcast_vars = set([]) fp16_params = set([]) fp16_to_fp32 = {} param_usage = {x: 0 for x in self.global_params} for op in block.ops: if is_optimizer_op(op): continue for input_name in op.desc.input_arg_names(): if input_name in self.global_params: param_usage[input_name] += 1 for op in block.ops: if not FP16Utils.is_fp16_cast_op(block, op, self.global_params): continue input_name = op.input_arg_names[0] output_name = op.output_arg_names[0] broadcast_vars.add(output_name) fp16_params.add(output_name) fp16_to_fp32[output_name] = input_name param_usage[input_name] -= 1 self.global_param2device[output_name] = self.global_param2device[ input_name] for param, usage in param_usage.items(): if usage > 0: broadcast_vars.add(param) return broadcast_vars
def _shard_optimizer_ops_and_states(self, main_block, startup_block): should_removed_optimizer_states = [] for idx, op in reversed(list(enumerate(main_block.ops))): if not is_optimizer_op(op): break if op.type in _supported_optimizer_type: assert "Param" in op.input_names assert len(op.input("Param")) == 1 param_name = op.input("Param")[0] if not self._is_parameter_in_local_shard(param_name): should_removed_optimizer_states.extend([ varname for varname in op.output_arg_names if varname != param_name ]) main_block._remove_op(idx, sync=False) for idx, op in reversed(list(enumerate(startup_block.ops))): if len(op.output_arg_names) == 1 and op.output_arg_names[ 0] in should_removed_optimizer_states: startup_block._remove_op(idx, sync=False) for varname in should_removed_optimizer_states: if main_block.has_var(varname): main_block._remove_var(varname, sync=False) if startup_block.has_var(varname): startup_block._remove_var(varname, sync=False) main_block._sync_with_cpp() startup_block._sync_with_cpp()
def _is_param_fp16_cast_op(block, op, params): if is_optimizer_op(op): return False if not _is_desired_cast_op(block, op): return False input_name = op.desc.input_arg_names()[0] if input_name not in params: return False return True
def is_fp32_cast_op(block, op): if op.type != "cast": return False if not is_optimizer_op(op): return False assert (len(op.desc.input_arg_names()) == 1) assert (len(op.desc.output_arg_names()) == 1) input_name, output_name = op.desc.input_arg_names()[ 0], op.desc.output_arg_names()[0] input_var = block.var(input_name) output_var = block.var(output_name) if input_var.dtype != core.VarDesc.VarType.FP16 or \ output_var.dtype != core.VarDesc.VarType.FP32: return False return True
def fuse_opt_broadcast_param_ops(block, ring_id, shard, op_role=OpRole.Optimize, strategy=None): """ fuse optimizer sharding broadcast param ops """ if strategy is None or not strategy.fuse_all_reduce_ops: return fuse_size = strategy.fuse_grad_size_in_MB nranks = shard.worker_num device_to_vars = [[] for _ in range(nranks)] for idx, op in reversed(list(enumerate(block.ops))): if not is_optimizer_op(op) or op.type != 'c_broadcast': break var = op.input_arg_names[0] root_id = op.attr('root') device_to_vars[root_id].insert(0, var) block._remove_op(idx, sync=False) insert_idx = idx + 1 for root_id, vars_name in enumerate(device_to_vars): vars_name = FuseHelper.sort_vars_by_dtype(block, vars_name) groups = FuseHelper.get_fused_groups(block, vars_name, fuse_size) fused_vars, insert_num = FuseHelper.insert_coalesce_tensor( block, insert_idx, groups, op_role, prefix="Param") for fused_var in fused_vars: block._insert_op_without_sync(insert_idx + insert_num, type='c_broadcast', inputs={'X': fused_var}, outputs={'Out': fused_var}, attrs={ 'ring_id': ring_id, 'root': root_id, 'use_calc_stream': True, OP_ROLE_KEY: op_role }) block._sync_with_cpp()
def _shard_parameter(self, main_block, startup_block): if self.stage < 3: return dp_ring_ids = [group.id for group in self.dp_groups] for sharding_info in self.sharding_infos: need_broadcast_vars, param_usage = sharding_info.get_broadcast_vars_and_param_usage( main_block) not_used_param_nane = [] for param_name in param_usage: if param_usage[param_name] == 0 and sharding_info.get_var_rank( param_name) != sharding_info.local_rank: not_used_param_nane.append(param_name) for idx, op in reversed(list(enumerate(main_block.ops))): if is_optimizer_op(op): continue for input_name in op.desc.input_arg_names(): if op.type == "cast": continue if input_name not in need_broadcast_vars: continue root_rank = sharding_info.get_var_rank(input_name) if root_rank == sharding_info.local_rank: broadcast_varname = input_name else: broadcast_varname = unique_name.generate(input_name + "@BroadCast") input_var = main_block.var(input_name) new_var = main_block.create_var(name=broadcast_varname, shape=input_var.shape, dtype=input_var.dtype, persistable=False) ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( input_var) out_var_dist_attr = set_var_dist_attr( self._dist_context, new_var, ref_dist_attr.dims_mapping, ref_dist_attr.process_mesh) op._rename_input(input_name, broadcast_varname) _insert_init_and_broadcast_op(main_block, idx, broadcast_varname, sharding_info.local_rank, root_rank, sharding_info.group.id, op.attr('op_role'), self._dist_context) for idx, op in reversed(list(enumerate(main_block.ops))): if op.type != "cast": continue input_name = op.input_arg_names[0] output_name = op.output_arg_names[0] if input_name in not_used_param_nane: main_block._remove_op(idx, sync=False) main_block._remove_var(output_name, sync=False) for idx, op in reversed(list(enumerate(startup_block.ops))): assert len(op.output_arg_names) == 1 output_name = op.output_arg_names[0] if op.type == "c_broadcast" and op.attr( "ring_id") in dp_ring_ids: if self.outer_dp_group and sharding_info.get_var_rank( output_name) == sharding_info.local_rank: op._set_attr("ring_id", self.outer_dp_group.id) else: startup_block._remove_op(idx, sync=False) continue if op.type != "c_broadcast" and output_name in param_usage and sharding_info.get_var_rank( output_name) != sharding_info.local_rank: startup_block._remove_op(idx, sync=False) for param_name in param_usage: if sharding_info.get_var_rank( param_name) != sharding_info.local_rank: main_block._remove_var(param_name, sync=False) startup_block._remove_var(param_name, sync=False) main_block._sync_with_cpp() startup_block._sync_with_cpp()