def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones): """ clone graph with serde (serialization) input: output: new_root of the cloned graph """ # clone nodes with GatherSendOp as root using serde ser_cloned_nodes = deserialize_graph(serialize_graph([root])) new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None) orig_ops = {op.uuid: op for op in Op.ordered_ops([root])} # Prune ops that are not control_deps of new_gather_send_op # deserialize includes extra referenced nodes cloned_graph = Op.ordered_ops([new_root]) new_send_nodes = OrderedSet() replaced_send_nodes = OrderedSet() # update newly cloned op metadata, generate new UUIDs for op in cloned_graph: op.metadata['transformer'] = op.metadata['device'] + str(clone_id) op.metadata['device_id'] = str(clone_id) if isinstance(op, (ScatterRecvOp, GatherSendOp)): op._shared_queues = orig_ops[op.uuid]._shared_queues op.idx = shared_queues_idx if isinstance(op, ScatterRecvOp): op._send_node = orig_ops[op.uuid].send_node() elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)): # Cloning a recv node means we need a broadcast, so simulate one by adding an # additional sender with the same input data as the original sender. # TODO replace with real broadcast #1398 #1399 send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0]) op._queue = send_op.queue op._send_node = send_op new_send_nodes.add(send_op) replaced_send_nodes.add(orig_ops[op.uuid].send_node()) if hasattr(op, '_axes') and parallel_axis in op._axes: op._axes = calculate_scatter_axes(op.axes, parallel_axis, num_clones) # TODO: Revisit to handle axes updation better. Github Ticket #1355 if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = calculate_scatter_axes( op.x_out_axes, parallel_axis, num_clones) elif parallel_axis in op.y_out_axes: op.y_out_axes = calculate_scatter_axes( op.y_out_axes, parallel_axis, num_clones) else: raise ValueError( "Missing parallel_axis in Op's x_out_axes or y_out_axes" ) op.uuid = uuid.uuid4() return new_root, new_send_nodes, replaced_send_nodes
def test_calculate_new_axes_no_remainder(axis, num): new_axes = calculate_scatter_axes(axes=axes, scatter_axis=axis, num_devices=num) expected_axes = ng.make_axes([ a if a != axis else ng.make_axis(length=axis.length / num, name=a.name) for a in axes ]) assert new_axes.full_lengths == expected_axes.full_lengths
def test_calculate_new_axes_null_parallel_axis(): new_axes = calculate_scatter_axes(axes=axes, scatter_axis=None, num_devices=1) # Checks null parallel axis. The axes calculated should have the same length as original assert new_axes.full_lengths == axes.full_lengths
def test_calculate_new_axes_null_axes(): with pytest.raises(TypeError): calculate_scatter_axes(axes=None, scatter_axis=ax_B, num_devices=2)
def test_calculate_new_axes_zero_devices(): with pytest.raises(ZeroDivisionError): calculate_scatter_axes(axes=axes, scatter_axis=ax_B, num_devices=0)
def tests_calculate_new_axes_has_remainder(axis, num): with pytest.raises(AssertionError): calculate_scatter_axes(axes=axes, scatter_axis=axis, num_devices=num)
def test_calculate_new_axes_single_device(): new_axes = calculate_scatter_axes(axes=axes, scatter_axis=ax_B, num_devices=1) assert new_axes.full_lengths == axes.full_lengths
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones): """ clone graph with serde (serialization) input: output: new_root of the cloned graph """ # clone nodes with GatherSendOp as root using serde ser_cloned_nodes = deserialize_graph(serialize_graph([root])) new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None) orig_ops = {op.uuid: op for op in Op.ordered_ops([root])} # Prune ops that are not control_deps of new_gather_send_op # deserialize includes extra referenced nodes cloned_graph = Op.ordered_ops([new_root]) new_send_nodes = OrderedSet() replaced_send_nodes = OrderedSet() # update newly cloned op metadata, generate new UUIDs for op in cloned_graph: cloned_ops = orig_ops[op.uuid].metadata.get('clones') if cloned_ops is None or cloned_ops.get(str(clone_id)) is None: op.metadata['transformer'] = op.metadata['device'] + str(clone_id) op.metadata['device_id'] = str(clone_id) if isinstance( op, (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)): op._shared_queues = orig_ops[op.uuid]._shared_queues op.idx = shared_queues_idx if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)): op._send_node = orig_ops[op.uuid].send_node() elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)): # Cloning a recv node means we need a broadcast, so simulate one by adding an # additional sender with the same input data as the original sender. send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0]) op._queue = send_op.queue op._send_node = send_op new_send_nodes.add(send_op) replaced_send_nodes.add(orig_ops[op.uuid].send_node()) if hasattr(op, '_axes') and parallel_axis in op._axes: op._axes = calculate_scatter_axes(op.axes, parallel_axis, num_clones) # TODO: Revisit to handle axes updation better. Github Ticket #1355 if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = calculate_scatter_axes( op.x_out_axes, parallel_axis, num_clones) elif parallel_axis in op.y_out_axes: op.y_out_axes = calculate_scatter_axes( op.y_out_axes, parallel_axis, num_clones) else: raise ValueError( "Missing parallel_axis in Op's x_out_axes or y_out_axes" ) if hasattr( op, 'reduction_axes') and parallel_axis in op.reduction_axes: op.reduction_axes = calculate_scatter_axes( op.reduction_axes, parallel_axis, num_clones) if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes: op.tensor._axes = calculate_scatter_axes( op.tensor.axes, parallel_axis, num_clones) args_list = list(op.args) for arg_idx, arg_op in enumerate(args_list): if arg_op.uuid in orig_ops.keys(): if orig_ops[arg_op.uuid].metadata.get('clones') and \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)): args_list[arg_idx] = \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)) op.invalidate_property_cache('all_deps') op._args = tuple(args_list) if op != new_root: if orig_ops[op.uuid].metadata.get('clones') is None: orig_ops[op.uuid].metadata['clones'] = dict() orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op else: orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op op.uuid = uuid.uuid4() return new_root, new_send_nodes, replaced_send_nodes