def update_parallel_axis(root, parallel_axis): for op in Op.ordered_ops([root]): if hasattr(op, 'reduction_axes') and parallel_axis in op.reduction_axes: op.reduction_axes = set_parallel_axes(op.reduction_axes, parallel_axis) if getattr(op, 'axes', None) is not None \ and parallel_axis in Axes.as_flattened_list(op.axes): # if parallel_axis in Axes.as_flattened_list(op.axes): op._axes = set_parallel_axes(op.axes, parallel_axis) if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = set_parallel_axes(op.x_out_axes, parallel_axis) elif parallel_axis in op.y_out_axes: op.y_out_axes = set_parallel_axes(op.y_out_axes, parallel_axis) else: raise ValueError("Missing parallel_axis in Op's " "x_out_axes or y_out_axes") if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes: op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis)
def clone_graph(root, clone_id, parallel_axis): """ clone graph with serde (serialization) input: output: new_root of the cloned graph """ # clone nodes with GatherSendOp as root using serde ser_cloned_nodes = deserialize_graph(serialize_graph([root])) new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None) orig_ops = {op.uuid: op for op in Op.ordered_ops([root])} cloned_graph = Op.ordered_ops([new_root]) new_send_nodes = OrderedSet() replaced_send_nodes = OrderedSet() # update newly cloned op metadata, generate new UUIDs for op in cloned_graph: cloned_ops = orig_ops[op.uuid].metadata.get('clones') if cloned_ops is None or cloned_ops.get(str(clone_id)) is None: op.metadata['transformer'] = op.metadata['device'] + str(clone_id) op.metadata['device_id'] = str(clone_id) if isinstance( op, (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)): # for gpu communication op buffer op.idx = int(clone_id) if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)): op._send_node = orig_ops[op.uuid].send_node() if hasattr( op, 'reduction_axes') and parallel_axis in op.reduction_axes: op.reduction_axes = set_parallel_axes(op.reduction_axes, parallel_axis) if getattr(op, 'axes', None) is not None \ and parallel_axis in Axes.as_flattened_list(op.axes): # if parallel_axis in Axes.as_flattened_list(op.axes): op._axes = set_parallel_axes(op.axes, parallel_axis) if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = set_parallel_axes( op.x_out_axes, parallel_axis) elif parallel_axis in op.y_out_axes: op.y_out_axes = set_parallel_axes( op.y_out_axes, parallel_axis) else: raise ValueError("Missing parallel_axis in Op's " "x_out_axes or y_out_axes") if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes: op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis) args_list = list(op.args) for arg_idx, arg_op in enumerate(args_list): if arg_op.uuid in orig_ops.keys(): if orig_ops[arg_op.uuid].metadata.get('clones') and \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)): args_list[arg_idx] = \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)) op.invalidate_property_cache('all_deps') op._args = tuple(args_list) if op != new_root: if orig_ops[op.uuid].metadata.get('clones') is None: orig_ops[op.uuid].metadata['clones'] = dict() orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op else: orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op op.uuid = uuid.uuid4() # create new uuids for all the ops that have references to the new root for _op in Op.all_op_references([new_root]): _op.uuid = uuid.uuid4() return new_root, new_send_nodes, replaced_send_nodes
def test_calculate_new_axes_null_parallel_axis(): new_axes = set_parallel_axes(axes=axes, parallel_axis=None) # Checks null parallel axis. The axes calculated should have the same length as original assert new_axes.full_lengths == axes.full_lengths
def test_calculate_new_axes_null_axes(): with pytest.raises(TypeError): set_parallel_axes(axes=None, parallel_axis=ax_B)
def test_calculate_new_axes_single_device(): new_axes = set_parallel_axes(axes=axes, parallel_axis=ax_B) assert new_axes.full_lengths == axes.full_lengths
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones): """ clone graph with serde (serialization) input: output: new_root of the cloned graph """ # clone nodes with GatherSendOp as root using serde ser_cloned_nodes = deserialize_graph(serialize_graph([root])) new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None) orig_ops = {op.uuid: op for op in Op.ordered_ops([root])} # Prune ops that are not control_deps of new_gather_send_op # deserialize includes extra referenced nodes cloned_graph = Op.ordered_ops([new_root]) new_send_nodes = OrderedSet() replaced_send_nodes = OrderedSet() # update newly cloned op metadata, generate new UUIDs for op in cloned_graph: cloned_ops = orig_ops[op.uuid].metadata.get('clones') if cloned_ops is None or cloned_ops.get(str(clone_id)) is None: op.metadata['transformer'] = op.metadata['device'] + str(clone_id) op.metadata['device_id'] = str(clone_id) if isinstance( op, (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)): op._shared_queues = orig_ops[op.uuid]._shared_queues op.idx = shared_queues_idx if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)): op._send_node = orig_ops[op.uuid].send_node() elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)): # Cloning a recv node means we need a broadcast, so simulate one by adding an # additional sender with the same input data as the original sender. send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0]) op._queue = send_op.queue op._send_node = send_op new_send_nodes.add(send_op) replaced_send_nodes.add(orig_ops[op.uuid].send_node()) if hasattr( op, 'reduction_axes') and parallel_axis in op.reduction_axes: op.reduction_axes = set_parallel_axes(op.reduction_axes, parallel_axis) if getattr(op, 'axes', None) is not None \ and parallel_axis in Axes.as_flattened_list(op.axes): # if parallel_axis in Axes.as_flattened_list(op.axes): op._axes = set_parallel_axes(op.axes, parallel_axis) if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = set_parallel_axes( op.x_out_axes, parallel_axis) elif parallel_axis in op.y_out_axes: op.y_out_axes = set_parallel_axes( op.y_out_axes, parallel_axis) else: raise ValueError("Missing parallel_axis in Op's " "x_out_axes or y_out_axes") if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes: op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis) args_list = list(op.args) for arg_idx, arg_op in enumerate(args_list): if arg_op.uuid in orig_ops.keys(): if orig_ops[arg_op.uuid].metadata.get('clones') and \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)): args_list[arg_idx] = \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)) op.invalidate_property_cache('all_deps') op._args = tuple(args_list) if op != new_root: if orig_ops[op.uuid].metadata.get('clones') is None: orig_ops[op.uuid].metadata['clones'] = dict() orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op else: orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op op.uuid = uuid.uuid4() return new_root, new_send_nodes, replaced_send_nodes