Example #1
0
def test_scatter_gather_graph(transformer_factory):
    # Build the graph
    W = ng.make_axis(length=6, name='width')

    with ng.metadata(device_id='0'):
        x = ng.placeholder(())
        z = ng.placeholder(())

    with ng.metadata(device_id=('1', '2'), parallel=W):
        y = ng.placeholder(())

    x_plus_z = x + z  # Does not create a recv node
    x_plus_y = x + y  # creates a gather recv node

    # Build the graph metadata
    graph_ops = OrderedSet([x, y, z, x_plus_z, x_plus_y])

    graph_op_metadata = {op: list() for op in graph_ops}
    graph_op_metadata[x] = ["cpu", '0']
    graph_op_metadata[z] = ["cpu", '0']
    graph_op_metadata[y] = ["cpu", ('1', '2')]
    graph_op_metadata[x_plus_z] = ["cpu", '0']
    graph_op_metadata[x_plus_y] = ["cpu", '0']

    check_device_assign_pass("cpu", "0", graph_op_metadata, graph_ops)

    check_communication_pass(
        ops_to_transform=graph_ops,
        expected_recv_nodes=[x_plus_y])
Example #2
0
def test_singleton_device_id(transformer_factory):
    with ng.metadata(device_id=(['1'])):
        x = ng.placeholder(())
    graph_ops = OrderedSet([x])

    graph_op_metadata = {op: list() for op in graph_ops}
    graph_op_metadata[x] = ["cpu", '1']

    check_device_assign_pass("cpu", "0", graph_op_metadata, graph_ops)