Exemple #1
0
    def inpaint_decode(self, x, h_d):

        # init image height & width
        init_size = (x.size(2) * 16, x.size(3) * 16)

        # replication pad bits
        x = F.pad(input=x,
                  pad=(self.p_w // 16, self.p_w // 16, self.p_h // 16,
                       self.p_h // 16),
                  mode="replicate")

        # bit image -> bit grid
        x = f.sliding_window(input=x,
                             kernel_size=(self.g_h // 16, self.g_w // 16),
                             stride=(self.p_h // 16, self.p_w // 16))

        # bit grid -> bit patches
        x = f.sliding_window(input=x,
                             kernel_size=(self.p_h // 16, self.p_w // 16),
                             stride=(self.p_h // 16, self.p_w // 16))

        # concat bits
        x = x.view(-1, self.inp_decoder.bnd, *x.size()[2:])

        # decode and save decoder state
        x, h_d = self.inp_decoder(x, h_d)

        return x, h_d
Exemple #2
0
    def encode_decode(self, r, itrs):

        # encodes batch of images r (B, C, h, w)

        assert (self.itrs >= itrs), "itrs > Training Iterations"

        with torch.no_grad():

            # extract original image dimensions
            img_size = r.size()[2:]

            # convert images to patches
            r = f.sliding_window(input=r,
                                 kernel_size=self.p_s,
                                 stride=self.p_s)

            # init decoded patches to zero
            patches = 0.0

            for i in range(itrs):
                # encode & decode patches
                dec = self.ae_sys[i](r)
                r = r - dec
                # sum residual predictions
                patches += dec

            # clamp patch values [-1,1]
            patches = patches.clamp(-1, 1)

            # reshape patches to images
            image = f.refactor_windows(windows=patches, output_size=img_size)

        return image
Exemple #3
0
    def encode_decode(self, r0, itrs, h_e=None, h_d=None):

        if self.itrs < itrs:
            warnings.warn('itrs > Training Iterations')

        with torch.no_grad():

            # init dec
            dec = None

            # extract original image dimensions
            img_size = r0.size()[2:]

            # covert images to patches
            r0 = f.sliding_window(input=r0,
                                  kernel_size=self.p_s,
                                  stride=self.p_s)

            r = r0

            for i in range(itrs):

                # encode & decode
                enc, h_e = self.encoder(r, h_e)
                b = self.binarizer(enc)
                dec, h_d = self.decoder(b, h_d)

                # calculate residual error
                r = r0 - dec

            # reshape patches to images
            dec = f.refactor_windows(windows=dec, output_size=img_size)

        return dec
Exemple #4
0
    def encode(self, x):

        # grids -> patches
        x = f.sliding_window(input=x, kernel_size=self.p_s, stride=self.p_s)

        # encode & binarize
        x = self.encoder(x)
        x = self.binarizer(x)

        return x
Exemple #5
0
    def _inp_encode(self, x, h_e):
        # images -> patches
        x = f.sliding_window(input=x, kernel_size=self.p_s, stride=self.p_s)
        # encode patches
        x, h_e = self.inp_encoder(x, h_e)

        # binarize
        x = self.inp_bin(x)

        # remove unnecessary states
        h_e = [h[4::9] for h in h_e]
        return x, h_e
Exemple #6
0
    def inpaint_encode(self, x, h_e):
        # extract initial height & width
        x_h, x_w = x.size()[2:]

        # images -> patches
        x = f.sliding_window(input=x, kernel_size=self.p_s, stride=self.p_s)

        # encode patches
        x, h_e = self.inp_encoder(x, h_e)

        # binarize
        x = self.inp_bin(x)

        # bit patches -> bit images
        x = f.refactor_windows(windows=x, output_size=(x_h // 16, x_w // 16))

        return x, h_e
Exemple #7
0
    def display_inpainting(self, save=False):

        if self.model.name in ["ConvAR", "ConvRNN"]:
            # sanity check
            print("Specified model is not an inpainting network!")
            return

        # model patch height & width
        p_h, p_w = self.model.p_s

        # context region, ground truth, inpainting
        context = iter(self.patch_dl).next()
        inpaint = self.apply_model(context, itrs=1)[0].cpu()
        g_truth = context[0][:, p_h:-p_h, p_w:-p_w].contiguous()

        if self.model.name in ["MaskedBINet", "SINet"]:
            # mask center patch
            context[:, :, p_h:-p_h, p_w:-p_w] = -1.0
        if self.model.name in ["SINet"]:
            # white out patches not in context
            context[:, :, -p_h:] = 1.0
            context[:, :, p_h:-p_h, -p_w:] = 1.0

        context = f.sliding_window(context, p_h, p_w)
        context = make_grid(context,
                            nrow=3,
                            padding=1,
                            pad_value=1,
                            normalize=True)

        inpaint = self.inv_norm(inpaint)
        g_truth = self.inv_norm(g_truth)

        # display inpainting
        im_t.display_inpaint(context, inpaint, g_truth)

        if save:
            # save images
            im_t.save_img(context, save_loc="./context.png")
            im_t.save_img(inpaint, save_loc="./inpaint.png")
            im_t.save_img(g_truth, save_loc="./g_truth.png")

        return
Exemple #8
0
    def encode_decode(self, r0, itrs):

        # r0 images (B, C, h, w)
        if self.itrs < itrs:
            warnings.WarningMessage(
                "Specified Iterations > Training Iterations")

        # run in inference mode
        with torch.no_grad():

            # init variables
            r = r0
            dec = None
            h_e = h_d = None

            # extract original image dimensions
            img_size = r0.size()[2:]

            # covert images to patches
            r0 = f.sliding_window(input=r0,
                                  kernel_size=self.p_s,
                                  stride=self.p_s)

            for i in range(itrs):
                if i == 0:
                    # binary inpainting
                    enc, h_e = self.inpaint_encode(r, h_e)
                    dec, h_d = self.inpaint_decode(enc, h_d)
                else:
                    enc, h_e = self.res_encoder(r, h_e)
                    b = self.res_bin(enc)
                    dec, h_d = self.res_decoder(b, h_d)

                # calculate residual error
                r = r0 - dec

            # reshape patches to images
            dec = f.refactor_windows(windows=dec, output_size=img_size)

        return dec