Exemplo n.º 1
0
    def test_l1_l2_norm(self):
        env = make('Acrobot-v1')
        env_spec = EnvSpec(obs_space=env.observation_space,
                           action_space=env.action_space)
        name = 'dqn'

        mlp_q = MLPQValueFunction(env_spec=env_spec,
                                  name_scope=name + '_mlp',
                                  name=name + '_mlp',
                                  mlp_config=[{
                                      "ACT": "RELU",
                                      "B_INIT_VALUE": 0.0,
                                      "NAME": "1",
                                      "N_UNITS": 16,
                                      "TYPE": "DENSE",
                                      "W_NORMAL_STDDEV": 0.03,
                                      "L1_NORM": 1000.0,
                                      "L2_NORM": 1000.0
                                  }, {
                                      "ACT": "LINEAR",
                                      "B_INIT_VALUE": 0.0,
                                      "NAME": "OUPTUT",
                                      "N_UNITS": 1,
                                      "L1_NORM": 1000.0,
                                      "L2_NORM": 1000.0,
                                      "TYPE": "DENSE",
                                      "W_NORMAL_STDDEV": 0.03
                                  }])
        dqn = DQN(env_spec=env_spec,
                  config_or_config_dict=dict(REPLAY_BUFFER_SIZE=1000,
                                             GAMMA=0.99,
                                             BATCH_SIZE=10,
                                             LEARNING_RATE=0.01,
                                             TRAIN_ITERATION=1,
                                             DECAY=0.5),
                  name=name,
                  value_func=mlp_q)
        dqn2, _ = self.create_dqn(name='dqn_2')
        a = TransitionData(env_spec)
        st = env.reset()
        dqn.init()
        dqn2.init()
        for i in range(100):
            ac = dqn.predict(obs=st, sess=self.sess, batch_flag=False)
            st_new, re, done, _ = env.step(action=ac)
            a.append(state=st,
                     new_state=st_new,
                     action=ac,
                     done=done,
                     reward=re)
            st = st_new
            dqn.append_to_memory(a)
        for i in range(20):
            print(
                'dqn1 loss: ',
                dqn.train(batch_data=a,
                          train_iter=10,
                          sess=None,
                          update_target=True))
            print(
                'dqn2 loss: ',
                dqn2.train(batch_data=a,
                           train_iter=10,
                           sess=None,
                           update_target=True))
        var_list = self.sess.run(dqn.q_value_func.parameters('tf_var_list'))
        print(var_list)
        var_list2 = self.sess.run(dqn2.q_value_func.parameters('tf_var_list'))
        print(var_list2)
        for var, var2 in zip(var_list, var_list2):
            diff = np.abs(var2) - np.abs(var)
            self.assertTrue(np.greater(np.mean(diff), 0.0).all())
Exemplo n.º 2
0
    def test_integration_with_dqn(self):
        env = make('Acrobot-v1')
        env_spec = EnvSpec(obs_space=env.observation_space,
                           action_space=env.action_space)

        mlp_q = MLPQValueFunction(env_spec=env_spec,
                                  name='mlp_q',
                                  name_scope='mlp_q',
                                  mlp_config=[{
                                      "ACT": "RELU",
                                      "B_INIT_VALUE": 0.0,
                                      "NAME": "1",
                                      "N_UNITS": 16,
                                      "TYPE": "DENSE",
                                      "W_NORMAL_STDDEV": 0.03
                                  }, {
                                      "ACT": "LINEAR",
                                      "B_INIT_VALUE": 0.0,
                                      "NAME": "OUPTUT",
                                      "N_UNITS": 1,
                                      "TYPE": "DENSE",
                                      "W_NORMAL_STDDEV": 0.03
                                  }])
        dqn = DQN(env_spec=env_spec,
                  name='dqn_test',
                  config_or_config_dict=dict(REPLAY_BUFFER_SIZE=1000,
                                             GAMMA=0.99,
                                             BATCH_SIZE=10,
                                             LEARNING_RATE=0.001,
                                             TRAIN_ITERATION=1,
                                             DECAY=0.5),
                  value_func=mlp_q)
        dqn.init()
        st = env.reset()
        from baconian.common.sampler.sample_data import TransitionData
        a = TransitionData(env_spec)
        res = []
        for i in range(100):
            ac = dqn.predict(obs=st, sess=self.sess, batch_flag=False)
            st_new, re, done, _ = env.step(action=ac)
            a.append(state=st,
                     new_state=st_new,
                     action=ac,
                     done=done,
                     reward=re)
            dqn.append_to_memory(a)
        res.append(
            dqn.train(batch_data=a,
                      train_iter=10,
                      sess=None,
                      update_target=True)['average_loss'])
        res.append(
            dqn.train(batch_data=None,
                      train_iter=10,
                      sess=None,
                      update_target=True)['average_loss'])
        self.assertTrue(dqn in dqn.recorder._obj_log)
        self.assertTrue('average_loss' in dqn.recorder._obj_log[dqn])
        self.assertTrue(len(dqn.recorder._obj_log[dqn]['average_loss']) == 2)
        self.assertTrue(
            np.equal(np.array(res), [
                x['log_val']
                for x in dqn.recorder._obj_log[dqn]['average_loss']
            ]).all())

        self.assertTrue(len(Logger()._registered_recorders) > 0)
        self.assertTrue(dqn.recorder in Logger()._registered_recorders)

        Logger().flush_recorder()