Esempio n. 1
0
def test_round_trip():
    # set up an op and Assign a value to it so we can read it out
    axes = ng.make_axes([
        ng.make_axis(name='A', length=2),
        ng.make_axis(name='B', length=3),
    ])
    x_op = ng.variable(axes)

    assign_op = ng.AssignOp(x_op, 1)

    with executor(assign_op) as assign_computation:
        t = assign_computation.transformer

        # Set initial value
        assign_computation()

        # Test value
        np.testing.assert_allclose(serde_weights.extract_op(t, x_op), 1)

        # write out values in x and graph
        f = BytesIO()

        # ## EXAMPLE OF HOW TO FULLY SERIALIZE A GRAPH ###
        serde_weights.serialize_weights(t, [x_op], f)
        graph_string = serde.serialize_graph([x_op])
        # ## /EXAMPLE OF HOW TO FULLY SERIALIZE A GRAPH ###

        f.seek(0)

        # ## EXAMPLE OF HOW TO FULLY DESERIALIZE A GRAPH ###
        new_ops = serde.deserialize_graph(graph_string)
        serde_weights.deserialize_weights(t, new_ops, f)
        # ## /EXAMPLE OF HOW TO FULLY DESERIALIZE A GRAPH ###

        np.testing.assert_allclose(serde_weights.extract_op(t, new_ops[0]), 1)
Esempio n. 2
0
 def do_pass(self, ops, **kwargs):
     assert isinstance(
         ops, Iterable), "Ops passed into do_pass must be an iterable"
     data = serde.serialize_graph(ops)
     self.tmpfile.write(data)
     logging.info("Written out serialized graph to {}", self.tmpfile.name)
     self.tmpfile.close()
Esempio n. 3
0
 def monkey_add_computation(self, comp):
     if comp.name.startswith('init'):
         return original_computation(self, comp)
     ser_comp = serde.serialize_graph([comp], only_return_handle_ops=True)
     deser_comp = serde.deserialize_graph(ser_comp)
     assert len(deser_comp) == 1
     return original_computation(self, deser_comp[0])
Esempio n. 4
0
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones):
    """
    clone graph with serde (serialization)
    input:
    output: new_root of the cloned graph
    """
    # clone nodes with GatherSendOp as root using serde
    ser_cloned_nodes = deserialize_graph(serialize_graph([root]))
    new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None)

    orig_ops = {op.uuid: op for op in Op.ordered_ops([root])}
    # Prune ops that are not control_deps of new_gather_send_op
    # deserialize includes extra referenced nodes
    cloned_graph = Op.ordered_ops([new_root])

    new_send_nodes = OrderedSet()
    replaced_send_nodes = OrderedSet()

    # update newly cloned op metadata, generate new UUIDs
    for op in cloned_graph:
        op.metadata['transformer'] = op.metadata['device'] + str(clone_id)
        op.metadata['device_id'] = str(clone_id)

        if isinstance(op, (ScatterRecvOp, GatherSendOp)):
            op._shared_queues = orig_ops[op.uuid]._shared_queues
            op.idx = shared_queues_idx
            if isinstance(op, ScatterRecvOp):
                op._send_node = orig_ops[op.uuid].send_node()
        elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)):
            # Cloning a recv node means we need a broadcast, so simulate one by adding an
            # additional sender with the same input data as the original sender.
            # TODO replace with real broadcast #1398 #1399
            send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0])
            op._queue = send_op.queue
            op._send_node = send_op
            new_send_nodes.add(send_op)
            replaced_send_nodes.add(orig_ops[op.uuid].send_node())

        if hasattr(op, '_axes') and parallel_axis in op._axes:
            op._axes = calculate_scatter_axes(op.axes, parallel_axis,
                                              num_clones)
            # TODO: Revisit to handle axes updation better. Github Ticket #1355
            if isinstance(op, DotOp):
                if parallel_axis in op.x_out_axes:
                    op.x_out_axes = calculate_scatter_axes(
                        op.x_out_axes, parallel_axis, num_clones)
                elif parallel_axis in op.y_out_axes:
                    op.y_out_axes = calculate_scatter_axes(
                        op.y_out_axes, parallel_axis, num_clones)
                else:
                    raise ValueError(
                        "Missing parallel_axis in Op's x_out_axes or y_out_axes"
                    )
        op.uuid = uuid.uuid4()

    return new_root, new_send_nodes, replaced_send_nodes
Esempio n. 5
0
def test_op_references():
    # test op references in arbitrary attributes
    orig_op = ng.placeholder(())
    other_op = ng.placeholder(()).named("foo")
    orig_op.op_ref = other_op
    orig_op.many_op_refs = [other_op]
    ser_string = ser.serialize_graph([orig_op], only_return_handle_ops=True)
    py_op = ser.deserialize_graph(ser_string)[0]
    assert py_op.op_ref.name.startswith('foo')
    assert py_op.many_op_refs[0].name.startswith('foo')
Esempio n. 6
0
def test_hetr_send_recv_graph_serialization():
    """
    test serializing send/recv ops defined in comm_nodes for hetr communication
    """
    z, recv_x, recv_x_plus_one, send_x, x_plus_one, from_node, send_x_plus_one = \
        create_send_recv_graph()
    ser_string = ser.serialize_graph([z])
    py_graph = ser.deserialize_graph(ser_string)
    orig_graph = Op.all_op_references([z])

    for o1, o2 in zip(sorted(py_graph, key=lambda x: x.uuid),
                      sorted(orig_graph, key=lambda x: x.uuid)):
        assert_object_equality(o1, o2)
Esempio n. 7
0
def test_op_handle_selection():
    """
    When serializing graphs, we can optionally add metadata to
    those nodes we pass in, and return only those nodes when deserializing.

    This is useful for ngraph transparent testing since it is common in
    ngraph to use the final op as the 'handle' to the entire graph.
    """
    base_op, simple_graph = get_simple_graph()
    ser_string = ser.serialize_graph([simple_graph], only_return_handle_ops=True)
    py_graph = ser.deserialize_graph(ser_string)
    assert len(py_graph) == 1
    assert_object_equality(simple_graph, py_graph[0])
Esempio n. 8
0
def test_full_graph_serialization_endtoend():
    base_op, simple_graph = get_simple_graph()

    ser_string = ser.serialize_graph([simple_graph])
    py_graph = ser.deserialize_graph(ser_string)
    orig_graph = Op.all_op_references([simple_graph])

    # This is actually overkill since the checks of the leaf nodes will recursively
    # check equality up the graph, but we also want to make sure the full set of nodes
    # returned is equal
    for o1, o2 in zip(sorted(py_graph, key=lambda x: x.uuid),
                      sorted(orig_graph, key=lambda x: x.uuid)):
        assert_object_equality(o1, o2)
Esempio n. 9
0
def clone_graph(root, clone_id, parallel_axis):
    """
    clone graph with serde (serialization)
    input:
    output: new_root of the cloned graph
    """

    # clone nodes with GatherSendOp as root using serde
    ser_cloned_nodes = deserialize_graph(serialize_graph([root]))

    new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None)

    orig_ops = {op.uuid: op for op in Op.ordered_ops([root])}
    cloned_graph = Op.ordered_ops([new_root])

    new_send_nodes = OrderedSet()
    replaced_send_nodes = OrderedSet()

    # update newly cloned op metadata, generate new UUIDs
    for op in cloned_graph:
        cloned_ops = orig_ops[op.uuid].metadata.get('clones')
        if cloned_ops is None or cloned_ops.get(str(clone_id)) is None:
            op.metadata['transformer'] = op.metadata['device'] + str(clone_id)
            op.metadata['device_id'] = str(clone_id)

            if isinstance(
                    op,
                (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)):
                # for gpu communication op buffer
                op.idx = int(clone_id)
                if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)):
                    op._send_node = orig_ops[op.uuid].send_node()

            if hasattr(
                    op,
                    'reduction_axes') and parallel_axis in op.reduction_axes:
                op.reduction_axes = set_parallel_axes(op.reduction_axes,
                                                      parallel_axis)

            if getattr(op, 'axes', None) is not None \
                    and parallel_axis in Axes.as_flattened_list(op.axes):
                # if parallel_axis in Axes.as_flattened_list(op.axes):
                op._axes = set_parallel_axes(op.axes, parallel_axis)
                if isinstance(op, DotOp):
                    if parallel_axis in op.x_out_axes:
                        op.x_out_axes = set_parallel_axes(
                            op.x_out_axes, parallel_axis)
                    elif parallel_axis in op.y_out_axes:
                        op.y_out_axes = set_parallel_axes(
                            op.y_out_axes, parallel_axis)
                    else:
                        raise ValueError("Missing parallel_axis in Op's "
                                         "x_out_axes or y_out_axes")

            if isinstance(op,
                          TensorValueOp) and parallel_axis in op.tensor.axes:
                op.tensor._axes = set_parallel_axes(op.tensor.axes,
                                                    parallel_axis)

            args_list = list(op.args)
            for arg_idx, arg_op in enumerate(args_list):
                if arg_op.uuid in orig_ops.keys():
                    if orig_ops[arg_op.uuid].metadata.get('clones') and \
                       orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)):
                        args_list[arg_idx] = \
                            orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id))

            op.invalidate_property_cache('all_deps')
            op._args = tuple(args_list)
            if op != new_root:
                if orig_ops[op.uuid].metadata.get('clones') is None:
                    orig_ops[op.uuid].metadata['clones'] = dict()
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op
                else:
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op

            op.uuid = uuid.uuid4()

    # create new uuids for all the ops that have references to the new root
    for _op in Op.all_op_references([new_root]):
        _op.uuid = uuid.uuid4()

    return new_root, new_send_nodes, replaced_send_nodes
Esempio n. 10
0
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones):
    """
    clone graph with serde (serialization)
    input:
    output: new_root of the cloned graph
    """
    # clone nodes with GatherSendOp as root using serde
    ser_cloned_nodes = deserialize_graph(serialize_graph([root]))
    new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None)

    orig_ops = {op.uuid: op for op in Op.ordered_ops([root])}
    # Prune ops that are not control_deps of new_gather_send_op
    # deserialize includes extra referenced nodes
    cloned_graph = Op.ordered_ops([new_root])

    new_send_nodes = OrderedSet()
    replaced_send_nodes = OrderedSet()

    # update newly cloned op metadata, generate new UUIDs
    for op in cloned_graph:
        cloned_ops = orig_ops[op.uuid].metadata.get('clones')
        if cloned_ops is None or cloned_ops.get(str(clone_id)) is None:
            op.metadata['transformer'] = op.metadata['device'] + str(clone_id)
            op.metadata['device_id'] = str(clone_id)

            if isinstance(
                    op,
                (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)):
                op._shared_queues = orig_ops[op.uuid]._shared_queues
                op.idx = shared_queues_idx
                if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)):
                    op._send_node = orig_ops[op.uuid].send_node()
            elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)):
                # Cloning a recv node means we need a broadcast, so simulate one by adding an
                # additional sender with the same input data as the original sender.
                send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0])
                op._queue = send_op.queue
                op._send_node = send_op
                new_send_nodes.add(send_op)
                replaced_send_nodes.add(orig_ops[op.uuid].send_node())

            if hasattr(
                    op,
                    'reduction_axes') and parallel_axis in op.reduction_axes:
                op.reduction_axes = set_parallel_axes(op.reduction_axes,
                                                      parallel_axis)

            if getattr(op, 'axes', None) is not None \
                    and parallel_axis in Axes.as_flattened_list(op.axes):
                # if parallel_axis in Axes.as_flattened_list(op.axes):
                op._axes = set_parallel_axes(op.axes, parallel_axis)
                if isinstance(op, DotOp):
                    if parallel_axis in op.x_out_axes:
                        op.x_out_axes = set_parallel_axes(
                            op.x_out_axes, parallel_axis)
                    elif parallel_axis in op.y_out_axes:
                        op.y_out_axes = set_parallel_axes(
                            op.y_out_axes, parallel_axis)
                    else:
                        raise ValueError("Missing parallel_axis in Op's "
                                         "x_out_axes or y_out_axes")

            if isinstance(op,
                          TensorValueOp) and parallel_axis in op.tensor.axes:
                op.tensor._axes = set_parallel_axes(op.tensor.axes,
                                                    parallel_axis)

            args_list = list(op.args)
            for arg_idx, arg_op in enumerate(args_list):
                if arg_op.uuid in orig_ops.keys():
                    if orig_ops[arg_op.uuid].metadata.get('clones') and \
                       orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)):
                        args_list[arg_idx] = \
                            orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id))
            op.invalidate_property_cache('all_deps')
            op._args = tuple(args_list)
            if op != new_root:
                if orig_ops[op.uuid].metadata.get('clones') is None:
                    orig_ops[op.uuid].metadata['clones'] = dict()
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op
                else:
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op

            op.uuid = uuid.uuid4()

    return new_root, new_send_nodes, replaced_send_nodes