def __init__(self, op, arg): super(GPULutLayoutConstraint, self).__init__(op, arg) if len(arg.axes) == 2: self.order = [Axes.as_flattened_list(Axes(arg.axes[0])), Axes.as_flattened_list(Axes(arg.axes[1]))] else: self.order = [Axes.as_flattened_list(arg.axes)]
def get_op_shape_and_layout(self, op, mkl_order, index=0): exop = self.get_exop(op) mkl_layout = exop.output_decls[index].tensor_view_decl.mkl_layout op_axes_mkl = [op.axes[idx] for idx in mkl_order] mkl_shape = [a.length for a in op_axes_mkl] if mkl_layout: (in_layout, in_axes) = mkl_layout # Check if we need to rotate axes in the MKL layout object if op_axes_mkl != in_axes: assert Axes( get_flattened_axes(in_axes)).is_equal_set( Axes( get_flattened_axes(op_axes_mkl))) mkl_layout = get_rotated_layout( self.mkldnn, in_layout, get_flattened_axes(in_axes), get_flattened_axes(op_axes_mkl)) else: mkl_layout = in_layout else: # TODO(jbobba): Need to change this to use tensor_decl mkl_layout = get_native_layout(self.mkldnn, exop.output_decls[ index].tensor_description, mkl_order)[0] return mkl_shape, mkl_layout