Ejemplo n.º 1
0
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones):
    """
    clone graph with serde (serialization)
    input:
    output: new_root of the cloned graph
    """
    # clone nodes with GatherSendOp as root using serde
    ser_cloned_nodes = deserialize_graph(serialize_graph([root]))
    new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None)

    orig_ops = {op.uuid: op for op in Op.ordered_ops([root])}
    # Prune ops that are not control_deps of new_gather_send_op
    # deserialize includes extra referenced nodes
    cloned_graph = Op.ordered_ops([new_root])

    new_send_nodes = OrderedSet()
    replaced_send_nodes = OrderedSet()

    # update newly cloned op metadata, generate new UUIDs
    for op in cloned_graph:
        op.metadata['transformer'] = op.metadata['device'] + str(clone_id)
        op.metadata['device_id'] = str(clone_id)

        if isinstance(op, (ScatterRecvOp, GatherSendOp)):
            op._shared_queues = orig_ops[op.uuid]._shared_queues
            op.idx = shared_queues_idx
            if isinstance(op, ScatterRecvOp):
                op._send_node = orig_ops[op.uuid].send_node()
        elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)):
            # Cloning a recv node means we need a broadcast, so simulate one by adding an
            # additional sender with the same input data as the original sender.
            # TODO replace with real broadcast #1398 #1399
            send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0])
            op._queue = send_op.queue
            op._send_node = send_op
            new_send_nodes.add(send_op)
            replaced_send_nodes.add(orig_ops[op.uuid].send_node())

        if hasattr(op, '_axes') and parallel_axis in op._axes:
            op._axes = calculate_scatter_axes(op.axes, parallel_axis,
                                              num_clones)
            # TODO: Revisit to handle axes updation better. Github Ticket #1355
            if isinstance(op, DotOp):
                if parallel_axis in op.x_out_axes:
                    op.x_out_axes = calculate_scatter_axes(
                        op.x_out_axes, parallel_axis, num_clones)
                elif parallel_axis in op.y_out_axes:
                    op.y_out_axes = calculate_scatter_axes(
                        op.y_out_axes, parallel_axis, num_clones)
                else:
                    raise ValueError(
                        "Missing parallel_axis in Op's x_out_axes or y_out_axes"
                    )
        op.uuid = uuid.uuid4()

    return new_root, new_send_nodes, replaced_send_nodes
Ejemplo n.º 2
0
def test_calculate_new_axes_no_remainder(axis, num):
    new_axes = calculate_scatter_axes(axes=axes,
                                      scatter_axis=axis,
                                      num_devices=num)
    expected_axes = ng.make_axes([
        a if a != axis else ng.make_axis(length=axis.length / num, name=a.name)
        for a in axes
    ])
    assert new_axes.full_lengths == expected_axes.full_lengths
Ejemplo n.º 3
0
def test_calculate_new_axes_null_parallel_axis():
    new_axes = calculate_scatter_axes(axes=axes, scatter_axis=None, num_devices=1)
    # Checks null parallel axis. The axes calculated should have the same length as original
    assert new_axes.full_lengths == axes.full_lengths
Ejemplo n.º 4
0
def test_calculate_new_axes_null_axes():
    with pytest.raises(TypeError):
        calculate_scatter_axes(axes=None, scatter_axis=ax_B, num_devices=2)
Ejemplo n.º 5
0
def test_calculate_new_axes_zero_devices():
    with pytest.raises(ZeroDivisionError):
        calculate_scatter_axes(axes=axes, scatter_axis=ax_B, num_devices=0)
Ejemplo n.º 6
0
def tests_calculate_new_axes_has_remainder(axis, num):
    with pytest.raises(AssertionError):
        calculate_scatter_axes(axes=axes, scatter_axis=axis, num_devices=num)
Ejemplo n.º 7
0
def test_calculate_new_axes_single_device():
    new_axes = calculate_scatter_axes(axes=axes, scatter_axis=ax_B, num_devices=1)
    assert new_axes.full_lengths == axes.full_lengths
Ejemplo n.º 8
0
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones):
    """
    clone graph with serde (serialization)
    input:
    output: new_root of the cloned graph
    """
    # clone nodes with GatherSendOp as root using serde
    ser_cloned_nodes = deserialize_graph(serialize_graph([root]))
    new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None)

    orig_ops = {op.uuid: op for op in Op.ordered_ops([root])}
    # Prune ops that are not control_deps of new_gather_send_op
    # deserialize includes extra referenced nodes
    cloned_graph = Op.ordered_ops([new_root])

    new_send_nodes = OrderedSet()
    replaced_send_nodes = OrderedSet()

    # update newly cloned op metadata, generate new UUIDs
    for op in cloned_graph:
        cloned_ops = orig_ops[op.uuid].metadata.get('clones')
        if cloned_ops is None or cloned_ops.get(str(clone_id)) is None:
            op.metadata['transformer'] = op.metadata['device'] + str(clone_id)
            op.metadata['device_id'] = str(clone_id)

            if isinstance(
                    op,
                (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)):
                op._shared_queues = orig_ops[op.uuid]._shared_queues
                op.idx = shared_queues_idx
                if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)):
                    op._send_node = orig_ops[op.uuid].send_node()
            elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)):
                # Cloning a recv node means we need a broadcast, so simulate one by adding an
                # additional sender with the same input data as the original sender.
                send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0])
                op._queue = send_op.queue
                op._send_node = send_op
                new_send_nodes.add(send_op)
                replaced_send_nodes.add(orig_ops[op.uuid].send_node())
            if hasattr(op, '_axes') and parallel_axis in op._axes:
                op._axes = calculate_scatter_axes(op.axes, parallel_axis,
                                                  num_clones)
                # TODO: Revisit to handle axes updation better. Github Ticket #1355
                if isinstance(op, DotOp):
                    if parallel_axis in op.x_out_axes:
                        op.x_out_axes = calculate_scatter_axes(
                            op.x_out_axes, parallel_axis, num_clones)
                    elif parallel_axis in op.y_out_axes:
                        op.y_out_axes = calculate_scatter_axes(
                            op.y_out_axes, parallel_axis, num_clones)
                    else:
                        raise ValueError(
                            "Missing parallel_axis in Op's x_out_axes or y_out_axes"
                        )

            if hasattr(
                    op,
                    'reduction_axes') and parallel_axis in op.reduction_axes:
                op.reduction_axes = calculate_scatter_axes(
                    op.reduction_axes, parallel_axis, num_clones)

            if isinstance(op,
                          TensorValueOp) and parallel_axis in op.tensor.axes:
                op.tensor._axes = calculate_scatter_axes(
                    op.tensor.axes, parallel_axis, num_clones)

            args_list = list(op.args)
            for arg_idx, arg_op in enumerate(args_list):
                if arg_op.uuid in orig_ops.keys():
                    if orig_ops[arg_op.uuid].metadata.get('clones') and \
                       orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)):
                        args_list[arg_idx] = \
                            orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id))
            op.invalidate_property_cache('all_deps')
            op._args = tuple(args_list)
            if op != new_root:
                if orig_ops[op.uuid].metadata.get('clones') is None:
                    orig_ops[op.uuid].metadata['clones'] = dict()
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op
                else:
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op

            op.uuid = uuid.uuid4()

    return new_root, new_send_nodes, replaced_send_nodes