def run_benches(self, device):
        # This is just a hack for running benchmarks

        num_samples = 10
        torch_device = None if device is None else torch.device(
            f'cuda:{device}')

        if self.labels.ndim == 2:
            start = time.time()
            for _ in range(num_samples):
                cellpose.dynamics.masks_to_flows(self.labels,
                                                 use_gpu=(device is not None),
                                                 device=torch_device)
            cellpose_time = (time.time() - start) / num_samples
        else:
            # cellpose often bugs out on 3d images
            cellpose_time = float('inf')

        start = time.time()
        for _ in range(num_samples):
            dynamics.masks_to_flows(self.labels, device=device)
        polus_time = (time.time() - start) / num_samples

        self.assertLess(polus_time, cellpose_time,
                        f'Polus slower than Cellpose :(')
        self.assertLess(cellpose_time, polus_time,
                        f'Cellpose slower than Polus :)')
        return
    def test_polus_errors(self):
        polus_flows = dynamics.masks_to_flows(self.labels, device=None)
        self.assertEqual((self.labels.ndim, *self.labels.shape),
                         polus_flows.shape, f'cpu shapes were different')

        polus_flows = dynamics.masks_to_flows(self.labels, device=0)
        self.assertEqual((self.labels.ndim, *self.labels.shape),
                         polus_flows.shape, f'gpu shapes were different')
        return
    def image_test(self, image, device):
        torch_device = None if device is None else torch.device(
            f'cuda:{device}')
        cellpose_flows, _ = cellpose.dynamics.masks_to_flows(
            image, use_gpu=(device is not None), device=torch_device)
        if image.ndim == 3:  # 3d cellpose flows need to be normalized to unit-norm.
            cellpose_flows = (cellpose_flows /
                              (numpy.linalg.norm(cellpose_flows, axis=0) +
                               1e-20)) * (image != 0)

        polus_flows = dynamics.masks_to_flows(image, device=device)

        self.assertEqual(cellpose_flows.shape, polus_flows.shape,
                         f'flows had different shapes')

        error = numpy.mean(numpy.square(cellpose_flows - polus_flows))
        self.assertLess(error, 0.05, f'error was too large {error:.3f}')
        return
    def transform_labels(self, label_image):
        """transform the labels which will be used as training target
        Parameters
        --------------
        label_image: array [width, height, channel]
            a label image

        Returns
        ------------------
        array [width, height, channel]
            the transformed label image
        """
        assert label_image.ndim == 3 and label_image.shape[2] == 1
        label_image = label_image[:, :, 0]
        veci = dynamics.masks_to_flows(label_image)[0]
        # concatenate flows with cell probability
        flows = np.concatenate(
            (np.stack([label_image, label_image > 0.5], axis=0), veci),
            axis=0).astype(np.float32)
        return flows.transpose(1, 2, 0)
class VectorToLabelTest(unittest.TestCase):
    data_path = Path('/data/axle/tests')
    input_path = data_path.joinpath('input')
    if input_path.joinpath('images').is_dir():
        input_path = input_path.joinpath('images')

    fp = filepattern.FilePattern(input_path, '.+')
    infile = next(Path(files.pop()['file']).resolve() for files in fp)
    with BioReader(infile) as reader:
        labels = numpy.squeeze(reader[:, :, :, 0, 0])
    labels = numpy.reshape(
        numpy.unique(labels, return_inverse=True)[1], labels.shape)
    if labels.ndim == 3:
        labels = numpy.transpose(labels, (2, 0, 1))
    device = 0 if torch.cuda.is_available() else None
    flows = dynamics.masks_to_flows(labels, device=device)

    toy_labels = numpy.zeros((17, 17), dtype=numpy.uint8)
    toy_labels[2:15, 2:7] = 1
    toy_labels[2:15, 10:15] = 2
    toy_flows = dynamics.masks_to_flows(toy_labels, device=device)

    @unittest.skip
    def test_benches(self):
        masks = self.labels > 0
        flows = -self.flows * masks

        # Let's warm up...
        cellpose_locations = cellpose.dynamics.follow_flows(flows,
                                                            niter=200,
                                                            interp=False,
                                                            use_gpu=True)
        cellpose.dynamics.get_masks(cellpose_locations,
                                    iscell=masks,
                                    flows=self.flows,
                                    use_gpu=True)

        polus_locations = dynamics.follow_flows(flows,
                                                num_iterations=200,
                                                interpolate=False,
                                                device=self.device)
        dynamics.get_masks(polus_locations,
                           is_cell=masks,
                           flows=self.flows,
                           device=self.device)

        num_samples = 10
        start = time.time()
        for _ in range(num_samples):
            cellpose_locations = cellpose.dynamics.follow_flows(flows,
                                                                niter=200,
                                                                interp=False,
                                                                use_gpu=True)
            cellpose.dynamics.get_masks(cellpose_locations,
                                        iscell=masks,
                                        flows=self.flows,
                                        use_gpu=True)
        cellpose_time = round((time.time() - start) / num_samples, 12)

        start = time.time()
        for _ in range(num_samples):
            polus_locations = dynamics.follow_flows(flows,
                                                    num_iterations=200,
                                                    interpolate=False,
                                                    device=self.device)
            dynamics.get_masks(polus_locations,
                               is_cell=masks,
                               flows=self.flows,
                               device=self.device)
        polus_time = round((time.time() - start) / num_samples, 12)

        self.assertLess(polus_time, cellpose_time,
                        f'Polus slower than Cellpose :(')
        self.assertLess(cellpose_time, polus_time,
                        f'Cellpose slower than Polus :)')
        return

    def recover_masks_test(self, flows, labels):
        cellpose_locations = cellpose.dynamics.follow_flows(
            -flows * (labels != 0),
            niter=200,
            interp=False,
            use_gpu=True,
        )

        polus_locations = dynamics.follow_flows(
            -flows * (labels != 0),
            num_iterations=200,
            interpolate=False,
            device=None,
        )
        polus_masks = dynamics.get_masks(
            polus_locations,
            is_cell=(labels != 0),
            flows=flows,
            device=self.device,
        )
        polus_masks = dynamics.fill_holes_and_remove_small_masks(polus_masks)

        self.assertEqual(
            cellpose_locations.shape,
            polus_locations.shape,
            f'locations had different shapes',
        )

        # Some cellpose-locations contain horizontal artifacts but polus-locations do not.
        # Thus we clamp the error here. If there were a programmatic way to detect those artifacts,
        # we could deal with the error in a different way and present a fairer test.
        # As things stand, I believe my implementation to be correct.
        locations_diff = numpy.clip(
            numpy.abs(cellpose_locations - polus_locations), 0, 1)
        self.assertLess(numpy.mean(locations_diff**2), 0.1,
                        f'error in convergent locations was too large...')

        masks_diff = (polus_masks == 0) != (labels == 0)
        self.assertLess(numpy.mean(masks_diff), 0.05,
                        f'error in polus masks was too large...')
        return

    def test_masks(self):
        self.recover_masks_test(self.toy_flows, self.toy_labels)
        self.recover_masks_test(self.flows, self.labels)
        return
def flow_thread(input_path: Path, zfile: Path, use_gpu: bool,
                dev: torch.device, x: int, y: int, z: int) -> bool:
    """ Converts labels to flows

    This function converts labels in each tile to vector field.

    Args:
        input_path(path): Path of input image collection
        zfile(path): Path where output zarr file will be saved
        x(int): Start index of the tile in x dimension of image
        y(int): Start index of the tile in y dimension of image
        z(int): Z slice of the  image

    """

    logging.basicConfig(
        format='%(asctime)s - %(name)-8s - %(levelname)-8s - %(message)s',
        datefmt='%d-%b-%y %H:%M:%S')
    logger = logging.getLogger("flow")
    logger.setLevel(logging.INFO)

    root = zarr.open(str(zfile))[0]

    with BioReader(input_path) as br:
        x_min = max([0, x - TILE_OVERLAP])
        x_max = min([br.X, x + TILE_SIZE + TILE_OVERLAP])
        y_min = max([0, y - TILE_OVERLAP])
        y_max = min([br.Y, y + TILE_SIZE + TILE_OVERLAP])

        # Normalize
        I = br[y_min:y_max, x_min:x_max, z:z + 1, 0, 0].squeeze()
        _, image = np.unique(I, return_inverse=True)
        image = image.reshape(y_max - y_min, x_max - x_min)

        flow = dynamics.masks_to_flows(image, use_gpu, dev)[0]

        logger.debug('Computed flows on slice %d tile(y,x) %d:%d %d:%d ', z, y,
                     y_max, x, x_max)
        flow_final = flow[:, :, :, np.newaxis,
                          np.newaxis].transpose(1, 2, 3, 0, 4)
        x_overlap = x - x_min
        x_min = x
        x_max = min([br.X, x + TILE_SIZE])
        y_overlap = y - y_min
        y_min = y
        y_max = min([br.Y, y + TILE_SIZE])

        root[0:1, 0:1, z:z + 1, y_min:y_max, x_min:x_max, ] = (
            I[y_overlap:y_max - y_min + y_overlap, x_overlap:x_max - x_min +
              x_overlap, np.newaxis, np.newaxis, np.newaxis] > 0).transpose(
                  4, 3, 2, 0, 1)
        root[0:1, 1:3, z:z + 1, y_min:y_max,
             x_min:x_max] = flow_final[y_overlap:y_max - y_min + y_overlap,
                                       x_overlap:x_max - x_min + x_overlap,
                                       ...].transpose(4, 3, 2, 0, 1)
        root[0:1, 3:4, z:z + 1, y_min:y_max,
             x_min:x_max, ] = I[y_overlap:y_max - y_min + y_overlap,
                                x_overlap:x_max - x_min + x_overlap,
                                np.newaxis, np.newaxis, np.newaxis].astype(
                                    np.float32).transpose(4, 3, 2, 0, 1)

    return True