Пример #1
0
    def encode_decode(self, r, itrs):

        # encodes batch of images r (B, C, h, w)

        assert (self.itrs >= itrs), "itrs > Training Iterations"

        with torch.no_grad():

            # extract original image dimensions
            img_size = r.size()[2:]

            # convert images to patches
            r = f.sliding_window(input=r,
                                 kernel_size=self.p_s,
                                 stride=self.p_s)

            # init decoded patches to zero
            patches = 0.0

            for i in range(itrs):
                # encode & decode patches
                dec = self.ae_sys[i](r)
                r = r - dec
                # sum residual predictions
                patches += dec

            # clamp patch values [-1,1]
            patches = patches.clamp(-1, 1)

            # reshape patches to images
            image = f.refactor_windows(windows=patches, output_size=img_size)

        return image
Пример #2
0
    def encode_decode(self, r0, itrs, h_e=None, h_d=None):

        if self.itrs < itrs:
            warnings.warn('itrs > Training Iterations')

        with torch.no_grad():

            # init dec
            dec = None

            # extract original image dimensions
            img_size = r0.size()[2:]

            # covert images to patches
            r0 = f.sliding_window(input=r0,
                                  kernel_size=self.p_s,
                                  stride=self.p_s)

            r = r0

            for i in range(itrs):

                # encode & decode
                enc, h_e = self.encoder(r, h_e)
                b = self.binarizer(enc)
                dec, h_d = self.decoder(b, h_d)

                # calculate residual error
                r = r0 - dec

            # reshape patches to images
            dec = f.refactor_windows(windows=dec, output_size=img_size)

        return dec
Пример #3
0
    def inpaint_encode(self, x, h_e):
        # extract initial height & width
        x_h, x_w = x.size()[2:]

        # images -> patches
        x = f.sliding_window(input=x, kernel_size=self.p_s, stride=self.p_s)

        # encode patches
        x, h_e = self.inp_encoder(x, h_e)

        # binarize
        x = self.inp_bin(x)

        # bit patches -> bit images
        x = f.refactor_windows(windows=x, output_size=(x_h // 16, x_w // 16))

        return x, h_e
Пример #4
0
    def encode_decode(self, r0, itrs):

        # r0 images (B, C, h, w)
        if self.itrs < itrs:
            warnings.WarningMessage(
                "Specified Iterations > Training Iterations")

        # run in inference mode
        with torch.no_grad():

            # init variables
            r = r0
            dec = None
            h_e = h_d = None

            # extract original image dimensions
            img_size = r0.size()[2:]

            # covert images to patches
            r0 = f.sliding_window(input=r0,
                                  kernel_size=self.p_s,
                                  stride=self.p_s)

            for i in range(itrs):
                if i == 0:
                    # binary inpainting
                    enc, h_e = self.inpaint_encode(r, h_e)
                    dec, h_d = self.inpaint_decode(enc, h_d)
                else:
                    enc, h_e = self.res_encoder(r, h_e)
                    b = self.res_bin(enc)
                    dec, h_d = self.res_decoder(b, h_d)

                # calculate residual error
                r = r0 - dec

            # reshape patches to images
            dec = f.refactor_windows(windows=dec, output_size=img_size)

        return dec