示例#1
0
class VariationalLoss(nn.Module):
    def __init__(self, distribution: TargetDistribution):
        super().__init__()
        self.distr = distribution
        self.base_distr = MultivariateNormal(torch.zeros(2), torch.eye(2))

    def forward(self, z0: Tensor, z: Tensor, sum_log_det_J: float) -> float:
        base_log_prob = self.base_distr.log_prob(z0)
        target_density_log_prob = -self.distr(z)
        return (base_log_prob - target_density_log_prob - sum_log_det_J).mean()
    def evaluate(self, states, actions):
        action_means = -states * self.agent[0].weight[:, 0] / self.agent[
            0].weight[:, 1]

        # action_var = torch.full((action_dim,), 0.5 * self.tau)
        action_var = 0.5 * self.tau / self.agent[0].weight[:, 1] * self.agent[
            0].weight[:, 1]
        action_var = action_var.expand_as(action_means)

        cov_mat = torch.diag_embed(action_var).to(device)

        dist = MultivariateNormal(action_means, cov_mat)

        action_logprobs = dist.log_prob(actions)
        action_values = self.agent(
            torch.cat((states, actions), dim=1).squeeze())
        dist_entropy = dist.entropy()

        return action_logprobs, torch.squeeze(action_values), dist_entropy
示例#3
0
    def get_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        mean, log_std = self.forward(state)
        std = log_std.exp()

        loc = torch.zeros(mean.size())
        scale = torch.ones(mean.size())
        if mean.size()[1] == 1:
            normal = Normal(loc, scale)
            z = normal.sample()
        else:
            scale = torch.diag_embed(scale)
            mvn = MultivariateNormal(loc, scale)
            z = mvn.sample()

        action = torch.tanh(mean + std * z)

        action = action.cpu()  # .detach().cpu().numpy()
        return action[0]
示例#4
0
    def evaluate(self, state, action):
        # import pdb; pdb.set_trace()
        x = F.tanh(self.affine1(state))
        x = F.tanh(self.affine2(x))
        alpha = self.alpha_action_mean(x)
        beta = self.beta_action_mean(x)
        action_mean = torch.cat((alpha, beta), dim=1)
        # action_mean = torch.squeeze(x)

        action_var = self.action_var.expand_as(action_mean)  #action_log_std
        cov_mat = torch.diag_embed(action_var).to(device)
        dist = MultivariateNormal(action_mean, cov_mat)

        # action_logprobs = dist.log_prob(torch.squeeze(action))
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_value = self.critic(state)
        # import pdb; pdb.set_trace()
        return action_logprobs, torch.squeeze(state_value), dist_entropy
示例#5
0
    def test_shapes(self):
        B1 = 100
        B2 = 50
        K1 = 20
        K2 = 4
        D = 16

        scale = torch.randn((K1, K2, D, D))
        cov = scale @ scale.transpose(-2, -1) + torch.diag(0.1 * torch.ones(D))
        p = MultivariateNormal(torch.randn((K1, K2, D)), covariance_matrix=cov)
        p_ = MultivariateNormal(loc=p.loc.view(-1, D),
                                scale_tril=p.scale_tril.view(-1, D, D))

        q = Normal(loc=torch.randn((B1, B2, D)), scale=torch.rand((B1, B2, D)))
        q_ = Normal(loc=q.loc.view(-1, D), scale=q.scale.view(-1, D))

        actual_loss = pt.ops.kl_divergence(q, p)
        reference_loss = pt.ops.kl_divergence(q_, p_).view(B1, B2, K1, K2)
        np.testing.assert_allclose(actual_loss, reference_loss, rtol=1e-4)
示例#6
0
    def test_against_multivariate_multivariate(self):
        B = 500
        K = 100
        D = 16

        scale = torch.randn((K, D, D))
        cov = scale @ scale.transpose(1, 2) + torch.diag(0.1 * torch.ones(D))
        p = MultivariateNormal(torch.randn((K, D)), covariance_matrix=cov)

        q = MultivariateNormal(loc=torch.randn((B, 1, D)),
                               scale_tril=torch.Tensor(
                                   np.broadcast_to(np.diag(np.random.rand(D)),
                                                   (B, 1, D, D))))
        q_ = Normal(loc=q.loc[:, 0],
                    scale=pt.ops.losses.loss._batch_diag(q.scale_tril[:, 0]))

        actual_loss = pt.ops.kl_divergence(q_, p)
        reference_loss = kl_divergence(q, p)
        np.testing.assert_allclose(actual_loss, reference_loss, rtol=1e-4)
示例#7
0
    def evaluate(self, state, action):
        action_mean = []
        state_value = []
        no_vehicles = len(state)
        for idx_1 in range(no_vehicles):
            action_mean.append(self.actor(state[idx_1][1], state[idx_1][0]))
            state_value.append(self.critic(state[idx_1][1], state[idx_1][0]))
        action_mean = torch.stack(action_mean).view(-1, 1, 2)
        state_value = torch.stack(state_value).view(-1, 1, 1)

        #action_mean = self.actor(state)
        dist = MultivariateNormal(torch.squeeze(action_mean),
                                  torch.diag(self.action_var))

        action_logprobs = dist.log_prob(torch.squeeze(action))
        dist_entropy = dist.entropy()
        #state_value = self.critic(state)

        return action_logprobs, torch.squeeze(state_value), dist_entropy
示例#8
0
文件: gauss.py 项目: baofff/BiSM
class GaussDataset(Dataset):
    def __init__(self, n, mean, cov):
        self.n = n
        self.dist = MultivariateNormal(loc=mean, covariance_matrix=cov)

    def __len__(self):
        return self.n

    def __getitem__(self, item):
        return self.dist.sample()
示例#9
0
 def dist(self, device):
     """
     The distribution induced by the gen.
     """
     W = self.gen.g.weight.data
     WtW = W @ W.t()
     cov = WtW + torch.eye(
         WtW.size(0)).to(device) * self.gen.logsigma.exp()**2
     mu = self.gen.g.bias
     return MultivariateNormal(mu, cov)
示例#10
0
    def evaluate(self, state, state_randomized, action, randomize):   

        if randomize:
            actor_output,critic_output = self.forward(state_randomized)
        else:
            actor_output,critic_output = self.forward(state)

        action_mean = torch.squeeze(actor_output)
        
        action_var = self.action_var.expand_as(action_mean)
        cov_mat = torch.diag_embed(action_var).to(device)
        
        dist = MultivariateNormal(action_mean, cov_mat)
        
        action_logprobs = dist.log_prob(torch.squeeze(action))
        dist_entropy = dist.entropy()
        state_value = critic_output
        
        return action_logprobs, torch.squeeze(state_value), dist_entropy
示例#11
0
    def get_training_params(self,frame,mes, action):
        frame = torch.stack(frame)
        mes = torch.stack(mes)
        if len(list(frame.size())) > 4:
            frame = torch.squeeze(frame)
        if len(list(mes.size())) > 2:
            mes = torch.squeeze(mes)

        action = torch.stack(action)

        mean = self.actor(frame,mes)
        action_expanded = self.action_var.expand_as(mean)
        cov_matrix = torch.diag_embed(action_expanded).to(device)

        gauss_dist = MultivariateNormal(mean,cov_matrix)
        action_log_prob = gauss_dist.log_prob(action).to(device)
        entropy = gauss_dist.entropy().to(device)
        state_value = torch.squeeze(self.critic(frame,mes)).to(device)
        return action_log_prob, state_value, entropy
示例#12
0
 def _get_init_dist(self):
     loc = self.z_trans_matrix.new_zeros(self.full_state_dim)
     covar = self.z_trans_matrix.new_zeros(self.full_state_dim,
                                           self.full_state_dim)
     covar[:self.full_gp_state_dim, :self.
           full_gp_state_dim] = block_diag_embed(
               self.kernel.stationary_covariance())
     covar[self.full_gp_state_dim:,
           self.full_gp_state_dim:] = self.init_noise_scale_sq.diag_embed()
     return MultivariateNormal(loc, covar)
示例#13
0
    def act(self, state):
        # state1, state2 = state
        # state1 = torch.FloatTensor(state1).to(device)
        # state2 = torch.FloatTensor(state2).to(device)
        if self.has_continuous_action_space:
            action_mean, action_sigma = self.actor(state)
            action_var = action_sigma**2
            cov_mat = torch.diag_embed(action_var).unsqueeze(dim=0)
            dist = MultivariateNormal(action_mean, cov_mat)
            # print(dist)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)

        action = dist.sample()
        action = action.clamp(-1, 1)
        action_logprob = dist.log_prob(action)

        return action.detach(), action_logprob.detach()
示例#14
0
    def _move(self, drift, index):
        """Move a walker.

        Args:
            drift (torch.tensor): drift velocity
            index (int): indx of the electron to move

        Returns:
            torch.tensor: position of the walkers
        """

        d = drift.view(self.nwalkers,
                       self.nelec, self.ndim)

        mv = MultivariateNormal(torch.zeros(self.ndim), np.sqrt(
            self.step_size) * torch.eye(self.ndim))

        return self.step_size * d[range(self.nwalkers), index, :] \
            + mv.sample((self.nwalkers, 1)).squeeze()
示例#15
0
 def forward(self, state):
     output_1 = F.relu(self.linear1(state))
     output_2 = F.relu(self.linear2(output_1))
     mu = 2 * torch.sigmoid(self.mu(output_2))  #有正有负
     sigma = F.relu(
         self.sigma(output_2)
     ) + 0.001  # avoid 0 softplus    output = F.softmax(output, dim=-1)         action_mean = self.linear3(output)
     #cov_mat = torch.diag(self.action_var).to(device)
     mu = torch.diag_embed(mu).to(device)
     sigma = torch.diag_embed(sigma).to(device)  # change to 2D
     dist = MultivariateNormal(
         mu,
         sigma)  #N(μ,σ^2)  σ超参不用训练 MultivariateNormal(action_mean, cov_mat)
     #distribution = Categorical(F.softmax(output, dim=-1))
     entropy = dist.entropy().mean()
     action = dist.sample()
     action_logprob = dist.log_prob(action)
     return action.detach(
     ), action_logprob, entropy  #distribution .detach()
    def select_action(self,state,actor):
        self.action_var = torch.full((2,), 0.6*0.6).to(device)   #manually change action_dim action_std
        no_vehicles = len(state)
        action_list = []
        for idx_1 in range(no_vehicles):
            state1 = state[idx_1]
            target1 = state1[0]
            other_vehicles1 = state1[1] 
            action_mean = actor(other_vehicles1,target1)
            dist = MultivariateNormal(action_mean, torch.diag(self.action_var).to(device))      ##ATTENTION: manually change variance in torch.diag(var)
            action = dist.sample()
            action_logprob = dist.log_prob(action)
            self.agentmemory.memory_list[idx_1].states.append(state1)
            self.agentmemory.memory_list[idx_1].actions.append(action)
            self.agentmemory.memory_list[idx_1].logprobs.append(action_logprob)

            action_list.append(action.detach().cpu().data.numpy().flatten())
            
        return action_list   
示例#17
0
    def y_dist(self):
        """
        Returns the current Y-distribution.
        :rtype: Normal|MultivariateNormal
        """
        if self._model.obs_ndim < 2:
            return Normal(self.ymean[..., 0], self.ycov[..., 0, 0].sqrt())

        return MultivariateNormal(self.ymean,
                                  scale_tril=torch.cholesky(self.ycov))
    def distribution(
        self, distr_args, scale: Optional[torch.Tensor] = None
    ) -> Distribution:
        loc, scale_tri = distr_args
        distr = MultivariateNormal(loc=loc, scale_tril=scale_tri)

        if scale is None:
            return distr
        else:
            return TransformedDistribution(distr, [AffineTransform(loc=0, scale=scale)])
示例#19
0
    def marginal_posterior_divergence(self, z, mean, logv, num_samples):
        batch_size, n = mean.shape
        diag = to_cuda_var(torch.eye(n).repeat(1, 1, 1))

        logq_zb_lst = []
        logp_zb_lst = []
        for b in range(batch_size):
            zb = z[b, :].unsqueeze(0)
            mu_b = mean[b, :].unsqueeze(0)
            logv_b = logv[b, :].unsqueeze(0)
            diag_b = to_cuda_var(torch.eye(n).repeat(1, 1, 1))
            cov_b = torch.exp(logv_b).unsqueeze(dim=2) * diag_b

            # removing b-th mean and logv
            zr = zb.repeat(batch_size - 1, 1)
            mu_r = torch.cat((mean[:b, :], mean[b + 1:, :]))
            logv_r = torch.cat((logv[:b, :], logv[b + 1:, :]))
            diag_r = to_cuda_var(torch.eye(n).repeat(batch_size - 1, 1, 1))
            cov_r = torch.exp(logv_r).unsqueeze(dim=2) * diag_r

            # E[log q(zb)] = - H(q(z))
            zb_xb_posterior_pdf = MultivariateNormal(mu_b, cov_b)
            logq_zb_xb = zb_xb_posterior_pdf.log_prob(zb)

            zb_xr_posterior_pdf = MultivariateNormal(mu_r, cov_r)
            logq_zb_xr = zb_xr_posterior_pdf.log_prob(zr)

            yb1 = logq_zb_xb - torch.log(
                to_cuda_var(torch.tensor(num_samples).float()))
            yb2 = logq_zb_xr + torch.log(
                to_cuda_var(
                    torch.tensor((num_samples - 1) /
                                 ((batch_size - 1) * num_samples)).float()))
            yb = torch.cat([yb1, yb2], dim=0)
            logq_zb = torch.logsumexp(yb, dim=0)

            # E[log p(zb)]
            zb_prior_pdf = MultivariateNormal(to_cuda_var(torch.zeros(n)),
                                              diag)
            logp_zb = zb_prior_pdf.log_prob(zb)

            logq_zb_lst.append(logq_zb)
            logp_zb_lst.append(logp_zb)

        logq_zb = torch.stack(logq_zb_lst, dim=0)
        logp_zb = torch.stack(logp_zb_lst, dim=0).squeeze(-1)

        return (logq_zb - logp_zb).sum()
示例#20
0
    def act(self, state, memory):
        x = F.tanh(self.affine1(state))
        x = F.tanh(self.affine2(x))
        alpha = self.alpha_action_mean(x)
        beta = self.beta_action_mean(x)
        action_mean = torch.cat((alpha, beta), dim=1)

        cov_mat = torch.diag(self.action_var).to(device)
        dist = MultivariateNormal(action_mean, cov_mat)

        action = dist.sample()
        # action = F.softmax(action.reshape(2,-1)).reshape(1,-1)
        action_logprob = dist.log_prob(action)

        memory.states.append(state)
        memory.actions.append(action)
        memory.logprobs.append(action_logprob)

        return action.detach()
    def act(self, state):
        # state = torch.from_numpy(state).float().to(device)
        # action_probs = self.old_actor.forward(state)
        # dist = Categorical(action_probs)
        # action = dist.sample()
        # print(state)

        with torch.no_grad():
            state = torch.from_numpy(state).float().to(device)
            action_probs = self.action_layer(state)
            cov_mat = torch.diag(self.action_var).to(device)
            # print(cov_mat)
            # print(action_probs)
            dist = MultivariateNormal(action_probs, cov_mat)
            action = dist.sample()
            # print(action, action_probs)
            log_prob = dist.log_prob(action)
            self.dist = dist
            return action.detach().cpu().numpy(), log_prob
 def forward(self, state):
     output_1 = F.relu(self.linear1(state))
     output_2 = F.relu(self.linear2(output_1))
     #LSTM
     output_2 = output_2.unsqueeze(0)
     output_3, self.hidden_cell = self.LSTM_layer_3(
         output_2)  #,self.hidden_cell
     a, b, c = output_3.shape
     #
     output_4 = F.relu(self.linear4(output_3.view(-1, c)))  #
     mu = 2 * torch.tanh(self.mu(output_4))  #有正有负 sigmoid 0-1
     sigma = F.relu(self.sigma(output_4)) + 0.001
     mu = torch.diag_embed(mu).to(device)
     sigma = torch.diag_embed(sigma).to(device)  # change to 2D
     dist = MultivariateNormal(mu, sigma)  #N(μ,σ^2)
     entropy = dist.entropy().mean()
     action = dist.sample()
     action_logprob = dist.log_prob(action)
     return action, action_logprob, entropy
示例#23
0
    def get_action_distribution(self, states, in_train=True):
        """
        Extract the probability distribution for actions.
        """
        policy = self.policy if in_train else self.prev_policy
        if not in_train:
            policy.eval()
            with torch.no_grad():
                action_probs = policy(states.to(self.device))
                distribution = MultivariateNormal(
                    action_probs, self.action_variances)

            policy.train()
        else:
            action_probs = policy(states.to(self.device))
            distribution = MultivariateNormal(action_probs,
                                              self.action_variances)

        return distribution
示例#24
0
def test_c2st_multi_round_snl_on_linearGaussian(set_seed):
    """Test SNL on linear Gaussian, comparing to ground truth posterior via c2st.

    Args:
        set_seed: fixture for manual seeding
    """

    num_dim = 2
    x_o = zeros((1, num_dim))
    num_samples = 500

    # likelihood_mean will be likelihood_shift+theta
    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.3 * eye(num_dim)

    prior_mean = zeros(num_dim)
    prior_cov = eye(num_dim)
    prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
    gt_posterior = true_posterior_linear_gaussian_mvn_prior(
        x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov)
    target_samples = gt_posterior.sample((num_samples, ))

    simulator = lambda theta: linear_gaussian(theta, likelihood_shift,
                                              likelihood_cov)

    simulator, prior = prepare_for_sbi(simulator, prior)
    inference = SNL(
        prior,
        show_progress_bars=False,
    )

    theta, x = simulate_for_sbi(simulator,
                                prior,
                                750,
                                simulation_batch_size=50)
    _ = inference.append_simulations(theta, x).train()
    posterior1 = inference.build_posterior(mcmc_method="slice_np_vectorized",
                                           mcmc_parameters={
                                               "thin": 5,
                                               "num_chains": 20
                                           }).set_default_x(x_o)

    theta, x = simulate_for_sbi(simulator,
                                posterior1,
                                750,
                                simulation_batch_size=50)
    _ = inference.append_simulations(theta, x).train()
    posterior = inference.build_posterior().copy_hyperparameters_from(
        posterior1)

    samples = posterior.sample(sample_shape=(num_samples, ),
                               mcmc_parameters={"thin": 3})

    # Check performance based on c2st accuracy.
    check_c2st(samples, target_samples, alg="multi-round-snl")
示例#25
0
def test_training_and_mcmc_on_device(method, model, device):
    """Test training on devices.

    This test does not check training speeds.

    """
    device = process_device(device)

    num_dim = 2
    num_samples = 10
    num_simulations = 500
    max_num_epochs = 5

    x_o = zeros(1, num_dim)
    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.3 * eye(num_dim)

    prior_mean = zeros(num_dim)
    prior_cov = eye(num_dim)
    prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)

    def simulator(theta):
        return linear_gaussian(theta, likelihood_shift, likelihood_cov)

    if method == SNPE:
        kwargs = dict(density_estimator=utils.posterior_nn(model=model), )
        mcmc_kwargs = dict(
            sample_with_mcmc=True,
            mcmc_method="slice_np",
        )
    elif method == SNLE:
        kwargs = dict(density_estimator=utils.likelihood_nn(model=model), )
        mcmc_kwargs = dict(mcmc_method="slice")
    elif method == SNRE:
        kwargs = dict(classifier=utils.classifier_nn(model=model), )
        mcmc_kwargs = dict(mcmc_method="slice_np_vectorized", )
    else:
        raise ValueError()

    inferer = method(prior, show_progress_bars=False, device=device, **kwargs)

    proposals = [prior]

    # Test for two rounds.
    for r in range(2):
        theta, x, = simulate_for_sbi(simulator,
                                     proposal=prior,
                                     num_simulations=num_simulations)
        _ = inferer.append_simulations(theta,
                                       x).train(training_batch_size=100,
                                                max_num_epochs=max_num_epochs)
        posterior = inferer.build_posterior(**mcmc_kwargs).set_default_x(x_o)
        proposals.append(posterior)

    proposals[-1].sample(sample_shape=(num_samples, ), x=x_o, **mcmc_kwargs)
def test_MultivariateNormalLinear(get_MultivariateNormalLinear):
    for example in get_MultivariateNormalLinear:
        i, o, b = example
        mnl = MultivariateNormalLinear(*example)

        wp = MultivariateNormal(torch.zeros(o, i),
                                torch.eye(i).repeat(o, 1, 1))
        bp = None if not b else MultivariateNormal(
            torch.zeros(o), torch.eye(o))

        assert eq_dist(mnl.weight_prior, wp)
        if b:
            assert eq_dist(mnl.bias_prior, bp)

        assert isinstance(mnl.weight, WeightMultivariateNormal)
        assert mnl.weight.shape == (o, i)
        assert hasattr(mnl, 'sample')
        assert hasattr(mnl, 'sampled')
        assert isinstance(mnl.sampled, tuple)
        assert len(mnl.sampled) == 2

        if b:
            assert mnl.bias.shape == (o,)
        else:
            assert mnl.bias is None

        init.constant_(mnl.weight.mean, 1)
        # todo: use lower triangular matrix
        init.constant_(mnl.weight.scale, -100)
        if b:
            init.constant_(mnl.bias.mean, 3)
            # todo: use lower triangular matrix
            init.constant_(mnl.bias.scale, -100)
        mnl.sample()

        x = ones_like(mnl.weight.mean)
        result = mnl(x)

        if b:
            assert allclose(result, full_like(result, i + 3))
        else:
            assert allclose(result, full_like(result, i))
示例#27
0
    def __init__(self, n_filt=8, q=8):
        super(ODE2VAE, self).__init__()
        h_dim = n_filt*4**3 # encoder output is [4*n_filt,4,4]
        # encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, n_filt, kernel_size=5, stride=2, padding=(2,2)), # 14,14
            nn.BatchNorm2d(n_filt),
            nn.ReLU(),
            nn.Conv2d(n_filt, n_filt*2, kernel_size=5, stride=2, padding=(2,2)), # 7,7
            nn.BatchNorm2d(n_filt*2),
            nn.ReLU(),
            nn.Conv2d(n_filt*2, n_filt*4, kernel_size=5, stride=2, padding=(2,2)),
            nn.ReLU(),
            Flatten()
        )

        self.fc1 = nn.Linear(h_dim, 2*q)
        self.fc2 = nn.Linear(h_dim, 2*q)
        self.fc3 = nn.Linear(q, h_dim)
        # differential function
        # to use a deterministic differential function, set bnn=False and self.beta=0.0
        self.bnn = BNN(2*q, q, n_hid_layers=2, n_hidden=50, act='celu', layer_norm=True, bnn=True)
        # downweighting the BNN KL term is helpful if self.bnn is heavily overparameterized
        self.beta = 1.0 # 2*q/self.bnn.kl().numel()
        # decoder
        self.decoder = nn.Sequential(
            UnFlatten(4),
            nn.ConvTranspose2d(h_dim//16, n_filt*8, kernel_size=3, stride=1, padding=(0,0)),
            nn.BatchNorm2d(n_filt*8),
            nn.ReLU(),
            nn.ConvTranspose2d(n_filt*8, n_filt*4, kernel_size=5, stride=2, padding=(1,1)),
            nn.BatchNorm2d(n_filt*4),
            nn.ReLU(),
            nn.ConvTranspose2d(n_filt*4, n_filt*2, kernel_size=5, stride=2, padding=(1,1), output_padding=(1,1)),
            nn.BatchNorm2d(n_filt*2),
            nn.ReLU(),
            nn.ConvTranspose2d(n_filt*2, 1, kernel_size=5, stride=1, padding=(2,2)),
            nn.Sigmoid(),
        )
        self._zero_mean = torch.zeros(2*q).to(device)
        self._eye_covar = torch.eye(2*q).to(device)
        self.mvn = MultivariateNormal(self._zero_mean, self._eye_covar)
示例#28
0
 def get_net_log_prob(self, net_input_state, net_input_onehot_action,
                      net_input_multihot_action,
                      net_input_continuous_action):
     net = getattr(self, net_name)
     n_action_dim = getattr(self, 'n_' + action_name)
     onehot_action_dim = getattr(self, 'onehot_' + action_name + '_dim')
     multihot_action_dim = getattr(self, 'multihot_' + action_name + '_dim')
     sections = getattr(self, 'onehot_' + action_name + '_sections')
     continuous_action_log_std = getattr(
         self, net_name + '_' + action_name + '_std')
     onehot_action_probs_with_continuous_mean = net(net_input_state)
     onehot_actions_log_prob = 0
     multihot_actions_log_prob = 0
     continuous_actions_log_prob = 0
     if onehot_action_dim != 0:
         dist = MultiOneHotCategorical(
             onehot_action_probs_with_continuous_mean[
                 ..., :onehot_action_dim], sections)
         onehot_actions_log_prob = dist.log_prob(net_input_onehot_action)
     if multihot_action_dim != 0:
         multihot_actions_prob = torch.sigmoid(
             onehot_action_probs_with_continuous_mean[
                 ...,
                 onehot_action_dim:onehot_action_dim + multihot_action_dim])
         dist = torch.distributions.bernoulli.Bernoulli(
             probs=multihot_actions_prob)
         multihot_actions_log_prob = dist.log_prob(
             net_input_multihot_action).sum(dim=1)
     if n_action_dim - onehot_action_dim - multihot_action_dim != 0:
         continuous_actions_mean = onehot_action_probs_with_continuous_mean[
             ..., onehot_action_dim + multihot_action_dim:]
         continuous_log_std = continuous_action_log_std.expand_as(
             continuous_actions_mean)
         continuous_actions_std = torch.exp(continuous_log_std)
         continuous_dist = MultivariateNormal(
             continuous_actions_mean,
             torch.diag_embed(continuous_actions_std))
         continuous_actions_log_prob = continuous_dist.log_prob(
             net_input_continuous_action)
     return FloatTensor(onehot_actions_log_prob +
                        multihot_actions_log_prob +
                        continuous_actions_log_prob).unsqueeze(-1)
示例#29
0
    def dist_init(self,
                  true_type='Gaussian',
                  cont_type='Gaussian',
                  cont_mean=None,
                  cont_var=1,
                  cont_covmat=None):
        """
        Set parameters for distribution under Huber contaminaton models. We assume
        the center parameter of the true distribution mu is 0 and the covariance
        is indentity martix. 

        Args:
            true_type : Type of real distribution P. 'Gaussian', 'Cauchy'.
            cont_type : Type of contamination distribution Q, 'Gaussian', 'Cauchy'.
            cont_mean: center parameter for Q
            cont_var: If scatter (covariance) matrix of Q is diagonal, cont_var gives 
                      the diagonal element.
            cont_covmat: Other scatter matrix can be provided (as torch.tensor format).
                         If cont_covmat is not None, cont_var will be ignored. 
        """

        self.true_type = true_type
        self.cont_type = cont_type

        ## settings for true distribution sampler
        self.true_mean = torch.zeros(self.p)
        if true_type == 'Gaussian':
            self.t_d = MultivariateNormal(torch.zeros(self.p),
                                          covariance_matrix=torch.eye(self.p))
        elif true_type == 'Cauchy':
            self.t_normal_d = MultivariateNormal(torch.zeros(self.p),
                                                 covariance_matrix=torch.eye(
                                                     self.p))
            self.t_chi2_d = Chi2(df=1)
        else:
            raise NameError('True type must be Gaussian or Cauchy!')

        ## settings for contamination distribution sampler
        if cont_covmat is not None:
            self.cont_covmat = cont_covmat
        else:
            self.cont_covmat = torch.eye(self.p) * cont_var
        self.cont_mean = torch.ones(self.p) * cont_mean
        if cont_type == 'Gaussian':
            self.c_d = MultivariateNormal(torch.zeros(self.p),
                                          covariance_matrix=self.cont_covmat)
        elif cont_type == 'Cauchy':
            self.c_normal_d = MultivariateNormal(
                torch.zeros(self.p), covariance_matrix=self.cont_covmat)
            self.c_chi2_d = Chi2(df=1)
        else:
            raise NameError('Cont type must be Gaussian or Cauchy!')
示例#30
0
    def act(self, observation, device, grad=False, return_dist=False):
        # Sample from a distribution of actions
        output = self.actor(observation)

        single_process = output.size() == torch.Size([self.action_dim * 2])

        if single_process:
            action_means, action_variances = torch.split(output,
                                                         self.action_dim,
                                                         dim=0)
        else:
            action_means, action_variances = torch.split(output,
                                                         self.action_dim,
                                                         dim=1)
        # Scale action variance between 0 and 1
        action_variances = torch.clamp_min((action_variances + 1) / 2, 1e-8)

        if single_process:
            action_variances = [action_variances]
        action_variances = torch.stack([
            torch.diag(action_variance) for action_variance in action_variances
        ])
        try:
            dist = MultivariateNormal(action_means, action_variances)
        except Exception as e:
            print(e)
            print("Action Means")
            print(action_means)
            print("Action Variances")
            print(action_variances)
            print("Observations")
            print(observation)
            exit()

        if return_dist:
            return dist
        action = dist.sample()
        action_logprob = dist.log_prob(action)
        if not grad:
            action = action.detach()
            action_logprob = action_logprob.detach()
        return action, action_logprob
示例#31
0
def run(setting='discrete_discrete'):
    if setting == 'discrete_discrete':
        y, wy = make_circle(radius=4, n_samples=n_target_samples)
        x, wx = make_circle(radius=2, n_samples=n_target_samples)

        x = torch.from_numpy(x).float()
        y = torch.from_numpy(y).float()
        wy = torch.from_numpy(wy).float()
        wx = torch.from_numpy(wx).float()

        x = MultivariateNormal(torch.zeros(2), torch.eye(2) / 4)
        x = x.sample((n_target_samples, ))
        wx = np.full(len(x), 1 / len(x))
        wx = torch.from_numpy(wx).float()

        ot_plan = OTPlan(source_type='discrete', target_type='discrete',
                         target_length=len(y), source_length=len(x))
    elif setting == 'continuous_discrete':
        x = MultivariateNormal(torch.zeros(2), torch.eye(2) / 4)
        y, wy = make_circle(radius=4, n_samples=n_target_samples)

        y = torch.from_numpy(y).float()
        wy = torch.from_numpy(wy).float()

        ot_plan = OTPlan(source_type='continuous', target_type='discrete',
                         target_length=len(y), source_dim=2)

    else:
        raise ValueError

    mapping = Mapping(ot_plan, dim=2)
    optimizer = Adam(ot_plan.parameters(), amsgrad=True, lr=lr)
    # optimizer = SGD(ot_plan.parameters(), lr=lr)


    plan_objectives = []
    map_objectives = []

    print('Learning OT plan')

    for i in range(n_plan_iter):
        optimizer.zero_grad()

        if setting == 'discrete_discrete':
            this_yidx = torch.multinomial(wy, batch_size)
            this_y = y[this_yidx]
            this_xidx = torch.multinomial(wx, batch_size)
            this_x = x[this_xidx]
        else:
            this_x = x.sample((batch_size,))
            this_yidx = torch.multinomial(wy, batch_size)
            this_y = y[this_yidx]
            this_xidx = None
        loss = ot_plan.loss(this_x, this_y, yidx=this_yidx, xidx=this_xidx)
        loss.backward()
        optimizer.step()
        plan_objectives.append(-loss.item())
        if i % 100 == 0:
            print(f'Iter {i}, loss {-loss.item():.3f}')

    optimizer = Adam(mapping.parameters(), amsgrad=True, lr=lr)
    # optimizer = SGD(mapping.parameters(), lr=1e-5)


    print('Learning barycentric mapping')
    for i in range(n_map_iter):
        optimizer.zero_grad()

        if setting == 'discrete_discrete':
            this_yidx = torch.multinomial(wy, batch_size)
            this_y = y[this_yidx]
            this_xidx = torch.multinomial(wx, batch_size)
            this_x = x[this_xidx]
        else:
            this_x = x.sample((batch_size,))
            this_yidx = torch.multinomial(wy, batch_size)
            this_y = y[this_yidx]
            this_xidx = None

        loss = mapping.loss(this_x, this_y, yidx=this_yidx, xidx=this_xidx)
        loss.backward()
        optimizer.step()
        map_objectives.append(loss.item())
        if i % 100 == 0:
            print(f'Iter {i}, loss {loss.item():.3f}')

    if setting == 'continuous_discrete':
        x = x.sample((len(y),))
    with torch.no_grad():
        mapped = mapping(x)
    x = x.numpy()
    y = y.numpy()
    mapped = mapped.numpy()

    return x, y, mapped, plan_objectives, map_objectives