コード例 #1
0
 def __init__(self, op, arg):
     super(GPULutLayoutConstraint, self).__init__(op, arg)
     if len(arg.axes) == 2:
         self.order = [Axes.as_flattened_list(Axes(arg.axes[0])),
                       Axes.as_flattened_list(Axes(arg.axes[1]))]
     else:
         self.order = [Axes.as_flattened_list(arg.axes)]
コード例 #2
0
    def get_op_shape_and_layout(self, op, mkl_order, index=0):
        exop = self.get_exop(op)
        mkl_layout = exop.output_decls[index].tensor_view_decl.mkl_layout
        op_axes_mkl = [op.axes[idx] for idx in mkl_order]
        mkl_shape = [a.length for a in op_axes_mkl]
        if mkl_layout:
            (in_layout, in_axes) = mkl_layout
            # Check if we need to rotate axes in the MKL layout object
            if op_axes_mkl != in_axes:
                assert Axes(
                    get_flattened_axes(in_axes)).is_equal_set(
                    Axes(
                        get_flattened_axes(op_axes_mkl)))
                mkl_layout = get_rotated_layout(
                    self.mkldnn,
                    in_layout,
                    get_flattened_axes(in_axes),
                    get_flattened_axes(op_axes_mkl))
            else:
                mkl_layout = in_layout
        else:
            # TODO(jbobba): Need to change this to use tensor_decl
            mkl_layout = get_native_layout(self.mkldnn, exop.output_decls[
                                           index].tensor_description, mkl_order)[0]

        return mkl_shape, mkl_layout