Beispiel #1
0
    def do_pass(self, min_ops, transformer):
        """
        Visit the ops until nothing changes.

        Args:
            min_ops: The set of ops that must be computed.
            transformer: An InitGraph object.

        """
        assert isinstance(min_ops, Iterable), "Ops passed into do_pass must be an iterable"
        has_work = True
        while True:
            ops = Op.ordered_ops(min_ops)

            # Check for ops that added state that needs to be initialized, so they can
            # be added to the initialization function.
            has_new_inits = transformer.add_initialization_ops(ops)
            if not has_work and not has_new_inits:
                return

            self.replacement_list = []

            # pass through the ops in an execution order collecting things to do
            ops = Op.ordered_ops(op.forwarded
                                 for op in transformer.state_initialization_ops + min_ops)
            for op in ops:
                op.update_forwards()
                self.visit(op)

            # Perform the gathered replacements
            for old, rep in self.replacement_list:
                old.forwarded.replace_self(rep.forwarded)
            has_work = len(self.replacement_list) > 0
            min_ops = list(_.forwarded for _ in min_ops)
Beispiel #2
0
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones):
    """
    clone graph with serde (serialization)
    input:
    output: new_root of the cloned graph
    """
    # clone nodes with GatherSendOp as root using serde
    ser_cloned_nodes = deserialize_graph(serialize_graph([root]))
    new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None)

    orig_ops = {op.uuid: op for op in Op.ordered_ops([root])}
    # Prune ops that are not control_deps of new_gather_send_op
    # deserialize includes extra referenced nodes
    cloned_graph = Op.ordered_ops([new_root])

    new_send_nodes = OrderedSet()
    replaced_send_nodes = OrderedSet()

    # update newly cloned op metadata, generate new UUIDs
    for op in cloned_graph:
        op.metadata['transformer'] = op.metadata['device'] + str(clone_id)
        op.metadata['device_id'] = str(clone_id)

        if isinstance(op, (ScatterRecvOp, GatherSendOp)):
            op._shared_queues = orig_ops[op.uuid]._shared_queues
            op.idx = shared_queues_idx
            if isinstance(op, ScatterRecvOp):
                op._send_node = orig_ops[op.uuid].send_node()
        elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)):
            # Cloning a recv node means we need a broadcast, so simulate one by adding an
            # additional sender with the same input data as the original sender.
            # TODO replace with real broadcast #1398 #1399
            send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0])
            op._queue = send_op.queue
            op._send_node = send_op
            new_send_nodes.add(send_op)
            replaced_send_nodes.add(orig_ops[op.uuid].send_node())

        if hasattr(op, '_axes') and parallel_axis in op._axes:
            op._axes = calculate_scatter_axes(op.axes, parallel_axis,
                                              num_clones)
            # TODO: Revisit to handle axes updation better. Github Ticket #1355
            if isinstance(op, DotOp):
                if parallel_axis in op.x_out_axes:
                    op.x_out_axes = calculate_scatter_axes(
                        op.x_out_axes, parallel_axis, num_clones)
                elif parallel_axis in op.y_out_axes:
                    op.y_out_axes = calculate_scatter_axes(
                        op.y_out_axes, parallel_axis, num_clones)
                else:
                    raise ValueError(
                        "Missing parallel_axis in Op's x_out_axes or y_out_axes"
                    )
        op.uuid = uuid.uuid4()

    return new_root, new_send_nodes, replaced_send_nodes
Beispiel #3
0
    def _transform_computations(self):
        """
        Transform computation graphs to a form that can be run.
        """

        # Run passes on the computation graphs
        all_results = []
        for comp in self.computations:
            all_results.append(comp.computation_op)

        all_ops = self.run_registered_graph_passes(ops=all_results)

        # Collect up all ops from the graph and obtain the init graph
        all_ops = OrderedSet(Op.ordered_ops(all_ops))

        def ensure_tensor(op):
            op = op.forwarded
            tensor_description = op.tensor_description()
            base = tensor_description.base
            tensor = self.op_tensors.get(base, None)
            if tensor is None:
                tensor = self.device_buffer_storage(
                    base.tensor_size,
                    base.dtype,
                    base.name
                )
                self.op_tensors[base] = tensor
                self.device_buffers.add(tensor)
            tensor_view = tensor.device_tensor(tensor_description)
            self.op_tensor_views[tensor_description] = tensor_view

        self.ops = Op.ordered_ops(all_ops)
        for op in self.ops:
            if op.is_tensor_op:
                ensure_tensor(op)

        self.start_transform_allocate()
        for device_buffer in self.device_buffers:
            device_buffer.transform_allocate()
        self.finish_transform_allocate()

        # Compile the computations now that we know their storage
        for comp in self.computations:
            comp.computation_name = \
                self.transform_ordered_ops(comp,
                                           Op.ordered_ops([comp.computation_op]),
                                           name=comp.name)
        self.finish_transform()
        self.finalized = True
Beispiel #4
0
    def _transform_computations(self):
        """
        Transform computation graphs to a form that can be run.
        """

        # Run passes on the computation graphs
        all_results = []
        for comp in self.computations:
            all_results.append(comp.computation)

        all_ops = self.run_registered_graph_passes(all_results)
        self.init_computation = \
            self.add_computation(computation(doall(self.state_initialization_ops)).named('init'))
        all_ops.append(self.init_computation.computation)

        # Collect up all ops from the graph and obtain the init graph
        all_ops = OrderedSet(Op.ordered_ops(all_ops))

        def init_tensor_description(tensor_description):
            if tensor_description.buffer is None:
                tensor_description.buffer = self.device_buffer_storage(
                    tensor_description.base.tensor_size,
                    tensor_description.dtype,
                    tensor_description.name
                )
                self.device_buffers.add(tensor_description.buffer)
            tensor_description.value = \
                tensor_description.buffer.device_tensor(tensor_description)

        for state in self.init_states:
            init_tensor_description(state.tensor_description())
        self.ops = Op.ordered_ops(all_ops)
        for op in self.ops:
            if op.is_tensor_op:
                init_tensor_description(op.tensor_description())

        self.start_transform_allocate()
        for device_buffer in self.device_buffers:
            device_buffer.transform_allocate()
        self.finish_transform_allocate()

        # Compile the computations now that we know their storage
        for comp in self.computations:
            comp.computation_name = \
                self.transform_ordered_ops(Op.ordered_ops([comp.computation]),
                                           name=comp.name)
        self.finish_transform()
        self.finalized = True
Beispiel #5
0
    def do_pass(self, ops, transformer):

        ops = OrderedSet(op.forwarded for op in ops)

        for op in reversed(Op.ordered_ops(ops)):
            if op.metadata.get('marker') == 'gather':
                # op is GatherRecvOp
                self.parallel_axes = op.metadata['parallel']

                gather_send_op = op.send_nodes[0]

                # clone nodes for each device_id
                replaced_send_ops = OrderedSet()
                new_gather_send_nodes = OrderedSet()
                for i, id in enumerate(op.from_id):
                    new_gather_send_op, new_sends, replaced_sends = clone_graph(
                        root=gather_send_op,
                        clone_id=id,
                        shared_queues_idx=i,
                        parallel_axis=self.parallel_axes,
                        num_clones=len(op.from_id))

                    new_gather_send_nodes.add(new_gather_send_op)

                    new_sends.add(new_gather_send_op)
                    for o in new_sends:
                        self.send_nodes.add(o)

                    replaced_send_ops |= replaced_sends

                op.send_nodes = new_gather_send_nodes

                replaced_send_ops.add(gather_send_op)
                for o in replaced_send_ops:
                    self.send_nodes.remove(o)
Beispiel #6
0
    def do_pass(self, min_ops, transformer):
        """
        Visit the ops until nothing changes.

        Args:
            min_ops: The set of ops that must be computed.
            transformer: An InitGraph object.

        """
        assert isinstance(
            min_ops, Iterable), "Ops passed into do_pass must be an iterable"
        has_work = True
        while True:
            if not has_work:
                return

            self.replacement_list = []

            # pass through the ops in an execution order collecting things to do
            ops = Op.ordered_ops(op.forwarded for op in min_ops)
            for op in ops:
                op.update_forwards()
                self.visit(op)

            # Perform the gathered replacements
            for old, rep in self.replacement_list:
                old.forwarded.replace_self(rep.forwarded)
            has_work = len(self.replacement_list) > 0
            min_ops = list(op.forwarded for op in min_ops)
Beispiel #7
0
    def Computation(self, request_iterator, context):
        logger.info("server: computation")
        if not self.transformer:
            return hetr_pb2.ComputationReply(
                comp_id=-1, message="build transformer before computation")
        try:
            comp_id = self.new_comp_id()
            pb_ops, pb_edges = [], []
            returns, placeholders = [], []
            reconstructed_returns, reconstructed_placeholders = [], []
            for request in request_iterator:
                pb_ops.extend(request.ops)
                pb_edges.extend(request.edges)
                returns.extend([protobuf_to_op(op) for op in request.returns])
                placeholders.extend(
                    [protobuf_to_op(op) for op in request.placeholders])

            subgraph = _deserialize_graph_ops_edges(pb_ops, pb_edges)
            ops = Op.ordered_ops(subgraph)
            for r in returns:
                for op in ops:
                    if op.uuid == r.uuid:
                        reconstructed_returns.append(op)
            for p in placeholders:
                for op in ops:
                    if op.uuid == p.uuid:
                        reconstructed_placeholders.append(op)

            computation = self.transformer.computation(
                reconstructed_returns, *reconstructed_placeholders)
            self.computations[comp_id] = computation
            return hetr_pb2.ComputationReply(comp_id=comp_id)
        except Exception:
            return hetr_pb2.ComputationReply(comp_id=-1,
                                             message=traceback.format_exc())
Beispiel #8
0
    def Computation(self, request, context):
        if not self.transformer:
            return hetr_pb2.ComputationReply(comp_id=-1)

        try:
            comp_id = self.new_comp_id()
            subgraph = _deserialize_graph(request.subgraph)
            returns = []
            placeholders = []
            for pb_op in request.returns:
                returns.append(protobuf_to_op(pb_op))
            for pb_op in request.placeholders:
                placeholders.append(protobuf_to_op(pb_op))
            return_list = []
            placeholder_list = []
            ops = Op.ordered_ops(subgraph)
            for op in ops:
                for r in returns:
                    if op.uuid == r.uuid:
                        return_list.append(op)
            for op in ops:
                for p in placeholders:
                    if op.uuid == p.uuid:
                        placeholder_list.append(op)
            computation = self.transformer.computation(return_list,
                                                       *placeholder_list)
            self.computations[comp_id] = computation
            return hetr_pb2.ComputationReply(comp_id=comp_id)
        except:
            return hetr_pb2.ComputationReply(comp_id=-1)
Beispiel #9
0
    def add_initialization_ops(self, ops):
        """
        Ensure initializations have been captured for state in ops.

        Args:
            ops: Collection of ops.

        Returns:
            True if new initializations were added.

        """
        did_work = False
        for op in ops:
            if op in self.init_checked_ops:
                continue
            self.init_checked_ops.add(op)
            new_inits = self.state_initializations(op.states_read)
            new_inits.update(self.state_initializations(op.states_written))
            if len(new_inits) > 0:
                did_work = True
                self.state_initialization_ops.update(new_inits)
                self.add_initialization_ops(Op.ordered_ops(new_inits))
        self.state_initialization_ops = \
            OrderedSet(op.forwarded for op in self.state_initialization_ops)
        return did_work
Beispiel #10
0
def update_parallel_axis(root, parallel_axis):
    for op in Op.ordered_ops([root]):

        if hasattr(op,
                   'reduction_axes') and parallel_axis in op.reduction_axes:
            op.reduction_axes = set_parallel_axes(op.reduction_axes,
                                                  parallel_axis)

        if getattr(op, 'axes', None) is not None \
                and parallel_axis in Axes.as_flattened_list(op.axes):
            # if parallel_axis in Axes.as_flattened_list(op.axes):
            op._axes = set_parallel_axes(op.axes, parallel_axis)
            if isinstance(op, DotOp):
                if parallel_axis in op.x_out_axes:
                    op.x_out_axes = set_parallel_axes(op.x_out_axes,
                                                      parallel_axis)
                elif parallel_axis in op.y_out_axes:
                    op.y_out_axes = set_parallel_axes(op.y_out_axes,
                                                      parallel_axis)
                else:
                    raise ValueError("Missing parallel_axis in Op's "
                                     "x_out_axes or y_out_axes")

        if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes:
            op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis)
Beispiel #11
0
    def _transform_computations(self):
        """
        Transform computation graphs to a form that can be run.
        """

        # with Op.saved_user_deps():
        # Run passes on the computation graphs
        self.run_registered_graph_passes(self.all_results)

        # Collect up all ops from the graph and obtain the init graph
        all_ops = OrderedSet(Op.ordered_ops(self.all_results))
        init_op = doall(self.ordered_initializers(all_ops))

        # Run passes on the initialization graphs
        self.run_registered_graph_passes([init_op])

        # Union the init and computation graphs
        self.inits = Op.ordered_ops([init_op])
        all_ops.update(self.inits)

        # create computation which initializes values (called once per session)
        init_op.update_forwards()
        self.init_computation = self.computation(init_op, name="init")

        # Give ids
        for op in all_ops:
            if op not in self.opids:
                self.opids[op] = len(self.opids)

        self.dataflow, self.memory = assign_buffers(self, all_ops, self.fusion)

        # Initialize tensor descriptions
        for op in all_ops:
            self.initialize_tensor_descriptions(op)

        self.ops = self.dataflow.instructions

        self.start_transform_allocate()
        for device_buffer in self.device_buffers:
            device_buffer.transform_allocate()
        self.finish_transform_allocate()

        # Compile the computations now that we know their storage
        for computation in self.computations:
            computation.transform()
        self.finish_transform()
        self.finalized = True
Beispiel #12
0
    def do_pass(self, min_ops, transformer):
        """
        Visit the ops until nothing changes.

        Args:
            min_ops: The set of ops that must be computed.
            transformer: An InitGraph object.

        """
        assert isinstance(min_ops, Iterable), "Ops passed into do_pass must be an iterable"
        has_work = True
        while True:
            ops = Op.ordered_ops(min_ops)

            # Check for ops that added state that needs to be initialized, so they can
            # be added to the initialization function.
            has_new_inits = transformer.add_initialization_ops(ops)

            if not has_work and not has_new_inits:
                return

            self.replacement_list = []

            # Make control dependency adjustments for any added control blocks.
            ops = Op.ordered_ops(op.forwarded
                                 for op in transformer.state_initialization_ops + min_ops)
            for op in ops:
                for cop in op.control_deps:
                    if isinstance(cop, ParallelOp):
                        op.remove_control_dep(cop)
                        for dep in cop.control_deps:
                            op.add_control_dep(dep)
                if isinstance(op, SequentialOp) and not op.control_dependencies_computed:
                    op.compute_control_dependencies()

            # pass through the ops in an execution order collecting things to do
            ops = Op.ordered_ops(op.forwarded
                                 for op in transformer.state_initialization_ops + min_ops)
            for op in ops:
                op.update_forwards()
                self.visit(op)
            for old, rep in self.replacement_list:
                old.forwarded.replace_self(rep.forwarded)
            has_work = len(self.replacement_list) > 0
            min_ops = list(_.forwarded for _ in min_ops)
Beispiel #13
0
    def Computation(self, request_iterator, context):
        logger.debug("server: computation")
        if not self.transformer:
            return hetr_pb2.ComputationReply(
                comp_id=-1, message="build transformer before computation")
        try:
            comp_id = self.new_comp_id()
            pb_ops, pb_edges = [], []
            returns, placeholders = [], []
            reconstructed_returns, reconstructed_placeholders = [], []
            for request in request_iterator:
                pb_ops.extend(request.ops)
                pb_edges.extend(request.edges)
                returns.extend([protobuf_to_op(op) for op in request.returns])
                placeholders.extend(
                    [protobuf_to_op(op) for op in request.placeholders])

            subgraph = _deserialize_graph_ops_edges(pb_ops, pb_edges)

            # Add dependency on recv op to their send op in scenarios where the send buffer
            # is passed as an argument to the communication call (gather/scatter)
            # on the root device.
            # This ensures that by the send buffer does not get reused before
            # the recv_buf gets access to items
            root_idx = 0
            for op in Op.all_op_references(subgraph):
                if isinstance(op, (GatherRecvOp)) and \
                   MPI.COMM_WORLD.Get_rank() == op.metadata['device_id']:
                    args = list(op._args)
                    args.extend(op.send_node().args)
                    op._args = tuple(args)
                    op.invalidate_property_cache('all_deps')
                elif (isinstance(op, (ScatterRecvOp))
                      and MPI.COMM_WORLD.Get_rank() == root_idx and
                      MPI.COMM_WORLD.Get_rank() in op.metadata['device_id']):
                    args = list(op._args)
                    args.extend(op.send_node().args)
                    op._args = tuple(args)
                    op.invalidate_property_cache('all_deps')

            ops = Op.ordered_ops(subgraph)
            for r in returns:
                for op in ops:
                    if op.uuid == r.uuid:
                        reconstructed_returns.append(op)
            for p in placeholders:
                for op in ops:
                    if op.uuid == p.uuid:
                        reconstructed_placeholders.append(op)

            computation = self.transformer.computation(
                reconstructed_returns, *reconstructed_placeholders)
            self.computations[comp_id] = computation
            return hetr_pb2.ComputationReply(comp_id=comp_id)
        except Exception:
            return hetr_pb2.ComputationReply(comp_id=-1,
                                             message=traceback.format_exc())
Beispiel #14
0
    def run_pass(self, process_op, ops, **kwargs):
        assert isinstance(ops, Iterable), "Ops passed into do_pass must be an iterable"
        has_work = True
        while has_work:
            self.begin_batch()

            # pass through the ops in an execution order collecting things to do
            ops = Op.ordered_ops(op.forwarded for op in ops)
            for op in ops:
                op.update_forwards()
                process_op(op)

            has_work = self.end_batch()
            ops = list(op.forwarded for op in ops)
Beispiel #15
0
 def do_pass(self, ops):
     assert isinstance(
         ops, Iterable), "Ops passed into do_pass must be an iterable"
     has_work = True
     while has_work:
         self.replacement_list = []
         ops = set(op.forwarded for op in ops)
         for op in Op.ordered_ops(ops):
             op.update_forwards()
             self.visit(op)
         for old, rep in self.replacement_list:
             old.forwarded.replace_self(rep.forwarded)
         has_work = len(self.replacement_list) > 0
     return ops
Beispiel #16
0
    def do_pass(self, ops, **kwargs):

        ops = OrderedSet(op.forwarded for op in ops)

        for op in reversed(Op.ordered_ops(ops)):
            if op.metadata.get('marker') == 'gather':
                # op is GatherRecvOp
                if self.parallel_axis is None:
                    a = op.metadata['parallel']
                    assert a.length % len(op.from_id) == 0, '{} can not be equally divided by {}'\
                        .format(a, len(op.from_id))
                    self.parallel_axis = make_axis(
                        name=a.name,
                        length=a.length // len(op.from_id),
                        docstring='HeTr parallel axis')
                gather_send_op = op.send_node()
                update_parallel_axis(gather_send_op, self.parallel_axis)
    def do_pass(self, ops, **kwargs):

        ops = OrderedSet(op.forwarded for op in ops)

        for op in reversed(Op.ordered_ops(ops)):
            if op.metadata.get('marker') == 'gather':
                # op is GatherRecvOp
                if self.parallel_axes is None:
                    a = op.metadata['parallel']
                    assert a.length % len(op.from_id) == 0, '{} can not be equally divided by {}'\
                        .format(a, len(op.from_id))
                    self.parallel_axes = make_axis(
                        name=a.name,
                        length=a.length // len(op.from_id),
                        docstring='HeTr parallel axis')
                gather_send_op = op.send_nodes[0]

                # clone nodes for each device_id
                replaced_send_ops = OrderedSet()
                new_gather_send_nodes = OrderedSet()
                for i, id in enumerate(op.from_id):
                    new_gather_send_op, new_sends, replaced_sends = clone_graph(
                        root=gather_send_op,
                        clone_id=id,
                        shared_queues_idx=i,
                        parallel_axis=self.parallel_axes,
                        num_clones=len(op.from_id))

                    new_gather_send_nodes.add(new_gather_send_op)

                    new_sends.add(new_gather_send_op)
                    for o in new_sends:
                        self.send_nodes.add(o)

                    replaced_send_ops |= replaced_sends

                op.send_nodes = new_gather_send_nodes

                replaced_send_ops.add(gather_send_op)
                for o in replaced_send_ops:
                    self.send_nodes.remove(o)
Beispiel #18
0
    def do_pass(self, ops, transformer):
        ops = OrderedSet(op.forwarded for op in ops)

        def set_new_axes(root, num_devices):
            visit = self.do_traversal(root)
            self.new_axes = calculate_new_axes(root.axes, self.parallel_axis,
                                               num_devices, False)

            while visit:
                node = visit.pop()
                if hasattr(node, 'axes'):
                    node._TensorOp__axes = self.new_axes

        # Start traversal from the top to the bottom
        for op in reversed(Op.ordered_ops(ops)):
            args = list()
            for arg in op.args:
                if 'marker' in arg.metadata:
                    if 'gather' is arg.metadata['marker']:
                        self.parallel_axis = arg.metadata['parallel']
                        set_new_axes(arg.send_node(), len(arg.from_id))

                        for d in range(1, len(arg.from_id)):
                            if d == (len(arg.from_id) - 1):
                                self.new_axes = calculate_new_axes(
                                    arg.axes, self.parallel_axis,
                                    len(arg.from_id), True)

                            nodes = self.do_traversal(arg.send_node())
                            self.clone_nodes(nodes, arg.from_id[d],
                                             self.new_axes,
                                             self.scatter_shared_queues[d],
                                             self.gather_shared_queues[d])

                args.append(arg)

            if isinstance(op.args, tuple):
                op._Op__args = tuple(args)
            else:
                op.args(args)
Beispiel #19
0
 def do_pass(self, ops):
     return len(Op.ordered_ops(ops))
Beispiel #20
0
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones):
    """
    clone graph with serde (serialization)
    input:
    output: new_root of the cloned graph
    """
    # clone nodes with GatherSendOp as root using serde
    ser_cloned_nodes = deserialize_graph(serialize_graph([root]))
    new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None)

    orig_ops = {op.uuid: op for op in Op.ordered_ops([root])}
    # Prune ops that are not control_deps of new_gather_send_op
    # deserialize includes extra referenced nodes
    cloned_graph = Op.ordered_ops([new_root])

    new_send_nodes = OrderedSet()
    replaced_send_nodes = OrderedSet()

    # update newly cloned op metadata, generate new UUIDs
    for op in cloned_graph:
        cloned_ops = orig_ops[op.uuid].metadata.get('clones')
        if cloned_ops is None or cloned_ops.get(str(clone_id)) is None:
            op.metadata['transformer'] = op.metadata['device'] + str(clone_id)
            op.metadata['device_id'] = str(clone_id)

            if isinstance(
                    op,
                (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)):
                op._shared_queues = orig_ops[op.uuid]._shared_queues
                op.idx = shared_queues_idx
                if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)):
                    op._send_node = orig_ops[op.uuid].send_node()
            elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)):
                # Cloning a recv node means we need a broadcast, so simulate one by adding an
                # additional sender with the same input data as the original sender.
                send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0])
                op._queue = send_op.queue
                op._send_node = send_op
                new_send_nodes.add(send_op)
                replaced_send_nodes.add(orig_ops[op.uuid].send_node())

            if hasattr(
                    op,
                    'reduction_axes') and parallel_axis in op.reduction_axes:
                op.reduction_axes = set_parallel_axes(op.reduction_axes,
                                                      parallel_axis)

            if getattr(op, 'axes', None) is not None \
                    and parallel_axis in Axes.as_flattened_list(op.axes):
                # if parallel_axis in Axes.as_flattened_list(op.axes):
                op._axes = set_parallel_axes(op.axes, parallel_axis)
                if isinstance(op, DotOp):
                    if parallel_axis in op.x_out_axes:
                        op.x_out_axes = set_parallel_axes(
                            op.x_out_axes, parallel_axis)
                    elif parallel_axis in op.y_out_axes:
                        op.y_out_axes = set_parallel_axes(
                            op.y_out_axes, parallel_axis)
                    else:
                        raise ValueError("Missing parallel_axis in Op's "
                                         "x_out_axes or y_out_axes")

            if isinstance(op,
                          TensorValueOp) and parallel_axis in op.tensor.axes:
                op.tensor._axes = set_parallel_axes(op.tensor.axes,
                                                    parallel_axis)

            args_list = list(op.args)
            for arg_idx, arg_op in enumerate(args_list):
                if arg_op.uuid in orig_ops.keys():
                    if orig_ops[arg_op.uuid].metadata.get('clones') and \
                       orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)):
                        args_list[arg_idx] = \
                            orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id))
            op.invalidate_property_cache('all_deps')
            op._args = tuple(args_list)
            if op != new_root:
                if orig_ops[op.uuid].metadata.get('clones') is None:
                    orig_ops[op.uuid].metadata['clones'] = dict()
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op
                else:
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op

            op.uuid = uuid.uuid4()

    return new_root, new_send_nodes, replaced_send_nodes
Beispiel #21
0
def clone_graph(root, clone_id, parallel_axis):
    """
    clone graph with serde (serialization)
    input:
    output: new_root of the cloned graph
    """

    # clone nodes with GatherSendOp as root using serde
    ser_cloned_nodes = deserialize_graph(serialize_graph([root]))

    new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None)

    orig_ops = {op.uuid: op for op in Op.ordered_ops([root])}
    cloned_graph = Op.ordered_ops([new_root])

    new_send_nodes = OrderedSet()
    replaced_send_nodes = OrderedSet()

    # update newly cloned op metadata, generate new UUIDs
    for op in cloned_graph:
        cloned_ops = orig_ops[op.uuid].metadata.get('clones')
        if cloned_ops is None or cloned_ops.get(str(clone_id)) is None:
            op.metadata['transformer'] = op.metadata['device'] + str(clone_id)
            op.metadata['device_id'] = str(clone_id)

            if isinstance(
                    op,
                (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)):
                # for gpu communication op buffer
                op.idx = int(clone_id)
                if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)):
                    op._send_node = orig_ops[op.uuid].send_node()

            if hasattr(
                    op,
                    'reduction_axes') and parallel_axis in op.reduction_axes:
                op.reduction_axes = set_parallel_axes(op.reduction_axes,
                                                      parallel_axis)

            if getattr(op, 'axes', None) is not None \
                    and parallel_axis in Axes.as_flattened_list(op.axes):
                # if parallel_axis in Axes.as_flattened_list(op.axes):
                op._axes = set_parallel_axes(op.axes, parallel_axis)
                if isinstance(op, DotOp):
                    if parallel_axis in op.x_out_axes:
                        op.x_out_axes = set_parallel_axes(
                            op.x_out_axes, parallel_axis)
                    elif parallel_axis in op.y_out_axes:
                        op.y_out_axes = set_parallel_axes(
                            op.y_out_axes, parallel_axis)
                    else:
                        raise ValueError("Missing parallel_axis in Op's "
                                         "x_out_axes or y_out_axes")

            if isinstance(op,
                          TensorValueOp) and parallel_axis in op.tensor.axes:
                op.tensor._axes = set_parallel_axes(op.tensor.axes,
                                                    parallel_axis)

            args_list = list(op.args)
            for arg_idx, arg_op in enumerate(args_list):
                if arg_op.uuid in orig_ops.keys():
                    if orig_ops[arg_op.uuid].metadata.get('clones') and \
                       orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)):
                        args_list[arg_idx] = \
                            orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id))

            op.invalidate_property_cache('all_deps')
            op._args = tuple(args_list)
            if op != new_root:
                if orig_ops[op.uuid].metadata.get('clones') is None:
                    orig_ops[op.uuid].metadata['clones'] = dict()
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op
                else:
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op

            op.uuid = uuid.uuid4()

    # create new uuids for all the ops that have references to the new root
    for _op in Op.all_op_references([new_root]):
        _op.uuid = uuid.uuid4()

    return new_root, new_send_nodes, replaced_send_nodes