def test_discrete_action(self): state_dim = 8 action_dim = 4 model = DuelingQNetwork( layers=[state_dim, 8, 4, action_dim], activations=["relu", "relu", "linear"], use_batch_norm=True, ) input = model.input_prototype() self.assertEqual((1, state_dim), input.state.float_features.shape) # Using batch norm requires more than 1 example in training, avoid that model.eval() q_values = model(input) self.assertEqual((1, action_dim), q_values.q_values.shape)
def test_save_load_discrete_action_batch_norm(self): state_dim = 8 action_dim = 4 model = DuelingQNetwork( layers=[state_dim, 8, 4, action_dim], activations=["relu", "relu", "linear"], use_batch_norm=False, ) # Freezing batch_norm model.eval() # Number of expected params is the same because DuelingQNetwork always # initialize batch norm layer even if it doesn't use it. expected_num_params, expected_num_inputs, expected_num_outputs = 22, 1, 1 check_save_load(self, model, expected_num_params, expected_num_inputs, expected_num_outputs)