Ejemplo n.º 1
0
class VUnetEncoder(nn.Module):
    def __init__(
        self,
        n_stages,
        nf_in=3,
        nf_start=64,
        nf_max=128,
        n_rnb=2,
        conv_layer=NormConv2d,
        dropout_prob=0.0,
    ):
        super().__init__()
        self.in_op = conv_layer(nf_in, nf_start, kernel_size=1)
        nf = nf_start
        self.blocks = ModuleDict()
        self.downs = ModuleDict()
        self.n_rnb = n_rnb
        self.n_stages = n_stages
        for i_s in range(self.n_stages):
            # prepare resnet blocks per stage
            if i_s > 0:
                self.downs.update(
                    {
                        f"s{i_s+1}": Downsample(
                            nf, min(2 * nf, nf_max), conv_layer=conv_layer
                        )
                    }
                )
                nf = min(2 * nf, nf_max)

            for ir in range(self.n_rnb):
                stage = f"s{i_s+1}_{ir+1}"
                self.blocks.update(
                    {
                        stage: VUnetResnetBlock(
                            nf, conv_layer=conv_layer, dropout_prob=dropout_prob
                        )
                    }
                )

    def forward(self, x):
        out = {}
        h = self.in_op(x)
        for ir in range(self.n_rnb):
            h = self.blocks[f"s1_{ir+1}"](h)
            out[f"s1_{ir+1}"] = h

        for i_s in range(1, self.n_stages):

            h = self.downs[f"s{i_s+1}"](h)

            for ir in range(self.n_rnb):
                stage = f"s{i_s+1}_{ir+1}"
                h = self.blocks[stage](h)
                out[stage] = h

        return out
Ejemplo n.º 2
0
class VUnetBottleneck(nn.Module):
    def __init__(
        self,
        n_stages,
        nf,
        device,
        n_rnb=2,
        n_auto_groups=4,
        conv_layer=NormConv2d,
        dropout_prob=0.0,
    ):
        super().__init__()
        self.device = device
        self.blocks = ModuleDict()
        self.channel_norm = ModuleDict()
        self.conv1x1 = conv_layer(nf, nf, 1)
        self.up = Upsample(in_channels=nf, out_channels=nf, conv_layer=conv_layer)
        self.depth_to_space = DepthToSpace(block_size=2)
        self.space_to_depth = SpaceToDepth(block_size=2)
        self.n_stages = n_stages
        self.n_rnb = n_rnb
        # number of autoregressively modeled groups
        self.n_auto_groups = n_auto_groups
        for i_s in range(self.n_stages, self.n_stages - 2, -1):
            self.channel_norm.update({f"s{i_s}": conv_layer(2 * nf, nf, 1)})
            for ir in range(self.n_rnb):
                self.blocks.update(
                    {
                        f"s{i_s}_{ir+1}": VUnetResnetBlock(
                            nf,
                            use_skip=True,
                            conv_layer=conv_layer,
                            dropout_prob=dropout_prob,
                        )
                    }
                )

        self.auto_blocks = ModuleList()
        # model the autoregressively groups rnb
        for i_a in range(4):
            if i_a < 1:
                self.auto_blocks.append(
                    VUnetResnetBlock(
                        nf, conv_layer=conv_layer, dropout_prob=dropout_prob
                    )
                )
                self.param_converter = conv_layer(4 * nf, nf, kernel_size=1)
            else:
                self.auto_blocks.append(
                    VUnetResnetBlock(
                        nf,
                        use_skip=True,
                        conv_layer=conv_layer,
                        dropout_prob=dropout_prob,
                    )
                )

    def forward(self, x_e, z_post, mode="train"):
        """

        Parameters
        ----------
        x_e : torch.Tensor
            The output from the encoder E_theta
        z_post : torch.Tensor
            The output from the encoder F_phi
        mode : str
            Determines the mode of the bottleneck, must be in
            ["train","appearance_transfer","sample_appearance"]

        Returns
        -------
        h : torch.Tensor
            the output of the last layer of the bottleneck which is
            subsequently used by the decoder.
        posterior_params : torch.Tensor
            The flattened means of the posterior distributions p(z|ŷ,x) of the
            two bottleneck stages.
        prior_params : dict(str: torch.Tensor)
            The flattened means of the prior distributions p(z|ŷ) of the two
            bottleneck stages.
        z_prior : torch.Tensor
            The current samples of the two stages of the prior distributions of
            both two bottleneck stages, flattened.
        """
        p_params = {}
        z_prior = {}

        use_z = mode == "train" or mode == "appearance_transfer"

        h = self.conv1x1(x_e[f"s{self.n_stages}_2"])
        for i_s in range(self.n_stages, self.n_stages - 2, -1):
            stage = f"s{i_s}"
            spatial_size = x_e[stage + "_2"].shape[-1]

            h = self.blocks[stage + "_2"](h, x_e[stage + "_2"])

            if spatial_size == 1:
                p_params[stage] = h
                # posterior_params[stage] = z_post[stage + "_2"]
                prior_samples = self._latent_sample(p_params[stage])
                z_prior[stage] = torch.squeeze(
                    torch.squeeze(prior_samples, dim=-1), dim=-1
                )
                # posterior_samples = self._latent_sample(posterior_params[stage])
            else:

                if use_z:
                    z_flat = (
                        self.space_to_depth(z_post[stage])
                        if z_post[stage].shape[2] > 1
                        else z_post[stage]
                    )
                    sec_size = z_flat.shape[1] // 4
                    z_groups = torch.split(
                        z_flat, [sec_size, sec_size, sec_size, sec_size], dim=1
                    )

                param_groups = []
                sample_groups = []

                param_features = self.auto_blocks[0](h)
                param_features = self.space_to_depth(param_features)
                # convert to fitting depth
                param_features = self.param_converter(param_features)

                for i_a in range(len(self.auto_blocks)):
                    param_groups.append(param_features)

                    prior_samples = self._latent_sample(param_groups[-1])

                    sample_groups.append(prior_samples)

                    if i_a + 1 < len(self.auto_blocks):
                        if use_z:
                            feedback = z_groups[i_a]
                        else:
                            feedback = prior_samples

                        param_features = self.auto_blocks[i_a](param_features, feedback)

                p_params_stage = torch.cat(param_groups, dim=1)
                prior_samples = self.__merge_groups(sample_groups)
                p_params[stage] = p_params_stage
                z_prior[stage] = (
                    self.space_to_depth(prior_samples).squeeze(dim=-1).squeeze(dim=-1)
                )

            if use_z:
                z = (
                    self.depth_to_space(z_post[stage])
                    if z_post[stage].shape[-1] != h.shape[-1]
                    else z_post[stage]
                )
            else:
                z = prior_samples

            h = torch.cat([h, z], dim=1)
            h = self.channel_norm[stage](h)
            h = self.blocks[stage + "_1"](h, x_e[stage + "_1"])

            if i_s == self.n_stages:
                h = self.up(h)

        return h, p_params, z_prior

    def __split_groups(self, x):
        # split along channel axis
        sec_size = x.shape[1] // 4
        return torch.split(
            self.space_to_depth(x), [sec_size, sec_size, sec_size, sec_size], dim=1,
        )

    def __merge_groups(self, x):
        # merge groups along channel axis
        return self.depth_to_space(torch.cat(x, dim=1))

    def _latent_sample(self, mean):
        sample_mean = torch.squeeze(torch.squeeze(mean, dim=-1), dim=-1)

        sampled = MultivariateNormal(
            loc=torch.zeros_like(sample_mean, device=self.device),
            covariance_matrix=torch.eye(sample_mean.shape[-1], device=self.device),
        ).sample()

        return (sampled + sample_mean).unsqueeze(dim=-1).unsqueeze(dim=-1)
Ejemplo n.º 3
0
class VUnetBottleneckOld(nn.Module):
    def __init__(
        self, n_stages, nf, device, n_rnb=2, n_auto_groups=4, conv_layer=NormConv2d,
    ):
        super().__init__()
        self.device = device
        self.blocks = ModuleDict()
        self.channel_norm = ModuleDict()
        self.conv1x1 = conv_layer(nf, nf, 1)
        self.up = Upsample(in_channels=nf, out_channels=nf, conv_layer=conv_layer)
        self.depth_to_space = DepthToSpace(block_size=2)
        self.space_to_depth = SpaceToDepth(block_size=2)
        self.n_stages = n_stages
        self.n_rnb = n_rnb
        # number of autoregressively modeled groups
        self.n_auto_groups = n_auto_groups
        for i_s in range(self.n_stages, self.n_stages - 2, -1):
            self.channel_norm.update({f"s{i_s}": conv_layer(2 * nf, nf, 1)})
            for ir in range(self.n_rnb):
                self.blocks.update(
                    {
                        f"s{i_s}_{ir+1}": VUnetResnetBlock(
                            nf, use_skip=True, conv_layer=conv_layer
                        )
                    }
                )

        if FLAGS.group_auto:
            self.auto_blocks = ModuleList()
            # model the autoregressively groups rnb
            for i_a in range(4):
                if i_a < 1:
                    self.auto_blocks.append(VUnetResnetBlock(nf, conv_layer=conv_layer))
                    self.param_converter = conv_layer(4 * nf, nf, kernel_size=1)
                else:
                    self.auto_blocks.append(
                        VUnetResnetBlock(nf, use_skip=True, conv_layer=conv_layer)
                    )

    def forward(self, x_e, x_f, mode="train"):
        """
        :param x_e: The output from the encoder E_theta
        :param x_f:  The output from the encoder F_phi
        :param mode: Determines the mode of the bottleneck, must be in ["train","appearance_transfer","sample_appearance"]
        :return:    h: the output of the last layer of the bottleneck which is subsequently used by the decoder
                    posterior_params: The flattened means of the posterior distributions p(z|ŷ,x) of the two bottleneck stages
                    prior_params: The flattened means of the prior distributions p(z|ŷ) of the two bottleneck stages
                    z_prior: The current samples of the two stages of the prior distributions of both two bottleneck stages, flattened
        """
        # posterior_samples = {}
        # prior_samples = {}
        prior_params = {}
        posterior_params = {}
        z_prior = {}
        h = self.conv1x1(x_e[f"s{self.n_stages}_2"])
        for i_s in range(self.n_stages, self.n_stages - 2, -1):
            stage = f"s{i_s}"
            spatial_size = x_e[stage + "_2"].shape[-1]

            h = self.blocks[stage + "_2"](h, x_e[stage + "_2"])

            if spatial_size == 1:
                prior_params[stage] = x_e[stage + "_2"]
                posterior_params[stage] = x_f[stage + "_2"]

                prior_samples = self._latent_sample(prior_params[stage])
                z_prior[stage] = torch.squeeze(
                    torch.squeeze(prior_samples, dim=-1), dim=-1
                )
                posterior_samples = self._latent_sample(posterior_params[stage])
            else:

                post_params = self.space_to_depth(x_f[stage + "_2"])
                posterior_params[stage] = post_params

                if FLAGS.group_auto:
                    if mode == "train" or mode == "appearance_transfer":
                        posterior_samples = self._latent_sample(post_params)

                        sec_size = posterior_samples.shape[1] // 4
                        posterior_sample_groups = torch.split(
                            posterior_samples,
                            [sec_size, sec_size, sec_size, sec_size],
                            dim=1,
                        )
                        posterior_samples = self.depth_to_space(posterior_samples)

                    param_groups = []
                    sample_groups = []

                    param_features = self.auto_blocks[0](h)
                    param_features = self.space_to_depth(param_features)
                    # convert to fitting depth
                    param_features = self.param_converter(param_features)

                    for i_a in range(len(self.auto_blocks)):
                        param_groups.append(param_features)
                        # with torch.cuda.device(self.device):
                        prior_samples = self._latent_sample(param_groups[-1])

                        sample_groups.append(prior_samples)

                        if i_a + 1 < len(self.auto_blocks):
                            if mode == "train" or mode == "appearance_transfer":
                                feedback = posterior_sample_groups[i_a]
                            else:
                                feedback = prior_samples

                            param_features = self.auto_blocks[i_a](
                                param_features, feedback
                            )

                    pri_params = torch.cat(param_groups, dim=1)
                    prior_samples = self.__merge_groups(sample_groups)

                else:

                    pri_params = self.space_to_depth(x_e[stage + "_2"])

                    prior_samples = self.depth_to_space(self._latent_sample(pri_params))
                    posterior_samples = self.depth_to_space(
                        self._latent_sample(post_params)
                    )

                prior_params[stage] = pri_params
                z_prior[stage] = (
                    self.space_to_depth(prior_samples).squeeze(dim=-1).squeeze(dim=-1)
                )

            if mode == "train" or mode == "appearance_transfer":
                # training and appearance transfer: sample from posterior
                z = posterior_samples
            elif mode == "sample_appearance":
                # appearance sampling: sample from prior
                z = prior_samples
            else:
                raise ValueError(
                    'The \'mode\' parameter in VUnetBottleneck must be in ["train","appearance_transfer","sample_appearance"]'
                )

            h = torch.cat([h, z], dim=1)
            h = self.channel_norm[stage](h)
            h = self.blocks[stage + "_1"](h, x_e[stage + "_1"])

            if i_s == self.n_stages:
                h = self.up(h)

        #
        return h, prior_params, posterior_params, z_prior

    def __split_groups(self, x):
        # split along channel axis
        sec_size = x.shape[1] // 4
        return torch.split(
            self.space_to_depth(x), [sec_size, sec_size, sec_size, sec_size], dim=1,
        )

    def __merge_groups(self, x):
        # merge groups along channel axis
        return self.depth_to_space(torch.cat(x, dim=1))

    def _latent_sample(self, mean):
        sample_mean = torch.squeeze(torch.squeeze(mean, dim=-1), dim=-1)

        sampled = MultivariateNormal(
            loc=torch.zeros_like(sample_mean, device=self.device),
            covariance_matrix=torch.eye(sample_mean.shape[-1], device=self.device),
        ).sample()

        return (sampled + sample_mean).unsqueeze(dim=-1).unsqueeze(dim=-1)
Ejemplo n.º 4
0
class VUnetDecoder(nn.Module):
    def __init__(
        self,
        n_stages,
        nf=128,
        nf_out=3,
        n_rnb=2,
        conv_layer=NormConv2d,
        spatial_size=256,
        final_act=True,
        dropout_prob=0.0,
    ):
        super().__init__()
        assert (2 ** (n_stages - 1)) == spatial_size
        self.final_act = final_act
        self.blocks = ModuleDict()
        self.ups = ModuleDict()
        self.n_stages = n_stages
        self.n_rnb = n_rnb
        for i_s in range(self.n_stages - 2, 0, -1):
            # for final stage, bisect number of filters
            if i_s == 1:
                # upsampling operations
                self.ups.update(
                    {
                        f"s{i_s+1}": Upsample(
                            in_channels=nf, out_channels=nf // 2, conv_layer=conv_layer,
                        )
                    }
                )
                nf = nf // 2
            else:
                # upsampling operations
                self.ups.update(
                    {
                        f"s{i_s+1}": Upsample(
                            in_channels=nf, out_channels=nf, conv_layer=conv_layer,
                        )
                    }
                )

            # resnet blocks
            for ir in range(self.n_rnb, 0, -1):
                stage = f"s{i_s}_{ir}"
                self.blocks.update(
                    {
                        stage: VUnetResnetBlock(
                            nf,
                            use_skip=True,
                            conv_layer=conv_layer,
                            dropout_prob=dropout_prob,
                        )
                    }
                )

        # final 1x1 convolution
        self.final_layer = conv_layer(nf, nf_out, kernel_size=1)

        # conditionally: set final activation
        if self.final_act:
            self.final_act = nn.Tanh()

    def forward(self, x, skips):
        """

        Parameters
        ----------
        x : torch.Tensor
            Latent representation to decode.
        skips : dict
            The skip connections of the VUnet

        Returns
        -------
        out : torch.Tensor
            An image as described by :attr:`x` and :attr:`skips`
        """
        out = x
        for i_s in range(self.n_stages - 2, 0, -1):
            out = self.ups[f"s{i_s+1}"](out)

            for ir in range(self.n_rnb, 0, -1):
                stage = f"s{i_s}_{ir}"
                out = self.blocks[stage](out, skips[stage])

        out = self.final_layer(out)
        if self.final_act:
            out = self.final_act(out)
        return out
Ejemplo n.º 5
0
class VUnetDecoder(nn.Module):
    def __init__(self, n_stages, nf=128, nf_out=3, n_rnb=2, conv_layer=NormConv2d):
        super().__init__()
        assert (2 ** (n_stages - 1)) == FLAGS.spatial_size
        self.blocks = ModuleDict()
        self.ups = ModuleDict()
        self.n_stages = n_stages
        self.n_rnb = n_rnb
        for i_s in range(self.n_stages - 2, 0, -1):
            # for final stage, bisect number of filters
            if i_s == 1:
                # upsampling operations
                self.ups.update(
                    {
                        f"s{i_s+1}": Upsample(
                            in_channels=nf, out_channels=nf // 2, conv_layer=conv_layer,
                        )
                    }
                )
                nf = nf // 2
            else:
                # upsampling operations
                self.ups.update(
                    {
                        f"s{i_s+1}": Upsample(
                            in_channels=nf, out_channels=nf, conv_layer=conv_layer,
                        )
                    }
                )

            # resnet blocks
            for ir in range(self.n_rnb, 0, -1):
                stage = f"s{i_s}_{ir}"
                self.blocks.update(
                    {stage: VUnetResnetBlock(nf, use_skip=True, conv_layer=conv_layer)}
                )

        # final 1x1 convolution
        self.final_layer = conv_layer(nf, nf_out, kernel_size=1)

        # conditionally: set final activation
        if FLAGS.final_act:
            self.final_act = nn.Tanh()

    def forward(self, x, skips):
        """

        :param x:
        :param skips: The skip connections of the VUnet
        :return:
        """
        out = x
        for i_s in range(self.n_stages - 2, 0, -1):
            out = self.ups[f"s{i_s+1}"](out)

            for ir in range(self.n_rnb, 0, -1):
                stage = f"s{i_s}_{ir}"
                out = self.blocks[stage](out, skips[stage])

        out = self.final_layer(out)
        if FLAGS.final_act:
            out = self.final_act(out)
        return out