def forward(self, input):
        if self.dropout_fn is not None:
            dropped_w = self.dropout_fn.forward(self.weight, self.training)
        else:
            dropped_w = self.weight

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)
            return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
                            dropped_w, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)

        return F.conv2d(input, dropped_w, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)
Beispiel #2
0
def conv_ws_2d(input,
               weight,
               bias=None,
               stride=1,
               padding=0,
               dilation=1,
               groups=1,
               eps=1e-5):
    """Conv2d with weight standarlization.

    :param input: input feature map
    :type input: torch.Tensor
    :param weight: weight of conv layer
    :type weight: torch.Tensor
    :param bias: bias
    :type bias: torch.Tensor
    :param stride: conv stride
    :type stride: int
    :param padding: num of padding
    :type padding: int
    :param dilation: num of dilation
    :type dilation: int
    :param groups: num of group
    :type groups: int
    :param eps: weight eps
    :type eps: float
    :return: feature map after weight standarlization
    :rtype: torch.Tensor
    """
    c_in = weight.size(0)
    weight_flat = weight.view(c_in, -1)
    mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
    std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
    weight = (weight - mean) / (std + eps)
    return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
def apply_cdna_kernel(image, kernel):
    """
        Inputs:
            image -> tensor[B, C, H, W]
            kernal -> tensor[B, N, K, K]
        Outputs:
            new_image -> tensor[B, N, C, H, W]
    """
    batch_size = image.shape[0]
    image_channel = image.shape[1]
    image_height = image.shape[2]
    image_width = image.shape[3]
    num_kernel = kernel.shape[1]
    kernel_size = kernel.shape[2]
    padding = kernel_size // 2

    _image = image.transpose(0, 1)  # [C, B, H, W]
    _kernel = kernel.view(batch_size * num_kernel, 1, kernel_size,
                          kernel_size)  # [B * N, 1, K, K]

    output = F.conv2d(_image,
                      _kernel,
                      stride=1,
                      padding=padding,
                      groups=batch_size)  # [C, B * N, H, W]

    output = output.view(image_channel, batch_size, num_kernel, image_height,
                         image_width)  # [C, B, N, H, W]

    return output.permute(1, 2, 0, 3, 4)  # [B, N, C, H, W]
Beispiel #4
0
 def forward(self, input):
     out = F.conv2d(input,
                    self.weight * self.scale,
                    bias=self.bias,
                    stride=self.stride,
                    padding=self.padding)
     return out
Beispiel #5
0
def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)
    def blur(self, tensor_image):
        kernel = [[0., 1., 1.],
                  [0., 3., 1.],
                  [0., 0., 0.]]
        kernel_1 = [[1., 0., 1.],
                    [0., 0, 0.],
                    [1., 0., 1.]]

        min_batch = tensor_image.size()[0]
        channels = tensor_image.size()[1]
        out_channel = channels
        kernel = torch.tensor(kernel_1).expand(out_channel, channels, 3, 3)
        weight = nn.Parameter(data=kernel, requires_grad=False)

        return F.conv2d(tensor_image, weight, padding=1)
Beispiel #7
0
 def forward(self, input, eps=1e-4):
     weight = self.standardize_weight(eps)
     return F.conv2d(input, weight, self.bias, self.stride, self.padding,
                     self.dilation, self.groups)
Beispiel #8
0
 def forward(self, x):
     return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
Beispiel #9
0
 def forward(self, img):
     img = F.conv2d(img, self.weight, None, stride=self.stride, padding=self.padding)
     img[0] += self.bias.view(-1, 1, 1)
     return img
Beispiel #10
0
    def forward(self, input, style):
        batch, in_channel, height, width = input.shape

        # Modulate the style
        style = self.modulation(style).view(batch, 1, in_channel, 1, 1)

        # Adding style modulation to weight
        weight = self.scale * self.weight * style

        if self.demodulate:
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
            weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
        #A good explenation of modulation/demodulation https://youtu.be/MYCTn80qSk0?t=142

        weight = weight.view(batch * self.out_channel, in_channel,
                             self.kernel_size, self.kernel_size)

        if self.up_sample:

            # Reshape input for deconvolution
            input = input.view(1, batch * in_channel, height, width)

            # Reshape weights for deconvolution
            weight = weight.view(batch, self.out_channel, in_channel,
                                 self.kernel_size, self.kernel_size)
            weight = weight.transpose(1, 2).reshape(batch * in_channel,
                                                    self.out_channel,
                                                    self.kernel_size,
                                                    self.kernel_size)

            # Upsampling convolution
            out = F.conv_transpose2d(input,
                                     weight,
                                     padding=0,
                                     stride=2,
                                     groups=batch)
            _, _, height, width = out.shape

            # Reshape output to initial shape
            out = out.view(batch, self.out_channel, height, width)

            # Blur the output
            out = self.blur(out)

        elif self.down_sample:
            # blur image before convolution
            input = self.blur(input)

            # Reshape input for convolution
            _, _, height, width = input.shape
            input = input.view(1, batch * in_channel, height, width)

            # Downsampling convolution
            out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
            _, _, height, width = out.shape

            # Reshape output to initial shape
            out = out.view(batch, self.out_channel, height, width)

        else:
            # Reshape input for convolution
            input = input.view(1, batch * in_channel, height, width)

            # Downsampling convolution
            out = F.conv2d(input, weight, padding=self.padding, groups=batch)
            _, _, height, width = out.shape

            # Reshape output to initial shape
            out = out.view(batch, self.out_channel, height, width)

        return out