Ejemplo n.º 1
0
    def eval_step(self, data, model_dict=None):
        ''' Performs an evaluation step.

        Args:
            data (dict): data dictionary
        '''
        self.model.eval()

        device = self.device
        threshold = self.threshold
        eval_dict = {}

        points_iou = data.get('points_iou').to(device)
        occ_iou = data.get('points_iou.occ').to(device)
        root_locs = data.get('points_iou.root_loc').to(device)
        trans = data.get('points_iou.trans').to(device)
        loc = data.get('points_iou.loc').to(device)
        # bone_transforms = data.get('points_iou.bone_transforms').to(device)
        # bone_transforms_inv = data.get('points_iou.bone_transforms_inv').to(device)    # B x num_joints x 4 x 4
        batch_size, T, D = points_iou.size()

        occ_iou = occ_iou[:, :]

        kwargs = {}
        scale = data.get('points_iou.scale').to(device)
        kwargs.update({'scale': scale.view(-1, 1, 1)}) #, 'bone_transforms_inv': bone_transforms_inv})

        with torch.no_grad():
            # Encoder inputs
            inputs = data.get('inputs', torch.empty(1, 1, 0)).to(device)
            mask = torch.ones(batch_size, T, dtype=points_iou.dtype, device=points_iou.device)

            # Decode occupancies
            out_dict = self.model(points_iou, inputs, **kwargs)
            logits = out_dict['logits']

            if len(logits.shape) == 4:
                # PTF-piecewise predictions
                logits = torch.max(logits, dim=1)[0]
                p_out = dist.Multinomial(logits=logits.transpose(1, 2))
            elif len(logits.shape) == 3:
                # IPNet/PTF predictions
                p_out = dist.Multinomial(logits=logits.transpose(1, 2))
            else:
                raise ValueError('Wrong logits shape')

        # Compute iou
        occ_iou_np = ((occ_iou >= 0.5) * mask).cpu().numpy()
        if len(logits.shape) == 3:
            # IoU for outer surface; we just want an easy-to-compute indicator for model selection
            occ_iou_hat_np = ((p_out.probs[:, :, 1:].sum(-1) >= threshold) * mask).cpu().numpy()
        else:
            raise ValueError('Wrong logits shape')

        iou = compute_iou(occ_iou_np, occ_iou_hat_np).mean()
        eval_dict['iou'] = iou

        return eval_dict
Ejemplo n.º 2
0
    def gen_data(self):
        # sample overall relative abundances of ASVs from a Dirichlet distribution
        self.ASV_rel_abundance = tdist.Dirichlet(torch.ones(
            self.numASVs)).sample()

        # sample spatial embedding of ASVs
        self.w = torch.zeros(self.numASVs, self.D)
        w_prior = tdist.MultivariateNormal(torch.zeros(self.D),
                                           torch.eye(self.D))

        for o in range(0, self.numASVs):
            self.w[o, :] = w_prior.sample()

        self.data = torch.zeros(self.numParticles, self.numASVs)

        num_nonempty = 0

        mu_prior = tdist.MultivariateNormal(torch.zeros(self.D),
                                            torch.eye(self.D))
        rad_prior = tdist.LogNormal(torch.tensor([self.mu_rad]),
                                    torch.tensor([self.mu_std]))

        # replace with neg bin prior
        num_reads_prior = tdist.Poisson(
            torch.tensor([self.avgNumReadsParticle]))

        while (num_nonempty < self.numParticles):
            # sample center
            mu = mu_prior.sample()
            rad = rad_prior.sample()

            zr = torch.zeros(1, self.numASVs, dtype=torch.float64)
            for o in range(0, self.numASVs):
                p = mu - self.w[o, :]
                p = torch.pow(p, 2.0) / rad
                p = (torch.sum(p)).sqrt()
                zr[0, o] = unitboxcar(p, 0.0, 2.0, self.step_approx)

            if torch.sum(zr) > 0.95:
                particle = Particle(mu, self)
                particle.zr = zr
                self.particles.append(particle)

                # renormalize particle abundances
                rn = self.ASV_rel_abundance * zr
                rn = rn / torch.sum(rn)

                # sample relative abundances for particle
                part_rel_abundance = tdist.Dirichlet(rn * self.conc).sample()

                # sample number of reads for particle
                # (replace w/ neg bin instead of Poisson)
                num_reads = num_reads_prior.sample().long().item()
                particle.total_reads = num_reads

                particle.reads = tdist.Multinomial(
                    num_reads, probs=part_rel_abundance).sample()

                num_nonempty += 1
Ejemplo n.º 3
0
def multinomial_sample(n: int, p_vec: Tensor) -> Tensor:
    r""" Multinomial distribution sample """
    assert p_vec.shape[0] > 0, "Multinomial size doesn't make sense"

    n_per_category = distributions.Multinomial(n, p_vec).sample().int()

    assert p_vec.shape == n_per_category.shape, "Dimension mismatch"
    assert int(n_per_category.sum().item()) == n, "Number of elements mismatch"
    return n_per_category
Ejemplo n.º 4
0
def multinomial_loss(logits, observations, reduction="mean"):
    """
    the nll of a multinomial distirbution parameterised by logits.
    """
    p = D.Multinomial(logits=logits)
    if reduction == "mean":
        return -p.log_prob(observations).mean()
    elif reduction == "sum":
        return -p.log_prob(observations).sum()
    raise NotImplementedError(f"Unknown reduction {reduction}")
Ejemplo n.º 5
0
def loss(probs, values):
    return torch.sum(-1 *
                     D.Multinomial(1, probs=probs).log_prob(values.float()))
Ejemplo n.º 6
0
def sample_multinomial(*args, **kwargs):
    return sample_(D.Multinomial(*args, **kwargs))