Exemplo n.º 1
0
    def setUp(self):
        super(ErrorBase, self).setUp()

        rtp = data_provider.ReadTableParameters([self.recordio_path()], True,
                                                util.BatchParameters(2, 1, 2),
                                                True, 0, 768)
        dp = data_provider.DataParameters(rtp, self.input_z_values,
                                          self.input_channel_values,
                                          self.target_z_values,
                                          self.target_channel_values)
        _, self.target_lt = data_provider.cropped_input_and_target(dp)

        self.target_lt_0 = lt.select(self.target_lt, {
            'channel': 'DAPI_CONFOCAL',
            'mask': False
        })
        self.target_lt_1 = lt.select(self.target_lt, {
            'channel': 'NEURITE_CONFOCAL',
            'mask': False
        })

        self.target_lt_0 = lt.reshape(self.target_lt_0,
                                      self.target_lt_0.axes.keys()[3:],
                                      ['channel'])
        self.target_lt_1 = lt.reshape(self.target_lt_1,
                                      self.target_lt_1.axes.keys()[3:],
                                      ['channel'])
Exemplo n.º 2
0
    def setUp(self):
        super(Base, self).setUp()

        self.input_z_values = [round(v, 4) for v in np.linspace(0.0, 1.0, 13)]
        self.input_channel_values = ['BRIGHTFIELD', 'PHASE_CONTRAST', 'DIC']
        self.target_z_values = ['MAXPROJECT']
        self.target_channel_values = [
            'DAPI_CONFOCAL',
            'DAPI_WIDEFIELD',
            'CELLMASK_CONFOCAL',
            'TUJ1_WIDEFIELD',
            'NFH_CONFOCAL',
            'MAP2_CONFOCAL',
            'ISLET_WIDEFIELD',
            'DEAD_CONFOCAL',
        ]

        rtp = data_provider.ReadTableParameters([self.recordio_path()], True,
                                                util.BatchParameters(2, 1, 2),
                                                True, 0, 256)
        dp = data_provider.DataParameters(rtp, self.input_z_values,
                                          self.input_channel_values,
                                          self.target_z_values,
                                          self.target_channel_values)
        # pylint: disable=line-too-long
        self.input_lt, self.target_lt = data_provider.cropped_input_and_target(
            dp)
        # pylint: enable=line-too-long

        minception_rtp = data_provider.ReadTableParameters(
            [self.recordio_path()], True, util.BatchParameters(2, 1, 2), True,
            51, 110)
        minception_dp = data_provider.DataParameters(
            minception_rtp, self.input_z_values, self.input_channel_values,
            self.target_z_values, self.target_channel_values)
        # pylint: disable=line-too-long
        self.minception_input_lt, self.minception_target_lt = data_provider.cropped_input_and_target(
            minception_dp)
Exemplo n.º 3
0
    def setUp(self):
        super(CroppedInputAndTargetTest, self).setUp()

        batch_size = 2
        rtp = data_provider.ReadTableParameters(
            shard_paths=[self.recordio_path()],
            is_recordio=True,
            bp=util.BatchParameters(batch_size, num_threads=1, capacity=2),
            is_deterministic=True,
            pad_size=0,
            crop_size=512)

        dp = data_provider.DataParameters(
            rtp,
            input_z_values=self.input_z_values,
            input_channel_values=self.input_channel_values,
            target_z_values=self.target_z_values,
            target_channel_values=self.target_channel_values)

        self.input_lt, self.target_lt = data_provider.cropped_input_and_target(
            dp)
Exemplo n.º 4
0
    def setUp(self):
        super(ErrorPanelTest, self).setUp()

        rtp = data_provider.ReadTableParameters([self.recordio_path()], True,
                                                util.BatchParameters(2, 1, 2),
                                                True, 0, 768)
        dp = data_provider.DataParameters(rtp, self.input_z_values,
                                          self.input_channel_values,
                                          self.target_z_values,
                                          self.target_channel_values)
        _, batch_target_lt = data_provider.cropped_input_and_target(dp)

        self.prediction_lt = lt.slice(
            lt.select(batch_target_lt, {'mask': False}),
            {'batch': slice(0, 1)})
        self.prediction_lt = util.onehot(16, self.prediction_lt)

        self.target_lt = lt.slice(batch_target_lt, {'batch': slice(1, 2)})

        self.error_panel_lt = visualize.error_panel(self.target_lt,
                                                    self.prediction_lt)
Exemplo n.º 5
0
def provide_preprocessed_data(
    dp: data_provider.DataParameters,
    ap: Optional[augment.AugmentParameters],
    extract_patch_size: int,
    stride: int,
    name: str = None,
) -> Tuple[np.ndarray, lt.LabeledTensor, lt.LabeledTensor]:
    """Provide preprocessed input and target patches.

  Args:
    dp: DataParameters.
    ap: Optional AugmentParameters.
    extract_patch_size: Patch size for patch extraction.
    stride: Stride for patch extraction.
    name: Optional op name.

  Returns:
    An array containing patch center locations.
    The patches are extracted with padding, so that the stitched model outputs
    will form an image the same size as the input.

    A tensor with model inputs, possibly jittered and corrupted for data
    enrichment.

    A tensor with model targets.
  """
    with tf.name_scope(name, 'provide_preprocessed_data', []) as scope:
        input_lt, target_lt = data_provider.cropped_input_and_target(dp)
        visualize.summarize_image(
            visualize.canonical_image(input_lt, name=scope + 'input'))
        visualize.summarize_image(
            visualize.canonical_image(target_lt, name=scope + 'target'))

        if ap is not None:
            input_lt, target_lt = augment.augment(ap, input_lt, target_lt)
            visualize.summarize_image(
                visualize.canonical_image(input_lt,
                                          name=scope + 'input_jitter'))
            visualize.summarize_image(
                visualize.canonical_image(target_lt,
                                          name=scope + 'target_jitter'))

        rc = lt.ReshapeCoder(['z', 'channel', 'mask'], ['channel'])

        patch_centers, input_lt = ops.extract_patches_single_scale(
            extract_patch_size, stride, rc.encode(input_lt))
        input_lt = rc.decode(input_lt)
        input_lt = lt.reshape(input_lt, ['batch', 'patch_row', 'patch_column'],
                              ['batch'],
                              name=scope + 'input')

        rc = lt.ReshapeCoder(['z', 'channel', 'mask'], ['channel'])
        target_lt = rc.decode(
            ops.extract_patches_single_scale(extract_patch_size, stride,
                                             rc.encode(target_lt))[1])
        target_lt = lt.reshape(target_lt,
                               ['batch', 'patch_row', 'patch_column'],
                               ['batch'],
                               name=scope + 'target')

        return patch_centers, input_lt, target_lt