def forward(self, input): res = input #theta match_base = self.conv_match_L_base(input) shape_base = list(res.size()) input_groups = torch.split(match_base, 1, dim=0) # patch size for matching kernel = self.ksize # raw_w is for reconstruction raw_w = [] # w is for matching w = [] #build feature pyramid for i in range(len(self.scale)): ref = input if self.scale[i] != 1: ref = F.interpolate(input, scale_factor=self.scale[i], mode='bicubic') #feature transformation function f base = self.conv_assembly(ref) shape_input = base.shape #sampling raw_w_i = extract_image_patches(base, ksizes=[kernel, kernel], strides=[self.stride, self.stride], rates=[1, 1], padding='same') # [N, C*k*k, L] raw_w_i = raw_w_i.view(shape_input[0], shape_input[1], kernel, kernel, -1) raw_w_i = raw_w_i.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k] raw_w_i_groups = torch.split(raw_w_i, 1, dim=0) raw_w.append(raw_w_i_groups) #feature transformation function g ref_i = self.conv_match(ref) shape_ref = ref_i.shape #sampling w_i = extract_image_patches(ref_i, ksizes=[self.ksize, self.ksize], strides=[self.stride, self.stride], rates=[1, 1], padding='same') w_i = w_i.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1) w_i = w_i.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k] w_i_groups = torch.split(w_i, 1, dim=0) w.append(w_i_groups) y = [] for idx, xi in enumerate(input_groups): #group in a filter wi = torch.cat([w[i][idx][0] for i in range(len(self.scale))], dim=0) # [L, C, k, k] #normalize max_wi = torch.max( torch.sqrt( reduce_sum(torch.pow(wi, 2), axis=[1, 2, 3], keepdim=True)), self.escape_NaN) wi_normed = wi / max_wi #matching xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W yi = F.conv2d( xi, wi_normed, stride=1) # [1, L, H, W] L = shape_ref[2]*shape_ref[3] yi = yi.view(1, wi.shape[0], shape_base[2], shape_base[3]) # (B=1, C=32*32, H=32, W=32) # softmax matching score yi = F.softmax(yi * self.softmax_scale, dim=1) if self.average == False: yi = (yi == yi.max(dim=1, keepdim=True)[0]).float() # deconv for patch pasting raw_wi = torch.cat( [raw_w[i][idx][0] for i in range(len(self.scale))], dim=0) yi = F.conv_transpose2d(yi, raw_wi, stride=self.stride, padding=1) / 4. y.append(yi) y = torch.cat(y, dim=0) + res * self.res_scale # back to the mini-batch return y
def forward(self, f, b, mask=None): """ Contextual attention layer implementation. Contextual attention is first introduced in publication: Generative Image Inpainting with Contextual Attention, Yu et al. Args: f: Input feature to match (foreground). b: Input feature for match (background). mask: Input mask for b, indicating patches not available. ksize: Kernel size for contextual attention. stride: Stride for extracting patches from b. rate: Dilation for matching. softmax_scale: Scaled softmax for attention. Returns: torch.tensor: output """ # get shapes raw_int_fs = list(f.size()) # b*c*h*w raw_int_bs = list(b.size()) # b*c*h*w # extract patches from background with stride and rate kernel = 2 * self.rate # raw_w is extracted for reconstruction raw_w = extract_image_patches(b, ksizes=[kernel, kernel], strides=[self.rate, self.rate]) # b*hw*c*k*k raw_w_groups = torch.split(raw_w, 1, dim=0) # downscaling foreground option: downscaling both foreground and # background for matching and use original background for reconstruction. f = F.interpolate(f, scale_factor=1 / self.rate, mode='nearest') b = F.interpolate(b, scale_factor=1 / self.rate, mode='nearest') int_fs = list(f.size()) # b*c*h*w int_bs = list(b.size()) f_groups = torch.split( f, 1, dim=0) # split tensors along the batch dimension w = extract_image_patches(b, ksizes=[self.ksize, self.ksize], strides=[self.stride, self.stride]) # b*hw*c*k*k w_groups = torch.split(w, 1, dim=0) # process mask if mask is None: mask = torch.zeros([int_bs[0], 1, int_bs[2], int_bs[3]]) if self.use_cuda: mask = mask.cuda() else: mask = F.interpolate(mask, scale_factor=1. / (4. * self.rate), mode='nearest') m_groups = extract_image_patches(mask, ksizes=[self.ksize, self.ksize], strides=[self.stride, self.stride]) # b*hw*c*k*k # m = m[0] # hw*c*k*k # m = reduce_mean(m, axis=[1, 2, 3]) # hw*1*1*1 # m = m.permute(1, 0, 2, 3).contiguous() # 1*hw*1*1 # mm = (m==0).to(torch.float32) # 1*hw*1*1 y = [] offsets = [] k = self.fuse_k scale = self.softmax_scale * 255 # to fit the PyTorch tensor image value range fuse_weight = torch.eye(k).view(1, 1, k, k) # 1*1*k*k if self.use_cuda: fuse_weight = fuse_weight.cuda() for xi, wi, raw_wi, mi in zip(f_groups, w_groups, raw_w_groups, m_groups): ''' O => output channel as a conv filter I => input channel as a conv filter xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32) wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3) raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4) ''' # conv for compare escape_NaN = torch.FloatTensor([1e-4]) if self.use_cuda: escape_NaN = escape_NaN.cuda() wi = wi[0] # hw*c*k*k wi_normed = wi / torch.max( torch.sqrt(reduce_sum(torch.pow(wi, 2), axis=[1, 2, 3])), escape_NaN) xi_normed = same_padding(xi, [self.ksize, self.ksize], [1, 1]) # xi: 1*c*H*W yi = F.conv2d(xi_normed, wi_normed, stride=1) # 1*hw*H*W # conv implementation for fuse scores to encourage large patches if self.fuse: # make all of depth to spatial resolution yi = yi.view(1, 1, int_bs[2] * int_bs[3], int_fs[2] * int_fs[3]) # (B=1, I=1, H=32*32, W=32*32) yi = same_padding(yi, [k, k], [1, 1]) yi = F.conv2d(yi, fuse_weight, stride=1) # (B=1, C=1, H=32*32, W=32*32) yi = yi.contiguous().view(1, int_bs[2], int_bs[3], int_fs[2], int_fs[3]) # (B=1, 32, 32, 32, 32) yi = yi.permute(0, 2, 1, 4, 3) yi = yi.contiguous().view(1, 1, int_bs[2] * int_bs[3], int_fs[2] * int_fs[3]) yi = same_padding(yi, [k, k], [1, 1]) yi = F.conv2d(yi, fuse_weight, stride=1) yi = yi.contiguous().view(1, int_bs[3], int_bs[2], int_fs[3], int_fs[2]) yi = yi.permute(0, 2, 1, 4, 3) yi = yi.contiguous().view( 1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3]) # (B=1, C=32*32, H=32, W=32) # mi: hw*c*k*k mi = reduce_mean(mi, axis=[1, 2, 3]) # hw*1*1*1 mi = mi.permute(1, 0, 2, 3).contiguous() # 1*hw*1*1 mm = (mi == 0).to(torch.float32) # 1*hw*1*1 # softmax to match yi = yi * mm yi = F.softmax(yi * scale, dim=1) yi = yi * mm # 1*hw*H*W offset = torch.argmax(yi, dim=1, keepdim=True) # 1*1*H*W if int_bs != int_fs: # Normalize the offset value to match foreground dimension times = float(int_fs[2] * int_fs[3]) / float( int_bs[2] * int_bs[3]) offset = ((offset + 1).float() * times - 1).to(torch.int64) offset = torch.cat([offset // int_fs[3], offset % int_fs[3]], dim=1) # 1*2*H*W # deconv for patch pasting wi_center = raw_wi[0] yi = F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1) / 4. # (B=1, C=128, H=64, W=64) y.append(yi) offsets.append(offset) y = torch.cat(y, dim=0) # back to the mini-batch y.contiguous().view(raw_int_fs) offsets = torch.cat(offsets, dim=0) offsets = offsets.view(int_fs[0], 2, *int_fs[2:]) # case1: visualize optical flow: minus current position h_add = torch.arange(int_fs[2]).view([1, 1, int_fs[2], 1]).expand( int_fs[0], -1, -1, int_fs[3]) w_add = torch.arange(int_fs[3]).view([1, 1, 1, int_fs[3]]).expand( int_fs[0], -1, int_fs[2], -1) ref_coordinate = torch.cat([h_add, w_add], dim=1) # b*2*H*W if self.use_cuda: ref_coordinate = ref_coordinate.cuda() offsets = offsets - ref_coordinate # flow = pt_flow_to_image(offsets) flow = torch.from_numpy( flow_to_image(offsets.permute(0, 2, 3, 1).cpu().data.numpy())) / 255. flow = flow.permute(0, 3, 1, 2) if self.use_cuda: flow = flow.cuda() # case2: visualize which pixels are attended # flow = torch.from_numpy(highlight_flow((offsets * mask.long()).cpu().data.numpy())) if self.rate != 1: flow = F.interpolate(flow, scale_factor=self.rate * 4, mode='nearest') return y, flow
def forward(self, input): #get embedding embed_w = self.conv_assembly(input) match_input = self.conv_match_1(input) # b*c*h*w shape_input = list(embed_w.size()) # b*c*h*w input_groups = torch.split(match_input, 1, dim=0) # kernel size on input for matching kernel = self.scale * self.ksize # raw_w is extracted for reconstruction raw_w = extract_image_patches( embed_w, ksizes=[kernel, kernel], strides=[self.stride * self.scale, self.stride * self.scale], rates=[1, 1], padding='same') # [N, C*k*k, L] # raw_shape: [N, C, k, k, L] raw_w = raw_w.view(shape_input[0], shape_input[1], kernel, kernel, -1) raw_w = raw_w.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k] raw_w_groups = torch.split(raw_w, 1, dim=0) # downscaling X to form Y for cross-scale matching ref = F.interpolate(input, scale_factor=1. / self.scale, mode='bilinear') ref = self.conv_match_2(ref) w = extract_image_patches(ref, ksizes=[self.ksize, self.ksize], strides=[self.stride, self.stride], rates=[1, 1], padding='same') shape_ref = ref.shape # w shape: [N, C, k, k, L] w = w.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1) w = w.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k] w_groups = torch.split(w, 1, dim=0) y = [] scale = self.softmax_scale # 1*1*k*k #fuse_weight = self.fuse_weight for xi, wi, raw_wi in zip(input_groups, w_groups, raw_w_groups): # normalize wi = wi[0] # [L, C, k, k] max_wi = torch.max( torch.sqrt( reduce_sum(torch.pow(wi, 2), axis=[1, 2, 3], keepdim=True)), self.escape_NaN) wi_normed = wi / max_wi # Compute correlation map xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W yi = F.conv2d( xi, wi_normed, stride=1) # [1, L, H, W] L = shape_ref[2]*shape_ref[3] yi = yi.view(1, shape_ref[2] * shape_ref[3], shape_input[2], shape_input[3]) # (B=1, C=32*32, H=32, W=32) # rescale matching score yi = F.softmax(yi * scale, dim=1) if self.average == False: yi = (yi == yi.max(dim=1, keepdim=True)[0]).float() # deconv for reconsturction wi_center = raw_wi[0] yi = F.conv_transpose2d(yi, wi_center, stride=self.stride * self.scale, padding=self.scale) yi = yi / 6. y.append(yi) y = torch.cat(y, dim=0) return y