Exemplo n.º 1
0
 def __call__(self, chunk: Chunk):
     assert isinstance(chunk, Chunk)
     assert chunk.ndim == 3
     for z in tqdm(range(chunk.voxel_offset[0], chunk.bbox.maxpt[0])):
         img = chunk.cutout((slice(z,z+1), chunk.slices[1], chunk.slices[2]))
         img = img.array[0,:,:]
         imsave(os.path.join(self.output_path, '{:05d}.png'.format(z)), img)
Exemplo n.º 2
0
    def __call__(self, chunk: Chunk):
        assert isinstance(chunk, Chunk)
        if chunk.is_affinity_map:
            properties = chunk.properties
            chunk = (chunk[1, ...] + chunk[2, ...]) / 2. * 255.
            chunk = chunk.astype(np.uint8)
            chunk = Chunk(chunk)
            chunk.set_properties(properties)

        assert chunk.ndim == 3
        for z in tqdm(range(chunk.voxel_offset[0], chunk.bbox.maxpt[0])):
            img = chunk.cutout(
                (slice(z, z + 1), chunk.slices[1], chunk.slices[2]))
            img = img.array[0, :, :]
            filename = os.path.join(self.output_path, f"{z:05d}.png")
            with open(filename, "wb") as f:
                f.write(pyspng.encode(img))
Exemplo n.º 3
0
    def __call__(self, input_chunk: Chunk):
        """
        args:
            input_chunk (Chunk): input chunk with voxel offset and voxel size 
        """
        assert isinstance(input_chunk, Chunk)

        self._update_parameters_for_input_chunk(input_chunk)
        output_buffer = self._get_output_buffer(input_chunk)
        if not self.mask_output_chunk:
            self._check_alignment()

        if self.dry_run:
            logging.info('dry run, return a special artifical chunk.')
            size = output_buffer.shape

            if self.mask_myelin_threshold:
                # eleminate the myelin channel
                size = (size[0] - 1, *size[1:])

            return Chunk.create(
                size=size,
                dtype=output_buffer.dtype,
                voxel_offset=output_buffer.voxel_offset,
                voxel_size=input_chunk.voxel_size,
            )

        if np.all(input_chunk == 0):
            logging.info('input is all zero, return zero buffer directly')
            if self.mask_myelin_threshold:
                assert output_buffer.shape[0] == 4
                return output_buffer[:-1, ...]
            else:
                return output_buffer

        if np.issubdtype(input_chunk.dtype, np.integer):
            # normalize to 0-1 value range
            dtype_max = np.iinfo(input_chunk.dtype).max
            input_chunk = input_chunk.astype(self.dtype)
            input_chunk /= dtype_max

        chunk_time_start = time.time()

        # iterate the offset list
        for i in tqdm(range(0, len(self.patch_slices_list), self.batch_size),
                      disable=(self.verbose <= 0),
                      desc='ConvNet inference for patches: '):
            start = time.time()

            batch_slices = self.patch_slices_list[i:i + self.batch_size]
            for batch_idx, slices in enumerate(batch_slices):
                self.input_patch_buffer[batch_idx,
                                        0, :, :, :] = input_chunk.cutout(
                                            slices[0]).array

            end = time.time()
            logging.debug(
                f'prepare {self.batch_size:d} input patches takes {end-start:.3f} sec'
            )
            start = end

            # the input and output patch is a 5d numpy array with
            # datatype of float32, the dimensions are batch/channel/z/y/x.
            # the input image should be normalized to [0,1]
            if not self.test_time_augmentation:
                output_patch = self.patch_inferencer(self.input_patch_buffer)
            else:
                # test time augmentation
                pass

            end = time.time()
            logging.debug(
                f'run inference for {self.batch_size:d} patch takes {end-start:.3f} sec'
            )
            start = end

            for batch_idx, slices in enumerate(batch_slices):
                # only use the required number of channels
                # the remaining channels are dropped
                # the slices[0] is for input patch slice
                # the slices[1] is for output patch slice
                offset = tuple(s.start for s in slices[1])
                output_patch_chunk = Chunk(output_patch[batch_idx, :, :, :, :],
                                           voxel_offset=offset,
                                           voxel_size=input_chunk.voxel_size)

                ## save some patch for debug
                #bbox = output_chunk.bbox
                #if bbox.minpt[-1] < 94066 and bbox.maxpt[-1] > 94066 and \
                #        bbox.minpt[-2]<81545 and bbox.maxpt[-2]>81545 and \
                #        bbox.minpt[-3]<17298 and bbox.maxpt[-3]>17298:
                #    print('save patch: ', output_chunk.bbox)
                #    output_chunk.to_tif()
                #    #input_chunk.cutout(slices[0]).to_tif()

                output_buffer.blend(output_patch_chunk)

            end = time.time()
            logging.debug('blend patch takes {:.3f} sec'.format(end - start))
            logging.debug("Inference of whole chunk takes {:.3f} sec".format(
                time.time() - chunk_time_start))

        if self.mask_output_chunk:
            output_buffer *= self.output_chunk_mask

        # theoretically, all the value of output_buffer should not be greater than 1
        # we use a slightly higher value here to accomondate numerical precision issue
        np.testing.assert_array_less(
            output_buffer,
            1.0001,
            err_msg='output buffer should not be greater than 1')

        if self.mask_myelin_threshold:
            # currently only for masking out affinity map
            assert output_buffer.shape[0] == 4
            output_chunk = output_buffer.mask_using_last_channel(
                threshold=self.mask_myelin_threshold)

            if output_chunk.dtype == np.dtype('<f4'):
                output_chunk = output_chunk.astype('float32')

            return output_chunk
        else:
            return output_buffer