Esempio n. 1
0
    def setUp(self):
        super(Base, self).setUp()

        rtp = data_provider.ReadTableParameters(
            shard_paths=[self.recordio_path()],
            is_recordio=True,
            bp=util.BatchParameters(2, num_threads=1, capacity=2),
            is_deterministic=True,
            pad_size=51,
            crop_size=110 + 24)
        self.dp = data_provider.DataParameters(
            rtp,
            input_z_values=self.input_z_values,
            input_channel_values=self.input_channel_values,
            target_z_values=self.target_z_values,
            target_channel_values=self.target_channel_values)

        self.ap = augment.AugmentParameters(offset_standard_deviation=0.1,
                                            multiplier_standard_deviation=0.05,
                                            noise_standard_deviation=0.1)

        self.extract_patch_size = 80
        self.stride = 8
        self.stitch_patch_size = 8

        self.bp = util.BatchParameters(4, 1, 4)

        self.core_model = functools.partial(fovea.core, 50)
        self.add_head = functools.partial(model_util.add_head,
                                          is_residual_conv=True)

        self.num_classes = 16
        self.shuffle = False
Esempio n. 2
0
    def setUp(self):
        super(ErrorBase, self).setUp()

        rtp = data_provider.ReadTableParameters([self.recordio_path()], True,
                                                util.BatchParameters(2, 1, 2),
                                                True, 0, 768)
        dp = data_provider.DataParameters(rtp, self.input_z_values,
                                          self.input_channel_values,
                                          self.target_z_values,
                                          self.target_channel_values)
        _, self.target_lt = data_provider.cropped_input_and_target(dp)

        self.target_lt_0 = lt.select(self.target_lt, {
            'channel': 'DAPI_CONFOCAL',
            'mask': False
        })
        self.target_lt_1 = lt.select(self.target_lt, {
            'channel': 'NEURITE_CONFOCAL',
            'mask': False
        })

        self.target_lt_0 = lt.reshape(self.target_lt_0,
                                      self.target_lt_0.axes.keys()[3:],
                                      ['channel'])
        self.target_lt_1 = lt.reshape(self.target_lt_1,
                                      self.target_lt_1.axes.keys()[3:],
                                      ['channel'])
Esempio n. 3
0
    def setUp(self):
        super(Base, self).setUp()

        self.input_z_values = [round(v, 4) for v in np.linspace(0.0, 1.0, 13)]
        self.input_channel_values = ['BRIGHTFIELD', 'PHASE_CONTRAST', 'DIC']
        self.target_z_values = ['MAXPROJECT']
        self.target_channel_values = [
            'DAPI_CONFOCAL',
            'DAPI_WIDEFIELD',
            'CELLMASK_CONFOCAL',
            'TUJ1_WIDEFIELD',
            'NFH_CONFOCAL',
            'MAP2_CONFOCAL',
            'ISLET_WIDEFIELD',
            'DEAD_CONFOCAL',
        ]

        rtp = data_provider.ReadTableParameters([self.recordio_path()], True,
                                                util.BatchParameters(2, 1, 2),
                                                True, 0, 256)
        dp = data_provider.DataParameters(rtp, self.input_z_values,
                                          self.input_channel_values,
                                          self.target_z_values,
                                          self.target_channel_values)
        # pylint: disable=line-too-long
        self.input_lt, self.target_lt = data_provider.cropped_input_and_target(
            dp)
        # pylint: enable=line-too-long

        minception_rtp = data_provider.ReadTableParameters(
            [self.recordio_path()], True, util.BatchParameters(2, 1, 2), True,
            51, 110)
        minception_dp = data_provider.DataParameters(
            minception_rtp, self.input_z_values, self.input_channel_values,
            self.target_z_values, self.target_channel_values)
        # pylint: disable=line-too-long
        self.minception_input_lt, self.minception_target_lt = data_provider.cropped_input_and_target(
            minception_dp)
Esempio n. 4
0
    def setUp(self):
        super(CroppedInputAndTargetTest, self).setUp()

        batch_size = 2
        rtp = data_provider.ReadTableParameters(
            shard_paths=[self.recordio_path()],
            is_recordio=True,
            bp=util.BatchParameters(batch_size, num_threads=1, capacity=2),
            is_deterministic=True,
            pad_size=0,
            crop_size=512)

        dp = data_provider.DataParameters(
            rtp,
            input_z_values=self.input_z_values,
            input_channel_values=self.input_channel_values,
            target_z_values=self.target_z_values,
            target_channel_values=self.target_channel_values)

        self.input_lt, self.target_lt = data_provider.cropped_input_and_target(
            dp)
Esempio n. 5
0
    def setUp(self):
        super(ErrorPanelTest, self).setUp()

        rtp = data_provider.ReadTableParameters([self.recordio_path()], True,
                                                util.BatchParameters(2, 1, 2),
                                                True, 0, 768)
        dp = data_provider.DataParameters(rtp, self.input_z_values,
                                          self.input_channel_values,
                                          self.target_z_values,
                                          self.target_channel_values)
        _, batch_target_lt = data_provider.cropped_input_and_target(dp)

        self.prediction_lt = lt.slice(
            lt.select(batch_target_lt, {'mask': False}),
            {'batch': slice(0, 1)})
        self.prediction_lt = util.onehot(16, self.prediction_lt)

        self.target_lt = lt.slice(batch_target_lt, {'batch': slice(1, 2)})

        self.error_panel_lt = visualize.error_panel(self.target_lt,
                                                    self.prediction_lt)
Esempio n. 6
0
def parameters() -> controller.GetInputTargetAndPredictedParameters:
    """Creates the network parameters for the given inputs and flags.

    Returns:
      A GetInputTargetAndPredictedParameters containing network parameters for the
      given mode, metric, and other flags.
    """
    if FLAGS.metric == METRIC_LOSS:
        stride = FLAGS.loss_patch_stride
        shuffle = True
    else:
        if FLAGS.model == MODEL_CONCORDANCE:
            stride = CONCORDANCE_STITCH_STRIDE
        else:
            raise NotImplementedError('Unsupported model: %s' % FLAGS.model)
        # Shuffling breaks stitching.
        shuffle = False

    if FLAGS.mode == MODE_TRAIN:
        is_train = True
    else:
        is_train = False

    if FLAGS.model == MODEL_CONCORDANCE:
        core_model = functools.partial(concordance.core, base_depth=FLAGS.base_depth, num_gpus=FLAGS.num_gpus)
        add_head = functools.partial(model_util.add_head, is_residual_conv=True)
        extract_patch_size = CONCORDANCE_EXTRACT_PATCH_SIZE
        stitch_patch_size = CONCORDANCE_STITCH_PATCH_SIZE
    else:
        raise NotImplementedError('Unsupported model: %s' % FLAGS.model)

    dp = data_parameters()

    if shuffle:
        preprocess_num_threads = FLAGS.preprocess_shuffle_batch_num_threads
    else:
        # Thread racing is an additional source of shuffling, so we can only
        # use 1 thread per queue.
        preprocess_num_threads = 1
    if is_train or FLAGS.metric == METRIC_JITTER_STITCH:
        ap = augment.AugmentParameters(FLAGS.augment_offset_std,
                                       FLAGS.augment_multiplier_std,
                                       FLAGS.augment_noise_std)
    else:
        ap = None

    if FLAGS.metric == METRIC_INFER_FULL:
        bp = None
    else:
        bp = util.BatchParameters(FLAGS.preprocess_batch_size,
                                  preprocess_num_threads,
                                  FLAGS.preprocess_batch_capacity)

    if FLAGS.loss == LOSS_CROSS_ENTROPY:
        loss = util.softmax_cross_entropy
    elif FLAGS.loss == LOSS_RANKED_PROBABILITY_SCORE:
        loss = util.ranked_probability_score
    else:
        logging.fatal('Invalid loss: %s', FLAGS.loss)

    return controller.GetInputTargetAndPredictedParameters(
        dp, ap, extract_patch_size, stride, stitch_patch_size, bp, core_model,
        add_head, shuffle, NUM_CLASSES, loss, is_train)
Esempio n. 7
0
def data_parameters() -> data_provider.DataParameters:
    """Creates the DataParameters."""
    if FLAGS.read_pngs:
        if FLAGS.mode == MODE_TRAIN or FLAGS.mode == MODE_EVAL_TRAIN:
            directory = FLAGS.dataset_train_directory
        else:
            directory = FLAGS.dataset_eval_directory

        if FLAGS.metric == METRIC_LOSS:
            crop_size = FLAGS.loss_crop_size
        else:
            crop_size = FLAGS.stitch_crop_size

        io_parameters = data_provider.ReadPNGsParameters(directory, None, None,
                                                         crop_size)
    else:
        # Use an eighth of the dataset for validation.
        if FLAGS.mode == MODE_TRAIN or FLAGS.mode == MODE_EVAL_TRAIN:
            dataset = [
                FLAGS.dataset_pattern % i
                for i in range(FLAGS.dataset_num_shards)
                if (i % 8 != 0) or FLAGS.train_on_full_dataset
            ]
        else:
            dataset = [
                FLAGS.dataset_pattern % i
                for i in range(FLAGS.dataset_num_shards)
                if i % 8 == 0
            ]
        if FLAGS.metric == METRIC_LOSS:
            crop_size = FLAGS.loss_crop_size
        else:
            crop_size = FLAGS.stitch_crop_size

        if FLAGS.model == MODEL_CONCORDANCE:
            extract_patch_size = CONCORDANCE_EXTRACT_PATCH_SIZE
            stitch_patch_size = CONCORDANCE_STITCH_PATCH_SIZE
        else:
            raise NotImplementedError('Unsupported model: %s' % FLAGS.model)

        if FLAGS.mode == MODE_EXPORT:
            # Any padding will be done by the C++ caller.
            pad_width = 0
        else:
            pad_width = (extract_patch_size - stitch_patch_size) // 2

        io_parameters = data_provider.ReadTableParameters(
            dataset,
            FLAGS.is_recordio,
            util.BatchParameters(FLAGS.data_batch_size,
                                 FLAGS.data_batch_num_threads,
                                 FLAGS.data_batch_capacity),
            # Do non-deterministic data fetching, to increase the variety of what we
            # see in the visualizer.
            False,
            pad_width,
            crop_size)

    z_values = get_z_values()
    return data_provider.DataParameters(io_parameters, z_values,
                                        INPUT_CHANNEL_VALUES, TARGET_Z_VALUES,
                                        TARGET_CHANNEL_VALUES)
Esempio n. 8
0
def setup_stitch(
    gitapp: GetInputTargetAndPredictedParameters,
    name=None,
) -> Dict[str, lt.LabeledTensor]:
    """Creates diagnostic images.

  All diagnostic images are registered as summaries.

  Args:
    gitapp: GetInputTargetAndPredictedParameters.
    name: Optional op name.

  Returns:
    A mapping where the keys are names of summary images and the values
    are image tensors.
  """
    logging.info('Setting up stitch')
    with tf.name_scope(name, 'setup_stitch', []) as scope:
        (patch_centers, input_lt, target_lt, predict_input_lt,
         predict_target_lt) = get_input_target_and_predicted(gitapp)

        predicted_size = len(predict_input_lt.axes['row'])
        assert predicted_size == len(predict_input_lt.axes['column'])
        input_lt = util.crop_center(predicted_size, input_lt)
        target_lt = util.crop_center(predicted_size, target_lt)

        # For now, we're not handling overlap or missing data.
        assert gitapp.stride == predicted_size

        if gitapp.bp is not None:
            # Rebatch so a single tensor is all the patches in a single image.
            [input_lt, target_lt, predict_input_lt,
             predict_target_lt] = util.entry_point_batch(
                 [input_lt, target_lt, predict_input_lt, predict_target_lt],
                 bp=util.BatchParameters(size=len(patch_centers),
                                         num_threads=1,
                                         capacity=1),
                 enqueue_many=True,
                 entry_point_names=[
                     'input_stitch', 'target_stitch', 'predict_input_stitch',
                     'predict_target_stitch'
                 ],
                 name='stitch')

        rc = lt.ReshapeCoder(util.CANONICAL_AXIS_ORDER[3:], ['channel'])
        input_lt = rc.decode(
            ops.patches_to_image(patch_centers, rc.encode(input_lt)))

        rc = lt.ReshapeCoder(util.CANONICAL_AXIS_ORDER[3:], ['channel'])
        target_lt = rc.decode(
            ops.patches_to_image(patch_centers, rc.encode(target_lt)))

        rc = lt.ReshapeCoder(util.CANONICAL_PREDICTION_AXIS_ORDER[3:],
                             ['channel'])
        predict_input_lt = rc.decode(
            ops.patches_to_image(patch_centers, rc.encode(predict_input_lt)))

        rc = lt.ReshapeCoder(util.CANONICAL_PREDICTION_AXIS_ORDER[3:],
                             ['channel'])
        predict_target_lt = rc.decode(
            ops.patches_to_image(patch_centers, rc.encode(predict_target_lt)))

        def get_statistics(t: lt.LabeledTensor) -> lt.LabeledTensor:
            t = visualize.to_softmax(t)
            rc = lt.ReshapeCoder(list(t.axes.keys())[:-1], ['batch'])
            return rc.decode(ops.distribution_statistics(rc.encode(t)))

        # C++ entry points .
        with tf.name_scope(''):
            input_lt = lt.identity(input_lt, name='entry_point_stitched_input')
            target_lt = lt.identity(target_lt,
                                    name='entry_point_stitched_target')
            # The nodes are used purely to export data to C++.
            lt.identity(get_statistics(predict_input_lt),
                        name='entry_point_stitched_predicted_input')
            lt.identity(get_statistics(predict_target_lt),
                        name='entry_point_stitched_predicted_target')

        predict_input_lt = visualize.to_softmax(predict_input_lt)
        predict_target_lt = visualize.to_softmax(predict_target_lt)

        input_summary_lt = visualize.error_panel(input_lt, predict_input_lt)
        target_summary_lt = visualize.error_panel(target_lt, predict_target_lt)

        if gitapp.bp is not None:
            input_summary_lt, target_summary_lt = lt.batch(
                [input_summary_lt, target_summary_lt],
                # We'll see 3 images in the visualizer.
                batch_size=3,
                enqueue_many=True,
                num_threads=1,
                capacity=1,
                name='group')

        input_summary_lt = lt.identity(input_summary_lt,
                                       name=scope + 'input_error_panel')
        target_summary_lt = lt.identity(target_summary_lt,
                                        name=scope + 'target_error_panel')

        visualize_op_dict = {}
        visualize_op_dict['input'] = input_lt
        visualize_op_dict['predict_input'] = predict_input_lt
        visualize_op_dict['target'] = target_lt
        visualize_op_dict['predict_target'] = predict_target_lt

        def summarize(tag, labeled_tensor):
            visualize.summarize_image(labeled_tensor,
                                      name=scope + 'summarize/' + tag)
            visualize_op_dict[tag] = labeled_tensor

        summarize('input_error_panel', input_summary_lt)
        summarize('target_error_panel', target_summary_lt)

        return visualize_op_dict
Esempio n. 9
0
    def setUp(self):
        super(ReadSerializedExampleTest, self).setUp()

        self.example_lt = data_provider.read_serialized_example(
            [self.recordio_path()], True,
            util.BatchParameters(4, num_threads=1, capacity=2), True)