def test_tensor_shape_checking_in_graph_mode(self): """Test for shape checking of tensor with partially defined shape.""" labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32, shape=(None, 1)) logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32, shape=(None, 1)) labels_input = np.array([[-10.], [10.]], dtype=np.float32) logits_input = np.array([[1.], [0.]], dtype=np.float32) loss = np.array([[1.], [2.]], dtype=np.float32) def _loss_fn(labels, logits): check_labels = tf.debugging.Assert(tf.reduce_all( tf.math.equal(labels, labels_input)), data=[labels]) check_logits = tf.debugging.Assert(tf.reduce_all( tf.math.equal(logits, logits_input)), data=[logits]) with tf.control_dependencies([check_labels, check_logits]): return tf.constant(loss) unweighted_loss = base_head.call_loss_fn( loss_fn=_loss_fn, labels=labels_placeholder, logits=logits_placeholder, features={'x': np.array(((42, ), ), dtype=np.int32)}) with self.cached_session(): self.assertAllClose( unweighted_loss.eval({ labels_placeholder: labels_input, logits_placeholder: logits_input }), loss)
def _unweighted_loss_and_weights(self, logits, labels, features): """Computes unweighted loss and weights.""" if self._loss_fn: unweighted_loss = base_head.call_loss_fn( loss_fn=self._loss_fn, labels=labels, logits=logits, features=features, expected_loss_dim=1) else: unweighted_loss = nn.sigmoid_cross_entropy_with_logits( labels=labels, logits=logits) weights = base_head.get_weights_and_check_match_logits( features=features, weight_column=self._weight_column, logits=logits) return unweighted_loss, weights
def _unweighted_loss_and_weights(self, logits, labels, features): """Computes loss spec.""" if self._loss_fn: unweighted_loss = base_head.call_loss_fn( loss_fn=self._loss_fn, labels=labels, logits=logits, features=features, expected_loss_dim=self._logits_dimension) else: unweighted_loss = losses.mean_squared_error( labels=labels, predictions=logits, reduction=losses.Reduction.NONE) weights = base_head.get_weights_and_check_match_logits( features=features, weight_column=self._weight_column, logits=logits, allow_per_logit_weights=True) return unweighted_loss, weights
def _unweighted_loss_and_weights(self, logits, label_ids, features): """Computes loss spec.""" if self._loss_fn: unweighted_loss = base_head.call_loss_fn( loss_fn=self._loss_fn, labels=label_ids, logits=logits, features=features, expected_loss_dim=1) else: unweighted_loss = tf.compat.v1.losses.sparse_softmax_cross_entropy( labels=label_ids, logits=logits, reduction=tf.compat.v1.losses.Reduction.NONE) # Restore the squeezed dim, so unweighted_loss matches the weights shape. unweighted_loss = tf.compat.v1.expand_dims(unweighted_loss, axis=-1) weights = base_head.get_weights_and_check_match_logits( features=features, weight_column=self._weight_column, logits=logits) return unweighted_loss, weights
def _unweighted_loss_and_weights(self, logits, processed_labels, features): """Computes loss spec.""" if self._loss_fn: unweighted_loss = base_head.call_loss_fn( loss_fn=self._loss_fn, labels=processed_labels, logits=logits, features=features, expected_loss_dim=1) else: unweighted_loss = tf.compat.v1.losses.sigmoid_cross_entropy( multi_class_labels=processed_labels, logits=logits, reduction=tf.compat.v1.losses.Reduction.NONE) # Averages loss over classes. unweighted_loss = tf.math.reduce_mean( unweighted_loss, axis=-1, keepdims=True) weights = base_head.get_weights_and_check_match_logits( features=features, weight_column=self._weight_column, logits=logits) return unweighted_loss, weights