Exemplo n.º 1
0
 def _logistic_loss(labels, logits):
     labels = head_lib._assert_range(  # pylint:disable=protected-access
         labels,
         n_classes=2,
         message='Labels must be in range [0, 1]')
     return nn.sigmoid_cross_entropy_with_logits(labels=labels,
                                                 logits=logits)
Exemplo n.º 2
0
 def _process_labels(self, labels):
     if labels is None:
         raise ValueError(
             'You must provide a labels Tensor. Given: None. '
             'Suggested troubleshooting steps: Check that your data contain '
             'your label feature. Check that your input_fn properly parses and '
             'returns labels.')
     if isinstance(labels, sparse_tensor.SparseTensor):
         if labels.dtype == dtypes.string:
             label_ids_values = lookup_ops.index_table_from_tensor(
                 vocabulary_list=tuple(self._label_vocabulary),
                 name='class_id_lookup').lookup(labels.values)
             label_ids = sparse_tensor.SparseTensor(
                 indices=labels.indices,
                 values=label_ids_values,
                 dense_shape=labels.dense_shape)
             return math_ops.to_int64(
                 sparse_ops.sparse_to_indicator(label_ids, self._n_classes))
         else:
             err_msg = (
                 r'labels must be an integer SparseTensor with values in '
                 r'[0, {})'.format(self._n_classes))
             assert_int = check_ops.assert_integer(labels.values,
                                                   message=err_msg)
             assert_less = check_ops.assert_less(labels.values,
                                                 ops.convert_to_tensor(
                                                     self._n_classes,
                                                     dtype=labels.dtype),
                                                 message=err_msg)
             assert_greater = check_ops.assert_non_negative(labels.values,
                                                            message=err_msg)
             with ops.control_dependencies(
                 [assert_int, assert_less, assert_greater]):
                 return math_ops.to_int64(
                     sparse_ops.sparse_to_indicator(labels,
                                                    self._n_classes))
     err_msg = (
         r'labels must be an integer indicator Tensor with values in [0, 1]'
     )
     return head_lib._assert_range(labels, 2, message=err_msg)  # pylint:disable=protected-access,
Exemplo n.º 3
0
 def _process_labels(self, labels):
   if labels is None:
     raise ValueError(
         'You must provide a labels Tensor. Given: None. '
         'Suggested troubleshooting steps: Check that your data contain '
         'your label feature. Check that your input_fn properly parses and '
         'returns labels.')
   if isinstance(labels, sparse_tensor.SparseTensor):
     if labels.dtype == dtypes.string:
       label_ids_values = lookup_ops.index_table_from_tensor(
           vocabulary_list=tuple(self._label_vocabulary),
           name='class_id_lookup').lookup(labels.values)
       label_ids = sparse_tensor.SparseTensor(
           indices=labels.indices,
           values=label_ids_values,
           dense_shape=labels.dense_shape)
       return math_ops.to_int64(
           sparse_ops.sparse_to_indicator(label_ids, self._n_classes))
     else:
       err_msg = (
           r'labels must be an integer SparseTensor with values in '
           r'[0, {})'.format(self._n_classes))
       assert_int = check_ops.assert_integer(
           labels.values, message=err_msg)
       assert_less = check_ops.assert_less(
           labels.values,
           ops.convert_to_tensor(self._n_classes, dtype=labels.dtype),
           message=err_msg)
       assert_greater = check_ops.assert_non_negative(
           labels.values, message=err_msg)
       with ops.control_dependencies(
           [assert_int, assert_less, assert_greater]):
         return math_ops.to_int64(
             sparse_ops.sparse_to_indicator(labels, self._n_classes))
   err_msg = (
       r'labels must be an integer indicator Tensor with values in [0, 1]')
   return head_lib._assert_range(labels, 2, message=err_msg)  # pylint:disable=protected-access,
Exemplo n.º 4
0
 def _logistic_loss(labels, logits):
   labels = head_lib._assert_range(  # pylint:disable=protected-access
       labels, n_classes=2, message='Labels must be in range [0, 1]')
   return nn.sigmoid_cross_entropy_with_logits(
       labels=labels, logits=logits)