Ejemplo n.º 1
0
    def __call__(self, input_chunk: np.ndarray):
        """
        args:
            input_chunk (Chunk): input chunk with global offset
        """
        assert isinstance(input_chunk, Chunk)

        self._update_parameters_for_input_chunk(input_chunk)
        output_buffer = self._get_output_buffer(input_chunk)

        if not self.mask_output_chunk:
            self._check_alignment()

        if self.dry_run:
            print('dry run, return a special artifical chunk.')
            size = output_buffer.shape

            if self.mask_myelin_threshold:
                # eleminate the myelin channel
                size = (size[0] - 1, *size[1:])

            return Chunk.create(size=size,
                                dtype=output_buffer.dtype,
                                voxel_offset=output_buffer.global_offset)

        if input_chunk == 0:
            print('input is all zero, return zero buffer directly')
            if self.mask_myelin_threshold:
                assert output_buffer.shape[0] == 4
                return output_buffer[:-1, ...]
            else:
                return output_buffer

        if np.issubdtype(input_chunk.dtype, np.integer):
            # normalize to 0-1 value range
            dtype_max = np.iinfo(input_chunk.dtype).max
            input_chunk = input_chunk.astype(self.dtype) / dtype_max

        if self.verbose:
            chunk_time_start = time.time()

        # iterate the offset list
        for i in tqdm(range(0, len(self.patch_slices_list), self.batch_size),
                      disable=not self.verbose,
                      desc='ConvNet inference for patches: '):
            if self.verbose:
                start = time.time()

            batch_slices = self.patch_slices_list[i:i + self.batch_size]
            for batch_idx, slices in enumerate(batch_slices):
                self.input_patch_buffer[batch_idx,
                                        0, :, :, :] = input_chunk.cutout(
                                            slices[0]).array

            if self.verbose > 1:
                end = time.time()
                print('prepare %d input patches takes %3f sec' %
                      (self.batch_size, end - start))
                start = end

            # the input and output patch is a 5d numpy array with
            # datatype of float32, the dimensions are batch/channel/z/y/x.
            # the input image should be normalized to [0,1]
            output_patch = self.patch_inferencer(self.input_patch_buffer)

            if self.verbose > 1:
                assert output_patch.ndim == 5
                end = time.time()
                print('run inference for %d patch takes %3f sec' %
                      (self.batch_size, end - start))
                start = end

            for batch_idx, slices in enumerate(batch_slices):
                # only use the required number of channels
                # the remaining channels are dropped
                # the slices[0] is for input patch slice
                # the slices[1] is for output patch slice
                offset = (0, ) + tuple(s.start for s in slices[1])
                output_chunk = Chunk(output_patch[batch_idx, :, :, :, :],
                                     global_offset=offset)

                ## save some patch for debug
                #bbox = output_chunk.bbox
                #if bbox.minpt[-1] < 94066 and bbox.maxpt[-1] > 94066 and \
                #        bbox.minpt[-2]<81545 and bbox.maxpt[-2]>81545 and \
                #        bbox.minpt[-3]<17298 and bbox.maxpt[-3]>17298:
                #    print('save patch: ', output_chunk.bbox)
                #    breakpoint()
                #    output_chunk.to_tif()
                #    #input_chunk.cutout(slices[0]).to_tif()

                output_buffer.blend(output_chunk)

            if self.verbose > 1:
                end = time.time()
                print('blend patch takes %3f sec' % (end - start))

        if self.verbose:
            print("Inference of whole chunk takes %3f sec" %
                  (time.time() - chunk_time_start))

        if self.mask_output_chunk:
            output_buffer *= self.output_chunk_mask

        # theoretically, all the value of output_buffer should not be greater than 1
        # we use a slightly higher value here to accomondate numerical precision issue
        np.testing.assert_array_less(
            output_buffer,
            1.0001,
            err_msg='output buffer should not be greater than 1')

        if self.mask_myelin_threshold:
            # currently only for masking out affinity map
            assert output_buffer.shape[0] == 4
            output_chunk = output_buffer.mask_using_last_channel(
                threshold=self.mask_myelin_threshold)

            # currently neuroglancer only support float32, not float16
            if output_chunk.dtype == np.dtype('float16'):
                output_chunk = output_chunk.astype('float32')

            return output_chunk
        else:
            return output_buffer
Ejemplo n.º 2
0
    def predict_chunk(self, input_chunk: np.ndarray):
        """
        args:
           input_chunk (Chunk): input chunk with global offset
        """
        assert isinstance(input_chunk, Chunk)

        self._update_parameters_for_input_chunk(input_chunk)
        output_buffer = self._get_output_buffer(input_chunk)

        if not self.mask_output_chunk:
            self._check_alignment()

        if self.dry_run:
            print('dry run, return a special artificial chunk.')
            size = output_buffer.shape

            if self.mask_myelin_threshold:
                # eliminate myelin channel
                size = (size[0] - 1, *size[1:])

            return Chunk.create(size=size,
                                dtype=output_buffer.dtype,
                                voxel_offset=output_buffer.global_offset)

        if input_chunk == 0:
            print('input is all zero, return zero buffer directly')
            if self.mask_myelin_threshold:
                assert output_buffer.shape[0] == 4
                return output_buffer[:-1, ...]
            else:
                return output_buffer

        if np.issubdtype(input_chunk.dtype, np.integer):
            # normalize to 0-1 value range
            dtype_max = np.info(input_chunk.dtype).max
            input_chunk = input_chunk.astype(self.dtype) / dtype_max

        if self.verbose:
            chunk_time_start = time.time()

        # set model to evalutation mode
        self.model.eval()

        # send model to device
        self.model.cuda()

        with torch.no_grad():
            for i in range(0, len(self.patch_slices_list), self.batch_size):
                if self.verbose:
                    start = time.time()

                batch_slices = self.patch_slices_list[i:i + self.batch_size]
                for batch_idx, slices in enumerate(batch_slices):
                    self.input_patch_buffer[batch_idx,
                                            0, :, :, :] = input_chunk.cutout(
                                                slices[0]).array

                if self.verbose > 1:
                    end = time.time()
                    print(
                        'preparing %d input patches takes %3f sec' %
                        self.batch_size, end - start)
                    start = end

                # the input and output patch is a 5d numpy array with
                # datatype of float32, the dimensions are (batch, channel, z, y, x)
                # the input image should be normalized to [0,1]
                patch = torch.from_numpy(
                    self.input_patch_buffer).float().cuda()

                output_patch = self.model(patch)

                assert output_patch.ndim == 5

                net_out = output_patch.cpu().numpy()
                #net_out_mask = np.where(net_out >= 0.9, 1, 0)
                #print(net_out.shape)
                for batch_idx, slices in enumerate(batch_slices):
                    # slices[0] is for input patch slice
                    # slices[1] is for output patch slice
                    offset = (0, ) + tuple(s.start for s in slices[1])
                    print(offset)
                    output_chunk = Chunk(net_out[batch_idx, 1:, :, :, :],
                                         global_offset=offset)
                    output_buffer.blend(output_chunk)

        return output_buffer