def downsample(img, scale, border='reflect'): """Bicubical downsample via **CONV2D**. Using PIL's kernel. Args: img: a tf tensor of 2/3/4-D. scale: n or 1/n. `n` must be integer >= 2. border: padding mode. Recommend to 'REFLECT'. """ device = img.device kernel, s = weights_downsample(scale) if s == 1: return img # bypass kernel = kernel.astype('float32') kernel = torch.from_numpy(kernel) p1 = int(s * 3 / 2) p2 = 4 * s - int(s * 3 / 2) img, shape = _push_shape_4d(img) img_ex = F.pad(img, [p1, p2, p1, p2], mode=border) c = img_ex.shape[1] assert c is not None, "img must define channel number" c = int(c) filters = torch.reshape(torch.eye(c, c), [c, c, 1, 1]) * kernel img_s = F.conv2d(img_ex, filters.to(device), stride=s) img_s = _pop_shape(img_s, shape) return img_s
def downsample(img, scale, border='REFLECT'): """Bicubical downsample via **CONV2D**. Using PIL's kernel. Args: img: a tf tensor of 2/3/4-D. scale: n or 1/n. `n` must be integer >= 2. border: padding mode. Recommend to 'REFLECT'. """ kernel, s = weights_downsample(scale) if s == 1: return img # bypass kernel = tf.convert_to_tensor(kernel, dtype='float32') p1 = int(s * 3 / 2) p2 = 4 * s - int(s * 3 / 2) img, shape = _push_shape_4d(img) img_ex = tf.pad(img, [[0, 0], [p1, p2], [p1, p2], [0, 0]], border) c = img_ex.shape[-1] assert c is not None, "img must define channel number" c = int(c) filters = tf.reshape(tf.eye(c, c), [c, c, 1, 1]) * kernel filters = tf.transpose(filters, [2, 3, 0, 1]) img_s = tf.nn.conv2d(img_ex, filters, [1, s, s, 1], 'VALID') img_s = _pop_shape(img_s, shape) return img_s