def test(self):
        """Forward function used in test time"""
        # save the groundtruth and masked image
        self.save_results(self.img_truth, data_name='truth')
        # self.save_results(self.img_m, data_name='mask')

        # encoder process
        ##flip tensors
        distribution, f = self.net_E(torch.flip(self.img_m, [0]))
        q_distribution = torch.distributions.Normal(distribution[-1][0],
                                                    distribution[-1][1])
        scale_mask = task.scale_img(self.mask,
                                    size=[f[2].size(2), f[2].size(3)])

        # decoder process
        for i in range(self.opt.nsampling):
            z = q_distribution.sample()
            self.img_g, attn = self.net_G(z,
                                          f_m=f[-1],
                                          f_e=f[2],
                                          mask=scale_mask.chunk(3, dim=1)[0])
            self.img_out = (1 - self.mask) * self.img_g[-1].detach(
            ) + self.mask * self.img_m
            self.score = self.net_D(self.img_out)
            ##save the last iteration only
            if i == self.opt.nsampling - 1:
                self.save_results(self.img_out, data_name='flipped')
Esempio n. 2
0
 def get_G_inputs(self, p_distribution, q_distribution, f):
     """Process the encoder feature and distributions for generation network, combine two dataflow when implement."""
     f_m = torch.cat([f[-1].chunk(2)[0], f[-1].chunk(2)[0]], dim=0)
     f_e = torch.cat([f[2].chunk(2)[0], f[2].chunk(2)[0]], dim=0)
     scale_mask = task.scale_img(self.mask, size=[f_e.size(2), f_e.size(3)])
     mask = torch.cat([scale_mask.chunk(3, dim=1)[0], scale_mask.chunk(3, dim=1)[0]], dim=0)
     z_p = p_distribution.rsample()
     z_q = q_distribution.rsample()
     z = torch.cat([z_p, z_q], dim=0)
     return z, f_m, f_e, mask
Esempio n. 3
0
    def test(self, mark=None):
        """Forward function used in test time"""
        # save the groundtruth and masked image
        self.save_results(self.img_truth, data_name='truth')
        self.save_results(self.img_m, data_name='mask')

        # encoder process
        distribution, f, f_text = self.net_E(
            self.img_m, self.sentence_embedding, self.word_embeddings, self.text_mask, self.mask)
        variation_factor = 0. if self.opt.no_variance else 1.
        q_distribution = torch.distributions.Normal(distribution[-1][0], distribution[-1][1] * variation_factor)
        scale_mask = task.scale_img(self.mask, size=[f[2].size(2), f[2].size(3)])

        # decoder process
        for i in range(self.opt.nsampling):
            z = q_distribution.sample()

            self.img_g, attn = self.net_G(z, f_text, f_e=f[2], mask=scale_mask.chunk(3, dim=1)[0])
            self.img_out = (1 - self.mask) * self.img_g[-1].detach() + self.mask * self.img_m
            self.score = self.net_D(self.img_out)
            self.save_results(self.img_out, i, data_name='out', mark=mark)