def forward(self, input): out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) return out
def conv_downsample_2d(x, w, k=None, factor=2, gain=1): """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype as `x`. """ assert isinstance(factor, int) and factor >= 1 _outC, _inC, convH, convW = w.shape assert convW == convH if k is None: k = [1] * factor k = _setup_kernel(k) * gain p = (k.shape[0] - factor) + (convW - 1) s = [factor, factor] x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2)) return F.conv2d(x, w, stride=s, padding=0)
def forward(self, input): out = upfirdn2d(input, self.kernel, pad=self.pad) return out
def upsample_conv_2d(x, w, k=None, factor=2, gain=1): """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as `x`. """ assert isinstance(factor, int) and factor >= 1 # Check weight shape. assert len(w.shape) == 4 convH = w.shape[2] convW = w.shape[3] inC = w.shape[1] outC = w.shape[0] assert convW == convH # Setup filter kernel. if k is None: k = [1] * factor k = _setup_kernel(k) * (gain * (factor**2)) p = (k.shape[0] - factor) - (convW - 1) stride = (factor, factor) # Determine data dimensions. stride = [1, 1, factor, factor] output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW) assert output_padding[0] >= 0 and output_padding[1] >= 0 num_groups = _shape(x, 1) // inC # Transpose weights. w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) ## Original TF code. # x = tf.nn.conv2d_transpose( # x, # w, # output_shape=output_shape, # strides=stride, # padding='VALID', # data_format=data_format) ## JAX equivalent return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
def forward(self, input): out = upfirdn2d(input, self.kernel, pad=self.pad) # print('Blur', out.size()) return out
def forward(self, input): return upfirdn2d(input, self.kernel, pad=self.pad)
def forward(self, input): return upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)