def forward(self, input):
        res = input
        #theta
        match_base = self.conv_match_L_base(input)
        shape_base = list(res.size())
        input_groups = torch.split(match_base, 1, dim=0)
        # patch size for matching
        kernel = self.ksize
        # raw_w is for reconstruction
        raw_w = []
        # w is for matching
        w = []
        #build feature pyramid
        for i in range(len(self.scale)):
            ref = input
            if self.scale[i] != 1:
                ref = F.interpolate(input,
                                    scale_factor=self.scale[i],
                                    mode='bicubic')
            #feature transformation function f
            base = self.conv_assembly(ref)
            shape_input = base.shape
            #sampling
            raw_w_i = extract_image_patches(base,
                                            ksizes=[kernel, kernel],
                                            strides=[self.stride, self.stride],
                                            rates=[1, 1],
                                            padding='same')  # [N, C*k*k, L]
            raw_w_i = raw_w_i.view(shape_input[0], shape_input[1], kernel,
                                   kernel, -1)
            raw_w_i = raw_w_i.permute(0, 4, 1, 2,
                                      3)  # raw_shape: [N, L, C, k, k]
            raw_w_i_groups = torch.split(raw_w_i, 1, dim=0)
            raw_w.append(raw_w_i_groups)

            #feature transformation function g
            ref_i = self.conv_match(ref)
            shape_ref = ref_i.shape
            #sampling
            w_i = extract_image_patches(ref_i,
                                        ksizes=[self.ksize, self.ksize],
                                        strides=[self.stride, self.stride],
                                        rates=[1, 1],
                                        padding='same')
            w_i = w_i.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize,
                           -1)
            w_i = w_i.permute(0, 4, 1, 2, 3)  # w shape: [N, L, C, k, k]
            w_i_groups = torch.split(w_i, 1, dim=0)
            w.append(w_i_groups)

        y = []
        for idx, xi in enumerate(input_groups):
            #group in a filter
            wi = torch.cat([w[i][idx][0] for i in range(len(self.scale))],
                           dim=0)  # [L, C, k, k]
            #normalize
            max_wi = torch.max(
                torch.sqrt(
                    reduce_sum(torch.pow(wi, 2), axis=[1, 2, 3],
                               keepdim=True)), self.escape_NaN)
            wi_normed = wi / max_wi
            #matching
            xi = same_padding(xi, [self.ksize, self.ksize], [1, 1],
                              [1, 1])  # xi: 1*c*H*W
            yi = F.conv2d(
                xi, wi_normed,
                stride=1)  # [1, L, H, W] L = shape_ref[2]*shape_ref[3]
            yi = yi.view(1, wi.shape[0], shape_base[2],
                         shape_base[3])  # (B=1, C=32*32, H=32, W=32)
            # softmax matching score
            yi = F.softmax(yi * self.softmax_scale, dim=1)

            if self.average == False:
                yi = (yi == yi.max(dim=1, keepdim=True)[0]).float()

            # deconv for patch pasting
            raw_wi = torch.cat(
                [raw_w[i][idx][0] for i in range(len(self.scale))], dim=0)
            yi = F.conv_transpose2d(yi, raw_wi, stride=self.stride,
                                    padding=1) / 4.
            y.append(yi)

        y = torch.cat(y,
                      dim=0) + res * self.res_scale  # back to the mini-batch
        return y
Beispiel #2
0
    def forward(self, f, b, mask=None):
        """ Contextual attention layer implementation.
        Contextual attention is first introduced in publication:
            Generative Image Inpainting with Contextual Attention, Yu et al.
        Args:
            f: Input feature to match (foreground).
            b: Input feature for match (background).
            mask: Input mask for b, indicating patches not available.
            ksize: Kernel size for contextual attention.
            stride: Stride for extracting patches from b.
            rate: Dilation for matching.
            softmax_scale: Scaled softmax for attention.
        Returns:
            torch.tensor: output
        """
        # get shapes
        raw_int_fs = list(f.size())  # b*c*h*w
        raw_int_bs = list(b.size())  # b*c*h*w

        # extract patches from background with stride and rate
        kernel = 2 * self.rate
        # raw_w is extracted for reconstruction
        raw_w = extract_image_patches(b,
                                      ksizes=[kernel, kernel],
                                      strides=[self.rate,
                                               self.rate])  # b*hw*c*k*k
        raw_w_groups = torch.split(raw_w, 1, dim=0)

        # downscaling foreground option: downscaling both foreground and
        # background for matching and use original background for reconstruction.
        f = F.interpolate(f, scale_factor=1 / self.rate, mode='nearest')
        b = F.interpolate(b, scale_factor=1 / self.rate, mode='nearest')
        int_fs = list(f.size())  # b*c*h*w
        int_bs = list(b.size())
        f_groups = torch.split(
            f, 1, dim=0)  # split tensors along the batch dimension

        w = extract_image_patches(b,
                                  ksizes=[self.ksize, self.ksize],
                                  strides=[self.stride,
                                           self.stride])  # b*hw*c*k*k
        w_groups = torch.split(w, 1, dim=0)

        # process mask
        if mask is None:
            mask = torch.zeros([int_bs[0], 1, int_bs[2], int_bs[3]])
            if self.use_cuda:
                mask = mask.cuda()
        else:
            mask = F.interpolate(mask,
                                 scale_factor=1. / (4. * self.rate),
                                 mode='nearest')
        m_groups = extract_image_patches(mask,
                                         ksizes=[self.ksize, self.ksize],
                                         strides=[self.stride,
                                                  self.stride])  # b*hw*c*k*k

        # m = m[0]  # hw*c*k*k
        # m = reduce_mean(m, axis=[1, 2, 3])  # hw*1*1*1
        # m = m.permute(1, 0, 2, 3).contiguous()  # 1*hw*1*1
        # mm = (m==0).to(torch.float32)   # 1*hw*1*1

        y = []
        offsets = []
        k = self.fuse_k
        scale = self.softmax_scale * 255  # to fit the PyTorch tensor image value range
        fuse_weight = torch.eye(k).view(1, 1, k, k)  # 1*1*k*k
        if self.use_cuda:
            fuse_weight = fuse_weight.cuda()

        for xi, wi, raw_wi, mi in zip(f_groups, w_groups, raw_w_groups,
                                      m_groups):
            '''
            O => output channel as a conv filter
            I => input channel as a conv filter
            xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
            wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
            raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
            '''
            # conv for compare
            escape_NaN = torch.FloatTensor([1e-4])
            if self.use_cuda:
                escape_NaN = escape_NaN.cuda()
            wi = wi[0]  # hw*c*k*k
            wi_normed = wi / torch.max(
                torch.sqrt(reduce_sum(torch.pow(wi, 2), axis=[1, 2, 3])),
                escape_NaN)
            xi_normed = same_padding(xi, [self.ksize, self.ksize],
                                     [1, 1])  # xi: 1*c*H*W
            yi = F.conv2d(xi_normed, wi_normed, stride=1)  # 1*hw*H*W

            # conv implementation for fuse scores to encourage large patches
            if self.fuse:
                # make all of depth to spatial resolution
                yi = yi.view(1, 1, int_bs[2] * int_bs[3], int_fs[2] *
                             int_fs[3])  # (B=1, I=1, H=32*32, W=32*32)
                yi = same_padding(yi, [k, k], [1, 1])
                yi = F.conv2d(yi, fuse_weight,
                              stride=1)  # (B=1, C=1, H=32*32, W=32*32)

                yi = yi.contiguous().view(1, int_bs[2], int_bs[3], int_fs[2],
                                          int_fs[3])  # (B=1, 32, 32, 32, 32)
                yi = yi.permute(0, 2, 1, 4, 3)
                yi = yi.contiguous().view(1, 1, int_bs[2] * int_bs[3],
                                          int_fs[2] * int_fs[3])
                yi = same_padding(yi, [k, k], [1, 1])
                yi = F.conv2d(yi, fuse_weight, stride=1)
                yi = yi.contiguous().view(1, int_bs[3], int_bs[2], int_fs[3],
                                          int_fs[2])
                yi = yi.permute(0, 2, 1, 4, 3)
                yi = yi.contiguous().view(
                    1, int_bs[2] * int_bs[3], int_fs[2],
                    int_fs[3])  # (B=1, C=32*32, H=32, W=32)

            # mi: hw*c*k*k
            mi = reduce_mean(mi, axis=[1, 2, 3])  # hw*1*1*1
            mi = mi.permute(1, 0, 2, 3).contiguous()  # 1*hw*1*1
            mm = (mi == 0).to(torch.float32)  # 1*hw*1*1

            # softmax to match
            yi = yi * mm
            yi = F.softmax(yi * scale, dim=1)
            yi = yi * mm  # 1*hw*H*W

            offset = torch.argmax(yi, dim=1, keepdim=True)  # 1*1*H*W
            if int_bs != int_fs:
                # Normalize the offset value to match foreground dimension
                times = float(int_fs[2] * int_fs[3]) / float(
                    int_bs[2] * int_bs[3])
                offset = ((offset + 1).float() * times - 1).to(torch.int64)
            offset = torch.cat([offset // int_fs[3], offset % int_fs[3]],
                               dim=1)  # 1*2*H*W

            # deconv for patch pasting
            wi_center = raw_wi[0]
            yi = F.conv_transpose2d(yi, wi_center, stride=self.rate,
                                    padding=1) / 4.  # (B=1, C=128, H=64, W=64)
            y.append(yi)
            offsets.append(offset)

        y = torch.cat(y, dim=0)  # back to the mini-batch
        y.contiguous().view(raw_int_fs)

        offsets = torch.cat(offsets, dim=0)
        offsets = offsets.view(int_fs[0], 2, *int_fs[2:])

        # case1: visualize optical flow: minus current position
        h_add = torch.arange(int_fs[2]).view([1, 1, int_fs[2], 1]).expand(
            int_fs[0], -1, -1, int_fs[3])
        w_add = torch.arange(int_fs[3]).view([1, 1, 1, int_fs[3]]).expand(
            int_fs[0], -1, int_fs[2], -1)
        ref_coordinate = torch.cat([h_add, w_add], dim=1)  # b*2*H*W
        if self.use_cuda:
            ref_coordinate = ref_coordinate.cuda()

        offsets = offsets - ref_coordinate
        # flow = pt_flow_to_image(offsets)

        flow = torch.from_numpy(
            flow_to_image(offsets.permute(0, 2, 3,
                                          1).cpu().data.numpy())) / 255.
        flow = flow.permute(0, 3, 1, 2)
        if self.use_cuda:
            flow = flow.cuda()
        # case2: visualize which pixels are attended
        # flow = torch.from_numpy(highlight_flow((offsets * mask.long()).cpu().data.numpy()))

        if self.rate != 1:
            flow = F.interpolate(flow,
                                 scale_factor=self.rate * 4,
                                 mode='nearest')

        return y, flow
Beispiel #3
0
    def forward(self, input):
        #get embedding
        embed_w = self.conv_assembly(input)
        match_input = self.conv_match_1(input)

        # b*c*h*w
        shape_input = list(embed_w.size())  # b*c*h*w
        input_groups = torch.split(match_input, 1, dim=0)
        # kernel size on input for matching
        kernel = self.scale * self.ksize

        # raw_w is extracted for reconstruction
        raw_w = extract_image_patches(
            embed_w,
            ksizes=[kernel, kernel],
            strides=[self.stride * self.scale, self.stride * self.scale],
            rates=[1, 1],
            padding='same')  # [N, C*k*k, L]
        # raw_shape: [N, C, k, k, L]
        raw_w = raw_w.view(shape_input[0], shape_input[1], kernel, kernel, -1)
        raw_w = raw_w.permute(0, 4, 1, 2, 3)  # raw_shape: [N, L, C, k, k]
        raw_w_groups = torch.split(raw_w, 1, dim=0)

        # downscaling X to form Y for cross-scale matching
        ref = F.interpolate(input,
                            scale_factor=1. / self.scale,
                            mode='bilinear')
        ref = self.conv_match_2(ref)
        w = extract_image_patches(ref,
                                  ksizes=[self.ksize, self.ksize],
                                  strides=[self.stride, self.stride],
                                  rates=[1, 1],
                                  padding='same')
        shape_ref = ref.shape
        # w shape: [N, C, k, k, L]
        w = w.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1)
        w = w.permute(0, 4, 1, 2, 3)  # w shape: [N, L, C, k, k]
        w_groups = torch.split(w, 1, dim=0)

        y = []
        scale = self.softmax_scale
        # 1*1*k*k
        #fuse_weight = self.fuse_weight

        for xi, wi, raw_wi in zip(input_groups, w_groups, raw_w_groups):
            # normalize
            wi = wi[0]  # [L, C, k, k]
            max_wi = torch.max(
                torch.sqrt(
                    reduce_sum(torch.pow(wi, 2), axis=[1, 2, 3],
                               keepdim=True)), self.escape_NaN)
            wi_normed = wi / max_wi

            # Compute correlation map
            xi = same_padding(xi, [self.ksize, self.ksize], [1, 1],
                              [1, 1])  # xi: 1*c*H*W
            yi = F.conv2d(
                xi, wi_normed,
                stride=1)  # [1, L, H, W] L = shape_ref[2]*shape_ref[3]

            yi = yi.view(1, shape_ref[2] * shape_ref[3], shape_input[2],
                         shape_input[3])  # (B=1, C=32*32, H=32, W=32)
            # rescale matching score
            yi = F.softmax(yi * scale, dim=1)
            if self.average == False:
                yi = (yi == yi.max(dim=1, keepdim=True)[0]).float()

            # deconv for reconsturction
            wi_center = raw_wi[0]
            yi = F.conv_transpose2d(yi,
                                    wi_center,
                                    stride=self.stride * self.scale,
                                    padding=self.scale)

            yi = yi / 6.
            y.append(yi)

        y = torch.cat(y, dim=0)
        return y