Exemplo n.º 1
0
def _init_allreduce_operators(length):
    """ initialize allreduce communication operators"""
    is_parallel_optimizer = context.get_auto_parallel_context(
        "enable_parallel_optimizer")
    split_indices = auto_parallel_context(
    ).get_all_reduce_fusion_split_indices()
    if is_parallel_optimizer and split_indices:
        group = 1
        fusion = ()
        for i in range(length):
            fusion = fusion + (group, )
            if split_indices[group - 1] <= i + 1:
                if group >= len(split_indices):
                    continue
                group = group + 1
        index = tuple(range(1, length + 1))
    else:
        fusion = (1, ) * length
        index = (0, ) * length
    opt_list = ()
    for i in range(length):
        opt = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP)
        opt.add_prim_attr('fusion', fusion[i])
        opt.add_prim_attr('index', index[i])
        opt_list = opt_list + (opt, )
    return opt_list
Exemplo n.º 2
0
def _init_optimizer_communication():
    global _all_reduce
    global _all_gather

    _all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
    _all_reduce.add_prim_attr('fusion', 1)
    _all_gather = AllGather(GlobalComm.WORLD_COMM_GROUP)
Exemplo n.º 3
0
def _init_allreduce_operators(length, split_indices):
    """ initialize allreduce communication operators"""
    group = 1
    fusion = ()
    for i in range(length):
        fusion = fusion + (group, )
        if split_indices[group - 1] <= i + 1:
            if group >= len(split_indices):
                continue
            group = group + 1
    index = tuple(range(1, length + 1))
    op_list = ()
    for i in range(length):
        op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP)
        op.add_prim_attr('fusion', fusion[i])
        op.add_prim_attr('index', index[i])
        op_list = op_list + (op, )
    return op_list
Exemplo n.º 4
0
def _init_allreduce_operators(length, split_indices):
    """ initialize allreduce communication operators"""
    indices = split_indices[0]
    fusion = split_indices[1]
    op_list = ()
    j = 0
    for i in range(length):
        if j <= len(indices) - 1:
            temp = indices[j]
        else:
            temp = length
        if i >= temp:
            j = j + 1
            fusion = fusion + 1
        op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP)
        op.add_prim_attr('fusion', fusion)
        op_list = op_list + (op, )
    return op_list
Exemplo n.º 5
0
def _init_optimizer_allreduce():
    global _all_reduce
    _all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
    _all_reduce.add_prim_attr('fusion', 1)