예제 #1
0
    def __init__(self, input_channel, out_channel):
        super(AllGatherNet, self).__init__()
        self.dense = Dense(input_channel, out_channel)
        if GlobalComm.BACKEND is Backend.HCCL:
            self.allgather = AllGather(group=HCCL_WORLD_COMM_GROUP)
        elif GlobalComm.BACKEND is Backend.NCCL:
            self.allgather = AllGather(group=NCCL_WORLD_COMM_GROUP)
        else:
            self.allgather = AllGather()

        self.relu = ReLU()
예제 #2
0
 def __init__(self, parameters, mean=True, degree=None):
     super(DistributedGradReducer, self).__init__(auto_prefix=False)
     self.map_ = C.Map()
     if degree is None:
         self.degree = get_group_size()
     else:
         if not isinstance(degree, int) or degree <= 0:
             raise ValueError(
                 "Parameter 'degree' in DistributedGradReducer should large than 0 and be int"
             )
         self.degree = degree
     self.mean = mean
     self.allreduce_filter = tuple(x.layerwise_parallel is False
                                   for x in parameters)
     is_parallel_optimizer = context.get_auto_parallel_context(
         "enable_parallel_optimizer")
     split_indices = auto_parallel_context(
     ).get_all_reduce_fusion_split_indices()
     if is_parallel_optimizer and split_indices:
         self.split_fusion = True
         self.op_list = _init_allreduce_operators(len(parameters),
                                                  split_indices)
     else:
         self.split_fusion = False
         self.allreduce = AllReduce().add_prim_attr('fusion', 1)
     self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP)
     ps_filter = lambda x: x.is_param_ps
     self.ps_parameters = tuple(ps_filter(x) for x in parameters)
     self.enable_parameter_server = any(self.ps_parameters)
예제 #3
0
def _init_optimizer_communication():
    global _all_reduce
    global _all_gather

    _all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
    _all_reduce.add_prim_attr('fusion', 1)
    _all_gather = AllGather(GlobalComm.WORLD_COMM_GROUP)
예제 #4
0
 def __init__(self, parameters, mean=True, degree=None):
     super(DistributedGradReducer, self).__init__(auto_prefix=False)
     self.map_ = C.Map()
     if degree is None:
         self.degree = get_group_size()
     else:
         if not isinstance(degree, int) or degree <= 0:
             raise ValueError(
                 "Parameter 'degree' in DistributedGradReducer should large than 0 and be int"
             )
         self.degree = degree
     self.mean = mean
     self.allreduce_filter = tuple(x.layerwise_parallel is False
                                   for x in parameters)
     self.opt_list = _init_allreduce_operators(len(parameters))
     self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP)
     ps_filter = lambda x: x.is_param_ps
     self.ps_parameters = tuple(ps_filter(x) for x in parameters)
예제 #5
0
 def __init__(self, group):
     super(SaveOptShardCkptCell, self).__init__(auto_prefix=False)
     self.allgather1 = AllGather(group)
     self.allgather2 = AllGather()
예제 #6
0
    def __init__(self):
        super(AllGatherCell, self).__init__(auto_prefix=False)

        self.allgather = AllGather()