Esempio n. 1
0
    def gen_data(self):
        # sample overall relative abundances of ASVs from a Dirichlet distribution
        self.ASV_rel_abundance = tdist.Dirichlet(torch.ones(
            self.numASVs)).sample()

        # sample spatial embedding of ASVs
        self.w = torch.zeros(self.numASVs, self.D)
        w_prior = tdist.MultivariateNormal(torch.zeros(self.D),
                                           torch.eye(self.D))

        for o in range(0, self.numASVs):
            self.w[o, :] = w_prior.sample()

        self.data = torch.zeros(self.numParticles, self.numASVs)

        num_nonempty = 0

        mu_prior = tdist.MultivariateNormal(torch.zeros(self.D),
                                            torch.eye(self.D))
        rad_prior = tdist.LogNormal(torch.tensor([self.mu_rad]),
                                    torch.tensor([self.mu_std]))

        # replace with neg bin prior
        num_reads_prior = tdist.Poisson(
            torch.tensor([self.avgNumReadsParticle]))

        while (num_nonempty < self.numParticles):
            # sample center
            mu = mu_prior.sample()
            rad = rad_prior.sample()

            zr = torch.zeros(1, self.numASVs, dtype=torch.float64)
            for o in range(0, self.numASVs):
                p = mu - self.w[o, :]
                p = torch.pow(p, 2.0) / rad
                p = (torch.sum(p)).sqrt()
                zr[0, o] = unitboxcar(p, 0.0, 2.0, self.step_approx)

            if torch.sum(zr) > 0.95:
                particle = Particle(mu, self)
                particle.zr = zr
                self.particles.append(particle)

                # renormalize particle abundances
                rn = self.ASV_rel_abundance * zr
                rn = rn / torch.sum(rn)

                # sample relative abundances for particle
                part_rel_abundance = tdist.Dirichlet(rn * self.conc).sample()

                # sample number of reads for particle
                # (replace w/ neg bin instead of Poisson)
                num_reads = num_reads_prior.sample().long().item()
                particle.total_reads = num_reads

                particle.reads = tdist.Multinomial(
                    num_reads, probs=part_rel_abundance).sample()

                num_nonempty += 1
Esempio n. 2
0
    def __init__(self, nb_states, obs_dim, act_dim,
                 prior, norm, device, **kwargs):
        super(ParametricAugmentationRegressor, self).__init__()

        self.device = device

        self.nb_states = nb_states
        self.obs_dim = obs_dim
        self.act_dim = act_dim

        # Dirichlet parameters
        self.prior = {'alpha': torch.as_tensor(prior['alpha'], dtype=torch.float32, device=self.device),
                      'kappa': torch.as_tensor(prior['kappa'], dtype=torch.float32, device=self.device)}

        # Normalization parameters
        self.norm = {'mean': torch.as_tensor(norm['mean'], dtype=torch.float32, device=self.device),
                     'std': torch.as_tensor(norm['std'], dtype=torch.float32, device=self.device)}

        self.dirichlets = []
        alphas = self.prior['alpha'] * torch.ones(self.nb_states, dtype=torch.float32, device=self.device)
        for k in range(self.nb_states):
            kappas = self.prior['kappa'] * torch.as_tensor(torch.arange(self.nb_states) == k,
                                                           dtype=torch.float32, device=self.device)
            self.dirichlets.append(dist.Dirichlet(alphas + kappas, validate_args=True))

        self.optim = None
Esempio n. 3
0
    def _expand_node(
            self,
            trees: _MCTSTree,
            n,  # n-th expansion, zero-based
            to_plays,
            model_output: ModelOutput,
            dirichlet_alpha=None,
            exploration_fraction=0.):
        if self._is_two_player_game:
            trees.to_play[:, n] = to_plays
        if trees.game_over is not None:
            trees.game_over[:, n] = model_output.game_over

        def _set_tree_state(ts, s):
            ts[:, n] = s

        nest.map_structure(_set_tree_state, trees.model_state,
                           model_output.state)
        if trees.reward is not None:
            trees.reward[:, n] = model_output.reward
        if trees.action is not None:
            trees.action[:, n] = model_output.actions
        prior = model_output.action_probs

        if exploration_fraction > 0.:
            batch_size = model_output.action_probs.shape[0]
            noise_dist = td.Dirichlet(
                dirichlet_alpha * torch.ones(trees.branch_factor))
            noise = noise_dist.sample((batch_size, ))
            noise = noise * (prior != 0)
            noise = noise / noise.sum(dim=1, keepdim=True)
            prior = exploration_fraction * noise + (
                1 - exploration_fraction) * prior

        trees.prior[:, n] = prior
Esempio n. 4
0
 def prev(self) -> dist.Distribution:
     """
     Prevalance for each of the categories in a Dirichlet distribution so it adds up
     to 1.
     """
     return dist.Dirichlet(
         torch.ones(self.num_categories) * (1.0 / self.num_categories))
Esempio n. 5
0
 def __init__(self, K, D):
     
     super().__init__()
     
     self.alpha = nn.Parameter(torch.ones(K), requires_grad=True)
     self.mu    = nn.Parameter(torch.randn(K, D)*0.2, requires_grad=True)
     self.chol  = nn.Parameter(torch.stack([torch.eye(D,D)*0.3]*K), requires_grad=True)
     self.dir   = td.Dirichlet(self.alpha)
Esempio n. 6
0
    def encoder(self,
                fusion_input,
                enc_hx,
                enc_cx,
                lstm,
                classifier,
                test=False):
        # fusion_input = self.feature_extractor(camera_input, sensor_input)
        enc_hx, enc_cx = lstm(fusion_input, (enc_hx, enc_cx))
        enc_score = classifier(enc_hx)

        if self.dirichlet:
            if self.method == 'Mean':
                enc_score_soft = self.softplus(enc_score)
                dist = distributions.Dirichlet(enc_score_soft)
                enc_score = dist.mean

            elif self.method == 'Sample':
                enc_score_soft = self.softplus(enc_score)
                dist = distributions.Dirichlet(enc_score_soft)
                enc_score = dist.rsample()

        if test:
            if self.var_method == 'covariance':
                var = dist.variance
                diagonal = np.diag(var.cpu().numpy()[0])
                con = dist.concentration
                con0 = con.sum(-1, True)
                con = con.cpu().numpy()[0]
                d = (con0.pow(2) * (con0 + 1))
                l = len(var.cpu().numpy()[0])
                for i in range(l):
                    for j in range(l):
                        if i != j:
                            diagonal[i][j] = -con[i] * con[j] / d

                return enc_hx, enc_cx, enc_score, diagonal

            elif self.var_method == 'diagonal':
                var = dist.variance
                diagonal = np.diag(var.cpu().numpy()[0])
                # print(diagonal)
                return enc_hx, enc_cx, enc_score, diagonal

        return enc_hx, enc_cx, enc_score
Esempio n. 7
0
def test_masked_dirichlet(K=3):
    mask = make_faces(K)
    w, con = get_parameters(mask, dir_alpha=None, gamma_alpha=1, gamma_beta=1)
    p = MaskedDirichlet(mask, con)
    q = MaskedDirichlet(mask, torch.ones_like(con))
    assert (torch.where(torch.logical_not(mask), p.concentration, torch.zeros_like(con)) == 0).all(), "Masked concentration parameters should be 0.0"
    assert (torch.where(mask, p.concentration, torch.ones_like(con)) > 0).all(), "Unmasked concentration parameters should be strictly positive"
    for i, face in enumerate(mask):
        idx = tuple(k for k, b in enumerate(face) if b)
        alphas = con[i,idx]
        p_low = td.Dirichlet(alphas)
        q_low = td.Dirichlet(torch.ones_like(alphas))
        assert torch.isclose(p.mean[i,idx], p_low.mean).all(), f"The {i}th face's mean does not match that of td.Dirichlet"
        assert torch.isclose(p.variance[i,idx], p_low.variance).all(), f"The {i}th face's variance does not match that of td.Dirichlet"
        assert torch.isclose(p.entropy()[i], p_low.entropy()).all(), f"The {i}th face's entropy does not match that of td.Dirichlet"
        assert (p.dim[i] == len(idx)), f"The dimensionality of the {i}th face is incorrect: got {p.dim[i]}, expected {len(idx)}"
        x = p.rsample()
        assert torch.isclose(p.log_prob(x)[i], p_low.log_prob(x[i,idx])).all(), "The log_prob of a sample does not match that assigned by td.Dirichlet"
        assert torch.isclose(td.kl_divergence(p, q)[i], td.kl_divergence(p_low, q_low)).all(), "The KL divergence does not match that of td.Dirichlet"
        assert torch.isclose(td.kl_divergence(q, p)[i], td.kl_divergence(q_low, p_low)).all(), "The KL divergence does not match that of td.Dirichlet"
Esempio n. 8
0
    def test_calculate_exploration_policy(self):
        dim = 400
        batch_size = 1000
        tol = 1e-6

        dist = td.Dirichlet(torch.full([dim], 0.25))
        prior = dist.sample((batch_size, ))
        value = torch.rand([batch_size, dim])
        c = torch.rand([batch_size, 1]) + 0.01
        for i in range(10):
            t = time.time()
            p, iterations = calculate_exploration_policy(value, prior, c, tol)
            t = time.time() - t
            logging.info("time=%s iterations=%s" % (t, iterations))
        self.assertTrue(((p.sum(dim=1) - 1).abs() < tol).all())
    def find_params(self, data: List[List[str]]) -> List[float]:
        # phi
        self.word_topics_distribution = dists.Dirichlet(
            torch.ones(self.num_topics, self.vocabulary_size)).sample()

        # theta
        self.document_topic_distribution = dists.Dirichlet(
            torch.ones(len(data), self.num_topics)).sample()

        # z
        self.topic_assignments = [
            dists.Categorical(probas[None].expand(
                [len(data[i]), self.num_topics]))
            for i, probas in enumerate(self.document_topic_distribution)
        ]

        history = []
        for index, document in enumerate(data):
            self.document_mapping[" ".join(document)] = index

        for _ in tqdm.trange(self.num_optim_steps, desc="Optim step"):
            self.run_gibbs_step(data)
            history.append(self.get_perplexity(data, -1))
        return history
Esempio n. 10
0
    def encoder(self, camera_input, sensor_input, enc_hx, enc_cx, enc_score):
        before_score = enc_score
        fusion_input = self.feature_extractor(camera_input, sensor_input)
        enc_hx, enc_cx = self.lstm(fusion_input, (enc_hx, enc_cx))

        before_score = before_score.unsqueeze(2)                #[32,22,1] 
        before_score = before_score * self.weight               #[32,22, 4096]
        before_score = torch.sum(before_score, dim=1)           #[32,4096]
        hx_enc = torch.add(enc_hx, before_score)

        enc_score = self.classifier(hx_enc)

        ## add dirichlet process
        enc_score_soft = self.softplus(enc_score)
        dist = distributions.Dirichlet(enc_score_soft)
        enc_score = dist.mean

        return enc_hx, enc_cx, enc_score
Esempio n. 11
0
    def forward(self, encoder_out: Dict[str, torch.LongTensor],
                salience_values) -> Dict[str, torch.Tensor]:
        mask = encoder_out['source_mask']
        seq_len = mask.sum(1)

        # shape: (batch_size, seq_len, 1)
        regression_output = self.regression(encoder_out['encoder_outputs'])
        # Sampling dirichlet
        alphas = torch.relu(regression_output).squeeze(dim=2) + 1e-6
        d_sample = lambda x: D.Dirichlet(x).rsample()
        loss = []
        for idx, alpha in enumerate(alphas):
            predicted_salience = torch.Tensor(d_sample(alpha[:seq_len[idx]]))
            loss.append(
                self._get_loss(predicted_salience,
                               salience_values[idx][:seq_len[idx]]))
        loss = torch.stack(loss).mean()
        if torch.isnan(loss):
            raise ValueError("nan loss encountered")
        output_dict = {'loss': loss}
        return output_dict
Esempio n. 12
0
def pick_cell_types(uni_labels, alpha, min_n_cells):
    '''
    Pick cell types to include in synthetic spots with proportions from 
    Dirichlet distribution.
    
    Parameters
    ----------
    uni_labels: np.array
        unique labels
    alpha: np.array 
        dirichlet distribution concentration value 
        (can be from cell type proportions in ST)
        
    Return
    ------
    tuple of picked cell types and proportions
    
    '''
    # get number of different
    # cell types present
    n_labels = uni_labels.shape[0]

    # sample number of types to be present at current spot
    # w/o having more types than cells
    n_types = dists.uniform.Uniform(low=1,
                                    high=min([n_labels, min_n_cells])).sample()

    n_types = n_types.round().type(t.int)

    # select which types to include
    pick_types = t.randperm(n_labels)[0:n_types]
    alpha = t.Tensor(np.array(alpha[pick_types]))

    # select cell type proportions
    member_props = dists.Dirichlet(concentration=alpha * t.ones(n_types)).sample()
    return ((pick_types, member_props))
Esempio n. 13
0
 def log_prob(self, value):
     return dists.Dirichlet(self.alpha).log_prob(value)
Esempio n. 14
0
    def encoder(self,
                camera_input,
                sensor_input,
                enc_hx,
                enc_cx,
                d_enc_hx,
                d_enc_cx,
                enc_score,
                delta,
                test=False):
        before_score = enc_score

        fusion_input = self.feature_extractor(camera_input, sensor_input)

        enc_hx, enc_cx = self.lstm_oad(fusion_input, (enc_hx, enc_cx))
        d_enc_hx, d_enc_cx = self.lstm_delta(fusion_input,
                                             (d_enc_hx, d_enc_cx))  # [32,4096]

        ## weighted embedding
        # print(before_score.sum(1))
        before_score = before_score.unsqueeze(2)  #[32,22,1]
        before_score = before_score * self.weight  #[32,22, 4096]
        before_score = torch.sum(before_score, dim=1)  #[32,4096]
        hx_enc = torch.add(d_enc_hx, before_score)  #[32,4096]

        # new_enc_hx = torch.add(enc_hx, before_score)
        # enc_score = self.classifier_oad(new_enc_hx)

        enc_score = self.classifier_oad(enc_hx)
        delta_score = self.classifier_delta(hx_enc)
        delta_var = self.classifier_deltav(hx_enc)

        if self.dirichlet:
            if self.method == 'Mean':
                enc_score_soft = self.softplus(enc_score)
                dist = distributions.Dirichlet(enc_score_soft)
                enc_score = dist.mean

            elif self.method == 'Sample':
                enc_score_soft = self.softplus(enc_score)
                dist = distributions.Dirichlet(enc_score_soft)
                enc_score = dist.rsample()

        ### diagonal
        delta_var_soft = self.softplus(delta_var)
        diagonal = []
        for i in range(len(delta_var_soft)):
            diag = torch.diag(delta_var_soft[i])
            diagonal.append(diag)
        diagonal = torch.stack(diagonal)
        norm_dist = distributions.MultivariateNormal(delta_score, diagonal)
        delta_score = norm_dist.rsample()

        if self.loss_method == 'state_before':
            var = dist.variance
            enc_var = [torch.diag(var_i) for var_i in var]  #(32,22,22)
            enc_var = torch.stack(enc_var, dim=0)

            delta_var = norm_dist.covariance_matrix  #(32,22,22)

        if test:
            if self.var_method == 'covariance':
                con = dist.concentration
                con0 = con.sum(-1, True)
                d = (con0.pow(2) * (con0 + 1))
                con = con.cpu().numpy()[0]
                con_s = np.reshape(con, (22, 1))
                con_t = np.reshape(con, (1, 22))
                diagonal = -con_s * con_t
                diagonal /= d.cpu().numpy()[0]
                var = dist.variance.cpu().numpy()[0]
                np.fill_diagonal(diagonal, var)
                return enc_hx, enc_cx, enc_score, diagonal

            elif self.var_method == 'diagonal':
                if delta == False:
                    var = dist.variance
                    # print('OAD variance')
                    # print(var.cpu().numpy()[0])
                    enc_diagonal = np.diag(var.cpu().numpy()[0])

                    return enc_hx, enc_cx, enc_score, enc_diagonal

                elif delta == True:
                    # print('DELTA')
                    delta_socre = norm_dist.mean
                    delta_vari = norm_dist.variance
                    delta_cov = np.diag(delta_vari.cpu().numpy()[0])
                    # print('DELTA variance')
                    # print(delta_vari.cpu().numpy()[0] )
                    # delta_cov = norm_dist.covariance_matrix
                    # delta_cov = delta_cov.reshape((22,22))
                    return d_enc_hx, d_enc_cx, delta_score, delta_cov

        if self.loss_method == 'oad_before':
            return enc_hx, enc_cx, enc_score, d_enc_hx, d_enc_cx, delta_score

        elif self.loss_method == 'state_before':
            return enc_hx, enc_cx, enc_score, enc_var, d_enc_hx, d_enc_cx, delta_score, delta_var
Esempio n. 15
0
        new.lambda1 = self.lambda1.expand(lambda1_shape)
        new.lambda2 = self.lambda2.expand(lambda2_shape)
        super(NaturalNormalWishart, new).__init__(batch_shape,
                                                  self.event_shape,
                                                  validate_args=False)
        new._validate_args = self._validate_args
        return new


if __name__ == '__main__':
    N, K, D = 1000, 3, 2
    mean = torch.zeros(D)
    nu = torch.tensor(1.)
    a = torch.tensor(float(D) - 1.)
    B = torch.eye(D)
    niw = NaturalNormalWishart.from_standard(mean, nu, a, B)
    data = torch.randn(N, D)
    post_niw = niw.posterior(data)
    print(post_niw.to_standard())
    mix = td.Dirichlet(torch.ones(K))
    weights = mix.sample((N, ))
    expanded_niw = niw.expand((K, ))
    post_niw = expanded_niw.posterior(data)
    print(post_niw.to_standard())
    post_niw = expanded_niw.posterior(data, weights)
    print(post_niw.to_standard())

    samples = niw.rsample((K, ))
    print(samples)
    print(samples.batch_shape, samples.event_shape)
Esempio n. 16
0
 def entropy(self):
     return dists.Dirichlet(self.alpha).entropy()
Esempio n. 17
0
 def confusion_matrix(self, j: int, c: int) -> dist.Distribution:
     """
     Confusion matrix for each labeler (j) and category (c), where each row is a
     Dirichlet distribution.
     """
     return dist.Dirichlet(self.alpha[c])
Esempio n. 18
0
 def sample(self, batch_size):
     return dists.Dirichlet(self.alpha).rsample((batch_size, ))
Esempio n. 19
0
 def __init__(self, a):
     dist = dists.Dirichlet(a[0])
     super().__init__(dist, "Dirichlet", -1, a)
Esempio n. 20
0
 def init_parameters(self, alpha=1.):
     for x in self.parameters(recurse=False):
        dirich = distr.Dirichlet(th.tensor([alpha] * x.shape[-1]))
        x.data = dirich.sample(x.shape[:-1]).log()
Esempio n. 21
0
def _assemble_spot(
    cnt: np.ndarray,
    labels: np.ndarray,
    alpha: float = 1.0,
    fraction: float = 0.1,
    bounds: List[int] = [10, 30],
) -> Dict[str, t.Tensor]:
    """Assemble single spot

    generates one synthetic ST-spot
    from provided single cell data

    Parameter:
    ---------
    cnt : np.ndarray
        single cell count data [n_cells x n_genes]
    labels : np.ndarray
        single cell annotations [n_cells]
    alpha : float
        dirichlet distribution
        concentration value
    fraction : float
        fraction of transcripts from each cell
        being observed in ST-spot

    Returns:
    -------
    Dictionary with expression data,
    proportion values and number of
    cells from each type at every
    spot

    """

    # sample between 10 to 30 cells to be present
    # at spot
    n_cells = dists.uniform.Uniform(
        low=bounds[0], high=bounds[1]).sample().round().type(t.int)

    # get unique labels found in single cell data
    uni_labs, uni_counts = np.unique(labels, return_counts=True)

    # make sure sufficient number
    # of cells are present within
    # all cell types
    assert np.all(uni_counts >=  30), \
            "Insufficient number of cells"

    # get number of different
    # cell types present
    n_labels = uni_labs.shape[0]

    # sample number of types to
    # be present at current spot
    n_types = dists.uniform.Uniform(low=1, high=n_labels).sample()

    n_types = n_types.round().type(t.int)

    # select which types to include
    pick_types = t.randperm(n_labels)[0:n_types]
    # pick at least one cell for spot
    members = t.zeros(n_labels).type(t.float)
    while members.sum() < 1:
        # draw proportion values from probability simplex
        member_props = dists.Dirichlet(concentration=alpha *
                                       t.ones(n_types)).sample()
        # get integer number of cells based on proportions
        members[pick_types] = (n_cells * member_props).round()

    # get proportion of each type
    props = members / members.sum()
    # convert to ints
    members = members.type(t.int)
    # get number of cells from each cell type

    # generate spot expression data
    spot_expr = t.zeros(cnt.shape[1]).type(t.float32)

    for z in range(n_types):
        # get indices of selected type
        idx = np.where(labels == uni_labs[pick_types[z]])[0]
        # pick random cells from type
        np.random.shuffle(idx)
        idx = idx[0:members[pick_types[z]]]
        # add fraction of transcripts to spot expression
        spot_expr += t.tensor(
            (cnt[idx, :] * fraction).sum(axis=0).round().astype(np.float32))

    return {
        'expr': spot_expr,
        'proportions': props,
        'members': members,
    }