def setup_losses(
    gitapp: GetInputTargetAndPredictedParameters,
    name: str = None,
) -> Tuple[Dict[str, lt.LabeledTensor], Dict[str, lt.LabeledTensor]]:
    """Creates cross entropy losses.

  Args:
    gitapp: GetInputTargetAndPredictedParameters.
    name: Optional op name.

  Returns:
    A dictionary of tensors with the input reconstruction losses.

    A dictionary of tensors with the target prediction losses.
  """
    logging.info('Setting up losses')
    with tf.name_scope(name, 'setup_losses', []) as scope:
        (_, input_lt, target_lt, predict_input_lt,
         predict_target_lt) = get_input_target_and_predicted(gitapp)

        predicted_size = len(predict_input_lt.axes['row'])
        visualize.summarize_image(
            visualize.error_panel(util.crop_center(predicted_size, input_lt),
                                  visualize.to_softmax(predict_input_lt),
                                  name=scope + 'input_patch_error_panel'))
        visualize.summarize_image(
            visualize.error_panel(util.crop_center(predicted_size, target_lt),
                                  visualize.to_softmax(predict_target_lt),
                                  name=scope + 'target_patch_error_panel'))

        def mean(lts: Dict[str, lt.LabeledTensor]) -> tf.Tensor:
            sum_op = tf.add_n([t.tensor for t in lts.values()])
            return sum_op / float(len(lts))

        tag = 'input'
        input_loss_lts = itemize_losses(gitapp.loss,
                                        input_lt,
                                        predict_input_lt,
                                        name=scope + tag)
        tf.summary.scalar(name='loss/' + tag, tensor=mean(input_loss_lts))

        tag = 'target'
        target_loss_lts = itemize_losses(gitapp.loss,
                                         target_lt,
                                         predict_target_lt,
                                         name=scope + tag)
        tf.summary.scalar(name='loss/' + tag, tensor=mean(target_loss_lts))

        variables = tf.global_variables()
        for v in variables:
            tf.summary.histogram(name='variable/' + v.name, values=v)

        return input_loss_lts, target_loss_lts
Exemple #2
0
    def setUp(self):
        super(ErrorPanelTest, self).setUp()

        rtp = data_provider.ReadTableParameters([self.recordio_path()], True,
                                                util.BatchParameters(2, 1, 2),
                                                True, 0, 768)
        dp = data_provider.DataParameters(rtp, self.input_z_values,
                                          self.input_channel_values,
                                          self.target_z_values,
                                          self.target_channel_values)
        _, batch_target_lt = data_provider.cropped_input_and_target(dp)

        self.prediction_lt = lt.slice(
            lt.select(batch_target_lt, {'mask': False}),
            {'batch': slice(0, 1)})
        self.prediction_lt = util.onehot(16, self.prediction_lt)

        self.target_lt = lt.slice(batch_target_lt, {'batch': slice(1, 2)})

        self.error_panel_lt = visualize.error_panel(self.target_lt,
                                                    self.prediction_lt)
def setup_stitch(
    gitapp: GetInputTargetAndPredictedParameters,
    name=None,
) -> Dict[str, lt.LabeledTensor]:
    """Creates diagnostic images.

  All diagnostic images are registered as summaries.

  Args:
    gitapp: GetInputTargetAndPredictedParameters.
    name: Optional op name.

  Returns:
    A mapping where the keys are names of summary images and the values
    are image tensors.
  """
    logging.info('Setting up stitch')
    with tf.name_scope(name, 'setup_stitch', []) as scope:
        (patch_centers, input_lt, target_lt, predict_input_lt,
         predict_target_lt) = get_input_target_and_predicted(gitapp)

        predicted_size = len(predict_input_lt.axes['row'])
        assert predicted_size == len(predict_input_lt.axes['column'])
        input_lt = util.crop_center(predicted_size, input_lt)
        target_lt = util.crop_center(predicted_size, target_lt)

        # For now, we're not handling overlap or missing data.
        assert gitapp.stride == predicted_size

        if gitapp.bp is not None:
            # Rebatch so a single tensor is all the patches in a single image.
            [input_lt, target_lt, predict_input_lt,
             predict_target_lt] = util.entry_point_batch(
                 [input_lt, target_lt, predict_input_lt, predict_target_lt],
                 bp=util.BatchParameters(size=len(patch_centers),
                                         num_threads=1,
                                         capacity=1),
                 enqueue_many=True,
                 entry_point_names=[
                     'input_stitch', 'target_stitch', 'predict_input_stitch',
                     'predict_target_stitch'
                 ],
                 name='stitch')

        rc = lt.ReshapeCoder(util.CANONICAL_AXIS_ORDER[3:], ['channel'])
        input_lt = rc.decode(
            ops.patches_to_image(patch_centers, rc.encode(input_lt)))

        rc = lt.ReshapeCoder(util.CANONICAL_AXIS_ORDER[3:], ['channel'])
        target_lt = rc.decode(
            ops.patches_to_image(patch_centers, rc.encode(target_lt)))

        rc = lt.ReshapeCoder(util.CANONICAL_PREDICTION_AXIS_ORDER[3:],
                             ['channel'])
        predict_input_lt = rc.decode(
            ops.patches_to_image(patch_centers, rc.encode(predict_input_lt)))

        rc = lt.ReshapeCoder(util.CANONICAL_PREDICTION_AXIS_ORDER[3:],
                             ['channel'])
        predict_target_lt = rc.decode(
            ops.patches_to_image(patch_centers, rc.encode(predict_target_lt)))

        def get_statistics(t: lt.LabeledTensor) -> lt.LabeledTensor:
            t = visualize.to_softmax(t)
            rc = lt.ReshapeCoder(list(t.axes.keys())[:-1], ['batch'])
            return rc.decode(ops.distribution_statistics(rc.encode(t)))

        # C++ entry points .
        with tf.name_scope(''):
            input_lt = lt.identity(input_lt, name='entry_point_stitched_input')
            target_lt = lt.identity(target_lt,
                                    name='entry_point_stitched_target')
            # The nodes are used purely to export data to C++.
            lt.identity(get_statistics(predict_input_lt),
                        name='entry_point_stitched_predicted_input')
            lt.identity(get_statistics(predict_target_lt),
                        name='entry_point_stitched_predicted_target')

        predict_input_lt = visualize.to_softmax(predict_input_lt)
        predict_target_lt = visualize.to_softmax(predict_target_lt)

        input_summary_lt = visualize.error_panel(input_lt, predict_input_lt)
        target_summary_lt = visualize.error_panel(target_lt, predict_target_lt)

        if gitapp.bp is not None:
            input_summary_lt, target_summary_lt = lt.batch(
                [input_summary_lt, target_summary_lt],
                # We'll see 3 images in the visualizer.
                batch_size=3,
                enqueue_many=True,
                num_threads=1,
                capacity=1,
                name='group')

        input_summary_lt = lt.identity(input_summary_lt,
                                       name=scope + 'input_error_panel')
        target_summary_lt = lt.identity(target_summary_lt,
                                        name=scope + 'target_error_panel')

        visualize_op_dict = {}
        visualize_op_dict['input'] = input_lt
        visualize_op_dict['predict_input'] = predict_input_lt
        visualize_op_dict['target'] = target_lt
        visualize_op_dict['predict_target'] = predict_target_lt

        def summarize(tag, labeled_tensor):
            visualize.summarize_image(labeled_tensor,
                                      name=scope + 'summarize/' + tag)
            visualize_op_dict[tag] = labeled_tensor

        summarize('input_error_panel', input_summary_lt)
        summarize('target_error_panel', target_summary_lt)

        return visualize_op_dict