Beispiel #1
0
    def forward(self, input, style):
        batch, in_channel, height, width = input.shape

        style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
        weight = self.scale * self.weight * style

        if self.demodulate:
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
            weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)

        weight = weight.view(batch * self.out_channel, in_channel,
                             self.kernel_size, self.kernel_size)

        if self.upsample:
            input = input.view(1, batch * in_channel, height, width)
            weight = weight.view(batch, self.out_channel, in_channel,
                                 self.kernel_size, self.kernel_size)
            weight = weight.transpose(1, 2).reshape(batch * in_channel,
                                                    self.out_channel,
                                                    self.kernel_size,
                                                    self.kernel_size)
            out = F.conv_transpose2d(input,
                                     weight,
                                     padding=0,
                                     stride=2,
                                     groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)
            out = self.blur(out)

        elif self.downsample:
            input = self.blur(input)
            _, _, height, width = input.shape
            input = input.view(1, batch * in_channel, height, width)
            out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        else:
            input = input.view(1, batch * in_channel, height, width)
            out = F.conv2d(input, weight, padding=self.padding, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        return out
Beispiel #2
0
    def forward(self, x, style):
        """Forward function.

        Args:
            x (Tensor): Tensor with shape (b, c, h, w).
            style (Tensor): Tensor with shape (b, num_style_feat).

        Returns:
            Tensor: Modulated tensor after convolution.
        """
        b, c, h, w = x.shape  # c = c_in
        # weight modulation
        style = self.modulation(style).view(b, 1, c, 1, 1)
        # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
        weight = self.scale * self.weight * style  # (b, c_out, c_in, k, k)

        if self.demodulate:
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
            weight = weight * demod.view(b, self.out_channels, 1, 1, 1)

        weight = weight.view(b * self.out_channels, c, self.kernel_size,
                             self.kernel_size)

        if self.sample_mode == 'upsample':
            x = x.view(1, b * c, h, w)
            weight = weight.view(b, self.out_channels, c, self.kernel_size,
                                 self.kernel_size)
            weight = weight.transpose(1, 2).reshape(b * c, self.out_channels,
                                                    self.kernel_size,
                                                    self.kernel_size)
            out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=b)
            out = out.view(b, self.out_channels, *out.shape[2:4])
            out = self.smooth(out)
        elif self.sample_mode == 'downsample':
            x = self.smooth(x)
            x = x.view(1, b * c, *x.shape[2:4])
            out = F.conv2d(x, weight, padding=0, stride=2, groups=b)
            out = out.view(b, self.out_channels, *out.shape[2:4])
        else:
            x = x.view(1, b * c, h, w)
            # weight: (b*c_out, c_in, k, k), groups=b
            out = F.conv2d(x, weight, padding=self.padding, groups=b)
            out = out.view(b, self.out_channels, *out.shape[2:4])

        return out
Beispiel #3
0
 def forward(self, A, S):
     A = A.reshape(A.shape[0], -1, self.dl, *A.shape[-2:])
     x = torch.empty(A.shape[0],
                     A.shape[1],
                     3,
                     *S.shape[-2:],
                     dtype=A.dtype).to(A.device)
     for ii in range(A.shape[0]):
         x[ii] = F.conv_transpose2d(A[ii],
                                    S[ii],
                                    stride=self.trans_stride,
                                    padding=self.trans_padding)
     x = x.reshape(x.shape[0], -1, *x.shape[-2:])
     x = self.cls(x)
     x = self.pool(x)
     x = x.view(x.shape[0], -1)
     x = self.lr(x)
     return x
 def forward(self, x):
     """
     Initialization
     """
     if not hasattr(self.weight, "latent_"):
         self.weight.latent_ = self.weight.data
     self.weight.data = binarize(self.weight.latent_)
     if not self.bias is None:
         self.bias.latent_ = self.bias.data.clone()
     return F.conv_transpose2d(
         input=x,
         weight=self.weight,
         bias=self.bias,
         stride=self.stride,
         padding=self.padding,
         groups=self.groups,
         dilation=self.dilation,
     )
def hollowUpSampleTensor(inputData):
    '''
    空洞上采样
    连续4个通道按照如下所示上采样
    [1 0  [0 1  [0 0  [0 0
     0 0]  0 0]  1 0]  0 1]

    :param inputData: batchSize * c * w * h
    :param hollowFilter: batchSize * c * 2w * 2h
    :return:
    '''
    batchSize, c, w, h = inputData.shape
    kernel = torch.zeros(size=[c, 1, 2, 2], device=inputData.device)
    kernel[0::4, 0, 0, 0] = 1
    kernel[1::4, 0, 0, 1] = 1
    kernel[2::4, 0, 1, 0] = 1
    kernel[3::4, 0, 1, 1] = 1
    return F.conv_transpose2d(inputData, kernel, stride=2, groups=c)
Beispiel #6
0
    def forward(self, t, x):
        self.nfe += 1
        y = self.norm1_1(x[0])
        z = self.norm1_2(x[1])

        out_y = F.conv2d(
            z, self.kernel, stride=self.stride, padding=1,
            bias=self.bias_1) - self.gamma * y
        out_y = self.leakyrelu(out_y)
        out_z = -F.conv_transpose2d(
            y, self.kernel, stride=self.stride, padding=1,
            bias=self.bias_2) - self.gamma * z
        out_z = self.leakyrelu(out_z)

        out_y = self.norm2_1(out_y)
        out_z = self.norm2_2(out_z)

        return out_y, out_z
Beispiel #7
0
    def forward(self, input, output_size=None):
        weight = kaiming_normal_scale(self.weight * self.scale,
                                      a=0.2,
                                      mode="fan_in",
                                      nonlinearity="leaky_relu")
        output_padding = self._output_padding(input, output_size, self.stride,
                                              self.padding, self.kernel_size)

        return F.conv_transpose2d(
            input,
            weight,
            self.bias,
            self.stride,
            self.padding,
            output_padding,
            self.groups,
            self.dilation,
        )
Beispiel #8
0
    def forward(self, x):
        fused_scale = self.fused_scale
        if fused_scale == "auto":
            fused_scale = min(x.shape[2:]) * 2 >= 128

        if not fused_scale:
            x = upscale2d(x)
            x = F.conv2d(x, self.weight, padding=1)
        else:
            w = self.weight.permute(1, 0, 2, 3)
            w = F.pad(w, (1, 1, 1, 1))
            w = w[:, :, 1:, 1:] + w[:, :, :-1,
                                    1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]
            x = F.conv_transpose2d(x,
                                   w,
                                   stride=2,
                                   padding=(w.size(-1) - 1) // 2)
        return x
Beispiel #9
0
    def forward(self, x):

        ih, iw = x.size()[-2:]
        kh, kw = self.kernel_size
        sh, sw = self.stride
        dh, dw = self.dilation
        oh, ow = math.ceil(ih * sh), math.ceil(iw * sw)
        pad_h = max(((ih - 1) * sh + (kh - 1) * dh + 1 - oh), 0)
        pad_w = max(((iw - 1) * sw + (kw - 1) * dw + 1 - ow), 0)
        x = F.conv_transpose2d(x, self.weight, self.bias, self.stride,
                               self.padding, self.output_padding, self.groups,
                               self.dilation)
        if pad_h > 0 or pad_w > 0:
            xh, xw = x.size()[-2:]
            x = x[:, :, 1:xh - (pad_h - pad_h // 2),
                  pad_w // 2:xw - (pad_w - pad_w // 2)]

        return x
def render_attn_frame(attn: np.ndarray, receptive_field: int, stride: int,
                      padding: int):
    attn = torch.Tensor(attn)
    b, h, w = attn.shape
    true_attn = torch.empty((
        b,
        receptive_field + (h - 1) * stride - 2 * padding,
        receptive_field + (w - 1) * stride - 2 * padding,
    ))
    flt = torch.ones((1, 1, receptive_field, receptive_field))
    for bb in range(b):
        true_attn[bb, :, :] = F.conv_transpose2d(
            attn[bb, :, :][None, None, ...],
            flt,
            stride=stride,
            padding=padding,
        )
    return true_attn.clamp(min=0).numpy()
Beispiel #11
0
    def forward(self, x, rev=False, jac=True):
        if jac:
            warnings.warn(
                'Invertible Autoencoder layers do not have a tractable log-det-Jacobian.'
                'It approaches 0 at convergence, but the value may be incorrect duing training.'
            )

        if not rev:
            out = self.conv2d(x[0])
            out += self.bias
        else:
            out = x[0] - self.bias
            out = f.conv_transpose2d(out,
                                     self.conv2d.weight,
                                     bias=None,
                                     padding=self.padding)

        return [out], 0.
    def test_weight_fake_quant_per_tensor(self):
        kernel_size = 8

        quant_conv_object = quant_conv.QuantConvTranspose2d(
            _NUM_IN_CHANNELS,
            _NUM_OUT_CHANNELS,
            kernel_size,
            bias=False,
            quant_desc_weight=QuantDescriptor())
        quant_conv_object.input_quantizer.disable()
        test_input = torch.randn(256, _NUM_IN_CHANNELS, 32, 32)

        weight_copy = quant_conv_object.weight.clone()
        quant_weight = tensor_quant.fake_tensor_quant(weight_copy, torch.max(torch.abs(weight_copy)))

        out1 = F.conv_transpose2d(test_input, quant_weight)
        out2 = quant_conv_object(test_input)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(), out2.detach().cpu().numpy())
    def test_no_quant(self):
        kernel_size = 3

        quant_conv_object = quant_conv.QuantConvTranspose2d(
            _NUM_IN_CHANNELS,
            _NUM_OUT_CHANNELS,
            kernel_size,
            bias=False)
        quant_conv_object.input_quantizer.disable()
        quant_conv_object.weight_quantizer.disable()
        test_input = torch.randn(16, _NUM_IN_CHANNELS, 32, 32)

        weight_copy = quant_conv_object.weight.clone()
        quant_weight = weight_copy

        out1 = F.conv_transpose2d(test_input, quant_weight)
        out2 = quant_conv_object(test_input)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(), out2.detach().cpu().numpy())
Beispiel #14
0
    def forward(self, input):
        weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1])
        weight = (
            weight[:, :, 1:, 1:]
            + weight[:, :, :-1, 1:]
            + weight[:, :, 1:, :-1]
            + weight[:, :, :-1, :-1]
        ) / 4

        if hasattr(self.quant, "activation_post_process") and (self.weight_fake_quant == None):
            self.weight_fake_quant = self.quant.qconfig.weight().to(self.weight.device)
        
        if self.weight_fake_quant:
            weight = self.weight_fake_quant(weight)
            
        out = F.conv_transpose2d(input, weight, self.bias, stride=2, padding=self.pad)

        return self.quant(out)
    def forward(self, x):  # 依次调用各层
        x = self.pixel_norm(x)
        x = self.upsample(x)

        if self.use_conv2d_transpose:
            kernel = self.weight * self.scale
            kernel = F.pad(kernel, (0, 0, 0, 0, 1, 1, 1, 1), 'constant', 0.0)
            kernel = (kernel[1:, 1:] + kernel[:-1, 1:] + kernel[1:, :-1] +
                      kernel[:-1, :-1])
            kernel = kernel.permute(2, 3, 0, 1)
            x = F.conv_transpose2d(x, kernel, stride=2, padding=1)  # 进行逆卷积运算
            x = x / self.scale
        else:
            x = self.conv(x)

        x = self.wscale(x)
        x = self.activate(x)
        return x
Beispiel #16
0
    def forward(self, x, rev=False):
        if not rev:
            self.elements = x.shape[1] * x.shape[2] * x.shape[3]
            self.last_jac = self.elements / 4 * np.log(1/16.)

            out = F.conv2d(x, self.haar_weights, bias=None, stride=2, groups=self.channel_in) / 4.0
            out = out.reshape([x.shape[0], self.channel_in, 4, x.shape[2] // 2, x.shape[3] // 2])
            out = torch.transpose(out, 1, 2)
            out = out.reshape([x.shape[0], self.channel_in * 4, x.shape[2] // 2, x.shape[3] // 2])
            return out
        else:
            self.elements = x.shape[1] * x.shape[2] * x.shape[3]
            self.last_jac = self.elements / 4 * np.log(16.)

            out = x.reshape([x.shape[0], 4, self.channel_in, x.shape[2], x.shape[3]])
            out = torch.transpose(out, 1, 2)
            out = out.reshape([x.shape[0], self.channel_in * 4, x.shape[2], x.shape[3]])
            return F.conv_transpose2d(out, self.haar_weights, bias=None, stride=2, groups = self.channel_in)
    def forward(self, x):
        weight = self.weight * self.weight_scale * self.lr_multiplier
        if self.scale_factor > 1:
            weight = weight.flip(0, 1).permute(2, 3, 0, 1)
            x = F.conv_transpose2d(x,
                                   weight,
                                   stride=self.scale_factor,
                                   padding=0)
            x = self.filter(x)
        else:
            weight = weight.permute(3, 2, 0, 1)
            x = F.conv2d(x, weight, stride=1, padding=self.conv_padding)

        if self.add_bias:
            bias = self.bias * self.lr_multiplier
            x = x + bias.view(1, -1, 1, 1)
        x = self.activate(x) * self.activate_scale
        return x
Beispiel #18
0
def conv_power_method(D, image_size, num_iters=100, stride=1):
    """
    Finds the maximal eigenvalue of D.T.dot(D) using the iterative power method
    :param D:
    :param num_needles:
    :param image_size:
    :param patch_size:
    :param num_iters:
    :return:
    """
    needles_shape = [int(((image_size[0] - D.shape[-2])/stride)+1), int(((image_size[1] - D.shape[-1])/stride)+1)]
    x = torch.randn(1, D.shape[0], *needles_shape).type_as(D)
    for _ in range(num_iters):
        c = torch.norm(x.reshape(-1))
        x = x / c
        y = functional.conv_transpose2d(x, D, stride=stride)
        x = functional.conv2d(y, D, stride=stride)
    return torch.norm(x.reshape(-1))
Beispiel #19
0
def aten_convolution(inputs, attributes, scope):
    inp, weight, bias = inputs[:3]
    stride, pad, dilation = inputs[3:6]
    transposed, output_padding, groups = inputs[6:9]
    net = current_network()
    if net is not None and has_trt_tensor(inputs):
        assert all([e == 0 for e in output_padding
                    ]), "tensor rt don't support out padding"
        if transposed:
            I, O_groups, *ksize = weight.shape
            O = O_groups * groups
        else:
            O, I_groups, *ksize = weight.shape
            I = I_groups * groups
        ndim = len(ksize)
        assert ndim == 2, "tensorrt only support 2d conv"
        # trt weight format: GKCRS: [num_groups, O_groups, I, H, W]
        weight = weight.detach().cpu().numpy()
        if bias is not None:
            bias = bias.detach().cpu().numpy()
        else:
            bias = trt.Weights()
        if transposed:
            layer = net.add_deconvolution(inputs[0], O, tuple(ksize), weight,
                                          bias)
        else:
            layer = net.add_convolution(inputs[0], O, tuple(ksize), weight,
                                        bias)
            layer.dilation = tuple(dilation)
        layer.stride = tuple(stride)
        layer.padding = tuple(pad)
        layer.num_groups = groups
        output = layer.get_output(0)
        output.name = scope
        layer.name = scope
        return [output]
    ndim = len(inputs[3])
    assert ndim == 2
    if transposed:
        res = F.conv_transpose2d(inp, weight, bias, stride, pad,
                                 output_padding, groups, dilation)
    else:
        res = F.conv2d(inp, weight, bias, stride, pad, dilation, groups)
    return [res]
    def forward(self, x):
        bias = self.bias
        if bias is not None:
            bias = bias * self.b_mul
        
        have_convolution = False
        if self.upscale is not None and min(x.shape[2:]) * 2 >= 128:
            # this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way
            # this really needs to be cleaned up and go into the conv...
            w = self.weight * self.w_mul
            w = w.permute(1, 0, 2, 3)
            # probably applying a conv on w would be more efficient. also this quadruples the weight (average)?!
            w = F.pad(w, (1,1,1,1))
            w = w[:, :, 1:, 1:]+ w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]
            x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1)-1)//2)
            have_convolution = True
        elif self.upscale is not None:
            x = self.upscale(x)
        
        downscale = self.downscale
        intermediate = self.intermediate
        if downscale is not None and min(x.shape[2:]) >= 128:
            w = self.weight * self.w_mul
            w = F.pad(w, (1,1,1,1))
            # in contrast to upscale, this is a mean...
            w = (w[:, :, 1:, 1:]+ w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1])*0.25 # avg_pool?
            x = F.conv2d(x, w, stride=2, padding=(w.size(-1)-1)//2)
            have_convolution = True
            downscale = None
        elif downscale is not None:
            assert intermediate is None
            intermediate = downscale
            
        if not have_convolution and intermediate is None:
            return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size//2)
        elif not have_convolution:
            x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size//2)

        if intermediate is not None:
            x = intermediate(x)

        if bias is not None:
            x = x + bias.view(1, -1, 1, 1)
        return x
Beispiel #21
0
    def forward(self, input, output_size=None):
        if self.train_scale:
            weight_scale = self.weight_scale
        else:
            weight_scale = Variable(self.weight_scale)
        # normalize weight matrix and linear projection [in x out x h x w]
        # for each output dimension, normalize through (in, h, w)  = (0, 2, 3) dims
        norm_weight = self.weight * (
            weight_scale[None, :, None, None] / torch.sqrt(
                (self.weight**2).sum(3, keepdim=True).sum(2, keepdim=True).sum(
                    0, keepdim=True) + 1e-6)).expand_as(self.weight)
        #norm_weight = self.weight * (weight_scale[None,:,None,None] / torch.sqrt((self.weight ** 2).sum(3).sum(2).sum(0) + 1e-6)).expand_as(self.weight)
        output_padding = self._output_padding(input, output_size)
        activation = F.conv_transpose2d(input,
                                        norm_weight,
                                        bias=None,
                                        stride=self.stride,
                                        padding=self.padding,
                                        output_padding=output_padding,
                                        groups=self.groups)

        if self.init_mode == True:
            mean_act = activation.mean(3).mean(2).mean(0).squeeze()
            activation = activation - mean_act[None, :, None,
                                               None].expand_as(activation)

            inv_stdv = self.init_stdv / torch.sqrt(
                (activation**2).mean(3).mean(2).mean(0) + 1e-6).squeeze()
            activation = activation * inv_stdv[None, :, None,
                                               None].expand_as(activation)

            if self.train_scale:
                self.weight_scale.data = self.weight_scale.data * inv_stdv.data
            else:
                self.weight_scale = self.weight_scale * inv_stdv.data
            self.bias.data = -mean_act.data * inv_stdv.data

        else:
            if self.bias is not None:
                activation = activation + self.bias[None, :, None,
                                                    None].expand_as(activation)

        return activation
 def conv2d_with_style_kernels(self, features, kernels, patch_size, deconv_flag=False):
     output = list()
     b, c, h, w = features.size()
     
     # padding
     pad = (patch_size - 1) // 2
     padding_size = (pad, pad, pad, pad)
     
     # batch-wise convolutions with style kernels
     for feature, kernel in zip(features, kernels):
         feature = F.pad(feature.unsqueeze(0), padding_size, 'constant', 0)
             
         if deconv_flag:
             padding_size = patch_size - 1
             output.append(F.conv_transpose2d(feature, kernel, padding=padding_size))
         else:
             output.append(F.conv2d(feature, kernel))
     
     return torch.cat(output, dim=0)
Beispiel #23
0
    def patch_copy_deconv(self, attention_score, context_filter):
        """Copy patches using deconv.

        Args:
            attention_score (torch.Tensor): Tensor with shape of (n, l , h, w).
            context_filter (torch.Tensor): Filter kernel.

        Returns:
            torch.Tensor: Tensor with shape of (n, c, h, w).
        """
        n, num_context, h, w = attention_score.size()
        attention_score = attention_score.view(1, -1, h, w)
        output = F.conv_transpose2d(attention_score,
                                    context_filter,
                                    stride=self.unfold_raw_stride,
                                    padding=self.unfold_raw_padding,
                                    groups=n)
        h_out, w_out = output.size()[-2:]
        return output.view(n, -1, h_out, w_out)
Beispiel #24
0
        def forward(ctx, input, weight, bias):
            if not transpose:
                out = F.conv2d(input=input,
                               weight=weight,
                               bias=bias,
                               **common_kwargs)

            else:
                out = F.conv_transpose2d(
                    input=input,
                    weight=weight,
                    bias=bias,
                    output_padding=output_padding,
                    **common_kwargs,
                )

            ctx.save_for_backward(input, weight)

            return out
def skeletonize(img):

    #binarize
    ret,th = cv2.threshold(255-img ,0,1,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    skeleton = skeletonize_ski(th)

    skeleton = torch.from_numpy(skeleton*255)[None,None,...]
    morph_kernel_dilate=3
    dilate_weights = torch.FloatTensor(1,1,morph_kernel_dilate,morph_kernel_dilate)
    r = morph_kernel_dilate//2
    for x in range(morph_kernel_dilate):
        for y in range(morph_kernel_dilate):
            dilate_weights[0,0,y,x] = float(((y-r)**2 + (x-r)**2) <= (r**2))
    out = F.conv_transpose2d(skeleton.float(),dilate_weights,stride=1,padding=1)#,padding=morph_padding)

    blur_kernel = 3
    blur_padding = blur_kernel // 2
    blur = torch.nn.AvgPool2d((blur_kernel,blur_kernel), stride=1, padding=(blur_padding,blur_padding))
    return 255-blur(out)[0,0].numpy()
Beispiel #26
0
Datei: stn.py Projekt: yf817/istn
    def compute_displacement(self, params):
        # compute dense displacement
        displacement = F.conv_transpose2d(params,
                                          self.kernel,
                                          padding=self.padding,
                                          stride=self.stride,
                                          groups=2)

        # crop displacement
        displacement = displacement[:, :, self.control_point_spacing[0] +
                                    self.crop_start[0]:
                                    -self.control_point_spacing[0] -
                                    self.crop_end[0],
                                    self.control_point_spacing[1] +
                                    self.crop_start[1]:
                                    -self.control_point_spacing[1] -
                                    self.crop_end[1]]

        return displacement.permute(0, 2, 3, 1)
Beispiel #27
0
 def forward(self, x):
     weight = self.weight * self.wscale
     bias = self.bias * self.bscale if self.bias is not None else None
     if self.use_conv2d_transpose:
         weight = weight.permute(1, 0, 2, 3).flip(2, 3)
         x = F.conv_transpose2d(x,
                                weight=weight,
                                bias=bias,
                                stride=self.scale_factor,
                                padding=self.padding)
         x = self.filter(x)
     else:
         x = F.conv2d(x,
                      weight=weight,
                      bias=bias,
                      stride=self.stride,
                      padding=self.padding)
     x = self.activate(x) * self.activate_scale
     return x
Beispiel #28
0
    def forward(self, input, sample=False):
        if self.training or sample:
            # during training we sample from the model distribution
            # sample = True can also be set during testing if we
            # want to use the stochastic/ensemble predictors
            weight = self.weight.sample()
            bias = self.bias.sample()
        else:
            # otherwise we use the posterior mean
            weight = self.weight.mu
            bias = self.bias.mu
        if self.training:
            # sum of the KL computed for weights and biases
            self.kl_div = self.weight.compute_kl(
                self.weight_prior) + self.bias.compute_kl(self.bias_prior)

        return F.conv_transpose2d(input, weight, bias, self.stride,
                                  self.padding, self.output_padding,
                                  self.groups, self.dilation)
Beispiel #29
0
    def compute_weight(self, module, do_power_iteration):
        r"""Where the deed is done.
        """
        A = getattr(module, self.name + "_orig")  # this is the kernel
        u = getattr(module, self.name + "_u")  # left eigenvector
        v = getattr(module, self.name + "_v")  # right eigenvector
        sigma = getattr(module, self.name + "_sigma")  # sigma, of course
        eps = torch.tensor(self.eps, device=A.device)
        stride = module.stride
        padding = module.padding
        dilation = module.dilation

        if do_power_iteration:
            with torch.no_grad():
                for _ in range(self.n_power_iterations):
                    v_ = F.conv2d(
                        u, A, stride=stride, padding=padding, dilation=dilation
                    )
                    beta = torch.max(v_.norm(), eps)
                    v = torch.div(v_, beta, out=v)

                    u_ = F.conv_transpose2d(
                        v, A, stride=stride, padding=padding, dilation=dilation
                    )

                    # this is the largest eigenvalue
                    sigma.copy_(torch.max(u_.norm(), eps))
                    u = torch.div(u_, sigma, out=u)

                    # See above on why we need to clone
                    if self.n_power_iterations > 0:
                        u = u.clone(memory_format=torch.contiguous_format)
                        v = v.clone(memory_format=torch.contiguous_format)

        if self._active:
            if self._leave_smaller:
                A = A / max(sigma.item() / self._lipschitz_k, 1)
            else:
                A = A / (sigma.item() / self._lipschitz_k)
        else:
            A = A + 0

        return A
    def forward(self, input, mask_in=None):
        assert len(input.shape) == 4
        if mask_in is not None or self.last_size != tuple(input.shape):
            self.last_size = tuple(input.shape)

            with torch.no_grad():
                if self.weight_maskUpdater.type() != input.type():
                    self.weight_maskUpdater = self.weight_maskUpdater.to(input)

                if mask_in is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3]).to(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)
                else:
                    mask = mask_in
                        
                self.update_mask = F.conv_transpose2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

                self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
                # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

        # if self.update_mask.type() != input.type() or self.mask_ratio.type() != input.type():
        #     self.update_mask.to(input)
        #     self.mask_ratio.to(input)

        raw_out = super(PartialConvTranspose2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)

        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)


        if self.return_mask:
            return output, self.update_mask
        else:
            return output