示例#1
0
 def evaluate_true(self, X: Tensor) -> Tensor:
     r"""Evaluate the GMMs."""
     # This needs to be reinstantiated because MVN apparently does not
     # have a `to` method to make it device/dtype agnostic.
     mvn = MultivariateNormal(loc=self.gmm_pos, covariance_matrix=self.gmm_covar)
     view_shape = (
         X.shape[:-1]
         + torch.Size([1] * (self.gmm_pos.ndim - 1))
         + self.gmm_pos.shape[-1:]
     )
     expand_shape = X.shape[:-1] + self.gmm_pos.shape
     pdf_X = mvn.log_prob(X.view(view_shape).expand(expand_shape)).exp()
     # Multiply by -1 to make this a minimization problem by default
     return -(self.gmm_norm * pdf_X).sum(dim=-1)
示例#2
0
    def act(self, state):

        if self.has_continuous_action_space:
            action_mean = self.actor(state)
            cov_mat = torch.diag(self.action_var).unsqueeze(dim=0)
            dist = MultivariateNormal(action_mean, cov_mat)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)

        action = dist.sample()
        action_logprob = dist.log_prob(action)

        return action.detach(), action_logprob.detach()
示例#3
0
 def _get_dist(self):
     """
     Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds
     to a :class:`LinearlyCoupledMaternGP`.
     """
     trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(
         dt=self.dt)
     trans_matrix = block_diag_embed(trans_matrix)
     process_covar = block_diag_embed(process_covar)
     loc = self.A.new_zeros(self.full_state_dim)
     trans_dist = MultivariateNormal(loc, process_covar)
     return dist.GaussianHMM(self._get_init_dist(), trans_matrix,
                             trans_dist, self._get_obs_matrix(),
                             self._get_obs_dist())
    def test_multivariate_normal_prior_log_prob(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        mean = torch.tensor([0.0, 1.0], device=device)
        cov = torch.eye(2, device=device)
        prior = MultivariateNormalPrior(mean, covariance_matrix=cov)
        dist = MultivariateNormal(mean, covariance_matrix=cov)

        self.assertFalse(prior.log_transform)
        t = torch.tensor([-1, 0.5], device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
        t = torch.tensor([[-1, 0.5], [1.5, -2.0]], device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
        with self.assertRaises(RuntimeError):
            prior.log_prob(torch.zeros(3, device=device))
示例#5
0
def test_api_snpe_c_posterior_correction(
    sample_with_mcmc, mcmc_method, prior_str, set_seed
):
    """Test that leakage correction applied to sampling works, with both MCMC and
    rejection.

    Args:
        set_seed: fixture for manual seeding
    """

    num_dim = 2
    x_o = zeros(1, num_dim)

    # likelihood_mean will be likelihood_shift+theta
    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.3 * eye(num_dim)

    if prior_str == "gaussian":
        prior_mean = zeros(num_dim)
        prior_cov = eye(num_dim)
        prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
    else:
        prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim))

    def simulator(theta):
        return linear_gaussian(theta, likelihood_shift, likelihood_cov)

    simulator, prior = prepare_for_sbi(simulator, prior)
    inference = SNPE_C(
        prior,
        density_estimator="maf",
        simulation_batch_size=50,
        sample_with_mcmc=sample_with_mcmc,
        mcmc_method=mcmc_method,
        show_progress_bars=False,
    )

    theta, x = simulate_for_sbi(simulator, prior, 1000)
    _ = inference.append_simulations(theta, x).train(max_num_epochs=5)
    posterior = inference.build_posterior()
    posterior = posterior.set_sample_with_mcmc(sample_with_mcmc).set_mcmc_method(
        mcmc_method
    )

    # Posterior should be corrected for leakage even if num_rounds just 1.
    samples = posterior.sample((10,), x=x_o)

    # Evaluate the samples to check correction factor.
    posterior.log_prob(samples, x=x_o)
def test_api_snl_sampling_methods(sampling_method: str, prior_str: str,
                                  set_seed):
    """Runs SNL on linear Gaussian and tests sampling from posterior via mcmc.

    Args:
        mcmc_method: which mcmc method to use for sampling
        prior_str: use gaussian or uniform prior
        set_seed: fixture for manual seeding
    """

    num_dim = 2
    num_samples = 10
    num_trials = 2
    # HMC with uniform prior needs good likelihood.
    num_simulations = 10000 if sampling_method == "hmc" else 1000
    x_o = zeros((num_trials, num_dim))
    # Test for multiple chains is cheap when vectorized.
    num_chains = 3 if sampling_method == "slice_np_vectorized" else 1
    if sampling_method == "rejection":
        sample_with = "rejection"
    else:
        sample_with = "mcmc"

    if prior_str == "gaussian":
        prior = MultivariateNormal(loc=zeros(num_dim),
                                   covariance_matrix=eye(num_dim))
    else:
        prior = utils.BoxUniform(-1.0 * ones(num_dim), ones(num_dim))

    simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior)
    inference = SNL(prior, show_progress_bars=False)

    theta, x = simulate_for_sbi(simulator,
                                prior,
                                num_simulations,
                                simulation_batch_size=50)
    _ = inference.append_simulations(theta, x).train(max_num_epochs=5)
    posterior = inference.build_posterior(
        sample_with=sample_with,
        mcmc_method=sampling_method).set_default_x(x_o)

    posterior.sample(
        sample_shape=(num_samples, ),
        x=x_o,
        mcmc_parameters={
            "thin": 3,
            "num_chains": num_chains
        },
    )
示例#7
0
文件: snpe_c.py 项目: bkmi/sbi
    def _set_maybe_z_scored_prior(self) -> None:
        r"""Compute and store potentially standardized prior (if z-scoring was done).

        The proposal posterior is:
        $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$

        Let's denote z-scored theta by `a`: a = (theta - mean) / std
        Then pp'(a|x) = 1/Z_2 * q'(a|x) * prop'(a) / p'(a)$

        The ' indicates that the evaluation occurs in standardized space. The constant
        scaling factor has been absorbed into Z_2.
        From the above equation, we see that we need to evaluate the prior **in
        standardized space**. We build the standardized prior in this function.

        The standardize transform that is applied to the samples theta does not use
        the exact prior mean and std (due to implementation issues). Hence, the z-scored
        prior will not be exactly have mean=0 and std=1.
        """

        if self.z_score_theta:
            scale = self._neural_net._transform._transforms[0]._scale
            shift = self._neural_net._transform._transforms[0]._shift

            # Following the definintion of the linear transform in
            # `standardizing_transform` in `sbiutils.py`:
            # shift=-mean / std
            # scale=1 / std
            # Solving these equations for mean and std:
            estim_prior_std = 1 / scale
            estim_prior_mean = -shift * estim_prior_std

            # Compute the discrepancy of the true prior mean and std and the mean and
            # std that was empirically estimated from samples.
            # N(theta|m,s) = N((theta-m_e)/s_e|(m-m_e)/s_e, s/s_e)
            # Above: m,s are true prior mean and std. m_e,s_e are estimated prior mean
            # and std (estimated from samples and used to build standardize transform).
            almost_zero_mean = (self._prior.mean -
                                estim_prior_mean) / estim_prior_std
            almost_one_std = torch.sqrt(self._prior.variance) / estim_prior_std

            if isinstance(self._prior, MultivariateNormal):
                self._maybe_z_scored_prior = MultivariateNormal(
                    almost_zero_mean, torch.diag(almost_one_std))
            else:
                range_ = torch.sqrt(almost_one_std * 3.0)
                self._maybe_z_scored_prior = utils.BoxUniform(
                    almost_zero_mean - range_, almost_zero_mean + range_)
        else:
            self._maybe_z_scored_prior = self._prior
示例#8
0
    def forward(self, input_, action=None):
        """
        
        """

        x = torch.relu(self.fc1(input_))
        x = self.bn1(x)
        x = torch.relu(self.fc2(torch.cat([x, input_],dim=1)))
        x = self.bn2(x)
        x = torch.relu(self.fc3(torch.cat([x, input_],dim=1)))
        x = self.bn3(x)
        x = torch.relu(self.fc4(torch.cat([x, input_],dim=1)))

        action_value = torch.tanh(self.action_values(x))
        entries = torch.tanh(self.matrix_entries(x))
        V = self.value(x)
        
        action_value = action_value.unsqueeze(-1)
        
        # create lower-triangular matrix
        L = torch.zeros((input_.shape[0], self.action_size, self.action_size)).to(device)

        # get lower triagular indices
        tril_indices = torch.tril_indices(row=self.action_size, col=self.action_size, offset=0)  

        # fill matrix with entries
        L[:, tril_indices[0], tril_indices[1]] = entries
        L.diagonal(dim1=1,dim2=2).exp_()

        # calculate state-dependent, positive-definite square matrix
        P = L*L.transpose(2, 1)
        
        Q = None
        if action is not None:  

            # calculate Advantage:
            A = (-0.5 * torch.matmul(torch.matmul((action.unsqueeze(-1) - action_value).transpose(2, 1), P), (action.unsqueeze(-1) - action_value))).squeeze(-1)

            Q = A + V   
        
        
        # add noise to action mu:
        dist = MultivariateNormal(action_value.squeeze(-1), torch.inverse(P))
        #dist = Normal(action_value.squeeze(-1), 1)
        action = dist.sample()
        action = torch.clamp(action, min=-1, max=1)
        #wandb.log({"Action Noise": action.detach().cpu().numpy() - action_value.squeeze(-1).detach().cpu().numpy()})

        return action, Q, V
示例#9
0
def test_c2st_snle_external_data_on_linearGaussian(set_seed):
    """Test whether SNPE C infers well a simple example with available ground truth.

    Args:
        set_seed: fixture for manual seeding
    """

    num_dim = 2

    device = "cpu"
    configure_default_device(device)
    x_o = zeros(1, num_dim)
    num_samples = 1000

    # likelihood_mean will be likelihood_shift+theta
    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.3 * eye(num_dim)

    prior_mean = zeros(num_dim)
    prior_cov = eye(num_dim)
    prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
    gt_posterior = true_posterior_linear_gaussian_mvn_prior(
        x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov)
    target_samples = gt_posterior.sample((num_samples, ))

    def simulator(theta):
        return linear_gaussian(theta, likelihood_shift, likelihood_cov)

    infer = SNL(
        *prepare_for_sbi(simulator, prior),
        simulation_batch_size=1000,
        show_progress_bars=False,
        device=device,
    )

    external_theta = prior.sample((1000, ))
    external_x = simulator(external_theta)

    infer.provide_presimulated(external_theta, external_x)

    posterior = infer(
        num_rounds=1,
        num_simulations_per_round=1000,
        training_batch_size=100,
    ).set_default_x(x_o)
    samples = posterior.sample((num_samples, ))

    # Compute the c2st and assert it is near chance level of 0.5.
    check_c2st(samples, target_samples, alg="snpe_c")
    def loss(
            self,
            oh,
            ce,
            mask,
            recon_oh,
            recon_ce_mean,
            recon_ce_log_var,
            gamma_d,
            z_mean,
            z_log_var,
            avg=True):
        # NL1 for oh
        NL1 = -(oh * (recon_oh + self.det).log()
                ).sum(1)  # cross entropy loss
        # NL2 for ce
        dist = MultivariateNormal(
            loc=recon_ce_mean,
            covariance_matrix=torch.diag_embed(
                recon_ce_log_var.exp().sqrt()))
        NL2 = (-dist.log_prob(ce.transpose(0, 1)).transpose(0, 1) * mask).sum(1)
        NL = NL1 + NL2

        # KLD_for pi
        KLD1 = -torch.sum(gamma_d *
                          torch.log(self.pi.unsqueeze(0) / gamma_d + self.det), 1)
        # KLD2 for all domains
        logvar_division = self.log_var_d.unsqueeze(0)
        var_division = torch.exp(
            z_log_var.unsqueeze(1) -
            self.log_var_d.unsqueeze(0))
        diff = z_mean.unsqueeze(1) - self.mean_d.unsqueeze(0)
        diff_term = diff.pow(2) / torch.exp(self.log_var_d.unsqueeze(0))
        KLD21 = torch.sum(
            logvar_division + var_division + diff_term,
            2)
        KLD21 = 0.5 * torch.sum(gamma_d * KLD21, 1)
        KLD22 = -0.5 * torch.sum(1 + z_log_var, 1)
        KLD2 = KLD21 + KLD22
        KLD = KLD1 + KLD2

        loss = NL + KLD

        # in training mode, return averaged loss. In testing mode, return
        # individual loss
        if avg:
            return loss.mean()
        else:
            return loss
示例#11
0
def test_conditional_corrcoeff(corr):
    """
    Test whether the conditional correlation coefficient is computed correctly.
    """
    d = MultivariateNormal(
        torch.tensor([0.6, 5.0]), torch.tensor([[0.1, corr], [corr, 10.0]])
    )
    estimated_corr = conditional_corrcoeff(
        density=d,
        condition=torch.ones(1, 2),
        limits=torch.tensor([[-2.0, 3.0], [-70, 90]]),
        resolution=500,
    )[0, 1]

    assert torch.abs(corr - estimated_corr) < 1e-3
示例#12
0
def test_conditional_pairplot():
    """
    This only tests whether `conditional.pairplot()` runs without errors. If does not
    test its correctness. See `test_conditional_density_2d` for a test on
    `eval_conditional_density`, which is the core building block of
    `conditional.pairplot()`
    """
    d = MultivariateNormal(
        torch.tensor([0.6, 5.0]), torch.tensor([[0.1, 0.99], [0.99, 10.0]])
    )
    _ = conditional_pairplot(
        density=d,
        condition=torch.ones(1, 2),
        limits=torch.tensor([[-1.0, 1.0], [-30, 30]]),
    )
    def _pred_point(self, goal_embed, im_shape, min_std=0.03):
        if self._2_point is None:
            return
        
        point_dist = self._2_point(goal_embed[:,0])
        mu = point_dist[:,:2]
        c1, c2, c3 = F.softplus(point_dist[:,2])[:,None], point_dist[:,3][:,None], F.softplus(point_dist[:,4])[:,None]
        scale_tril = torch.cat((c1 + min_std, torch.zeros_like(c2), c2, c3 + min_std), dim=1).reshape((-1, 2, 2))
        mu, scale_tril = [x.unsqueeze(1).unsqueeze(1) for x in (mu, scale_tril)]
        point_dist = MultivariateNormal(mu, scale_tril=scale_tril)

        h = torch.linspace(-1, 1, im_shape[0]).reshape((1, -1, 1, 1)).repeat((1, 1, im_shape[1], 1))
        w = torch.linspace(-1, 1, im_shape[1]).reshape((1, 1, -1, 1)).repeat((1, im_shape[0], 1, 1))
        hw = torch.cat((h, w), 3).repeat((goal_embed.shape[0], 1, 1, 1)).to(goal_embed.device)
        return point_dist.log_prob(hw)
示例#14
0
    def act(self, state, memory=None):
        # action_mean = self.actor(state)
        action_mean = self.actor(torch.FloatTensor(state).to(device))
        cov_mat = torch.diag(self.action_var).to(device)

        dist = MultivariateNormal(action_mean, cov_mat)
        action = dist.sample()
        action_logprob = dist.log_prob(action)

        if (memory):
            memory.states.append(state)
            memory.actions.append(action)
            memory.logprobs.append(action_logprob)

        return action.detach()
示例#15
0
    def evaluate(self, states, actions):
        #states = torch.stack(states)
        states_input = states.view(-1, *states.shape[-3:])
        actor_critic_input = self.conv(states_input).view(-1, self.size)
        action_mean = self.actor(actor_critic_input)
        action_var = self.action_var.expand_as(action_mean)
        cov_mat = torch.diag_embed(action_var).to(device)
        dist = MultivariateNormal(action_mean, cov_mat)

        actions = actions.view(-1, self.action_dim)
        action_logprobs = dist.log_prob(actions).view(states.shape[:-3])

        dist_entropy = dist.entropy().view(states.shape[:-3])
        state_value = self.critic(actor_critic_input).view(states.shape[:-3])
        return action_logprobs, torch.squeeze(state_value), dist_entropy
示例#16
0
    def log_prob(self, x):
        log_prob = 0
        for layer in self.layers[::-1]:
            x, log_prob_change = layer.g(x)
            log_prob = log_prob_change + log_prob

        if self.prior is None:
            norm_prior = MultivariateNormal(torch.zeros(self.num_vars).to(x.device),
                                            torch.eye(self.num_vars).to(x.device))

            log_prob += norm_prior.log_prob(x)
        else:
            log_prob += self.prior.log_prob(x)

        return log_prob
示例#17
0
def _construct_mvn(x, w):
    """
    Constructs a multivariate normal distribution of weighted samples.
    :param x: The samples
    :type x: torch.Tensor
    :param w: The weights
    :type w: torch.Tensor
    :rtype: MultivariateNormal
    """

    mean = (x * w.unsqueeze(-1)).sum(0)
    centralized = x - mean
    cov = torch.matmul(w * centralized.t(), centralized)

    return MultivariateNormal(mean, scale_tril=torch.cholesky(cov))
示例#18
0
文件: joints.py 项目: AlexImmer/VIND
 def __init__(self, data, mu_prior, alpha_prior, W_df_prior, W_prior,
              G_df_prior, rate_prior):
     d = W_prior.shape[0]
     self.W_prior = Wishart({
         'df':
         torch.tensor([W_df_prior], dtype=torch.float64),
         'W':
         torch.from_numpy(W_prior.astype(np.float64))
     })
     self.nu_prior = Gamma(torch.tensor([G_df_prior], dtype=torch.float64),
                           torch.tensor([rate_prior], dtype=torch.float64))
     self.mu_prior = MultivariateNormal(
         loc=mu_prior * torch.ones(d, dtype=torch.float64),
         covariance_matrix=alpha_prior * torch.eye(d, dtype=torch.float64))
     self.data = data
示例#19
0
    def get_dist(self, duration=None):
        """
        Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds
        to ``obs_dim``-many independent Matern GPs.

        :param int duration: Optional size of the time axis ``event_shape[0]``.
            This is required when sampling from homogeneous HMMs whose parameters
            are not expanded along the time axis.
        """
        trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt)
        trans_dist = MultivariateNormal(self.obs_matrix.new_zeros(self.obs_dim, 1, self.kernel.state_dim),
                                        process_covar.unsqueeze(-3))
        trans_matrix = trans_matrix.unsqueeze(-3)
        return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist,
                                self.obs_matrix, self._get_obs_dist(), duration=duration)
示例#20
0
    def __init__(self, **kwargs):
        super(CPGGAN, self).__init__(**kwargs)
        if self.mode == 'train':
            # path to numpy arrays containing the calculated mean and cov matrices for calculating a multivariate gaussian
            path_to_lm_mean = kwargs.get('lm_mean', ARRAY_LANDMARKS_28_MEAN)
            path_to_lm_cov = kwargs.get('lm_cov', ARRAY_LANDMARKS_28_COV)
            path_to_lr_mean = kwargs.get('lr_mean', ARRAY_LOWRES_4_MEAN)
            path_to_lr_cov = kwargs.get('lr_cov', ARRAY_LOWRES_4_COV)

            # ==================================================
            # Currently only preparation for extensions
            # ==================================================
            # gaussian distribution of our landmarks
            self.landmarks_mean = np.load(path_to_lm_mean)
            self.landmarks_cov = np.load(path_to_lm_cov)
            self.landmarks_mean = torch.from_numpy(self.landmarks_mean)
            self.landmarks_cov = torch.from_numpy(self.landmarks_cov)
            self.distribution_landmarks = MultivariateNormal(
                loc=self.landmarks_mean.type(torch.float64),
                covariance_matrix=self.landmarks_cov.type(torch.float64))
            # gaussian distribution of our low res pixel map
            self.lowres_mean = np.load(path_to_lr_mean)
            self.lowres_cov = np.load(path_to_lr_cov)
            self.lowres_mean = torch.from_numpy(self.lowres_mean)
            self.lowres_cov = torch.from_numpy(self.lowres_cov)
            self.distribution_lowres = MultivariateNormal(
                loc=self.lowres_mean.type(torch.float64),
                covariance_matrix=self.lowres_cov.type(torch.float64))
            # static noise for calculating the validation
            self.static_landmarks = 2 * (self.distribution_landmarks.sample(
                (self.batch_size, )).type(torch.float32) - 0.5)
            self.static_lowres = 2 * (self.distribution_lowres.sample(
                (self.batch_size, )).type(torch.float32) - 0.5)

        # Static noise for anonymization
        self.anonymization_noise = self.noise(1)
示例#21
0
    def evaluate(self, state, action):
        action_mean = torch.squeeze(self.actor(state))

        action_var = self.action_var.expand_as(action_mean)

        cov_mat = torch.diag_embed(action_var).to(device)
        #cov_mat = torch.diag(action_var).to(device)

        dist = MultivariateNormal(action_mean, cov_mat)

        action_logprobs = dist.log_prob(torch.squeeze(action))
        dist_entropy = dist.entropy()
        state_value = self.critic(state)

        return action_logprobs, torch.squeeze(state_value), dist_entropy
示例#22
0
文件: gp.py 项目: jamestwebber/pyro
 def forecast(self, targets, dts):
     """
     :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets
         of shape `(T, obs_dim)`, where `T` is the length of the time series and `obs_dim`
         is the dimension of the real-valued targets at each time step. These
         represent the training data that are conditioned on for the purpose of making
         forecasts.
     :param torch.Tensor dts: A 1-dimensional tensor of times to forecast into the future,
         with zero corresponding to the time of the final target `targets[-1]`.
     :returns torch.distributions.MultivariateNormal: Returns a predictive MultivariateNormal
         distribution with batch shape `(S,)` and event shape `(obs_dim,)`, where `S` is the size of `dts`.
     """
     filtering_state = self._filter(targets)
     predicted_mean, predicted_covar = self._forecast(dts, filtering_state)
     return MultivariateNormal(predicted_mean, predicted_covar)
    def act(self, state, memory):
        action_mean = self.actor(state)
        cov_mat = torch.diag(self.action_var).to(device)

        dist = MultivariateNormal(action_mean, cov_mat)

        action = action_mean + self.action_std * torch.randn(action_mean.shape)
        # action = dist.sample()
        action_logprob = dist.log_prob(action)

        memory.states.append(state)
        memory.actions.append(action)
        memory.logprobs.append(action_logprob)

        return action.detach()
示例#24
0
    def act(self, ob, state, memory):
        latent = self.encoder(ob)
        state = torch.cat((latent, state), 1)
        action_mean = self.actor(state)
        dist = MultivariateNormal(action_mean,
                                  torch.diag(self.action_var).to(device))
        action = dist.sample()
        action_logprob = dist.log_prob(action)

        memory.states.append(state)
        # memory.obs.append(ob)
        memory.actions.append(action)
        memory.logprobs.append(action_logprob)

        return action.detach()
示例#25
0
    def log_ratio(self, x, px_mean, px_var, qz_m, qz_v, z, return_full=False):
        if self.learn_var:
            log_px_z = (Normal(px_mean, torch.sqrt(px_var)).log_prob(
                x.repeat(px_mean.shape[0], 1, 1)).sum(dim=-1).view(
                    (px_mean.shape[0], -1)))
            log_pz = (Normal(torch.zeros_like(qz_m),
                             torch.ones_like(qz_m)).log_prob(z).sum(dim=-1))
            if not self.linear_encoder:
                log_qz_given_x = Normal(qz_m,
                                        qz_v.sqrt()).log_prob(z).sum(dim=-1)
            else:
                log_qz_given_x = MultivariateNormal(qz_m, qz_v).log_prob(
                    z)  # No need to sum over latent dim
            log_pxz = log_px_z + log_pz

        else:
            zx = torch.cat([z, x.repeat(px_mean.shape[0], 1, 1)], dim=-1)
            reshape_dim = x.shape[-1] + z.shape[-1]

            if not self.linear_encoder:
                log_qz_given_x = Normal(qz_m,
                                        qz_v.sqrt()).log_prob(z).sum(dim=-1)
            else:
                log_qz_given_x = MultivariateNormal(qz_m, qz_v).log_prob(z)
            log_pxz = self.log_normal_full(
                zx.view((-1, reshape_dim)),
                torch.zeros_like(zx.view((-1, reshape_dim))),
                self.log_det_pxz,
                self.inv_sqrt_pxz,
            ).view((px_mean.shape[0], -1))

        log_ratio = log_pxz - log_qz_given_x
        if return_full:
            return log_ratio, log_pxz, log_qz_given_x
        else:
            return log_ratio
示例#26
0
    def act(self, state, memory):
        action_mean = self._action_mean(state)

        cov_mat = torch.diag(self.action_var).to(device)
        dist = MultivariateNormal(action_mean, cov_mat)

        action = dist.sample()
        # action = F.softmax(action.reshape(2,-1)).reshape(1,-1)
        action_logprob = dist.log_prob(action)

        memory.states.append(state)
        memory.actions.append(action)
        memory.logprobs.append(action_logprob)

        return action.detach()
def test_multivariate_normal() -> None:
    num_samples = 2000
    dim = 2

    mu = np.arange(0, dim) / float(dim)

    L_diag = np.ones((dim, ))
    L_low = 0.1 * np.ones((dim, dim)) * np.tri(dim, k=-1)
    L = np.diag(L_diag) + L_low
    Sigma = L.dot(L.transpose())

    distr = MultivariateNormal(loc=torch.Tensor(mu),
                               scale_tril=torch.Tensor(L))

    samples = distr.sample((num_samples, ))

    mu_hat, L_hat = maximum_likelihood_estimate_sgd(
        MultivariateNormalOutput(dim=dim),
        samples,
        init_biases=
        None,  # todo we would need to rework biases a bit to use it in the multivariate case
        learning_rate=0.01,
        num_epochs=10,
    )

    distr = MultivariateNormal(loc=torch.tensor(mu_hat),
                               scale_tril=torch.tensor(L_hat))

    Sigma_hat = distr.covariance_matrix.numpy()

    assert np.allclose(
        mu_hat, mu, atol=0.1,
        rtol=0.1), f"mu did not match: mu = {mu}, mu_hat = {mu_hat}"
    assert np.allclose(
        Sigma_hat, Sigma, atol=0.1, rtol=0.1
    ), f"Sigma did not match: sigma = {Sigma}, sigma_hat = {Sigma_hat}"
示例#28
0
    def act(self, state, memory, stochastic=True):
        action_mean = self.actor(state)

        cov_mat = torch.diag(self.action_var).to(device)
        dist = MultivariateNormal(action_mean, cov_mat)
        action = dist.sample()
        action_logprob = dist.log_prob(action)
        if not stochastic:
            action = action_mean
            
        memory.states.append(state)
        memory.actions.append(action)
        memory.logprobs.append(action_logprob)
        
        return action.detach()
示例#29
0
def _construct_mvn(x: torch.Tensor, w: torch.Tensor, scale=1.0):
    """
    Constructs a multivariate normal distribution of weighted samples.
    """

    mean = (x * w.unsqueeze(-1)).sum(0)
    centralized = x - mean
    cov = torch.matmul(w * centralized.t(), centralized)

    if cov.det() == 0.0:
        chol = cov.diag().sqrt().diag()
    else:
        chol = cov.cholesky()

    return MultivariateNormal(mean, scale_tril=scale * chol)
示例#30
0
    def act(self, state, opponent_state):
        if self.has_continuous_action_space:
            pre_mean, pre_sigma = self.om(opponent_state)
            # print('sigma: ', pre_sigma)
            pre_var = pre_sigma**2
            pre_var = pre_var.repeat(1, 2).to(device)
            pre_mat = torch.diag_embed(pre_var).to(device)
            pre_dist = MultivariateNormal(pre_mean, pre_mat)
            pre_action = pre_dist.sample()
            pre_action = pre_action.clamp(-1, 1)
            action_mean, action_sigma = self.actor(state, pre_action[0])
            action_var = action_sigma**2
            action_var = action_var.repeat(1, 2).to(device)
            cov_mat = torch.diag_embed(action_var).to(device)
            dist = MultivariateNormal(action_mean, cov_mat)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)

        action = dist.sample()
        action = action.clamp(-1, 1)
        action_logprob = dist.log_prob(action)

        return action.detach(), action_logprob.detach(), pre_action.detach()