def get_statistics(t: lt.LabeledTensor) -> lt.LabeledTensor: t = visualize.to_softmax(t) rc = lt.ReshapeCoder(list(t.axes.keys())[:-1], ['batch']) return rc.decode(ops.distribution_statistics(rc.encode(t)))
def get_statistics(tensor): rc = lt.ReshapeCoder(list(tensor.axes.keys())[:-1], ['batch']) return rc.decode(ops.distribution_statistics(rc.encode(tensor)))
def error_panel( target_lt: lt.LabeledTensor, predicted_lt: lt.LabeledTensor, name: str = None, ) -> lt.LabeledTensor: """Creates a big informative error panel image. Args: target_lt: The ground truth values in canonical order. predicted_lt: The predicted values in canonical prediction order as a probability distribution. name: Optional op name. Returns: The error panel. """ with tf.name_scope(name, 'error_panel', [target_lt, predicted_lt]) as scope: target_lt = lt.transpose(target_lt, util.CANONICAL_AXIS_ORDER) predicted_lt = lt.transpose(predicted_lt, util.CANONICAL_PREDICTION_AXIS_ORDER) assert list(target_lt.axes.items())[:-1] == list( predicted_lt.axes.items())[:-1], (target_lt.axes, predicted_lt.axes) rc = lt.ReshapeCoder(list(predicted_lt.axes.keys())[:-1], ['batch']) statistic_lt = rc.decode( ops.distribution_statistics(rc.encode(predicted_lt))) columns = [] def get_column( labeled_tensor: lt.LabeledTensor, color: Tuple[float, float, float], ): labeled_tensor = add_border(color, PAD_WIDTH, labeled_tensor) labeled_tensor = lt.transpose( labeled_tensor, ['batch', 'z', 'channel', 'row', 'column', 'color']) columns.append( lt.reshape(labeled_tensor, ['z', 'channel', 'row'], ['row'])) # We only show these statistics. statistics = [ 'mode', 'median', 'mean', 'standard_deviation', 'entropy' ] # Show the statistics on the predictions. for s in statistics: get_column(lt.select(statistic_lt, {'statistic': s}), PURPLE) # Show the ground truth target image. get_column(lt.select(target_lt, {'mask': False}), TURQUOISE) # Show the cross entropy error. num_classes = len(predicted_lt.axes['class']) cross_entropy_lt = cross_entropy_error( util.onehot(num_classes, lt.select(target_lt, {'mask': False})), predicted_lt) get_column(cross_entropy_lt, RED) # Show the additive error visualizations. for s in statistics[:3]: rc = lt.ReshapeCoder(['z', 'channel'], ['channel']) error_lt = rc.decode( additive_error( rc.encode(lt.select(target_lt, {'mask': False})), rc.encode(lt.select(statistic_lt, {'statistic': s})))) get_column(error_lt, WHITE) # Show the subtractive error visualizations. for s in statistics[:3]: rc = lt.ReshapeCoder(['z', 'channel'], ['channel']) error_lt = rc.decode( subtractive_error( rc.encode(lt.select(target_lt, {'mask': False})), rc.encode(lt.select(statistic_lt, {'statistic': s})))) get_column(error_lt, BLACK) # Show the pixel presence / absence masks. get_column(lt.select(target_lt, {'mask': True}), MANGO) panel_lt = lt.concat(columns, 'column', name=scope) return panel_lt