示例#1
0
文件: image.py 项目: roromaniac/xfuse
    def _sample_image(self, x, decoded):
        def _create_mu_decoder():
            decoder = torch.nn.Sequential(
                torch.nn.Conv2d(self.num_channels,
                                self.num_channels,
                                kernel_size=1),
                torch.nn.BatchNorm2d(self.num_channels, momentum=0.05),
                torch.nn.LeakyReLU(0.2, inplace=True),
                torch.nn.Conv2d(self.num_channels,
                                x["image"].shape[1],
                                kernel_size=1),
                torch.nn.Tanh(),
            )
            torch.nn.init.constant_(decoder[-2].weight, 0.0)
            mean = x["image"].mean((0, 2, 3))
            decoder[-2].bias.data = ((1 + mean) / (1 - mean)).log() / 2
            return decoder

        def _create_sd_decoder():
            decoder = torch.nn.Sequential(
                torch.nn.Conv2d(self.num_channels,
                                self.num_channels,
                                kernel_size=1),
                torch.nn.BatchNorm2d(self.num_channels, momentum=0.05),
                torch.nn.LeakyReLU(0.2, inplace=True),
                torch.nn.Conv2d(self.num_channels,
                                x["image"].shape[1],
                                kernel_size=1),
                torch.nn.Softplus(),
            )
            torch.nn.init.constant_(decoder[-2].weight, 0.0)
            std = x["image"].std((0, 2, 3))
            decoder[-2].bias.data = (std.exp() - 1).log()
            return decoder

        img_mu = get_module("img_mu", _create_mu_decoder, checkpoint=True)
        img_sd = get_module("img_sd", _create_sd_decoder, checkpoint=True)
        mu = img_mu(decoded)
        sd = img_sd(decoded)

        image_distr = Normal(mu, 1e-8 + sd).to_event(3)
        pyro.sample(
            "image",
            image_distr,
            obs=center_crop(x["image"], image_distr.shape()),
        )
        return image_distr