Ejemplo n.º 1
0
 def test_embed_mws_rand(self):
     from torch_em.metric import EmbeddingMWSRandMetric
     emebd = torch.from_numpy(np.random.rand(self.batch_size, 6, 128, 128))
     gt = make_gt((128, 128), n_batches=self.batch_size, with_channels=True)
     metric = EmbeddingMWSRandMetric(delta=2.0,
                                     offsets=self.offsets,
                                     min_seg_size=self.min_size)
     self._test_metric(emebd, gt, metric, upper_bound=1.0)
Ejemplo n.º 2
0
 def test_mws_sbd(self):
     from torch_em.metric import MWSSBDMetric
     affs = torch.from_numpy(np.random.rand(self.batch_size, 5, 128, 128))
     gt = make_gt((128, 128),
                  n_batches=self.batch_size,
                  with_channels=True,
                  with_background=True)
     metric = MWSSBDMetric(self.offsets, self.min_size)
     self._test_metric(affs, gt, metric, upper_bound=1.0)
Ejemplo n.º 3
0
 def test_hdbscan_sbd(self):
     from torch_em.metric import HDBScanSBDMetric
     embed = torch.from_numpy(np.random.rand(self.batch_size, 6, 128, 128))
     gt = make_gt((128, 128),
                  n_batches=self.batch_size,
                  with_channels=True,
                  with_background=True)
     metric = HDBScanSBDMetric(min_size=50, eps=1.0e-4)
     self._test_metric(embed, gt, metric, upper_bound=1.0)
Ejemplo n.º 4
0
    def _test_spoco(self, aux_loss):
        from torch_em.loss import SPOCOLoss
        loss = SPOCOLoss(delta_var=0.75, delta_dist=2.0, aux_loss=aux_loss)
        input1 = torch.from_numpy(np.random.rand(2, 8, 64, 64))
        input1.requires_grad = True
        input1.retain_grad = True
        input2 = torch.from_numpy(np.random.rand(2, 8, 64, 64))
        target = make_gt(
            (64, 64), n_batches=2, with_channels=True, dtype="int64") - 1
        assert target.min() == 0
        lval = loss((input1, input2), target)
        self.assertNotEqual(lval.item(), 0.0)

        lval.backward()
        grads = input1.grad
        self.assertEqual(grads.shape, input1.shape)
        self.assertFalse(np.allclose(grads.numpy(), 0))
Ejemplo n.º 5
0
 def test_multicut_voi(self):
     from torch_em.metric import MulticutVOIMetric
     bd = torch.from_numpy(np.random.rand(self.batch_size, 1, 128, 128))
     gt = make_gt((128, 128), n_batches=self.batch_size, with_channels=True)
     metric = MulticutVOIMetric(self.min_size)
     self._test_metric(bd, gt, metric)
Ejemplo n.º 6
0
 def test_mws_voi(self):
     from torch_em.metric import MWSVOIMetric
     affs = torch.from_numpy(np.random.rand(self.batch_size, 4, 128, 128))
     gt = make_gt((128, 128), n_batches=self.batch_size, with_channels=True)
     metric = MWSVOIMetric(self.offsets, min_seg_size=self.min_size)
     self._test_metric(affs, gt, metric)