Esempio n. 1
0
def test_find_recvs():
    z, recv_x, recv_x_plus_one, send_x, x_plus_one, from_node, send_x_plus_one = \
        create_send_recv_graph()

    assert set([recv_x]) == set(find_recvs(x_plus_one))
    assert set([recv_x]) == set(find_recvs(recv_x))
    assert len(find_recvs(from_node)) == 0
    assert set([recv_x]) == set(find_recvs(send_x_plus_one))
    assert set([recv_x_plus_one, recv_x]) == set(find_recvs(recv_x_plus_one))
    assert set([recv_x_plus_one, recv_x]) == set(find_recvs(z))
Esempio n. 2
0
def test_hetr_send_recv_graph_serialization():
    """
    test serializing send/recv ops defined in comm_nodes for hetr communication
    """
    z, recv_x, recv_x_plus_one, send_x, x_plus_one, from_node, send_x_plus_one = \
        create_send_recv_graph()
    ser_string = ser.serialize_graph([z])
    py_graph = ser.deserialize_graph(ser_string)
    orig_graph = Op.all_op_references([z])

    for o1, o2 in zip(sorted(py_graph, key=lambda x: x.uuid),
                      sorted(orig_graph, key=lambda x: x.uuid)):
        assert_object_equality(o1, o2)
Esempio n. 3
0
def test_update_comm_deps():
    with ng.metadata(transformer='cpu0'):
        z, recv_x, recv_x_plus_one, send_x, x_plus_one, from_node, send_x_plus_one = \
            create_send_recv_graph()
    update_comm_deps((z, send_x))
    assert recv_x_plus_one in z.all_deps