예제 #1
0
    def _postprocess(self, cls_outputs, box_outputs, scales, mode='global'):
        """Postprocess class and box predictions."""
        if not mode:
            return cls_outputs, box_outputs

        if mode == 'global':
            return postprocess.postprocess_global(self.config.as_dict(),
                                                  cls_outputs, box_outputs,
                                                  scales)
        if mode == 'per_class':
            return postprocess.postprocess_per_class(self.config.as_dict(),
                                                     cls_outputs, box_outputs,
                                                     scales)
        if mode == 'combined':
            return postprocess.postprocess_combined(self.config.as_dict(),
                                                    cls_outputs, box_outputs,
                                                    scales)
        if mode == 'tflite':
            if scales is not None:
                # pre_mode should be None for TFLite.
                raise ValueError(
                    'scales not supported for TFLite post-processing')
            return postprocess.postprocess_tflite(self.config.as_dict(),
                                                  cls_outputs, box_outputs)
        raise ValueError('Unsupported postprocess mode {}'.format(mode))
예제 #2
0
    def _postprocess(self, cls_outputs, box_outputs, scales, mode='global'):
        """Postprocess class and box predictions."""
        if not mode:
            return cls_outputs, box_outputs

        # TODO(tanmingxing): remove this cast once FP16 works postprocessing.
        cls_outputs = [tf.cast(i, tf.float32) for i in cls_outputs]
        box_outputs = [tf.cast(i, tf.float32) for i in box_outputs]

        if mode == 'global':
            return postprocess.postprocess_global(self.config.as_dict(),
                                                  cls_outputs, box_outputs,
                                                  scales)
        if mode == 'per_class':
            return postprocess.postprocess_per_class(self.config.as_dict(),
                                                     cls_outputs, box_outputs,
                                                     scales)
        if mode == 'tflite':
            if scales is not None:
                # pre_mode should be None for TFLite.
                raise ValueError(
                    'scales not supported for TFLite post-processing')
            return postprocess.postprocess_tflite(self.config.as_dict(),
                                                  cls_outputs, box_outputs)
        raise ValueError('Unsupported postprocess mode {}'.format(mode))