def _init_allreduce_operators(length): """ initialize allreduce communication operators""" is_parallel_optimizer = context.get_auto_parallel_context( "enable_parallel_optimizer") split_indices = auto_parallel_context( ).get_all_reduce_fusion_split_indices() if is_parallel_optimizer and split_indices: group = 1 fusion = () for i in range(length): fusion = fusion + (group, ) if split_indices[group - 1] <= i + 1: if group >= len(split_indices): continue group = group + 1 index = tuple(range(1, length + 1)) else: fusion = (1, ) * length index = (0, ) * length opt_list = () for i in range(length): opt = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) opt.add_prim_attr('fusion', fusion[i]) opt.add_prim_attr('index', index[i]) opt_list = opt_list + (opt, ) return opt_list
def _init_optimizer_communication(): global _all_reduce global _all_gather _all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) _all_reduce.add_prim_attr('fusion', 1) _all_gather = AllGather(GlobalComm.WORLD_COMM_GROUP)
def _init_allreduce_operators(length, split_indices): """ initialize allreduce communication operators""" group = 1 fusion = () for i in range(length): fusion = fusion + (group, ) if split_indices[group - 1] <= i + 1: if group >= len(split_indices): continue group = group + 1 index = tuple(range(1, length + 1)) op_list = () for i in range(length): op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) op.add_prim_attr('fusion', fusion[i]) op.add_prim_attr('index', index[i]) op_list = op_list + (op, ) return op_list
def _init_allreduce_operators(length, split_indices): """ initialize allreduce communication operators""" indices = split_indices[0] fusion = split_indices[1] op_list = () j = 0 for i in range(length): if j <= len(indices) - 1: temp = indices[j] else: temp = length if i >= temp: j = j + 1 fusion = fusion + 1 op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) op.add_prim_attr('fusion', fusion) op_list = op_list + (op, ) return op_list
def _init_optimizer_allreduce(): global _all_reduce _all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) _all_reduce.add_prim_attr('fusion', 1)