def do_pass(self, min_ops, transformer): """ Visit the ops until nothing changes. Args: min_ops: The set of ops that must be computed. transformer: An InitGraph object. """ assert isinstance(min_ops, Iterable), "Ops passed into do_pass must be an iterable" has_work = True while True: ops = Op.ordered_ops(min_ops) # Check for ops that added state that needs to be initialized, so they can # be added to the initialization function. has_new_inits = transformer.add_initialization_ops(ops) if not has_work and not has_new_inits: return self.replacement_list = [] # pass through the ops in an execution order collecting things to do ops = Op.ordered_ops(op.forwarded for op in transformer.state_initialization_ops + min_ops) for op in ops: op.update_forwards() self.visit(op) # Perform the gathered replacements for old, rep in self.replacement_list: old.forwarded.replace_self(rep.forwarded) has_work = len(self.replacement_list) > 0 min_ops = list(_.forwarded for _ in min_ops)
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones): """ clone graph with serde (serialization) input: output: new_root of the cloned graph """ # clone nodes with GatherSendOp as root using serde ser_cloned_nodes = deserialize_graph(serialize_graph([root])) new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None) orig_ops = {op.uuid: op for op in Op.ordered_ops([root])} # Prune ops that are not control_deps of new_gather_send_op # deserialize includes extra referenced nodes cloned_graph = Op.ordered_ops([new_root]) new_send_nodes = OrderedSet() replaced_send_nodes = OrderedSet() # update newly cloned op metadata, generate new UUIDs for op in cloned_graph: op.metadata['transformer'] = op.metadata['device'] + str(clone_id) op.metadata['device_id'] = str(clone_id) if isinstance(op, (ScatterRecvOp, GatherSendOp)): op._shared_queues = orig_ops[op.uuid]._shared_queues op.idx = shared_queues_idx if isinstance(op, ScatterRecvOp): op._send_node = orig_ops[op.uuid].send_node() elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)): # Cloning a recv node means we need a broadcast, so simulate one by adding an # additional sender with the same input data as the original sender. # TODO replace with real broadcast #1398 #1399 send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0]) op._queue = send_op.queue op._send_node = send_op new_send_nodes.add(send_op) replaced_send_nodes.add(orig_ops[op.uuid].send_node()) if hasattr(op, '_axes') and parallel_axis in op._axes: op._axes = calculate_scatter_axes(op.axes, parallel_axis, num_clones) # TODO: Revisit to handle axes updation better. Github Ticket #1355 if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = calculate_scatter_axes( op.x_out_axes, parallel_axis, num_clones) elif parallel_axis in op.y_out_axes: op.y_out_axes = calculate_scatter_axes( op.y_out_axes, parallel_axis, num_clones) else: raise ValueError( "Missing parallel_axis in Op's x_out_axes or y_out_axes" ) op.uuid = uuid.uuid4() return new_root, new_send_nodes, replaced_send_nodes
def _transform_computations(self): """ Transform computation graphs to a form that can be run. """ # Run passes on the computation graphs all_results = [] for comp in self.computations: all_results.append(comp.computation_op) all_ops = self.run_registered_graph_passes(ops=all_results) # Collect up all ops from the graph and obtain the init graph all_ops = OrderedSet(Op.ordered_ops(all_ops)) def ensure_tensor(op): op = op.forwarded tensor_description = op.tensor_description() base = tensor_description.base tensor = self.op_tensors.get(base, None) if tensor is None: tensor = self.device_buffer_storage( base.tensor_size, base.dtype, base.name ) self.op_tensors[base] = tensor self.device_buffers.add(tensor) tensor_view = tensor.device_tensor(tensor_description) self.op_tensor_views[tensor_description] = tensor_view self.ops = Op.ordered_ops(all_ops) for op in self.ops: if op.is_tensor_op: ensure_tensor(op) self.start_transform_allocate() for device_buffer in self.device_buffers: device_buffer.transform_allocate() self.finish_transform_allocate() # Compile the computations now that we know their storage for comp in self.computations: comp.computation_name = \ self.transform_ordered_ops(comp, Op.ordered_ops([comp.computation_op]), name=comp.name) self.finish_transform() self.finalized = True
def _transform_computations(self): """ Transform computation graphs to a form that can be run. """ # Run passes on the computation graphs all_results = [] for comp in self.computations: all_results.append(comp.computation) all_ops = self.run_registered_graph_passes(all_results) self.init_computation = \ self.add_computation(computation(doall(self.state_initialization_ops)).named('init')) all_ops.append(self.init_computation.computation) # Collect up all ops from the graph and obtain the init graph all_ops = OrderedSet(Op.ordered_ops(all_ops)) def init_tensor_description(tensor_description): if tensor_description.buffer is None: tensor_description.buffer = self.device_buffer_storage( tensor_description.base.tensor_size, tensor_description.dtype, tensor_description.name ) self.device_buffers.add(tensor_description.buffer) tensor_description.value = \ tensor_description.buffer.device_tensor(tensor_description) for state in self.init_states: init_tensor_description(state.tensor_description()) self.ops = Op.ordered_ops(all_ops) for op in self.ops: if op.is_tensor_op: init_tensor_description(op.tensor_description()) self.start_transform_allocate() for device_buffer in self.device_buffers: device_buffer.transform_allocate() self.finish_transform_allocate() # Compile the computations now that we know their storage for comp in self.computations: comp.computation_name = \ self.transform_ordered_ops(Op.ordered_ops([comp.computation]), name=comp.name) self.finish_transform() self.finalized = True
def do_pass(self, ops, transformer): ops = OrderedSet(op.forwarded for op in ops) for op in reversed(Op.ordered_ops(ops)): if op.metadata.get('marker') == 'gather': # op is GatherRecvOp self.parallel_axes = op.metadata['parallel'] gather_send_op = op.send_nodes[0] # clone nodes for each device_id replaced_send_ops = OrderedSet() new_gather_send_nodes = OrderedSet() for i, id in enumerate(op.from_id): new_gather_send_op, new_sends, replaced_sends = clone_graph( root=gather_send_op, clone_id=id, shared_queues_idx=i, parallel_axis=self.parallel_axes, num_clones=len(op.from_id)) new_gather_send_nodes.add(new_gather_send_op) new_sends.add(new_gather_send_op) for o in new_sends: self.send_nodes.add(o) replaced_send_ops |= replaced_sends op.send_nodes = new_gather_send_nodes replaced_send_ops.add(gather_send_op) for o in replaced_send_ops: self.send_nodes.remove(o)
def do_pass(self, min_ops, transformer): """ Visit the ops until nothing changes. Args: min_ops: The set of ops that must be computed. transformer: An InitGraph object. """ assert isinstance( min_ops, Iterable), "Ops passed into do_pass must be an iterable" has_work = True while True: if not has_work: return self.replacement_list = [] # pass through the ops in an execution order collecting things to do ops = Op.ordered_ops(op.forwarded for op in min_ops) for op in ops: op.update_forwards() self.visit(op) # Perform the gathered replacements for old, rep in self.replacement_list: old.forwarded.replace_self(rep.forwarded) has_work = len(self.replacement_list) > 0 min_ops = list(op.forwarded for op in min_ops)
def Computation(self, request_iterator, context): logger.info("server: computation") if not self.transformer: return hetr_pb2.ComputationReply( comp_id=-1, message="build transformer before computation") try: comp_id = self.new_comp_id() pb_ops, pb_edges = [], [] returns, placeholders = [], [] reconstructed_returns, reconstructed_placeholders = [], [] for request in request_iterator: pb_ops.extend(request.ops) pb_edges.extend(request.edges) returns.extend([protobuf_to_op(op) for op in request.returns]) placeholders.extend( [protobuf_to_op(op) for op in request.placeholders]) subgraph = _deserialize_graph_ops_edges(pb_ops, pb_edges) ops = Op.ordered_ops(subgraph) for r in returns: for op in ops: if op.uuid == r.uuid: reconstructed_returns.append(op) for p in placeholders: for op in ops: if op.uuid == p.uuid: reconstructed_placeholders.append(op) computation = self.transformer.computation( reconstructed_returns, *reconstructed_placeholders) self.computations[comp_id] = computation return hetr_pb2.ComputationReply(comp_id=comp_id) except Exception: return hetr_pb2.ComputationReply(comp_id=-1, message=traceback.format_exc())
def Computation(self, request, context): if not self.transformer: return hetr_pb2.ComputationReply(comp_id=-1) try: comp_id = self.new_comp_id() subgraph = _deserialize_graph(request.subgraph) returns = [] placeholders = [] for pb_op in request.returns: returns.append(protobuf_to_op(pb_op)) for pb_op in request.placeholders: placeholders.append(protobuf_to_op(pb_op)) return_list = [] placeholder_list = [] ops = Op.ordered_ops(subgraph) for op in ops: for r in returns: if op.uuid == r.uuid: return_list.append(op) for op in ops: for p in placeholders: if op.uuid == p.uuid: placeholder_list.append(op) computation = self.transformer.computation(return_list, *placeholder_list) self.computations[comp_id] = computation return hetr_pb2.ComputationReply(comp_id=comp_id) except: return hetr_pb2.ComputationReply(comp_id=-1)
def add_initialization_ops(self, ops): """ Ensure initializations have been captured for state in ops. Args: ops: Collection of ops. Returns: True if new initializations were added. """ did_work = False for op in ops: if op in self.init_checked_ops: continue self.init_checked_ops.add(op) new_inits = self.state_initializations(op.states_read) new_inits.update(self.state_initializations(op.states_written)) if len(new_inits) > 0: did_work = True self.state_initialization_ops.update(new_inits) self.add_initialization_ops(Op.ordered_ops(new_inits)) self.state_initialization_ops = \ OrderedSet(op.forwarded for op in self.state_initialization_ops) return did_work
def update_parallel_axis(root, parallel_axis): for op in Op.ordered_ops([root]): if hasattr(op, 'reduction_axes') and parallel_axis in op.reduction_axes: op.reduction_axes = set_parallel_axes(op.reduction_axes, parallel_axis) if getattr(op, 'axes', None) is not None \ and parallel_axis in Axes.as_flattened_list(op.axes): # if parallel_axis in Axes.as_flattened_list(op.axes): op._axes = set_parallel_axes(op.axes, parallel_axis) if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = set_parallel_axes(op.x_out_axes, parallel_axis) elif parallel_axis in op.y_out_axes: op.y_out_axes = set_parallel_axes(op.y_out_axes, parallel_axis) else: raise ValueError("Missing parallel_axis in Op's " "x_out_axes or y_out_axes") if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes: op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis)
def _transform_computations(self): """ Transform computation graphs to a form that can be run. """ # with Op.saved_user_deps(): # Run passes on the computation graphs self.run_registered_graph_passes(self.all_results) # Collect up all ops from the graph and obtain the init graph all_ops = OrderedSet(Op.ordered_ops(self.all_results)) init_op = doall(self.ordered_initializers(all_ops)) # Run passes on the initialization graphs self.run_registered_graph_passes([init_op]) # Union the init and computation graphs self.inits = Op.ordered_ops([init_op]) all_ops.update(self.inits) # create computation which initializes values (called once per session) init_op.update_forwards() self.init_computation = self.computation(init_op, name="init") # Give ids for op in all_ops: if op not in self.opids: self.opids[op] = len(self.opids) self.dataflow, self.memory = assign_buffers(self, all_ops, self.fusion) # Initialize tensor descriptions for op in all_ops: self.initialize_tensor_descriptions(op) self.ops = self.dataflow.instructions self.start_transform_allocate() for device_buffer in self.device_buffers: device_buffer.transform_allocate() self.finish_transform_allocate() # Compile the computations now that we know their storage for computation in self.computations: computation.transform() self.finish_transform() self.finalized = True
def do_pass(self, min_ops, transformer): """ Visit the ops until nothing changes. Args: min_ops: The set of ops that must be computed. transformer: An InitGraph object. """ assert isinstance(min_ops, Iterable), "Ops passed into do_pass must be an iterable" has_work = True while True: ops = Op.ordered_ops(min_ops) # Check for ops that added state that needs to be initialized, so they can # be added to the initialization function. has_new_inits = transformer.add_initialization_ops(ops) if not has_work and not has_new_inits: return self.replacement_list = [] # Make control dependency adjustments for any added control blocks. ops = Op.ordered_ops(op.forwarded for op in transformer.state_initialization_ops + min_ops) for op in ops: for cop in op.control_deps: if isinstance(cop, ParallelOp): op.remove_control_dep(cop) for dep in cop.control_deps: op.add_control_dep(dep) if isinstance(op, SequentialOp) and not op.control_dependencies_computed: op.compute_control_dependencies() # pass through the ops in an execution order collecting things to do ops = Op.ordered_ops(op.forwarded for op in transformer.state_initialization_ops + min_ops) for op in ops: op.update_forwards() self.visit(op) for old, rep in self.replacement_list: old.forwarded.replace_self(rep.forwarded) has_work = len(self.replacement_list) > 0 min_ops = list(_.forwarded for _ in min_ops)
def Computation(self, request_iterator, context): logger.debug("server: computation") if not self.transformer: return hetr_pb2.ComputationReply( comp_id=-1, message="build transformer before computation") try: comp_id = self.new_comp_id() pb_ops, pb_edges = [], [] returns, placeholders = [], [] reconstructed_returns, reconstructed_placeholders = [], [] for request in request_iterator: pb_ops.extend(request.ops) pb_edges.extend(request.edges) returns.extend([protobuf_to_op(op) for op in request.returns]) placeholders.extend( [protobuf_to_op(op) for op in request.placeholders]) subgraph = _deserialize_graph_ops_edges(pb_ops, pb_edges) # Add dependency on recv op to their send op in scenarios where the send buffer # is passed as an argument to the communication call (gather/scatter) # on the root device. # This ensures that by the send buffer does not get reused before # the recv_buf gets access to items root_idx = 0 for op in Op.all_op_references(subgraph): if isinstance(op, (GatherRecvOp)) and \ MPI.COMM_WORLD.Get_rank() == op.metadata['device_id']: args = list(op._args) args.extend(op.send_node().args) op._args = tuple(args) op.invalidate_property_cache('all_deps') elif (isinstance(op, (ScatterRecvOp)) and MPI.COMM_WORLD.Get_rank() == root_idx and MPI.COMM_WORLD.Get_rank() in op.metadata['device_id']): args = list(op._args) args.extend(op.send_node().args) op._args = tuple(args) op.invalidate_property_cache('all_deps') ops = Op.ordered_ops(subgraph) for r in returns: for op in ops: if op.uuid == r.uuid: reconstructed_returns.append(op) for p in placeholders: for op in ops: if op.uuid == p.uuid: reconstructed_placeholders.append(op) computation = self.transformer.computation( reconstructed_returns, *reconstructed_placeholders) self.computations[comp_id] = computation return hetr_pb2.ComputationReply(comp_id=comp_id) except Exception: return hetr_pb2.ComputationReply(comp_id=-1, message=traceback.format_exc())
def run_pass(self, process_op, ops, **kwargs): assert isinstance(ops, Iterable), "Ops passed into do_pass must be an iterable" has_work = True while has_work: self.begin_batch() # pass through the ops in an execution order collecting things to do ops = Op.ordered_ops(op.forwarded for op in ops) for op in ops: op.update_forwards() process_op(op) has_work = self.end_batch() ops = list(op.forwarded for op in ops)
def do_pass(self, ops): assert isinstance( ops, Iterable), "Ops passed into do_pass must be an iterable" has_work = True while has_work: self.replacement_list = [] ops = set(op.forwarded for op in ops) for op in Op.ordered_ops(ops): op.update_forwards() self.visit(op) for old, rep in self.replacement_list: old.forwarded.replace_self(rep.forwarded) has_work = len(self.replacement_list) > 0 return ops
def do_pass(self, ops, **kwargs): ops = OrderedSet(op.forwarded for op in ops) for op in reversed(Op.ordered_ops(ops)): if op.metadata.get('marker') == 'gather': # op is GatherRecvOp if self.parallel_axis is None: a = op.metadata['parallel'] assert a.length % len(op.from_id) == 0, '{} can not be equally divided by {}'\ .format(a, len(op.from_id)) self.parallel_axis = make_axis( name=a.name, length=a.length // len(op.from_id), docstring='HeTr parallel axis') gather_send_op = op.send_node() update_parallel_axis(gather_send_op, self.parallel_axis)
def do_pass(self, ops, **kwargs): ops = OrderedSet(op.forwarded for op in ops) for op in reversed(Op.ordered_ops(ops)): if op.metadata.get('marker') == 'gather': # op is GatherRecvOp if self.parallel_axes is None: a = op.metadata['parallel'] assert a.length % len(op.from_id) == 0, '{} can not be equally divided by {}'\ .format(a, len(op.from_id)) self.parallel_axes = make_axis( name=a.name, length=a.length // len(op.from_id), docstring='HeTr parallel axis') gather_send_op = op.send_nodes[0] # clone nodes for each device_id replaced_send_ops = OrderedSet() new_gather_send_nodes = OrderedSet() for i, id in enumerate(op.from_id): new_gather_send_op, new_sends, replaced_sends = clone_graph( root=gather_send_op, clone_id=id, shared_queues_idx=i, parallel_axis=self.parallel_axes, num_clones=len(op.from_id)) new_gather_send_nodes.add(new_gather_send_op) new_sends.add(new_gather_send_op) for o in new_sends: self.send_nodes.add(o) replaced_send_ops |= replaced_sends op.send_nodes = new_gather_send_nodes replaced_send_ops.add(gather_send_op) for o in replaced_send_ops: self.send_nodes.remove(o)
def do_pass(self, ops, transformer): ops = OrderedSet(op.forwarded for op in ops) def set_new_axes(root, num_devices): visit = self.do_traversal(root) self.new_axes = calculate_new_axes(root.axes, self.parallel_axis, num_devices, False) while visit: node = visit.pop() if hasattr(node, 'axes'): node._TensorOp__axes = self.new_axes # Start traversal from the top to the bottom for op in reversed(Op.ordered_ops(ops)): args = list() for arg in op.args: if 'marker' in arg.metadata: if 'gather' is arg.metadata['marker']: self.parallel_axis = arg.metadata['parallel'] set_new_axes(arg.send_node(), len(arg.from_id)) for d in range(1, len(arg.from_id)): if d == (len(arg.from_id) - 1): self.new_axes = calculate_new_axes( arg.axes, self.parallel_axis, len(arg.from_id), True) nodes = self.do_traversal(arg.send_node()) self.clone_nodes(nodes, arg.from_id[d], self.new_axes, self.scatter_shared_queues[d], self.gather_shared_queues[d]) args.append(arg) if isinstance(op.args, tuple): op._Op__args = tuple(args) else: op.args(args)
def do_pass(self, ops): return len(Op.ordered_ops(ops))
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones): """ clone graph with serde (serialization) input: output: new_root of the cloned graph """ # clone nodes with GatherSendOp as root using serde ser_cloned_nodes = deserialize_graph(serialize_graph([root])) new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None) orig_ops = {op.uuid: op for op in Op.ordered_ops([root])} # Prune ops that are not control_deps of new_gather_send_op # deserialize includes extra referenced nodes cloned_graph = Op.ordered_ops([new_root]) new_send_nodes = OrderedSet() replaced_send_nodes = OrderedSet() # update newly cloned op metadata, generate new UUIDs for op in cloned_graph: cloned_ops = orig_ops[op.uuid].metadata.get('clones') if cloned_ops is None or cloned_ops.get(str(clone_id)) is None: op.metadata['transformer'] = op.metadata['device'] + str(clone_id) op.metadata['device_id'] = str(clone_id) if isinstance( op, (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)): op._shared_queues = orig_ops[op.uuid]._shared_queues op.idx = shared_queues_idx if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)): op._send_node = orig_ops[op.uuid].send_node() elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)): # Cloning a recv node means we need a broadcast, so simulate one by adding an # additional sender with the same input data as the original sender. send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0]) op._queue = send_op.queue op._send_node = send_op new_send_nodes.add(send_op) replaced_send_nodes.add(orig_ops[op.uuid].send_node()) if hasattr( op, 'reduction_axes') and parallel_axis in op.reduction_axes: op.reduction_axes = set_parallel_axes(op.reduction_axes, parallel_axis) if getattr(op, 'axes', None) is not None \ and parallel_axis in Axes.as_flattened_list(op.axes): # if parallel_axis in Axes.as_flattened_list(op.axes): op._axes = set_parallel_axes(op.axes, parallel_axis) if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = set_parallel_axes( op.x_out_axes, parallel_axis) elif parallel_axis in op.y_out_axes: op.y_out_axes = set_parallel_axes( op.y_out_axes, parallel_axis) else: raise ValueError("Missing parallel_axis in Op's " "x_out_axes or y_out_axes") if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes: op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis) args_list = list(op.args) for arg_idx, arg_op in enumerate(args_list): if arg_op.uuid in orig_ops.keys(): if orig_ops[arg_op.uuid].metadata.get('clones') and \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)): args_list[arg_idx] = \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)) op.invalidate_property_cache('all_deps') op._args = tuple(args_list) if op != new_root: if orig_ops[op.uuid].metadata.get('clones') is None: orig_ops[op.uuid].metadata['clones'] = dict() orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op else: orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op op.uuid = uuid.uuid4() return new_root, new_send_nodes, replaced_send_nodes
def clone_graph(root, clone_id, parallel_axis): """ clone graph with serde (serialization) input: output: new_root of the cloned graph """ # clone nodes with GatherSendOp as root using serde ser_cloned_nodes = deserialize_graph(serialize_graph([root])) new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None) orig_ops = {op.uuid: op for op in Op.ordered_ops([root])} cloned_graph = Op.ordered_ops([new_root]) new_send_nodes = OrderedSet() replaced_send_nodes = OrderedSet() # update newly cloned op metadata, generate new UUIDs for op in cloned_graph: cloned_ops = orig_ops[op.uuid].metadata.get('clones') if cloned_ops is None or cloned_ops.get(str(clone_id)) is None: op.metadata['transformer'] = op.metadata['device'] + str(clone_id) op.metadata['device_id'] = str(clone_id) if isinstance( op, (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)): # for gpu communication op buffer op.idx = int(clone_id) if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)): op._send_node = orig_ops[op.uuid].send_node() if hasattr( op, 'reduction_axes') and parallel_axis in op.reduction_axes: op.reduction_axes = set_parallel_axes(op.reduction_axes, parallel_axis) if getattr(op, 'axes', None) is not None \ and parallel_axis in Axes.as_flattened_list(op.axes): # if parallel_axis in Axes.as_flattened_list(op.axes): op._axes = set_parallel_axes(op.axes, parallel_axis) if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = set_parallel_axes( op.x_out_axes, parallel_axis) elif parallel_axis in op.y_out_axes: op.y_out_axes = set_parallel_axes( op.y_out_axes, parallel_axis) else: raise ValueError("Missing parallel_axis in Op's " "x_out_axes or y_out_axes") if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes: op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis) args_list = list(op.args) for arg_idx, arg_op in enumerate(args_list): if arg_op.uuid in orig_ops.keys(): if orig_ops[arg_op.uuid].metadata.get('clones') and \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)): args_list[arg_idx] = \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)) op.invalidate_property_cache('all_deps') op._args = tuple(args_list) if op != new_root: if orig_ops[op.uuid].metadata.get('clones') is None: orig_ops[op.uuid].metadata['clones'] = dict() orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op else: orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op op.uuid = uuid.uuid4() # create new uuids for all the ops that have references to the new root for _op in Op.all_op_references([new_root]): _op.uuid = uuid.uuid4() return new_root, new_send_nodes, replaced_send_nodes