def test_extract_sarsa_parametric_action(self):
     extractor = TrainingFeatureExtractor(
         state_normalization_parameters=self.
         get_state_normalization_parameters(),
         action_normalization_parameters=self.
         get_action_normalization_parameters(),
         max_q_learning=False,
     )
     # Setup
     ws, net = self.create_ws_and_net(extractor)
     input_record = self.create_extra_input_record(net)
     self.setup_state_features(ws, input_record.state_features)
     self.setup_next_state_features(ws, input_record.next_state_features)
     self.setup_action_features(ws, input_record.action)
     self.setup_next_action_features(ws, input_record.next_action)
     reward = self.setup_reward(ws, input_record.reward)
     extra_data = self.setup_extra_data(ws, input_record)
     # Run
     ws.run(net)
     res = extractor.extract(ws, input_record, net.output_record())
     o = res.training_input
     npt.assert_array_equal(reward.reshape(-1, 1), o.reward.numpy())
     npt.assert_array_equal(
         extra_data.action_probability.reshape(-1, 1),
         res.extras.action_probability.numpy(),
     )
     npt.assert_array_equal(self.expected_action_features(),
                            o.action.float_features.numpy())
     npt.assert_array_equal(self.expected_next_action_features(),
                            o.next_action.float_features.numpy())
     npt.assert_array_equal(self.expected_state_features(),
                            o.state.float_features.numpy())
     npt.assert_array_equal(self.expected_next_state_features(),
                            o.next_state.float_features.numpy())
 def test_extract_max_q_discrete_action(self):
     extractor = TrainingFeatureExtractor(
         state_normalization_parameters=self.
         get_state_normalization_parameters(),
         max_q_learning=True,
     )
     # Setup
     ws, net = self.create_ws_and_net(extractor)
     input_record = self.create_extra_input_record(net)
     self.setup_state_features(ws, input_record.state_features)
     self.setup_next_state_features(ws, input_record.next_state_features)
     action = self.setup_action(ws, input_record.action)
     possible_next_actions = self.setup_possible_next_actions(
         ws, input_record.possible_next_actions)
     reward = self.setup_reward(ws, input_record.reward)
     extra_data = self.setup_extra_data(ws, input_record)
     # Run
     ws.run(net)
     res = extractor.extract(ws, input_record, net.output_record())
     o = res.training_input
     npt.assert_array_equal(reward.reshape(-1, 1), o.reward.numpy())
     npt.assert_array_equal(
         extra_data.action_probability.reshape(-1, 1),
         res.extras.action_probability.numpy(),
     )
     npt.assert_array_equal(action, o.action.numpy())
     npt.assert_array_equal(possible_next_actions[0],
                            o.possible_next_actions.lengths.numpy())
     npt.assert_array_equal(possible_next_actions[1],
                            o.possible_next_actions.actions.numpy())
     npt.assert_array_equal(self.expected_state_features(),
                            o.state.float_features.numpy())
     npt.assert_array_equal(self.expected_next_state_features(),
                            o.next_state.float_features.numpy())
Beispiel #3
0
 def test_extract_sarsa_parametric_action(self):
     extractor = TrainingFeatureExtractor(
         state_normalization_parameters=self.get_state_normalization_parameters(),
         action_normalization_parameters=self.get_action_normalization_parameters(),
         max_q_learning=False,
     )
     # Setup
     ws, net = self.create_ws_and_net(extractor)
     input_record = net.input_record() + schema.NewRecord(
         net, schema.Struct(("reward", schema.Scalar()))
     )
     self.setup_state_features(ws, input_record.state_features)
     self.setup_next_state_features(ws, input_record.next_state_features)
     self.setup_action_features(ws, input_record.action)
     self.setup_next_action_features(ws, input_record.next_action)
     reward = self.setup_reward(ws, input_record.reward)
     # Run
     ws.run(net)
     res = extractor.extract(ws, input_record, net.output_record())
     o = res.training_input
     npt.assert_array_equal(reward, o.reward.numpy())
     npt.assert_array_equal(
         self.expected_action_features(), o.action.float_features.numpy()
     )
     npt.assert_array_equal(
         self.expected_next_action_features(), o.next_action.float_features.numpy()
     )
     npt.assert_array_equal(
         self.expected_state_features(), o.state.float_features.numpy()
     )
     npt.assert_array_equal(
         self.expected_next_state_features(), o.next_state.float_features.numpy()
     )
Beispiel #4
0
 def test_extract_max_q_discrete_action(self):
     extractor = TrainingFeatureExtractor(
         state_normalization_parameters=self.get_state_normalization_parameters(),
         max_q_learning=True,
     )
     # Setup
     ws, net = self.create_ws_and_net(extractor)
     input_record = net.input_record() + schema.NewRecord(
         net, schema.Struct(("reward", schema.Scalar()))
     )
     self.setup_state_features(ws, input_record.state_features)
     self.setup_next_state_features(ws, input_record.next_state_features)
     action = self.setup_action(ws, input_record.action)
     possible_next_actions = self.setup_possible_next_actions(
         ws, input_record.possible_next_actions
     )
     reward = self.setup_reward(ws, input_record.reward)
     # Run
     ws.run(net)
     res = extractor.extract(ws, input_record, net.output_record())
     o = res.training_input
     npt.assert_array_equal(reward, o.reward.numpy())
     npt.assert_array_equal(action, o.action.numpy())
     npt.assert_array_equal(
         possible_next_actions[0], o.possible_next_actions.lengths.numpy()
     )
     npt.assert_array_equal(
         possible_next_actions[1], o.possible_next_actions.actions.numpy()
     )
     npt.assert_array_equal(
         self.expected_state_features(), o.state.float_features.numpy()
     )
     npt.assert_array_equal(
         self.expected_next_state_features(), o.next_state.float_features.numpy()
     )
 def _test_extract_sarsa_parametric_action(self, normalize):
     extractor = TrainingFeatureExtractor(
         state_normalization_parameters=self.
         get_state_normalization_parameters(),
         action_normalization_parameters=self.
         get_action_normalization_parameters(),
         include_possible_actions=False,
         normalize=normalize,
     )
     # Setup
     ws, net = self.create_ws_and_net(extractor)
     input_record = self.create_extra_input_record(net)
     self.setup_state_features(ws, input_record.state_features)
     self.setup_next_state_features(ws, input_record.next_state_features)
     self.setup_action_features(ws, input_record.action)
     self.setup_next_action_features(ws, input_record.next_action)
     reward = self.setup_reward(ws, input_record.reward)
     not_terminal = self.setup_not_terminal(ws, input_record.not_terminal)
     time_diff = self.setup_time_diff(ws, input_record.time_diff)
     mdp_id = self.setup_mdp_id(ws, input_record.mdp_id)
     sequence_number = self.setup_seq_num(ws, input_record.sequence_number)
     extra_data = self.setup_extra_data(ws, input_record)
     # Run
     ws.run(net)
     res = extractor.extract(ws, input_record, net.output_record())
     o = res.training_input
     e = res.extras
     npt.assert_array_equal(reward.reshape(-1, 1), o.reward.numpy())
     npt.assert_array_equal(time_diff.reshape(-1, 1), o.time_diff.numpy())
     npt.assert_array_equal(not_terminal.reshape(-1, 1),
                            o.not_terminal.numpy())
     npt.assert_array_equal(sequence_number.reshape(-1, 1),
                            e.sequence_number.numpy())
     npt.assert_array_equal(mdp_id.reshape(-1, 1), e.mdp_id)
     npt.assert_array_equal(
         extra_data.action_probability.reshape(-1, 1),
         res.extras.action_probability.numpy(),
     )
     npt.assert_allclose(
         self.expected_action_features(normalize),
         o.action.float_features.numpy(),
         rtol=1e-6,
     )
     npt.assert_allclose(
         self.expected_next_action_features(normalize),
         o.next_action.float_features.numpy(),
         rtol=1e-6,
     )
     npt.assert_allclose(
         self.expected_state_features(normalize),
         o.state.float_features.numpy(),
         rtol=1e-6,
     )
     npt.assert_allclose(
         self.expected_next_state_features(normalize),
         o.next_state.float_features.numpy(),
         rtol=1e-6,
     )
Beispiel #6
0
 def _test_extract_max_q_discrete_action(self, normalize):
     extractor = TrainingFeatureExtractor(
         state_normalization_parameters=self.
         get_state_normalization_parameters(),
         include_possible_actions=True,
         normalize=normalize,
         max_num_actions=2,
         multi_steps=3,
     )
     # Setup
     ws, net = self.create_ws_and_net(extractor)
     input_record = self.create_extra_input_record(net)
     self.setup_state_features(ws, input_record.state_features)
     self.setup_next_state_features(ws, input_record.next_state_features)
     action = self.setup_action(ws, input_record.action)
     next_action = self.setup_next_action(ws, input_record.next_action)
     possible_actions_mask = self.setup_possible_actions_mask(
         ws, input_record.possible_actions_mask)
     possible_next_actions_mask = self.setup_possible_next_actions_mask(
         ws, input_record.possible_next_actions_mask)
     reward = self.setup_reward(ws, input_record.reward)
     not_terminal = self.setup_not_terminal(ws, input_record.not_terminal)
     time_diff = self.setup_time_diff(ws, input_record.time_diff)
     step = self.setup_step(ws, input_record.step)
     extra_data = self.setup_extra_data(ws, input_record)
     # Run
     ws.run(net)
     res = extractor.extract(ws, input_record, net.output_record())
     o = res.training_input
     npt.assert_array_equal(reward.reshape(-1, 1), o.reward.numpy())
     npt.assert_array_equal(time_diff.reshape(-1, 1), o.time_diff.numpy())
     npt.assert_array_equal(not_terminal.reshape(-1, 1),
                            o.not_terminal.numpy())
     npt.assert_array_equal(step.reshape(-1, 1), o.step.numpy())
     npt.assert_array_equal(
         extra_data.action_probability.reshape(-1, 1),
         res.extras.action_probability.numpy(),
     )
     npt.assert_array_equal(action, o.action.numpy())
     npt.assert_array_equal(next_action, o.next_action.numpy())
     npt.assert_array_equal(possible_actions_mask[1],
                            o.possible_actions_mask.numpy().flatten())
     npt.assert_array_equal(
         possible_next_actions_mask[1],
         o.possible_next_actions_mask.numpy().flatten(),
     )
     npt.assert_allclose(
         self.expected_state_features(normalize),
         o.state.float_features.numpy(),
         rtol=1e-6,
     )
     npt.assert_allclose(
         self.expected_next_state_features(normalize),
         o.next_state.float_features.numpy(),
         rtol=1e-6,
     )
 def _test_extract_max_q_discrete_action(self, normalize):
     extractor = TrainingFeatureExtractor(
         state_normalization_parameters=self.get_state_normalization_parameters(),
         include_possible_actions=True,
         normalize=normalize,
         max_num_actions=2,
         multi_steps=3,
     )
     # Setup
     ws, net = self.create_ws_and_net(extractor)
     input_record = self.create_extra_input_record(net)
     self.setup_state_features(ws, input_record.state_features)
     self.setup_next_state_features(ws, input_record.next_state_features)
     action = self.setup_action(ws, input_record.action)
     next_action = self.setup_next_action(ws, input_record.next_action)
     possible_actions_mask = self.setup_possible_actions_mask(
         ws, input_record.possible_actions_mask
     )
     possible_next_actions_mask = self.setup_possible_next_actions_mask(
         ws, input_record.possible_next_actions_mask
     )
     reward = self.setup_reward(ws, input_record.reward)
     not_terminal = self.setup_not_terminal(ws, input_record.not_terminal)
     step = self.setup_step(ws, input_record.step)
     extra_data = self.setup_extra_data(ws, input_record)
     # Run
     ws.run(net)
     res = extractor.extract(ws, input_record, net.output_record())
     o = res.training_input
     npt.assert_array_equal(reward.reshape(-1, 1), o.reward.numpy())
     npt.assert_array_equal(not_terminal.reshape(-1, 1), o.not_terminal.numpy())
     npt.assert_array_equal(step.reshape(-1, 1), o.step.numpy())
     npt.assert_array_equal(
         extra_data.action_probability.reshape(-1, 1),
         res.extras.action_probability.numpy(),
     )
     npt.assert_array_equal(action, o.action.numpy())
     npt.assert_array_equal(next_action, o.next_action.numpy())
     npt.assert_array_equal(
         possible_actions_mask[1], o.possible_actions_mask.numpy().flatten()
     )
     npt.assert_array_equal(
         possible_next_actions_mask[1],
         o.possible_next_actions_mask.numpy().flatten(),
     )
     npt.assert_allclose(
         self.expected_state_features(normalize),
         o.state.float_features.numpy(),
         rtol=1e-6,
     )
     npt.assert_allclose(
         self.expected_next_state_features(normalize),
         o.next_state.float_features.numpy(),
         rtol=1e-6,
     )
 def _test_extract_sarsa_parametric_action(self, normalize):
     extractor = TrainingFeatureExtractor(
         state_normalization_parameters=self.get_state_normalization_parameters(),
         action_normalization_parameters=self.get_action_normalization_parameters(),
         include_possible_actions=False,
         normalize=normalize,
     )
     # Setup
     ws, net = self.create_ws_and_net(extractor)
     input_record = self.create_extra_input_record(net)
     self.setup_state_features(ws, input_record.state_features)
     self.setup_next_state_features(ws, input_record.next_state_features)
     self.setup_action_features(ws, input_record.action)
     self.setup_next_action_features(ws, input_record.next_action)
     reward = self.setup_reward(ws, input_record.reward)
     not_terminal = self.setup_not_terminal(ws, input_record.not_terminal)
     extra_data = self.setup_extra_data(ws, input_record)
     # Run
     ws.run(net)
     res = extractor.extract(ws, input_record, net.output_record())
     o = res.training_input
     npt.assert_array_equal(reward.reshape(-1, 1), o.reward.numpy())
     npt.assert_array_equal(not_terminal.reshape(-1, 1), o.not_terminal.numpy())
     npt.assert_array_equal(
         extra_data.action_probability.reshape(-1, 1),
         res.extras.action_probability.numpy(),
     )
     npt.assert_allclose(
         self.expected_action_features(normalize),
         o.action.float_features.numpy(),
         rtol=1e-6,
     )
     npt.assert_allclose(
         self.expected_next_action_features(normalize),
         o.next_action.float_features.numpy(),
         rtol=1e-6,
     )
     npt.assert_allclose(
         self.expected_state_features(normalize),
         o.state.float_features.numpy(),
         rtol=1e-6,
     )
     npt.assert_allclose(
         self.expected_next_state_features(normalize),
         o.next_state.float_features.numpy(),
         rtol=1e-6,
     )
Beispiel #9
0
 def _test_extract_sarsa_discrete_action(self, normalize):
     extractor = TrainingFeatureExtractor(
         state_normalization_parameters=self.
         get_state_normalization_parameters(),
         include_possible_actions=False,
         normalize=normalize,
         max_num_actions=2,
     )
     # Setup
     ws, net = self.create_ws_and_net(extractor)
     input_record = self.create_extra_input_record(net)
     self.setup_state_features(ws, input_record.state_features)
     self.setup_next_state_features(ws, input_record.next_state_features)
     action = self.setup_action(ws, input_record.action)
     next_action = self.setup_next_action(ws, input_record.next_action)
     reward = self.setup_reward(ws, input_record.reward)
     not_terminal = self.setup_not_terminal(ws, input_record.not_terminal)
     extra_data = self.setup_extra_data(ws, input_record)
     # Run
     ws.run(net)
     res = extractor.extract(ws, input_record, net.output_record())
     o = res.training_input
     npt.assert_array_equal(reward.reshape(-1, 1), o.reward.numpy())
     npt.assert_array_equal(not_terminal.reshape(-1, 1),
                            o.not_terminal.numpy())
     npt.assert_array_equal(
         extra_data.action_probability.reshape(-1, 1),
         res.extras.action_probability.numpy(),
     )
     npt.assert_array_equal(action, o.action.numpy())
     npt.assert_array_equal(next_action, o.next_action.numpy())
     npt.assert_allclose(
         self.expected_state_features(normalize),
         o.state.float_features.numpy(),
         rtol=1e-6,
     )
     npt.assert_allclose(
         self.expected_next_state_features(normalize),
         o.next_state.float_features.numpy(),
         rtol=1e-6,
     )