def __init__(self, input_channel, out_channel): super(AllGatherNet, self).__init__() self.dense = Dense(input_channel, out_channel) if GlobalComm.BACKEND is Backend.HCCL: self.allgather = AllGather(group=HCCL_WORLD_COMM_GROUP) elif GlobalComm.BACKEND is Backend.NCCL: self.allgather = AllGather(group=NCCL_WORLD_COMM_GROUP) else: self.allgather = AllGather() self.relu = ReLU()
def __init__(self, parameters, mean=True, degree=None): super(DistributedGradReducer, self).__init__(auto_prefix=False) self.map_ = C.Map() if degree is None: self.degree = get_group_size() else: if not isinstance(degree, int) or degree <= 0: raise ValueError( "Parameter 'degree' in DistributedGradReducer should large than 0 and be int" ) self.degree = degree self.mean = mean self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) is_parallel_optimizer = context.get_auto_parallel_context( "enable_parallel_optimizer") split_indices = auto_parallel_context( ).get_all_reduce_fusion_split_indices() if is_parallel_optimizer and split_indices: self.split_fusion = True self.op_list = _init_allreduce_operators(len(parameters), split_indices) else: self.split_fusion = False self.allreduce = AllReduce().add_prim_attr('fusion', 1) self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP) ps_filter = lambda x: x.is_param_ps self.ps_parameters = tuple(ps_filter(x) for x in parameters) self.enable_parameter_server = any(self.ps_parameters)
def _init_optimizer_communication(): global _all_reduce global _all_gather _all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) _all_reduce.add_prim_attr('fusion', 1) _all_gather = AllGather(GlobalComm.WORLD_COMM_GROUP)
def __init__(self, parameters, mean=True, degree=None): super(DistributedGradReducer, self).__init__(auto_prefix=False) self.map_ = C.Map() if degree is None: self.degree = get_group_size() else: if not isinstance(degree, int) or degree <= 0: raise ValueError( "Parameter 'degree' in DistributedGradReducer should large than 0 and be int" ) self.degree = degree self.mean = mean self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) self.opt_list = _init_allreduce_operators(len(parameters)) self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP) ps_filter = lambda x: x.is_param_ps self.ps_parameters = tuple(ps_filter(x) for x in parameters)
def __init__(self, group): super(SaveOptShardCkptCell, self).__init__(auto_prefix=False) self.allgather1 = AllGather(group) self.allgather2 = AllGather()
def __init__(self): super(AllGatherCell, self).__init__(auto_prefix=False) self.allgather = AllGather()