Beispiel #1
0
    def test_discrete(self):
        passed = 0

        for _ in xrange(5):
            environment = MinimalTest(definition=False)
            config = Configuration(batch_size=8,
                                   learning_rate=0.0005,
                                   memory_capacity=800,
                                   first_update=80,
                                   target_update_frequency=20,
                                   memory=dict(type='replay',
                                               random_sampling=True),
                                   states=environment.states,
                                   actions=environment.actions,
                                   network=layered_network_builder([
                                       dict(type='dense', size=32),
                                       dict(type='dense', size=32)
                                   ]))
            agent = CategoricalDQNAgent(config=config)
            runner = Runner(agent=agent, environment=environment)

            def episode_finished(r):
                return r.episode < 100 or not all(
                    x / l >= reward_threshold for x, l in zip(
                        r.episode_rewards[-100:], r.episode_lengths[-100:]))

            runner.run(episodes=1000, episode_finished=episode_finished)
            print('Categorical DQN agent: ' + str(runner.episode))
            if runner.episode < 1000:
                passed += 1

        print('Categorical DQN agent passed = {}'.format(passed))
        self.assertTrue(passed >= 4)
Beispiel #2
0
    def test_multi(self):
        passed = 0

        def network_builder(inputs, **kwargs):
            layer = layers['dense']
            state0 = layer(x=layer(x=inputs['state0'], size=32), size=32)
            state1 = layer(x=layer(x=inputs['state1'], size=32), size=32)
            return state0 * state1

        for _ in xrange(5):
            environment = MinimalTest(definition=[False, (False, 2)])
            config = Configuration(
                batch_size=8,
                learning_rate=0.001,
                memory_capacity=800,
                first_update=80,
                target_update_frequency=20,
                states=environment.states,
                actions=environment.actions,
                network=network_builder
            )
            agent = CategoricalDQNAgent(config=config)
            runner = Runner(agent=agent, environment=environment)

            def episode_finished(r):
                return r.episode < 15 or not all(x / l >= reward_threshold for x, l in zip(r.episode_rewards[-15:], r.episode_lengths[-15:]))

            runner.run(episodes=2000, episode_finished=episode_finished)
            print('Categorical DQN agent (multi-state/action): ' + str(runner.episode))
            if runner.episode < 2000:
                passed += 1

        print('Categorical DQN agent (multi-state/action) passed = {}'.format(passed))
        self.assertTrue(passed >= 2)