예제 #1
0
파일: passes.py 프로젝트: kkasravi/ngraph
    def do_pass(self, min_ops, transformer):
        """
        Visit the ops until nothing changes.

        Args:
            min_ops: The set of ops that must be computed.
            transformer: An InitGraph object.

        """
        assert isinstance(min_ops, Iterable), "Ops passed into do_pass must be an iterable"
        has_work = True
        while True:
            ops = Op.ordered_ops(min_ops)

            # Check for ops that added state that needs to be initialized, so they can
            # be added to the initialization function.
            has_new_inits = transformer.add_initialization_ops(ops)
            if not has_work and not has_new_inits:
                return

            self.replacement_list = []

            # pass through the ops in an execution order collecting things to do
            ops = Op.ordered_ops(op.forwarded
                                 for op in transformer.state_initialization_ops + min_ops)
            for op in ops:
                op.update_forwards()
                self.visit(op)

            # Perform the gathered replacements
            for old, rep in self.replacement_list:
                old.forwarded.replace_self(rep.forwarded)
            has_work = len(self.replacement_list) > 0
            min_ops = list(_.forwarded for _ in min_ops)
예제 #2
0
def test_all_op_references():
    base_op, simple_graph = get_simple_graph()

    leaf_all_ops = Op.all_op_references([simple_graph])
    assert base_op in leaf_all_ops
    assert simple_graph in leaf_all_ops
    base_all_ops = Op.all_op_references([base_op])
    assert base_op in base_all_ops
    assert simple_graph not in base_all_ops
예제 #3
0
    def Computation(self, request_iterator, context):
        logger.debug("server: computation")
        if not self.transformer:
            return hetr_pb2.ComputationReply(
                comp_id=-1, message="build transformer before computation")
        try:
            comp_id = self.new_comp_id()
            pb_ops, pb_edges = [], []
            returns, placeholders = [], []
            reconstructed_returns, reconstructed_placeholders = [], []
            for request in request_iterator:
                pb_ops.extend(request.ops)
                pb_edges.extend(request.edges)
                returns.extend([protobuf_to_op(op) for op in request.returns])
                placeholders.extend(
                    [protobuf_to_op(op) for op in request.placeholders])

            subgraph = _deserialize_graph_ops_edges(pb_ops, pb_edges)

            # Add dependency on recv op to their send op in scenarios where the send buffer
            # is passed as an argument to the communication call (gather/scatter)
            # on the root device.
            # This ensures that by the send buffer does not get reused before
            # the recv_buf gets access to items
            root_idx = 0
            for op in Op.all_op_references(subgraph):
                if isinstance(op, (GatherRecvOp)) and \
                   MPI.COMM_WORLD.Get_rank() == op.metadata['device_id']:
                    args = list(op._args)
                    args.extend(op.send_node().args)
                    op._args = tuple(args)
                    op.invalidate_property_cache('all_deps')
                elif (isinstance(op, (ScatterRecvOp))
                      and MPI.COMM_WORLD.Get_rank() == root_idx and
                      MPI.COMM_WORLD.Get_rank() in op.metadata['device_id']):
                    args = list(op._args)
                    args.extend(op.send_node().args)
                    op._args = tuple(args)
                    op.invalidate_property_cache('all_deps')

            ops = Op.ordered_ops(subgraph)
            for r in returns:
                for op in ops:
                    if op.uuid == r.uuid:
                        reconstructed_returns.append(op)
            for p in placeholders:
                for op in ops:
                    if op.uuid == p.uuid:
                        reconstructed_placeholders.append(op)

            computation = self.transformer.computation(
                reconstructed_returns, *reconstructed_placeholders)
            self.computations[comp_id] = computation
            return hetr_pb2.ComputationReply(comp_id=comp_id)
        except Exception:
            return hetr_pb2.ComputationReply(comp_id=-1,
                                             message=traceback.format_exc())
예제 #4
0
파일: hetr_utils.py 프로젝트: ami-GS/ngraph
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones):
    """
    clone graph with serde (serialization)
    input:
    output: new_root of the cloned graph
    """
    # clone nodes with GatherSendOp as root using serde
    ser_cloned_nodes = deserialize_graph(serialize_graph([root]))
    new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None)

    orig_ops = {op.uuid: op for op in Op.ordered_ops([root])}
    # Prune ops that are not control_deps of new_gather_send_op
    # deserialize includes extra referenced nodes
    cloned_graph = Op.ordered_ops([new_root])

    new_send_nodes = OrderedSet()
    replaced_send_nodes = OrderedSet()

    # update newly cloned op metadata, generate new UUIDs
    for op in cloned_graph:
        op.metadata['transformer'] = op.metadata['device'] + str(clone_id)
        op.metadata['device_id'] = str(clone_id)

        if isinstance(op, (ScatterRecvOp, GatherSendOp)):
            op._shared_queues = orig_ops[op.uuid]._shared_queues
            op.idx = shared_queues_idx
            if isinstance(op, ScatterRecvOp):
                op._send_node = orig_ops[op.uuid].send_node()
        elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)):
            # Cloning a recv node means we need a broadcast, so simulate one by adding an
            # additional sender with the same input data as the original sender.
            # TODO replace with real broadcast #1398 #1399
            send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0])
            op._queue = send_op.queue
            op._send_node = send_op
            new_send_nodes.add(send_op)
            replaced_send_nodes.add(orig_ops[op.uuid].send_node())

        if hasattr(op, '_axes') and parallel_axis in op._axes:
            op._axes = calculate_scatter_axes(op.axes, parallel_axis,
                                              num_clones)
            # TODO: Revisit to handle axes updation better. Github Ticket #1355
            if isinstance(op, DotOp):
                if parallel_axis in op.x_out_axes:
                    op.x_out_axes = calculate_scatter_axes(
                        op.x_out_axes, parallel_axis, num_clones)
                elif parallel_axis in op.y_out_axes:
                    op.y_out_axes = calculate_scatter_axes(
                        op.y_out_axes, parallel_axis, num_clones)
                else:
                    raise ValueError(
                        "Missing parallel_axis in Op's x_out_axes or y_out_axes"
                    )
        op.uuid = uuid.uuid4()

    return new_root, new_send_nodes, replaced_send_nodes
예제 #5
0
    def _transform_computations(self):
        """
        Transform computation graphs to a form that can be run.
        """

        # Run passes on the computation graphs
        all_results = []
        for comp in self.computations:
            all_results.append(comp.computation_op)

        all_ops = self.run_registered_graph_passes(ops=all_results)

        # Collect up all ops from the graph and obtain the init graph
        all_ops = OrderedSet(Op.ordered_ops(all_ops))

        def ensure_tensor(op):
            op = op.forwarded
            tensor_description = op.tensor_description()
            base = tensor_description.base
            tensor = self.op_tensors.get(base, None)
            if tensor is None:
                tensor = self.device_buffer_storage(
                    base.tensor_size,
                    base.dtype,
                    base.name
                )
                self.op_tensors[base] = tensor
                self.device_buffers.add(tensor)
            tensor_view = tensor.device_tensor(tensor_description)
            self.op_tensor_views[tensor_description] = tensor_view

        self.ops = Op.ordered_ops(all_ops)
        for op in self.ops:
            if op.is_tensor_op:
                ensure_tensor(op)

        self.start_transform_allocate()
        for device_buffer in self.device_buffers:
            device_buffer.transform_allocate()
        self.finish_transform_allocate()

        # Compile the computations now that we know their storage
        for comp in self.computations:
            comp.computation_name = \
                self.transform_ordered_ops(comp,
                                           Op.ordered_ops([comp.computation_op]),
                                           name=comp.name)
        self.finish_transform()
        self.finalized = True
예제 #6
0
    def _transform_computations(self):
        """
        Transform computation graphs to a form that can be run.
        """

        # Run passes on the computation graphs
        all_results = []
        for comp in self.computations:
            all_results.append(comp.computation)

        all_ops = self.run_registered_graph_passes(all_results)
        self.init_computation = \
            self.add_computation(computation(doall(self.state_initialization_ops)).named('init'))
        all_ops.append(self.init_computation.computation)

        # Collect up all ops from the graph and obtain the init graph
        all_ops = OrderedSet(Op.ordered_ops(all_ops))

        def init_tensor_description(tensor_description):
            if tensor_description.buffer is None:
                tensor_description.buffer = self.device_buffer_storage(
                    tensor_description.base.tensor_size,
                    tensor_description.dtype,
                    tensor_description.name
                )
                self.device_buffers.add(tensor_description.buffer)
            tensor_description.value = \
                tensor_description.buffer.device_tensor(tensor_description)

        for state in self.init_states:
            init_tensor_description(state.tensor_description())
        self.ops = Op.ordered_ops(all_ops)
        for op in self.ops:
            if op.is_tensor_op:
                init_tensor_description(op.tensor_description())

        self.start_transform_allocate()
        for device_buffer in self.device_buffers:
            device_buffer.transform_allocate()
        self.finish_transform_allocate()

        # Compile the computations now that we know their storage
        for comp in self.computations:
            comp.computation_name = \
                self.transform_ordered_ops(Op.ordered_ops([comp.computation]),
                                           name=comp.name)
        self.finish_transform()
        self.finalized = True
예제 #7
0
    def Computation(self, request_iterator, context):
        logger.info("server: computation")
        if not self.transformer:
            return hetr_pb2.ComputationReply(
                comp_id=-1, message="build transformer before computation")
        try:
            comp_id = self.new_comp_id()
            pb_ops, pb_edges = [], []
            returns, placeholders = [], []
            reconstructed_returns, reconstructed_placeholders = [], []
            for request in request_iterator:
                pb_ops.extend(request.ops)
                pb_edges.extend(request.edges)
                returns.extend([protobuf_to_op(op) for op in request.returns])
                placeholders.extend(
                    [protobuf_to_op(op) for op in request.placeholders])

            subgraph = _deserialize_graph_ops_edges(pb_ops, pb_edges)
            ops = Op.ordered_ops(subgraph)
            for r in returns:
                for op in ops:
                    if op.uuid == r.uuid:
                        reconstructed_returns.append(op)
            for p in placeholders:
                for op in ops:
                    if op.uuid == p.uuid:
                        reconstructed_placeholders.append(op)

            computation = self.transformer.computation(
                reconstructed_returns, *reconstructed_placeholders)
            self.computations[comp_id] = computation
            return hetr_pb2.ComputationReply(comp_id=comp_id)
        except Exception:
            return hetr_pb2.ComputationReply(comp_id=-1,
                                             message=traceback.format_exc())
예제 #8
0
    def do_pass(self, min_ops, transformer):
        """
        Visit the ops until nothing changes.

        Args:
            min_ops: The set of ops that must be computed.
            transformer: An InitGraph object.

        """
        assert isinstance(
            min_ops, Iterable), "Ops passed into do_pass must be an iterable"
        has_work = True
        while True:
            if not has_work:
                return

            self.replacement_list = []

            # pass through the ops in an execution order collecting things to do
            ops = Op.ordered_ops(op.forwarded for op in min_ops)
            for op in ops:
                op.update_forwards()
                self.visit(op)

            # Perform the gathered replacements
            for old, rep in self.replacement_list:
                old.forwarded.replace_self(rep.forwarded)
            has_work = len(self.replacement_list) > 0
            min_ops = list(op.forwarded for op in min_ops)
예제 #9
0
파일: hetrpasses.py 프로젝트: ami-GS/ngraph
    def do_pass(self, ops, transformer):

        ops = OrderedSet(op.forwarded for op in ops)

        for op in reversed(Op.ordered_ops(ops)):
            if op.metadata.get('marker') == 'gather':
                # op is GatherRecvOp
                self.parallel_axes = op.metadata['parallel']

                gather_send_op = op.send_nodes[0]

                # clone nodes for each device_id
                replaced_send_ops = OrderedSet()
                new_gather_send_nodes = OrderedSet()
                for i, id in enumerate(op.from_id):
                    new_gather_send_op, new_sends, replaced_sends = clone_graph(
                        root=gather_send_op,
                        clone_id=id,
                        shared_queues_idx=i,
                        parallel_axis=self.parallel_axes,
                        num_clones=len(op.from_id))

                    new_gather_send_nodes.add(new_gather_send_op)

                    new_sends.add(new_gather_send_op)
                    for o in new_sends:
                        self.send_nodes.add(o)

                    replaced_send_ops |= replaced_sends

                op.send_nodes = new_gather_send_nodes

                replaced_send_ops.add(gather_send_op)
                for o in replaced_send_ops:
                    self.send_nodes.remove(o)
예제 #10
0
def update_parallel_axis(root, parallel_axis):
    for op in Op.ordered_ops([root]):

        if hasattr(op,
                   'reduction_axes') and parallel_axis in op.reduction_axes:
            op.reduction_axes = set_parallel_axes(op.reduction_axes,
                                                  parallel_axis)

        if getattr(op, 'axes', None) is not None \
                and parallel_axis in Axes.as_flattened_list(op.axes):
            # if parallel_axis in Axes.as_flattened_list(op.axes):
            op._axes = set_parallel_axes(op.axes, parallel_axis)
            if isinstance(op, DotOp):
                if parallel_axis in op.x_out_axes:
                    op.x_out_axes = set_parallel_axes(op.x_out_axes,
                                                      parallel_axis)
                elif parallel_axis in op.y_out_axes:
                    op.y_out_axes = set_parallel_axes(op.y_out_axes,
                                                      parallel_axis)
                else:
                    raise ValueError("Missing parallel_axis in Op's "
                                     "x_out_axes or y_out_axes")

        if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes:
            op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis)
예제 #11
0
    def Computation(self, request, context):
        if not self.transformer:
            return hetr_pb2.ComputationReply(comp_id=-1)

        try:
            comp_id = self.new_comp_id()
            subgraph = _deserialize_graph(request.subgraph)
            returns = []
            placeholders = []
            for pb_op in request.returns:
                returns.append(protobuf_to_op(pb_op))
            for pb_op in request.placeholders:
                placeholders.append(protobuf_to_op(pb_op))
            return_list = []
            placeholder_list = []
            ops = Op.ordered_ops(subgraph)
            for op in ops:
                for r in returns:
                    if op.uuid == r.uuid:
                        return_list.append(op)
            for op in ops:
                for p in placeholders:
                    if op.uuid == p.uuid:
                        placeholder_list.append(op)
            computation = self.transformer.computation(return_list,
                                                       *placeholder_list)
            self.computations[comp_id] = computation
            return hetr_pb2.ComputationReply(comp_id=comp_id)
        except:
            return hetr_pb2.ComputationReply(comp_id=-1)
예제 #12
0
    def add_initialization_ops(self, ops):
        """
        Ensure initializations have been captured for state in ops.

        Args:
            ops: Collection of ops.

        Returns:
            True if new initializations were added.

        """
        did_work = False
        for op in ops:
            if op in self.init_checked_ops:
                continue
            self.init_checked_ops.add(op)
            new_inits = self.state_initializations(op.states_read)
            new_inits.update(self.state_initializations(op.states_written))
            if len(new_inits) > 0:
                did_work = True
                self.state_initialization_ops.update(new_inits)
                self.add_initialization_ops(Op.ordered_ops(new_inits))
        self.state_initialization_ops = \
            OrderedSet(op.forwarded for op in self.state_initialization_ops)
        return did_work
예제 #13
0
파일: base.py 프로젝트: wanjinchang/ngraph
    def _transform_computations(self):
        """
        Transform computation graphs to a form that can be run.
        """

        # with Op.saved_user_deps():
        # Run passes on the computation graphs
        self.run_registered_graph_passes(self.all_results)

        # Collect up all ops from the graph and obtain the init graph
        all_ops = OrderedSet(Op.ordered_ops(self.all_results))
        init_op = doall(self.ordered_initializers(all_ops))

        # Run passes on the initialization graphs
        self.run_registered_graph_passes([init_op])

        # Union the init and computation graphs
        self.inits = Op.ordered_ops([init_op])
        all_ops.update(self.inits)

        # create computation which initializes values (called once per session)
        init_op.update_forwards()
        self.init_computation = self.computation(init_op, name="init")

        # Give ids
        for op in all_ops:
            if op not in self.opids:
                self.opids[op] = len(self.opids)

        self.dataflow, self.memory = assign_buffers(self, all_ops, self.fusion)

        # Initialize tensor descriptions
        for op in all_ops:
            self.initialize_tensor_descriptions(op)

        self.ops = self.dataflow.instructions

        self.start_transform_allocate()
        for device_buffer in self.device_buffers:
            device_buffer.transform_allocate()
        self.finish_transform_allocate()

        # Compile the computations now that we know their storage
        for computation in self.computations:
            computation.transform()
        self.finish_transform()
        self.finalized = True
예제 #14
0
파일: passes.py 프로젝트: kkasravi/ngraph
    def do_pass(self, min_ops, transformer):
        """
        Visit the ops until nothing changes.

        Args:
            min_ops: The set of ops that must be computed.
            transformer: An InitGraph object.

        """
        assert isinstance(min_ops, Iterable), "Ops passed into do_pass must be an iterable"
        has_work = True
        while True:
            ops = Op.ordered_ops(min_ops)

            # Check for ops that added state that needs to be initialized, so they can
            # be added to the initialization function.
            has_new_inits = transformer.add_initialization_ops(ops)

            if not has_work and not has_new_inits:
                return

            self.replacement_list = []

            # Make control dependency adjustments for any added control blocks.
            ops = Op.ordered_ops(op.forwarded
                                 for op in transformer.state_initialization_ops + min_ops)
            for op in ops:
                for cop in op.control_deps:
                    if isinstance(cop, ParallelOp):
                        op.remove_control_dep(cop)
                        for dep in cop.control_deps:
                            op.add_control_dep(dep)
                if isinstance(op, SequentialOp) and not op.control_dependencies_computed:
                    op.compute_control_dependencies()

            # pass through the ops in an execution order collecting things to do
            ops = Op.ordered_ops(op.forwarded
                                 for op in transformer.state_initialization_ops + min_ops)
            for op in ops:
                op.update_forwards()
                self.visit(op)
            for old, rep in self.replacement_list:
                old.forwarded.replace_self(rep.forwarded)
            has_work = len(self.replacement_list) > 0
            min_ops = list(_.forwarded for _ in min_ops)
예제 #15
0
    def do_pass(self, ops, **kwargs):
        try:
            import graphviz
        except ImportError:
            raise ImportError(
                "You tried to use the ShowGraph transformer pass but did "
                "not have the python graphviz library installed")
        # Get all ops and edges from this set
        all_ops = Op.all_op_references(ops)
        all_edges = ser._serialize_graph(ops).edges

        vg = graphviz.Digraph(node_attr={'shape': 'box'},
                              graph_attr={
                                  'nodesep': '.5',
                                  'ranksep': '.5'
                              })
        if self.subgraph_attr is not None:
            subgraphs = {}
            for subgraph_name in self.get_subgraphs(all_ops):
                if subgraph_name not in subgraphs and subgraph_name is not None:
                    sg = graphviz.Digraph(
                        name='cluster_{}'.format(subgraph_name))
                    sg.body.append('color="{}"'.format(self.random_color()))
                    sg.body.append('style=filled')
                    sg.body.append('label="{}"'.format(subgraph_name))
                    subgraphs[subgraph_name] = sg
            for op in all_ops:
                subgraph_name = op.metadata.get(self.subgraph_attr, '')
                # hack to show hetr graph, a tuple becomes a list after clone_graph
                if isinstance(subgraph_name, list):
                    subgraph_name = tuple(subgraph_name)
                if subgraph_name in subgraphs:
                    graph = subgraphs[subgraph_name]
                else:
                    graph = vg
                self.add_op_to_graph(op, graph)

            for sg in subgraphs.values():
                vg.subgraph(sg)

        else:
            for op in all_ops:
                self.add_op_to_graph(op, vg)

        for edge in all_edges:
            self.add_edge_to_graph(edge, vg)

        tmp_dir = tempfile.mkdtemp()
        vg.render(directory=tmp_dir, view=self.view, cleanup=True)
        if not self.view:
            logging.info("VizPass graph rendered to {}", tmp_dir)
        # Cleanup
        self.uuid_lookup_table.clear()

        return ops
예제 #16
0
파일: test_serde.py 프로젝트: ami-GS/ngraph
def test_hetr_send_recv_graph_serialization():
    """
    test serializing send/recv ops defined in comm_nodes for hetr communication
    """
    z, recv_x, recv_x_plus_one, send_x, x_plus_one, from_node, send_x_plus_one = \
        create_send_recv_graph()
    ser_string = ser.serialize_graph([z])
    py_graph = ser.deserialize_graph(ser_string)
    orig_graph = Op.all_op_references([z])

    for o1, o2 in zip(sorted(py_graph, key=lambda x: x.uuid),
                      sorted(orig_graph, key=lambda x: x.uuid)):
        assert_object_equality(o1, o2)
예제 #17
0
파일: test_serde.py 프로젝트: ami-GS/ngraph
def test_full_graph_serialization_endtoend():
    base_op, simple_graph = get_simple_graph()

    ser_string = ser.serialize_graph([simple_graph])
    py_graph = ser.deserialize_graph(ser_string)
    orig_graph = Op.all_op_references([simple_graph])

    # This is actually overkill since the checks of the leaf nodes will recursively
    # check equality up the graph, but we also want to make sure the full set of nodes
    # returned is equal
    for o1, o2 in zip(sorted(py_graph, key=lambda x: x.uuid),
                      sorted(orig_graph, key=lambda x: x.uuid)):
        assert_object_equality(o1, o2)
예제 #18
0
    def run_pass(self, process_op, ops, **kwargs):
        assert isinstance(ops, Iterable), "Ops passed into do_pass must be an iterable"
        has_work = True
        while has_work:
            self.begin_batch()

            # pass through the ops in an execution order collecting things to do
            ops = Op.ordered_ops(op.forwarded for op in ops)
            for op in ops:
                op.update_forwards()
                process_op(op)

            has_work = self.end_batch()
            ops = list(op.forwarded for op in ops)
예제 #19
0
        def generate_messages():
            pb_ops, pb_edges = [], []
            pb_returns, pb_placeholders = generate_returns_placeholders()
            ops = Op.all_op_references(returns + list(placeholders))
            for i, op in enumerate(ops):
                pb_ops.append(op_to_protobuf(op))
                add_edges(pb_edges, pb_ops, op)
                if (i != 0 and i % _OPS_PER_MSG == 0) or (i == len(ops) - 1):
                    msg = make_computation_request(pb_ops, pb_edges,
                                                   pb_returns, pb_placeholders)
                    yield msg

                    pb_ops, pb_edges = [], []
                    pb_returns, pb_placeholders = [], []
예제 #20
0
 def do_pass(self, ops):
     assert isinstance(
         ops, Iterable), "Ops passed into do_pass must be an iterable"
     has_work = True
     while has_work:
         self.replacement_list = []
         ops = set(op.forwarded for op in ops)
         for op in Op.ordered_ops(ops):
             op.update_forwards()
             self.visit(op)
         for old, rep in self.replacement_list:
             old.forwarded.replace_self(rep.forwarded)
         has_work = len(self.replacement_list) > 0
     return ops
예제 #21
0
    def do_pass(self, ops, **kwargs):

        ops = OrderedSet(op.forwarded for op in ops)

        for op in reversed(Op.ordered_ops(ops)):
            if op.metadata.get('marker') == 'gather':
                # op is GatherRecvOp
                if self.parallel_axis is None:
                    a = op.metadata['parallel']
                    assert a.length % len(op.from_id) == 0, '{} can not be equally divided by {}'\
                        .format(a, len(op.from_id))
                    self.parallel_axis = make_axis(
                        name=a.name,
                        length=a.length // len(op.from_id),
                        docstring='HeTr parallel axis')
                gather_send_op = op.send_node()
                update_parallel_axis(gather_send_op, self.parallel_axis)
예제 #22
0
    def do_pass(self, ops, **kwargs):

        ops = OrderedSet(op.forwarded for op in ops)

        for op in reversed(Op.ordered_ops(ops)):
            if op.metadata.get('marker') == 'gather':
                # op is GatherRecvOp
                if self.parallel_axes is None:
                    a = op.metadata['parallel']
                    assert a.length % len(op.from_id) == 0, '{} can not be equally divided by {}'\
                        .format(a, len(op.from_id))
                    self.parallel_axes = make_axis(
                        name=a.name,
                        length=a.length // len(op.from_id),
                        docstring='HeTr parallel axis')
                gather_send_op = op.send_nodes[0]

                # clone nodes for each device_id
                replaced_send_ops = OrderedSet()
                new_gather_send_nodes = OrderedSet()
                for i, id in enumerate(op.from_id):
                    new_gather_send_op, new_sends, replaced_sends = clone_graph(
                        root=gather_send_op,
                        clone_id=id,
                        shared_queues_idx=i,
                        parallel_axis=self.parallel_axes,
                        num_clones=len(op.from_id))

                    new_gather_send_nodes.add(new_gather_send_op)

                    new_sends.add(new_gather_send_op)
                    for o in new_sends:
                        self.send_nodes.add(o)

                    replaced_send_ops |= replaced_sends

                op.send_nodes = new_gather_send_nodes

                replaced_send_ops.add(gather_send_op)
                for o in replaced_send_ops:
                    self.send_nodes.remove(o)
예제 #23
0
파일: base.py 프로젝트: wanjinchang/ngraph
    def allocate(self):
        """
        Allocate storage and then initializes constants.

        Will finalize if not already done.
        """
        if self.allocated:
            return

        with Op.saved_user_deps():
            # Disable user_deps during transformations

            if not self.finalized:
                self._transform_computations()

            self.allocate_storage()

            for op in OrderedSet(self.inits + self.ops):
                self.initialize_constant(op)

        self.allocated = True
예제 #24
0
    def do_pass(self, ops, transformer):
        ops = OrderedSet(op.forwarded for op in ops)

        def set_new_axes(root, num_devices):
            visit = self.do_traversal(root)
            self.new_axes = calculate_new_axes(root.axes, self.parallel_axis,
                                               num_devices, False)

            while visit:
                node = visit.pop()
                if hasattr(node, 'axes'):
                    node._TensorOp__axes = self.new_axes

        # Start traversal from the top to the bottom
        for op in reversed(Op.ordered_ops(ops)):
            args = list()
            for arg in op.args:
                if 'marker' in arg.metadata:
                    if 'gather' is arg.metadata['marker']:
                        self.parallel_axis = arg.metadata['parallel']
                        set_new_axes(arg.send_node(), len(arg.from_id))

                        for d in range(1, len(arg.from_id)):
                            if d == (len(arg.from_id) - 1):
                                self.new_axes = calculate_new_axes(
                                    arg.axes, self.parallel_axis,
                                    len(arg.from_id), True)

                            nodes = self.do_traversal(arg.send_node())
                            self.clone_nodes(nodes, arg.from_id[d],
                                             self.new_axes,
                                             self.scatter_shared_queues[d],
                                             self.gather_shared_queues[d])

                args.append(arg)

            if isinstance(op.args, tuple):
                op._Op__args = tuple(args)
            else:
                op.args(args)
예제 #25
0
파일: serde.py 프로젝트: ami-GS/ngraph
def _serialize_graph(ops):
    """
    Serializes a graph and returns the actual protobuf python object (rather than serialized
    byte string as done by `serialize_graph`).
    """
    assert isinstance(
        ops, Iterable), "Ops passed into `serialize_graph` must be an iterable"
    ops = Op.all_op_references(ops)
    pb_ops = []
    pb_edges = []
    for op in ops:
        pb_ops.append(op_to_protobuf(op))
        add_edges(pb_edges, pb_ops, op)

    graph_def = ops_pb.GraphDef()
    for edge in pb_edges:
        temp = graph_def.edges.add()
        temp.CopyFrom(edge)
    for op in pb_ops:
        temp = graph_def.ops.add()
        temp.CopyFrom(op)
    return graph_def
예제 #26
0
파일: hetr_utils.py 프로젝트: QiJune/ngraph
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones):
    """
    clone graph with serde (serialization)
    input:
    output: new_root of the cloned graph
    """
    # clone nodes with GatherSendOp as root using serde
    ser_cloned_nodes = deserialize_graph(serialize_graph([root]))
    new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None)

    orig_ops = {op.uuid: op for op in Op.ordered_ops([root])}
    # Prune ops that are not control_deps of new_gather_send_op
    # deserialize includes extra referenced nodes
    cloned_graph = Op.ordered_ops([new_root])

    new_send_nodes = OrderedSet()
    replaced_send_nodes = OrderedSet()

    # update newly cloned op metadata, generate new UUIDs
    for op in cloned_graph:
        cloned_ops = orig_ops[op.uuid].metadata.get('clones')
        if cloned_ops is None or cloned_ops.get(str(clone_id)) is None:
            op.metadata['transformer'] = op.metadata['device'] + str(clone_id)
            op.metadata['device_id'] = str(clone_id)

            if isinstance(
                    op,
                (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)):
                op._shared_queues = orig_ops[op.uuid]._shared_queues
                op.idx = shared_queues_idx
                if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)):
                    op._send_node = orig_ops[op.uuid].send_node()
            elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)):
                # Cloning a recv node means we need a broadcast, so simulate one by adding an
                # additional sender with the same input data as the original sender.
                send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0])
                op._queue = send_op.queue
                op._send_node = send_op
                new_send_nodes.add(send_op)
                replaced_send_nodes.add(orig_ops[op.uuid].send_node())

            if hasattr(
                    op,
                    'reduction_axes') and parallel_axis in op.reduction_axes:
                op.reduction_axes = set_parallel_axes(op.reduction_axes,
                                                      parallel_axis)

            if getattr(op, 'axes', None) is not None \
                    and parallel_axis in Axes.as_flattened_list(op.axes):
                # if parallel_axis in Axes.as_flattened_list(op.axes):
                op._axes = set_parallel_axes(op.axes, parallel_axis)
                if isinstance(op, DotOp):
                    if parallel_axis in op.x_out_axes:
                        op.x_out_axes = set_parallel_axes(
                            op.x_out_axes, parallel_axis)
                    elif parallel_axis in op.y_out_axes:
                        op.y_out_axes = set_parallel_axes(
                            op.y_out_axes, parallel_axis)
                    else:
                        raise ValueError("Missing parallel_axis in Op's "
                                         "x_out_axes or y_out_axes")

            if isinstance(op,
                          TensorValueOp) and parallel_axis in op.tensor.axes:
                op.tensor._axes = set_parallel_axes(op.tensor.axes,
                                                    parallel_axis)

            args_list = list(op.args)
            for arg_idx, arg_op in enumerate(args_list):
                if arg_op.uuid in orig_ops.keys():
                    if orig_ops[arg_op.uuid].metadata.get('clones') and \
                       orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)):
                        args_list[arg_idx] = \
                            orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id))
            op.invalidate_property_cache('all_deps')
            op._args = tuple(args_list)
            if op != new_root:
                if orig_ops[op.uuid].metadata.get('clones') is None:
                    orig_ops[op.uuid].metadata['clones'] = dict()
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op
                else:
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op

            op.uuid = uuid.uuid4()

    return new_root, new_send_nodes, replaced_send_nodes
예제 #27
0
def clone_graph(root, clone_id, parallel_axis):
    """
    clone graph with serde (serialization)
    input:
    output: new_root of the cloned graph
    """

    # clone nodes with GatherSendOp as root using serde
    ser_cloned_nodes = deserialize_graph(serialize_graph([root]))

    new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None)

    orig_ops = {op.uuid: op for op in Op.ordered_ops([root])}
    cloned_graph = Op.ordered_ops([new_root])

    new_send_nodes = OrderedSet()
    replaced_send_nodes = OrderedSet()

    # update newly cloned op metadata, generate new UUIDs
    for op in cloned_graph:
        cloned_ops = orig_ops[op.uuid].metadata.get('clones')
        if cloned_ops is None or cloned_ops.get(str(clone_id)) is None:
            op.metadata['transformer'] = op.metadata['device'] + str(clone_id)
            op.metadata['device_id'] = str(clone_id)

            if isinstance(
                    op,
                (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)):
                # for gpu communication op buffer
                op.idx = int(clone_id)
                if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)):
                    op._send_node = orig_ops[op.uuid].send_node()

            if hasattr(
                    op,
                    'reduction_axes') and parallel_axis in op.reduction_axes:
                op.reduction_axes = set_parallel_axes(op.reduction_axes,
                                                      parallel_axis)

            if getattr(op, 'axes', None) is not None \
                    and parallel_axis in Axes.as_flattened_list(op.axes):
                # if parallel_axis in Axes.as_flattened_list(op.axes):
                op._axes = set_parallel_axes(op.axes, parallel_axis)
                if isinstance(op, DotOp):
                    if parallel_axis in op.x_out_axes:
                        op.x_out_axes = set_parallel_axes(
                            op.x_out_axes, parallel_axis)
                    elif parallel_axis in op.y_out_axes:
                        op.y_out_axes = set_parallel_axes(
                            op.y_out_axes, parallel_axis)
                    else:
                        raise ValueError("Missing parallel_axis in Op's "
                                         "x_out_axes or y_out_axes")

            if isinstance(op,
                          TensorValueOp) and parallel_axis in op.tensor.axes:
                op.tensor._axes = set_parallel_axes(op.tensor.axes,
                                                    parallel_axis)

            args_list = list(op.args)
            for arg_idx, arg_op in enumerate(args_list):
                if arg_op.uuid in orig_ops.keys():
                    if orig_ops[arg_op.uuid].metadata.get('clones') and \
                       orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)):
                        args_list[arg_idx] = \
                            orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id))

            op.invalidate_property_cache('all_deps')
            op._args = tuple(args_list)
            if op != new_root:
                if orig_ops[op.uuid].metadata.get('clones') is None:
                    orig_ops[op.uuid].metadata['clones'] = dict()
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op
                else:
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op

            op.uuid = uuid.uuid4()

    # create new uuids for all the ops that have references to the new root
    for _op in Op.all_op_references([new_root]):
        _op.uuid = uuid.uuid4()

    return new_root, new_send_nodes, replaced_send_nodes
예제 #28
0
 def do_pass(self, ops):
     return len(Op.ordered_ops(ops))
예제 #29
0
    def __init__(self, hetr, computation_op):
        self.child_computations = dict()
        self.transformer = hetr
        # clear send_nodes for multiple computations
        if hetr.send_nodes:
            hetr.send_nodes.clear()
        self.send_nodes = hetr.send_nodes
        self.computation_op = computation_op

        # self.returns could be replaced by comp_op.returns if it were expressed as a set
        self.returns = OrderedSet()
        if isinstance(computation_op.returns, collections.Container):
            self.returns.update(list(computation_op.returns))
        elif isinstance(computation_op.returns, Op):
            self.returns.update(list([computation_op.returns]))

        # if one of the requested results is marked as distributed across devices,
        # wrap it in a ResultOp to facilitate DistributedPass inserting a gather operation
        new_returns = OrderedSet()
        for op in self.returns:
            if 'device_id' in op.metadata and \
                    isinstance(op.metadata['device_id'], (list, tuple)):
                op.metadata['is_split_op'] = True
                new_result = ResultOp(device_id=0, args=tuple([op]))
                op.metadata['hetr_replaced_by'] = new_result
                new_result.metadata['replaces_op'] = op
                new_returns.add(new_result)
            else:
                new_returns.add(op)

        # Do Hetr passes
        logger.info('Running graph passes'),
        pass_ops = new_returns | OrderedSet(self.computation_op.parameters)
        for graph_pass in self.transformer.graph_passes:
            pass_ops = pass_ops | OrderedSet(hetr.send_nodes)
            graph_pass.do_pass(ops=pass_ops)

        # hack around new TensorValueOp that wraps AssignableTensorOp
        # autogenerated by creating a ComputationOp:
        for p in self.computation_op.parameters:
            if isinstance(p, TensorValueOp):
                p.metadata.update(p.states_read[0].metadata)

        logger.info('Launching child processes'),
        # assume all children are the same type
        # and all GPUs are in one chassis
        num_process = len(self.transformer.child_transformers)
        ppn = 1 if self.transformer.default_device == 'cpu' else num_process
        self.transformer.mpilauncher.launch(num_process, ppn)
        self.transformer.setup_child_transformers(num_process)

        def is_my_op(op, name):
            op_trans = op.metadata['transformer']
            return name == op_trans or name in op_trans

        logger.info('Serializaing computation graph'),
        # build whole_graph once to avoid slow serialization once per worker
        # split whole pb message into list of smaller chunks
        # gRPC prefers sending smaller messages
        placeholders = [p for p in self.computation_op.parameters]
        all_returns = [o for o in self.send_nodes | new_returns]
        transform_returns = [
            o.args[0] if isinstance(o, ResultOp) else o for o in all_returns
        ]
        whole_graph = Op.all_op_references(transform_returns + placeholders)

        pb_whole_graph = []
        pb_ops, pb_edges = [], []
        for i, o in enumerate(whole_graph):
            pb_ops.append(op_to_protobuf(o))
            add_edges(pb_edges, pb_ops, o)
            if (i != 0 and i % _OPS_PER_MSG == 0) or (i
                                                      == len(whole_graph) - 1):
                pb_whole_graph.append((pb_ops, pb_edges))
                pb_ops, pb_edges = [], []

        t_placeholders, t_returns = {}, {}
        for t_name in self.transformer.child_transformers.keys():
            t_placeholders[t_name] = [
                p for p in placeholders if is_my_op(p, t_name)
            ]
            t_returns[t_name] = [r for r in all_returns if is_my_op(r, t_name)]

        # create_computation is an async call using gPRC future
        # allowing child transformers to create computation simultaneously
        # get_computation waits the corresponding request to finish
        logger.info('Creating remote computations'),
        for t_name, trans in iteritems(self.transformer.child_transformers):
            logger.debug('child transformer: {}'.format(t_name))
            trans.build_transformer()
            transform_ops = [
                r.args[0] if isinstance(r, ResultOp) else r
                for r in t_returns[t_name]
            ]
            trans.create_computation(pb_whole_graph, transform_ops,
                                     t_placeholders[t_name])

        for t_name, trans in iteritems(self.transformer.child_transformers):
            comp = trans.get_computation()
            comp.param_idx = [
                g_pos for g_pos, p in enumerate(self.computation_op.parameters)
                if is_my_op(p, t_name)
            ]

            # when there is a ResultOp, hack around it
            comp.returns = dict()
            for i, op in enumerate(t_returns[t_name]):
                if op in self.returns and 'hetr_replaced_by' not in op.metadata:
                    comp.returns[op] = i
                elif 'replaces_op' in op.metadata and op.metadata[
                        'replaces_op'] in self.returns:
                    comp.returns[op.metadata['replaces_op']] = i
            self.child_computations[t_name] = comp