def patch2im(x_patch, patchsize, stride, padding): padtop, padbottom, padleft, padright = padding counts = pyinn.col2im(torch.ones_like(x_patch), [patchsize]*2, [stride]*2, [0,0]) x = pyinn.col2im(x_patch.contiguous(), [patchsize]*2, [stride]*2, [0,0]) x = x/counts x = x[:,:,padtop:x.shape[2]-padbottom, padleft:x.shape[3]-padright] return x
def test_im2col_batch(self): src = Variable(torch.randn(4, 8, 7, 7).cuda()) k = 1 pad = 0 s = (1, 1) dst = P.im2col(src, k, s, pad) back = P.col2im(dst, k, s, pad) self.assertEqual((src - back).data.abs().max(), 0)
def forward(self, x): flow = self.block(x) x = self.conv1(x) x = P.im2col(x, 5, 1, 0) flow = P.im2col(flow, 5, 1, 0) x = x * flow x = P.col2im(x, 5, 1, 0) return x
def main(): window = np.array([150, 150]) overlap = np.array([75, 75]) lena = scipy.misc.face(True) print lena.shape # lena = torch.from_numpy(lena.transpose(2, 0, 1)) lena = torch.from_numpy(lena).unsqueeze(0) lena = Variable(lena.float()).cuda() ################################################## # Unfold and convolution #------------------------------------------------- # lena = lena.unfold(0, window[0], window[0] - overlap[0]).unfold(1, window[1], window[0] - overlap[1]) # s = lena.size() # lena = lena.contiguous().view(s[0]*s[1], 1, window[0],window[1]) # mask = make_identifier_mask(s[0], s[1]) # mask = torch.from_numpy(mask.transpose(2, 0, 1)).unsqueeze(1) # print mask.size() # print lena.size() # plt.ion() # for i in xrange(s[0]*s[1]): # t = F.conv_transpose2d(Variable(mask[i].unsqueeze(0)).float(), Variable(lena[i].unsqueeze(0)).float()) # plt.imshow(t.data.squeeze().numpy()) # plt.draw() # plt.pause(0.3) ################################################## # pyinn #------------------------------------------------- lena = P.im2col(lena, window, window - overlap, [0, 0]) # (768 x 1024) -> (1 x 150 x 150 x 7 x 9) s = lena.data.size() plot = torchvision.utils.make_grid(lena.squeeze().transpose( 0, 2).transpose(1, 3).contiguous().view(s[-1] * s[-2], 1, 150, 150).data, nrow=5, padding=10, normalize=True) plt.imshow(plot.cpu().numpy().transpose(1, 2, 0)) plt.show() lena = P.col2im(lena, window, window - overlap, [0, 0]) # (1 x 750 x 950) plt.imshow(lena.cpu().data.squeeze().numpy()) plt.show() pass
def forward(self, x1, x2): """ x2 should have better resolution :param x1: :param x2: :return: """ assert x1.is_cuda and x2.is_cuda, "Inputs are not in GPU!" debugplot = False windows = self.windows overlap = self.overlap inshape = x1.data.size() x = torch.cat([x1, x2], 2).cuda() # Unfold #----------- x.data = x.data.unfold(3, windows[1], overlap[1]).unfold(4, windows[2], overlap[2]).squeeze() s = x.data.size() x = x.contiguous().view(2, s[1]*s[2], windows[1], windows[2]) x = x.transpose(0, 1) outX = x.contiguous().view(s[1]*s[2], 1, 2, windows[1], windows[2]) if (debugplot): imglist = [outX[i,:,0].squeeze().data.unsqueeze(0) for i in xrange(outX.data.size()[0])] plot = torchvision.utils.make_grid(imglist, nrow=25, normalize=True) plt.ioff() fig = plt.figure(2) fig.clear() ax = fig.add_subplot(111) ax.cla() ax.imshow(plot.cpu().numpy().transpose(1, 2, 0)) plt.show() x = self.conv1(outX) x = self.fc1(x) x = self.conv2(torch.squeeze(x)) x = self.fc2(x) x = self.deconv2(x) x = self.fc3(x) s2 = x.data.size() x = self.deconv1(x.view(s2[0], s2[1], 1, s2[2], s2[3])) if (debugplot): imglist = [(x[i, 0, 0] + x[i, 0, 1]).view(1, windows[1], windows[2]).data for i in xrange(x.data.size()[0])] plot = torchvision.utils.make_grid(imglist, nrow=25, normalize=True) plt.ioff() fig = plt.figure(2) fig.clear() ax = fig.add_subplot(111) ax.cla() ax.imshow(plot.cpu().numpy().transpose(1, 2, 0)) plt.show() V = None for i in xrange(s[1]): l_v = None for j in xrange(s[2]): counter = i*s[1] + j try: lin = self.linearModules[counter] except KeyError: lin = nn.Linear(windows[0], 1) lin.weight.requires_grad = True lin.cuda() self.linearModules.append(lin) pass l_x = x[counter] l_s = l_x.data.size() l_x = l_x.transpose(len(l_s) - 3, len(l_s) - 1).contiguous().view(windows[1]*windows[2], windows[0]) l_x = lin(l_x) l_x = l_x.view(windows[1], windows[2]).contiguous() l_x = l_x.unsqueeze(0).unsqueeze(0) if (l_v is None): l_v = l_x else: l_v = torch.cat([l_v, l_x], 0) if (V is None): V = l_v else: V = torch.cat([V, l_v], 1) V = V.transpose(0, 3).transpose(1, 2).unsqueeze(0) # (1 x win1 x win2 x p1 x p2) if (debugplot): plot = torchvision.utils.make_grid(V.squeeze() .transpose(0, 2).transpose(1, 3) .contiguous() .view(s[1]*s[2], 1, windows[1], windows[2]).data , nrow = 25, padding=10, normalize=False) fig = plt.figure(2) fig.clear() ax = fig.add_subplot(111) ax.cla() ax.imshow(plot[0].cpu().numpy(), vmin=-10, vmax=10, cmap='Greys_r') plt.show() outX2 = col2im(V.contiguous() # remember to contiguous() here , windows[1:], windows[1:] - overlap[1:], [0,0]) if (debugplot): fig = plt.figure(2) ax = fig.add_subplot(111) ax.imshow(outX2.cpu().data.numpy()[0], cmap="Greys_r") plt.ioff() plt.show() x2s = outX2.data.size() outX2 = self.linear1(outX2.view(np.prod(x2s), 1)) outX2 = outX2.view_as(x2) x = x2 - outX2 # x = outX2 return x