Esempio n. 1
0
 def get_statistics(t: lt.LabeledTensor) -> lt.LabeledTensor:
     t = visualize.to_softmax(t)
     rc = lt.ReshapeCoder(list(t.axes.keys())[:-1], ['batch'])
     return rc.decode(ops.distribution_statistics(rc.encode(t)))
 def get_statistics(tensor):
   rc = lt.ReshapeCoder(list(tensor.axes.keys())[:-1], ['batch'])
   return rc.decode(ops.distribution_statistics(rc.encode(tensor)))
Esempio n. 3
0
def error_panel(
    target_lt: lt.LabeledTensor,
    predicted_lt: lt.LabeledTensor,
    name: str = None,
) -> lt.LabeledTensor:
    """Creates a big informative error panel image.

  Args:
    target_lt: The ground truth values in canonical order.
    predicted_lt: The predicted values in canonical prediction order as a
     probability distribution.
    name: Optional op name.

  Returns:
    The error panel.
  """
    with tf.name_scope(name, 'error_panel',
                       [target_lt, predicted_lt]) as scope:
        target_lt = lt.transpose(target_lt, util.CANONICAL_AXIS_ORDER)
        predicted_lt = lt.transpose(predicted_lt,
                                    util.CANONICAL_PREDICTION_AXIS_ORDER)

        assert list(target_lt.axes.items())[:-1] == list(
            predicted_lt.axes.items())[:-1], (target_lt.axes,
                                              predicted_lt.axes)

        rc = lt.ReshapeCoder(list(predicted_lt.axes.keys())[:-1], ['batch'])
        statistic_lt = rc.decode(
            ops.distribution_statistics(rc.encode(predicted_lt)))

        columns = []

        def get_column(
            labeled_tensor: lt.LabeledTensor,
            color: Tuple[float, float, float],
        ):
            labeled_tensor = add_border(color, PAD_WIDTH, labeled_tensor)
            labeled_tensor = lt.transpose(
                labeled_tensor,
                ['batch', 'z', 'channel', 'row', 'column', 'color'])
            columns.append(
                lt.reshape(labeled_tensor, ['z', 'channel', 'row'], ['row']))

        # We only show these statistics.
        statistics = [
            'mode', 'median', 'mean', 'standard_deviation', 'entropy'
        ]

        # Show the statistics on the predictions.
        for s in statistics:
            get_column(lt.select(statistic_lt, {'statistic': s}), PURPLE)

        # Show the ground truth target image.
        get_column(lt.select(target_lt, {'mask': False}), TURQUOISE)

        # Show the cross entropy error.
        num_classes = len(predicted_lt.axes['class'])
        cross_entropy_lt = cross_entropy_error(
            util.onehot(num_classes, lt.select(target_lt, {'mask': False})),
            predicted_lt)
        get_column(cross_entropy_lt, RED)

        # Show the additive error visualizations.
        for s in statistics[:3]:
            rc = lt.ReshapeCoder(['z', 'channel'], ['channel'])
            error_lt = rc.decode(
                additive_error(
                    rc.encode(lt.select(target_lt, {'mask': False})),
                    rc.encode(lt.select(statistic_lt, {'statistic': s}))))
            get_column(error_lt, WHITE)

        # Show the subtractive error visualizations.
        for s in statistics[:3]:
            rc = lt.ReshapeCoder(['z', 'channel'], ['channel'])
            error_lt = rc.decode(
                subtractive_error(
                    rc.encode(lt.select(target_lt, {'mask': False})),
                    rc.encode(lt.select(statistic_lt, {'statistic': s}))))
            get_column(error_lt, BLACK)

        # Show the pixel presence / absence masks.
        get_column(lt.select(target_lt, {'mask': True}), MANGO)

        panel_lt = lt.concat(columns, 'column', name=scope)

        return panel_lt