Example #1
0
def update_parallel_axis(root, parallel_axis):
    for op in Op.ordered_ops([root]):

        if hasattr(op,
                   'reduction_axes') and parallel_axis in op.reduction_axes:
            op.reduction_axes = set_parallel_axes(op.reduction_axes,
                                                  parallel_axis)

        if getattr(op, 'axes', None) is not None \
                and parallel_axis in Axes.as_flattened_list(op.axes):
            # if parallel_axis in Axes.as_flattened_list(op.axes):
            op._axes = set_parallel_axes(op.axes, parallel_axis)
            if isinstance(op, DotOp):
                if parallel_axis in op.x_out_axes:
                    op.x_out_axes = set_parallel_axes(op.x_out_axes,
                                                      parallel_axis)
                elif parallel_axis in op.y_out_axes:
                    op.y_out_axes = set_parallel_axes(op.y_out_axes,
                                                      parallel_axis)
                else:
                    raise ValueError("Missing parallel_axis in Op's "
                                     "x_out_axes or y_out_axes")

        if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes:
            op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis)
Example #2
0
def clone_graph(root, clone_id, parallel_axis):
    """
    clone graph with serde (serialization)
    input:
    output: new_root of the cloned graph
    """

    # clone nodes with GatherSendOp as root using serde
    ser_cloned_nodes = deserialize_graph(serialize_graph([root]))

    new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None)

    orig_ops = {op.uuid: op for op in Op.ordered_ops([root])}
    cloned_graph = Op.ordered_ops([new_root])

    new_send_nodes = OrderedSet()
    replaced_send_nodes = OrderedSet()

    # update newly cloned op metadata, generate new UUIDs
    for op in cloned_graph:
        cloned_ops = orig_ops[op.uuid].metadata.get('clones')
        if cloned_ops is None or cloned_ops.get(str(clone_id)) is None:
            op.metadata['transformer'] = op.metadata['device'] + str(clone_id)
            op.metadata['device_id'] = str(clone_id)

            if isinstance(
                    op,
                (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)):
                # for gpu communication op buffer
                op.idx = int(clone_id)
                if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)):
                    op._send_node = orig_ops[op.uuid].send_node()

            if hasattr(
                    op,
                    'reduction_axes') and parallel_axis in op.reduction_axes:
                op.reduction_axes = set_parallel_axes(op.reduction_axes,
                                                      parallel_axis)

            if getattr(op, 'axes', None) is not None \
                    and parallel_axis in Axes.as_flattened_list(op.axes):
                # if parallel_axis in Axes.as_flattened_list(op.axes):
                op._axes = set_parallel_axes(op.axes, parallel_axis)
                if isinstance(op, DotOp):
                    if parallel_axis in op.x_out_axes:
                        op.x_out_axes = set_parallel_axes(
                            op.x_out_axes, parallel_axis)
                    elif parallel_axis in op.y_out_axes:
                        op.y_out_axes = set_parallel_axes(
                            op.y_out_axes, parallel_axis)
                    else:
                        raise ValueError("Missing parallel_axis in Op's "
                                         "x_out_axes or y_out_axes")

            if isinstance(op,
                          TensorValueOp) and parallel_axis in op.tensor.axes:
                op.tensor._axes = set_parallel_axes(op.tensor.axes,
                                                    parallel_axis)

            args_list = list(op.args)
            for arg_idx, arg_op in enumerate(args_list):
                if arg_op.uuid in orig_ops.keys():
                    if orig_ops[arg_op.uuid].metadata.get('clones') and \
                       orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)):
                        args_list[arg_idx] = \
                            orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id))

            op.invalidate_property_cache('all_deps')
            op._args = tuple(args_list)
            if op != new_root:
                if orig_ops[op.uuid].metadata.get('clones') is None:
                    orig_ops[op.uuid].metadata['clones'] = dict()
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op
                else:
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op

            op.uuid = uuid.uuid4()

    # create new uuids for all the ops that have references to the new root
    for _op in Op.all_op_references([new_root]):
        _op.uuid = uuid.uuid4()

    return new_root, new_send_nodes, replaced_send_nodes
Example #3
0
def test_calculate_new_axes_null_parallel_axis():
    new_axes = set_parallel_axes(axes=axes, parallel_axis=None)
    # Checks null parallel axis. The axes calculated should have the same length as original
    assert new_axes.full_lengths == axes.full_lengths
Example #4
0
def test_calculate_new_axes_null_axes():
    with pytest.raises(TypeError):
        set_parallel_axes(axes=None, parallel_axis=ax_B)
Example #5
0
def test_calculate_new_axes_single_device():
    new_axes = set_parallel_axes(axes=axes, parallel_axis=ax_B)
    assert new_axes.full_lengths == axes.full_lengths
Example #6
0
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones):
    """
    clone graph with serde (serialization)
    input:
    output: new_root of the cloned graph
    """
    # clone nodes with GatherSendOp as root using serde
    ser_cloned_nodes = deserialize_graph(serialize_graph([root]))
    new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None)

    orig_ops = {op.uuid: op for op in Op.ordered_ops([root])}
    # Prune ops that are not control_deps of new_gather_send_op
    # deserialize includes extra referenced nodes
    cloned_graph = Op.ordered_ops([new_root])

    new_send_nodes = OrderedSet()
    replaced_send_nodes = OrderedSet()

    # update newly cloned op metadata, generate new UUIDs
    for op in cloned_graph:
        cloned_ops = orig_ops[op.uuid].metadata.get('clones')
        if cloned_ops is None or cloned_ops.get(str(clone_id)) is None:
            op.metadata['transformer'] = op.metadata['device'] + str(clone_id)
            op.metadata['device_id'] = str(clone_id)

            if isinstance(
                    op,
                (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)):
                op._shared_queues = orig_ops[op.uuid]._shared_queues
                op.idx = shared_queues_idx
                if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)):
                    op._send_node = orig_ops[op.uuid].send_node()
            elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)):
                # Cloning a recv node means we need a broadcast, so simulate one by adding an
                # additional sender with the same input data as the original sender.
                send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0])
                op._queue = send_op.queue
                op._send_node = send_op
                new_send_nodes.add(send_op)
                replaced_send_nodes.add(orig_ops[op.uuid].send_node())

            if hasattr(
                    op,
                    'reduction_axes') and parallel_axis in op.reduction_axes:
                op.reduction_axes = set_parallel_axes(op.reduction_axes,
                                                      parallel_axis)

            if getattr(op, 'axes', None) is not None \
                    and parallel_axis in Axes.as_flattened_list(op.axes):
                # if parallel_axis in Axes.as_flattened_list(op.axes):
                op._axes = set_parallel_axes(op.axes, parallel_axis)
                if isinstance(op, DotOp):
                    if parallel_axis in op.x_out_axes:
                        op.x_out_axes = set_parallel_axes(
                            op.x_out_axes, parallel_axis)
                    elif parallel_axis in op.y_out_axes:
                        op.y_out_axes = set_parallel_axes(
                            op.y_out_axes, parallel_axis)
                    else:
                        raise ValueError("Missing parallel_axis in Op's "
                                         "x_out_axes or y_out_axes")

            if isinstance(op,
                          TensorValueOp) and parallel_axis in op.tensor.axes:
                op.tensor._axes = set_parallel_axes(op.tensor.axes,
                                                    parallel_axis)

            args_list = list(op.args)
            for arg_idx, arg_op in enumerate(args_list):
                if arg_op.uuid in orig_ops.keys():
                    if orig_ops[arg_op.uuid].metadata.get('clones') and \
                       orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)):
                        args_list[arg_idx] = \
                            orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id))
            op.invalidate_property_cache('all_deps')
            op._args = tuple(args_list)
            if op != new_root:
                if orig_ops[op.uuid].metadata.get('clones') is None:
                    orig_ops[op.uuid].metadata['clones'] = dict()
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op
                else:
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op

            op.uuid = uuid.uuid4()

    return new_root, new_send_nodes, replaced_send_nodes