Exemple #1
0
    def validate(self, agent: Agent, *args, **kwargs):
        old_result = self.result
        self.validation_result = 0
        for a in range(len(self._dynamics_model)):
            individual_model = self._dynamics_model.model[a]
            env = individual_model.return_as_env()
            batch_data = agent.sample(
                env=env,
                sample_count=self.parameters('validation_trajectory_count'),
                sample_type='trajectory',
                store_flag=False)

            self.result[a] = batch_data.get_mean_of('reward')
            if self.result[a] > old_result[a]:
                self.validation_result += 1

        self.validation_result = self.validation_result / len(
            self._dynamics_model)

        return self.validation_result
Exemple #2
0
    def test_integration_with_dqn(self):
        env = make('Acrobot-v1')
        env_spec = EnvSpec(obs_space=env.observation_space,
                           action_space=env.action_space)

        mlp_q = MLPQValueFunction(env_spec=env_spec,
                                  name='mlp_q',
                                  name_scope='mlp_q',
                                  mlp_config=[
                                      {
                                          "ACT": "RELU",
                                          "B_INIT_VALUE": 0.0,
                                          "NAME": "1",
                                          "N_UNITS": 16,
                                          "TYPE": "DENSE",
                                          "W_NORMAL_STDDEV": 0.03
                                      },
                                      {
                                          "ACT": "LINEAR",
                                          "B_INIT_VALUE": 0.0,
                                          "NAME": "OUPTUT",
                                          "N_UNITS": 1,
                                          "TYPE": "DENSE",
                                          "W_NORMAL_STDDEV": 0.03
                                      }
                                  ])
        dqn = DQN(env_spec=env_spec,
                  name='dqn_test',
                  config_or_config_dict=dict(REPLAY_BUFFER_SIZE=1000,
                                             GAMMA=0.99,
                                             BATCH_SIZE=10,
                                             LEARNING_RATE=0.001,
                                             TRAIN_ITERATION=1,
                                             DECAY=0.5),
                  value_func=mlp_q)
        agent = Agent(env=env, env_spec=env_spec,
                      algo=dqn,
                      name='agent')
        agent.init()
        # dqn.init()
        st = env.reset()
        from baconian.common.sampler.sample_data import TransitionData
        a = TransitionData(env_spec)
        res = []
        agent.sample(env=env,
                     sample_count=100,
                     in_which_status='TRAIN',
                     store_flag=True,
                     sample_type='transition')
        agent.sample(env=env,
                     sample_count=100,
                     in_which_status='TRAIN',
                     store_flag=True,
                     sample_type='transition')
        res.append(dqn.train(batch_data=a, train_iter=10, sess=None, update_target=True)['average_loss'])
        res.append(dqn.train(batch_data=None, train_iter=10, sess=None, update_target=True)['average_loss'])
        self.assertTrue(dqn in dqn.recorder._obj_log)
        self.assertTrue('average_loss' in dqn.recorder._obj_log[dqn])
        self.assertTrue(len(dqn.recorder._obj_log[dqn]['average_loss']) == 2)
        self.assertTrue(
            np.equal(np.array(res), [x['value'] for x in dqn.recorder._obj_log[dqn]['average_loss']]).all())

        self.assertTrue(len(Logger()._registered_recorders) > 0)
        self.assertTrue(dqn.recorder in Logger()._registered_recorders)
        res = dqn.recorder.get_log(attr_name='average_loss', filter_by_status=dict())
        self.assertEqual(len(res), 2)
        res = agent.recorder.get_log(attr_name='sum_reward', filter_by_status={'status': 'TRAIN'})
        self.assertEqual(len(res), 2)
        res = agent.recorder.get_log(attr_name='sum_reward', filter_by_status={'status': 'TEST'})
        self.assertEqual(len(res), 0)
        Logger().flush_recorder()