def _compute_loss(self, prediction_tensor, target_tensor, weights, class_indices=None): """Compute loss function. Args: prediction_tensor: A float tensor of shape [batch_size, num_anchors, num_classes] representing the predicted logits for each class target_tensor: A float tensor of shape [batch_size, num_anchors, num_classes] representing one-hot encoded classification targets weights: a float tensor of shape [batch_size, num_anchors] class_indices: (Optional) A 1-D integer tensor of class indices. If provided, computes loss only for the specified class indices. Returns: loss: a float tensor of shape [batch_size, num_anchors, num_classes] representing the value of the loss function. """ weights = tf.expand_dims(weights, 2) if class_indices is not None: weights *= tf.reshape( ops.indices_to_dense_vector(class_indices, tf.shape(prediction_tensor)[2]), [1, 1, -1]) per_entry_cross_ent = (tf.nn.sigmoid_cross_entropy_with_logits( labels=target_tensor, logits=prediction_tensor)) return per_entry_cross_ent * weights
def subsample_indicator(indicator, num_samples): """Subsample indicator vector. Given a boolean indicator vector with M elements set to `True`, the function assigns all but `num_samples` of these previously `True` elements to `False`. If `num_samples` is greater than M, the original indicator vector is returned. Args: indicator: a 1-dimensional boolean tensor indicating which elements are allowed to be sampled and which are not. num_samples: int32 scalar tensor Returns: a boolean tensor with the same shape as input (indicator) tensor """ indices = tf.where(indicator) indices = tf.random_shuffle(indices) indices = tf.reshape(indices, [-1]) num_samples = tf.minimum(tf.size(indices), num_samples) selected_indices = tf.slice(indices, [0], tf.reshape(num_samples, [1])) selected_indicator = ops.indices_to_dense_vector( selected_indices, tf.shape(indicator)[0]) return tf.equal(selected_indicator, 1)
def test_indices_to_dense_vector_empty_indices_as_input(self): size = 500 rand_indices = [] expected_output = np.zeros(size, dtype=np.float32) tf_rand_indices = tf.constant(rand_indices) indicator = ops.indices_to_dense_vector(tf_rand_indices, size) with self.test_session() as sess: output = sess.run(indicator) self.assertAllEqual(output, expected_output) self.assertEqual(output.dtype, expected_output.dtype)
def test_indices_to_dense_vector_all_indices_as_input(self): size = 500 num_indices = 500 rand_indices = np.random.permutation(np.arange(size))[0:num_indices] expected_output = np.ones(size, dtype=np.float32) tf_rand_indices = tf.constant(rand_indices) indicator = ops.indices_to_dense_vector(tf_rand_indices, size) with self.test_session() as sess: output = sess.run(indicator) self.assertAllEqual(output, expected_output) self.assertEqual(output.dtype, expected_output.dtype)
def test_indices_to_dense_vector_int(self): size = 500 num_indices = 25 rand_indices = np.random.permutation(np.arange(size))[0:num_indices] expected_output = np.zeros(size, dtype=np.int64) expected_output[rand_indices] = 1 tf_rand_indices = tf.constant(rand_indices) indicator = ops.indices_to_dense_vector(tf_rand_indices, size, 1, dtype=tf.int64) with self.test_session() as sess: output = sess.run(indicator) self.assertAllEqual(output, expected_output) self.assertEqual(output.dtype, expected_output.dtype)
def test_indices_to_dense_vector_size_at_inference(self): size = 5000 num_indices = 250 all_indices = np.arange(size) rand_indices = np.random.permutation(all_indices)[0:num_indices] expected_output = np.zeros(size, dtype=np.float32) expected_output[rand_indices] = 1. tf_all_indices = tf.placeholder(tf.int32) tf_rand_indices = tf.constant(rand_indices) indicator = ops.indices_to_dense_vector(tf_rand_indices, tf.shape(tf_all_indices)[0]) feed_dict = {tf_all_indices: all_indices} with self.test_session() as sess: output = sess.run(indicator, feed_dict=feed_dict) self.assertAllEqual(output, expected_output) self.assertEqual(output.dtype, expected_output.dtype)
def _compute_loss(self, prediction_tensor, target_tensor, weights, class_indices=None): """Compute loss function. Args: prediction_tensor: A float tensor of shape [batch_size, num_anchors, num_classes] representing the predicted logits for each class target_tensor: A float tensor of shape [batch_size, num_anchors, num_classes] representing one-hot encoded classification targets weights: a float tensor of shape, either [batch_size, num_anchors, num_classes] or [batch_size, num_anchors, 1]. If the shape is [batch_size, num_anchors, 1], all the classses are equally weighted. class_indices: (Optional) A 1-D integer tensor of class indices. If provided, computes loss only for the specified class indices. Returns: loss: a float tensor of shape [batch_size, num_anchors, num_classes] representing the value of the loss function. """ if class_indices is not None: weights *= tf.reshape( ops.indices_to_dense_vector(class_indices, tf.shape(prediction_tensor)[2]), [1, 1, -1]) per_entry_cross_ent = (tf.nn.sigmoid_cross_entropy_with_logits( labels=target_tensor, logits=prediction_tensor)) prediction_probabilities = tf.sigmoid(prediction_tensor) p_t = ((target_tensor * prediction_probabilities) + ((1 - target_tensor) * (1 - prediction_probabilities))) modulating_factor = 1.0 if self._gamma: modulating_factor = tf.pow(1.0 - p_t, self._gamma) alpha_weight_factor = 1.0 if self._alpha is not None: alpha_weight_factor = (target_tensor * self._alpha + (1 - target_tensor) * (1 - self._alpha)) focal_cross_entropy_loss = (modulating_factor * alpha_weight_factor * per_entry_cross_ent) return focal_cross_entropy_loss * weights
def test_indices_to_dense_vector_custom_values(self): size = 100 num_indices = 10 rand_indices = np.random.permutation(np.arange(size))[0:num_indices] indices_value = np.random.rand(1) default_value = np.random.rand(1) expected_output = np.float32(np.ones(size) * default_value) expected_output[rand_indices] = indices_value tf_rand_indices = tf.constant(rand_indices) indicator = ops.indices_to_dense_vector(tf_rand_indices, size, indices_value=indices_value, default_value=default_value) with self.test_session() as sess: output = sess.run(indicator) self.assertAllClose(output, expected_output) self.assertEqual(output.dtype, expected_output.dtype)