コード例 #1
0
 def _processed_labels(self, logits, labels):
     labels = base_head.check_dense_labels_match_logits_and_reshape(
         labels=labels,
         logits=logits,
         expected_labels_dimension=self._logits_dimension)
     labels = math_ops.to_float(labels)
     return labels
コード例 #2
0
 def _processed_labels(self, logits, labels):
     labels = base_head.check_dense_labels_match_logits_and_reshape(
         labels=labels,
         logits=logits,
         expected_labels_dimension=self._logits_dimension)
     labels = tf.cast(labels, dtype=tf.dtypes.float32)
     return labels
コード例 #3
0
 def _processed_labels(self, logits, labels):
   """Converts labels to integer id space."""
   labels = base_head.check_dense_labels_match_logits_and_reshape(
       labels=labels, logits=logits, expected_labels_dimension=1)
   if self._label_vocabulary is not None:
     labels = self._class_id_table.lookup(labels)
   labels = math_ops.cast(labels, dtype=dtypes.float32)
   return base_head.check_label_range(labels, n_classes=2)
コード例 #4
0
ファイル: multi_class_head.py プロジェクト: AdiosSora/FOCUS
 def _processed_labels(self, logits, labels):
   """Converts labels to integer id space."""
   labels = base_head.check_dense_labels_match_logits_and_reshape(
       labels=labels, logits=logits, expected_labels_dimension=1)
   if self._label_vocabulary is None:
     if not labels.dtype.is_integer:
       raise ValueError(
           'Labels dtype should be integer. Instead got {}.'.format(
               labels.dtype))
     label_ids = labels
   else:
     if labels.dtype != tf.dtypes.string:
       raise ValueError('Labels dtype should be string if there is a '
                        'vocabulary. Instead got {}'.format(labels.dtype))
     label_ids = self._class_id_table.lookup(labels)
   return base_head.check_label_range(label_ids, self._n_classes)
コード例 #5
0
    def _processed_labels(self, logits, labels):
        """Converts labels to integer id space."""
        if labels is None:
            raise ValueError(base_head._LABEL_NONE_ERR_MSG)  # pylint:disable=protected-access
        if isinstance(labels, sparse_tensor.SparseTensor):
            label_values = labels.values
            if labels.dtype == dtypes.string:
                label_ids_values = self._class_id_table.lookup(label_values)
                label_ids = sparse_tensor.SparseTensor(
                    indices=labels.indices,
                    values=label_ids_values,
                    dense_shape=labels.dense_shape)
                processed_labels = sparse_ops.sparse_to_indicator(
                    label_ids, self._n_classes)
            else:
                if not label_values.dtype.is_integer:
                    raise ValueError(
                        'Labels dtype should be integer. Instead got {}.'.
                        format(label_values.dtype))
                err_msg = (
                    r'labels must be an integer SparseTensor with values in '
                    r'[0, {})'.format(self._n_classes))
                label_values = base_head.check_label_range(labels.values,
                                                           self._n_classes,
                                                           message=err_msg)
                if context.executing_eagerly():
                    processed_labels = sparse_ops.sparse_to_indicator(
                        labels, self._n_classes)
                else:
                    with ops.control_dependencies([label_values]):
                        processed_labels = sparse_ops.sparse_to_indicator(
                            labels, self._n_classes)
            processed_labels = math_ops.cast(processed_labels,
                                             dtype=dtypes.int64)
        else:
            err_msg = (
                r'labels must be an integer indicator Tensor with values in [0, 1]'
            )
            processed_labels = base_head.check_label_range(labels,
                                                           2,
                                                           message=err_msg)

        return base_head.check_dense_labels_match_logits_and_reshape(
            labels=processed_labels,
            logits=logits,
            expected_labels_dimension=self.logits_dimension)