コード例 #1
0
    def generate_default_dot_layout(op):
        """
        Generates the default layout assignment for a dot operation on GPU. By default
        we allow the first operand to be transposed but not the second.

        Output layout of the Dot operation is defined by rows which are the non-reduction
        axes from the first operand and columns which are the non-reduction axes from the
        second operand.

        Arguments:
            op (DotOp): op to generate layout for

        Returns:
            GPUDotLayoutAssignment for this op
        """
        axes_list = Axes.as_flattened_list(op.axes)
        rows_axis = [
            axes_list.index(a) for a in Axes.as_flattened_list(op.x_out_axes)
        ]
        cols_axis = [
            axes_list.index(a) for a in Axes.as_flattened_list(op.y_out_axes)
        ]
        # By default allow first argument to be transposed, but not second
        # TODO: this could be bad for perf some heuristic?
        return [
            GPUDotLayoutAssignment(True, False, axes_list,
                                   [rows_axis, cols_axis])
        ]
コード例 #2
0
 def __init__(self, op, arg):
     super(GPULutLayoutConstraint, self).__init__(op, arg)
     if len(arg.axes) == 2:
         self.order = [Axes.as_flattened_list(Axes(arg.axes[0])),
                       Axes.as_flattened_list(Axes(arg.axes[1]))]
     else:
         self.order = [Axes.as_flattened_list(arg.axes)]
コード例 #3
0
    def get_op_shape_and_layout(self, op, mkl_order, index=0):
        exop = self.get_exop(op)
        mkl_layout = exop.output_decls[index].tensor_view_decl.mkl_layout
        op_axes_mkl = [op.axes[idx] for idx in mkl_order]
        mkl_shape = [a.length for a in op_axes_mkl]
        if mkl_layout:
            (in_layout, in_axes) = mkl_layout
            # Check if we need to rotate axes in the MKL layout object
            if op_axes_mkl != in_axes:
                assert Axes(
                    get_flattened_axes(in_axes)).is_equal_set(
                    Axes(
                        get_flattened_axes(op_axes_mkl)))
                mkl_layout = get_rotated_layout(
                    self.mkldnn,
                    in_layout,
                    get_flattened_axes(in_axes),
                    get_flattened_axes(op_axes_mkl))
            else:
                mkl_layout = in_layout
        else:
            # TODO(jbobba): Need to change this to use tensor_decl
            mkl_layout = get_native_layout(self.mkldnn, exop.output_decls[
                                           index].tensor_description, mkl_order)[0]

        return mkl_shape, mkl_layout
コード例 #4
0
    def __init__(self, op, arg):
        super(GPUDotLayoutConstraint, self).__init__(op, arg)

        args = list(self.op.args)
        self.op_axes = Axes.as_flattened_list(self.op.axes)
        if self.arg.forwarded is args[0].forwarded:
            self.operand = 'A'
            self.reduction_axes = Axes.as_flattened_list(self.op.reduction_axes)
            self.out_axes = Axes.as_flattened_list(self.op.x_out_axes)
        elif self.arg.forwarded is args[1].forwarded:
            self.operand = 'B'
            self.reduction_axes = Axes.as_flattened_list(self.op.reduction_axes)
            self.out_axes = Axes.as_flattened_list(self.op.y_out_axes)
        else:
            raise ValueError("Invalid argument for constraint")
コード例 #5
0
def get_flattened_axes(x):
    """
    Ordered list of axis visible to MKLDNN
    """
    return [
        axis for axis in Axes.as_flattened_list(x) if axis.name != '__NG_DEPTH'
    ]
コード例 #6
0
ファイル: hetr_utils.py プロジェクト: leonllm/ngraph
def update_parallel_axis(root, parallel_axis):
    for op in Op.ordered_ops([root]):

        if hasattr(op,
                   'reduction_axes') and parallel_axis in op.reduction_axes:
            op.reduction_axes = set_parallel_axes(op.reduction_axes,
                                                  parallel_axis)

        if getattr(op, 'axes', None) is not None \
                and parallel_axis in Axes.as_flattened_list(op.axes):
            # if parallel_axis in Axes.as_flattened_list(op.axes):
            op._axes = set_parallel_axes(op.axes, parallel_axis)
            if isinstance(op, DotOp):
                if parallel_axis in op.x_out_axes:
                    op.x_out_axes = set_parallel_axes(op.x_out_axes,
                                                      parallel_axis)
                elif parallel_axis in op.y_out_axes:
                    op.y_out_axes = set_parallel_axes(op.y_out_axes,
                                                      parallel_axis)
                else:
                    raise ValueError("Missing parallel_axis in Op's "
                                     "x_out_axes or y_out_axes")

        if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes:
            op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis)
コード例 #7
0
    def generate_default_onehot_layout(op):
        """
        Generates the default layout assignment for a onehot operation on GPU.

        Arguments:
            op (OneHotOp): op to generate layout for

        Returns:
            GPULayoutAssignment for this op
        """
        axes_list = Axes.as_flattened_list(op.axes)
        oh_axis = axes_list.index(op.axis)
        other_group = [i for i, a in enumerate(axes_list) if a is not op.axis]

        if oh_axis == 0:
            return [
                GPUDotLayoutAssignment(True, False, axes_list,
                                       [[oh_axis], other_group])
            ]
        elif oh_axis == (len(axes_list) - 1):
            return [
                GPUDotLayoutAssignment(True, False, axes_list,
                                       [other_group, [oh_axis]])
            ]
        else:
            group0 = [i for i in other_group if i < oh_axis]
            group1 = [i for i in other_group if i > oh_axis]
            return [
                GPUDotLayoutAssignment(True, False, axes_list,
                                       [group0, [oh_axis], group1])
            ]
コード例 #8
0
    def generate_default_layout(clss, axes, max_out_axes):
        """
        Generates a default layout assignment for an elementwise operation

        Arguments:
            axes: List of axes in the output of the operation
            max_out_axes: The maximum number of strided axes supported by
                the kernel for this operation

        Return:
            A list containing a single layout assignment
        """
        axes_list = Axes.as_flattened_list(axes)

        # Need to divide op axes into `max_out_axes` sets
        if len(axes_list) > max_out_axes:
            num_splits = max_out_axes - 1
            num_axes = len(axes_list)
            split_points = list(
                reversed([(num_axes - (i + 1)) for i in range(num_splits)]))
            layout = split_points_to_groups(split_points, len(axes_list))
        else:
            layout = [[i] for i in range(len(axes_list))]

        return [clss(axes_list, layout)]
コード例 #9
0
    def generate_default_lut_layout(op):
        """
        Generates the default layout assignment for a lookup table operation on GPU.

        Arguments:
            op (LookupTableOp): op to generate layout for

        Returns:
            GPULayoutAssignment for this op
        """
        axes_list = Axes.as_flattened_list(op.axes)
        groups = Axes.as_nested_list(op.axes)
        layout = []
        for group in groups:
            if isinstance(group, list):
                layout.append([axes_list.index(a) for a in group])
            else:
                layout.append([axes_list.index(group)])
        return [GPULayoutAssignment(axes_list, layout)]
コード例 #10
0
ファイル: comm_nodes.py プロジェクト: leonllm/ngraph
def set_parallel_axes(axes, parallel_axis):
    new_axes = []
    for axis in Axes.as_nested_list(axes):
        if axis == parallel_axis:
            axis = parallel_axis
        elif isinstance(axis, collections.Iterable):
            # flattened axis
            axis = [parallel_axis if a == parallel_axis else a for a in axis]
        new_axes.append(axis)

    return make_axes(new_axes)
コード例 #11
0
ファイル: comm_nodes.py プロジェクト: rsumner31/ngraph
def set_parallel_axes(axes, parallel_axis):
    new_axes = []
    flat_names = dict()
    for i, axis in enumerate(Axes.as_nested_list(axes)):
        if axis == parallel_axis:
            axis = parallel_axis
        elif isinstance(axis, collections.Iterable):
            flat_names[i] = axes[i].name
            axis = [parallel_axis if a == parallel_axis else a for a in axis]
        new_axes.append(axis)
    new_axes = make_axes(new_axes)

    for i in flat_names:
        new_axes[i].name = flat_names[i]
    return new_axes
コード例 #12
0
    def generate_comms_layout(op, max_out_axes):
        """
        Generates layout assignment for communication operations

        Arguments:
            op (CommunicationOp): op to generate layout for
            max_out_axes: The maximum number of strided axes supported by
                the kernel for this operation

        Return:
            GPULayoutAssignment for this op
        """
        parallel_axis = op.metadata['parallel']
        axes_list = Axes.as_flattened_list(op.axes)
        if parallel_axis not in axes_list:
            return GPULayoutAssignment.generate_default_layout(
                op.axes, max_out_axes)

        parallel_axis_index = axes_list.index(parallel_axis)

        if len(axes_list) > max_out_axes:
            num_splits = max_out_axes - 1
            num_axes = len(axes_list)
            split_points = list(
                reversed([(num_axes - (i + 1)) for i in range(num_splits)]))
            layout = split_points_to_groups(split_points, len(axes_list))

            group_index = -1
            # Find which group contains the parallel axis
            for idx, group in enumerate(layout):
                if parallel_axis_index in group:
                    group_index = idx
                    break
            # Move the parallel_axis group to the first position
            parallel_group = layout[0]
            if group_index > 0:
                parallel_group = layout.pop(group_index)
                layout.insert(0, parallel_group)
            # If parallel_axis is not the first in its group make it the first
            if parallel_axis_index != parallel_group[0]:
                parallel_group.remove(parallel_axis_index)
                parallel_group.insert(0, parallel_axis_index)
        else:
            layout = [[i] for i in range(len(axes_list))
                      if i != parallel_axis_index]
            layout.insert(0, [parallel_axis_index])

        return [GPULayoutAssignment(axes_list, layout)]
コード例 #13
0
ファイル: comm_nodes.py プロジェクト: rsumner31/ngraph
    def __init__(self, from_node, parallel_axis=None):
        super(SendOp,
              self).__init__(node=from_node,
                             args=tuple([from_node]),
                             axes=self.hetr_axes(from_node.axes,
                                                 parallel_axis),
                             dtype=from_node.dtype)

        # Add native/original axes to op
        # Also ensure that the native axes has the original length
        # of the parallel_axis
        self.native_axes = Axes.as_flattened_list(from_node.axes)
        if parallel_axis is not None and parallel_axis in self.native_axes:
            p_axis_idx = self.native_axes.index(parallel_axis)
            self.native_axes[p_axis_idx] = parallel_axis
        self.native_axes = make_axes(self.native_axes)
コード例 #14
0
    def get_layout_transform(self, arg_layout, op_layout, arg):
        """
        Returns a reshape view of the argument strided to match the AssignOp axes

        Arguments:
            arg_layout (GPULayoutAssignment): layout of the argument
            op_layout: (GPULayoutAssignment): layout required by the op
            arg (TensorOp): Op producing the argument

        A GPUIndexOp which satisfies the requirements of the op_layout
        """
        arg_mem_order = flatten(arg_layout.axes)
        arg_view_axes = Axes.as_flattened_list(arg.axes)
        arg_axes = arg_layout.ng_axes
        out_groups = [[a] for a in arg_view_axes]
        return self.get_reshape(arg_mem_order, arg_axes, out_groups, arg)
コード例 #15
0
ファイル: layout_common.py プロジェクト: rsumner31/ngraph
    def generate_ew_layouts(clss, axes, max_out_axes):
        """
        Generates a set of possible layouts for an elementwise operation.

        Arguments:
            axes: List of axes in the output of the operation
            max_out_axes: The maximum number of strided axes supported by
                the kernel for this operation

        Return:
            A list of layout possibilities for this operation
        """
        # Get list of individual axes
        axes_list = Axes.as_flattened_list(axes)

        # Need to divide op axes into `max_out_axes` sets
        if len(axes_list) > max_out_axes:
            groups = get_split_groups(len(axes_list), max_out_axes)
            num_groups = max_out_axes
        else:
            groups = [[[i] for i in range(len(axes_list))]]
            num_groups = len(axes_list)

        # Find all permutations of these axis groups
        permutations = enumerate_axis_orders(tuple(range(num_groups)))

        if permutations:
            # Create EW layouts
            layouts = []
            for group in groups:
                for order in permutations:
                    layout_spec = [group[i] for i in order]
                    layouts.append(clss(axes_list, layout_spec))
        else:
            layouts = [clss(axes_list, [])]

        return layouts
コード例 #16
0
ファイル: hetr_utils.py プロジェクト: leonllm/ngraph
def clone_graph(root, clone_id, parallel_axis):
    """
    clone graph with serde (serialization)
    input:
    output: new_root of the cloned graph
    """

    # clone nodes with GatherSendOp as root using serde
    ser_cloned_nodes = deserialize_graph(serialize_graph([root]))

    new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None)

    orig_ops = {op.uuid: op for op in Op.ordered_ops([root])}
    cloned_graph = Op.ordered_ops([new_root])

    new_send_nodes = OrderedSet()
    replaced_send_nodes = OrderedSet()

    # update newly cloned op metadata, generate new UUIDs
    for op in cloned_graph:
        cloned_ops = orig_ops[op.uuid].metadata.get('clones')
        if cloned_ops is None or cloned_ops.get(str(clone_id)) is None:
            op.metadata['transformer'] = op.metadata['device'] + str(clone_id)
            op.metadata['device_id'] = str(clone_id)

            if isinstance(
                    op,
                (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)):
                # for gpu communication op buffer
                op.idx = int(clone_id)
                if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)):
                    op._send_node = orig_ops[op.uuid].send_node()

            if hasattr(
                    op,
                    'reduction_axes') and parallel_axis in op.reduction_axes:
                op.reduction_axes = set_parallel_axes(op.reduction_axes,
                                                      parallel_axis)

            if getattr(op, 'axes', None) is not None \
                    and parallel_axis in Axes.as_flattened_list(op.axes):
                # if parallel_axis in Axes.as_flattened_list(op.axes):
                op._axes = set_parallel_axes(op.axes, parallel_axis)
                if isinstance(op, DotOp):
                    if parallel_axis in op.x_out_axes:
                        op.x_out_axes = set_parallel_axes(
                            op.x_out_axes, parallel_axis)
                    elif parallel_axis in op.y_out_axes:
                        op.y_out_axes = set_parallel_axes(
                            op.y_out_axes, parallel_axis)
                    else:
                        raise ValueError("Missing parallel_axis in Op's "
                                         "x_out_axes or y_out_axes")

            if isinstance(op,
                          TensorValueOp) and parallel_axis in op.tensor.axes:
                op.tensor._axes = set_parallel_axes(op.tensor.axes,
                                                    parallel_axis)

            args_list = list(op.args)
            for arg_idx, arg_op in enumerate(args_list):
                if arg_op.uuid in orig_ops.keys():
                    if orig_ops[arg_op.uuid].metadata.get('clones') and \
                       orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)):
                        args_list[arg_idx] = \
                            orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id))

            op.invalidate_property_cache('all_deps')
            op._args = tuple(args_list)
            if op != new_root:
                if orig_ops[op.uuid].metadata.get('clones') is None:
                    orig_ops[op.uuid].metadata['clones'] = dict()
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op
                else:
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op

            op.uuid = uuid.uuid4()

    # create new uuids for all the ops that have references to the new root
    for _op in Op.all_op_references([new_root]):
        _op.uuid = uuid.uuid4()

    return new_root, new_send_nodes, replaced_send_nodes
コード例 #17
0
 def __init__(self, op, arg, axes):
     super(GPUFixedLayoutConstraint, self).__init__(op, arg)
     self.order = Axes.as_flattened_list(axes)
コード例 #18
0
 def __init__(self, op, arg):
     super(GPUEWLayoutConstraint, self).__init__(op, arg)
     if isinstance(op, ReductionOp):
         self.red_axes = Axes.as_flattened_list(op.reduction_axes)
     else:
         self.red_axes = None
コード例 #19
0
def get_flattened_axes(x):
    """
    Ordered list of axis visible to MKLDNN
    """
    return Axes.as_flattened_list(x)
コード例 #20
0
ファイル: hetr_utils.py プロジェクト: QiJune/ngraph
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones):
    """
    clone graph with serde (serialization)
    input:
    output: new_root of the cloned graph
    """
    # clone nodes with GatherSendOp as root using serde
    ser_cloned_nodes = deserialize_graph(serialize_graph([root]))
    new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None)

    orig_ops = {op.uuid: op for op in Op.ordered_ops([root])}
    # Prune ops that are not control_deps of new_gather_send_op
    # deserialize includes extra referenced nodes
    cloned_graph = Op.ordered_ops([new_root])

    new_send_nodes = OrderedSet()
    replaced_send_nodes = OrderedSet()

    # update newly cloned op metadata, generate new UUIDs
    for op in cloned_graph:
        cloned_ops = orig_ops[op.uuid].metadata.get('clones')
        if cloned_ops is None or cloned_ops.get(str(clone_id)) is None:
            op.metadata['transformer'] = op.metadata['device'] + str(clone_id)
            op.metadata['device_id'] = str(clone_id)

            if isinstance(
                    op,
                (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)):
                op._shared_queues = orig_ops[op.uuid]._shared_queues
                op.idx = shared_queues_idx
                if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)):
                    op._send_node = orig_ops[op.uuid].send_node()
            elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)):
                # Cloning a recv node means we need a broadcast, so simulate one by adding an
                # additional sender with the same input data as the original sender.
                send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0])
                op._queue = send_op.queue
                op._send_node = send_op
                new_send_nodes.add(send_op)
                replaced_send_nodes.add(orig_ops[op.uuid].send_node())

            if hasattr(
                    op,
                    'reduction_axes') and parallel_axis in op.reduction_axes:
                op.reduction_axes = set_parallel_axes(op.reduction_axes,
                                                      parallel_axis)

            if getattr(op, 'axes', None) is not None \
                    and parallel_axis in Axes.as_flattened_list(op.axes):
                # if parallel_axis in Axes.as_flattened_list(op.axes):
                op._axes = set_parallel_axes(op.axes, parallel_axis)
                if isinstance(op, DotOp):
                    if parallel_axis in op.x_out_axes:
                        op.x_out_axes = set_parallel_axes(
                            op.x_out_axes, parallel_axis)
                    elif parallel_axis in op.y_out_axes:
                        op.y_out_axes = set_parallel_axes(
                            op.y_out_axes, parallel_axis)
                    else:
                        raise ValueError("Missing parallel_axis in Op's "
                                         "x_out_axes or y_out_axes")

            if isinstance(op,
                          TensorValueOp) and parallel_axis in op.tensor.axes:
                op.tensor._axes = set_parallel_axes(op.tensor.axes,
                                                    parallel_axis)

            args_list = list(op.args)
            for arg_idx, arg_op in enumerate(args_list):
                if arg_op.uuid in orig_ops.keys():
                    if orig_ops[arg_op.uuid].metadata.get('clones') and \
                       orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)):
                        args_list[arg_idx] = \
                            orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id))
            op.invalidate_property_cache('all_deps')
            op._args = tuple(args_list)
            if op != new_root:
                if orig_ops[op.uuid].metadata.get('clones') is None:
                    orig_ops[op.uuid].metadata['clones'] = dict()
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op
                else:
                    orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op

            op.uuid = uuid.uuid4()

    return new_root, new_send_nodes, replaced_send_nodes
コード例 #21
0
ファイル: layout_common.py プロジェクト: rsumner31/ngraph
    def __init__(self, op, arg):
        self.op = op
        self.arg = arg

        # Build mapping of arg axis position to axis position in buffer
        # Arg axes may be re-ordered, cast, broadcast, sliced between the original
        # buffer and the op being used for this constraint
        predecessor_op = arg
        while isinstance(predecessor_op, SequentialOp):
            predecessor_op = predecessor_op.value_tensor

        self.arg_axes_list = Axes.as_flattened_list(arg.axes)
        self.mappings = {}
        self.sliced_out = []
        for i in range(len(self.arg_axes_list)):
            self.mappings[i] = i

        while not (predecessor_op.is_device_op or isinstance(predecessor_op, TensorValueOp)):
            pred_axes = Axes.as_flattened_list(predecessor_op.axes)
            pred_arg_axes = Axes.as_flattened_list(predecessor_op.args[0].axes)
            if isinstance(predecessor_op, (BroadcastOp, ExpandDims)):
                bcast_axes = [pred_axes.index(a) for a in pred_axes if a not in pred_arg_axes]
                bcast_mappings = [a for a in self.mappings if self.mappings[a] in bcast_axes]
                for bcast in bcast_mappings:
                    self.mappings[bcast] = "bcast"
                for a, p in self.mappings.items():
                    if isinstance(p, int):
                        offset = 0
                        for bcast_axis in bcast_axes:
                            if p > bcast_axis:
                                offset += 1
                        self.mappings[a] = p - offset

                for i in range(len(self.sliced_out)):
                    if self.sliced_out[i][0] in bcast_axes:
                        self.sliced_out[i] = (self.sliced_out[i][0], "bcast")
                    else:
                        new_axis_index = pred_arg_axes.index(pred_axes[self.sliced_out[i][0]])
                        self.sliced_out[i] = (new_axis_index, self.sliced_out[i][1])
            elif isinstance(predecessor_op, TensorSliceOp):
                old_index = 0
                new_indexes = []
                for index, axis in enumerate(pred_arg_axes):
                    if isinstance(predecessor_op.slices[index], int):
                        self.sliced_out.append((index, predecessor_op.slices[index]))
                    else:
                        if predecessor_op.slices[index] != slice(None, None, None):
                            new_index = ("slice", index, predecessor_op.slices[index])
                        else:
                            new_index = index

                        # Remap this axis
                        for a, p in self.mappings.items():
                            if isinstance(p, int) and p == old_index:
                                new_indexes.append((a, new_index))
                        old_index += 1

                for a, p in new_indexes:
                    self.mappings[a] = p
            elif isinstance(predecessor_op, (ReorderAxes, Transpose)):
                new_indexes = []
                for a, p in self.mappings.items():
                    if isinstance(p, int):
                        new_indexes.append((a, pred_arg_axes.index(pred_axes[p])))
                for a, p in new_indexes:
                    self.mappings[a] = p

                for i in range(len(self.sliced_out)):
                    new_axis_index = pred_arg_axes.index(pred_axes[self.sliced_out[i][0]])
                    self.sliced_out[i] = (new_axis_index, self.sliced_out[i][1])
            elif isinstance(predecessor_op, AxesCastOp):
                pass
            elif isinstance(predecessor_op, Flatten):
                pass
            elif isinstance(predecessor_op, Unflatten):
                pass
            else:
                raise RuntimeError("Confused")
            predecessor_op = predecessor_op.args[0]
            while isinstance(predecessor_op, SequentialOp):
                predecessor_op = predecessor_op.value_tensor