Exemple #1
0
 def forward(self, state_features):
     x = self.feedforward_model(state_features)
     if self.dist == 'tanh_normal':
         mean, std = th.chunk(x, 2, -1)
         mean = self.mean_scale * th.tanh(mean / self.mean_scale)
         std = F.softplus(std + self.raw_init_std) + self.min_std
         dist = td.Normal(mean, std)
         # TODO: fix nan problem
         dist = td.TransformedDistribution(dist,
                                           td.TanhTransform(cache_size=1))
         dist = td.Independent(dist, 1)
         dist = SampleDist(dist)
     elif self.dist == 'trunc_normal':
         mean, std = th.chunk(x, 2, -1)
         std = 2 * th.sigmoid((std + self.raw_init_std) / 2) + self.min_std
         from rls.nn.dists.TruncatedNormal import \
             TruncatedNormal as TruncNormalDist
         dist = TruncNormalDist(th.tanh(mean), std, -1, 1)
         dist = td.Independent(dist, 1)
     elif self.dist == 'one_hot':
         dist = td.OneHotCategoricalStraightThrough(logits=x)
     elif self.dist == 'relaxed_one_hot':
         dist = td.RelaxedOneHotCategorical(th.tensor(0.1), logits=x)
     else:
         raise NotImplementedError(f"{self.dist} is not implemented.")
     return dist
    def losses_clustering(self, x, x_hat, mu_z, logvar_z, z):
        if not self.computes_std:
            std_z = torch.exp(logvar_z / 2)
            std_c = torch.exp(self.logvar_c / 2)
        else:
            std_z = torch.exp(logvar_z)
            std_c = torch.exp(self.logvar_c)
        pi = distributions.Categorical(torch.sigmoid(self.pi)).probs
        pc_given_z = self.pc_given_z(z)

        BCE = F.binary_cross_entropy_with_logits(
            x_hat, x, reduction='mean') * self.width * self.height
        KLD = torch.sum(pc_given_z * distributions.kl_divergence(
            distributions.Independent(distributions.Normal(
                mu_z[:, None, :], std_z[:, None, :]),
                                      reinterpreted_batch_ndims=1),
            distributions.Independent(distributions.Normal(
                self.mu_c[None, :, :], std_c[None, :, :]),
                                      reinterpreted_batch_ndims=1)),
                        dim=1).mean()
        KLD_c = distributions.kl_divergence(
            distributions.Categorical(pc_given_z),
            distributions.Categorical(pi[None, :])).mean()

        return BCE, KLD, KLD_c, torch.tensor(0).float(), pc_given_z
Exemple #3
0
    def forward(self, x, beta=1.0, switch=1.0, iw_samples=1):
        # Encoder step
        z_mu, z_std = self.encoder(x)
        q_dist = D.Independent(D.Normal(z_mu, z_std), 1)
        z = q_dist.rsample([iw_samples])

        # Decoder step
        x_mu, x_std = self.decoder(z, switch)
        if switch:
            valid = torch.zeros((x.shape[0], 1), device=x.device)
            fake = torch.ones((x.shape[0], 1), device=x.device)
            labels = torch.cat([valid, fake], dim=0)
            x_cat = torch.cat([x.repeat(iw_samples, 1, 1), x_mu], dim=1)

            prop = self.adverserial(x_cat)
            advert_loss = F.binary_cross_entropy(prop,
                                                 labels.repeat(
                                                     iw_samples, 1, 1),
                                                 reduction='sum')
            x_std = self.dec_std(prop[:, :x.shape[0]])
        else:
            advert_loss = 0

        p_dist = D.Independent(D.Normal(x_mu, x_std), 1)

        # Calculate loss
        prior = D.Independent(
            D.Normal(torch.zeros_like(z), torch.ones_like(z)), 1)
        log_px = p_dist.log_prob(x)
        kl = q_dist.log_prob(z) - prior.log_prob(z)
        elbo = (log_px - beta * kl).mean()
        iw_elbo = elbo.logsumexp(dim=0) - torch.tensor(float(iw_samples)).log()

        return iw_elbo.mean() - advert_loss, log_px.mean(), kl.mean(
        ), x_mu[0], x_std, z[0], z_mu, z_std
Exemple #4
0
 def reparameterize(self, mu, var):
     pred_dist = dist.Independent(dist.Normal(mu, var), 1)
     self.pred_dist = pred_dist
     eps = pred_dist.rsample()
     prior_mean = self.loc(self.onehot)
     prior_std = self.sp(self.scale(self.onehot))
     self.prior = dist.Independent(dist.Normal(prior_mean, prior_std), 1)
     return eps
    def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims,
                 n_mels, fft_bins, postnet_dims, encoder_K, lstm_dims,
                 postnet_K, num_highways, dropout, speaker_latent_dims,
                 speaker_encoder_dims, n_speakers, noise_latent_dims,
                 noise_encoder_dims):
        super().__init__()
        self.n_mels = n_mels
        self.lstm_dims = lstm_dims
        self.decoder_dims = decoder_dims

        # Standard Tacotron #############################################################
        self.encoder = Encoder(embed_dims, num_chars, encoder_dims, encoder_K,
                               num_highways, dropout)
        self.encoder_proj = nn.Linear(decoder_dims, decoder_dims, bias=False)
        self.decoder = Decoder(n_mels, decoder_dims, lstm_dims,
                               speaker_latent_dims, noise_latent_dims)
        self.postnet = CBHG(postnet_K, n_mels + noise_latent_dims,
                            postnet_dims, [256, n_mels + noise_latent_dims],
                            num_highways)
        self.post_proj = nn.Linear(postnet_dims * 2, fft_bins, bias=False)

        # VAE Domain Adversarial ########################################################
        if hp.encoder_model == "CNN":
            self.speaker_encoder = CNNEncoder(n_mels, speaker_latent_dims,
                                              speaker_encoder_dims)
            self.noise_encoder = CNNEncoder(n_mels, noise_latent_dims,
                                            noise_encoder_dims)
        elif hp.encoder_model == "CNNRNN":
            self.speaker_encoder = CNNRNNEncoder(n_mels, speaker_latent_dims,
                                                 speaker_encoder_dims)
            self.noise_encoder = CNNRNNEncoder(n_mels, noise_latent_dims,
                                               noise_encoder_dims)

        self.speaker_speaker = Classifier(speaker_latent_dims, n_speakers)
        self.speaker_noise = Classifier(speaker_latent_dims, 2)
        self.noise_speaker = Classifier(noise_latent_dims, n_speakers)
        self.noise_noise = Classifier(noise_latent_dims, 2)
        ## speaker encoder prior
        self.speaker_latent_loc = nn.Parameter(
            torch.zeros(speaker_latent_dims), requires_grad=False)
        self.speaker_latent_scale = nn.Parameter(
            torch.ones(speaker_latent_dims), requires_grad=False)
        self.speaker_latent_prior = dist.Independent(
            dist.Normal(self.speaker_latent_loc, self.speaker_latent_scale), 1)
        ## noise encoder prior
        self.noise_latent_loc = nn.Parameter(torch.zeros(noise_latent_dims),
                                             requires_grad=False)
        self.noise_latent_scale = nn.Parameter(torch.ones(noise_latent_dims),
                                               requires_grad=False)
        self.noise_latent_prior = dist.Independent(
            dist.Normal(self.noise_latent_loc, self.noise_latent_scale), 1)

        #################################################################################

        self.init_model()
        self.num_params()
        self.register_buffer("step", torch.zeros(1).long())
        self.register_buffer("r", torch.tensor(0).long())
Exemple #6
0
 def kl_penalty(self) -> torch.Tensor:
     """Compute the KL divergence prior penalty, used for constructing the ELBO."""
     q = dist.Independent(
         dist.Normal(self.q_mean, torch.exp(self.log_q_scale)),
         reinterpreted_batch_ndims=2,
     )
     p_mean = torch.zeros_like(self.q_mean)
     p_scale = torch.ones_like(self.q_mean)
     p = dist.Independent(dist.Normal(p_mean, p_scale), reinterpreted_batch_ndims=2)
     return dist.kl_divergence(q, p)
    def forward(self, image, **kwargs):
        logits = F.relu(super().forward(image, **kwargs)[0])
        batch_size = logits.shape[0]
        event_shape = (self.num_classes, ) + logits.shape[2:]

        mean = self.mean_l(logits)
        cov_diag = self.log_cov_diag_l(logits).exp() + self.epsilon
        mean = mean.view((batch_size, -1))
        cov_diag = cov_diag.view((batch_size, -1))

        cov_factor = self.cov_factor_l(logits)
        cov_factor = cov_factor.view(
            (batch_size, self.rank, self.num_classes, -1))
        cov_factor = cov_factor.flatten(2, 3)
        cov_factor = cov_factor.transpose(1, 2)

        # covariance in the background tens to blow up to infinity, hence set to 0 outside the ROI
        mask = kwargs['sampling_mask']
        mask = mask.unsqueeze(1).expand((batch_size, self.num_classes) +
                                        mask.shape[1:]).reshape(
                                            batch_size, -1)
        cov_factor = cov_factor * mask.unsqueeze(-1)
        cov_diag = cov_diag * mask + self.epsilon

        if self.diagonal:
            base_distribution = td.Independent(
                td.Normal(loc=mean, scale=torch.sqrt(cov_diag)), 1)
        else:
            try:
                base_distribution = td.LowRankMultivariateNormal(
                    loc=mean, cov_factor=cov_factor, cov_diag=cov_diag)
            except:
                print(
                    'Covariance became not invertible using independent normals for this batch!'
                )
                base_distribution = td.Independent(
                    td.Normal(loc=mean, scale=torch.sqrt(cov_diag)), 1)

        distribution = ReshapedDistribution(base_distribution, event_shape)

        shape = (batch_size, ) + event_shape
        logit_mean = mean.view(shape)
        cov_diag_view = cov_diag.view(shape).detach()
        cov_factor_view = cov_factor.transpose(
            2, 1).view((batch_size, self.num_classes * self.rank) +
                       event_shape[1:]).detach()

        output_dict = {
            'logit_mean': logit_mean.detach(),
            'cov_diag': cov_diag_view,
            'cov_factor': cov_factor_view,
            'distribution': distribution
        }

        return logit_mean, output_dict
Exemple #8
0
 def clone_dist(self, dist, detach=False):
     if self._rssm_type == 'discrete':
         mean = dist.mean
         if detach:
             mean = th.detach(mean)
         return td.Independent(OneHotDistFlattenSample(mean), 1)
     else:
         mean, stddev = dist.mean, dist.stddev
         if detach:
             mean, stddev = th.detach(mean), th.detach(stddev)
         return td.Independent(td.Normal(mean, stddev), 1)
Exemple #9
0
    def _train(self, BATCH):
        output = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        if self.is_continuous:
            mu, log_std = output  # [T, B, A], [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            new_log_prob = dist.log_prob(BATCH.action).unsqueeze(
                -1)  # [T, B, 1]
            entropy = dist.entropy().mean()  # 1
        else:
            logits = output  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            new_log_prob = (BATCH.action * logp_all).sum(
                -1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(-1).mean()  # 1
        ratio = (new_log_prob - BATCH.log_prob).exp()  # [T, B, 1]
        actor_loss = -(ratio * BATCH.gae_adv).mean()  # 1

        flat_grads = grads_flatten(actor_loss, self.actor,
                                   retain_graph=True).detach()  # [1,]

        if self.is_continuous:
            kl = td.kl_divergence(
                td.Independent(td.Normal(BATCH.mu, BATCH.log_std.exp()), 1),
                td.Independent(td.Normal(mu, log_std.exp()), 1)).mean()
        else:
            kl = (BATCH.logp_all.exp() *
                  (BATCH.logp_all - logp_all)).sum(-1).mean()  # 1

        flat_kl_grad = grads_flatten(kl, self.actor, create_graph=True)
        search_direction = -self._conjugate_gradients(
            flat_grads, flat_kl_grad, cg_iters=self._cg_iters)  # [1,]

        with th.no_grad():
            flat_params = th.cat(
                [param.data.view(-1) for param in self.actor.parameters()])
            new_flat_params = flat_params + self.actor_step_size * search_direction
            set_from_flat_params(self.actor, new_flat_params)

        for _ in range(self._train_critic_iters):
            value = self.critic(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, 1]
            td_error = BATCH.discounted_reward - value  # [T, B, 1]
            critic_loss = td_error.square().mean()  # 1
            self.critic_oplr.optimize(critic_loss)

        return {
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/entropy': entropy.mean(),
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr
        }
Exemple #10
0
 def forward(self, x, beta=1.0, epsilon=1e-5):
     z_mu, z_var = self.encoder(x)
     q_dist = D.Independent(D.Normal(z_mu, z_var.sqrt()+epsilon), 1)
     z = q_dist.rsample()
     x_mu, x_var = self.decoder(z)
     p_dist = D.Independent(D.Normal(x_mu, x_var.sqrt()+epsilon), 1)
     
     prior = D.Independent(D.Normal(torch.zeros_like(z),
                                    torch.ones_like(z)), 1)
     log_px = p_dist.log_prob(x)
     kl = q_dist.log_prob(z) - prior.log_prob(z)
     elbo = log_px - beta*kl
     return elbo.mean(), log_px.mean(), kl.mean(), x_mu, x_var, z, z_mu, z_var
Exemple #11
0
 def forward(self, x, beta=1.0, epsilon=1e-5,Q=0.5):
     z_mu, z_var = self.encoder(x)
     q_dist = D.Independent(D.Normal(z_mu, z_var.sqrt()+epsilon), 1)
     z = q_dist.rsample()
     x_mu = self.decoder(z) 
     prior = D.Independent(D.Normal(torch.zeros_like(z),
                                    torch.ones_like(z)), 1)
     log_px_Q1 = torch.sum(torch.max(0.15 * (x-x_mu[:,0:4]), (0.15 - 1) * (x-x_mu[:,0:4])).view(-1, 4),(1))
     log_px_Q2 = torch.sum(torch.max(0.5 * (x-x_mu[:,4:8]), (0.5 - 1) * (x-x_mu[:,4:8])).view(-1, 4),(1))
     log_px_Q3= torch.sum(torch.max(0.85 * (x-x_mu[:,8:12] ), (0.85 - 1) * (x-x_mu[:,8:12] )).view(-1, 4),(1))
     log_px=(log_px_Q1+log_px_Q2+log_px_Q3)/3
     kl = q_dist.log_prob(z) - prior.log_prob(z)
     elbo = -log_px - 0.28*kl
     return elbo.mean(), log_px.mean(), kl.mean(), x_mu, z, z_mu, z_var
Exemple #12
0
def test_inv():
    flow = flowtorch.bijectors.AffineAutoregressive(
        flowtorch.params.DenseAutoregressive())
    tdist, params = flow(
        dist.Independent(dist.Normal(torch.zeros(2), torch.ones(2)), 1))
    inv_flow = flow.inv()
    inv_tdist, inv_params = inv_flow(
        dist.Independent(dist.Normal(torch.zeros(2), torch.ones(2)), 1))
    x = torch.zeros(1, 2)
    y = flow.forward(x, params, context=torch.empty(0))
    assert tdist.bijector.log_abs_det_jacobian(
        x, y, params,
        context=torch.empty(0)) == inv_tdist.bijector.log_abs_det_jacobian(
            y, x, inv_params, context=torch.empty(0))
    assert flow.inv().inv == flow
def test_conditional_2gmm():
    context_size = 2

    flow = flowtorch.bijectors.Compose(
        [
            flowtorch.bijectors.AffineAutoregressive(context_size=context_size)
            for _ in range(2)
        ],
        context_size=context_size,
    ).inv()

    base_dist = dist.Normal(torch.zeros(2), torch.ones(2))
    new_cond_dist, params_module = flow(base_dist)

    target_dist_0 = dist.Independent(
        dist.Normal(torch.zeros(2) + 5,
                    torch.ones(2) * 0.5), 1)
    target_dist_1 = dist.Independent(
        dist.Normal(torch.zeros(2) - 5,
                    torch.ones(2) * 0.5), 1)

    opt = torch.optim.Adam(params_module.parameters(), lr=5e-3)

    for idx in range(501):
        opt.zero_grad()

        if idx % 2 == 0:
            target_dist = target_dist_0
            context = torch.ones(context_size)
        else:
            target_dist = target_dist_1
            context = -1 * torch.ones(context_size)

        marginal = new_cond_dist.condition(context)
        y = marginal.rsample((1000, ))
        loss = -target_dist.log_prob(y) + marginal.log_prob(y)
        loss = loss.mean()

        if idx % 100 == 0:
            print("epoch", idx, "loss", loss)

        loss.backward()
        opt.step()

    assert (new_cond_dist.condition(torch.ones(context_size)).sample(
        (1000, )).mean() - 5.0).norm().item() < 0.1
    assert (new_cond_dist.condition(-1 * torch.ones(context_size)).sample(
        (1000, )).mean() + 5.0).norm().item() < 0.1
Exemple #14
0
def gaussian_mixture_sampler(num_latent,
                             num_mixtures=4,
                             weights=None,
                             means=None,
                             cov=None):
    """

    :param num_latent:
    :param num_mixtures:
    :param weights:
    :param means:
    :param cov:
    :return:
    """

    if weights is None:
        weights = torch.randn(num_latent, num_mixtures).softmax(dim=1)

    if means is None:
        means = torch.randn(num_latent, num_mixtures, 1) * 2

    if cov is None:
        cov = torch.randn(num_latent, num_mixtures, 1)

    mix = dist.Categorical(weights)
    comp = dist.Independent(dist.Normal(means, cov), 1)

    gmm = dist.MixtureSameFamily(mix, comp)

    return lambda n: gmm.sample((n, )).squeeze()
 def detach(self):
     self.mean = self.mean.detach()
     self.log_std = self.log_std.detach()
     self.normal = None
     self.diagn = None
     self.normal = P.Normal(self.mean, (self.log_std.exp()))
     self.diagn = P.Independent(self.normal, 1)
Exemple #16
0
def test_neals_funnel_vi():
    torch.manual_seed(42)
    nf = NealsFunnel()
    flow = flowtorch.bijectors.AffineAutoregressive(
        flowtorch.params.DenseAutoregressive())
    tdist, params = flow(
        dist.Independent(dist.Normal(torch.zeros(2), torch.ones(2)), 1))
    opt = torch.optim.Adam(params.parameters(), lr=1e-3)
    num_elbo_mc_samples = 100
    for _ in range(400):
        z0 = tdist.base_dist.rsample(sample_shape=(num_elbo_mc_samples, ))
        zk = flow._forward(z0, params, context=torch.empty(0))
        ldj = flow._log_abs_det_jacobian(z0,
                                         zk,
                                         params,
                                         context=torch.empty(0))

        neg_elbo = -nf.log_prob(zk).sum()
        neg_elbo += tdist.base_dist.log_prob(z0).sum() - ldj.sum()
        neg_elbo /= num_elbo_mc_samples

        if not torch.isnan(neg_elbo):
            neg_elbo.backward()
            opt.step()
            opt.zero_grad()

    nf_samples = NealsFunnel().sample((20, )).squeeze().numpy()
    vi_samples = tdist.sample((20, )).detach().numpy()

    assert scipy.stats.ks_2samp(nf_samples[:, 0], vi_samples[:,
                                                             0]).pvalue >= 0.05
    assert scipy.stats.ks_2samp(nf_samples[:, 1], vi_samples[:,
                                                             1]).pvalue >= 0.05
Exemple #17
0
 def select_action(self, obs):
     q = self.q_net(obs, rnncs=self.rnncs)  # [B, P]
     self.rnncs_ = self.q_net.get_rnncs()
     pi = self.intra_option_net(obs, rnncs=self.rnncs)  # [B, P, A]
     beta = self.termination_net(obs, rnncs=self.rnncs)  # [B, P]
     options_onehot = F.one_hot(self.options,
                                self.options_num).float()  # [B, P]
     options_onehot_expanded = options_onehot.unsqueeze(-1)  # [B, P, 1]
     pi = (pi * options_onehot_expanded).sum(-2)  # [B, A]
     if self.is_continuous:
         mu = pi.tanh()  # [B, A]
         log_std = self.log_std[self.options]  # [B, A]
         dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
         actions = dist.sample().clamp(-1, 1)  # [B, A]
     else:
         pi = pi / self.boltzmann_temperature  # [B, A]
         dist = td.Categorical(logits=pi)
         actions = dist.sample()  # [B, ]
     max_options = q.argmax(-1).long()  # [B, P] => [B, ]
     if self.use_eps_greedy:
         # epsilon greedy
         if self._is_train_mode and self.expl_expt_mng.is_random(
                 self._cur_train_step):
             self.new_options = self._generate_random_options()
         else:
             self.new_options = max_options
     else:
         beta_probs = (beta * options_onehot).sum(-1)  # [B, P] => [B,]
         beta_dist = td.Bernoulli(probs=beta_probs)
         self.new_options = th.where(beta_dist.sample() < 1, self.options,
                                     max_options)
     return actions, Data(action=actions,
                          last_options=self.options,
                          options=self.new_options)
Exemple #18
0
 def independent(self, reinterpreted_batch_ndims=1):
     '''
     Flattening the data into one (or more) dimensions and using as if it were
     td.Independent(distribution=OurDistribution...) is common
     '''
     return td.Independent(
         self, reinterpreted_batch_ndims=reinterpreted_batch_ndims)
Exemple #19
0
def create_gmm(system, gmm_scale=0.05):
    """
    Get distribution using gaussian kernels on a system of points.

    Arguments:
        system: set of points from which gmm will be produced
        batches: bool indicating if system shape includes batch dimension
        kernel_size: stdev of kernel placed on each point to form gmm

    Returns:
        gmm_x: gmm probability distribution
    """
    system = torch.squeeze(system)
    n_dim = system.shape[-1]
    n_concepts = system.shape[-2]

    # Weight concepts equally
    mix = D.Categorical(torch.ones(n_concepts, ))

    # Covariance matrix (diagonal) set with gmm_scale
    components = D.Independent(
        D.Normal(system, gmm_scale * torch.ones(n_dim, )), 1)
    gmm_X = D.mixture_same_family.MixtureSameFamily(mix, components)

    return gmm_X
Exemple #20
0
    def forward(self, input: torch.Tensor, proposal: distributions.Normal,
                reconstruction: torch.Tensor) -> torch.Tensor:

        if self.likelihood == 'bernoulli':
            likelihood = distributions.Bernoulli(probs=reconstruction)
        else:
            likelihood = distributions.Normal(reconstruction,
                                              torch.ones_like(reconstruction))

        likelihood = distributions.Independent(likelihood,
                                               reinterpreted_batch_ndims=-1)
        reconstruction_loss = likelihood.log_prob(input).mean()

        assert proposal.loc.dim(
        ) == 2, "proposal.shape == [*, D], D is shape of isotopic gaussian"

        prior = distributions.Normal(torch.zeros_like(proposal.loc),
                                     torch.ones_like(proposal.scale))
        regularization = distributions.kl_divergence(proposal,
                                                     prior).sum(dim=-1).mean()

        # evidence lower bound (maximize)
        total_loss = reconstruction_loss - self.beta * regularization

        return -total_loss, -reconstruction_loss, regularization
Exemple #21
0
 def _build_dist(self, output):
     if self._rssm_type == 'discrete':
         logits = output.view(
             output.shape[:-1] +
             (self.stoch_dim, self._discretes))  # [B, s, d]
         return td.Independent(OneHotDistFlattenSample(logits=logits), 1)
     else:
         mean, stddev = th.chunk(output, 2, -1)  # [B, *]
         if self._std_act == 'softplus':
             stddev = F.softplus(stddev)
         elif self._std_act == 'sigmoid':
             stddev = th.sigmoid(stddev)
         elif self._std_act == 'sigmoid2':
             stddev = 2. * th.sigmoid(stddev / 2.)
         stddev = stddev + self._min_stddev  # [B, *]
         return td.Independent(td.Normal(mean, stddev), 1)
Exemple #22
0
 def update(self, observations, actions, adv_n=None):
     # TODO: update the policy and return the loss
     # observations = ptu.from_numpy(observations)
     # actions = ptu.from_numpy(actions)
     if adv_n is not None:
         # adv_n = ptu.from_numpy(adv_n)
         pass
     else:
         # in which circumstances can adv_n be None?? seems no
         raise ValueError("adv_n is None!?")
     action_dist = self.forward(observations)
     if self.discrete:
         log_pi = action_dist.log_prob(actions)
     else:
         if len(action_dist.batch_shape) == 1:
             log_pi = action_dist.log_prob(actions)
         else:
             action_dist_new = distributions.Independent(action_dist, 1)
             log_pi = action_dist_new.log_prob(actions)
     assert adv_n.ndim == log_pi.ndim
     sums = adv_n * log_pi
     # sums = torch.tensor(sums)l
     # loss = sum(sums)
     loss = -torch.sum(
         sums
     )  # `optimizer.step()` MINIMIZES a loss but we want to MAXIMIZE expectation
     self.optimizer.zero_grad()
     loss.backward()
     self.optimizer.step()
     return loss.item()  # what  does item() do
Exemple #23
0
    def _train(self, BATCH):
        v = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        td_error = BATCH.discounted_reward - v  # [T, B, 1]
        critic_loss = td_error.square().mean()  # 1
        self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs,
                                     begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            log_act_prob = dist.log_prob(BATCH.action).unsqueeze(
                -1)  # [T, B, 1]
            entropy = dist.entropy().unsqueeze(-1)  # [T, B, 1]
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            log_act_prob = (BATCH.action * logp_all).sum(
                -1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(
                -1, keepdim=True)  # [T, B, 1]
        # advantage = BATCH.discounted_reward - v.detach()  # [T, B, 1]
        actor_loss = -(log_act_prob * BATCH.gae_adv +
                       self.beta * entropy).mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        return {
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/entropy': entropy.mean(),
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr
        }
Exemple #24
0
    def select_action(self, obs):
        if self.is_continuous:
            if self._share_net:
                mu, log_std, value = self.net(obs, rnncs=self.rnncs)  # [B, A]
                self.rnncs_ = self.net.get_rnncs()
            else:
                mu, log_std = self.actor(obs, rnncs=self.rnncs)  # [B, A]
                self.rnncs_ = self.actor.get_rnncs()
                value = self.critic(obs, rnncs=self.rnncs)  # [B, 1]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            action = dist.sample().clamp(-1, 1)  # [B, A]
            log_prob = dist.log_prob(action).unsqueeze(-1)  # [B, 1]
        else:
            if self._share_net:
                logits, value = self.net(obs, rnncs=self.rnncs)  # [B, A], [B, 1]
                self.rnncs_ = self.net.get_rnncs()
            else:
                logits = self.actor(obs, rnncs=self.rnncs)  # [B, A]
                self.rnncs_ = self.actor.get_rnncs()
                value = self.critic(obs, rnncs=self.rnncs)  # [B, 1]
            norm_dist = td.Categorical(logits=logits)
            action = norm_dist.sample()  # [B,]
            log_prob = norm_dist.log_prob(action).unsqueeze(-1)  # [B, 1]

        acts_info = Data(action=action,
                         value=value,
                         log_prob=log_prob + th.finfo().eps)
        if self.use_rnn:
            acts_info.update(rnncs=self.rnncs)
        return action, acts_info
Exemple #25
0
 def squashed_diagonal_gaussian_head(x):
     mean, log_scale = torch.chunk(x, 2, dim=-1)
     log_scale = torch.clamp(log_scale, -20.0, 2.0)
     var = torch.exp(log_scale * 2)
     base_distribution = distributions.Independent(
         distributions.Normal(loc=mean, scale=torch.sqrt(var)), 1)
     return base_distribution
    def __init__(self, x, layers, num_components=100, device=None, old=False):
        super(VAE_bodies, self).__init__()

        self.device = device

        self.p = int(layers[0])  # Dimension of x
        self.d = int(layers[-1])  # Dimension of z
        self.h = layers  # [1:-1] # Dimension of hidden layers
        self.num_components = num_components

        enc = []
        for k in range(len(layers) - 1):
            in_features = int(layers[k])
            out_features = int(layers[k + 1])
            enc.append(
                nnj.ResidualBlock(nnj.Linear(in_features, out_features),
                                  nnj.Softplus()))
        enc.append(nnj.Linear(out_features, int(self.d * 2)))

        dec = []
        for k in reversed(range(len(layers) - 1)):
            in_features = int(layers[k + 1])
            out_features = int(layers[k])
            if not old:  # temporary to load old models TODO: delete
                if out_features != layers[0]:
                    dec.append(
                        nnj.ResidualBlock(
                            nnj.Linear(in_features, out_features),
                            nnj.Softplus()))
                else:
                    dec.append(
                        nnj.ResidualBlock(
                            nnj.Linear(in_features, out_features),
                            nnj.Sigmoid()))
            else:
                dec.append(
                    nnj.ResidualBlock(nnj.Linear(in_features, out_features),
                                      nnj.Softplus()))
                if out_features == layers[0]:
                    dec.append(nnj.Sigmoid())

        # Note how we use 'nnj' instead of 'nn' -- this gives automatic
        # computation of Jacobians of the implemented neural network.
        # The embed function is required to also return Jacobians if
        # requested; by using 'nnj' this becomes a trivial constraint.
        self.encoder = nnj.Sequential(*enc)

        self.decoder_loc = nnj.Sequential(*dec)
        self.init_decoder_scale = 0.01 * torch.ones(self.p, device=self.device)

        self.prior_loc = torch.zeros(self.d, device=self.device)
        self.prior_scale = torch.ones(self.d, device=self.device)
        self.prior = td.Independent(
            td.Normal(loc=self.prior_loc, scale=self.prior_scale), 1)

        # Create a blank std-network.
        # It is important to call init_std after training the mean, but before training the std
        self.dec_std = None

        self.to(self.device)
def miwae_loss(iota_x, mask, d, K, p_z, encoder, decoder):
    batch_size = iota_x.shape[0]
    p = iota_x.shape[1]
    out_encoder = encoder(iota_x)
    q_zgivenxobs = td.Independent(
        td.Normal(loc=out_encoder[..., :d],
                  scale=torch.nn.Softplus()(out_encoder[..., d:(2 * d)])), 1)

    zgivenx = q_zgivenxobs.rsample([K])
    zgivenx_flat = zgivenx.reshape([K * batch_size, d])

    out_decoder = decoder(zgivenx_flat)
    all_means_obs_model = out_decoder[..., :p]
    all_scales_obs_model = torch.nn.Softplus()(out_decoder[...,
                                                           p:(2 * p)]) + 0.001
    all_degfreedom_obs_model = torch.nn.Softplus()(
        out_decoder[..., (2 * p):(3 * p)]) + 3

    data_flat = torch.Tensor.repeat(iota_x, [K, 1]).reshape([-1, 1])
    tiledmask = torch.Tensor.repeat(mask, [K, 1])

    all_log_pxgivenz_flat = torch.distributions.StudentT(
        loc=all_means_obs_model.reshape([-1, 1]),
        scale=all_scales_obs_model.reshape([-1, 1]),
        df=all_degfreedom_obs_model.reshape([-1, 1])).log_prob(data_flat)
    all_log_pxgivenz = all_log_pxgivenz_flat.reshape([K * batch_size, p])

    logpxobsgivenz = torch.sum(all_log_pxgivenz * tiledmask,
                               1).reshape([K, batch_size])
    logpz = p_z.log_prob(zgivenx)
    logq = q_zgivenxobs.log_prob(zgivenx)

    neg_bound = -torch.mean(torch.logsumexp(logpxobsgivenz + logpz - logq, 0))

    return neg_bound
Exemple #28
0
    def forward(self, input: torch.Tensor,
                proposal: distributions.RelaxedOneHotCategorical,
                proposal_sample: torch.Tensor,
                reconstruction: torch.Tensor) -> torch.Tensor:

        if self.likelihood == 'bernoulli':
            likelihood = distributions.Bernoulli(probs=reconstruction)
        else:
            likelihood = distributions.Normal(reconstruction,
                                              torch.ones_like(reconstruction))

        likelihood = distributions.Independent(likelihood,
                                               reinterpreted_batch_ndims=-1)
        reconstruction_loss = likelihood.log_prob(input).mean()

        assert proposal.logits.dim(
        ) == 2, "proposal.shape == [*, D], D is shape of isotopic gaussian"

        prior = distributions.RelaxedOneHotCategorical(proposal.temperature,
                                                       logits=torch.ones_like(
                                                           proposal.logits))
        regularization = (proposal.log_prob(proposal_sample) - prior.log_prob(proposal_sample)) \
            .mean()

        # evidence lower bound (maximize)
        total_loss = reconstruction_loss - self.beta * regularization

        return -total_loss, -reconstruction_loss, regularization
Exemple #29
0
    def log_prob(self, locations_3d, x_offset_3d, y_offset_3d, z_offset_3d,
                 intensities_3d):
        xyzi, counts, s_mask = get_true_labels(locations_3d, x_offset_3d,
                                               y_offset_3d, z_offset_3d,
                                               intensities_3d)
        x_mu, y_mu, z_mu, i_mu = (i.unsqueeze(1)
                                  for i in torch.unbind(self.xyzi_mu, dim=1))
        x_si, y_si, z_si, i_si = (
            i.unsqueeze(1) for i in torch.unbind(self.xyzi_sigma, dim=1))

        P = torch.sigmoid(self.logits) + 0.00001
        count_mean = P.sum(dim=[2, 3, 4]).squeeze(-1)
        count_var = (P - P**2).sum(dim=[2, 3, 4]).squeeze(
            -1)  #avoid situation where we have perfect match
        count_dist = D.Normal(count_mean, torch.sqrt(count_var))
        count_prob = count_dist.log_prob(counts)
        mixture_probs = P / P.sum(dim=[1, 2, 3], keepdim=True)

        xyz_mu_list, _, _, i_mu_list, x_sigma_list, y_sigma_list, z_sigma_list, i_sigma_list, mixture_probs_l = img_to_coord(
            P, x_mu, y_mu, z_mu, i_mu, x_si, y_si, z_si, i_si, mixture_probs)
        xyzi_mu = torch.cat((xyz_mu_list, i_mu_list), dim=-1)
        xyzi_sigma = torch.cat(
            (x_sigma_list, y_sigma_list, z_sigma_list, i_sigma_list),
            dim=-1)  #to avoind NAN
        mix = D.Categorical(mixture_probs_l.squeeze(-1))
        comp = D.Independent(D.Normal(xyzi_mu, xyzi_sigma), 1)
        spatial_gmm = D.MixtureSameFamily(mix, comp)
        spatial_prob = spatial_gmm.log_prob(xyzi.transpose(0,
                                                           1)).transpose(0, 1)
        spatial_prob = (spatial_prob * s_mask).sum(-1)
        log_prob = count_prob + spatial_prob
        return log_prob
Exemple #30
0
    def _goal_likelihood(self, y: torch.Tensor, goal: torch.Tensor,
                         **hyperparams) -> torch.Tensor:
        """Returns the goal-likelihood of a plan `y`, given `goal`.
        Args:
          y: A plan under evaluation, with shape `[B, T, 2]`.
          goal: The goal locations, with shape `[B, K, 2]`.
          hyperparams: (keyword arguments) The goal-likelihood hyperparameters.

        Returns:
          The log-likelihodd of the plan `y` under the `goal` distribution.
        """
        # Parses tensor dimensions.
        B, K, _ = goal.shape

        # Fetches goal-likelihood hyperparameters.
        epsilon = hyperparams.get("epsilon", 1.0)

        # TODO(filangel): implement other goal likelihoods from the DIM paper
        # Initializes the goal distribution.
        goal_distribution = D.MixtureSameFamily(
            mixture_distribution=D.Categorical(
                probs=torch.ones((B, K)).to(goal.device)),
            component_distribution=D.Independent(
                D.Normal(loc=goal, scale=torch.ones_like(goal) * epsilon),
                reinterpreted_batch_ndims=1,
            ))

        return torch.mean(goal_distribution.log_prob(y[:, -1, :]), dim=0)