def test_extract_sarsa_parametric_action(self): extractor = TrainingFeatureExtractor( state_normalization_parameters=self. get_state_normalization_parameters(), action_normalization_parameters=self. get_action_normalization_parameters(), max_q_learning=False, ) # Setup ws, net = self.create_ws_and_net(extractor) input_record = self.create_extra_input_record(net) self.setup_state_features(ws, input_record.state_features) self.setup_next_state_features(ws, input_record.next_state_features) self.setup_action_features(ws, input_record.action) self.setup_next_action_features(ws, input_record.next_action) reward = self.setup_reward(ws, input_record.reward) extra_data = self.setup_extra_data(ws, input_record) # Run ws.run(net) res = extractor.extract(ws, input_record, net.output_record()) o = res.training_input npt.assert_array_equal(reward.reshape(-1, 1), o.reward.numpy()) npt.assert_array_equal( extra_data.action_probability.reshape(-1, 1), res.extras.action_probability.numpy(), ) npt.assert_array_equal(self.expected_action_features(), o.action.float_features.numpy()) npt.assert_array_equal(self.expected_next_action_features(), o.next_action.float_features.numpy()) npt.assert_array_equal(self.expected_state_features(), o.state.float_features.numpy()) npt.assert_array_equal(self.expected_next_state_features(), o.next_state.float_features.numpy())
def test_extract_max_q_discrete_action(self): extractor = TrainingFeatureExtractor( state_normalization_parameters=self. get_state_normalization_parameters(), max_q_learning=True, ) # Setup ws, net = self.create_ws_and_net(extractor) input_record = self.create_extra_input_record(net) self.setup_state_features(ws, input_record.state_features) self.setup_next_state_features(ws, input_record.next_state_features) action = self.setup_action(ws, input_record.action) possible_next_actions = self.setup_possible_next_actions( ws, input_record.possible_next_actions) reward = self.setup_reward(ws, input_record.reward) extra_data = self.setup_extra_data(ws, input_record) # Run ws.run(net) res = extractor.extract(ws, input_record, net.output_record()) o = res.training_input npt.assert_array_equal(reward.reshape(-1, 1), o.reward.numpy()) npt.assert_array_equal( extra_data.action_probability.reshape(-1, 1), res.extras.action_probability.numpy(), ) npt.assert_array_equal(action, o.action.numpy()) npt.assert_array_equal(possible_next_actions[0], o.possible_next_actions.lengths.numpy()) npt.assert_array_equal(possible_next_actions[1], o.possible_next_actions.actions.numpy()) npt.assert_array_equal(self.expected_state_features(), o.state.float_features.numpy()) npt.assert_array_equal(self.expected_next_state_features(), o.next_state.float_features.numpy())
def test_extract_sarsa_parametric_action(self): extractor = TrainingFeatureExtractor( state_normalization_parameters=self.get_state_normalization_parameters(), action_normalization_parameters=self.get_action_normalization_parameters(), max_q_learning=False, ) # Setup ws, net = self.create_ws_and_net(extractor) input_record = net.input_record() + schema.NewRecord( net, schema.Struct(("reward", schema.Scalar())) ) self.setup_state_features(ws, input_record.state_features) self.setup_next_state_features(ws, input_record.next_state_features) self.setup_action_features(ws, input_record.action) self.setup_next_action_features(ws, input_record.next_action) reward = self.setup_reward(ws, input_record.reward) # Run ws.run(net) res = extractor.extract(ws, input_record, net.output_record()) o = res.training_input npt.assert_array_equal(reward, o.reward.numpy()) npt.assert_array_equal( self.expected_action_features(), o.action.float_features.numpy() ) npt.assert_array_equal( self.expected_next_action_features(), o.next_action.float_features.numpy() ) npt.assert_array_equal( self.expected_state_features(), o.state.float_features.numpy() ) npt.assert_array_equal( self.expected_next_state_features(), o.next_state.float_features.numpy() )
def test_extract_max_q_discrete_action(self): extractor = TrainingFeatureExtractor( state_normalization_parameters=self.get_state_normalization_parameters(), max_q_learning=True, ) # Setup ws, net = self.create_ws_and_net(extractor) input_record = net.input_record() + schema.NewRecord( net, schema.Struct(("reward", schema.Scalar())) ) self.setup_state_features(ws, input_record.state_features) self.setup_next_state_features(ws, input_record.next_state_features) action = self.setup_action(ws, input_record.action) possible_next_actions = self.setup_possible_next_actions( ws, input_record.possible_next_actions ) reward = self.setup_reward(ws, input_record.reward) # Run ws.run(net) res = extractor.extract(ws, input_record, net.output_record()) o = res.training_input npt.assert_array_equal(reward, o.reward.numpy()) npt.assert_array_equal(action, o.action.numpy()) npt.assert_array_equal( possible_next_actions[0], o.possible_next_actions.lengths.numpy() ) npt.assert_array_equal( possible_next_actions[1], o.possible_next_actions.actions.numpy() ) npt.assert_array_equal( self.expected_state_features(), o.state.float_features.numpy() ) npt.assert_array_equal( self.expected_next_state_features(), o.next_state.float_features.numpy() )
def _test_extract_sarsa_parametric_action(self, normalize): extractor = TrainingFeatureExtractor( state_normalization_parameters=self. get_state_normalization_parameters(), action_normalization_parameters=self. get_action_normalization_parameters(), include_possible_actions=False, normalize=normalize, ) # Setup ws, net = self.create_ws_and_net(extractor) input_record = self.create_extra_input_record(net) self.setup_state_features(ws, input_record.state_features) self.setup_next_state_features(ws, input_record.next_state_features) self.setup_action_features(ws, input_record.action) self.setup_next_action_features(ws, input_record.next_action) reward = self.setup_reward(ws, input_record.reward) not_terminal = self.setup_not_terminal(ws, input_record.not_terminal) time_diff = self.setup_time_diff(ws, input_record.time_diff) mdp_id = self.setup_mdp_id(ws, input_record.mdp_id) sequence_number = self.setup_seq_num(ws, input_record.sequence_number) extra_data = self.setup_extra_data(ws, input_record) # Run ws.run(net) res = extractor.extract(ws, input_record, net.output_record()) o = res.training_input e = res.extras npt.assert_array_equal(reward.reshape(-1, 1), o.reward.numpy()) npt.assert_array_equal(time_diff.reshape(-1, 1), o.time_diff.numpy()) npt.assert_array_equal(not_terminal.reshape(-1, 1), o.not_terminal.numpy()) npt.assert_array_equal(sequence_number.reshape(-1, 1), e.sequence_number.numpy()) npt.assert_array_equal(mdp_id.reshape(-1, 1), e.mdp_id) npt.assert_array_equal( extra_data.action_probability.reshape(-1, 1), res.extras.action_probability.numpy(), ) npt.assert_allclose( self.expected_action_features(normalize), o.action.float_features.numpy(), rtol=1e-6, ) npt.assert_allclose( self.expected_next_action_features(normalize), o.next_action.float_features.numpy(), rtol=1e-6, ) npt.assert_allclose( self.expected_state_features(normalize), o.state.float_features.numpy(), rtol=1e-6, ) npt.assert_allclose( self.expected_next_state_features(normalize), o.next_state.float_features.numpy(), rtol=1e-6, )
def _test_extract_max_q_discrete_action(self, normalize): extractor = TrainingFeatureExtractor( state_normalization_parameters=self. get_state_normalization_parameters(), include_possible_actions=True, normalize=normalize, max_num_actions=2, multi_steps=3, ) # Setup ws, net = self.create_ws_and_net(extractor) input_record = self.create_extra_input_record(net) self.setup_state_features(ws, input_record.state_features) self.setup_next_state_features(ws, input_record.next_state_features) action = self.setup_action(ws, input_record.action) next_action = self.setup_next_action(ws, input_record.next_action) possible_actions_mask = self.setup_possible_actions_mask( ws, input_record.possible_actions_mask) possible_next_actions_mask = self.setup_possible_next_actions_mask( ws, input_record.possible_next_actions_mask) reward = self.setup_reward(ws, input_record.reward) not_terminal = self.setup_not_terminal(ws, input_record.not_terminal) time_diff = self.setup_time_diff(ws, input_record.time_diff) step = self.setup_step(ws, input_record.step) extra_data = self.setup_extra_data(ws, input_record) # Run ws.run(net) res = extractor.extract(ws, input_record, net.output_record()) o = res.training_input npt.assert_array_equal(reward.reshape(-1, 1), o.reward.numpy()) npt.assert_array_equal(time_diff.reshape(-1, 1), o.time_diff.numpy()) npt.assert_array_equal(not_terminal.reshape(-1, 1), o.not_terminal.numpy()) npt.assert_array_equal(step.reshape(-1, 1), o.step.numpy()) npt.assert_array_equal( extra_data.action_probability.reshape(-1, 1), res.extras.action_probability.numpy(), ) npt.assert_array_equal(action, o.action.numpy()) npt.assert_array_equal(next_action, o.next_action.numpy()) npt.assert_array_equal(possible_actions_mask[1], o.possible_actions_mask.numpy().flatten()) npt.assert_array_equal( possible_next_actions_mask[1], o.possible_next_actions_mask.numpy().flatten(), ) npt.assert_allclose( self.expected_state_features(normalize), o.state.float_features.numpy(), rtol=1e-6, ) npt.assert_allclose( self.expected_next_state_features(normalize), o.next_state.float_features.numpy(), rtol=1e-6, )
def _test_extract_max_q_discrete_action(self, normalize): extractor = TrainingFeatureExtractor( state_normalization_parameters=self.get_state_normalization_parameters(), include_possible_actions=True, normalize=normalize, max_num_actions=2, multi_steps=3, ) # Setup ws, net = self.create_ws_and_net(extractor) input_record = self.create_extra_input_record(net) self.setup_state_features(ws, input_record.state_features) self.setup_next_state_features(ws, input_record.next_state_features) action = self.setup_action(ws, input_record.action) next_action = self.setup_next_action(ws, input_record.next_action) possible_actions_mask = self.setup_possible_actions_mask( ws, input_record.possible_actions_mask ) possible_next_actions_mask = self.setup_possible_next_actions_mask( ws, input_record.possible_next_actions_mask ) reward = self.setup_reward(ws, input_record.reward) not_terminal = self.setup_not_terminal(ws, input_record.not_terminal) step = self.setup_step(ws, input_record.step) extra_data = self.setup_extra_data(ws, input_record) # Run ws.run(net) res = extractor.extract(ws, input_record, net.output_record()) o = res.training_input npt.assert_array_equal(reward.reshape(-1, 1), o.reward.numpy()) npt.assert_array_equal(not_terminal.reshape(-1, 1), o.not_terminal.numpy()) npt.assert_array_equal(step.reshape(-1, 1), o.step.numpy()) npt.assert_array_equal( extra_data.action_probability.reshape(-1, 1), res.extras.action_probability.numpy(), ) npt.assert_array_equal(action, o.action.numpy()) npt.assert_array_equal(next_action, o.next_action.numpy()) npt.assert_array_equal( possible_actions_mask[1], o.possible_actions_mask.numpy().flatten() ) npt.assert_array_equal( possible_next_actions_mask[1], o.possible_next_actions_mask.numpy().flatten(), ) npt.assert_allclose( self.expected_state_features(normalize), o.state.float_features.numpy(), rtol=1e-6, ) npt.assert_allclose( self.expected_next_state_features(normalize), o.next_state.float_features.numpy(), rtol=1e-6, )
def _test_extract_sarsa_parametric_action(self, normalize): extractor = TrainingFeatureExtractor( state_normalization_parameters=self.get_state_normalization_parameters(), action_normalization_parameters=self.get_action_normalization_parameters(), include_possible_actions=False, normalize=normalize, ) # Setup ws, net = self.create_ws_and_net(extractor) input_record = self.create_extra_input_record(net) self.setup_state_features(ws, input_record.state_features) self.setup_next_state_features(ws, input_record.next_state_features) self.setup_action_features(ws, input_record.action) self.setup_next_action_features(ws, input_record.next_action) reward = self.setup_reward(ws, input_record.reward) not_terminal = self.setup_not_terminal(ws, input_record.not_terminal) extra_data = self.setup_extra_data(ws, input_record) # Run ws.run(net) res = extractor.extract(ws, input_record, net.output_record()) o = res.training_input npt.assert_array_equal(reward.reshape(-1, 1), o.reward.numpy()) npt.assert_array_equal(not_terminal.reshape(-1, 1), o.not_terminal.numpy()) npt.assert_array_equal( extra_data.action_probability.reshape(-1, 1), res.extras.action_probability.numpy(), ) npt.assert_allclose( self.expected_action_features(normalize), o.action.float_features.numpy(), rtol=1e-6, ) npt.assert_allclose( self.expected_next_action_features(normalize), o.next_action.float_features.numpy(), rtol=1e-6, ) npt.assert_allclose( self.expected_state_features(normalize), o.state.float_features.numpy(), rtol=1e-6, ) npt.assert_allclose( self.expected_next_state_features(normalize), o.next_state.float_features.numpy(), rtol=1e-6, )
def _test_extract_sarsa_discrete_action(self, normalize): extractor = TrainingFeatureExtractor( state_normalization_parameters=self. get_state_normalization_parameters(), include_possible_actions=False, normalize=normalize, max_num_actions=2, ) # Setup ws, net = self.create_ws_and_net(extractor) input_record = self.create_extra_input_record(net) self.setup_state_features(ws, input_record.state_features) self.setup_next_state_features(ws, input_record.next_state_features) action = self.setup_action(ws, input_record.action) next_action = self.setup_next_action(ws, input_record.next_action) reward = self.setup_reward(ws, input_record.reward) not_terminal = self.setup_not_terminal(ws, input_record.not_terminal) extra_data = self.setup_extra_data(ws, input_record) # Run ws.run(net) res = extractor.extract(ws, input_record, net.output_record()) o = res.training_input npt.assert_array_equal(reward.reshape(-1, 1), o.reward.numpy()) npt.assert_array_equal(not_terminal.reshape(-1, 1), o.not_terminal.numpy()) npt.assert_array_equal( extra_data.action_probability.reshape(-1, 1), res.extras.action_probability.numpy(), ) npt.assert_array_equal(action, o.action.numpy()) npt.assert_array_equal(next_action, o.next_action.numpy()) npt.assert_allclose( self.expected_state_features(normalize), o.state.float_features.numpy(), rtol=1e-6, ) npt.assert_allclose( self.expected_next_state_features(normalize), o.next_state.float_features.numpy(), rtol=1e-6, )