def test_discrete_action(self):
     state_dim = 8
     action_dim = 4
     model = DuelingQNetwork(
         layers=[state_dim, 8, 4, action_dim],
         activations=["relu", "relu", "linear"],
         use_batch_norm=True,
     )
     input = model.input_prototype()
     self.assertEqual((1, state_dim), input.state.float_features.shape)
     # Using batch norm requires more than 1 example in training, avoid that
     model.eval()
     q_values = model(input)
     self.assertEqual((1, action_dim), q_values.q_values.shape)
 def test_save_load_discrete_action_batch_norm(self):
     state_dim = 8
     action_dim = 4
     model = DuelingQNetwork(
         layers=[state_dim, 8, 4, action_dim],
         activations=["relu", "relu", "linear"],
         use_batch_norm=False,
     )
     # Freezing batch_norm
     model.eval()
     # Number of expected params is the same because DuelingQNetwork always
     # initialize batch norm layer even if it doesn't use it.
     expected_num_params, expected_num_inputs, expected_num_outputs = 22, 1, 1
     check_save_load(self, model, expected_num_params, expected_num_inputs,
                     expected_num_outputs)