class VariationalLoss(nn.Module): def __init__(self, distribution: TargetDistribution): super().__init__() self.distr = distribution self.base_distr = MultivariateNormal(torch.zeros(2), torch.eye(2)) def forward(self, z0: Tensor, z: Tensor, sum_log_det_J: float) -> float: base_log_prob = self.base_distr.log_prob(z0) target_density_log_prob = -self.distr(z) return (base_log_prob - target_density_log_prob - sum_log_det_J).mean()
def evaluate(self, states, actions): action_means = -states * self.agent[0].weight[:, 0] / self.agent[ 0].weight[:, 1] # action_var = torch.full((action_dim,), 0.5 * self.tau) action_var = 0.5 * self.tau / self.agent[0].weight[:, 1] * self.agent[ 0].weight[:, 1] action_var = action_var.expand_as(action_means) cov_mat = torch.diag_embed(action_var).to(device) dist = MultivariateNormal(action_means, cov_mat) action_logprobs = dist.log_prob(actions) action_values = self.agent( torch.cat((states, actions), dim=1).squeeze()) dist_entropy = dist.entropy() return action_logprobs, torch.squeeze(action_values), dist_entropy
def get_action(self, state): state = torch.FloatTensor(state).unsqueeze(0) mean, log_std = self.forward(state) std = log_std.exp() loc = torch.zeros(mean.size()) scale = torch.ones(mean.size()) if mean.size()[1] == 1: normal = Normal(loc, scale) z = normal.sample() else: scale = torch.diag_embed(scale) mvn = MultivariateNormal(loc, scale) z = mvn.sample() action = torch.tanh(mean + std * z) action = action.cpu() # .detach().cpu().numpy() return action[0]
def evaluate(self, state, action): # import pdb; pdb.set_trace() x = F.tanh(self.affine1(state)) x = F.tanh(self.affine2(x)) alpha = self.alpha_action_mean(x) beta = self.beta_action_mean(x) action_mean = torch.cat((alpha, beta), dim=1) # action_mean = torch.squeeze(x) action_var = self.action_var.expand_as(action_mean) #action_log_std cov_mat = torch.diag_embed(action_var).to(device) dist = MultivariateNormal(action_mean, cov_mat) # action_logprobs = dist.log_prob(torch.squeeze(action)) action_logprobs = dist.log_prob(action) dist_entropy = dist.entropy() state_value = self.critic(state) # import pdb; pdb.set_trace() return action_logprobs, torch.squeeze(state_value), dist_entropy
def test_shapes(self): B1 = 100 B2 = 50 K1 = 20 K2 = 4 D = 16 scale = torch.randn((K1, K2, D, D)) cov = scale @ scale.transpose(-2, -1) + torch.diag(0.1 * torch.ones(D)) p = MultivariateNormal(torch.randn((K1, K2, D)), covariance_matrix=cov) p_ = MultivariateNormal(loc=p.loc.view(-1, D), scale_tril=p.scale_tril.view(-1, D, D)) q = Normal(loc=torch.randn((B1, B2, D)), scale=torch.rand((B1, B2, D))) q_ = Normal(loc=q.loc.view(-1, D), scale=q.scale.view(-1, D)) actual_loss = pt.ops.kl_divergence(q, p) reference_loss = pt.ops.kl_divergence(q_, p_).view(B1, B2, K1, K2) np.testing.assert_allclose(actual_loss, reference_loss, rtol=1e-4)
def test_against_multivariate_multivariate(self): B = 500 K = 100 D = 16 scale = torch.randn((K, D, D)) cov = scale @ scale.transpose(1, 2) + torch.diag(0.1 * torch.ones(D)) p = MultivariateNormal(torch.randn((K, D)), covariance_matrix=cov) q = MultivariateNormal(loc=torch.randn((B, 1, D)), scale_tril=torch.Tensor( np.broadcast_to(np.diag(np.random.rand(D)), (B, 1, D, D)))) q_ = Normal(loc=q.loc[:, 0], scale=pt.ops.losses.loss._batch_diag(q.scale_tril[:, 0])) actual_loss = pt.ops.kl_divergence(q_, p) reference_loss = kl_divergence(q, p) np.testing.assert_allclose(actual_loss, reference_loss, rtol=1e-4)
def evaluate(self, state, action): action_mean = [] state_value = [] no_vehicles = len(state) for idx_1 in range(no_vehicles): action_mean.append(self.actor(state[idx_1][1], state[idx_1][0])) state_value.append(self.critic(state[idx_1][1], state[idx_1][0])) action_mean = torch.stack(action_mean).view(-1, 1, 2) state_value = torch.stack(state_value).view(-1, 1, 1) #action_mean = self.actor(state) dist = MultivariateNormal(torch.squeeze(action_mean), torch.diag(self.action_var)) action_logprobs = dist.log_prob(torch.squeeze(action)) dist_entropy = dist.entropy() #state_value = self.critic(state) return action_logprobs, torch.squeeze(state_value), dist_entropy
class GaussDataset(Dataset): def __init__(self, n, mean, cov): self.n = n self.dist = MultivariateNormal(loc=mean, covariance_matrix=cov) def __len__(self): return self.n def __getitem__(self, item): return self.dist.sample()
def dist(self, device): """ The distribution induced by the gen. """ W = self.gen.g.weight.data WtW = W @ W.t() cov = WtW + torch.eye( WtW.size(0)).to(device) * self.gen.logsigma.exp()**2 mu = self.gen.g.bias return MultivariateNormal(mu, cov)
def evaluate(self, state, state_randomized, action, randomize): if randomize: actor_output,critic_output = self.forward(state_randomized) else: actor_output,critic_output = self.forward(state) action_mean = torch.squeeze(actor_output) action_var = self.action_var.expand_as(action_mean) cov_mat = torch.diag_embed(action_var).to(device) dist = MultivariateNormal(action_mean, cov_mat) action_logprobs = dist.log_prob(torch.squeeze(action)) dist_entropy = dist.entropy() state_value = critic_output return action_logprobs, torch.squeeze(state_value), dist_entropy
def get_training_params(self,frame,mes, action): frame = torch.stack(frame) mes = torch.stack(mes) if len(list(frame.size())) > 4: frame = torch.squeeze(frame) if len(list(mes.size())) > 2: mes = torch.squeeze(mes) action = torch.stack(action) mean = self.actor(frame,mes) action_expanded = self.action_var.expand_as(mean) cov_matrix = torch.diag_embed(action_expanded).to(device) gauss_dist = MultivariateNormal(mean,cov_matrix) action_log_prob = gauss_dist.log_prob(action).to(device) entropy = gauss_dist.entropy().to(device) state_value = torch.squeeze(self.critic(frame,mes)).to(device) return action_log_prob, state_value, entropy
def _get_init_dist(self): loc = self.z_trans_matrix.new_zeros(self.full_state_dim) covar = self.z_trans_matrix.new_zeros(self.full_state_dim, self.full_state_dim) covar[:self.full_gp_state_dim, :self. full_gp_state_dim] = block_diag_embed( self.kernel.stationary_covariance()) covar[self.full_gp_state_dim:, self.full_gp_state_dim:] = self.init_noise_scale_sq.diag_embed() return MultivariateNormal(loc, covar)
def act(self, state): # state1, state2 = state # state1 = torch.FloatTensor(state1).to(device) # state2 = torch.FloatTensor(state2).to(device) if self.has_continuous_action_space: action_mean, action_sigma = self.actor(state) action_var = action_sigma**2 cov_mat = torch.diag_embed(action_var).unsqueeze(dim=0) dist = MultivariateNormal(action_mean, cov_mat) # print(dist) else: action_probs = self.actor(state) dist = Categorical(action_probs) action = dist.sample() action = action.clamp(-1, 1) action_logprob = dist.log_prob(action) return action.detach(), action_logprob.detach()
def _move(self, drift, index): """Move a walker. Args: drift (torch.tensor): drift velocity index (int): indx of the electron to move Returns: torch.tensor: position of the walkers """ d = drift.view(self.nwalkers, self.nelec, self.ndim) mv = MultivariateNormal(torch.zeros(self.ndim), np.sqrt( self.step_size) * torch.eye(self.ndim)) return self.step_size * d[range(self.nwalkers), index, :] \ + mv.sample((self.nwalkers, 1)).squeeze()
def forward(self, state): output_1 = F.relu(self.linear1(state)) output_2 = F.relu(self.linear2(output_1)) mu = 2 * torch.sigmoid(self.mu(output_2)) #有正有负 sigma = F.relu( self.sigma(output_2) ) + 0.001 # avoid 0 softplus output = F.softmax(output, dim=-1) action_mean = self.linear3(output) #cov_mat = torch.diag(self.action_var).to(device) mu = torch.diag_embed(mu).to(device) sigma = torch.diag_embed(sigma).to(device) # change to 2D dist = MultivariateNormal( mu, sigma) #N(μ,σ^2) σ超参不用训练 MultivariateNormal(action_mean, cov_mat) #distribution = Categorical(F.softmax(output, dim=-1)) entropy = dist.entropy().mean() action = dist.sample() action_logprob = dist.log_prob(action) return action.detach( ), action_logprob, entropy #distribution .detach()
def select_action(self,state,actor): self.action_var = torch.full((2,), 0.6*0.6).to(device) #manually change action_dim action_std no_vehicles = len(state) action_list = [] for idx_1 in range(no_vehicles): state1 = state[idx_1] target1 = state1[0] other_vehicles1 = state1[1] action_mean = actor(other_vehicles1,target1) dist = MultivariateNormal(action_mean, torch.diag(self.action_var).to(device)) ##ATTENTION: manually change variance in torch.diag(var) action = dist.sample() action_logprob = dist.log_prob(action) self.agentmemory.memory_list[idx_1].states.append(state1) self.agentmemory.memory_list[idx_1].actions.append(action) self.agentmemory.memory_list[idx_1].logprobs.append(action_logprob) action_list.append(action.detach().cpu().data.numpy().flatten()) return action_list
def y_dist(self): """ Returns the current Y-distribution. :rtype: Normal|MultivariateNormal """ if self._model.obs_ndim < 2: return Normal(self.ymean[..., 0], self.ycov[..., 0, 0].sqrt()) return MultivariateNormal(self.ymean, scale_tril=torch.cholesky(self.ycov))
def distribution( self, distr_args, scale: Optional[torch.Tensor] = None ) -> Distribution: loc, scale_tri = distr_args distr = MultivariateNormal(loc=loc, scale_tril=scale_tri) if scale is None: return distr else: return TransformedDistribution(distr, [AffineTransform(loc=0, scale=scale)])
def marginal_posterior_divergence(self, z, mean, logv, num_samples): batch_size, n = mean.shape diag = to_cuda_var(torch.eye(n).repeat(1, 1, 1)) logq_zb_lst = [] logp_zb_lst = [] for b in range(batch_size): zb = z[b, :].unsqueeze(0) mu_b = mean[b, :].unsqueeze(0) logv_b = logv[b, :].unsqueeze(0) diag_b = to_cuda_var(torch.eye(n).repeat(1, 1, 1)) cov_b = torch.exp(logv_b).unsqueeze(dim=2) * diag_b # removing b-th mean and logv zr = zb.repeat(batch_size - 1, 1) mu_r = torch.cat((mean[:b, :], mean[b + 1:, :])) logv_r = torch.cat((logv[:b, :], logv[b + 1:, :])) diag_r = to_cuda_var(torch.eye(n).repeat(batch_size - 1, 1, 1)) cov_r = torch.exp(logv_r).unsqueeze(dim=2) * diag_r # E[log q(zb)] = - H(q(z)) zb_xb_posterior_pdf = MultivariateNormal(mu_b, cov_b) logq_zb_xb = zb_xb_posterior_pdf.log_prob(zb) zb_xr_posterior_pdf = MultivariateNormal(mu_r, cov_r) logq_zb_xr = zb_xr_posterior_pdf.log_prob(zr) yb1 = logq_zb_xb - torch.log( to_cuda_var(torch.tensor(num_samples).float())) yb2 = logq_zb_xr + torch.log( to_cuda_var( torch.tensor((num_samples - 1) / ((batch_size - 1) * num_samples)).float())) yb = torch.cat([yb1, yb2], dim=0) logq_zb = torch.logsumexp(yb, dim=0) # E[log p(zb)] zb_prior_pdf = MultivariateNormal(to_cuda_var(torch.zeros(n)), diag) logp_zb = zb_prior_pdf.log_prob(zb) logq_zb_lst.append(logq_zb) logp_zb_lst.append(logp_zb) logq_zb = torch.stack(logq_zb_lst, dim=0) logp_zb = torch.stack(logp_zb_lst, dim=0).squeeze(-1) return (logq_zb - logp_zb).sum()
def act(self, state, memory): x = F.tanh(self.affine1(state)) x = F.tanh(self.affine2(x)) alpha = self.alpha_action_mean(x) beta = self.beta_action_mean(x) action_mean = torch.cat((alpha, beta), dim=1) cov_mat = torch.diag(self.action_var).to(device) dist = MultivariateNormal(action_mean, cov_mat) action = dist.sample() # action = F.softmax(action.reshape(2,-1)).reshape(1,-1) action_logprob = dist.log_prob(action) memory.states.append(state) memory.actions.append(action) memory.logprobs.append(action_logprob) return action.detach()
def act(self, state): # state = torch.from_numpy(state).float().to(device) # action_probs = self.old_actor.forward(state) # dist = Categorical(action_probs) # action = dist.sample() # print(state) with torch.no_grad(): state = torch.from_numpy(state).float().to(device) action_probs = self.action_layer(state) cov_mat = torch.diag(self.action_var).to(device) # print(cov_mat) # print(action_probs) dist = MultivariateNormal(action_probs, cov_mat) action = dist.sample() # print(action, action_probs) log_prob = dist.log_prob(action) self.dist = dist return action.detach().cpu().numpy(), log_prob
def forward(self, state): output_1 = F.relu(self.linear1(state)) output_2 = F.relu(self.linear2(output_1)) #LSTM output_2 = output_2.unsqueeze(0) output_3, self.hidden_cell = self.LSTM_layer_3( output_2) #,self.hidden_cell a, b, c = output_3.shape # output_4 = F.relu(self.linear4(output_3.view(-1, c))) # mu = 2 * torch.tanh(self.mu(output_4)) #有正有负 sigmoid 0-1 sigma = F.relu(self.sigma(output_4)) + 0.001 mu = torch.diag_embed(mu).to(device) sigma = torch.diag_embed(sigma).to(device) # change to 2D dist = MultivariateNormal(mu, sigma) #N(μ,σ^2) entropy = dist.entropy().mean() action = dist.sample() action_logprob = dist.log_prob(action) return action, action_logprob, entropy
def get_action_distribution(self, states, in_train=True): """ Extract the probability distribution for actions. """ policy = self.policy if in_train else self.prev_policy if not in_train: policy.eval() with torch.no_grad(): action_probs = policy(states.to(self.device)) distribution = MultivariateNormal( action_probs, self.action_variances) policy.train() else: action_probs = policy(states.to(self.device)) distribution = MultivariateNormal(action_probs, self.action_variances) return distribution
def test_c2st_multi_round_snl_on_linearGaussian(set_seed): """Test SNL on linear Gaussian, comparing to ground truth posterior via c2st. Args: set_seed: fixture for manual seeding """ num_dim = 2 x_o = zeros((1, num_dim)) num_samples = 500 # likelihood_mean will be likelihood_shift+theta likelihood_shift = -1.0 * ones(num_dim) likelihood_cov = 0.3 * eye(num_dim) prior_mean = zeros(num_dim) prior_cov = eye(num_dim) prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) gt_posterior = true_posterior_linear_gaussian_mvn_prior( x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov) target_samples = gt_posterior.sample((num_samples, )) simulator = lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov) simulator, prior = prepare_for_sbi(simulator, prior) inference = SNL( prior, show_progress_bars=False, ) theta, x = simulate_for_sbi(simulator, prior, 750, simulation_batch_size=50) _ = inference.append_simulations(theta, x).train() posterior1 = inference.build_posterior(mcmc_method="slice_np_vectorized", mcmc_parameters={ "thin": 5, "num_chains": 20 }).set_default_x(x_o) theta, x = simulate_for_sbi(simulator, posterior1, 750, simulation_batch_size=50) _ = inference.append_simulations(theta, x).train() posterior = inference.build_posterior().copy_hyperparameters_from( posterior1) samples = posterior.sample(sample_shape=(num_samples, ), mcmc_parameters={"thin": 3}) # Check performance based on c2st accuracy. check_c2st(samples, target_samples, alg="multi-round-snl")
def test_training_and_mcmc_on_device(method, model, device): """Test training on devices. This test does not check training speeds. """ device = process_device(device) num_dim = 2 num_samples = 10 num_simulations = 500 max_num_epochs = 5 x_o = zeros(1, num_dim) likelihood_shift = -1.0 * ones(num_dim) likelihood_cov = 0.3 * eye(num_dim) prior_mean = zeros(num_dim) prior_cov = eye(num_dim) prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) def simulator(theta): return linear_gaussian(theta, likelihood_shift, likelihood_cov) if method == SNPE: kwargs = dict(density_estimator=utils.posterior_nn(model=model), ) mcmc_kwargs = dict( sample_with_mcmc=True, mcmc_method="slice_np", ) elif method == SNLE: kwargs = dict(density_estimator=utils.likelihood_nn(model=model), ) mcmc_kwargs = dict(mcmc_method="slice") elif method == SNRE: kwargs = dict(classifier=utils.classifier_nn(model=model), ) mcmc_kwargs = dict(mcmc_method="slice_np_vectorized", ) else: raise ValueError() inferer = method(prior, show_progress_bars=False, device=device, **kwargs) proposals = [prior] # Test for two rounds. for r in range(2): theta, x, = simulate_for_sbi(simulator, proposal=prior, num_simulations=num_simulations) _ = inferer.append_simulations(theta, x).train(training_batch_size=100, max_num_epochs=max_num_epochs) posterior = inferer.build_posterior(**mcmc_kwargs).set_default_x(x_o) proposals.append(posterior) proposals[-1].sample(sample_shape=(num_samples, ), x=x_o, **mcmc_kwargs)
def test_MultivariateNormalLinear(get_MultivariateNormalLinear): for example in get_MultivariateNormalLinear: i, o, b = example mnl = MultivariateNormalLinear(*example) wp = MultivariateNormal(torch.zeros(o, i), torch.eye(i).repeat(o, 1, 1)) bp = None if not b else MultivariateNormal( torch.zeros(o), torch.eye(o)) assert eq_dist(mnl.weight_prior, wp) if b: assert eq_dist(mnl.bias_prior, bp) assert isinstance(mnl.weight, WeightMultivariateNormal) assert mnl.weight.shape == (o, i) assert hasattr(mnl, 'sample') assert hasattr(mnl, 'sampled') assert isinstance(mnl.sampled, tuple) assert len(mnl.sampled) == 2 if b: assert mnl.bias.shape == (o,) else: assert mnl.bias is None init.constant_(mnl.weight.mean, 1) # todo: use lower triangular matrix init.constant_(mnl.weight.scale, -100) if b: init.constant_(mnl.bias.mean, 3) # todo: use lower triangular matrix init.constant_(mnl.bias.scale, -100) mnl.sample() x = ones_like(mnl.weight.mean) result = mnl(x) if b: assert allclose(result, full_like(result, i + 3)) else: assert allclose(result, full_like(result, i))
def __init__(self, n_filt=8, q=8): super(ODE2VAE, self).__init__() h_dim = n_filt*4**3 # encoder output is [4*n_filt,4,4] # encoder self.encoder = nn.Sequential( nn.Conv2d(1, n_filt, kernel_size=5, stride=2, padding=(2,2)), # 14,14 nn.BatchNorm2d(n_filt), nn.ReLU(), nn.Conv2d(n_filt, n_filt*2, kernel_size=5, stride=2, padding=(2,2)), # 7,7 nn.BatchNorm2d(n_filt*2), nn.ReLU(), nn.Conv2d(n_filt*2, n_filt*4, kernel_size=5, stride=2, padding=(2,2)), nn.ReLU(), Flatten() ) self.fc1 = nn.Linear(h_dim, 2*q) self.fc2 = nn.Linear(h_dim, 2*q) self.fc3 = nn.Linear(q, h_dim) # differential function # to use a deterministic differential function, set bnn=False and self.beta=0.0 self.bnn = BNN(2*q, q, n_hid_layers=2, n_hidden=50, act='celu', layer_norm=True, bnn=True) # downweighting the BNN KL term is helpful if self.bnn is heavily overparameterized self.beta = 1.0 # 2*q/self.bnn.kl().numel() # decoder self.decoder = nn.Sequential( UnFlatten(4), nn.ConvTranspose2d(h_dim//16, n_filt*8, kernel_size=3, stride=1, padding=(0,0)), nn.BatchNorm2d(n_filt*8), nn.ReLU(), nn.ConvTranspose2d(n_filt*8, n_filt*4, kernel_size=5, stride=2, padding=(1,1)), nn.BatchNorm2d(n_filt*4), nn.ReLU(), nn.ConvTranspose2d(n_filt*4, n_filt*2, kernel_size=5, stride=2, padding=(1,1), output_padding=(1,1)), nn.BatchNorm2d(n_filt*2), nn.ReLU(), nn.ConvTranspose2d(n_filt*2, 1, kernel_size=5, stride=1, padding=(2,2)), nn.Sigmoid(), ) self._zero_mean = torch.zeros(2*q).to(device) self._eye_covar = torch.eye(2*q).to(device) self.mvn = MultivariateNormal(self._zero_mean, self._eye_covar)
def get_net_log_prob(self, net_input_state, net_input_onehot_action, net_input_multihot_action, net_input_continuous_action): net = getattr(self, net_name) n_action_dim = getattr(self, 'n_' + action_name) onehot_action_dim = getattr(self, 'onehot_' + action_name + '_dim') multihot_action_dim = getattr(self, 'multihot_' + action_name + '_dim') sections = getattr(self, 'onehot_' + action_name + '_sections') continuous_action_log_std = getattr( self, net_name + '_' + action_name + '_std') onehot_action_probs_with_continuous_mean = net(net_input_state) onehot_actions_log_prob = 0 multihot_actions_log_prob = 0 continuous_actions_log_prob = 0 if onehot_action_dim != 0: dist = MultiOneHotCategorical( onehot_action_probs_with_continuous_mean[ ..., :onehot_action_dim], sections) onehot_actions_log_prob = dist.log_prob(net_input_onehot_action) if multihot_action_dim != 0: multihot_actions_prob = torch.sigmoid( onehot_action_probs_with_continuous_mean[ ..., onehot_action_dim:onehot_action_dim + multihot_action_dim]) dist = torch.distributions.bernoulli.Bernoulli( probs=multihot_actions_prob) multihot_actions_log_prob = dist.log_prob( net_input_multihot_action).sum(dim=1) if n_action_dim - onehot_action_dim - multihot_action_dim != 0: continuous_actions_mean = onehot_action_probs_with_continuous_mean[ ..., onehot_action_dim + multihot_action_dim:] continuous_log_std = continuous_action_log_std.expand_as( continuous_actions_mean) continuous_actions_std = torch.exp(continuous_log_std) continuous_dist = MultivariateNormal( continuous_actions_mean, torch.diag_embed(continuous_actions_std)) continuous_actions_log_prob = continuous_dist.log_prob( net_input_continuous_action) return FloatTensor(onehot_actions_log_prob + multihot_actions_log_prob + continuous_actions_log_prob).unsqueeze(-1)
def dist_init(self, true_type='Gaussian', cont_type='Gaussian', cont_mean=None, cont_var=1, cont_covmat=None): """ Set parameters for distribution under Huber contaminaton models. We assume the center parameter of the true distribution mu is 0 and the covariance is indentity martix. Args: true_type : Type of real distribution P. 'Gaussian', 'Cauchy'. cont_type : Type of contamination distribution Q, 'Gaussian', 'Cauchy'. cont_mean: center parameter for Q cont_var: If scatter (covariance) matrix of Q is diagonal, cont_var gives the diagonal element. cont_covmat: Other scatter matrix can be provided (as torch.tensor format). If cont_covmat is not None, cont_var will be ignored. """ self.true_type = true_type self.cont_type = cont_type ## settings for true distribution sampler self.true_mean = torch.zeros(self.p) if true_type == 'Gaussian': self.t_d = MultivariateNormal(torch.zeros(self.p), covariance_matrix=torch.eye(self.p)) elif true_type == 'Cauchy': self.t_normal_d = MultivariateNormal(torch.zeros(self.p), covariance_matrix=torch.eye( self.p)) self.t_chi2_d = Chi2(df=1) else: raise NameError('True type must be Gaussian or Cauchy!') ## settings for contamination distribution sampler if cont_covmat is not None: self.cont_covmat = cont_covmat else: self.cont_covmat = torch.eye(self.p) * cont_var self.cont_mean = torch.ones(self.p) * cont_mean if cont_type == 'Gaussian': self.c_d = MultivariateNormal(torch.zeros(self.p), covariance_matrix=self.cont_covmat) elif cont_type == 'Cauchy': self.c_normal_d = MultivariateNormal( torch.zeros(self.p), covariance_matrix=self.cont_covmat) self.c_chi2_d = Chi2(df=1) else: raise NameError('Cont type must be Gaussian or Cauchy!')
def act(self, observation, device, grad=False, return_dist=False): # Sample from a distribution of actions output = self.actor(observation) single_process = output.size() == torch.Size([self.action_dim * 2]) if single_process: action_means, action_variances = torch.split(output, self.action_dim, dim=0) else: action_means, action_variances = torch.split(output, self.action_dim, dim=1) # Scale action variance between 0 and 1 action_variances = torch.clamp_min((action_variances + 1) / 2, 1e-8) if single_process: action_variances = [action_variances] action_variances = torch.stack([ torch.diag(action_variance) for action_variance in action_variances ]) try: dist = MultivariateNormal(action_means, action_variances) except Exception as e: print(e) print("Action Means") print(action_means) print("Action Variances") print(action_variances) print("Observations") print(observation) exit() if return_dist: return dist action = dist.sample() action_logprob = dist.log_prob(action) if not grad: action = action.detach() action_logprob = action_logprob.detach() return action, action_logprob
def run(setting='discrete_discrete'): if setting == 'discrete_discrete': y, wy = make_circle(radius=4, n_samples=n_target_samples) x, wx = make_circle(radius=2, n_samples=n_target_samples) x = torch.from_numpy(x).float() y = torch.from_numpy(y).float() wy = torch.from_numpy(wy).float() wx = torch.from_numpy(wx).float() x = MultivariateNormal(torch.zeros(2), torch.eye(2) / 4) x = x.sample((n_target_samples, )) wx = np.full(len(x), 1 / len(x)) wx = torch.from_numpy(wx).float() ot_plan = OTPlan(source_type='discrete', target_type='discrete', target_length=len(y), source_length=len(x)) elif setting == 'continuous_discrete': x = MultivariateNormal(torch.zeros(2), torch.eye(2) / 4) y, wy = make_circle(radius=4, n_samples=n_target_samples) y = torch.from_numpy(y).float() wy = torch.from_numpy(wy).float() ot_plan = OTPlan(source_type='continuous', target_type='discrete', target_length=len(y), source_dim=2) else: raise ValueError mapping = Mapping(ot_plan, dim=2) optimizer = Adam(ot_plan.parameters(), amsgrad=True, lr=lr) # optimizer = SGD(ot_plan.parameters(), lr=lr) plan_objectives = [] map_objectives = [] print('Learning OT plan') for i in range(n_plan_iter): optimizer.zero_grad() if setting == 'discrete_discrete': this_yidx = torch.multinomial(wy, batch_size) this_y = y[this_yidx] this_xidx = torch.multinomial(wx, batch_size) this_x = x[this_xidx] else: this_x = x.sample((batch_size,)) this_yidx = torch.multinomial(wy, batch_size) this_y = y[this_yidx] this_xidx = None loss = ot_plan.loss(this_x, this_y, yidx=this_yidx, xidx=this_xidx) loss.backward() optimizer.step() plan_objectives.append(-loss.item()) if i % 100 == 0: print(f'Iter {i}, loss {-loss.item():.3f}') optimizer = Adam(mapping.parameters(), amsgrad=True, lr=lr) # optimizer = SGD(mapping.parameters(), lr=1e-5) print('Learning barycentric mapping') for i in range(n_map_iter): optimizer.zero_grad() if setting == 'discrete_discrete': this_yidx = torch.multinomial(wy, batch_size) this_y = y[this_yidx] this_xidx = torch.multinomial(wx, batch_size) this_x = x[this_xidx] else: this_x = x.sample((batch_size,)) this_yidx = torch.multinomial(wy, batch_size) this_y = y[this_yidx] this_xidx = None loss = mapping.loss(this_x, this_y, yidx=this_yidx, xidx=this_xidx) loss.backward() optimizer.step() map_objectives.append(loss.item()) if i % 100 == 0: print(f'Iter {i}, loss {loss.item():.3f}') if setting == 'continuous_discrete': x = x.sample((len(y),)) with torch.no_grad(): mapped = mapping(x) x = x.numpy() y = y.numpy() mapped = mapped.numpy() return x, y, mapped, plan_objectives, map_objectives