def get_dist(self):
        n = len(self.mean)
        mix = D.Categorical(torch.ones(n, ))
        comp = D.Independent(D.Normal(self.mean, self.var * torch.ones(n, 2)),
                             1)

        return D.MixtureSameFamily(mix, comp)
예제 #2
0
    def log_prob(self, locations_3d, x_offset_3d, y_offset_3d, z_offset_3d,
                 intensities_3d):
        xyzi, counts, s_mask = get_true_labels(locations_3d, x_offset_3d,
                                               y_offset_3d, z_offset_3d,
                                               intensities_3d)
        x_mu, y_mu, z_mu, i_mu = (i.unsqueeze(1)
                                  for i in torch.unbind(self.xyzi_mu, dim=1))
        x_si, y_si, z_si, i_si = (
            i.unsqueeze(1) for i in torch.unbind(self.xyzi_sigma, dim=1))

        P = torch.sigmoid(self.logits) + 0.00001
        count_mean = P.sum(dim=[2, 3, 4]).squeeze(-1)
        count_var = (P - P**2).sum(dim=[2, 3, 4]).squeeze(
            -1)  #avoid situation where we have perfect match
        count_dist = D.Normal(count_mean, torch.sqrt(count_var))
        count_prob = count_dist.log_prob(counts)
        mixture_probs = P / P.sum(dim=[1, 2, 3], keepdim=True)

        xyz_mu_list, _, _, i_mu_list, x_sigma_list, y_sigma_list, z_sigma_list, i_sigma_list, mixture_probs_l = img_to_coord(
            P, x_mu, y_mu, z_mu, i_mu, x_si, y_si, z_si, i_si, mixture_probs)
        xyzi_mu = torch.cat((xyz_mu_list, i_mu_list), dim=-1)
        xyzi_sigma = torch.cat(
            (x_sigma_list, y_sigma_list, z_sigma_list, i_sigma_list),
            dim=-1)  #to avoind NAN
        mix = D.Categorical(mixture_probs_l.squeeze(-1))
        comp = D.Independent(D.Normal(xyzi_mu, xyzi_sigma), 1)
        spatial_gmm = D.MixtureSameFamily(mix, comp)
        spatial_prob = spatial_gmm.log_prob(xyzi.transpose(0,
                                                           1)).transpose(0, 1)
        spatial_prob = (spatial_prob * s_mask).sum(-1)
        log_prob = count_prob + spatial_prob
        return log_prob
예제 #3
0
def gaussian_mixture_sampler(num_latent,
                             num_mixtures=4,
                             weights=None,
                             means=None,
                             cov=None):
    """

    :param num_latent:
    :param num_mixtures:
    :param weights:
    :param means:
    :param cov:
    :return:
    """

    if weights is None:
        weights = torch.randn(num_latent, num_mixtures).softmax(dim=1)

    if means is None:
        means = torch.randn(num_latent, num_mixtures, 1) * 2

    if cov is None:
        cov = torch.randn(num_latent, num_mixtures, 1)

    mix = dist.Categorical(weights)
    comp = dist.Independent(dist.Normal(means, cov), 1)

    gmm = dist.MixtureSameFamily(mix, comp)

    return lambda n: gmm.sample((n, )).squeeze()
예제 #4
0
파일: dim_model.py 프로젝트: Czworldy/GP
    def _goal_likelihood(self, y: torch.Tensor, goal: torch.Tensor,
                         **hyperparams) -> torch.Tensor:
        """Returns the goal-likelihood of a plan `y`, given `goal`.
        Args:
          y: A plan under evaluation, with shape `[B, T, 2]`.
          goal: The goal locations, with shape `[B, K, 2]`.
          hyperparams: (keyword arguments) The goal-likelihood hyperparameters.

        Returns:
          The log-likelihodd of the plan `y` under the `goal` distribution.
        """
        # Parses tensor dimensions.
        B, K, _ = goal.shape

        # Fetches goal-likelihood hyperparameters.
        epsilon = hyperparams.get("epsilon", 1.0)

        # TODO(filangel): implement other goal likelihoods from the DIM paper
        # Initializes the goal distribution.
        goal_distribution = D.MixtureSameFamily(
            mixture_distribution=D.Categorical(
                probs=torch.ones((B, K)).to(goal.device)),
            component_distribution=D.Independent(
                D.Normal(loc=goal, scale=torch.ones_like(goal) * epsilon),
                reinterpreted_batch_ndims=1,
            ))

        return torch.mean(goal_distribution.log_prob(y[:, -1, :]), dim=0)
예제 #5
0
    def decoder(self,
                z,
                encoded_history,
                current_state,
                y_e=None,
                train=False):
        pass

        bs = encoded_history.shape[0]
        a_0 = F.dropout(self.action(current_state.reshape(bs, -1)),
                        self.dropout_p)
        state = F.dropout(self.state(encoded_history.reshape(bs, -1)),
                          self.dropout_p)

        current_state = current_state.unsqueeze(1)
        gauses = []
        inp = F.dropout(
            torch.cat((encoded_history.reshape(bs, -1), a_0), dim=-1),
            self.dropout_p)
        for i in range(12):
            h_state = self.gru(inp.reshape(bs, -1), state)

            _, deltas, log_sigmas, corrs = self.project_to_GMM_params(h_state)
            deltas = torch.clamp(deltas, max=1.5, min=-1.5)
            deltas = deltas.reshape(bs, -1, 2)
            log_sigmas = log_sigmas.reshape(bs, -1, 2)
            corrs = corrs.reshape(bs, -1, 1)

            mus = deltas + current_state
            current_state = mus
            variance = torch.clamp(torch.exp(log_sigmas).unsqueeze(2)**2,
                                   max=1e3)

            m_diag = variance * torch.eye(2).to(variance.device)
            sigma_xy = torch.clamp(torch.prod(torch.exp(log_sigmas), dim=-1),
                                   min=1e-8,
                                   max=1e3)

            if train:
                # log_pis = z.reshape(bs, 1) * torch.ones(bs, self.num_modes).cuda()
                log_pis = to_one_hot(z, n_dims=self.num_modes).cuda()

            else:
                log_pis = to_one_hot(z, n_dims=self.num_modes).cuda()
            log_pis = log_pis - torch.logsumexp(log_pis, dim=-1, keepdim=True)
            mix = D.Categorical(logits=log_pis)
            comp = D.MultivariateNormal(mus, m_diag)
            gmm = D.MixtureSameFamily(mix, comp)
            t = (sigma_xy * corrs.squeeze()).reshape(-1, 1, 1)
            cov_matrix = m_diag  # + anti_diag
            gauses.append(gmm)
            a_t = gmm.sample()  # possible grad problems?
            a_tt = F.dropout(self.action(a_t.reshape(bs, -1)), self.dropout_p)
            state = h_state
            inp = F.dropout(
                torch.cat((encoded_history.reshape(bs, -1), a_tt), dim=-1),
                self.dropout_p)

        return gauses
예제 #6
0
파일: layers.py 프로젝트: jw9730/setvae
    def forward(self, output_sizes, hold_seed=None, hold_initial_set=False):
        """
        Sample from prior
        :param output_sizes: Tensor([B,])
        :param hold_seed
        :param hold_initial_set
        :return: Tensor([B, N, D])
        """
        bsize = output_sizes.shape[0]
        if hold_initial_set:  # [B, N]
            x_mask = get_mask(output_sizes, self.max_outputs)
        else:
            x_mask = sample_mask(output_sizes, self.max_outputs)

        if hold_seed is not None:  # [B, N, Ds]
            torch.random.manual_seed(hold_seed)
            eps = torch.randn([1, self.max_outputs, self.dim_seed
                               ]).to(x_mask.device).repeat(bsize, 1, 1)
        else:
            eps = torch.randn([bsize, self.max_outputs,
                               self.dim_seed]).to(x_mask.device)

        if self.n_mixtures == 1:
            x = self.mu + torch.exp(self.logvar / 2.) * eps
        else:
            if self.train_gmm:
                if hold_seed is not None:
                    torch.random.manual_seed(hold_seed)
                    logits = self.logits.reshape([1, 1,
                                                  self.n_mixtures]).repeat(
                                                      1, self.max_outputs,
                                                      1)  # [1, N, M]
                    onehot = F.gumbel_softmax(
                        logits, tau=self.tau,
                        hard=True).repeat(bsize, 1,
                                          1).unsqueeze(-1)  # [B, N, M, 1]
                else:
                    logits = self.logits.reshape([1, 1,
                                                  self.n_mixtures]).repeat(
                                                      bsize, self.max_outputs,
                                                      1)  # [B, N, M]
                    onehot = F.gumbel_softmax(logits, tau=self.tau,
                                              hard=True).unsqueeze(
                                                  -1)  # [B, N, M, 1]
                mu = self.mu.reshape([1, 1, self.n_mixtures,
                                      self.dim_seed])  # [1, 1, M, D]
                sig = self.sig.reshape([1, 1, self.n_mixtures,
                                        self.dim_seed])  # [1, 1, M, D]
                mu = (mu * onehot).sum(2)  # [B, N, D]
                sig = (sig * onehot).sum(2)  # [B, N, D]
                x = mu + sig * eps
            else:
                mix = D.Categorical(self.logits)
                comp = D.Independent(D.Normal(self.mu, self.sig.abs()), 1)
                mixture = D.MixtureSameFamily(mix, comp)
                x = mixture.sample((output_sizes.size(0), self.max_outputs))

        x = self.output(x)  # [B, N, D]
        return x, x_mask
예제 #7
0
파일: synth.py 프로젝트: sinead/NPL
 def __init__(self, is_test):
     super().__init__()
     #self.flip_var_order = flip_var_order
     #if is_test:
     #self.pX = D.Uniform(torch.tensor([0.0]), torch.tensor([1.0]))
     #else:
     mix = D.Categorical(torch.ones(2, ))
     comp = D.Uniform(torch.tensor([0.0, 0.35]), torch.tensor([0.45, 1.0]))
     self.pX = D.MixtureSameFamily(mix, comp)
     self.pY1 = D.Uniform(torch.tensor([0.0]), torch.tensor([1.0]))
     self.pY2 = lambda X: D.Normal(torch.sin(10 * X), 0.05)
예제 #8
0
    def _parameterize_distribution(self,
                                   hidden: torch.Tensor) -> D.Distribution:
        mixture_logits = self.mixture_linear(hidden)
        mixture = F.softmax(mixture_logits, dim=-1)

        means = self.means_linear(hidden)[..., None]

        stddev = F.softplus(self.stddev_linear(hidden))[..., None]

        c = D.Categorical(probs=mixture)
        n = D.Independent(D.Normal(means, stddev), 1)

        return D.MixtureSameFamily(c, n)
예제 #9
0
파일: gaussianmix.py 프로젝트: sinead/NPL
    def __init__(self, pX=None):
        super().__init__()

        #self.flip_var_order = flip_var_order
        #if is_test:
        #self.pX = D.Uniform(torch.tensor([0.0]), torch.tensor([1.0]))
        #else:
        if pX is None:
            mix = D.Categorical(torch.ones(3, ))
            comp = D.Independent(
                D.Normal(torch.randn(3, 2), 0.3 * torch.ones(3, 2)), 1)
            self.pX = D.MixtureSameFamily(mix, comp)
        else:
            self.pX = pX
예제 #10
0
파일: eval.py 프로젝트: alexpostnikov/tped
    def __call__(self, scene: torch.Tensor, train=False):
        gmms = []
        for model in self.models:
            gmm = model(scene)
            gmms.append(gmm)

        combined_gmm = []
        for timestamp in range(12):
            logits = torch.cat([gmms[i][timestamp].mixture_distribution.logits for i in range(len(gmms))], dim=1)
            mus = torch.cat([gmms[i][timestamp].component_distribution.mean for i in range(len(gmms))], dim=1)
            variance = torch.cat([gmms[i][timestamp].component_distribution.variance for i in range(len(gmms))], dim=1)
            m_diag = variance.unsqueeze(2) * torch.eye(2).to(variance.device)
            mix = D.Categorical(logits)
            comp = D.MultivariateNormal(mus, m_diag)
            combined_gmm.append(D.MixtureSameFamily(mix, comp))
        return combined_gmm
예제 #11
0
    def dist(
        self,
        batch: dict[str, Union[torch.Tensor, list[torch.Tensor]]],
    ) -> distributions.Distribution:
        """Возвращает распределение доходности."""
        logits, mean, std = self(batch)

        try:
            weights_dist = distributions.Categorical(logits=logits)
        except ValueError:
            raise GradientsError(
                f"Ошибка при обновлении градиентов: NaN in Categorical distribution"
            )

        comp_dist = distributions.LogNormal(mean, std)

        return distributions.MixtureSameFamily(weights_dist, comp_dist)
예제 #12
0
파일: mixture.py 프로젝트: stnkl/probflow
    def __call__(self):
        """Get the distribution object from the backend"""
        if get_backend() == "pytorch":
            import torch
            import torch.distributions as tod

            # Convert to pytorch distributions if probflow distributions
            if isinstance(self.distributions, BaseDistribution):
                self.distributions = self.distributions()

            # Broadcast probs/logits
            shape = self.distributions.batch_shape
            args = {"logits": None, "probs": None}
            if self.logits is not None:
                args["logits"], _ = torch.broadcast_tensors(
                    self["logits"], torch.zeros(shape)
                )
            else:
                args["probs"], _ = torch.broadcast_tensors(
                    self["probs"], torch.zeros(shape)
                )

            # Return torcch distribution object
            return tod.MixtureSameFamily(
                tod.Categorical(**args), self.distributions
            )
        else:
            import tensorflow as tf
            from tensorflow_probability import distributions as tfd

            # Convert to tensorflow distributions if probflow distributions
            if isinstance(self.distributions, BaseDistribution):
                self.distributions = self.distributions()

            # Broadcast probs/logits
            shape = self.distributions.batch_shape
            args = {"logits": None, "probs": None}
            if self.logits is not None:
                args["logits"] = tf.broadcast_to(self["logits"], shape)
            else:
                args["probs"] = tf.broadcast_to(self["probs"], shape)

            # Return TFP distribution object
            return tfd.MixtureSameFamily(
                tfd.Categorical(**args), self.distributions
            )
예제 #13
0
def SampleGMM_detach(nsamps):
    global mu_pr1
    global var_pr1

    Z = shape(mu_pr1)[1]
    K = shape(mu_pr1)[0]
    alpha_pr = torch.zeros(K)
    for k in range(K):
        alpha_pr[k] = 1.0 / K

    mix = distributions.Categorical(alpha_pr)
    comp = distributions.MultivariateNormal(mu_pr1.detach(), var_pr1.detach())
    gmm = distributions.MixtureSameFamily(mix, comp)
    sample = torch.zeros(nsamps, Z).to(device)
    sample = gmm.sample((nsamps, ))

    return sample
예제 #14
0
    def predict():
        mdrnn.eval()

        preds = []
        gt = []
        n_episodes = test_dataset[-1][-2] + 1
        predictions = [[] for _ in range(n_episodes)]
        with torch.no_grad():
            for batch_index, (states, actions, next_states, rewards, episode, timesteps) in enumerate(test_loader):

                states = states.to(device)
                next_states = next_states.to(device)
                rewards = rewards.to(device)
                actions = actions.to(device)

                latent_obs, _ = to_latent(states,
                    next_states, batch_size=1,sequence_horizon=1)

                # Check model's next state predictions
                mus, sigmas, logpi, _ , _, _ = mdrnn(actions, latent_obs)
                mix = D.Categorical(logpi)
                comp = D.Independent(D.Normal(mus, sigmas), 1)
                gmm = D.MixtureSameFamily(mix, comp)
                sample = gmm.sample()

                decoded_states = vae.decoder(sample).squeeze(0)
                decoded_states = decoded_states.cpu().detach().numpy()
                preds.append(decoded_states)

                for i in range(len(states)):
                    predictions[episode[i].int()].append(np.expand_dims(decoded_states[i], axis=0))


                gt.append(next_states.cpu().detach().numpy())
            #import pdb;pdb.set_trace()
            predictions = [np.stack(p) for p in predictions]
            preds = np.asarray(preds)
            gt = np.asarray(gt).squeeze(1)
            error = (preds - gt)**2

        path = cfg.logdir + '/' + cfg.resname + '.pkl'
        pickle.dump(predictions, open(path, 'wb'))

        print("Mean Error: {}".format(error.mean(0)[0]))
        print("Min  Error: {}".format(error.min(0)[0]))
        print("Max  Error: {}".format(error.max(0)[0]))
예제 #15
0
    def sample(self, n) -> Iterable[Tuple[Individual, t.Tensor]]:
        samples = []

        components = d.Normal(loc=self.component_means, scale=self.std)
        for i in range(n):
            log_p = 0.0
            params = {}
            for k, logits in self.mixing_logits.items():
                mix = d.Categorical(logits=logits)
                expanded = components.expand(mix.batch_shape +
                                             components.batch_shape)
                gmm = d.MixtureSameFamily(
                    mix, d.Independent(expanded,
                                       self.component_means.ndim - 1))
                with t.no_grad():
                    sample = gmm.sample()
                params[k] = sample
                log_p += gmm.log_prob(sample).sum()

            samples.append((self.constructor(params), log_p))

        return samples
예제 #16
0
파일: gmm.py 프로젝트: rfeinman/Sketch-RNN
def sample_gmm(mix_logp, means, scales, corrs):
    covs = compute_cov2d(scales, corrs)
    mix = D.Categorical(mix_logp.exp())
    comp = D.MultivariateNormal(means, covs)
    gmm = D.MixtureSameFamily(mix, comp)
    return gmm.sample()
예제 #17
0
    def decoder(self, z, encoded_history, current_state, train=False):

        bs = encoded_history.shape[0]
        a_0 = F.dropout(self.action(current_state.reshape(bs, -1)),
                        self.dropout_p)

        # state = self.bn3(F.dropout(self.state(encoded_history.reshape(bs, -1)), self.dropout_p))
        # state = self.ln3(F.dropout(self.state(encoded_history.reshape(bs, -1)), self.dropout_p))
        state = F.dropout(self.state(encoded_history.reshape(bs, -1)),
                          self.dropout_p)

        current_state = current_state.unsqueeze(1)
        gauses = []
        lp = to_one_hot(z, n_dims=self.num_modes).to(encoded_history.device)
        # lp = z.reshape(bs, 1) * torch.ones(bs, self.num_modes).cuda()
        # lp = z
        inp = F.dropout(
            torch.cat((encoded_history.reshape(bs, -1), a_0, 0 * lp), dim=-1),
            self.dropout_p)

        for i in range(12):
            # h_state = self.ln4(self.gru(inp.reshape(bs, -1), state))
            input = inp.reshape(bs, -1)
            # input = self.gru_prep(inp.reshape(bs, -1))
            h_state = self.gru(input, state)
            # h_state = self.bn4(self.gru(inp.reshape(bs, -1), state))

            _, deltas, log_sigmas, corrs = self.project_to_gmm_params(h_state)
            deltas = torch.clamp(deltas, max=1.5, min=-1.5)
            deltas = deltas.reshape(bs, -1, 2)
            log_sigmas = log_sigmas.reshape(bs, -1, 2)
            corrs = corrs.reshape(bs, -1, 1)

            mus = deltas + current_state
            current_state = mus
            variance = torch.clamp(torch.exp(log_sigmas).unsqueeze(2)**2,
                                   max=1e3,
                                   min=1e-3)

            m_diag = variance * torch.eye(2).to(variance.device)
            sigma_xy = torch.clamp(torch.prod(torch.exp(log_sigmas), dim=-1),
                                   min=1e-3,
                                   max=1e3)

            if train:
                # log_pis = z.reshape(bs, 1) * torch.ones(bs, self.num_modes).cuda()
                log_pis = to_one_hot(z, n_dims=self.num_modes).to(
                    encoded_history.device)

                # log_pis = z

            else:
                log_pis = to_one_hot(z, n_dims=self.num_modes).to(
                    encoded_history.device)
                # log_pis = z

            log_pis = log_pis - torch.logsumexp(log_pis, dim=-1, keepdim=True)
            mix = D.Categorical(logits=log_pis)
            scale_tril = torch.cholesky(m_diag.cpu()).to(z.device)
            comp = D.MultivariateNormal(mus, scale_tril=scale_tril)

            gmm = D.MixtureSameFamily(mix, comp)
            t = (sigma_xy * corrs.squeeze()).reshape(-1, 1, 1)
            # cov_matrix = m_diag  # + anti_diag
            gauses.append(gmm)
            a_t = gmm.sample()  # TODO possible grad problems?
            a_tt = F.dropout(self.action(a_t.reshape(bs, -1)), self.dropout_p)
            state = h_state
            # input = self.gru_prep(torch.cat((encoded_history.reshape(bs, -1), a_tt, lp), dim=-1))
            input = torch.cat((encoded_history.reshape(bs, -1), a_tt, 0 * lp),
                              dim=-1)
            inp = F.dropout(input, self.dropout_p)
        return gauses
예제 #18
0
 def create_gmm(self, log_w_t, mu_t, log_sig_t):
     mix = D.Categorical(logits=log_w_t)  # Batchsize x K
     # refer https://github.com/pytorch/pytorch/pull/22742/files
     comp = D.Independent(D.Normal(mu_t, torch.exp(log_sig_t)), 1)
     # Individual Distribution = Batchsize x K x Da
     return D.MixtureSameFamily(mix, comp)
예제 #19
0
    def forward(self, scene: torch.Tensor):
        """
        :param scene: tensor of shape num_peds, history_size, data_dim
        :return: predicted poses distributions for each agent at next 12 timesteps
        """
        bs = scene.shape[0]
        poses = scene[:, :, :2]
        pv = scene[:, :, 2:6]
        vel = scene[:, :, 2:4]
        acc = scene[:, :, 4:6]
        pav = scene[:, :, :6]

        lstm__poses_out, _ = self.node_hist_encoder_poses(
            poses)  # lstm_out shape num_peds, timestamps ,  2*hidden_dim
        lstm_out_acc, hid = self.node_hist_encoder_acc(
            acc)  # lstm_out shape num_peds, timestamps ,  2*hidden_dim
        lstm_out_vell, hid = self.node_hist_encoder_vel(
            vel)  # lstm_out shape num_peds, timestamps ,  2*hidden_dim
        # lstm_out_poses, hid = self.node_hist_encoder_poses(poses)
        lstm_out = lstm_out_vell + lstm_out_acc  # + lstm_out_poses

        current_state = poses[:, -1, :]
        # np, data_dim = current_pose.shape
        bs, seq, data_dim = poses.shape
        stacked = poses.permute(1, 0, 2).reshape(seq,
                                                 -1).repeat(1, bs).reshape(
                                                     seq, bs, bs * data_dim)
        deltas = (stacked - poses.permute(1, 0, 2).repeat(1, 1, bs))
        deltas = deltas.permute(1, 0, 2).reshape(bs, seq, bs, data_dim)
        deltas_flat = deltas.reshape(deltas.shape[0], deltas.shape[1],
                                     -1).cuda()
        max_size = 50  # TODO: fix
        prep_for_deltas = torch.zeros(bs, seq, 50).cuda()
        if deltas_flat.shape[2] >= max_size:
            prep_for_deltas = deltas_flat[:, :, :max_size]
        else:
            prep_for_deltas[:, :, :deltas_flat.shape[2]] = deltas_flat
        at_hidden = self.att.init_hidden(bs=bs)
        for i in range(8):
            at_output, at_hidden, at_normalized_weights = self.att(
                at_hidden, lstm__poses_out[:, i:i + 1, :],
                prep_for_deltas[:, i:i + 1, :])
        # current_pose = scene[:, -1, :2]  # num_people, data_dim
        # stacked = current_pose.flatten().repeat(np).reshape(np, np * data_dim)
        # deltas = (stacked - current_pose.repeat(1, np)).reshape(np, np, data_dim)  # np, np, data_dim

        # distruction, _ = self.edge_encoder(deltas, poses, poses)
        catted = torch.cat((lstm_out[:, -1:, :], at_output[:, -1:, :]), dim=2)
        a_0 = F.dropout(self.action(current_state.reshape(bs, -1)),
                        self.dropout_p)
        state = F.dropout(self.state(catted.reshape(bs, -1)), self.dropout_p)

        current_state = current_state.unsqueeze(1)
        gauses = []
        inp = F.dropout(torch.cat((catted.reshape(bs, -1), a_0), dim=-1),
                        self.dropout_p)
        for i in range(12):
            h_state = self.gru(inp.reshape(bs, -1), state)

            log_pis, deltas, log_sigmas, corrs = self.project_to_GMM_params(
                h_state)
            deltas = torch.clamp(deltas, max=1.5, min=-1.5)

            log_pis = log_pis.reshape(bs, -1)
            log_pis = log_pis - torch.logsumexp(log_pis, dim=-1, keepdim=True)
            deltas = deltas.reshape(bs, -1, 2)
            log_sigmas = log_sigmas.reshape(bs, -1, 2)
            corrs = corrs.reshape(bs, -1, 1)

            mus = deltas + current_state
            current_state = mus
            variance = torch.clamp(torch.exp(log_sigmas).unsqueeze(2)**2,
                                   max=1e3)

            m_diag = variance * torch.eye(2).to(variance.device)
            sigma_xy = torch.clamp(torch.prod(torch.exp(log_sigmas), dim=-1),
                                   min=1e-8,
                                   max=1e3)

            mix = D.Categorical(log_pis)
            comp = D.MultivariateNormal(mus, m_diag)
            gmm = D.MixtureSameFamily(mix, comp)
            t = (sigma_xy * corrs.squeeze()).reshape(-1, 1, 1)
            cov_matrix = m_diag  # + anti_diag
            gauses.append(gmm)
            a_t = gmm.sample()  # possible grad problems?
            state = h_state
            inp = F.dropout(torch.cat((catted.reshape(bs, -1), a_t), dim=-1),
                            self.dropout_p)
        return gauses
예제 #20
0
 def position_log_prob(self, x):
     # Computing the log probability over only the positions.
     component_dist = td.MultivariateNormal(loc=self.component_distribution.mean[..., :2],
                                            scale_tril=self.component_distribution.scale_tril[..., :2, :2])
     position_dist = td.MixtureSameFamily(self.mixture_distribution, component_dist)
     return position_dist.log_prob(x)
예제 #21
0
    def forward(self, scene: torch.Tensor):
        """
        :param scene: tensor of shape num_peds, history_size, data_dim
        :return: predicted poses distributions for each agent at next 12 timesteps
        """
        bs = scene.shape[0]
        poses = scene[:, :, :2]
        pv = scene[:, :, 2:6]
        vel = scene[:, :, 2:4]
        acc = scene[:, :, 4:6]
        pav = scene[:, :, :6]

        # lstm_out, hid = self.node_hist_encoder(pav)  # lstm_out shape num_peds, timestamps ,  2*hidden_dim
        lstm_out_acc, hid = self.node_hist_encoder_acc(
            acc)  # lstm_out shape num_peds, timestamps ,  2*hidden_dim
        lstm_out_vell, hid = self.node_hist_encoder_vel(
            vel)  # lstm_out shape num_peds, timestamps ,  2*hidden_dim
        lstm_out_poses, hid = self.node_hist_encoder_poses(poses)
        lstm_out = lstm_out_vell + lstm_out_poses + lstm_out_acc
        # lstm_out = lstm_out_poses  # + lstm_out_poses

        current_pose = scene[:, -1, :2]  # num_people, data_dim
        current_state = poses[:, -1, :]
        np, data_dim = current_pose.shape
        stacked = current_pose.flatten().repeat(np).reshape(np, np * data_dim)
        deltas = (stacked - current_pose.repeat(1, np)).reshape(
            np, np, data_dim)  # np, np, data_dim

        distruction, _ = self.edge_encoder(deltas)
        catted = torch.cat((lstm_out[:, -1:, :], distruction[:, -1:, :]),
                           dim=1)
        a_0 = F.dropout(self.action(current_state.reshape(bs, -1)),
                        self.dropout_p)
        state = F.dropout(self.state(catted.reshape(bs, -1)), self.dropout_p)

        current_state = current_state.unsqueeze(1)
        gauses = []
        inp = F.dropout(torch.cat((catted.reshape(bs, -1), a_0), dim=-1),
                        self.dropout_p)
        for i in range(12):
            h_state = self.gru(inp.reshape(bs, -1), state)

            log_pis, deltas, log_sigmas, corrs = self.project_to_GMM_params(
                h_state)
            deltas = torch.clamp(deltas, max=1.5, min=-1.5)

            log_pis = log_pis.reshape(bs, -1)
            log_pis = log_pis - torch.logsumexp(log_pis, dim=-1, keepdim=True)
            deltas = deltas.reshape(bs, -1, 2)
            log_sigmas = log_sigmas.reshape(bs, -1, 2)
            corrs = corrs.reshape(bs, -1, 1)

            mus = deltas + current_state
            current_state = mus
            variance = torch.clamp(torch.exp(log_sigmas).unsqueeze(2)**2,
                                   max=1e3)

            m_diag = variance * torch.eye(2).to(variance.device)
            sigma_xy = torch.clamp(torch.prod(torch.exp(log_sigmas), dim=-1),
                                   min=1e-8,
                                   max=1e3)

            mix = D.Categorical(log_pis)
            comp = D.MultivariateNormal(mus, m_diag)
            gmm = D.MixtureSameFamily(mix, comp)
            t = (sigma_xy * corrs.squeeze()).reshape(-1, 1, 1)
            cov_matrix = m_diag  # + anti_diag
            gauses.append(gmm)
            a_t = gmm.sample()  # possible grad problems?
            a_tt = F.dropout(self.action(a_t.reshape(bs, -1)), self.dropout_p)
            state = h_state
            inp = F.dropout(torch.cat((catted.reshape(bs, -1), a_tt), dim=-1),
                            self.dropout_p)

        return gauses
예제 #22
0
    y = torch.pow(torch.abs(x), 1)
    y.mean().backward()
    print(x.grad)
    main()
    def log_prob(value, loc=0, var=1, p=0.7 ):
        # loc = torch.Tensor(loc)
        # var = torch.Tensor(var)
        # return -torch.log(2 * var) - torch.abs(var - loc)**p / var

        return -(torch.abs(value - loc) ** p) / var - math.log(2 * var) 
    
    value = torch.linspace(-100,100,2000)
    x = torch.distributions.Laplace(0,1)
    y = torch.distributions.Normal(0,1)
    
    mix = dist.Categorical(torch.ones(5,))
    comp = dist.Normal(0, torch.rand(5,))
    # z = torch.distributions.MixtureSameFamily
    gmm = dist.MixtureSameFamily(mix, comp)
    plt.plot(value, torch.exp(x.log_prob(value)),color='blue')
    plt.plot(value, torch.exp(y.log_prob(value)),color='red')
    plt.plot(value, torch.exp(gmm.log_prob(value)),color='green')
    plt.plot(value, torch.exp(log_prob(value)),color='yellow')

    # plt.ylim(10**0, -10**2)
    plt.yscale('log')
    
    # plt.fill_between(value, torch.exp(x.log_prob(value)))
    plt.show()

    # main()
예제 #23
0
        nclus, P,
        device=device) + args.clus_sep * torch.randn(nclus, P, device=device)
    var_pr1 = torch.zeros(nclus, P, P, device=device)
    #store centers of gaussian inputs to prior-encoder
    clscenlog = 'ClusterCenters_CELEBA_EPSWAE.txt'
    for i in range(nclus):
        var_pr1[i] = torch.eye(P, device=device).detach()
    np.savetxt(clscenlog, mu_pr1.detach().numpy())

    alpha_pr = torch.zeros(nclus, device=device)
    for k in range(nclus):
        alpha_pr[k] = 1.0 / nclus

    mix = distributions.Categorical(alpha_pr)
    comp = distributions.MultivariateNormal(mu_pr1, var_pr1)
    gmm = distributions.MixtureSameFamily(mix, comp)

    for epoch in range(1, args.epochs + 1):
        recon_batch, data = train(epoch)

        if (epoch > 0 and epoch % args.save_interval == 0):
            torch.save(model.state_dict(),
                       'models/CELEBA_EPSWAE_AE_e' + str(epoch))
            torch.save(prior.state_dict(),
                       'models/CELEBA_EPSWAE_PE_e' + str(epoch))

        #generate sample plots
        sample = prior.GenNsamples_NormPrior(64)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 3, leng, leng),
                   'results/CelebA_EPSWAE_Sample_e_' + str(epoch) + '.png')
예제 #24
0
def kde_pig_dl(
    dm: pl.LightningDataModule,
    batch_size: int,
    N_hat_multiplier: float = 1,
) -> DataLoader:
    # %
    gd_n_steps, gd_lr, gd_threshold = 5, 4e-1, 0.005

    # Spherical = each component has single variance.
    bgm = BayesianGaussianMixture(
        n_components=batch_size,
        covariance_type="spherical",
        warm_start=True,
    )

    x_hat = torch.Tensor()
    for idx, batch in enumerate(iter(dm.train_dataloader())):
        x, _ = batch
        device = x.device
        x = x.detach().cpu().numpy()
        # Last batch might have less elements than origin n_components
        if x.shape[0] < bgm.n_components:
            bgm = BayesianGaussianMixture(
                n_components=x.shape[0],
                covariance_type="spherical",
            )
        # Estimate KDE
        bgm.fit(x)
        # [N_components, 1], [N_components, N_features], [N_components, 1]
        weights, means, variances = (
            torch.Tensor(bgm.weights_).to(device),
            torch.Tensor(bgm.means_).to(device),
            torch.Tensor(bgm.covariances_).to(device),
        )
        filter_weights_idx = weights >= 1e-5
        weights, means, variances = (
            weights[filter_weights_idx],
            means[filter_weights_idx],
            variances[filter_weights_idx][:, None],
        )
        n_selected_components = weights.shape[0]
        p_x = D.Independent(D.Normal(means, torch.sqrt(variances)), 1)
        mix = D.Categorical(weights)
        p_x = D.MixtureSameFamily(mix, p_x)
        # Sample according to multiplier
        x_start = p_x.sample(
            (
                n_selected_components
                * ((batch_size // n_selected_components) + 1)
                * N_hat_multiplier,
            )
        ).reshape(-1, x.shape[1])
        # Use GD
        _x_hat = density_gradient_descent(
            p_x,
            x_start,
            {"N_steps": gd_n_steps, "lr": gd_lr, "threshold": gd_threshold},
        )
        # Ensure same device
        if x_hat.device != device:
            x_hat = x_hat.to(device)
        x_hat = torch.cat((x_hat, _x_hat.detach()))

    dl = DataLoader(TensorDataset(x_hat), batch_size=batch_size, shuffle=True)
    return dl
예제 #25
0
def MixtureLogistic(logits, loc, scale):
    return D.MixtureSameFamily(
        D.Categorical(logits=logits),
        Logistic(loc, scale),
    )