Beispiel #1
0
def backward_p2b_parallel_cast(x):
    split_axis = _infer_split_axis(x)
    if split_axis < 0:
        raise RuntimeError("can't infer split axis")

    sbps = [f"S({split_axis})", "B"]
    parallel_dist = _gen_parallel_dist_by_2d_sbp(sbps)

    dist_util = get_dist_util()
    if dist_util.is_hybrid_parallel():
        # forward: [S(0), B] cast to [S(0), B], identity
        # backward: [S(0), P] cast to [S(0), B], for layernorm grad not supporting P, cast from P to B
        x = flow.hierarchical_parallel_cast(
            x,
            nd_sbp=parallel_dist,
            grad_mode="manual",
            grad_nd_sbp=parallel_dist,
        )
    elif dist_util.is_data_parallel():
        # parallel cast: S(0) -> S(0), identity
        pass
    elif dist_util.is_model_parallel():
        # auto cast by choicing P -> B or P -> S(0), according to order value it should be former
        pass
    elif dist_util.is_non_parallel():
        # no need to cast, identity
        pass
    else:
        raise NotImplementedError

    return x
Beispiel #2
0
def output_parallel_cast(x, device="gpu"):
    dist_util = get_dist_util()
    if dist_util.is_hybrid_parallel():
        with flow.scope.placement(device, dist_util.get_layer_placement(-1)):
            x = flow.hierarchical_parallel_cast(x, nd_sbp=["B"])

    return x
Beispiel #3
0
def forward_p2b_parallel_cast(x):
    split_axis = _infer_split_axis(x)
    if split_axis < 0:
        raise RuntimeError("can't infer split axis")

    sbps = [f"S({split_axis})", "B"]
    parallel_dist = _gen_parallel_dist_by_2d_sbp(sbps)

    dist_util = get_dist_util()
    if dist_util.is_hybrid_parallel() or dist_util.is_model_parallel():
        # forward: [S(0), P] cast to [S(0), B], allreduce
        # backward: [S(0), B] cast to [S(0), B], identity
        # forward: P -> B, allreduce
        # backward: B -> B, identity
        x = flow.hierarchical_parallel_cast(
            x,
            nd_sbp=parallel_dist,
            grad_mode="manual",
            grad_nd_sbp=parallel_dist,
        )
    elif dist_util.is_data_parallel():
        # parallel cast: S(0) -> S(0), identity
        pass
    elif dist_util.is_non_parallel():
        # no need to cast, identity
        pass
    else:
        raise NotImplementedError

    return x
Beispiel #4
0
def input_data_parallel_cast(x):
    dist_util = get_dist_util()
    if dist_util.is_hybrid_parallel():
        x = flow.hierarchical_parallel_cast(
            x, nd_sbp=get_data_parallel_dist(),
        )

    return x
Beispiel #5
0
 def test_fn(
     x: flow.typing.Numpy.Placeholder((1024, 4)),
     indices: flow.typing.Numpy.Placeholder(shape=(12, ), dtype=flow.int32),
 ) -> flow.typing.Numpy:
     with flow.scope.placement("gpu", "0:0-3", (2, 2)):
         x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "S(0)"])
         indices = flow.hierarchical_parallel_cast(indices,
                                                   nd_sbp=["B", "B"])
         x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "B"])
         v = flow.get_variable(
             name="v",
             shape=(1024, 4),
             nd_sbp=["S(0)", "B"],
             initializer=flow.zeros_initializer(),
         )
         x = x + v
         indices = flow.hierarchical_parallel_cast(indices,
                                                   nd_sbp=["B", "S(0)"])
         x = flow.gather(x, indices)
         x = flow.hierarchical_parallel_cast(
             x,
             nd_sbp=["B", "S(0)"],
             grad_mode="manual",
             grad_nd_sbp=["B", "S(0)"],
         )
         x = flow.math.relu(x)
         x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"])
     x = flow.hierarchical_parallel_cast(x, nd_sbp=["B"])
     flow.optimizer.SGD(flow.optimizer.PiecewiseConstantScheduler([],
                                                                  [0.001]),
                        momentum=0).minimize(x)
     return x
Beispiel #6
0
 def FlowJob(x: flow.typing.Numpy.Placeholder((4, 6), dtype=flow.float)):
     with flow.scope.placement("gpu", "0:0-3", (2, 2)):
         v = flow.get_variable(
             "x",
             shape=(4, 6),
             dtype=flow.float,
             initializer=flow.constant_initializer(0),
             trainable=True,
             nd_sbp=["S(0)", "S(1)"],
         )
         x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "S(1)"])
         x += v
         loss = flow.reshape(x, (4, 2, 3))
     loss = flow.hierarchical_parallel_cast(loss, nd_sbp=["S(0)"])
     flow.optimizer.SGD(flow.optimizer.PiecewiseConstantScheduler([],
                                                                  [0.0001]),
                        momentum=0).minimize(loss)
     return loss
Beispiel #7
0
 def test_fn(
         a: flow.typing.Numpy.Placeholder(a_shape),
         b: flow.typing.Numpy.Placeholder(b_shape),
         c: flow.typing.Numpy.Placeholder(c_shape),
 ) -> flow.typing.Numpy:
     var_a = flow.get_variable(
         name="var_a",
         shape=a_shape,
         dtype=flow.float32,
         initializer=flow.ones_initializer(),
         distribute=flow.distribute.split(1),
     )
     a = flow.hierarchical_parallel_cast(a, nd_sbp=["S(1)"])
     a = var_a * a
     out = flow.matmul(a, b)
     out = flow.hierarchical_parallel_cast(out, nd_sbp=["B"])
     c = flow.hierarchical_parallel_cast(c, nd_sbp=["B"])
     out = flow.nn.bias_add(out, c)
     lr_scheduler = flow.optimizer.PiecewiseConstantScheduler([], [0.001])
     flow.optimizer.SGD(lr_scheduler, momentum=0).minimize(out)
     return out
Beispiel #8
0
 def gpt_loader_fn() -> flow.typing.Numpy:
     with flow.scope.placement("cpu", device_strs, parallel_hierachy):
         tokens = flow.data.megatron_gpt_mmap_data_loader(
             data_file_prefix=data_file_prefix,
             seq_length=seq_length,
             num_samples=num_samples,
             batch_size=batch_size,
             dtype=dtype,
             shuffle=shuffle,
             random_seed=random_seed,
             split_sizes=split_sizes,
             split_index=split_index,
             nd_sbp=nd_sbp,
             start_from_saved_progress=start_from_saved_progress,
             name="GPTDataLoader",
         )
         if isinstance(nd_sbp, list) and len(nd_sbp) > 1:
             tokens = flow.hierarchical_parallel_cast(tokens,
                                                      nd_sbp=["B", "B"])
     tokens = flow.hierarchical_parallel_cast(tokens, nd_sbp=["B"])
     return tokens
Beispiel #9
0
 def FlowJob(x: flow.typing.Numpy.Placeholder((4, 3, 2, 3),
                                              dtype=flow.float)):
     with flow.scope.placement("gpu", "0:0-3", (2, 2)):
         v1 = flow.get_variable(
             "v1",
             shape=(4, 3, 2, 3),
             dtype=flow.float,
             initializer=flow.constant_initializer(0),
             trainable=True,
             nd_sbp=["S(0)", "S(2)"],
         )
         v2 = flow.get_variable(
             "v2",
             shape=(4, 3, 6),
             dtype=flow.float,
             initializer=flow.constant_initializer(0),
             trainable=True,
             nd_sbp=["S(0)", "S(2)"],
         )
         x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "S(2)"])
         x += v1
         loss = flow.reshape_like(x, v2)
     loss = flow.hierarchical_parallel_cast(loss, nd_sbp=["S(0)"])
     return loss
Beispiel #10
0
 def test_fn(
     x: flow.typing.Numpy.Placeholder((1024, 1024)),
     indices: flow.typing.Numpy.Placeholder(shape=(64, ), dtype=flow.int32),
 ) -> flow.typing.Numpy:
     with flow.scope.placement("gpu", "0:0-3", (2, 2)):
         if src[0] == "S(0)":
             x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"])
             indices = flow.hierarchical_parallel_cast(
                 indices, nd_sbp=["S(0)", "S(0)"])
             if src[1] == "S(0)":
                 x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"])
                 indices = flow.hierarchical_parallel_cast(
                     indices, nd_sbp=["S(0)", "S(0)"])
             elif src[1] == "S(1)":
                 x = flow.hierarchical_parallel_cast(x,
                                                     nd_sbp=["B", "S(1)"])
                 indices = flow.hierarchical_parallel_cast(
                     indices, nd_sbp=["S(0)", "B"])
             elif src[1] == "P":
                 x = flow.hierarchical_parallel_cast(x,
                                                     nd_sbp=["B", "S(0)"])
                 indices = flow.hierarchical_parallel_cast(
                     indices, nd_sbp=["S(0)", "B"])
             elif src[1] == "B":
                 x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"])
                 indices = flow.hierarchical_parallel_cast(
                     indices, nd_sbp=["S(0)", "B"])
         elif src[0] == "P":
             x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "S(0)"])
             indices = flow.hierarchical_parallel_cast(indices,
                                                       nd_sbp=["B", "B"])
             if src[1] == "S(0)":
                 x = flow.hierarchical_parallel_cast(x,
                                                     nd_sbp=["S(0)", "B"])
                 indices = flow.hierarchical_parallel_cast(
                     indices, nd_sbp=["B", "S(0)"])
             elif src[1] == "S(1)":
                 x = flow.hierarchical_parallel_cast(
                     x, nd_sbp=["S(0)", "S(1)"])
                 indices = flow.hierarchical_parallel_cast(
                     indices, nd_sbp=["B", "B"])
             elif src[1] == "P":
                 x = flow.hierarchical_parallel_cast(
                     x, nd_sbp=["S(0)", "S(0)"])
                 indices = flow.hierarchical_parallel_cast(
                     indices, nd_sbp=["B", "B"])
             elif src[1] == "B":
                 x = flow.hierarchical_parallel_cast(x,
                                                     nd_sbp=["S(0)", "B"])
                 indices = flow.hierarchical_parallel_cast(
                     indices, nd_sbp=["B", "B"])
         elif src[0] == "B":
             x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"])
             indices = flow.hierarchical_parallel_cast(indices,
                                                       nd_sbp=["B", "B"])
             if src[1] == "S(0)":
                 x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"])
                 indices = flow.hierarchical_parallel_cast(
                     indices, nd_sbp=["B", "S(0)"])
             elif src == "S(1)":
                 x = flow.hierarchical_parallel_cast(x,
                                                     nd_sbp=["B", "S(1)"])
                 indices = flow.hierarchical_parallel_cast(
                     indices, nd_sbp=["B", "B"])
             elif src == "P":
                 x = flow.hierarchical_parallel_cast(x,
                                                     nd_sbp=["B", "S(0)"])
                 indices = flow.hierarchical_parallel_cast(
                     indices, nd_sbp=["B", "B"])
             elif src == "B":
                 x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"])
                 indices = flow.hierarchical_parallel_cast(
                     indices, nd_sbp=["B", "B"])
             else:
                 raise NotImplementedError
         x = flow.gather(x, indices)
         x = flow.hierarchical_parallel_cast(x,
                                             nd_sbp=dst,
                                             name="gather_cast")
         if dst[0] == "S(0)":
             x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "S(0)"])
         elif dst[0] == "B":
             x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"])
         elif dst[0] == "S(1)":
             x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(1)", "S(1)"])
         else:
             raise NotImplementedError
     x = flow.hierarchical_parallel_cast(x, nd_sbp=["B"])
     return x