Ejemplo n.º 1
0
    def _postprocess(self, cls_outputs, box_outputs, scales, mode='global'):
        """Postprocess class and box predictions."""
        if not mode:
            return cls_outputs, box_outputs

        if mode == 'global':
            return postprocess.postprocess_global(self.config.as_dict(),
                                                  cls_outputs, box_outputs,
                                                  scales)
        if mode == 'per_class':
            return postprocess.postprocess_per_class(self.config.as_dict(),
                                                     cls_outputs, box_outputs,
                                                     scales)
        if mode == 'combined':
            return postprocess.postprocess_combined(self.config.as_dict(),
                                                    cls_outputs, box_outputs,
                                                    scales)
        if mode == 'tflite':
            if scales is not None:
                # pre_mode should be None for TFLite.
                raise ValueError(
                    'scales not supported for TFLite post-processing')
            return postprocess.postprocess_tflite(self.config.as_dict(),
                                                  cls_outputs, box_outputs)
        raise ValueError('Unsupported postprocess mode {}'.format(mode))
Ejemplo n.º 2
0
 def _postprocess(self, cls_outputs, box_outputs, scales, mode='global'):
   if not mode:
     return cls_outputs, box_outputs
   if mode == 'global':
     return postprocess.postprocess_global(self.config.as_dict(), cls_outputs,
                                           box_outputs, scales)
   if mode == 'per_class':
     return postprocess.postprocess_per_class(self.config.as_dict(),
                                              cls_outputs, box_outputs, scales)
   raise ValueError('Unsupported postprocess mode {}'.format(mode))
Ejemplo n.º 3
0
  def _postprocess(self, cls_outputs, box_outputs, scales, mode='global'):
    if not mode:
      return cls_outputs, box_outputs

    # TODO(tanmingxing): remove this cast once FP16 works postprocessing.
    cls_outputs = [tf.cast(i, tf.float32) for i in cls_outputs]
    box_outputs = [tf.cast(i, tf.float32) for i in box_outputs]

    if mode == 'global':
      return postprocess.postprocess_global(self.config.as_dict(), cls_outputs,
                                            box_outputs, scales)
    if mode == 'per_class':
      return postprocess.postprocess_per_class(self.config.as_dict(),
                                               cls_outputs, box_outputs, scales)
    raise ValueError('Unsupported postprocess mode {}'.format(mode))
Ejemplo n.º 4
0
  def _postprocess(self, cls_outputs, box_outputs, scales, mode='global'):
    """Postprocess class and box predictions."""
    if not mode:
      return cls_outputs, box_outputs

    # TODO(tanmingxing): remove this cast once FP16 works postprocessing.
    cls_outputs = [tf.cast(i, tf.float32) for i in cls_outputs]
    box_outputs = [tf.cast(i, tf.float32) for i in box_outputs]

    if mode == 'global':
      return postprocess.postprocess_global(self.config.as_dict(), cls_outputs,
                                            box_outputs, scales)
    if mode == 'per_class':
      return postprocess.postprocess_per_class(self.config.as_dict(),
                                               cls_outputs, box_outputs, scales)
    if mode == 'tflite':
      if scales is not None:
        # pre_mode should be None for TFLite.
        raise ValueError('scales not supported for TFLite post-processing')
      return postprocess.postprocess_tflite(self.config.as_dict(), cls_outputs,
                                            box_outputs)
    raise ValueError('Unsupported postprocess mode {}'.format(mode))