def backward_p2b_parallel_cast(x): split_axis = _infer_split_axis(x) if split_axis < 0: raise RuntimeError("can't infer split axis") sbps = [f"S({split_axis})", "B"] parallel_dist = _gen_parallel_dist_by_2d_sbp(sbps) dist_util = get_dist_util() if dist_util.is_hybrid_parallel(): # forward: [S(0), B] cast to [S(0), B], identity # backward: [S(0), P] cast to [S(0), B], for layernorm grad not supporting P, cast from P to B x = flow.hierarchical_parallel_cast( x, nd_sbp=parallel_dist, grad_mode="manual", grad_nd_sbp=parallel_dist, ) elif dist_util.is_data_parallel(): # parallel cast: S(0) -> S(0), identity pass elif dist_util.is_model_parallel(): # auto cast by choicing P -> B or P -> S(0), according to order value it should be former pass elif dist_util.is_non_parallel(): # no need to cast, identity pass else: raise NotImplementedError return x
def output_parallel_cast(x, device="gpu"): dist_util = get_dist_util() if dist_util.is_hybrid_parallel(): with flow.scope.placement(device, dist_util.get_layer_placement(-1)): x = flow.hierarchical_parallel_cast(x, nd_sbp=["B"]) return x
def forward_p2b_parallel_cast(x): split_axis = _infer_split_axis(x) if split_axis < 0: raise RuntimeError("can't infer split axis") sbps = [f"S({split_axis})", "B"] parallel_dist = _gen_parallel_dist_by_2d_sbp(sbps) dist_util = get_dist_util() if dist_util.is_hybrid_parallel() or dist_util.is_model_parallel(): # forward: [S(0), P] cast to [S(0), B], allreduce # backward: [S(0), B] cast to [S(0), B], identity # forward: P -> B, allreduce # backward: B -> B, identity x = flow.hierarchical_parallel_cast( x, nd_sbp=parallel_dist, grad_mode="manual", grad_nd_sbp=parallel_dist, ) elif dist_util.is_data_parallel(): # parallel cast: S(0) -> S(0), identity pass elif dist_util.is_non_parallel(): # no need to cast, identity pass else: raise NotImplementedError return x
def input_data_parallel_cast(x): dist_util = get_dist_util() if dist_util.is_hybrid_parallel(): x = flow.hierarchical_parallel_cast( x, nd_sbp=get_data_parallel_dist(), ) return x
def test_fn( x: flow.typing.Numpy.Placeholder((1024, 4)), indices: flow.typing.Numpy.Placeholder(shape=(12, ), dtype=flow.int32), ) -> flow.typing.Numpy: with flow.scope.placement("gpu", "0:0-3", (2, 2)): x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "S(0)"]) indices = flow.hierarchical_parallel_cast(indices, nd_sbp=["B", "B"]) x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "B"]) v = flow.get_variable( name="v", shape=(1024, 4), nd_sbp=["S(0)", "B"], initializer=flow.zeros_initializer(), ) x = x + v indices = flow.hierarchical_parallel_cast(indices, nd_sbp=["B", "S(0)"]) x = flow.gather(x, indices) x = flow.hierarchical_parallel_cast( x, nd_sbp=["B", "S(0)"], grad_mode="manual", grad_nd_sbp=["B", "S(0)"], ) x = flow.math.relu(x) x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"]) x = flow.hierarchical_parallel_cast(x, nd_sbp=["B"]) flow.optimizer.SGD(flow.optimizer.PiecewiseConstantScheduler([], [0.001]), momentum=0).minimize(x) return x
def FlowJob(x: flow.typing.Numpy.Placeholder((4, 6), dtype=flow.float)): with flow.scope.placement("gpu", "0:0-3", (2, 2)): v = flow.get_variable( "x", shape=(4, 6), dtype=flow.float, initializer=flow.constant_initializer(0), trainable=True, nd_sbp=["S(0)", "S(1)"], ) x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "S(1)"]) x += v loss = flow.reshape(x, (4, 2, 3)) loss = flow.hierarchical_parallel_cast(loss, nd_sbp=["S(0)"]) flow.optimizer.SGD(flow.optimizer.PiecewiseConstantScheduler([], [0.0001]), momentum=0).minimize(loss) return loss
def test_fn( a: flow.typing.Numpy.Placeholder(a_shape), b: flow.typing.Numpy.Placeholder(b_shape), c: flow.typing.Numpy.Placeholder(c_shape), ) -> flow.typing.Numpy: var_a = flow.get_variable( name="var_a", shape=a_shape, dtype=flow.float32, initializer=flow.ones_initializer(), distribute=flow.distribute.split(1), ) a = flow.hierarchical_parallel_cast(a, nd_sbp=["S(1)"]) a = var_a * a out = flow.matmul(a, b) out = flow.hierarchical_parallel_cast(out, nd_sbp=["B"]) c = flow.hierarchical_parallel_cast(c, nd_sbp=["B"]) out = flow.nn.bias_add(out, c) lr_scheduler = flow.optimizer.PiecewiseConstantScheduler([], [0.001]) flow.optimizer.SGD(lr_scheduler, momentum=0).minimize(out) return out
def gpt_loader_fn() -> flow.typing.Numpy: with flow.scope.placement("cpu", device_strs, parallel_hierachy): tokens = flow.data.megatron_gpt_mmap_data_loader( data_file_prefix=data_file_prefix, seq_length=seq_length, num_samples=num_samples, batch_size=batch_size, dtype=dtype, shuffle=shuffle, random_seed=random_seed, split_sizes=split_sizes, split_index=split_index, nd_sbp=nd_sbp, start_from_saved_progress=start_from_saved_progress, name="GPTDataLoader", ) if isinstance(nd_sbp, list) and len(nd_sbp) > 1: tokens = flow.hierarchical_parallel_cast(tokens, nd_sbp=["B", "B"]) tokens = flow.hierarchical_parallel_cast(tokens, nd_sbp=["B"]) return tokens
def FlowJob(x: flow.typing.Numpy.Placeholder((4, 3, 2, 3), dtype=flow.float)): with flow.scope.placement("gpu", "0:0-3", (2, 2)): v1 = flow.get_variable( "v1", shape=(4, 3, 2, 3), dtype=flow.float, initializer=flow.constant_initializer(0), trainable=True, nd_sbp=["S(0)", "S(2)"], ) v2 = flow.get_variable( "v2", shape=(4, 3, 6), dtype=flow.float, initializer=flow.constant_initializer(0), trainable=True, nd_sbp=["S(0)", "S(2)"], ) x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "S(2)"]) x += v1 loss = flow.reshape_like(x, v2) loss = flow.hierarchical_parallel_cast(loss, nd_sbp=["S(0)"]) return loss
def test_fn( x: flow.typing.Numpy.Placeholder((1024, 1024)), indices: flow.typing.Numpy.Placeholder(shape=(64, ), dtype=flow.int32), ) -> flow.typing.Numpy: with flow.scope.placement("gpu", "0:0-3", (2, 2)): if src[0] == "S(0)": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["S(0)", "S(0)"]) if src[1] == "S(0)": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["S(0)", "S(0)"]) elif src[1] == "S(1)": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "S(1)"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["S(0)", "B"]) elif src[1] == "P": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "S(0)"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["S(0)", "B"]) elif src[1] == "B": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["S(0)", "B"]) elif src[0] == "P": x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "S(0)"]) indices = flow.hierarchical_parallel_cast(indices, nd_sbp=["B", "B"]) if src[1] == "S(0)": x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "B"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["B", "S(0)"]) elif src[1] == "S(1)": x = flow.hierarchical_parallel_cast( x, nd_sbp=["S(0)", "S(1)"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["B", "B"]) elif src[1] == "P": x = flow.hierarchical_parallel_cast( x, nd_sbp=["S(0)", "S(0)"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["B", "B"]) elif src[1] == "B": x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "B"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["B", "B"]) elif src[0] == "B": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"]) indices = flow.hierarchical_parallel_cast(indices, nd_sbp=["B", "B"]) if src[1] == "S(0)": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["B", "S(0)"]) elif src == "S(1)": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "S(1)"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["B", "B"]) elif src == "P": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "S(0)"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["B", "B"]) elif src == "B": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["B", "B"]) else: raise NotImplementedError x = flow.gather(x, indices) x = flow.hierarchical_parallel_cast(x, nd_sbp=dst, name="gather_cast") if dst[0] == "S(0)": x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "S(0)"]) elif dst[0] == "B": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"]) elif dst[0] == "S(1)": x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(1)", "S(1)"]) else: raise NotImplementedError x = flow.hierarchical_parallel_cast(x, nd_sbp=["B"]) return x