Esempio n. 1
0
    def test_tensor_shape_checking_in_graph_mode(self):
        """Test for shape checking of tensor with partially defined shape."""
        labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32,
                                                      shape=(None, 1))
        logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32,
                                                      shape=(None, 1))
        labels_input = np.array([[-10.], [10.]], dtype=np.float32)
        logits_input = np.array([[1.], [0.]], dtype=np.float32)

        loss = np.array([[1.], [2.]], dtype=np.float32)

        def _loss_fn(labels, logits):
            check_labels = tf.debugging.Assert(tf.reduce_all(
                tf.math.equal(labels, labels_input)),
                                               data=[labels])
            check_logits = tf.debugging.Assert(tf.reduce_all(
                tf.math.equal(logits, logits_input)),
                                               data=[logits])
            with tf.control_dependencies([check_labels, check_logits]):
                return tf.constant(loss)

        unweighted_loss = base_head.call_loss_fn(
            loss_fn=_loss_fn,
            labels=labels_placeholder,
            logits=logits_placeholder,
            features={'x': np.array(((42, ), ), dtype=np.int32)})
        with self.cached_session():
            self.assertAllClose(
                unweighted_loss.eval({
                    labels_placeholder: labels_input,
                    logits_placeholder: logits_input
                }), loss)
Esempio n. 2
0
 def _unweighted_loss_and_weights(self, logits, labels, features):
   """Computes unweighted loss and weights."""
   if self._loss_fn:
     unweighted_loss = base_head.call_loss_fn(
         loss_fn=self._loss_fn, labels=labels, logits=logits,
         features=features, expected_loss_dim=1)
   else:
     unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
         labels=labels, logits=logits)
   weights = base_head.get_weights_and_check_match_logits(
       features=features, weight_column=self._weight_column, logits=logits)
   return unweighted_loss, weights
Esempio n. 3
0
 def _unweighted_loss_and_weights(self, logits, labels, features):
   """Computes loss spec."""
   if self._loss_fn:
     unweighted_loss = base_head.call_loss_fn(
         loss_fn=self._loss_fn, labels=labels, logits=logits,
         features=features, expected_loss_dim=self._logits_dimension)
   else:
     unweighted_loss = losses.mean_squared_error(
         labels=labels, predictions=logits, reduction=losses.Reduction.NONE)
   weights = base_head.get_weights_and_check_match_logits(
       features=features, weight_column=self._weight_column, logits=logits,
       allow_per_logit_weights=True)
   return unweighted_loss, weights
Esempio n. 4
0
 def _unweighted_loss_and_weights(self, logits, label_ids, features):
   """Computes loss spec."""
   if self._loss_fn:
     unweighted_loss = base_head.call_loss_fn(
         loss_fn=self._loss_fn,
         labels=label_ids,
         logits=logits,
         features=features,
         expected_loss_dim=1)
   else:
     unweighted_loss = tf.compat.v1.losses.sparse_softmax_cross_entropy(
         labels=label_ids,
         logits=logits,
         reduction=tf.compat.v1.losses.Reduction.NONE)
     # Restore the squeezed dim, so unweighted_loss matches the weights shape.
     unweighted_loss = tf.compat.v1.expand_dims(unweighted_loss, axis=-1)
   weights = base_head.get_weights_and_check_match_logits(
       features=features, weight_column=self._weight_column, logits=logits)
   return unweighted_loss, weights
Esempio n. 5
0
 def _unweighted_loss_and_weights(self, logits, processed_labels, features):
   """Computes loss spec."""
   if self._loss_fn:
     unweighted_loss = base_head.call_loss_fn(
         loss_fn=self._loss_fn,
         labels=processed_labels,
         logits=logits,
         features=features,
         expected_loss_dim=1)
   else:
     unweighted_loss = tf.compat.v1.losses.sigmoid_cross_entropy(
         multi_class_labels=processed_labels,
         logits=logits,
         reduction=tf.compat.v1.losses.Reduction.NONE)
     # Averages loss over classes.
     unweighted_loss = tf.math.reduce_mean(
         unweighted_loss, axis=-1, keepdims=True)
   weights = base_head.get_weights_and_check_match_logits(
       features=features, weight_column=self._weight_column, logits=logits)
   return unweighted_loss, weights