Esempio n. 1
0
def test_scatter_gather_node_axes(config):
    t = config
    axes = ng.make_axes([ng.make_axis(length) for length in t['axes']])
    parallel_axis = axes[t['parallel_axis']]
    hetr_axes = parallel_axis + (axes - parallel_axis)
    with ng.metadata(device=None, device_id='0', transformer='cpu0', host_transformer=None):
        from_node = ng.placeholder(axes=axes)
        to_node = ng.placeholder(axes=axes)

    with ng.metadata(device=None, device_id=t['device_id'], transformer=None,
                     parallel=parallel_axis, host_transformer=None):
        par_node = ng.placeholder(axes=axes)

    scatter_send_op = ScatterSendOp(from_node=from_node,
                                    to_node=par_node)
    assert hetr_axes == scatter_send_op.axes
    assert t['slices'] == scatter_send_op.slices

    scatter_recv_op = ScatterRecvOp(to_node=par_node,
                                    send_node=scatter_send_op)

    for sct_a, a in zip(scatter_recv_op.axes, hetr_axes):
        assert sct_a.length == a.length

    gather_send_op = GatherSendOp(from_node=scatter_recv_op)
    assert_axes_eq_len(scatter_recv_op.axes, gather_send_op.axes)

    gather_recv_op = GatherRecvOp(from_node=par_node,
                                  to_node=to_node,
                                  send_node=gather_send_op)
    assert_axes_eq_len(hetr_axes, gather_recv_op.axes)

    assert t['slices'] == gather_recv_op.slices
Esempio n. 2
0
def test_update_comm_deps_scatter_gather():
    ax_a = ng.make_axis(length=10, name='A')
    ax_b = ng.make_axis(length=15, name='B')
    axes = ng.make_axes([ax_a, ax_b])

    parallel_metadata = dict(parallel=ax_a,
                             device_id=(0, 1),
                             transformer=None,
                             host_transformer=None,
                             device=None)
    with ng.metadata(transformer='cpu0'):
        with ng.metadata(**parallel_metadata):
            from_node_a = ng.placeholder(axes)
            to_node_a = ng.placeholder(axes)
        scatter_send_x = ScatterSendOp(from_node=from_node_a,
                                       to_node=to_node_a)
        scatter_recv_a = ScatterRecvOp(to_node=to_node_a,
                                       send_node=scatter_send_x)
        with ng.metadata(**parallel_metadata):
            x_plus_one_a = scatter_recv_a + 1
        gather_send_x_plus_one_a = GatherSendOp(from_node=x_plus_one_a)

    with ng.metadata(transformer='cpu1'):
        with ng.metadata(**parallel_metadata):
            to_node_b = ng.placeholder(axes)
        scatter_recv_b = ScatterRecvOp(to_node=to_node_b,
                                       send_node=scatter_send_x)
        with ng.metadata(**parallel_metadata):
            x_plus_one_b = scatter_recv_b + 1
        gather_send_x_plus_one_b = GatherSendOp(from_node=x_plus_one_b)

    with ng.metadata(transformer='cpu0'):
        with ng.metadata(**parallel_metadata):
            gather_recv_x_plus_one_a = GatherRecvOp(
                from_node=from_node_a,
                to_node=to_node_a,
                send_node=gather_send_x_plus_one_a)
            z_a = gather_recv_x_plus_one_a + 1

    update_comm_deps((scatter_send_x, gather_send_x_plus_one_a, z_a))
    update_comm_deps((gather_send_x_plus_one_b, ))

    assert set([scatter_send_x]) == set(scatter_recv_a.control_deps)
    assert set([scatter_send_x, gather_send_x_plus_one_a]) == \
        set(gather_recv_x_plus_one_a.control_deps)
Esempio n. 3
0
def create_scatter_gather_graph():
    ax_a = ng.make_axis(length=10, name='A')
    ax_b = ng.make_axis(length=20, name='B')
    axes = ng.make_axes([ax_a, ax_b])

    with ng.metadata(parallel=ax_b, device=(0, 1), device_id=(0, 1),
                     transformer=None, host_transformer=None):
        from_node = ng.placeholder(axes)
        to_node = ng.placeholder(axes)
    scatter_send_x = ScatterSendOp(from_node=from_node, to_node=to_node)
    scatter_recv_a = ScatterRecvOp(to_node=to_node, send_node=scatter_send_x)
    scatter_recv_b = ScatterRecvOp(to_node=to_node, send_node=scatter_send_x)
    gather_send_a = GatherSendOp(from_node=scatter_recv_a)
    gather_send_b = GatherSendOp(from_node=scatter_recv_b)
    gather_recv_x_plus_one = GatherRecvOp(from_node=from_node, to_node=to_node,
                                          send_node=gather_send_a)
    return scatter_send_x, scatter_recv_a, scatter_recv_b, \
        gather_send_a, gather_send_b, gather_recv_x_plus_one