예제 #1
0
    def test_init(self):

        sess = tf.InteractiveSession()
        env = rlcard.make('leduc-holdem', allow_step_back=True)
        agent = DeepCFR(session=sess,
                        env=env,
                        policy_network_layers=(4,4),
                        advantage_network_layers=(4,4),
                        num_traversals=1,
                        num_step=1,
                        learning_rate=1e-4,
                        batch_size_advantage=10,
                        batch_size_strategy=10,
                        memory_capacity=int(1e7))

        self.assertEqual(agent._num_traversals, 1)
        self.assertEqual(agent._num_step, 1)
        self.assertEqual(agent._batch_size_advantage, 10)
        self.assertEqual(agent._batch_size_strategy, 10)

        sess.close()
        tf.reset_default_graph()
예제 #2
0
    def test_train(self):

        num_iterations = 10

        sess = tf.InteractiveSession()
        env = rlcard.make('leduc-holdem', {'allow_step_back': True})
        agent = DeepCFR(session=sess,
                        scope='deepcfr',
                        env=env,
                        policy_network_layers=(128, 128),
                        advantage_network_layers=(128, 128),
                        num_traversals=1,
                        num_step=1,
                        learning_rate=1e-4,
                        batch_size_advantage=64,
                        batch_size_strategy=64,
                        memory_capacity=int(1e5))

        # Test train
        for _ in range(num_iterations):
            agent.train()

        # Test eval_step
        state = {
            'obs': np.random.random_sample(env.state_shape),
            'legal_actions': [a for a in range(env.action_num)]
        }
        action, _ = agent.eval_step(state)
        self.assertIn(action, [a for a in range(env.action_num)])

        # Test simulate other
        action = agent.simulate_other(0, state)
        self.assertIn(action, [a for a in range(env.action_num)])

        # Test action advantage
        advantages = agent.action_advantage(state, 0)
        self.assertEqual(advantages.shape[0], env.action_num)

        sess.close()
        tf.reset_default_graph()