def test_throws_exception_when_weights_key_is_missing(self):
   losses = {"a": -15, "b": .4, "c": .3}
   weights = {"a": tf.constant(0.5),
              "b": tf.constant(0.3)}  # Misses the key "c".
   with self.assertRaises(ValueError):
     _ = scalarization.LinearlyScalarizedOptimizer(
         problem=ProblemWithConstantLosses(losses), weights=weights)
 def test_check_weighted_value_on_constant_losses(self):
   weights = {"a": tf.constant(0.5),
              "b": tf.constant(0.3),
              "c": tf.constant(0.4)}
   losses = {"a": -15, "b": .4, "c": .3}
   optimizer = scalarization.LinearlyScalarizedOptimizer(
       problem=ProblemWithConstantLosses(losses), weights=weights)
   loss, _ = optimizer.compute_train_loss_and_update_op(
       inputs=dict(), base_optimizer=tf.train.GradientDescentOptimizer(0.))
   with self.cached_session() as session:
     session.run(tf.initializers.global_variables())
   self.assertAllClose(loss,
                       sum(weights[key] * losses[key] for key in weights))
 def test_exception_thrown_when_weights_is_of_invalid_type(self):
     losses = {"a": -15, "b": .4, "c": .3}
     # Should fail as `weights` is neither a dict nor in the enum.
     with self.assertRaises(TypeError):
         _ = scalarization.LinearlyScalarizedOptimizer(
             problem=ProblemWithConstantLosses(losses), weights=123)