Example #1
0
    def negative_iwae_bound(self, x, iw):
        """
        Computes the Importance Weighted Autoencoder Bound
        Additionally, we also compute the ELBO KL and reconstruction terms

        Args:
            x: tensor: (batch, dim): Observations
            iw: int: (): Number of importance weighted samples

        Returns:
            niwae: tensor: (): Negative IWAE bound
            kl: tensor: (): ELBO KL divergence to prior
            rec: tensor: (): ELBO Reconstruction term
        """
        ################################################################################
        # TODO: Modify/complete the code here
        # Compute niwae (negative IWAE) with iw importance samples, and the KL
        # and Rec decomposition of the Evidence Lower Bound
        #
        # Outputs should all be scalar
        ################################################################################
        # Compute the mixture of Gaussian prior
        prior = ut.gaussian_parameters(self.z_pre, dim=1)
        ################################################################################
        # End of code modification
        ################################################################################
        return niwae, kl, rec
Example #2
0
    def encode(self, x, y=None, c=None):
        xy = x if y is None else torch.cat((x, y), dim=-1)
        xyc = xy if c is None else torch.cat((xy, c), dim=-1)
        states = self.elu(self.initializer(xyc)).view(
            x.size()[0], self.NUM_LAYERS * (1 + self.BIDIRECTIONAL), 2 * self.HIDDEN_DIM)
        (h_0, c_0) = states.split(self.HIDDEN_DIM, dim=-1)

        x = x.unsqueeze(2)
        position = torch.arange(self.x_dim, dtype=torch.float) / (0.5 * (self.x_dim - 1)) - 1.0
        position = position.view(1, self.x_dim).expand(x.size()[0], self.x_dim).unsqueeze(2)
        x = torch.cat((x, position), dim=2)

        if y is not None:
            y = y.unsqueeze(1).expand(x.size()[0], self.x_dim, self.y_dim)
            x = torch.cat((x, y), dim=2)
        if c is not None:
            c = c.unsqueeze(1).expand(x.size()[0], self.x_dim, self.c_dim)
            x = torch.cat((x, c), dim=2)

        input = self.elu(self.embedder(x))
        # print(x.size())
        # print(input.size())
        _, (h_n, _) = self.rnn(input.transpose(1, 0), (h_0.transpose(1, 0), c_0.transpose(1, 0)))

        # h_n of shape (num_layers * num_directions, batch, hidden_size)
        # reshape to (batch, last_hidden_outputs)
        h = h_n.transpose(1, 0).reshape(-1, self.NUM_LAYERS * (1+self.BIDIRECTIONAL) * self.HIDDEN_DIM)
        out = self.regressor(h)

        m, v = ut.gaussian_parameters(out, dim=-1)
        return m, v
Example #3
0
 def encode(self, x, y=None):
     xy = x if y is None else torch.cat((x, y), dim=1)
     xy = xy.view(-1, self.channel * 96 * 96)
     h = self.net(xy)
     m, v = ut.gaussian_parameters(h, dim=1)
     #print(self.z_dim,m.size(),v.size())
     return m, v
Example #4
0
    def decode(self, z, y=None, c=None):
        # Note: not designed for IW!!
        assert len(z.size()) <= 2

        zy = z if y is None else torch.cat((z, y), dim=-1)
        zyc = zy if c is None else torch.cat((zy, c), dim=-1)

        states = self.elu(self.initializer(zyc)).view(
            z.size()[0], self.NUM_LAYERS * (1 + self.BIDIRECTIONAL),
            2 * self.HIDDEN_DIM)

        (h_0, c_0) = states.split(self.HIDDEN_DIM, dim=-1)

        input = zyc.unsqueeze(1).expand(*zyc.size()[0:-1], self.x_dim,
                                        zyc.size()[-1])

        position = torch.arange(
            self.x_dim, dtype=torch.float) / (0.5 * (self.x_dim - 1)) - 1.0
        position = position.view(1, self.x_dim).expand(z.size()[0],
                                                       self.x_dim).unsqueeze(2)
        input = torch.cat((input, position), dim=2)

        output, (_, _) = self.rnn(input.transpose(1, 0),
                                  (h_0.transpose(1, 0), c_0.transpose(1, 0)))
        out = self.regressor(output.transpose(1, 0)).transpose(-2, -1).reshape(
            z.size()[0], -1)
        m, v = ut.gaussian_parameters(out, dim=-1)
        return m, v
Example #5
0
def mse_rnn(model, inputs, targets, n_samples):
    pred = model.forward(inputs).detach()
    if not model.constant_var:
        mean, var = ut.gaussian_parameters(pred, dim=-1)
    else:
        mean = pred
        var = model.pred_var
    return ((targets - mean) ** 2).sum(-1).sum(0)
Example #6
0
def wse_rnn(model, inputs, targets):
    pred = model.forward(inputs).detach()
    if not model.constant_var:
        mean, var = ut.gaussian_parameters(pred, dim=-1)
    else:
        mean = pred
        var = model.pred_var
    sample_trajs = ut.sample_gaussian(mean, var)
    return ((targets - sample_trajs) ** 2).sum(-1).sum(0)
Example #7
0
 def conditional_encode(self, x, l):
     x = x.view(-1, self.channel * 96 * 96)
     x = F.elu(self.fc1(x))
     l = l.view(-1, 4)
     x = F.elu(self.fc2(torch.cat([x, l], dim=1)))
     x = F.elu(self.fc3(x))
     x = self.fc4(x)
     m, v = ut.gaussian_parameters(x, dim=1)
     return m, v
def get_mse(model, full_true_trajs, n_samples=100):
    """ root-weighted square error (RWSE) captures 
        the deviation of a model’s probability
        mass from real-world trajectories
    """
    n_seqs = full_true_trajs.shape[1]
    inputs = full_true_trajs[:model.n_input_steps, :, :].detach()
    targets = full_true_trajs[model.n_input_steps:, :, :2].detach()

    if model.BBB:
        for i in range(n_samples):
            # not using sharpening
            pred = model.forward(inputs)  # one output sample
            pred = pred.detach()
            if i == 0:
                pred_list = pred.unsqueeze(-1)
            else:
                pred = pred.unsqueeze(-1)
                pred_list = torch.cat((pred_list, pred), dim=-1)

        if model.constant_var:
            mean_pred = pred_list.mean(dim=-1)
            std_pred = pred_list.std(dim=-1)
        else:
            mean_pred = pred_list[:, :, :-1, :].mean(dim=-1)
            std_pred = pred_list[:, :, :-1, :].std(dim=-1)

    else:
        pred = model.forward(inputs)
        pred = pred.detach()
        if not model.constant_var:
            mean, var = ut.gaussian_parameters(pred, dim=-1)
        else:
            mean = pred
            var = model.pred_var

        for i in range(n_samples):
            sample_trajs = ut.sample_gaussian(mean, var)
            sample_trajs = mean

            if i == 0:
                pred_list = sample_trajs.unsqueeze(-1)
            else:
                sample_trajs = sample_trajs.unsqueeze(-1)
                pred_list = torch.cat((pred_list, sample_trajs), dim=-1)

        if model.constant_var:
            mean_pred = pred_list.mean(dim=-1)
            std_pred = pred_list.std(dim=-1)
        else:
            mean_pred = pred_list[:, :, :-1, :].mean(dim=-1)
            std_pred = pred_list[:, :, :-1, :].std(dim=-1)

    mse = ((mean_pred - targets)**2).sum() / n_seqs

    return mse
Example #9
0
    def sample_z(self, batch):
        m, v = ut.gaussian_parameters(self.z_pre.squeeze(0), dim=0)

        # Among all the mix Gaussian distribution, sample batch size z
        # For each a, which distribution it belongs to is sampled by a categorical distribution.

        idx = torch.distributions.categorical.Categorical(self.pi).sample(
            (batch, ))
        m, v = m[idx], v[idx]
        return ut.sample_gaussian(m, v)
Example #10
0
    def kl_elem(self, z, qm, qv):
        # Compute the mixture of Gaussian prior
        prior_m, prior_v = ut.gaussian_parameters(self.z_pre, dim=1)

        log_prob_net = ut.log_normal(z, qm, qv)
        log_prob_prior = ut.log_normal_mixture(z, prior_m, prior_v)

        # print("log_prob_net:", log_prob_net.mean(), "log_prob_prior:", log_prob_prior.mean())
        kl_elem = log_prob_net - log_prob_prior
        return kl_elem
Example #11
0
    def negative_iwae_bound(self, x, iw):
        """
        Computes the Importance Weighted Autoencoder Bound
        Additionally, we also compute the ELBO KL and reconstruction terms

        Args:
            x: tensor: (batch, dim): Observations
            iw: int: (): Number of importance weighted samples

        Returns:
            niwae: tensor: (): Negative IWAE bound
            kl: tensor: (): ELBO KL divergence to prior
            rec: tensor: (): ELBO Reconstruction term
        """
        ################################################################################
        # TODO: Modify/complete the code here
        # Compute niwae (negative IWAE) with iw importance samples, and the KL
        # and Rec decomposition of the Evidence Lower Bound
        #
        # Outputs should all be scalar
        ################################################################################
        # Compute the mixture of Gaussian prior
        prior = ut.gaussian_parameters(self.z_pre, dim=1)

        m, v = self.enc.encode(x)

        dist = Normal(loc=m, scale=torch.sqrt(v))
        z_sample = dist.rsample(sample_shape=torch.Size([iw]))
        log_batch_z_sample_prob = []
        kl_batch_z_sample = []

        for i in range(iw):
            recon_logits = self.dec.decode(z_sample[i])
            log_batch_z_sample_prob.append(
                ut.log_bernoulli_with_logits(
                    x, recon_logits))  # [batch, z_sample]
            kl_batch_z_sample.append(
                ut.log_normal(z_sample[i], m, v) -
                ut.log_normal_mixture(z_sample[i], prior[0], prior[1]))

        log_batch_z_sample_prob = torch.stack(log_batch_z_sample_prob, dim=1)
        kl_batch_z_sample = torch.stack(kl_batch_z_sample, dim=1)

        niwae = -ut.log_mean_exp(log_batch_z_sample_prob - kl_batch_z_sample,
                                 dim=1).mean(dim=0)

        rec = -torch.mean(log_batch_z_sample_prob, dim=0)
        kl = torch.mean(kl_batch_z_sample, dim=0)

        ################################################################################
        # End of code modification
        ################################################################################
        return niwae, kl, rec
Example #12
0
    def negative_iwae_bound(self, x, iw):
        """
        Computes the Importance Weighted Autoencoder Bound
        Additionally, we also compute the ELBO KL and reconstruction terms

        Args:
            x: tensor: (batch, dim): Observations
            iw: int: (): Number of importance weighted samples

        Returns:
            niwae: tensor: (): Negative IWAE bound
            kl: tensor: (): ELBO KL divergence to prior
            rec: tensor: (): ELBO Reconstruction term
        """
        ################################################################################
        # TODO: Modify/complete the code here
        # Compute niwae (negative IWAE) with iw importance samples, and the KL
        # and Rec decomposition of the Evidence Lower Bound
        #
        # Outputs should all be scalar
        ################################################################################
        # Compute the mixture of Gaussian prior
        pm, pv = ut.gaussian_parameters(self.z_pre, dim=1)
        #
        # Generate samples.
        qm, qv = self.enc.encode(x)
        niwaes = []
        recs = []
        kls = []
        for i in range(iw):
            z_sample = ut.sample_gaussian(qm, qv).view(-1, qm.shape[1])
            rec = self.dec.decode(z_sample)
            logptheta_x_g_z = ut.log_bernoulli_with_logits(x, rec)
            logptheta_z = ut.log_normal_mixture(z_sample, pm, pv)
            logqphi_z_g_x = ut.log_normal(z_sample, qm, qv)
            niwae = logptheta_x_g_z + logptheta_z - logqphi_z_g_x
            #
            # Normal variables.
            rec = -ut.log_bernoulli_with_logits(x, rec)
            kl = ut.log_normal(z_sample, qm, qv) - ut.log_normal_mixture(
                z_sample, pm, pv)
            niwaes.append(niwae)
            recs.append(rec)
            kls.append(kl)
        niwaes = torch.stack(niwaes, -1)
        niwae = ut.log_mean_exp(niwaes, -1)
        kl = torch.stack(kls, -1)
        rec = torch.stack(recs, -1)

        ################################################################################
        # End of code modification
        ################################################################################
        return -niwae.mean(), kl.mean(), rec.mean()
Example #13
0
    def negative_elbo_bound(self, x):
        """
        Computes the Evidence Lower Bound, KL and, Reconstruction costs

        Args:
            x: tensor: (batch, dim): Observations

        Returns:
            nelbo: tensor: (): Negative evidence lower bound
            kl: tensor: (): ELBO KL divergence to prior
            rec: tensor: (): ELBO Reconstruction term
        """
        ################################################################################
        # TODO: Modify/complete the code here
        # Compute negative Evidence Lower Bound and its KL and Rec decomposition
        #
        # To help you start, we have computed the mixture of Gaussians prior
        # prior = (m_mixture, v_mixture) for you, where
        # m_mixture and v_mixture each have shape (1, self.k, self.z_dim)
        #
        # Note that nelbo = kl + rec
        #
        # Outputs should all be scalar
        ################################################################################
        # Compute the mixture of Gaussian prior

        (m, v) = self.enc.encode(x)  # compute the encoder output
        #print(" ***** \n")
        #print("x xhape ", x.shape)
        #print("m and v shapes = ", m.shape, v.shape)
        prior = ut.gaussian_parameters(self.z_pre, dim=1)

        #print("prior shapes = ", prior[0].shape, prior[1].shape)
        z = ut.sample_gaussian(m, v)  # sample a point from the multivariate Gaussian
        #print("shape of z = ",z.shape)
        logits = self.dec.decode(z)  # pass the sampled "Z" through the decoder

        #print("logits shape = ", logits.shape)
        rec = -torch.mean(ut.log_bernoulli_with_logits(x, logits), -1)  # Calculate log Prob of the output

        log_prob = ut.log_normal(z, m, v)
        log_prob  -= ut.log_normal_mixture(z, prior[0], prior[1])

        kl = torch.mean(log_prob)

        rec = torch.mean(rec)

        nelbo = kl + rec
        ################################################################################
        # End of code modification
        ################################################################################
        return nelbo, kl, rec
Example #14
0
    def negative_elbo_bound(self, x):
        """
        Computes the Evidence Lower Bound, KL and, Reconstruction costs

        Args:
            x: tensor: (batch, dim): Observations

        Returns:
            nelbo: tensor: (): Negative evidence lower bound
            kl: tensor: (): ELBO KL divergence to prior
            rec: tensor: (): ELBO Reconstruction term
        """
        ################################################################################
        # TODO: Modify/complete the code here
        # Compute negative Evidence Lower Bound and its KL and Rec decomposition
        #
        # To help you start, we have computed the mixture of Gaussians prior
        # prior = (m_mixture, v_mixture) for you, where
        # m_mixture and v_mixture each have shape (1, self.k, self.z_dim)
        #
        # Note that nelbo = kl + rec
        #
        # Outputs should all be scalar
        ################################################################################
        # Compute the mixture of Gaussian prior
        prior = ut.gaussian_parameters(self.z_pre, dim=1)

        q_m, q_v = self.enc.encode(x)
        #print("q_m", q_m.size())
        z_given_x = ut.sample_gaussian(q_m, q_v)
        decoded_bernoulli_logits = self.dec.decode(z_given_x)
        rec = -ut.log_bernoulli_with_logits(x, decoded_bernoulli_logits)
        #rec = -torch.mean(rec)

        #terms for KL divergence
        log_q_phi = ut.log_normal(z_given_x, q_m, q_v)
        #print("log_q_phi", log_q_phi.size())
        log_p_theta = ut.log_normal_mixture(z_given_x, prior[0], prior[1])
        #print("log_p_theta", log_p_theta.size())
        kl = log_q_phi - log_p_theta
        #print("kl", kl.size())

        nelbo = torch.mean(kl + rec)

        rec = torch.mean(rec)
        kl = torch.mean(kl)
        ################################################################################
        # End of code modification
        ################################################################################
        return nelbo, kl, rec
 def get_nll(self, outputs, targets):
     """
     :return: negative log-likelihood of a minibatch
     """
     if self.likelihood_cost_form == 'mse':
         # This method is not validated
         return self.mse_fn(outputs, targets)
     elif self.likelihood_cost_form == 'gaussian':
         if not self.constant_var:
             mean, var = ut.gaussian_parameters(outputs, dim=-1)
             return -torch.mean(ut.log_normal(targets, mean, var))
         else:
             var = self.pred_var * torch.ones_like(outputs)
             return -torch.mean(ut.log_normal(targets, outputs, var))
Example #16
0
    def negative_iwae_bound(self, x, iw):
        """
        Computes the Importance Weighted Autoencoder Bound
        Additionally, we also compute the ELBO KL and reconstruction terms

        Args:
            x: tensor: (batch, dim): Observations
            iw: int: (): Number of importance weighted samples

        Returns:
            niwae: tensor: (): Negative IWAE bound
            kl: tensor: (): ELBO KL divergence to prior
            rec: tensor: (): ELBO Reconstruction term
        """
        ################################################################################
        # TODO: Modify/complete the code here
        # Compute niwae (negative IWAE) with iw importance samples, and the KL
        # and Rec decomposition of the Evidence Lower Bound
        #
        # Outputs should all be scalar
        ################################################################################
        # Compute the mixture of Gaussian prior
        prior = ut.gaussian_parameters(self.z_pre, dim=1)

        q_m, q_v = self.enc.encode(x)
        q_m_, q_v_ = ut.duplicate(q_m, rep=iw), ut.duplicate(q_v, rep=iw)

        z_given_x = ut.sample_gaussian(q_m_, q_v_)
        decoded_bernoulli_logits = self.dec.decode(z_given_x)

        #duplicate x
        x_dup = ut.duplicate(x, rep=iw)

        rec = ut.log_bernoulli_with_logits(x_dup, decoded_bernoulli_logits)

        log_p_theta = ut.log_normal_mixture(z_given_x, prior[0], prior[1])
        log_q_phi = ut.log_normal(z_given_x, q_m_, q_v_)

        kl = log_q_phi - log_p_theta

        niwae = rec - kl

        niwae = ut.log_mean_exp(niwae.reshape(iw, -1), dim=0)
        niwae = -torch.mean(niwae)

        #yay!
        ################################################################################
        # End of code modification
        ################################################################################
        return niwae, kl, rec
Example #17
0
    def negative_elbo_bound(self, x):
        """
        Computes the Evidence Lower Bound, KL and, Reconstruction costs

        Args:
            x: tensor: (batch, dim): Observations

        Returns:
            nelbo: tensor: (): Negative evidence lower bound
            kl: tensor: (): ELBO KL divergence to prior
            rec: tensor: (): ELBO Reconstruction term
        """
        ################################################################################
        # TODO: Modify/complete the code here
        # Compute negative Evidence Lower Bound and its KL and Rec decomposition
        #
        # To help you start, we have computed the mixture of Gaussians prior
        # prior = (m_mixture, v_mixture) for you, where
        # m_mixture and v_mixture each have shape (1, self.k, self.z_dim)
        #
        # Note that nelbo = kl + rec
        #
        # Outputs should all be scalar
        ################################################################################
        # Compute the mixture of Gaussian prior
        prior = ut.gaussian_parameters(self.z_pre, dim=1)
        prior_m, prior_v = prior

        batch = x.shape[0]

        qm, qv = self.enc.encode(x)
        # Now draw Zs from the posterior qm/qv
        z = ut.sample_gaussian(qm, qv)

        l_posterior = ut.log_normal(z, qm, qv)
        multi_m = prior_m.expand(batch, *prior_m.shape[1:])
        multi_v = prior_v.expand(batch, *prior_v.shape[1:])
        l_prior = ut.log_normal_mixture(z, multi_m, multi_v)
        kls = l_posterior - l_prior
        kl = torch.mean(kls)

        probs = self.dec.decode(z)
        recs = ut.log_bernoulli_with_logits(x, probs)
        rec = -1.0 * torch.mean(recs)

        nelbo = kl + rec
        ################################################################################
        # End of code modification
        ################################################################################
        return nelbo, kl, rec
Example #18
0
    def negative_elbo_bound(self, x):
        """
        Computes the Evidence Lower Bound, KL and, Reconstruction costs

        Args:
            x: tensor: (batch, dim): Observations

        Returns:
            nelbo: tensor: (): Negative evidence lower bound
            kl: tensor: (): ELBO KL divergence to prior
            rec: tensor: (): ELBO Reconstruction term
        """
        ################################################################################
        # TODO: Modify/complete the code here
        # Compute negative Evidence Lower Bound and its KL and Rec decomposition
        #
        # To help you start, we have computed the mixture of Gaussians prior
        # prior = (m_mixture, v_mixture) for you, where
        # m_mixture and v_mixture each have shape (1, self.k, self.z_dim)
        #
        # Note that nelbo = kl + rec
        #
        # Outputs should all be scalar
        ################################################################################
        #
        # Compute the mixture of Gaussian prior
        pm, pv = ut.gaussian_parameters(self.z_pre, dim=1)
        #
        # Generate samples.
        qm, qv = self.enc.encode(x)
        z_sample = ut.sample_gaussian(qm, qv)
        rec = self.dec.decode(z_sample)
        #
        # Compute loss.
        # KL divergence between the latent distribution and the prior.
        rec = -ut.log_bernoulli_with_logits(x, rec)
        # kl = ut.kl_normal(qm, qv, pm, pv)
        kl = ut.log_normal(z_sample, qm, qv) - ut.log_normal_mixture(
            z_sample, pm, pv)
        #
        # The liklihood of reproducing the sample image given the parameters.
        # Would need to take the average of this otherwise.
        nelbo = (kl + rec).mean()
        # NELBO: 89.24684143066406. KL: 10.346451759338379. Rec: 78.90038299560547
        ################################################################################
        # End of code modification
        ################################################################################
        return nelbo, kl.mean(), rec.mean()
    def negative_iwae_bound(self, x, iw):
        """
        Computes the Importance Weighted Autoencoder Bound
        Additionally, we also compute the ELBO KL and reconstruction terms

        Args:
            x: tensor: (batch, dim): Observations
            iw: int: (): Number of importance weighted samples

        Returns:
            niwae: tensor: (): Negative IWAE bound
            kl: tensor: (): ELBO KL divergence to prior
            rec: tensor: (): ELBO Reconstruction term
        """
        ################################################################################
        # TODO: Modify/complete the code here
        # Compute niwae (negative IWAE) with iw importance samples, and the KL
        # and Rec decomposition of the Evidence Lower Bound
        #
        # Outputs should all be scalar
        ################################################################################
        # Compute the mixture of Gaussian prior
        prior = ut.gaussian_parameters(self.z_pre, dim=1)

        N_batches, dims = x.size()

        x = ut.duplicate(x, iw)

        q_mu, q_var = self.enc.encode(x)

        z_samp = ut.sample_gaussian(q_mu, q_var)

        logits = self.dec.decode(z_samp)

        probs = ut.log_bernoulli_with_logits(x, logits)

        kl_vals = -ut.log_normal(z_samp, q_mu, q_var) + ut.log_normal_mixture(z_samp, *prior)

        probs = probs + kl_vals

        niwae = torch.mean(-ut.log_mean_exp(probs.reshape(N_batches, iw), 1))

        kl = torch.tensor(0)
        rec = torch.tensor(0)
        ################################################################################
        # End of code modification
        ################################################################################
        return niwae, kl, rec
Example #20
0
def get_rwse(model, full_true_trajs, n_samples=100):
    """ root-weighted square error (RWSE) captures 
        the deviation of a model’s probability
        mass from real-world trajectories
    """
    n_seqs = full_true_trajs.shape[1]
    inputs = full_true_trajs[:model.n_input_steps, :, :].detach()
    targets = full_true_trajs[model.n_input_steps:, :, :2].detach()

    if model.BBB:
        for i in range(n_samples):
            # not using sharpening
            pred = model.forward(inputs)
            pred = pred.detach()
            if not model.constant_var:
                pred = pred[:, :, :-1]
            mean_sq_err = ((targets - pred)**2).sum() / n_seqs

            if i == 0:
                mean_sq_err_list = mean_sq_err.unsqueeze(-1)
            else:
                mean_sq_err = mean_sq_err.unsqueeze(-1)
                mean_sq_err_list = torch.cat((mean_sq_err_list, mean_sq_err),
                                             dim=-1)

    else:
        pred = model.forward(inputs)
        pred = pred.detach()
        if not model.constant_var:
            mean, var = ut.gaussian_parameters(pred, dim=-1)
        else:
            mean = pred
            var = model.pred_var

        for i in range(n_samples):
            sample_trajs = ut.sample_gaussian(mean, var)
            mean_sq_err = ((targets - sample_trajs)**2).sum() / n_seqs

            if i == 0:
                mean_sq_err_list = mean_sq_err.unsqueeze(-1)
            else:
                mean_sq_err = mean_sq_err.unsqueeze(-1)
                mean_sq_err_list = torch.cat((mean_sq_err_list, mean_sq_err),
                                             dim=-1)

    mean_rwse = mean_sq_err_list.mean().sqrt()

    return mean_rwse
    def negative_elbo_bound(self, x):
        """
        Computes the Evidence Lower Bound, KL and, Reconstruction costs

        Args:
            x: tensor: (batch, dim): Observations

        Returns:
            nelbo: tensor: (): Negative evidence lower bound
            kl: tensor: (): ELBO KL divergence to prior
            rec: tensor: (): ELBO Reconstruction term
        """
        ################################################################################
        # TODO: Modify/complete the code here
        # Compute negative Evidence Lower Bound and its KL and Rec decomposition
        #
        # To help you start, we have computed the mixture of Gaussians prior
        # prior = (m_mixture, v_mixture) for you, where
        # m_mixture and v_mixture each have shape (1, self.k, self.z_dim)
        #
        # Note that nelbo = kl + rec
        #
        # Outputs should all be scalar
        ################################################################################
        # Compute the mixture of Gaussian prior
        prior = ut.gaussian_parameters(self.z_pre, dim=1)

        N_samp, dim = x.size()

        q_mu, q_var = self.enc.encode(x)

        z_samp = ut.sample_gaussian(q_mu, q_var)

        logits = self.dec.decode(z_samp)

        rec = -torch.mean(ut.log_bernoulli_with_logits(x, logits))

        kl = torch.mean(ut.log_normal(z_samp, q_mu, q_var) - ut.log_normal_mixture(z_samp, *prior))

        nelbo = kl + rec
        ################################################################################
        # End of code modification
        ################################################################################
        return nelbo, kl, rec
Example #22
0
    def negative_iwae_bound(self, x, iw):
        """
        Computes the Importance Weighted Autoencoder Bound
        Additionally, we also compute the ELBO KL and reconstruction terms

        Args:
            x: tensor: (batch, dim): Observations
            iw: int: (): Number of importance weighted samples

        Returns:
            niwae: tensor: (): Negative IWAE bound
            kl: tensor: (): ELBO KL divergence to prior
            rec: tensor: (): ELBO Reconstruction term
        """
        ################################################################################
        # TODO: Modify/complete the code here
        # Compute niwae (negative IWAE) with iw importance samples, and the KL
        # and Rec decomposition of the Evidence Lower Bound
        #
        # Outputs should all be scalar
        ################################################################################
        # Compute the mixture of Gaussian prior
        prior = ut.gaussian_parameters(self.z_pre, dim=1)

        m, v = self.enc.encode(x)
        m = ut.duplicate(m, iw)
        v = ut.duplicate(v, iw)
        x = ut.duplicate(x, iw)
        z = ut.sample_gaussian(m, v)
        logits = self.dec.decode(z)

        kl = ut.log_normal(z, m, v) - ut.log_normal_mixture(z, *prior)
        rec = -ut.log_bernoulli_with_logits(x, logits)
        nelbo = kl + rec
        niwae = -ut.log_mean_exp(-nelbo.reshape(iw, -1), dim=0)

        niwae, kl, rec = niwae.mean(), kl.mean(), rec.mean()

        ################################################################################
        # End of code modification
        ################################################################################
        return niwae, kl, rec
Example #23
0
 def decode(self, z, y=None, c=None):
     zy = z if y is None else torch.cat((z, y), dim=-1)
     zyc = zy if c is None else torch.cat((zy, c), dim=-1)
     h = self.net(zyc)
     m, v = ut.gaussian_parameters(h, dim=-1)
     return m, v
Example #24
0
 def encode(self, x, y=None, c=None):
     xy = x if y is None else torch.cat((x, y), dim=-1)
     xyc = xy if c is None else torch.cat((xy, c), dim=-1)
     h = self.net(xyc)
     m, v = ut.gaussian_parameters(h, dim=-1)
     return m, v
Example #25
0
 def encode(self, x, y=None):
     xy = x if y is None else torch.cat((x, y), dim=1)
     h = self.net(xy.float())
     m, v = ut.gaussian_parameters(h, dim=1)
     return m, v
Example #26
0
 def encode_simple(self, x):
     x = self.conv6(x)
     m, v = ut.gaussian_parameters(x, dim=1)
     #print(m.size())
     return m, v
Example #27
0
 def sample_z(self, batch):
     m, v = ut.gaussian_parameters(self.z_pre.squeeze(0), dim=0)
     idx = torch.distributions.categorical.Categorical(self.pi).sample(
         (batch, ))
     m, v = m[idx], v[idx]
     return ut.sample_gaussian(m, v)
Example #28
0
def train(model, train_data, batch_size, n_batches, 
            lr=1e-3,
            clip_grad=None,
            iter_max=np.inf, 
            iter_save=np.inf, 
            iter_plot=np.inf, 
            reinitialize=False,
            kernel=None):

    # Optimization
    if reinitialize:
        model.apply(ut.reset_weights)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    mse = nn.MSELoss()

    # # Model
    # hidden = model.init_hidden(batch_size)

    i = 0 # i is num of gradient steps taken by end of loop iteration
    loss_list = []
    mse_list = []
    with tqdm.tqdm(total=iter_max) as pbar:
        while True:
            for batch in train_data:
                i += 1 
                # print(psutil.virtual_memory())
                optimizer.zero_grad()

                inputs = batch[:model.n_input_steps, :, :]
                targets = batch[model.n_input_steps:, :, :2]
                
                # Since the data is not continued from batch to batch,
                # reinit hidden every batch. (using zeros)
                outputs = model.forward(inputs, targets=targets)
                batch_mean_nll, KL, KL_sharp = model.get_loss(outputs, targets)
                # print(batch_mean_nll, KL, KL_sharp)
                
                # # Re-weighting for minibatches
                NLL_term = batch_mean_nll * model.n_pred_steps

                # Here B = n_batchs, C = 1 (since each sequence is complete)
                KL_term = KL / n_batches

                loss = NLL_term + KL_term

                if model.sharpen:
                    KL_sharp /= n_batches
                    loss += KL_sharp

                loss_list.append(loss.cpu().detach())

                if clip_grad is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)

                # Print progress
                if model.likelihood_cost_form == 'gaussian':
                    if model.constant_var:
                        mse_val = mse(outputs, targets) * model.n_pred_steps
                    else:
                        if model.rnn_cell_type == 'FF':
                            mean, var = ut.gaussian_parameters_ff(outputs, dim=0)

                        else:
                            mean, var = ut.gaussian_parameters(outputs, dim=-1)
                        mse_val = mse(mean, targets) * model.n_pred_steps

                elif model.likelihood_cost_form == 'mse':
                    mse_val = batch_mean_nll * model.n_pred_steps
                    
                mse_list.append(mse_val.cpu().detach())

                if i % iter_plot == 0:
                    with torch.no_grad():
                        model.eval()
                        if model.input_feat_dim <= 2:
                            ut.test_plot(model, i, kernel)

                        elif model.input_feat_dim == 4:
                            rand_idx = random.sample(range(batch.shape[1]), 4)
                            full_true_traj = batch[:, rand_idx, :]
                            if not model.BBB:
                                if model.constant_var:
                                    pred_traj = outputs[:, rand_idx, :]
                                    std_pred = None
                                else:
                                    pred_traj = mean[:, rand_idx, :]
                                    std_pred = var.sqrt()

                                ut.plot_highd_traj(model, i, full_true_traj, 
                                    pred_traj, std_pred=std_pred)
                            else:
                                # resample a few forward passes
                                ut.plot_highd_traj_BBB(model, i, full_true_traj, 
                                                        n_resample_weights=10)
    
                        ut.plot_history(model, loss_list, i, obj='loss')
                        ut.plot_history(model, mse_list, i, obj='mse')
                        model.train()

                
                # loss.backward(retain_graph=True)
                loss.backward()
                optimizer.step()
                
                pbar.set_postfix(loss='{:.2e}'.format(loss), 
                                 mse='{:.2e}'.format(mse_val))
                pbar.update(1)

                # Save model
                if i % iter_save == 0:
                    ut.save_model_by_name(model, i, only_latest=True)

                if i == iter_max:
                    return
Example #29
0
    def negative_iwae_bound(self, x, iw):
        """
        Computes the Importance Weighted Autoencoder Bound
        Additionally, we also compute the ELBO KL and reconstruction terms

        Args:
            x: tensor: (batch, dim): Observations
            iw: int: (): Number of importance weighted samples

        Returns:
            niwae: tensor: (): Negative IWAE bound
            kl: tensor: (): ELBO KL divergence to prior
            rec: tensor: (): ELBO Reconstruction term
        """
        ################################################################################
        # TODO: Modify/complete the code here
        # Compute niwae (negative IWAE) with iw importance samples, and the KL
        # and Rec decomposition of the Evidence Lower Bound
        #
        # Outputs should all be scalar
        ################################################################################
        # Compute the mixture of Gaussian prior
        prior = ut.gaussian_parameters(self.z_pre, dim=1)
        prior_m, prior_v = prior

        batch = x.shape[0]
        multi_x = ut.duplicate(x, iw)

        qm, qv = self.enc.encode(x)
        multi_qm = ut.duplicate(qm, iw)
        multi_qv = ut.duplicate(qv, iw)

        # z will be (batch*iw x z_dim)
        # with sampled z's for a given x non-contiguous!
        z = ut.sample_gaussian(multi_qm, multi_qv)

        probs = self.dec.decode(z)
        recs = ut.log_bernoulli_with_logits(multi_x, probs)
        rec = -1.0 * torch.mean(recs)

        multi_m = prior_m.expand(batch * iw, *prior_m.shape[1:])
        multi_v = prior_v.expand(batch * iw, *prior_v.shape[1:])
        z_priors = ut.log_normal_mixture(z, multi_m, multi_v)
        x_posteriors = recs
        z_posteriors = ut.log_normal(z, multi_qm, multi_qv)

        kls = z_posteriors - z_priors
        kl = torch.mean(kls)

        log_ratios = z_priors + x_posteriors - z_posteriors
        # Should be (batch*iw, z_dim), batch ratios non contiguous

        unflat_log_ratios = log_ratios.reshape(iw, batch)

        niwaes = ut.log_mean_exp(unflat_log_ratios, 0)
        niwae = -1.0 * torch.mean(niwaes)

        ################################################################################
        # End of code modification
        ################################################################################
        return niwae, kl, rec
Example #30
0
 def encode(self, x):
     h = self.net(x)
     m, v = ut.gaussian_parameters(h, dim=1)
     return m, v