예제 #1
0
    def predict(self, obs, deterministic=False) -> np.ndarray:
        self.policy_network.eval()
        with th.no_grad():
            if deterministic:
                action = self.policy_network(obs).detach().cpu().numpy()
                self.policy_network.train()
            else:
                distr = MultivariateNormal(self.policy_network(obs),
                                           self.action_std)
                action = distr.sample().detach().cpu().numpy()
                self.policy_network.train()

        return action
예제 #2
0
 def act(self, state, memory):
     action_mean = self.actor(state)
     cov_mat = torch.diag(self.action_var).to(device)
     
     dist = MultivariateNormal(action_mean, cov_mat)
     action = dist.sample()
     action_logprob = dist.log_prob(action)
     
     memory.states.append(state)
     memory.actions.append(action)
     memory.logprobs.append(action_logprob)
     
     return action.detach()
예제 #3
0
    def sample(self, num_samples):
        if self.prior is None:
            norm_prior = MultivariateNormal(torch.zeros(self.num_vars).to(self.s_net.device),
                                            torch.eye(self.num_vars).to(self.s_net.device))

            x = norm_prior.sample((num_samples, ))
        else:
            x = self.prior.sample(num_samples)

        for layer in self.layers:
            x = layer.f(x)

        return x
예제 #4
0
class FuzzedExpansion(object):
    def __init__(self, new_dims, fuzz_scale):
        self.new_dims = new_dims
        if new_dims > 0:
            self._fuzz_distribution = MultivariateNormal(
                torch.zeros(new_dims),
                torch.eye(new_dims) * fuzz_scale)

    def __call__(self, X):
        if self.new_dims < 1:
            return X
        fuzz = self._fuzz_distribution.sample((X.size(0), ))
        return torch.cat([X, fuzz], dim=1)
 def act(self, state):
     state = torch.from_numpy(state).float().to(device) 
     action_probs = self.action_layer(state)
     cov_mat = torch.diag(self.action_var).to(device)
     # print(action_probs)
     dist = MultivariateNormal(action_probs, cov_mat)
     action = dist.sample()
     log_prob = dist.log_prob(action)
     # print('action',action)
     # memory.states.append(state)
     # memory.actions.append(action)
     # memory.logprobs.append(log_prob)
     return action, log_prob
예제 #6
0
    def choose_action(self, observation, memory):
        action_mean = self.actor(observation)
        cov_mat = torch.diag(self.agent_action).to(device)

        dist = MultivariateNormal(action_mean, cov_mat)
        action = dist.sample()
        action_log_prob = dist.log_prob(action)

        memory.observations.append(observation)
        memory.actions.append(action)
        memory.log_probs.append(action_log_prob)

        return action.detach()
def get_action(policy_new, obs):
    global cov_mat
    mean = policy_new(obs)
    dist = MultivariateNormal(mean, cov_mat)

    # Sample an action from the distribution
    action = dist.sample()

    # Calculate the log probability for that action
    log_prob = dist.log_prob(action)

    # Return the sampled action and the log probability of that action in our distribution
    return action.detach().numpy(), log_prob.detach()
예제 #8
0
    def act(self, state):
        if self.has_continuous_action_space:
            action_mean = self.actor(state)
            cov_mat = torch.diag(self.action_var).unsqueeze(dim=0)
            dist = MultivariateNormal(action_mean, cov_mat)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)

        action = dist.sample()
        action_logprob = dist.log_prob(action)

        return action.detach(), action_logprob.detach()
예제 #9
0
    def _init_multivar(self):
        """Initialize the walkers in a sphere covering the molecule

        Returns:
            torch.tensor -- positions of the walkers
        """
        multi = MultivariateNormal(
            torch.tensor(self.init_domain['mean']),
            torch.tensor(self.init_domain['sigma']))
        pos = multi.sample((self.nwalkers, self.nelec)).type(
            torch.get_default_dtype())
        pos = pos.view(self.nwalkers, self.nelec * self.ndim)
        return pos.to(device=self.device)
예제 #10
0
파일: linear_gaussian.py 프로젝트: bkmi/sbi
def samples_true_posterior_linear_gaussian_uniform_prior(
    x_o: Tensor,
    likelihood_shift: Tensor,
    likelihood_cov: Tensor,
    prior: Union[Uniform, Independent],
    num_samples: int = 1000,
) -> Tensor:
    """
    Returns ground truth posterior samples for Gaussian likelihood and uniform prior.

    Args:
        x_o: The observation.
        likelihood_shift: Mean of the likelihood p(x|theta) is likelihood_shift+theta.
        likelihood_cov: Covariance matrix of likelihood.
        prior: Uniform prior distribution.
        num_samples: Desired number of samples.

    Returns: Samples from posterior.
    """

    # Let s denote the likelihood_shift:
    # The likelihood has the term (x-(s+theta))^2 in the exponent of the Gaussian.
    # In other words, as a function of x, the mean of the likelihood is s+theta.
    # For computing the posterior we need the likelihood as a function of theta. Hence:
    # (x-(s+theta))^2 = (theta-(-s+x))^2
    # We see that the mean is -s+x = x-s

    # Take into account iid trials
    x_o = atleast_2d(x_o)
    num_trials, *_ = x_o.shape
    x_o_mean = x_o.mean(0)
    likelihood_mean = x_o_mean - likelihood_shift

    posterior = MultivariateNormal(loc=likelihood_mean,
                                   covariance_matrix=1 / num_trials *
                                   likelihood_cov)

    # generate samples from ND Gaussian truncated by prior support
    num_remaining = num_samples
    samples = []

    while num_remaining > 0:
        candidate_samples = posterior.sample(
            sample_shape=torch.Size((num_remaining, )))
        is_in_prior = within_support(prior, candidate_samples)
        # accept if in prior
        if is_in_prior.sum():
            samples.append(candidate_samples[is_in_prior, :])
            num_remaining -= int(is_in_prior.sum().item())

    return torch.cat(samples)
예제 #11
0
    def act(self, state):
        '''Choose action according to the policy.'''
        action_mu, action_sigma, state_value = self.forward(state)

        action_var = self.action_var.expand_as(action_mu)
        cov_mat = torch.diag_embed(action_var)
        dist = MultivariateNormal(action_mu, cov_mat)
        action = dist.sample()
        #print("act bef = ", action)
        action = np.clip(action, 0, 1)
        #print("act aft = ", action)
        log_prob = dist.log_prob(action)

        return action.detach(), log_prob.detach()
예제 #12
0
    def forward(self, state):
        value = self.critic(state)
        action_mean = self.actor(state)
        cov_mat = torch.diag(self.action_var).to(self.device)
        dist = MultivariateNormal(action_mean, cov_mat)

        if not self.random_action:
            action = action_mean
        else:
            action = dist.sample()

        action_logprobs = dist.log_prob(action)

        return action.detach(), action_logprobs, value
예제 #13
0
파일: ppo.py 프로젝트: RCAVelez/SumoRC
    def get_action(self, obs, actorIndex):
        mean = None
        if actorIndex == 1:
            mean = self.actor1(obs)
        if actorIndex == 2:
            mean = self.actor2(obs)

        dist = MultivariateNormal(mean, self.cov_mat)

        action = dist.sample()
        log_prob = dist.log_prob(action)

        return action.detach().numpy(), log_prob.detach(
        )  #might break for me here
예제 #14
0
    def act(self, state, memory):
        state_input = self.conv(state).view(-1, self.size)
        action_mean = self.actor(state_input)
        cov_mat = torch.diag(self.action_var)

        dist = MultivariateNormal(action_mean, cov_mat)
        action = dist.sample()
        action_logprob = dist.log_prob(action)

        memory.states.append(state)
        memory.actions.append(action)
        memory.logprobs.append(action_logprob)

        return action.detach()
예제 #15
0
    def get_next_state(self, state, action):
        hidden = self.hidden_arr[3](torch.cat((state, action), dim=-1))
        next_state_mean = self.policy_arr[3](hidden)
        next_state_log_std = self.next_state_std.expand_as(next_state_mean)
        next_state_std = torch.exp(next_state_log_std)
        next_state_dist = MultivariateNormal(next_state_mean,
                                             torch.diag_embed(next_state_std))
        next_state = next_state_dist.sample()
        next_state_log_prob = next_state_dist.log_prob(next_state).reshape(
            -1, 1)

        continuous_state = state[..., -17 - 6:-17] + action[..., -6:]
        return torch.cat((action[..., :-6], continuous_state, next_state),
                         dim=-1), next_state_log_prob
예제 #16
0
    def signature(self, X: torch.Tensor):
        """
        :return: signature matrix with shape (bands, rows, n_samples)
        """
        device = X.device
        N, D = X.shape

        distribution = MultivariateNormal(torch.zeros(D), torch.eye(D))
        random_planes = distribution.sample((self.bands * self.rows,)).to(device)

        # signature_matrix is (b*r) x N
        signature_matrix = (torch.mm(random_planes, X.t()) >= 0).int() * 2 - 1

        return signature_matrix.reshape(self.bands, self.rows, N)
 def forward(self, state):
     output_1 = F.relu(self.linear1(state))
     output_2 = F.relu(self.linear2(output_1))
     mu = 2 * torch.sigmoid(self.mu(output_2))   #有正有负
     sigma = F.relu(self.sigma(output_2)) + 0.001   # avoid 0 softplus    output = F.softmax(output, dim=-1)         action_mean = self.linear3(output)
     #cov_mat = torch.diag(self.action_var).to(device)
     mu = torch.diag_embed(mu).to(device)
     sigma = torch.diag_embed(sigma).to(device)  # change to 2D
     dist = MultivariateNormal(mu,sigma)  #N(μ,σ^2)  σ超参不用训练 MultivariateNormal(action_mean, cov_mat) 
     #distribution = Categorical(F.softmax(output, dim=-1))
     entropy = dist.entropy().mean()
     action = dist.sample()
     action_logprob = dist.log_prob(action)
     return action.detach(),action_logprob,entropy   #distribution .detach()
예제 #18
0
    def act(self, state, opponent_state):
        if self.has_continuous_action_space:
            pre_mean, pre_sigma = self.om(opponent_state)
            pre_var = pre_sigma**2
            pre_var = pre_var.repeat(1, 2).to(device)
            pre_mat = torch.diag_embed(pre_var).to(device)
            pre_dist = MultivariateNormal(pre_mean, pre_mat)
            pre_action = pre_dist.sample()
            pre_action = pre_action.clamp(-1, 1)
            action_mean, action_sigma = self.actor(state, pre_action[0])
            action_var = action_sigma**2
            action_var = action_var.repeat(1, 2).to(device)
            cov_mat = torch.diag_embed(action_var).to(device)
            dist = MultivariateNormal(action_mean, cov_mat)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)

        action = dist.sample()
        action = action.clamp(-1, 1)
        action_logprob = dist.log_prob(action)

        return action.detach(), action_logprob.detach(), pre_action.detach()
예제 #19
0
    def forward(self, x, a):
        x = torch.tensor(x, dtype=torch.float)

        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))

        means = self.means(x)
        dist = MultivariateNormal(means, self.eye)

        if a is None:
            a = dist.sample().detach().numpy()

        log_prob = dist.log_prob(torch.tensor(a, dtype=torch.float))
        return log_prob, a
예제 #20
0
def test_c2st_snle_external_data_on_linearGaussian(set_seed):
    """Test whether SNPE C infers well a simple example with available ground truth.

    Args:
        set_seed: fixture for manual seeding
    """

    num_dim = 2

    device = "cpu"
    configure_default_device(device)
    x_o = zeros(1, num_dim)
    num_samples = 1000

    # likelihood_mean will be likelihood_shift+theta
    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.3 * eye(num_dim)

    prior_mean = zeros(num_dim)
    prior_cov = eye(num_dim)
    prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
    gt_posterior = true_posterior_linear_gaussian_mvn_prior(
        x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov)
    target_samples = gt_posterior.sample((num_samples, ))

    def simulator(theta):
        return linear_gaussian(theta, likelihood_shift, likelihood_cov)

    infer = SNL(
        *prepare_for_sbi(simulator, prior),
        simulation_batch_size=1000,
        show_progress_bars=False,
        device=device,
    )

    external_theta = prior.sample((1000, ))
    external_x = simulator(external_theta)

    infer.provide_presimulated(external_theta, external_x)

    posterior = infer(
        num_rounds=1,
        num_simulations_per_round=1000,
        training_batch_size=100,
    ).set_default_x(x_o)
    samples = posterior.sample((num_samples, ))

    # Compute the c2st and assert it is near chance level of 0.5.
    check_c2st(samples, target_samples, alg="snpe_c")
예제 #21
0
    def forward(self, input_, action=None):
        """
        
        """

        x = torch.relu(self.fc1(input_))
        x = self.bn1(x)
        x = torch.relu(self.fc2(torch.cat([x, input_],dim=1)))
        x = self.bn2(x)
        x = torch.relu(self.fc3(torch.cat([x, input_],dim=1)))
        x = self.bn3(x)
        x = torch.relu(self.fc4(torch.cat([x, input_],dim=1)))

        action_value = torch.tanh(self.action_values(x))
        entries = torch.tanh(self.matrix_entries(x))
        V = self.value(x)
        
        action_value = action_value.unsqueeze(-1)
        
        # create lower-triangular matrix
        L = torch.zeros((input_.shape[0], self.action_size, self.action_size)).to(device)

        # get lower triagular indices
        tril_indices = torch.tril_indices(row=self.action_size, col=self.action_size, offset=0)  

        # fill matrix with entries
        L[:, tril_indices[0], tril_indices[1]] = entries
        L.diagonal(dim1=1,dim2=2).exp_()

        # calculate state-dependent, positive-definite square matrix
        P = L*L.transpose(2, 1)
        
        Q = None
        if action is not None:  

            # calculate Advantage:
            A = (-0.5 * torch.matmul(torch.matmul((action.unsqueeze(-1) - action_value).transpose(2, 1), P), (action.unsqueeze(-1) - action_value))).squeeze(-1)

            Q = A + V   
        
        
        # add noise to action mu:
        dist = MultivariateNormal(action_value.squeeze(-1), torch.inverse(P))
        #dist = Normal(action_value.squeeze(-1), 1)
        action = dist.sample()
        action = torch.clamp(action, min=-1, max=1)
        #wandb.log({"Action Noise": action.detach().cpu().numpy() - action_value.squeeze(-1).detach().cpu().numpy()})

        return action, Q, V
예제 #22
0
    def act(self, ob, state, memory):
        latent = self.encoder(ob)
        state = torch.cat((latent, state), 1)
        action_mean = self.actor(state)
        dist = MultivariateNormal(action_mean,
                                  torch.diag(self.action_var).to(device))
        action = dist.sample()
        action_logprob = dist.log_prob(action)

        memory.states.append(state)
        # memory.obs.append(ob)
        memory.actions.append(action)
        memory.logprobs.append(action_logprob)

        return action.detach()
예제 #23
0
class XavierMVNPopulationInitializer(BasePopulationInitializer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        sd = torch.eye(self.lambda_, device=self.device).cpu()
        mean = (torch.zeros_like(self.initial_value).unsqueeze(1).repeat(
            1, self.lambda_).cpu())
        self.normal = MultivariateNormal(mean, sd)

    def get_new_population(self, lower, upper):
        population = self.normal.sample().to(self.device)
        population *= self.xavier_coeffs[:, None]
        population += self.initial_value[:, None]
        population[:, 0] = self.initial_value
        return bounce_back_boundary_2d(population, lower, upper)
예제 #24
0
    def generate(self,
                 length=300,
                 batch=1,
                 bias=0.25,
                 device=torch.device("cpu")):
        """
        Get a random sample from the distribution (model)
        """
        samples = torch.zeros(length + 1, batch, 3,
                              device=device)  # batch_first=false
        lstm_states = None

        for i in range(1, length + 1):
            # get distribution parameters
            with torch.no_grad():
                e, log_pi, mu, sigma, rho, lstm_states = self.forward(
                    samples[i - 1:i], lstm_states)
            # sample from the distribution (returned parameters)
            # samples[i, :, 0] = e[-1, :] > 0.5
            distrbn1 = Bernoulli(e[-1, :])
            samples[i, :, 0] = distrbn1.sample()

            # selected_mode = torch.argmax(log_pi[-1, :, :], dim=1) # shape = (batch,)
            distrbn2 = Categorical((log_pi[-1, :, :] * (1 + bias)).exp())
            selected_mode = distrbn2.sample()

            index_1 = selected_mode.unsqueeze(1)  # shape (batch, 1)
            # shape (batch, 1, 2)
            index_2 = torch.stack([index_1, index_1], dim=2)

            mu = (mu[-1].view(batch, self.n_gaussians,
                              2).gather(dim=1, index=index_2).squeeze(dim=1))
            sigma = ((sigma[-1] / torch.exp(torch.tensor(1 + bias))).view(
                batch, self.n_gaussians,
                2).gather(dim=1, index=index_2).squeeze(dim=1))
            rho = rho[-1].gather(dim=1, index=index_1).squeeze(dim=1)

            sigma2d = sigma.new_zeros(batch, 2, 2)
            sigma2d[:, 0, 0] = sigma[:, 0]**2
            sigma2d[:, 1, 1] = sigma[:, 1]**2
            sigma2d[:, 0, 1] = rho[:] * sigma[:, 0] * sigma[:, 1]
            sigma2d[:, 1, 0] = sigma2d[:, 0, 1]

            distribn = MultivariateNormal(loc=mu, covariance_matrix=sigma2d)

            samples[i, :, 1:] = distribn.sample()

        return samples[1:, :, :]  # remove dummy first zeros
예제 #25
0
    def get_net_action(self, state, num_trajs=1):
        net = getattr(self, net_name)
        n_action_dim = getattr(self, 'n_' + action_name)
        onehot_action_dim = getattr(self, 'onehot_' + action_name + '_dim')
        multihot_action_dim = getattr(self, 'multihot_' + action_name + '_dim')
        sections = getattr(self, 'onehot_' + action_name + '_sections')
        continuous_action_log_std = getattr(
            self, net_name + '_' + action_name + '_std')
        onehot_action_probs_with_continuous_mean = net(state)

        onehot_actions = torch.empty((num_trajs, 0), device=self.device)
        multihot_actions = torch.empty((num_trajs, 0), device=self.device)
        continuous_actions = torch.empty((num_trajs, 0), device=self.device)
        onehot_actions_log_prob = 0
        multihot_actions_log_prob = 0
        continuous_actions_log_prob = 0
        if onehot_action_dim != 0:
            dist = MultiOneHotCategorical(
                onehot_action_probs_with_continuous_mean[
                    ..., :onehot_action_dim], sections)
            onehot_actions = dist.sample()
            onehot_actions_log_prob = dist.log_prob(onehot_actions)
        if multihot_action_dim != 0:
            multihot_actions_prob = torch.sigmoid(
                onehot_action_probs_with_continuous_mean[
                    ...,
                    onehot_action_dim:onehot_action_dim + multihot_action_dim])
            dist = torch.distributions.bernoulli.Bernoulli(
                probs=multihot_actions_prob)
            multihot_actions = dist.sample()
            multihot_actions_log_prob = dist.log_prob(multihot_actions).sum(
                dim=1)
        if n_action_dim - onehot_action_dim - multihot_action_dim != 0:
            continuous_actions_mean = onehot_action_probs_with_continuous_mean[
                ..., onehot_action_dim + multihot_action_dim:]
            continuous_log_std = continuous_action_log_std.expand_as(
                continuous_actions_mean)
            continuous_actions_std = torch.exp(continuous_log_std)
            continuous_dist = MultivariateNormal(
                continuous_actions_mean,
                torch.diag_embed(continuous_actions_std))
            continuous_actions = continuous_dist.sample()
            continuous_actions_log_prob = continuous_dist.log_prob(
                continuous_actions)

        return onehot_actions, multihot_actions, continuous_actions, FloatTensor(
            onehot_actions_log_prob + multihot_actions_log_prob +
            continuous_actions_log_prob).unsqueeze(-1)
예제 #26
0
    def forward(self, x, n_samples, reparam=True, squeeze=True):
        q_m = self.mean_encoder(x)
        l_mat = self.var_encoder
        q_v = l_mat.matmul(l_mat.T)

        variational_dist = MultivariateNormal(loc=q_m, scale_tril=l_mat)

        if squeeze and n_samples == 1:
            sample_shape = []
        else:
            sample_shape = (n_samples, )
        if reparam:
            latent = variational_dist.rsample(sample_shape=sample_shape)
        else:
            latent = variational_dist.sample(sample_shape=sample_shape)
        return dict(q_m=q_m, q_v=q_v, latent=latent)
예제 #27
0
    def sample_tracking_direction_prob(self,
                                       learned_gaussian_params: Tuple[Tensor,
                                                                      Tensor]):
        """
        From the gaussian parameters, sample a direction.
        """
        means, sigmas = learned_gaussian_params

        # Sample a final function in the chosen Gaussian
        # One direction per time step per sequence
        distribution = MultivariateNormal(means,
                                          covariance_matrix=torch.diag_embed(
                                              sigmas**2))
        direction = distribution.sample()

        return direction
예제 #28
0
    def act(self, state, memory):
        action_mean = self.actor(state)
        cov_mat = torch.diag(self.action_var).to(device)

        dist = MultivariateNormal(action_mean, cov_mat)
        action = dist.sample()
        # Don't allow duckie to go backwards or spin in place
        action[:, 0] = torch.clamp(action[:, 0], min=0)
        action[:, 1] = torch.clamp(action[:, 1], min=-1, max=1)
        action_logprob = dist.log_prob(action)

        memory.states.append(state)
        memory.actions.append(action)
        memory.logprobs.append(action_logprob)

        return action.detach()
예제 #29
0
def test_expectation_multivariate_norm():

    xrange = range(1, 50)
    analyticE = [expect_multivariate_norm(N) for N in range(1, 50)]

    E = []
    for N in xrange:
        dist = MultivariateNormal(torch.zeros(N), torch.eye(N))
        norms = []
        for sample in dist.sample((100, )).unbind(0):
            norms.append(sample.norm().item())
        E.append(sum(norms) / 100)

    plt.plot(xrange, E)
    plt.plot(xrange, analyticE)
    plt.show()
예제 #30
0
    def sample_tracking_directions(self,
                                   outputs: Tuple[torch.Tensor, torch.Tensor]) \
            -> torch.Tensor:
        """
        From the gaussian parameters, sample a direction.
        """
        means, variances = outputs

        # Sample a final function in the chosen Gaussian
        # One direction per time step per sequence
        distribution = MultivariateNormal(means,
                                          covariance_matrix=torch.diag_embed(
                                              variances**2))
        direction = distribution.sample()

        return direction
예제 #31
0
def run(setting='discrete_discrete'):
    if setting == 'discrete_discrete':
        y, wy = make_circle(radius=4, n_samples=n_target_samples)
        x, wx = make_circle(radius=2, n_samples=n_target_samples)

        x = torch.from_numpy(x).float()
        y = torch.from_numpy(y).float()
        wy = torch.from_numpy(wy).float()
        wx = torch.from_numpy(wx).float()

        x = MultivariateNormal(torch.zeros(2), torch.eye(2) / 4)
        x = x.sample((n_target_samples, ))
        wx = np.full(len(x), 1 / len(x))
        wx = torch.from_numpy(wx).float()

        ot_plan = OTPlan(source_type='discrete', target_type='discrete',
                         target_length=len(y), source_length=len(x))
    elif setting == 'continuous_discrete':
        x = MultivariateNormal(torch.zeros(2), torch.eye(2) / 4)
        y, wy = make_circle(radius=4, n_samples=n_target_samples)

        y = torch.from_numpy(y).float()
        wy = torch.from_numpy(wy).float()

        ot_plan = OTPlan(source_type='continuous', target_type='discrete',
                         target_length=len(y), source_dim=2)

    else:
        raise ValueError

    mapping = Mapping(ot_plan, dim=2)
    optimizer = Adam(ot_plan.parameters(), amsgrad=True, lr=lr)
    # optimizer = SGD(ot_plan.parameters(), lr=lr)


    plan_objectives = []
    map_objectives = []

    print('Learning OT plan')

    for i in range(n_plan_iter):
        optimizer.zero_grad()

        if setting == 'discrete_discrete':
            this_yidx = torch.multinomial(wy, batch_size)
            this_y = y[this_yidx]
            this_xidx = torch.multinomial(wx, batch_size)
            this_x = x[this_xidx]
        else:
            this_x = x.sample((batch_size,))
            this_yidx = torch.multinomial(wy, batch_size)
            this_y = y[this_yidx]
            this_xidx = None
        loss = ot_plan.loss(this_x, this_y, yidx=this_yidx, xidx=this_xidx)
        loss.backward()
        optimizer.step()
        plan_objectives.append(-loss.item())
        if i % 100 == 0:
            print(f'Iter {i}, loss {-loss.item():.3f}')

    optimizer = Adam(mapping.parameters(), amsgrad=True, lr=lr)
    # optimizer = SGD(mapping.parameters(), lr=1e-5)


    print('Learning barycentric mapping')
    for i in range(n_map_iter):
        optimizer.zero_grad()

        if setting == 'discrete_discrete':
            this_yidx = torch.multinomial(wy, batch_size)
            this_y = y[this_yidx]
            this_xidx = torch.multinomial(wx, batch_size)
            this_x = x[this_xidx]
        else:
            this_x = x.sample((batch_size,))
            this_yidx = torch.multinomial(wy, batch_size)
            this_y = y[this_yidx]
            this_xidx = None

        loss = mapping.loss(this_x, this_y, yidx=this_yidx, xidx=this_xidx)
        loss.backward()
        optimizer.step()
        map_objectives.append(loss.item())
        if i % 100 == 0:
            print(f'Iter {i}, loss {loss.item():.3f}')

    if setting == 'continuous_discrete':
        x = x.sample((len(y),))
    with torch.no_grad():
        mapped = mapping(x)
    x = x.numpy()
    y = y.numpy()
    mapped = mapped.numpy()

    return x, y, mapped, plan_objectives, map_objectives