Example #1
0
 def test_save_load_batch_norm(self):
     state_dim = 8
     action_dim = 4
     model = FullyConnectedActor(
         state_dim,
         action_dim,
         sizes=[7, 6],
         activations=["relu", "relu"],
         use_batch_norm=True,
     )
     # Freezing batch_norm
     model.eval()
     expected_num_params, expected_num_inputs, expected_num_outputs = 21, 1, 1
     check_save_load(self, model, expected_num_params, expected_num_inputs,
                     expected_num_outputs)
Example #2
0
 def test_basic(self):
     state_dim = 8
     action_dim = 4
     model = FullyConnectedActor(
         state_dim,
         action_dim,
         sizes=[7, 6],
         activations=["relu", "relu"],
         use_batch_norm=True,
     )
     input = model.input_prototype()
     self.assertEqual((1, state_dim), input.float_features.shape)
     # Using batch norm requires more than 1 example in training, avoid that
     model.eval()
     action = model(input)
     self.assertEqual((1, action_dim), action.action.shape)