Exemplo n.º 1
0
def setup_losses(
    gitapp: GetInputTargetAndPredictedParameters,
    name: str = None,
) -> Tuple[Dict[str, lt.LabeledTensor], Dict[str, lt.LabeledTensor]]:
    """Creates cross entropy losses.

  Args:
    gitapp: GetInputTargetAndPredictedParameters.
    name: Optional op name.

  Returns:
    A dictionary of tensors with the input reconstruction losses.

    A dictionary of tensors with the target prediction losses.
  """
    logging.info('Setting up losses')
    with tf.name_scope(name, 'setup_losses', []) as scope:
        (_, input_lt, target_lt, predict_input_lt,
         predict_target_lt) = get_input_target_and_predicted(gitapp)

        predicted_size = len(predict_input_lt.axes['row'])
        visualize.summarize_image(
            visualize.error_panel(util.crop_center(predicted_size, input_lt),
                                  visualize.to_softmax(predict_input_lt),
                                  name=scope + 'input_patch_error_panel'))
        visualize.summarize_image(
            visualize.error_panel(util.crop_center(predicted_size, target_lt),
                                  visualize.to_softmax(predict_target_lt),
                                  name=scope + 'target_patch_error_panel'))

        def mean(lts: Dict[str, lt.LabeledTensor]) -> tf.Tensor:
            sum_op = tf.add_n([t.tensor for t in lts.values()])
            return sum_op / float(len(lts))

        tag = 'input'
        input_loss_lts = itemize_losses(gitapp.loss,
                                        input_lt,
                                        predict_input_lt,
                                        name=scope + tag)
        tf.summary.scalar(name='loss/' + tag, tensor=mean(input_loss_lts))

        tag = 'target'
        target_loss_lts = itemize_losses(gitapp.loss,
                                         target_lt,
                                         predict_target_lt,
                                         name=scope + tag)
        tf.summary.scalar(name='loss/' + tag, tensor=mean(target_loss_lts))

        variables = tf.global_variables()
        for v in variables:
            tf.summary.histogram(name='variable/' + v.name, values=v)

        return input_loss_lts, target_loss_lts
Exemplo n.º 2
0
 def get_statistics(t: lt.LabeledTensor) -> lt.LabeledTensor:
     t = visualize.to_softmax(t)
     rc = lt.ReshapeCoder(list(t.axes.keys())[:-1], ['batch'])
     return rc.decode(ops.distribution_statistics(rc.encode(t)))
Exemplo n.º 3
0
def setup_stitch(
    gitapp: GetInputTargetAndPredictedParameters,
    name=None,
) -> Dict[str, lt.LabeledTensor]:
    """Creates diagnostic images.

  All diagnostic images are registered as summaries.

  Args:
    gitapp: GetInputTargetAndPredictedParameters.
    name: Optional op name.

  Returns:
    A mapping where the keys are names of summary images and the values
    are image tensors.
  """
    logging.info('Setting up stitch')
    with tf.name_scope(name, 'setup_stitch', []) as scope:
        (patch_centers, input_lt, target_lt, predict_input_lt,
         predict_target_lt) = get_input_target_and_predicted(gitapp)

        predicted_size = len(predict_input_lt.axes['row'])
        assert predicted_size == len(predict_input_lt.axes['column'])
        input_lt = util.crop_center(predicted_size, input_lt)
        target_lt = util.crop_center(predicted_size, target_lt)

        # For now, we're not handling overlap or missing data.
        assert gitapp.stride == predicted_size

        if gitapp.bp is not None:
            # Rebatch so a single tensor is all the patches in a single image.
            [input_lt, target_lt, predict_input_lt,
             predict_target_lt] = util.entry_point_batch(
                 [input_lt, target_lt, predict_input_lt, predict_target_lt],
                 bp=util.BatchParameters(size=len(patch_centers),
                                         num_threads=1,
                                         capacity=1),
                 enqueue_many=True,
                 entry_point_names=[
                     'input_stitch', 'target_stitch', 'predict_input_stitch',
                     'predict_target_stitch'
                 ],
                 name='stitch')

        rc = lt.ReshapeCoder(util.CANONICAL_AXIS_ORDER[3:], ['channel'])
        input_lt = rc.decode(
            ops.patches_to_image(patch_centers, rc.encode(input_lt)))

        rc = lt.ReshapeCoder(util.CANONICAL_AXIS_ORDER[3:], ['channel'])
        target_lt = rc.decode(
            ops.patches_to_image(patch_centers, rc.encode(target_lt)))

        rc = lt.ReshapeCoder(util.CANONICAL_PREDICTION_AXIS_ORDER[3:],
                             ['channel'])
        predict_input_lt = rc.decode(
            ops.patches_to_image(patch_centers, rc.encode(predict_input_lt)))

        rc = lt.ReshapeCoder(util.CANONICAL_PREDICTION_AXIS_ORDER[3:],
                             ['channel'])
        predict_target_lt = rc.decode(
            ops.patches_to_image(patch_centers, rc.encode(predict_target_lt)))

        def get_statistics(t: lt.LabeledTensor) -> lt.LabeledTensor:
            t = visualize.to_softmax(t)
            rc = lt.ReshapeCoder(list(t.axes.keys())[:-1], ['batch'])
            return rc.decode(ops.distribution_statistics(rc.encode(t)))

        # C++ entry points .
        with tf.name_scope(''):
            input_lt = lt.identity(input_lt, name='entry_point_stitched_input')
            target_lt = lt.identity(target_lt,
                                    name='entry_point_stitched_target')
            # The nodes are used purely to export data to C++.
            lt.identity(get_statistics(predict_input_lt),
                        name='entry_point_stitched_predicted_input')
            lt.identity(get_statistics(predict_target_lt),
                        name='entry_point_stitched_predicted_target')

        predict_input_lt = visualize.to_softmax(predict_input_lt)
        predict_target_lt = visualize.to_softmax(predict_target_lt)

        input_summary_lt = visualize.error_panel(input_lt, predict_input_lt)
        target_summary_lt = visualize.error_panel(target_lt, predict_target_lt)

        if gitapp.bp is not None:
            input_summary_lt, target_summary_lt = lt.batch(
                [input_summary_lt, target_summary_lt],
                # We'll see 3 images in the visualizer.
                batch_size=3,
                enqueue_many=True,
                num_threads=1,
                capacity=1,
                name='group')

        input_summary_lt = lt.identity(input_summary_lt,
                                       name=scope + 'input_error_panel')
        target_summary_lt = lt.identity(target_summary_lt,
                                        name=scope + 'target_error_panel')

        visualize_op_dict = {}
        visualize_op_dict['input'] = input_lt
        visualize_op_dict['predict_input'] = predict_input_lt
        visualize_op_dict['target'] = target_lt
        visualize_op_dict['predict_target'] = predict_target_lt

        def summarize(tag, labeled_tensor):
            visualize.summarize_image(labeled_tensor,
                                      name=scope + 'summarize/' + tag)
            visualize_op_dict[tag] = labeled_tensor

        summarize('input_error_panel', input_summary_lt)
        summarize('target_error_panel', target_summary_lt)

        return visualize_op_dict