Exemplo n.º 1
0
    def sample(self, inf_data_in):

        if self.sample_stored_output is not None:
            output_data = self.sample_stored_output
        else:
            if inf_data_in is not None:
                inf_layer_data = self.inf_pre(inf_data_in)
            else:
                inf_layer_data = None

            if self.sublayer is not None:
                next_inf_data = self.inf_post(inf_layer_data)

                if self.inf_batchnorm is not None:
                    next_inf_data = self.inf_batchnorm(next_inf_data)

                gen_data_in = self.sublayer.sample(next_inf_data)
            else:
                gen_data_in = None

            if not (gen_data_in is None):
                gen_layer_data = self.gen_pre(gen_data_in)
            else:
                gen_layer_data = None

            if self.latent_q is not None or self.latent_p is not None:
                p_data = gen_layer_data
                q_data = torch_ext.cat((inf_layer_data, gen_layer_data), 1)

                latent_q_dist = self.latent_q(q_data)

                if p_data is None:
                    latent_p_dist = tdist.Normal(
                        torch.zeros_like(latent_q_dist.loc),
                        torch.ones_like(latent_q_dist.loc))
                else:
                    latent_p_dist = self.latent_p(p_data)

                # latent_q_dist = torch_ext.combine(latent_q_dist, latent_p_dist)
                latent_sample = latent_p_dist.rsample()
                latent_dec = self.latent_dec(latent_sample)
            else:
                latent_dec = None

            if p_data is not None and self.bypass is not None:
                bypass = self.bypass(p_data)
            else:
                bypass = None

            output_data = self.gen_post(bypass, latent_dec)

        self.sample_stored_output = output_data

        if self.gen_batchnorm is not None:
            output_data = self.gen_batchnorm(output_data)

        return output_data
Exemplo n.º 2
0
 def forward(self, x, *additional_input):
     x = torch_ext.cat((x, *additional_input), 1)
     return self.layers(x) + self.skip(x)
Exemplo n.º 3
0
    def forward(self, inf_data_in, lengths=None):

        if inf_data_in is not None:
            if len(inf_data_in.size()) > 2:
                inf_layer_data = self.inf_pre(inf_data_in.mean(-1))
            else:
                inf_layer_data = self.inf_pre(inf_data_in)
        else:
            inf_layer_data = None

        if self.sublayer is not None:
            next_inf_data = self.inf_post(inf_layer_data)

            if self.inf_batchnorm is not None:
                next_inf_data = self.inf_batchnorm(next_inf_data)
            gen_data_in, ls = self.sublayer(next_inf_data)
        else:
            ls = []
            gen_data_in = None

        if not (gen_data_in is None):
            gen_layer_data = self.gen_pre(gen_data_in)
        else:
            gen_layer_data = None

        p_data = gen_layer_data
        q_data = torch_ext.cat((inf_layer_data, gen_layer_data), 1)

        if self.latent_q is not None or self.latent_p is not None:
            latent_q_dist = self.latent_q(q_data)

            if p_data is None:
                latent_p_dist = tdist.Normal(
                    torch.zeros_like(latent_q_dist.loc),
                    torch.ones_like(latent_q_dist.loc))
            else:
                latent_p_dist = self.latent_p(p_data)

            latent_q_sample = latent_q_dist.rsample()
            latent_dec = self.latent_dec(latent_q_sample)

            this_ls = [
                LatentState(latent_q_dist, latent_p_dist, latent_q_sample)
            ]
        else:
            latent_dec = None

            this_ls = []

        if p_data is not None and self.bypass is not None:
            bypass = self.bypass(p_data)
        else:
            bypass = None

        output_data = self.gen_post(bypass, latent_dec)

        new_ls = ls + this_ls

        if self.gen_batchnorm is not None:
            output_data = self.gen_batchnorm(output_data)

        if inf_data_in is not None:
            if len(inf_data_in.size()) > 2:
                output_data = output_data[...,
                                          None].repeat(1, 1,
                                                       inf_data_in.size()[2])

        return output_data, new_ls
Exemplo n.º 4
0
 def forward(self, x, *cond):
     return self.layers(torch_ext.cat((x, *cond), 1)) + self.skip(x)