예제 #1
0
    def __init__(
        self, net, device, training_data, validation_data, batch_size, energy_evaluator
    ):
        self.net = net
        self.device = device
        self.training_data = training_data
        self.validation_data = validation_data
        self.batch_size = batch_size
        self.energy_evaluator = energy_evaluator

        self.training_indices = np.arange(self.training_data.shape[0])
        self.validation_indices = np.arange(self.validation_data.shape[0])

        # Setup latent gaussian distribution
        mu = torch.zeros(self.training_data.shape[-1] - 6, device=device)
        cov = torch.eye(self.training_data.shape[-1] - 6, device=device)
        self.latent_distribution = distributions.MultivariateNormal(
            mu, covariance_matrix=cov
        ).expand((self.batch_size,))

        # These statistics are updated during training.
        self.forward_loss = None
        self.forward_ml = None
        self.forward_jac = None
        self.val_forward_loss = None
        self.val_forward_ml = None
        self.val_forward_jac = None
        self.inverse_loss = None
        self.inverse_kl = None
        self.inverse_jac = None
        self.mean_energy = None
        self.median_energy = None
        self.min_energy = None
        self.acceptance_probs = []
예제 #2
0
def pre_train_unconditional_nsf(
    net, device, training_data, batch_size, epochs, lr, out_freq
):
    mu = torch.zeros(training_data.shape[-1] - 6, device=device)
    cov = torch.eye(training_data.shape[-1] - 6, device=device)
    dist = distributions.MultivariateNormal(mu, covariance_matrix=cov).expand(
        (batch_size,)
    )

    indices = np.arange(training_data.shape[0])
    optimizer = setup_optimizer(net, lr, 0.0)
    with tqdm(range(epochs)) as progress:
        for epoch in progress:
            net.train()

            index_batch = np.random.choice(
                indices, args.pretrans_batch_size, replace=True
            )
            x_batch = training_data[index_batch, :]
            loss, _, _ = get_ml_loss(net, x_batch, 1.0, dist)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if epoch % out_freq == 0:
                progress.set_postfix(loss=f"{loss.item():8.3f}")
예제 #3
0
    def score_shapes_type(self, subid, shapes):
        """
        Compute the log-probability of the control points for each sub-stroke
        under the prior

        Parameters
        ----------
        subid : (nsub,) tensor
            sub-stroke ID sequence
        shapes : (ncpt, 2, nsub) tensor
            shapes of bsplines

        Returns
        -------
        ll : (nsub,) tensor
            vector of log-likelihood scores
        """
        if self.isunif:
            raise NotImplementedError
        # check that subid is a vector
        assert len(subid.shape) == 1
        # record vector length
        nsub = len(subid)
        assert shapes.shape[-1] == nsub
        # reshape tensor (ncpt, 2, nsub) -> (ncpt*2, nsub)
        shapes = shapes.view(self.ncpt * 2, nsub)
        # transpose axes (ncpt*2, nsub) -> (nsub, ncpt*2)
        shapes = shapes.transpose(0, 1)
        # create multivariate normal distribution
        mvn = dist.MultivariateNormal(self.shapes_mu[subid],
                                      self.shapes_Cov[subid])
        # score points using the multivariate normal distribution
        ll = mvn.log_prob(shapes)

        return ll
예제 #4
0
    def conditional_distribution(self, xti, T=1, reverse=False):
        '''
        Conditional Distribution will compute the distribution in the Polar Space
        '''
        _mu = xti
        _var = torch.zeros(xti.shape[0], self.dim, self.dim).to(xti)
        if not reverse:
            for i in range(T):
                Ad = self.first_Taylor_dyn(_mu) * self.dt + torch.eye(self.dim).to(xti)
                _mu = self.velocity(_mu) * self.dt + _mu
                _var = torch.bmm(torch.bmm(Ad, _var), Ad) + self.var * self.dt
        else:
            for i in range(T):
                Ad = -self.first_Taylor_dyn(_mu) * self.dt + torch.eye(self.dim).to(xti)
                _mu = -self.velocity(_mu) * self.dt + _mu
                _var = torch.bmm(torch.bmm(Ad, _var), Ad) + self.var * self.dt

        dists = []
        dist_r = tdist.Normal(loc=_mu[:,0], scale=torch.sqrt(_var[:,0,0]))
        dists.append(dist_r)
        dist_w = AngleNormal(loc=_mu[:,1], scale=torch.sqrt(_var[:,1,1]))
        dists.append(dist_w)
        if self.dim ==3:
            dist_z = tdist.Normal(loc=_mu[:,2], scale=_var[:,2,2])
            dists.append(dist_z)
        elif self.dim>3:
            dist_z = tdist.MultivariateNormal(loc=_mu[:,2:], scale=_var[:,2:,2:])
            dists.append(dist_z)
        return dists
예제 #5
0
파일: dim_model.py 프로젝트: Czworldy/GP
 def to(self, *args, **kwargs):
     self = super().to(*args, **kwargs)
     self._base_dist = D.MultivariateNormal(
         loc=self._base_dist.mean.to(*args, **kwargs),
         scale_tril=self._base_dist.scale_tril.to(*args, **kwargs),
     )
     return self
예제 #6
0
    def sample_shapes_type(self, subid):
        """
        Sample the control points for each sub-stroke ID in a given sequence

        Parameters
        ----------
        subid : (nsub,) tensor
            sub-stroke ID sequence

        Returns
        -------
        shapes : (ncpt, 2, nsub) tensor
            sampled shapes of bsplines
        """
        if self.isunif:
            raise NotImplementedError
        # check that subid is a vector
        assert len(subid.shape) == 1
        # record vector length
        nsub = len(subid)
        # create multivariate normal distribution
        mvn = dist.MultivariateNormal(self.shapes_mu[subid],
                                      self.shapes_Cov[subid])
        # sample points from the multivariate normal distribution
        shapes = mvn.sample()
        # transpose axes (nsub, ncpt*2) -> (ncpt*2, nsub)
        shapes = shapes.transpose(0, 1)
        # reshape tensor (ncpt*2, nsub) -> (ncpt, 2, nsub)
        shapes = shapes.view(self.ncpt, 2, nsub)

        return shapes
예제 #7
0
파일: Models.py 프로젝트: kevin-w-li/al-ws
 def sample_logp0(self, n):
     
     logp0 = 0
     
     #pz_pi= td.Categorical(self.alpha.softmax(0).clamp(1e-5,1-1e-5))
     pz_pi= td.Categorical(logits=self.alpha)
     z    = pz_pi.sample([n])
     logp0 = logp0 + pz_pi.log_prob(z)
     
     mean = self.mu[z,:]
     cov  = torch.einsum("ijk,ilk->ijl", self.chol, self.chol)
     
     #cov = self.chol
     cov = cov + torch.eye(cov.shape[-1], device=cov.device)  * 1e-6
     cov = cov[z,...]
     
     py_z = td.MultivariateNormal(mean, cov)
     y    = py_z.sample([])
     logp0 += py_z.log_prob(y).sum(-1)
             
     p   = self.obs.conditional_param(y)
     x    = p.sample([])
     nat  = self.obs.nat(y)
     norm = self.obs.norm(y)       
     x    = x.reshape(x.shape[0],-1)
     x    = x.reshape(y.shape[0],-1)
     
 
     return (None, z, None, None, y, p.mean.detach()), x, norm - logp0, nat
예제 #8
0
파일: dim_model.py 프로젝트: Czworldy/GP
    def __init__(self, output_shape, hidden_size: int = 64):
        """
        Args:
          output_shape: The shape of the base and data distribution (a.k.a. event_shape).
          hidden_size: The dimensionality of the GRU hidden state.
        """
        super(AutoregressiveFlow, self).__init__()
        self._output_shape = output_shape

        # Initialises the base distribution.
        self._base_dist = D.MultivariateNormal(
            loc=torch.zeros(self._output_shape[-2] * self._output_shape[-1]),
            scale_tril=torch.eye(self._output_shape[-2] *
                                 self._output_shape[-1]),
        )

        # The decoder recurrent network used for the sequence generation.
        self._decoder = nn.GRUCell(
            input_size=self._output_shape[-1],
            hidden_size=hidden_size,
        )

        # The output head.
        self._locscale = MLP(
            input_size=hidden_size,
            # output_sizes=[32, self._output_shape[0]],
            output_sizes=[32, 4],
            activation_fn=nn.ReLU,
            dropout_rate=None,
            activate_final=False,
        )
예제 #9
0
 def forward(self, observation: torch.FloatTensor, multivariate=False):
     # TODO: get this from hw1
     if self.discrete:
         tmp = self.logits_na(observation)
         actions_probability = F.softmax(tmp, dim=-1)
         # inspect gradient of discrete actions
         if torch.sum(torch.isnan(actions_probability)).item() > 0:
             params = self.logits_na.parameters()
             print(actions_probability)
             for param in params:
                 print(param)
             print(actions_probability)
             assert False, "Gradient Error"
         actions_distribution = distributions.categorical.Categorical(
             actions_probability)
         return actions_distribution
     else:
         if not multivariate:
             loc = self.mean_net(observation)
             scale = torch.exp(self.logstd)
             actions_distribution = distributions.normal.Normal(loc, scale)
             return actions_distribution
         else:
             # Multivariate Gaussian distribution version
             batch_mean = self.mean_net(observation)
             batch_scale_tril = torch.diag(torch.exp(self.logstd))
             # batch_dim = batch_mean.shape[0]
             # batch_scale_tril = scale_tril.repeat(batch_dim, 1, 1)
             actions_distribution = distributions.MultivariateNormal(
                 batch_mean,
                 scale_tril=batch_scale_tril,
             )
             return actions_distribution
예제 #10
0
    def test_log_normal_spherical(self):
        """
        Test the log-normal probabilities for spherical covariance.
        """
        N = 100
        S = 50
        D = 10

        means = torch.randn(S, D)
        covs = torch.rand(S)
        x = torch.randn(N, D)

        distributions = [
            dist.MultivariateNormal(means[i],
                                    torch.diag(covs[i].clone().expand(D)))
            for i in range(S)
        ]

        expected = []
        for item in x:
            e_item = []
            for d in distributions:
                e_item.append(d.log_prob(item))
            expected.append(e_item)
        expected = torch.as_tensor(expected)

        predicted = log_normal(x, means, covs, 'spherical')

        self.assertTrue(
            torch.allclose(expected, predicted, atol=1e-03, rtol=1e-05))
예제 #11
0
    def forward(self, h, Q, u):
        batch_size = h.size()[0]
        v, r = self.trans(h).chunk(2, dim=1)
        v1 = v.unsqueeze(2)
        rT = r.unsqueeze(1)
        I = torch.eye(self.dim_z).repeat(batch_size, 1, 1)
        if rT.is_cuda:
            I = I.to(rT.device)
        A = I.add(v1.bmm(rT))

        B = self.fc_B(h).view(-1, self.dim_z, self.dim_u)
        o = self.fc_o(h)

        # need to compute the parameters for distributions
        # as well as for the samples
        u = u.unsqueeze(2)

        d = A.bmm(Q.mean.unsqueeze(2)).add(B.bmm(u)).add(
            o.unsqueeze(2)).squeeze(2)
        sample = A.bmm(h.unsqueeze(2)).add(B.bmm(u)).add(
            o.unsqueeze(2)).squeeze(2)

        z_cov = Q.covariance_matrix
        z_next_cov = A.bmm(z_cov).bmm(A.transpose(1, 2))
        # return sample, NormalDistribution(d, Q.sigma, Q.logsigma, v=v, r=r)
        return sample, distributions.MultivariateNormal(d, z_next_cov)
예제 #12
0
def sample_obj(lib, model, tau, seq_to_x, X_all, observed=[],
               its=1000, n=100, return_all=False):
    num_greater = torch.zeros(its)
    lib = helpers.seqs_from_set(lib, 4)
    unseen_lib = np.array(sorted(set(lib) - set(observed)))
    if len(unseen_lib) > 1:
        X_test = torch.tensor(X_all[[seq_to_x[s] for s in unseen_lib]]).float()
        mu, K = model(X_test.double())

        for i in range(its):
            rand_inds = np.random.choice(len(lib), n, replace=True)
            rand_inds = np.unique(rand_inds)
            rand_inds = rand_inds[np.where(rand_inds < len(unseen_lib))]
            if len(rand_inds) < 1:
                continue
            else:
                mu_chosen = mu[rand_inds].squeeze()
                K_chosen = K[rand_inds][:, rand_inds]
                if len(rand_inds) == 1:
                    sample = dist.Normal(mu_chosen,
                                         torch.sqrt(K_chosen.squeeze())).sample()
                else:
                    sample = dist.MultivariateNormal(mu_chosen, K_chosen).sample()
            num_greater[i] = torch.sum(sample > tau)
    elif len(unseen_lib) == 1:
        X_test = torch.tensor(X_all[[seq_to_x[s] for s in unseen_lib]]).float()
        mu, var = model(X_test.double())
        std = torch.sqrt(var).detach()
        mu = mu.detach()
        num_greater = torch.ones(its) * (1 - dist.Normal(mu, std).cdf(tau))
    if return_all:
        return -torch.mean(num_greater), num_greater
    else:
        return -torch.mean(num_greater)
 def get_action(self, action_prob, context='train'):
     if self.action_space_type == 'discrete':
         if context == 'train':
             dist = distributions.Categorical(action_prob)
             action = dist.sample()
             log_prob_action = dist.log_prob(action)
             return action, log_prob_action  # Both have to be tensors, or the processing gets weird later on.
         else:
             action = torch.argmax(action_prob).item()
             return action
     elif self.action_space_type == 'continuous':
         if self.family == 'stochastic':
             action_mean_vector = action_prob * self.action_upper_limit  # Scaling up to ensure that the value works.
             # Covariance matrix is defined in init. Does not change.
             dist = distributions.MultivariateNormal(
                 action_mean_vector, self.covariance_matrix)
             action = dist.sample()
             # Clip action to ensure no erroneous action taken by accident.
             action = torch.clamp(action,
                                  min=self.action_lower_limit,
                                  max=self.action_upper_limit)
             log_prob_action = dist.log_prob(action)
             if context == 'train':
                 return action, log_prob_action
             else:
                 return action.detach().numpy()
         elif self.family == 'deterministic':
             # Action Prob here is actually just action itself. No probability.
             # Logic adapted from TD3 repo actors. Scaled up to max_action_possible, since tanh before.
             action = action_prob * self.action_upper_limit
             if context == 'train':
                 return action, None
             else:
                 return action.detach().numpy().reshape(self.dim[-1], )
예제 #14
0
def get_chi_star(args, Cstar, Qstar, eps=0.01):
    ww = args.SigmaW**2
    bb = args.SigmaB**2
    mu = args.Mu
    n = args.num_samples  ##num samples for convariance measurement
    N = args.output_size
    M = np.zeros(N)
    mx = min(Cstar + eps, 1.0)
    mn = max(Cstar - eps, 0.0)
    Cv = torch.linspace(mn, mx, N)
    for i in range(N):
        c = Cv[i]
        Qu = ww * Qstar + bb
        Cu = (ww * Qstar * c + bb) / Qu
        ##print("Qu = %.02f, Cu = %.02f" % (Qu,Cu))
        U = tdist.Normal(torch.tensor([mu]), torch.tensor([np.sqrt(Qu)]))
        u = args.act(U.sample([n]))
        Eu = torch.mean(u)

        Uc = tdist.MultivariateNormal(torch.tensor([0.0, 0.0]),
                                      torch.eye(2)).sample([n])
        u1 = Uc[:, 0]
        u2 = Uc[:, 1]
        ch1 = np.sqrt(Qu) * u1 + mu
        ch2 = np.sqrt(Qu) * (Cu * u1 + np.sqrt(1.0 - Cu * Cu) * u2) + mu
        tmp = torch.mean(args.act(ch1) * args.act(ch2)) - Eu**2
        mc = tmp / Qstar
        if (mc >= 1.0):
            mc = 1.0
        M[i] = mc

    slope, intercept, r_value, p_value, std_err = stats.linregress(Cv, M)

    return slope
예제 #15
0
def get_ff_star(args):
    ww = args.SigmaW**2
    bb = args.SigmaB**2
    mu = args.Mu
    n = args.num_samples  ##num samples for convariance measurement
    N = args.output_size

    Qstar = get_ff_qstar(args)
    c = 0.5
    for i in range(N):
        Qu = ww * Qstar + bb
        Cu = (ww * Qstar * c + bb) / Qu
        ##print("Qu = %.02f, Cu = %.02f" % (Qu,Cu))
        U = tdist.Normal(torch.tensor([mu]), torch.tensor([np.sqrt(Qu)]))
        u = args.act(U.sample([n]))
        Eu = torch.mean(u)

        Uc = tdist.MultivariateNormal(torch.tensor([0.0, 0.0]),
                                      torch.eye(2)).sample([n])
        u1 = Uc[:, 0]
        u2 = Uc[:, 1]
        ch1 = np.sqrt(Qu) * u1 + mu
        ch2 = np.sqrt(Qu) * (Cu * u1 + np.sqrt(1.0 - Cu * Cu) * u2) + mu
        tmp = torch.mean(args.act(ch1) * args.act(ch2)) - Eu**2
        c = tmp / Qstar
        if (c >= 1.0):
            return 1.0, Qstar
        elif (c <= 0.0):
            return 0.0, Qstar
    return c, Qstar
예제 #16
0
    def __init__(self, k, b, m, dt, x0, T, num_examples):
        super(SyntheticDataset, self).__init__()
        # x0: (4, ) tensor
        # state = (2D velocity, 2D position)
        self.A = torch.tensor([[-b / m, 0., -k / m, 0.],
                               [0., -b / m, 0., -k / m], [1., 0., 0., 0.],
                               [0., 1., 0., 0.]])
        self.A = self.A * dt + torch.eye(4)
        self.A.requires_grad = False
        # measurement = 2D position
        self.C = torch.tensor([[0., 0., 1., 0.], [0., 0., 0., 1.]])
        self.C.requires_grad = False
        # dynamics noise: IID zero mean Gaussian (only applied to velocity)
        self.Bw = torch.tensor([[1., 0.], [0., 1.], [0., 0.], [0., 0.]])
        self.Bw.requires_grad = False
        self.Q = torch.eye(2)
        self.mvnormal_process = tdist.MultivariateNormal(
            torch.zeros(2), self.Q)
        self.Q.requires_grad = False
        self.L = torch.tensor([1.1, 0.15, 1.1])
        self.L.requires_grad = False
        L = torch.tensor([[torch.exp(self.L[0]), 0.0],
                          [self.L[1], torch.exp(self.L[2])]])
        self.R = L @ L.t()
        self.R.requires_grad = False
        self.mvnormal_measurement = tdist.MultivariateNormal(
            torch.zeros(2), self.R)
        self.T = T
        self.num_examples = num_examples

        # Data generation
        self.data = []
        for ii in range(self.num_examples):
            xs = [x0]
            zs = []
            ys = []
            for t in range(self.T):
                x_output = self.A @ xs[
                    -1] + self.Bw @ self.mvnormal_process.sample()
                xs.append(x_output)
                z_output = self.C @ x_output + self.mvnormal_measurement.sample(
                )
                zs.append(z_output)
                y_output = self.C @ x_output
                ys.append(y_output)
            xs.pop(0)
            self.data.append((torch.stack(zs, 0), torch.stack(ys, 0)))
예제 #17
0
파일: trainer.py 프로젝트: riddhishb/dpVAE
    def __init__(self, config):
        super(Trainer, self).__init__()
        self.use_cuda = torch.cuda.is_available()
        self.device = 'cuda' if self.use_cuda else 'cpu'
        # self.device ='cuda:1'

        # model
        self.modef = config['model']
        self.model = get_model(config)
        self.input_dims = config['input_dims']
        self.z_dims = config['z_dims']
        self.prior = distributions.MultivariateNormal(torch.zeros(self.z_dims),
                                                      torch.eye(self.z_dims))

        # train
        self.max_iter = config['max_iter']
        self.global_iter = 1
        self.mseWeight = config['mse_weight']
        self.lr = config['lr']
        self.beta1 = config['beta1']
        self.beta2 = config['beta2']
        self.optim = optim.Adam(self.model.parameters(),
                                lr=self.lr,
                                betas=(self.beta1, self.beta2))
        self.implicit = 'implicit' in config and config['implicit']
        if self.implicit:
            self.train_inst = self.implicit_inst

        # saving
        self.ckpt_dir = config['ckpt_dir']
        os.makedirs(self.ckpt_dir, exist_ok=True)
        self.ckpt_name = config['ckpt_name']
        self.save_output = config['save_output']
        self.output_dir = config['output_dir']
        os.makedirs(self.output_dir, exist_ok=True)
        # saving
        if config['cont'] and self.ckpt_name is not None:
            self.load_checkpoint(self.ckpt_name)

        self.meta = defaultdict(list)

        self.gather_step = config['gather_step']
        self.display_step = config['display_step']
        self.save_step = config['save_step']

        # data
        self.dset_dir = config['dset_dir']
        self.dataset = config['dataset']
        self.data_type = config['data_type']
        if self.data_type == 'linear':
            self.draw_reconstruction = self.linear_reconstruction
            self.draw_generated = self.linear_generated
            self.visualize_traverse = self.linear_traverse
            self.traversal = self.linear_traversal
        self.batch_size = config['batch_size']
        self.img_size = 32 if 'image_size' not in config else config[
            'image_size']
        self.data_loader = get_dataset(config)
        self.val_loader = get_dataset(config, train=False)
예제 #18
0
    def decoder(self,
                z,
                encoded_history,
                current_state,
                y_e=None,
                train=False):
        pass

        bs = encoded_history.shape[0]
        a_0 = F.dropout(self.action(current_state.reshape(bs, -1)),
                        self.dropout_p)
        state = F.dropout(self.state(encoded_history.reshape(bs, -1)),
                          self.dropout_p)

        current_state = current_state.unsqueeze(1)
        gauses = []
        inp = F.dropout(
            torch.cat((encoded_history.reshape(bs, -1), a_0), dim=-1),
            self.dropout_p)
        for i in range(12):
            h_state = self.gru(inp.reshape(bs, -1), state)

            _, deltas, log_sigmas, corrs = self.project_to_GMM_params(h_state)
            deltas = torch.clamp(deltas, max=1.5, min=-1.5)
            deltas = deltas.reshape(bs, -1, 2)
            log_sigmas = log_sigmas.reshape(bs, -1, 2)
            corrs = corrs.reshape(bs, -1, 1)

            mus = deltas + current_state
            current_state = mus
            variance = torch.clamp(torch.exp(log_sigmas).unsqueeze(2)**2,
                                   max=1e3)

            m_diag = variance * torch.eye(2).to(variance.device)
            sigma_xy = torch.clamp(torch.prod(torch.exp(log_sigmas), dim=-1),
                                   min=1e-8,
                                   max=1e3)

            if train:
                # log_pis = z.reshape(bs, 1) * torch.ones(bs, self.num_modes).cuda()
                log_pis = to_one_hot(z, n_dims=self.num_modes).cuda()

            else:
                log_pis = to_one_hot(z, n_dims=self.num_modes).cuda()
            log_pis = log_pis - torch.logsumexp(log_pis, dim=-1, keepdim=True)
            mix = D.Categorical(logits=log_pis)
            comp = D.MultivariateNormal(mus, m_diag)
            gmm = D.MixtureSameFamily(mix, comp)
            t = (sigma_xy * corrs.squeeze()).reshape(-1, 1, 1)
            cov_matrix = m_diag  # + anti_diag
            gauses.append(gmm)
            a_t = gmm.sample()  # possible grad problems?
            a_tt = F.dropout(self.action(a_t.reshape(bs, -1)), self.dropout_p)
            state = h_state
            inp = F.dropout(
                torch.cat((encoded_history.reshape(bs, -1), a_tt), dim=-1),
                self.dropout_p)

        return gauses
예제 #19
0
 def gaussians(self):
     gaussians = [
         dist.MultivariateNormal(
             mean,
             F.softplus(inv_std)**2 * torch.eye(self.d).to(self.device))
         for mean, inv_std in zip(self.means, self.inv_cov_stds)
     ]
     return gaussians
    def sample(self, n) -> Iterable[Tuple[Individual, t.Tensor]]:
        dist = d.MultivariateNormal(loc=self.means, scale_tril=t.exp(self.log_stds).tril())
        with t.no_grad():
            samples = dist.sample((n,))
        log_probs = dist.log_prob(samples)
        individuals = [self.constructor(self._to_shapes(s)) for s in samples]

        return zip(individuals, log_probs.unbind(0))
예제 #21
0
 def log_prob(self, x):
     z, sum_log_jacobians = self.forward(x)
     log_prob_z = tdist.MultivariateNormal(
         self.mean.repeat(z.size(1), 1, 1).permute(1, 0, 2),
         self.cov.repeat(z.size(1), 1, 1, 1).permute(1, 0, 2, 3)
     ).log_prob(z)
     log_prob_x = log_prob_z + sum_log_jacobians  # [batch_size]
     return log_prob_x
 def sample(self, n) -> Iterable[Tuple[Individual, t.Tensor]]:
     for _ in range(n):
         dist = d.MultivariateNormal(loc=self.means,
                                     scale_tril=t.exp(self.log_stds).tril())
         with t.no_grad():
             sample = dist.sample()
         log_prob = dist.log_prob(sample)
         yield self.constructor(self._to_shapes(sample)), log_prob
예제 #23
0
 def cartesian_conditional_distribution(self, xti, T=1, reverse=False):
     _mu = xti
     _var = torch.zeros(xti.shape[0], self.dim, self.dim).to(xti)
     if not reverse:
         _mu = self.evolve(_mu, T=T)
         var = self.var*self.dt*T
         _var = var
     return tdist.MultivariateNormal(_mu, _var)
예제 #24
0
 def to(self, *args, **kwargs):
     """Handles non-parameter tensors when moved to a new device."""
     self = super().to(*args, **kwargs)
     self._base_dist = D.MultivariateNormal(
         loc=self._base_dist.mean.to(*args, **kwargs),
         scale_tril=self._base_dist.scale_tril.to(*args, **kwargs),
     )
     return self
예제 #25
0
def x_clf_cross_loss(mu1, sig1, w1, mu2, sig2, w2, z, preds=None):
    f1 = dist.MultivariateNormal(mu1, sig1)
    f2 = dist.MultivariateNormal(mu2, sig2)

    p_z_m1 = f1.log_prob(z) + tr.log(w1)
    p_z_m2 = f2.log_prob(z) + tr.log(w2)

    p_z = log_prob_sum(p_z_m1, p_z_m2)

    if preds is None:
        loss = -tr.max(p_z_m1, p_z_m2) + p_z
    else:
        preds = tr.tensor(preds, dtype=tr.int32).cuda()
        loss = -tr.where(preds == 0, p_z_m1, p_z_m2) + p_z

    loss = loss.sum()
    return loss
예제 #26
0
def ppo_update(config, f_actor, diff_actor_opt, critic, critic_opt, memory_cache, update_type='meta'):
    # Actor is functional in meta, and normal in rl.
    summed_policy_loss = torch.zeros(1)
    summed_value_loss = torch.zeros(1)

    states, next_states, actions_init, rewards, dones, log_prob_actions_init = get_shaped_memory_sample(config, memory_cache)
    # Using critic to predict last reward. Just as a placeholder in case the trajectory is incomplete in the batch-mode.
    final_predicted_reward = 0.
    if dones[-1] == 0.:  # Then last step is not done. Last value has to be predicted.
        final_state = next_states[-1]
        with torch.no_grad():
            final_predicted_reward = critic(final_state).detach().item()
    returns = calculate_returns(config, rewards, dones, predicted_end_reward=final_predicted_reward) #Returns(samples,1)
    # At this point, they should always be tensors and output a tensor based solution.
    values_init = critic(states)
    advantages = returns - values_init
    if config.normalize_rewards_and_advantages:
        advantages = (advantages - advantages.mean()) / advantages.std()
    advantages = advantages.detach()  # Necessary to keep the advantages from have a connection to the value model.
    # Now the actor makes steps and recalculates actions and log_probs based on the current values for k epochs.

    for ppo_step in range(config.num_ppo_steps):
        action_prob = f_actor(states)
        # print('action_prob', type(action_prob), action_prob.shape, action_prob)
        values_pred = critic(states)
        if config.env_config.action_space_type == 'discrete':
            dist = distributions.Categorical(action_prob) ## Stupido
            actions_init = actions_init.squeeze(-1)
            new_log_prob_actions = dist.log_prob(actions_init)
            new_log_prob_actions = new_log_prob_actions.view(-1, 1)
        elif config.env_config.action_space_type == 'continuous':
            action_mean_vector = action_prob * f_actor.action_upper_limit  # Direct code from actor get_action, refer there
            dist = distributions.MultivariateNormal(action_mean_vector, f_actor.covariance_matrix)
            actions_init = actions_init.view(-1, config.action_dim)
            new_log_prob_actions = dist.log_prob(actions_init)
            new_log_prob_actions = new_log_prob_actions.view(-1, 1)

        policy_ratio = (new_log_prob_actions - log_prob_actions_init).exp()
        policy_loss_1 = policy_ratio * advantages
        policy_loss_2 = torch.clamp(policy_ratio, min=1.0 - config.ppo_clip, max=1.0 + config.ppo_clip) * advantages
        if config.include_entropy_in_ppo:
            inner_policy_loss = (
                        -torch.min(policy_loss_1, policy_loss_2) - config.entropy_coefficient * dist.entropy()).sum()
        else:
            inner_policy_loss = -torch.min(policy_loss_1, policy_loss_2).sum()
        if update_type == 'meta':
            diff_actor_opt.step(inner_policy_loss)
        else:
            # In this case, it's normal RL, and so there is no updating that happens outside in the main function.
            diff_actor_opt.zero_grad()
            inner_policy_loss.backward()
            diff_actor_opt.step()
        inner_value_loss = F.smooth_l1_loss(values_pred, returns).sum()
        inner_value_loss.backward()
        critic_opt.step()
        summed_policy_loss += inner_policy_loss
        summed_value_loss += inner_value_loss
    return summed_policy_loss, summed_value_loss.item()
예제 #27
0
파일: losses.py 프로젝트: indyfree/CARLA
def csvae_loss(csvae, x_train, y_train):
    x = x_train.clone()
    x = x.float()
    y = y_train.clone()
    y = y.float()

    (
        x_mu,
        x_logvar,
        zw,
        y_pred,
        w_mu_encoder,
        w_logvar_encoder,
        w_mu_prior,
        w_logvar_prior,
        z_mu,
        z_logvar,
    ) = csvae.forward(x, y)

    x_recon = nn.MSELoss()(x_mu, x)

    w_dist = dists.MultivariateNormal(
        w_mu_encoder.flatten(), torch.diag(w_logvar_encoder.flatten().exp()))
    w_prior = dists.MultivariateNormal(
        w_mu_prior.flatten(), torch.diag(w_logvar_prior.flatten().exp()))
    w_kl = dists.kl.kl_divergence(w_dist, w_prior)

    z_dist = dists.MultivariateNormal(z_mu.flatten(),
                                      torch.diag(z_logvar.flatten().exp()))
    z_prior = dists.MultivariateNormal(
        torch.zeros(csvae.z_dim * z_mu.size()[0]),
        torch.eye(csvae.z_dim * z_mu.size()[0]),
    )
    z_kl = dists.kl.kl_divergence(z_dist, z_prior)

    y_pred_negentropy = (y_pred.log() * y_pred + (1 - y_pred).log() *
                         (1 - y_pred)).mean()

    class_label = torch.argmax(y, dim=1)
    y_recon = (100.0 * torch.where(class_label == 1, -torch.log(y_pred[:, 1]),
                                   -torch.log(y_pred[:, 0]))).mean()

    ELBO = 40 * x_recon + 0.2 * z_kl + 1 * w_kl + 110 * y_pred_negentropy

    return ELBO, x_recon, w_kl, z_kl, y_pred_negentropy, y_recon
예제 #28
0
 def sample(self, batch_size):
     if self._diag_type == "cholesky":
         return dists.MultivariateNormal(self.loc, self.scale).rsample(
             (batch_size, ))
     elif self._diag_type == 'diag':
         return dists.Normal(self.loc, self.std).rsample((batch_size, ))
     else:
         raise NotImplementedError(
             "_diag_type can only be cholesky or diag")
예제 #29
0
 def log_prob(self, value):
     if self._diag_type == "cholesky":
         return dists.MultivariateNormal(self.loc,
                                         self.scale).log_prob(value)
     elif self._diag_type == 'diag':
         return dists.Normal(self.loc, self.std).log_prob(value).sum(dim=-1)
     else:
         raise NotImplementedError(
             "_diag_type can only be cholesky or diag")
예제 #30
0
    def _sample(self, num_samples):

        weights = self.module.soft_max(self.module.soft_weights)
        idx = dist.Categorical(probs=weights).sample([num_samples])
        X = dist.MultivariateNormal(loc=self.module.means,
                                    scale_tril=self.module.L).sample(
                                        [num_samples])

        return X[torch.arange(num_samples, device=self.device), idx, :]