def test_all_op_references():
    base_op, simple_graph = get_simple_graph()

    leaf_all_ops = Op.all_op_references([simple_graph])
    assert base_op in leaf_all_ops
    assert simple_graph in leaf_all_ops
    base_all_ops = Op.all_op_references([base_op])
    assert base_op in base_all_ops
    assert simple_graph not in base_all_ops
Exemple #2
0
    def Computation(self, request_iterator, context):
        logger.debug("server: computation")
        if not self.transformer:
            return hetr_pb2.ComputationReply(
                comp_id=-1, message="build transformer before computation")
        try:
            comp_id = self.new_comp_id()
            pb_ops, pb_edges = [], []
            returns, placeholders = [], []
            reconstructed_returns, reconstructed_placeholders = [], []
            for request in request_iterator:
                pb_ops.extend(request.ops)
                pb_edges.extend(request.edges)
                returns.extend([protobuf_to_op(op) for op in request.returns])
                placeholders.extend(
                    [protobuf_to_op(op) for op in request.placeholders])

            subgraph = _deserialize_graph_ops_edges(pb_ops, pb_edges)

            # Add dependency on recv op to their send op in scenarios where the send buffer
            # is passed as an argument to the communication call (gather/scatter)
            # on the root device.
            # This ensures that by the send buffer does not get reused before
            # the recv_buf gets access to items
            root_idx = 0
            for op in Op.all_op_references(subgraph):
                if isinstance(op, (GatherRecvOp)) and \
                   MPI.COMM_WORLD.Get_rank() == op.metadata['device_id']:
                    args = list(op._args)
                    args.extend(op.send_node().args)
                    op._args = tuple(args)
                    op.invalidate_property_cache('all_deps')
                elif (isinstance(op, (ScatterRecvOp))
                      and MPI.COMM_WORLD.Get_rank() == root_idx and
                      MPI.COMM_WORLD.Get_rank() in op.metadata['device_id']):
                    args = list(op._args)
                    args.extend(op.send_node().args)
                    op._args = tuple(args)
                    op.invalidate_property_cache('all_deps')

            ops = Op.ordered_ops(subgraph)
            for r in returns:
                for op in ops:
                    if op.uuid == r.uuid:
                        reconstructed_returns.append(op)
            for p in placeholders:
                for op in ops:
                    if op.uuid == p.uuid:
                        reconstructed_placeholders.append(op)

            computation = self.transformer.computation(
                reconstructed_returns, *reconstructed_placeholders)
            self.computations[comp_id] = computation
            return hetr_pb2.ComputationReply(comp_id=comp_id)
        except Exception:
            return hetr_pb2.ComputationReply(comp_id=-1,
                                             message=traceback.format_exc())
    def do_pass(self, ops, **kwargs):
        try:
            import graphviz
        except ImportError:
            raise ImportError(
                "You tried to use the ShowGraph transformer pass but did "
                "not have the python graphviz library installed")
        # Get all ops and edges from this set
        all_ops = Op.all_op_references(ops)
        all_edges = ser._serialize_graph(ops).edges

        vg = graphviz.Digraph(node_attr={'shape': 'box'},
                              graph_attr={
                                  'nodesep': '.5',
                                  'ranksep': '.5'
                              })
        if self.subgraph_attr is not None:
            subgraphs = {}
            for subgraph_name in self.get_subgraphs(all_ops):
                if subgraph_name not in subgraphs and subgraph_name is not None:
                    sg = graphviz.Digraph(
                        name='cluster_{}'.format(subgraph_name))
                    sg.body.append('color="{}"'.format(self.random_color()))
                    sg.body.append('style=filled')
                    sg.body.append('label="{}"'.format(subgraph_name))
                    subgraphs[subgraph_name] = sg
            for op in all_ops:
                subgraph_name = op.metadata.get(self.subgraph_attr, '')
                # hack to show hetr graph, a tuple becomes a list after clone_graph
                if isinstance(subgraph_name, list):
                    subgraph_name = tuple(subgraph_name)
                if subgraph_name in subgraphs:
                    graph = subgraphs[subgraph_name]
                else:
                    graph = vg
                self.add_op_to_graph(op, graph)

            for sg in subgraphs.values():
                vg.subgraph(sg)

        else:
            for op in all_ops:
                self.add_op_to_graph(op, vg)

        for edge in all_edges:
            self.add_edge_to_graph(edge, vg)

        tmp_dir = tempfile.mkdtemp()
        vg.render(directory=tmp_dir, view=self.view, cleanup=True)
        if not self.view:
            logging.info("VizPass graph rendered to {}", tmp_dir)
        # Cleanup
        self.uuid_lookup_table.clear()

        return ops
Exemple #4
0
def test_hetr_send_recv_graph_serialization():
    """
    test serializing send/recv ops defined in comm_nodes for hetr communication
    """
    z, recv_x, recv_x_plus_one, send_x, x_plus_one, from_node, send_x_plus_one = \
        create_send_recv_graph()
    ser_string = ser.serialize_graph([z])
    py_graph = ser.deserialize_graph(ser_string)
    orig_graph = Op.all_op_references([z])

    for o1, o2 in zip(sorted(py_graph, key=lambda x: x.uuid),
                      sorted(orig_graph, key=lambda x: x.uuid)):
        assert_object_equality(o1, o2)
Exemple #5
0
def test_full_graph_serialization_endtoend():
    base_op, simple_graph = get_simple_graph()

    ser_string = ser.serialize_graph([simple_graph])
    py_graph = ser.deserialize_graph(ser_string)
    orig_graph = Op.all_op_references([simple_graph])

    # This is actually overkill since the checks of the leaf nodes will recursively
    # check equality up the graph, but we also want to make sure the full set of nodes
    # returned is equal
    for o1, o2 in zip(sorted(py_graph, key=lambda x: x.uuid),
                      sorted(orig_graph, key=lambda x: x.uuid)):
        assert_object_equality(o1, o2)
        def generate_messages():
            pb_ops, pb_edges = [], []
            pb_returns, pb_placeholders = generate_returns_placeholders()
            ops = Op.all_op_references(returns + list(placeholders))
            for i, op in enumerate(ops):
                pb_ops.append(op_to_protobuf(op))
                add_edges(pb_edges, pb_ops, op)
                if (i != 0 and i % _OPS_PER_MSG == 0) or (i == len(ops) - 1):
                    msg = make_computation_request(pb_ops, pb_edges,
                                                   pb_returns, pb_placeholders)
                    yield msg

                    pb_ops, pb_edges = [], []
                    pb_returns, pb_placeholders = [], []
Exemple #7
0
def _serialize_graph(ops):
    """
    Serializes a graph and returns the actual protobuf python object (rather than serialized
    byte string as done by `serialize_graph`).
    """
    assert isinstance(
        ops, Iterable), "Ops passed into `serialize_graph` must be an iterable"
    ops = Op.all_op_references(ops)
    pb_ops = []
    pb_edges = []
    for op in ops:
        pb_ops.append(op_to_protobuf(op))
        add_edges(pb_edges, pb_ops, op)

    graph_def = ops_pb.GraphDef()
    for edge in pb_edges:
        temp = graph_def.edges.add()
        temp.CopyFrom(edge)
    for op in pb_ops:
        temp = graph_def.ops.add()
        temp.CopyFrom(op)
    return graph_def
Exemple #8
0
def clone_graph(root, clone_id, parallel_axis):
    """
    clone graph with serde (serialization)
    input:
    output: new_root of the cloned graph
    """

    # clone nodes with GatherSendOp as root using serde
    ser_cloned_nodes = deserialize_graph(serialize_graph([root]))

    new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None)

    orig_ops = {op.uuid: op for op in Op.ordered_ops([root])}
    cloned_graph = Op.ordered_ops([new_root])

    new_send_nodes = OrderedSet()
    replaced_send_nodes = OrderedSet()

    # update newly cloned op metadata, generate new UUIDs
    for op in cloned_graph:
        cloned_ops = orig_ops[op.uuid].metadata.get('clones')
        if cloned_ops is None or cloned_ops.get(str(clone_id)) is None:
            op.metadata['transformer'] = op.metadata['device'] + str(clone_id)
            op.metadata['device_id'] = str(clone_id)

            if isinstance(
                    op,
                (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)):
                # for gpu communication op buffer
                op.idx = int(clone_id)
                if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)):
                    op._send_node = orig_ops[op.uuid].send_node()

            if hasattr(
                    op,
                    'reduction_axes') and parallel_axis in op.reduction_axes:
                op.reduction_axes = set_parallel_axes(op.reduction_axes,
                                                      parallel_axis)

            if getattr(op, 'axes', None) is not None \
                    and parallel_axis in Axes.as_flattened_list(op.axes):
                # if parallel_axis in Axes.as_flattened_list(op.axes):
                op._axes = set_parallel_axes(op.axes, parallel_axis)
                if isinstance(op, DotOp):
                    if parallel_axis in op.x_out_axes:
                        op.x_out_axes = set_parallel_axes(
                            op.x_out_axes, parallel_axis)
                    elif parallel_axis in op.y_out_axes:
                        op.y_out_axes = set_parallel_axes(
                            op.y_out_axes, parallel_axis)
                    else:
                        raise ValueError("Missing parallel_axis in Op's "
                                         "x_out_axes or y_out_axes")

            if isinstance(op,
                          TensorValueOp) and parallel_axis in op.tensor.axes:
                op.tensor._axes = set_parallel_axes(op.tensor.axes,
                                                    parallel_axis)

            args_list = list(op.args)
            for arg_idx, arg_op in enumerate(args_list):
                if arg_op.uuid in orig_ops.keys():
                    if orig_ops[arg_op.uuid].metadata.get('clones') and \
                       orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)):
                        args_list[arg_idx] = \
                            orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id))

            op.invalidate_property_cache('all_deps')
            op._args = tuple(args_list)
            if op != new_root:
                if orig_ops[op.uuid].metadata.get('clones') is None:
                    orig_ops[op.uuid].metadata['clones'] = dict()
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op
                else:
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op

            op.uuid = uuid.uuid4()

    # create new uuids for all the ops that have references to the new root
    for _op in Op.all_op_references([new_root]):
        _op.uuid = uuid.uuid4()

    return new_root, new_send_nodes, replaced_send_nodes
Exemple #9
0
    def __init__(self, hetr, computation_op):
        self.child_computations = dict()
        self.transformer = hetr
        # clear send_nodes for multiple computations
        if hetr.send_nodes:
            hetr.send_nodes.clear()
        self.send_nodes = hetr.send_nodes
        self.computation_op = computation_op

        # self.returns could be replaced by comp_op.returns if it were expressed as a set
        self.returns = OrderedSet()
        if isinstance(computation_op.returns, collections.Container):
            self.returns.update(list(computation_op.returns))
        elif isinstance(computation_op.returns, Op):
            self.returns.update(list([computation_op.returns]))

        # if one of the requested results is marked as distributed across devices,
        # wrap it in a ResultOp to facilitate DistributedPass inserting a gather operation
        new_returns = OrderedSet()
        for op in self.returns:
            if 'device_id' in op.metadata and \
                    isinstance(op.metadata['device_id'], (list, tuple)):
                op.metadata['is_split_op'] = True
                new_result = ResultOp(device_id=0, args=tuple([op]))
                op.metadata['hetr_replaced_by'] = new_result
                new_result.metadata['replaces_op'] = op
                new_returns.add(new_result)
            else:
                new_returns.add(op)

        # Do Hetr passes
        logger.info('Running graph passes'),
        pass_ops = new_returns | OrderedSet(self.computation_op.parameters)
        for graph_pass in self.transformer.graph_passes:
            pass_ops = pass_ops | OrderedSet(hetr.send_nodes)
            graph_pass.do_pass(ops=pass_ops)

        # hack around new TensorValueOp that wraps AssignableTensorOp
        # autogenerated by creating a ComputationOp:
        for p in self.computation_op.parameters:
            if isinstance(p, TensorValueOp):
                p.metadata.update(p.states_read[0].metadata)

        logger.info('Launching child processes'),
        # assume all children are the same type
        # and all GPUs are in one chassis
        num_process = len(self.transformer.child_transformers)
        ppn = 1 if self.transformer.default_device == 'cpu' else num_process
        self.transformer.mpilauncher.launch(num_process, ppn)
        self.transformer.setup_child_transformers(num_process)

        def is_my_op(op, name):
            op_trans = op.metadata['transformer']
            return name == op_trans or name in op_trans

        logger.info('Serializaing computation graph'),
        # build whole_graph once to avoid slow serialization once per worker
        # split whole pb message into list of smaller chunks
        # gRPC prefers sending smaller messages
        placeholders = [p for p in self.computation_op.parameters]
        all_returns = [o for o in self.send_nodes | new_returns]
        transform_returns = [
            o.args[0] if isinstance(o, ResultOp) else o for o in all_returns
        ]
        whole_graph = Op.all_op_references(transform_returns + placeholders)

        pb_whole_graph = []
        pb_ops, pb_edges = [], []
        for i, o in enumerate(whole_graph):
            pb_ops.append(op_to_protobuf(o))
            add_edges(pb_edges, pb_ops, o)
            if (i != 0 and i % _OPS_PER_MSG == 0) or (i
                                                      == len(whole_graph) - 1):
                pb_whole_graph.append((pb_ops, pb_edges))
                pb_ops, pb_edges = [], []

        t_placeholders, t_returns = {}, {}
        for t_name in self.transformer.child_transformers.keys():
            t_placeholders[t_name] = [
                p for p in placeholders if is_my_op(p, t_name)
            ]
            t_returns[t_name] = [r for r in all_returns if is_my_op(r, t_name)]

        # create_computation is an async call using gPRC future
        # allowing child transformers to create computation simultaneously
        # get_computation waits the corresponding request to finish
        logger.info('Creating remote computations'),
        for t_name, trans in iteritems(self.transformer.child_transformers):
            logger.debug('child transformer: {}'.format(t_name))
            trans.build_transformer()
            transform_ops = [
                r.args[0] if isinstance(r, ResultOp) else r
                for r in t_returns[t_name]
            ]
            trans.create_computation(pb_whole_graph, transform_ops,
                                     t_placeholders[t_name])

        for t_name, trans in iteritems(self.transformer.child_transformers):
            comp = trans.get_computation()
            comp.param_idx = [
                g_pos for g_pos, p in enumerate(self.computation_op.parameters)
                if is_my_op(p, t_name)
            ]

            # when there is a ResultOp, hack around it
            comp.returns = dict()
            for i, op in enumerate(t_returns[t_name]):
                if op in self.returns and 'hetr_replaced_by' not in op.metadata:
                    comp.returns[op] = i
                elif 'replaces_op' in op.metadata and op.metadata[
                        'replaces_op'] in self.returns:
                    comp.returns[op.metadata['replaces_op']] = i
            self.child_computations[t_name] = comp