예제 #1
0
def HT_y(y, sf, fft_BT):
    if len(y.shape) == 3:
        ch, w, h = y.shape
        # z = torch.zeros([ch, w*sf ,h*sf])
        # z[:,::sf, ::sf] = y
        z = F.pad(y, [0, 0, 0, 0, 0, sf * sf - 1], "constant", value=0)
        z = F.pixel_shuffle(z, upscale_factor=sf).view(bs, ch, w * sf, h * sf)

        f = torch.rfft(z, 2, onesided=False)
        fft_BT = fft_BT.unsqueeze(0).repeat(ch, 1, 1, 1)
        M = torch.cat(((f[:, :, :, 0] * fft_B[:, :, :, 0] -
                        f[:, :, :, 1] * fft_B[:, :, :, 1]).unsqueeze(3),
                       (f[:, :, :, 0] * fft_B[:, :, :, 1] +
                        f[:, :, :, 1] * fft_B[:, :, :, 0]).unsqueeze(3)), 3)
        Hz = torch.irfft(M, 2, onesided=False)
    elif len(y.shape) == 4:
        bs, ch, w, h = y.shape
        # z = torch.zeros([bs ,ch ,sf*w ,sf*w])
        # z[:,:,::sf,::sf] = y
        z = y.view(-1, 1, w, h)
        z = F.pad(z, [0, 0, 0, 0, 0, sf * sf - 1, 0, 0], "constant", value=0)
        z = F.pixel_shuffle(z, upscale_factor=sf).view(bs, ch, w * sf, h * sf)

        f = torch.rfft(z, 2, onesided=False)
        fft_BT = fft_BT.unsqueeze(0).unsqueeze(0).repeat(bs, ch, 1, 1, 1)
        M = torch.cat(
            ((f[:, :, :, :, 0] * fft_BT[:, :, :, :, 0] -
              f[:, :, :, :, 1] * fft_BT[:, :, :, :, 1]).unsqueeze(4),
             (f[:, :, :, :, 0] * fft_BT[:, :, :, :, 1] +
              f[:, :, :, :, 1] * fft_BT[:, :, :, :, 0]).unsqueeze(4)), 4)
        Hz = torch.irfft(M, 2, onesided=False)
    return Hz
예제 #2
0
    def img_sum(self, input_patch, input_img, kernel_size, weights_out):

        output = torch.zeros(input_img.size()).to(input_img.device)

        if (weights_out is
                None) or (weights_out.size(2) != input_img.size(2)) or (
                    weights_out.size(3) != input_img.size(3)):
            weights = torch.ones(1, kernel_size * kernel_size,
                                 input_patch.size(2), input_patch.size(3))
            weights_output = torch.zeros(1, 1, input_img.size(2),
                                         input_img.size(3))
        else:
            weights_output = weights_out

        for i in range(kernel_size):
            for j in range(kernel_size):
                in_x = input_patch[:, :, i::kernel_size, j::kernel_size]
                in_x = F.pixel_shuffle(in_x, kernel_size)
                output[:, :, i:i + in_x.size(2), j:j + in_x.size(3)] += in_x

                if (weights_out is None) or (
                        weights_out.size(2) != input_img.size(2)) or (
                            weights_out.size(3) != input_img.size(3)):
                    wei_x = weights[:, :, i::kernel_size, j::kernel_size]
                    wei_x = F.pixel_shuffle(wei_x, kernel_size)
                    weights_output[:, :, i:i + wei_x.size(2),
                                   j:j + wei_x.size(3)] += wei_x

        weights_output = weights_output.to(output.device)
        output = output / weights_output

        return output, weights_output
예제 #3
0
    def forward(self, in_tensor):
        N, C, H, W = in_tensor.size()

        kernel_tensor = self.down(in_tensor)
        kernel_tensor = self.encoder(kernel_tensor)
        kernel_tensor = F.pixel_shuffle(kernel_tensor, self.delta)
        kernel_tensor = F.softmax(kernel_tensor, dim=1)
        kernel_tensor = kernel_tensor.unfold(2, self.delta, step=self.delta)
        kernel_tensor = kernel_tensor.unfold(3, self.delta, step=self.delta)
        kernel_tensor = kernel_tensor.reshape(N, self.Kup**2, H, W,
                                              self.delta**2)
        kernel_tensor = kernel_tensor.permute(0, 2, 3, 1, 4)

        in_tensor = F.pad(in_tensor,
                          pad=(self.Kup // 2, self.Kup // 2, self.Kup // 2,
                               self.Kup // 2),
                          mode='constant',
                          value=0)
        in_tensor = in_tensor.unfold(dimension=2, size=self.Kup, step=1)
        in_tensor = in_tensor.unfold(3, self.Kup, step=1)
        in_tensor = in_tensor.reshape(N, C, H, W, -1)
        in_tensor = in_tensor.permute(0, 2, 3, 1, 4)

        out_tensor = torch.matmul(in_tensor, kernel_tensor)
        out_tensor = out_tensor.reshape(N, H, W, -1)
        out_tensor = out_tensor.permute(0, 3, 1, 2)
        out_tensor = F.pixel_shuffle(out_tensor, self.delta)
        out_tensor = self.out(out_tensor)
        return out_tensor
예제 #4
0
 def forward(self, input):
     x = input
     res = x
     x = self.relu(self.conv1(x))
     x = F.pixel_shuffle(x, 2)
     x = F.pixel_shuffle(self.upsample(res), 2) + self.conv2(x)
     return self.relu(x)
예제 #5
0
 def forward(self, z):
     #
     if self.up_type == 'shuffle':
         for il in np.arange(self.gen_layer_num):
             if il == 0:
                 z = self.convs[il](z.view(-1, self.nz))
                 z = z.view(-1,
                            self.ngf * (2**(self.gen_layer_num - 2 - il)),
                            8, 8)
                 z = functional.relu(self.BNs[il](z))
             elif il == self.gen_layer_num - 1:
                 z = self.convs[il](z)
                 z = functional.tanh(z)
                 z = functional.pixel_shuffle(z, 2)
             else:
                 z = self.convs[il](z)
                 z = functional.relu(self.BNs[il](z))
                 z = functional.pixel_shuffle(z, 2)
     else:
         for il in range(self.gen_layer_num):
             if il == self.gen_layer_num - 1:
                 z = functional.tanh(self.convs[il](z))
             else:
                 z = functional.relu(self.BNs[il](self.convs[il](z)))
     return z
예제 #6
0
    def forward(self, x):
        b, c, h, w = x.size()
        x = fill(x)
        top_x = self.down4(x)
        bottom_x = self.down2(x)

        top_x = self.top1(top_x)
        top_x = self.top2(top_x)
        top_x = self.top3(top_x)
        top_x = F.pixel_shuffle(top_x, 2)

        bottom_x = self.bottom1(bottom_x)
        bottom_x = torch.cat((bottom_x, top_x), 1)
        bottom_x = self.bottom_gate(bottom_x)
        bottom_x = self.bottom2(bottom_x)
        bottom_x = self.bottom3(bottom_x)
        bottom_x = F.pixel_shuffle(bottom_x, 2)

        x = self.main1(x)
        x = torch.cat((x, bottom_x), 1)
        x = self.main_gate(x)
        x = self.main2(x)
        x = self.main3(x)

        x = self.end(x)
        x = x[:, :, :h, :w]
        return x
    def forward(self, x, h0, h1, h2, h3, h4):
        x = self.conv1(x)

        h0_new = self.rnn0(x, h0)
        x = h0_new[0]
        x = F.pixel_shuffle(x, 2)

        h1_new = self.rnn1(x, h1)
        x = h1_new[0]
        x = F.pixel_shuffle(x, 2)

        h2_new = self.rnn2(x, h2)
        x = h2_new[0]
        x = F.pixel_shuffle(x, 2)

        h3_new = self.rnn3(x, h3)
        x = h3_new[0]
        x = F.pixel_shuffle(x, 2)

        h4_new = self.rnn4(x, h4)
        x = h4_new[0]
        x = F.pixel_shuffle(x, 2)

        x = self.conv2(x)
        x = F.tanh(x) / 2

        return x, h0_new, h1_new, h2_new, h3_new, h4_new
예제 #8
0
 def forward(self, x, HR):
     B, C, T, H, W = x.shape
     x = F.relu(self.bn3d_1(self.conv3d_1(x)))
     x = F.relu(self.bn3d_2(self.conv3d_2(x)))
     x = F.relu(self.bn3d_2_1(self.conv3d_2_1(x)))
     x = self.conv3d_2_2(x)
     x = self.head(x)
     x = self.middle_1(x)
     x = self.middle_2(x)
     x = self.middle_3(x)
     x = self.middle_4(x)
     x, x_att = self.last(x)
     x = self.fusion_head(x)
     x = self.Fusion_last(x)
     x = self.compress(F.relu(self.bn3d_2_2(x)))
     x = self.middle_6(x)
     x = self.middle_7(x)
     x = self.middle_8(x)
     x = self.middle_9(x)
     x = self.middle_10(x)
     x = self.middle_11(x)
     x = self.middle_12(x)
     x = F.relu(self.conv3d_3(F.relu(self.bn3d_3(x))))
     Rx = F.relu(self.conv3d_r1(x))
     Rx = self.conv3d_r2(Rx)
     Rx = F.relu(F.pixel_shuffle(Rx.squeeze_(2), 2))
     Rx = torch.unsqueeze(Rx, dim=2)
     Rx = F.relu(self.conv3d_r4(Rx))
     Rx = self.conv3d_r3(Rx)
     out = HR + F.pixel_shuffle(Rx.squeeze_(2), 2)
     return out
예제 #9
0
    def forward(self, xs):
        x1,x2,x3,x4 = xs

        x1 = self.dp1(x1)
        x3 = self.dp3(x3)
        x4 = self.dp4(x4)

        t1 = self.prelu1(self.conv1(self.arm1(x1)))
        t2 = self.prelu2(self.conv2(self.arm2(x2)))
        t3 = self.prelu3(self.conv3(self.arm3(x3)))
        t4 = self.prelu4(self.conv4(self.arm4(x4)))

        # s1 = t1
        s1 = F.pixel_shuffle(t1, upscale_factor=2)
        s2 = F.pixel_shuffle(t2, upscale_factor=4)
        s3 = F.pixel_shuffle(t3, upscale_factor=8)
        s4 = F.pixel_shuffle(t4, upscale_factor=16)

        fusion = self.fusenet(s1,s2,s3,s4)

        cls = self.conv_cls(fusion)
        cks = self.conv_ck(fusion)
        dist = self.dist_conv(fusion)

        output = {'cls':cls,'cks':cks,'dist':dist}
        return output
예제 #10
0
    def forward(self, ims, tList):

        # shape of ims : list of input images [[B,C,H,W], ...]
        # shape of tList : list of target time index (e.g. [1/4, 2/4, 3/4])

        b,c,h,w = ims[0].size()
    
        outs = torch.zeros([len(tList)+1, b,c, h*self.sf, w*self.sf]).cuda()
        

        # Get feature representation of each images
        enc_s = []
        for i in range((len(ims))):
            s = self.encoder(ims[i])
            enc_s.append(s)

        # Fuse or merge feautres using EFST
        enc_sf = self.efst(enc_s)
        

        # Spatial decoder
        dec_feat, rimg = self.decoder(enc_s[3], enc_sf) 
        rimg = F.pixel_shuffle(rimg, self.sf) 
        out = F.upsample(ims[3], scale_factor= self.sf, mode='bilinear') + rimg
        outs[0,:] = out

        # Flow estimator
        uI3 = F.upsample(ims[3], scale_factor=self.sf, mode='bilinear')
        uI4 = F.upsample(ims[4], scale_factor=self.sf, mode='bilinear')
        flow34 = self.pwcnet(uI3, uI4)
        flow43 = self.pwcnet(uI4, uI3)


        for l in range(len(tList)):

            featI = []
            
            t = tList[l]
            flowt0 = -t*(1-t)*flow34 + t*t*flow43
            flowt1 = (1-t)*(1-t)*flow34 -t*(1-t)*flow43
        
            # Feature interpolation network
            for i in range(len(enc_s[3])):
                fi = self.fi(enc_s[3][i], enc_s[4][i], flowt0, flowt1)                    
                featI.append(fi)


            # Generate LR intermediate frames
            dwI = (warp(ims[3], flowt0) + warp(ims[4], flowt1))/2.

            # Spatio-temporal decoder
            _, trimg = self.decoder(featI, dec_feat)
            trimg = F.pixel_shuffle(trimg, self.sf)

            wI = F.upsample(dwI, scale_factor=self.sf, mode='bilinear')
            out = wI + trimg
            outs[l+1, :] = out

        return outs
예제 #11
0
 def forward(self, x):
     cat_feats, out = self.backbone(x)
     msfe_out = self.msfe(cat_feats)
     body = self.relu(self.bn1(self.conv1(msfe_out)))
     edge = self.sigmoid(self.bn2(self.conv2(msfe_out)))
     final_body = F.pixel_shuffle(body, upscale_factor=self.scale_factor)
     final_edge = F.pixel_shuffle(edge, upscale_factor=self.scale_factor)
     return final_edge, final_body
    def forward(self, input, h_1, h_2, h_3, h_4):
        
        h_1 = self.rnn1(self.conv1(input), h_1)
        h_2 = self.rnn2(F.pixel_shuffle(h_1[0], 2), h_2)
        h_3 = self.rnn3(F.pixel_shuffle(h_2[0], 2), h_3)
        h_4 = self.rnn4(F.pixel_shuffle(h_3[0], 2), h_4)

        return torch.tanh(self.conv2(F.pixel_shuffle(h_4[0], 2))), h_1, h_2, h_3, h_4
예제 #13
0
    def forward(self, x):
        x = self.indexnet(x)
        
        y = torch.sigmoid(x)
        z = F.softmax(y, dim=1)

        idx_en = F.pixel_shuffle(z, 2)
        idx_de = F.pixel_shuffle(y, 2)

        return idx_en, idx_de
예제 #14
0
 def forward(self, input, iter):
     output = self.conv1(input)
     output = self.rnn1(output, iter)
     output = F.pixel_shuffle(output, 2)
     output = self.rnn2(output, iter)
     output = F.pixel_shuffle(output, 2)
     output = self.rnn3(output, iter)
     output = F.pixel_shuffle(output, 2)
     output = self.rnn4(output, iter)
     output = F.pixel_shuffle(output, 2)
     output = self.conv2(output)
     return output
예제 #15
0
 def forward(self, noise):
     out = self.fc_1(noise)
     out = F.leaky_relu(out, 0.2)
     out = self.fc_2(out)
     out = F.leaky_relu(out, 0.2)
     out = out.view([-1, 128, 7, 7])
     out = F.pixel_shuffle(out, 2)
     out = self.conv_3(out)
     out = self.bn_3(out)
     out = F.leaky_relu(out, 0.2)
     out = F.pixel_shuffle(out, 2)
     out = self.conv_4(out)
     out = F.tanh(out)
     return out
예제 #16
0
    def forward(self, x):
        x = F.pixel_shuffle(x, 2)
        x = self.conv1(x)

        x = F.pixel_shuffle(x, 2)
        x = self.conv2(x)

        x = F.pixel_shuffle(x, 2)
        x = self.conv3(x)

        x = F.pixel_shuffle(x, 2)
        x = self.conv4(x)

        x = F.pixel_shuffle(x, 2)
        return x
예제 #17
0
    def forward(self, input, hidden1, hidden2, hidden3, hidden4, unet_output1,
                unet_output2, wdec):
        init_conv, rnn1_i, rnn1_h, rnn2_i, rnn2_h, rnn3_i, rnn3_h, rnn4_i, rnn4_h, final_conv = wdec

        init_conv = init_conv + self.conv1.weight

        x = F.conv2d(input, init_conv, stride=1, padding=0)

        # x = self.conv1(input)
        hidden1 = self.rnn1(x, rnn1_i, rnn1_h, hidden1)

        # rnn 2
        x = hidden1[0]
        x = F.pixel_shuffle(x, 2)

        if self.v_compress and self.fuse_level >= 3:
            x = torch.cat([x, unet_output1[0], unet_output2[0]], dim=1)

        hidden2 = self.rnn2(x, rnn2_i, rnn2_h, hidden2)

        # rnn 3
        x = hidden2[0]
        x = F.pixel_shuffle(x, 2)

        if self.v_compress and self.fuse_level >= 2:
            x = torch.cat([x, unet_output1[1], unet_output2[1]], dim=1)

        hidden3 = self.rnn3(x, rnn3_i, rnn3_h, hidden3)

        # rnn 4
        x = hidden3[0]
        x = F.pixel_shuffle(x, 2)

        if self.v_compress:
            x = torch.cat([x, unet_output1[2], unet_output2[2]], dim=1)

        hidden4 = self.rnn4(x, rnn4_i, rnn4_h, hidden4)

        # final
        x = hidden4[0]
        x = F.pixel_shuffle(x, 2)

        final_conv = final_conv + self.conv2.weight
        x = F.conv2d(x, final_conv, stride=1, padding=0)

        x = F.tanh(x) / 2

        return x, hidden1, hidden2, hidden3, hidden4
예제 #18
0
 def forward(self, input, shape):
     N, C, L, H, W = shape
     out = self.lrule(self.conv(input))
     out = out.permute(0, 2, 1, 3, 4).reshape(N * L, -1, H, W)
     out = F.pixel_shuffle(out, 2)
     out = out.reshape(N, L, -1, H * 2, W * 2).permute(0, 2, 1, 3, 4)
     return out
예제 #19
0
 def forward(self, x):
     #x = space_to_depth(x)
     I = torch.cat((0.8 * x, x, 1.2 * x, 1.5 * x), dim=1)
     inc = self.inc(I)
     '''
     layer1 = self.layer1(space_to_depth(inc))
     layer2 = self.layer2(space_to_depth(layer1))
     layer3 = self.layer3(space_to_depth(layer2))
     '''
     layer1 = self.layer1(inc)
     layer2 = self.layer2(layer1)
     layer3 = self.layer3(layer2)
     #global_feature = self.global_feature(layer3)
     #inc = self.fusionblock0(global_feature,inc)
     #layer1 = self.fusionblock1(global_feature,layer1)
     #layer2 = self.fusionblock2(global_feature,layer2)
     #layer3 = self.fusionblock3(global_feature,layer3)
     #inter = self.inter(space_to_depth(layer3))
     #up0 = self.up0(inter)
     #inter_layer = torch.cat((up0,layer3),dim=1)
     #inter_layer = self.inter_layer(inter_layer)
     up1 = self.up1(layer3)
     layer4 = torch.cat((up1, layer2), dim=1)
     layer4 = self.layer4(layer4)
     up2 = self.up2(layer4)
     layer5 = torch.cat((up2, layer1), dim=1)
     layer5 = self.layer5(layer5)
     up3 = self.up3(layer5)
     layer6 = torch.cat((up3, inc), dim=1)
     layer6 = self.layer6(layer6)
     output = self.output(layer6)
     output = F.pixel_shuffle(output, 2)
     return output
예제 #20
0
파일: DUF_arch.py 프로젝트: JMU2021/DESRGAN
    def forward(self, x):
        '''
        x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D
        Generate filters and image residual:
        Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C
        Rx: [B, 3*16, 1, H, W]
        '''
        B, T, C, H, W = x.size()
        x = x.permute(0, 2, 1, 3, 4)  # [B,C,T,H,W] for Conv3D
        x_center = x[:, :, T // 2, :, :]
        x = self.conv3d_1(x)
        x = self.dense_block_1(x)
        x = self.dense_block_2(x)
        x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True)

        # image residual
        Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True))  # [B, 3*16, 1, H, W]

        # filter
        Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True))  # [B, 25*16, 1, H, W]
        Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1)

        # Adapt to official model weights
        if self.adapt_official:
            adapt_official(Rx, scale=self.scale)

        # dynamic filter
        out = self.dynamic_filter(x_center, Fx)  # [B, 3*R, H, W]
        out += Rx.squeeze_(2)
        out = F.pixel_shuffle(out, self.scale)  # [B, 3, H, W]
        return out
예제 #21
0
  def forward(self,x):
    data_shape = x.size()
    x = self.prepLayer(x)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    #x = self.layer4(x)
  
    #out = self.upsample(x, output_size=data_shape)
    out = self.upsample(x)
    out = F.pixel_shuffle(out, 2)
    # rather than probabilities we are making it a hard mask prediction
    # this is not it, we can restore binary logic later

    out = F.relu(self.bn1(self.conv1(out)))
    out = F.relu(self.bn2(self.conv2(out)))
    out = self.conv3(out)
    outshape = out.size()

    # min max scaling
    y = out.view(outshape[0], outshape[1], -1) 
    y = y - y.min(2, keepdim=True)[0]
    y = y/(y.max(2, keepdim=True)[0] )
    y = y.view(outshape)
    #mask = mask.float() # cast back to float sicne x is a ByteTensor now
    return y
    def forward(self, x):
        conv1 = self.conv1(x)
        pool1 = F.max_pool2d(conv1, kernel_size=2)

        conv2 = self.conv2(pool1)
        pool2 = F.max_pool2d(conv2, kernel_size=2)

        conv3 = self.conv3(pool2)
        pool3 = F.max_pool2d(conv3, kernel_size=2)

        conv4 = self.conv4(pool3)
        pool4 = F.max_pool2d(conv4, kernel_size=2)

        conv5 = self.conv5(pool4)

        up6 = self.up6(conv5)
        up6 = torch.cat([up6, conv4], 1)
        conv6 = self.conv6(up6)

        up7 = self.up7(conv6)
        up7 = torch.cat([up7, conv3], 1)
        conv7 = self.conv7(up7)
        
        up8 = self.up8(conv7)
        up8 = torch.cat([up8, conv2], 1)
        conv8 = self.conv8(up8)

        up9 = self.up9(conv8)
        up9 = torch.cat([up9, conv1], 1)
        conv9 = self.conv9(up9)

        conv10 = self.conv10(conv9)
        out = F.pixel_shuffle(conv10, 2)

        return out
 def forward(self, sample):
   out = self.preprocess(sample).view(-1, 512, 4, 4)
   for idx, block in enumerate(self.blocks):
     pos = torch.arange(out.size(-1), dtype=out.dtype, device=out.device) / 100
     pos = pos[None].expand(out.size(0), out.size(-1))
     sym = (pos[:, None, :] + pos[:, :, None]) / 2
     asym = (pos[:, None, :] - pos[:, :, None]) / 2
     combined = torch.cat((out, sym[:, None], asym[:, None]), dim=1)
     out = block(combined)
     out = func.pixel_shuffle(out, 2)
     #out = func.interpolate(out, scale_factor=2)
   mask = torch.arange(out.size(-1), device=out.device)
   mask = (mask[:, None] - mask[None, :]) > 0
   mask = mask.float()
   distances = func.softplus(self.distances(out))
   distances = (distances + distances.permute(0, 1, 3, 2)) / 2
   rotation = self.rotation(out)
   rotation = mask[None, None] * rotation + (1 - mask[None, None]) * rotation.permute(0, 1, 3, 2)
   #rotation = rotation + rotation.permute(0, 1, 3, 2)
   rotation = rotation.sin() / (rotation.sin().norm(dim=1, keepdim=True).detach() + 1e-6)
   direction = self.direction(out)
   direction = mask[None, None] * direction.permute(0, 1, 3, 2) + (1 - mask[None, None]) * direction
   #direction = direction + direction.permute(0, 1, 3, 2)
   direction = direction.sin() / (direction.sin().norm(dim=1, keepdim=True).detach() + 1e-6)
   #rotation, direction = self.predict_rotation(out)
   size = out.size(-1)
   ind = torch.arange(size, device=out.device)
   distances[:, :, ind, ind] = 0
   out = torch.cat((distances, rotation, direction), dim=1)
   return (out,)
예제 #24
0
 def forward(self, x):
     x = self.preprocess(x) if self.preprocess else x
     x = self.trns(x)
     x = torch.unsqueeze(x, 2)
     x = torch.unsqueeze(x, 2)
     x = F.pixel_shuffle(x, 2)
     return x
예제 #25
0
def upsample(img, scale, border='reflect'):
    """Bicubical upsample via **CONV2D**. Using PIL's kernel.

  Args:
    img: a tf tensor of 2/3/4-D.
    scale: must be integer >= 2.
    border: padding mode. Recommend to 'REFLECT'.
  """
    device = img.device
    kernels, s = weights_upsample(scale)
    if s == 1:
        return img  # bypass
    kernels = [k.astype('float32') for k in kernels]
    kernels = [torch.from_numpy(k) for k in kernels]
    p1 = 1 + s // 2
    p2 = 3
    img, shape = _push_shape_4d(img)
    img_ex = F.pad(img, [p1, p2, p1, p2], mode=border)
    c = img_ex.shape[1]
    assert c is not None, "img must define channel number"
    c = int(c)
    filters = [
        torch.reshape(torch.eye(c, c), [c, c, 1, 1]) * k for k in kernels
    ]
    weights = torch.stack(filters, dim=0).transpose(0,
                                                    1).reshape([-1, c, 5, 5])
    img_s = F.conv2d(img_ex, weights.to(device))
    img_s = F.pixel_shuffle(img_s, s)
    more = s // 2 * s
    crop = slice(more - s // 2, -(s // 2))
    img_s = _pop_shape(img_s[..., crop, crop], shape)
    return img_s
예제 #26
0
    def forward(self, x):
        """
        Args:
            x (Tensor): Input with shape (b, 7, c, h, w)

        Returns:
            Tensor: Output with shape (b, 1, h * scale, w * scale)
        """
        num_batches, num_imgs, _, h, w = x.size()

        x = x.permute(0, 2, 1, 3, 4)  # (b, c, 7, h, w) for Conv3D
        x_center = x[:, :, num_imgs // 2, :, :]

        x = self.conv3d1(x)
        x = self.dense_block1(x)
        x = self.dense_block2(x)
        x = F.relu(self.bn3d2(x), inplace=True)
        x = F.relu(self.conv3d2(x), inplace=True)

        # residual image
        res = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True))

        # filter
        filter_ = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True))
        filter_ = F.softmax(filter_.view(num_batches, 25, self.scale**2, h, w),
                            dim=1)

        # dynamic filter
        out = self.dynamic_filter(x_center, filter_)
        out += res.squeeze_(2)
        out = F.pixel_shuffle(out, self.scale)

        return out
예제 #27
0
    def sample(self, x):
        """
        Sample from the prior to generate a new
        datapoint.

        :param x: tensor representing shape of sample
        """
        h, w = self.x_shape
        batch_size = x.size(0)

        h_dec = x.new_zeros((batch_size, self.h_dim, h // 2, w // 2))
        c_dec = x.new_zeros((batch_size, self.h_dim, h // 2, w // 2))

        canvas = x.new_zeros((batch_size, self.x_dim, h, w))

        for _ in range(self.T):
            p_mu, p_log_std = torch.split(self.prior(h_dec), self.z_dim, dim=1)
            p_std = torch.exp(p_log_std)
            z = Normal(p_mu, p_std).sample()

            canvas_next = self.read_head(canvas)
            h_dec, c_dec = self.decoder(torch.cat([z, canvas_next], dim=1),
                                        [h_dec, c_dec])
            canvas = canvas + F.pixel_shuffle(self.write_head(h_dec), 2)

        return canvas
예제 #28
0
    def forward(self, input, tmp_FLAG=False):
        tmp_list = []
        batch_size, row, col = input.size(0), input.size(2), input.size(3)
        y = torch.autograd.Variable(torch.zeros(batch_size, 32, row,
                                                col)).cuda()
        c = torch.autograd.Variable(torch.zeros(batch_size, 32, row,
                                                col)).cuda()
        x = self.relu1(self.conv1(input))
        x = self.resbk1(x)
        x = self.relu2(self.conv2(x))
        x = self.resbk2(x)

        # loop mechanism, with conv-lstm at first
        for loop in range(3):
            concat_feature = self.relu_concate2(
                self.conv_concate2(
                    self.resbk_concate(
                        self.relu_concate1(
                            self.conv_concate1(torch.cat([x, y], dim=1))))))
            i = self.conv_i(concat_feature)
            f = self.conv_f(concat_feature)
            g = self.conv_g(concat_feature)
            o = self.conv_o(concat_feature)
            c = f * c + i * g  # c: hidden state
            h = o * torch.tanh(c)  # h: LSTM output
            y = self.mrc(h) + h
            y_upsampled = F.pixel_shuffle(self.ps2_conv(
                self.relu_inter_ps2(
                    self.conv_inter_ps2(
                        self.resbk_inter_ps(
                            self.relu_inter_ps1(
                                self.conv_inter_ps1(
                                    F.pixel_shuffle(self.ps1_conv(y),
                                                    upscale_factor=2))))))),
                                          upscale_factor=2)
            output = self.conv_final(
                self.relu_sr2(
                    self.conv_sr2(self.relu_sr1(self.conv_sr1(
                        y_upsampled))))) + torch.nn.functional.interpolate(
                            input, scale_factor=4, mode=self.interpolate)
            if tmp_FLAG:
                tmp_list.append(output)

        if tmp_FLAG:
            return output, tmp_list[:-1]
        else:
            return output
예제 #29
0
 def forward(self, x):
     print(x.size())
     for layer_idx, conv in enumerate(self.conv_layers):
         x = same_padding_conv(x, conv)
         x = F.relu(
             x) if layer_idx != len(self.conv_layers) - 1 else F.tanh(x)
     x = F.pixel_shuffle(x, 4)
     return x
 def forward(self, x):
     # start_time = datetime.datetime.now()
     x = self.channel_compressor(x)
     x = self.context_encoder(x)
     x = F.pixel_shuffle(x, self.enlarge_rate)
     x = self.kernel_normalizer(x)
     # print("KP cost:{}".format(datetime.datetime.now() - start_time))
     return x
    def forward(self, input, hidden1, hidden2, hidden3, hidden4):
        x = self.conv1(input)

        hidden1 = self.rnn1(x, hidden1)
        x = hidden1[0]
        x = F.pixel_shuffle(x, 2)

        hidden2 = self.rnn2(x, hidden2)
        x = hidden2[0]
        x = F.pixel_shuffle(x, 2)

        hidden3 = self.rnn3(x, hidden3)
        x = hidden3[0]
        x = F.pixel_shuffle(x, 2)

        hidden4 = self.rnn4(x, hidden4)
        x = hidden4[0]
        x = F.pixel_shuffle(x, 2)

        x = F.tanh(self.conv2(x)) / 2
        return x, hidden1, hidden2, hidden3, hidden4