示例#1
0
    def needs_transform(self, arg_layout, op_layout):
        """
        Given the op layout and argument layout, check if a DimshuffleOp is needed to convert
        the argument to a suitable layout.

        Arguments:
            arg_layout (GPULayoutAssignment): layout of the argument
            op_layout: (GPULayoutAssignment): layout required by the op

        Returns:
            True if a DimshuffleOp is needed to convert the arg
        """
        # Flattened arg layout axes list used to determine arg contiguity
        arg_mem_order = flatten(arg_layout.axes)

        # Contiguity requirements come from this op's layout groupings
        compatible = True
        for op_axis in op_layout.axes:
            arg_axis = [op_layout.ng_axes[i] for i in op_axis]
            if isinstance(self.op, OneHotOp) and arg_axis[0] == self.op.axis:
                continue
            if not self.group_axis_strided_valid(arg_mem_order, arg_axis):
                compatible = False
                break

        # Check for reduction axes
        if isinstance(self.op, ReductionOp):
            red_axis = [a for a in self.red_axes]
            if not self.group_axis_strided_valid(arg_mem_order, red_axis):
                compatible = False

        return (not compatible)
示例#2
0
    def get_layout_transform(self, arg_layout, op_layout, arg):
        """
        Generates either a DimshuffleOp or GPUIndexOp for the argument that produces a view
        which satisfies the dot op layout.

        Arguments:
            arg_layout (GPULayoutAssignment): layout of the argument
            op_layout: (GPULayoutAssignment): layout required by the op
            arg (TensorOp): Op producing the argument

        Either a GPUIndexOp if no transform is needed, or a DimshuffleOp which satisfies
            the requirements of the op_layout
        """
        arg_mem_order = flatten(arg_layout.axes)
        arg_axes = arg_layout.ng_axes
        args = list(self.op.args)

        reduction_group = [a for a in self.reduction_axes]
        out_group = [a for a in self.out_axes]

        if self.needs_transform(arg_layout, op_layout):
            if self.arg.forwarded is args[0].forwarded:
                out_groups = [out_group, reduction_group]
                return self.get_dimshuffle(arg_mem_order, arg_axes, out_groups, arg)
            else:
                out_groups = [reduction_group, out_group]
                return self.get_dimshuffle(arg_mem_order, arg_axes, out_groups, arg)
        else:
            if self.arg.forwarded is args[0].forwarded:
                out_groups = [out_group, reduction_group]
                return self.get_reshape(arg_mem_order, arg_axes, out_groups, arg)
            else:
                out_groups = [reduction_group, out_group]
                return self.get_reshape(arg_mem_order, arg_axes, out_groups, arg)
示例#3
0
    def get_layout_transform(self, arg_layout, op_layout, arg):
        arg_mem_order = flatten(arg_layout.axes)
        arg_axes = arg_layout.ng_axes

        if self.needs_transform(arg_layout, op_layout):
            return self.get_dimshuffle(arg_mem_order, arg_axes, self.order, arg)
        else:
            return self.get_reshape(arg_mem_order, arg_axes, self.order, arg)
示例#4
0
    def needs_transform(self, arg_layout, op_layout):
        """
        Checks if reduction_axes and out_axes are contiguous and if the argument
        meets the transpose requirements of the op

        Arguments:
            arg_layout (GPULayoutAssignment): layout of the argument
            op_layout: (GPULayoutAssignment): layout required by the op

        Returns:
            True if a DimshuffleOp is needed to convert the arg
        """
        arg_mem_order = flatten(arg_layout.axes)
        out_mem_order = flatten(op_layout.axes)

        reduction_group = [a for a in self.reduction_axes]
        out_group = [
            self.op_axes[i] for i in out_mem_order
            if self.op_axes[i] in self.out_axes
        ]

        # Check if this argument can be transposed
        if self.operand == 'A':
            can_trans = op_layout.A_trans
        elif self.operand == 'B':
            can_trans = op_layout.B_trans

        # Each arg must have two contiguous axes where one matches
        # reduction axes and the other matches one of the output axes
        if len(reduction_group) == 0 or self.group_axis_contig(
                arg_mem_order, reduction_group):
            if can_trans:
                if self.group_axis_contig(arg_mem_order, out_group):
                    return False
            else:
                # Make sure operand is not transposed
                if self.operand == 'A':
                    required_layout = out_group + reduction_group
                elif self.operand == 'B':
                    required_layout = reduction_group + out_group

                if self.group_axis_contig(arg_mem_order, required_layout):
                    return False

        return True
示例#5
0
    def get_layout_transform(self, arg_layout, op_layout, arg):
        """
        Returns a reshape view of the argument strided to match the AssignOp axes

        Arguments:
            arg_layout (GPULayoutAssignment): layout of the argument
            op_layout: (GPULayoutAssignment): layout required by the op
            arg (TensorOp): Op producing the argument

        A GPUIndexOp which satisfies the requirements of the op_layout
        """
        arg_mem_order = flatten(arg_layout.axes)
        arg_view_axes = Axes.as_flattened_list(arg.axes)
        arg_axes = arg_layout.ng_axes
        out_groups = [[a] for a in arg_view_axes]
        return self.get_reshape(arg_mem_order, arg_axes, out_groups, arg)
示例#6
0
    def needs_transform(self, arg_layout, op_layout):
        """
        Checks if all axes in self.order are contiguous in the argument.

        Arguments:
            arg_layout (GPULayoutAssignment): layout of the argument
            op_layout: (GPULayoutAssignment): layout required by the op

        Returns:
            True if a DimshuffleOp is needed to convert the arg
        """
        arg_mem_order = flatten(arg_layout.axes)
        if not self.group_axis_contig(arg_mem_order, self.order):
            return True

        return False
示例#7
0
    def get_layout_transform(self, arg_layout, op_layout, arg):
        """
        Generates either a DimshuffleOp or GPUIndexOp for the argument that produces a view
        which satisfies contiguous order requirement.

        Arguments:
            arg_layout (GPULayoutAssignment): layout of the argument
            op_layout: (GPULayoutAssignment): layout required by the op
            arg (TensorOp): Op producing the argument

        Either a GPUIndexOp if no transform is needed, or a DimshuffleOp which satisfies
            the requirements of the op_layout
        """
        arg_mem_order = flatten(arg_layout.axes)
        arg_axes = arg_layout.ng_axes

        if self.needs_transform(arg_layout, op_layout):
            return self.get_dimshuffle(arg_mem_order, arg_axes, [self.order], arg)
        else:
            return self.get_reshape(arg_mem_order, arg_axes, [self.order], arg)
示例#8
0
    def get_layout_transform(self, arg_layout, op_layout, arg):
        """
        Given the op layout and argument layout, check if a DimshuffleOp is needed to convert
        the argument to a suitable layout. Generates either a DimshuffleOp or GPUIndexOp for
        the argument which produces a view which satisfies the op_layout assignment.

        Arguments:
            arg_layout (GPULayoutAssignment): layout of the argument
            op_layout: (GPULayoutAssignment): layout required by the op
            arg (TensorOp): Op producing the argument

        Returns:
            Either a GPUIndexOp if no transform is needed, or a DimshuffleOp which satisfies
            the requirements of the op_layout
        """
        # Flattened arg layout axes list used to determine arg contiguity
        arg_mem_order = flatten(arg_layout.axes)
        arg_axes = arg_layout.ng_axes

        if self.needs_transform(arg_layout, op_layout):
            if isinstance(self.op, ReductionOp):
                # Dimshuffle to 3d with out axis groups plus reduction group
                out_groups = [[a for a in self.red_axes]
                              ] if self.red_axes else []
                for op_axis in op_layout.axes:
                    out_groups.append([op_layout.ng_axes[i] for i in op_axis])
                return self.get_dimshuffle(arg_mem_order, arg_axes, out_groups,
                                           arg)
            elif isinstance(self.op, OneHotOp):
                # Dimshuffle to 3d with out axis groups other than onehot axis
                out_groups = []
                for op_axis in op_layout.axes:
                    group = [op_layout.ng_axes[i] for i in op_axis]
                    if self.op.axis in group:
                        assert len(group) == 1
                        continue
                    out_groups.append(group)
                return self.get_dimshuffle(arg_mem_order, arg_axes, out_groups,
                                           arg)
            else:
                # Dimshuffle to 3d with out axis groups
                out_groups = []
                for op_axis in op_layout.axes:
                    out_groups.append([op_layout.ng_axes[i] for i in op_axis])
                return self.get_dimshuffle(arg_mem_order, arg_axes, out_groups,
                                           arg)
        else:
            # Compute derived layout for arg
            if isinstance(self.op, ReductionOp):
                out_groups = [[a for a in self.red_axes]
                              ] if self.red_axes else []
                for op_axis in op_layout.axes:
                    out_groups.append([op_layout.ng_axes[i] for i in op_axis])
                return self.get_reshape(arg_mem_order, arg_axes, out_groups,
                                        arg)
            elif isinstance(self.op, OneHotOp):
                out_groups = []
                for op_axis in op_layout.axes:
                    group = [op_layout.ng_axes[i] for i in op_axis]
                    if self.op.axis in group:
                        assert len(group) == 1
                        continue
                    out_groups.append(group)
                return self.get_reshape(arg_mem_order, arg_axes, out_groups,
                                        arg)
            else:
                out_groups = []
                for op_axis in op_layout.axes:
                    out_groups.append([op_layout.ng_axes[i] for i in op_axis])
                return self.get_reshape(arg_mem_order, arg_axes, out_groups,
                                        arg)

            return arg