def testCustomTransform(self):
        # Test applying sigmoid to the 2nd and 4th metrics.
        sigmoid_metric_mask = [False, True, False, True]

        def sigmoid_metric_transform(metrics: tf.Tensor):
            batch_size = tf.shape(tf.nest.flatten(metrics)[0])[0]
            sigmoid_batch_mask = tf.reshape(
                tf.tile(sigmoid_metric_mask, [batch_size]),
                [batch_size, len(sigmoid_metric_mask)])
            return tf.where(sigmoid_batch_mask, tf.sigmoid(metrics), metrics)

        sigmoid_scalarizer = multi_objective_scalarizer.LinearScalarizer(
            [1, 2, 3, -1], sigmoid_metric_transform)

        self.assertAllClose(sigmoid_scalarizer(self._batch_multi_objectives),
                            [10.77958, 26.995390, -9.7795804])
 def testInvalidWeights(self):
     with self.assertRaisesRegex(ValueError, 'at least two objectives'):
         multi_objective_scalarizer.LinearScalarizer([])
     with self.assertRaisesRegex(ValueError, 'at least two objectives'):
         multi_objective_scalarizer.LinearScalarizer([1])
 def setUp(self):
     super(LinearScalarizerTest, self).setUp()
     self._scalarizer = multi_objective_scalarizer.LinearScalarizer(
         [1, 2, 3, -1])
     self._batch_multi_objectives = tf.constant(
         [[1, 2, 3, 4], [5, 6, 7, 8], [-1, -2, -3, -4]], dtype=tf.float32)
 def setUp(self):
     self._scalarizer = multi_objective_scalarizer.LinearScalarizer(
         [1, 2, 3, -1])
     super(LinearScalarizerTest, self).setUp()