Ejemplo n.º 1
0
class NetworkTest(unittest.TestCase):
    def setUp(self):
        # Make tests reproducible.
        np.random.seed(0)
        tf.random.set_seed(0)
        self.env = TicTacToeEnv()
        self.network_initializer = TicTacToeInitializer()
        self.network = Network(self.network_initializer)
        self.model = MuZeroMctsModel(self.env, self.network)

    def test_initial_inference(self):
        game_state = np.array(self.env.get_states()).reshape(
            -1, TicTacToeConfig.action_size)
        output = self.network.initial_inference(game_state)
        self.assertTrue(output.reward == 0)
        # self.assertTrue(output.value.shape == (1, 2*TicTacToeConfig.support_size + 1))
        self.assertTrue(output.value.shape == (1, TicTacToeConfig.value_size))
        self.assertTrue(
            output.policy_logits.shape == (1, TicTacToeConfig.action_size))
        self.assertTrue(
            output.hidden_state.shape == (1, TicTacToeConfig.hidden_size))

    def test_recurrent_inference(self):
        game_state = np.array(self.env.get_states()).reshape(
            -1, TicTacToeConfig.action_size)
        action = Action(0)
        output = self.network.recurrent_inference(game_state, action)
        #self.assertEqual(0, output.reward)
        # self.assertTrue(output.value.shape == (1, 2*TicTacToeConfig.support_size + 1))
        self.assertTrue(output.value.shape == (1, TicTacToeConfig.value_size))
        self.assertTrue(
            output.policy_logits.shape == (1, TicTacToeConfig.action_size))
        self.assertTrue(
            output.hidden_state.shape == (1, TicTacToeConfig.hidden_size))
class TicTacToeNetworkInitializerTest(unittest.TestCase):
    def setUp(self):
        # Make tests reproducible.
        np.random.seed(0)
        tf.random.set_seed(0)

        self.env = TicTacToeEnv()
        self.network_initializer = TicTacToeInitializer()
        self.prediction_network, self.dynamics_network, self.representation_network, self.dynamics_encoder, self.representation_encoder = self.network_initializer.initialize(
        )

    def test_prediction_network(self):
        input_image = np.array(self.env.get_states()).reshape(
            -1, TicTacToeConfig.action_size)
        policy_logits, value = self.prediction_network(input_image)
        # self.assertTrue(value.shape == (1, 2*TicTacToeConfig.support_size + 1))
        self.assertTrue(value.shape == (1, TicTacToeConfig.value_size))
        self.assertTrue(policy_logits.shape == (1,
                                                TicTacToeConfig.action_size))

    def test_representation_network(self):
        input_image = np.array(self.env.get_states()).reshape(
            -1, TicTacToeConfig.action_size)
        hidden_state = self.representation_network(input_image)

        self.assertEqual([1, TicTacToeConfig.hidden_size], hidden_state.shape)

        print('LAYERS: ', self.representation_network.layers)
        print('INPUTS: ', self.representation_network.inputs)
        print('OUTPUTS: ', self.representation_network.outputs)
        print('SUMMARY: ', self.representation_network.summary())

        # self.assertEqual(output.value, np.zeros([1, 2*support_size + 1]))
        # self.assertTrue(output.reward == 0)
        # self.assertTrue(output.reward == 0)

    def test_dynamics_network(self):
        hidden_state = np.zeros(
            (TicTacToeConfig.batch_size, TicTacToeConfig.hidden_size))
        action = Action(0)
        encoded_state = self.dynamics_encoder.encode(hidden_state, action)
        hidden_state, reward = self.dynamics_network(encoded_state)
        #self.assertTrue(reward == 0)
        self.assertTrue(hidden_state.shape == (1, TicTacToeConfig.hidden_size))

    def test_encoded_dynamics_state(self):
        # hidden_state = np.zeros((TicTacToeConfig.batch_size, TicTacToeConfig.hidden_size))
        hidden_state = np.zeros(TicTacToeConfig.hidden_size)
        action = Action(0)
        hidden_state = np.zeros(
            (TicTacToeConfig.batch_size, TicTacToeConfig.hidden_size))
        encoded_state = self.dynamics_encoder.encode(hidden_state, action)
        #TODO(FJUR): This should encode 2 planes, a 1 hot plane with the selected action,
        # and a binary plane whether the move was valid or not (all 0's or 1's)
        self.assertTrue(
            encoded_state.shape == (1, TicTacToeConfig.representation_size))
Ejemplo n.º 3
0
class BasicMctsModelTest(unittest.TestCase):
    def setUp(self):
        self.env = TicTacToeEnv()
        self.dynamics_model = BasicMctsModel(self.env)

    def test_get_predicted_value_and_final_info_discounted(self):
        self.dynamics_model = BasicMctsModel(self.env, discount=0.9)
        # Check some conditions first.
        states = [0] * 9
        states[4] = 1
        states[0] = 4
        self.assertEqual((False, 0.0), self.env.check(states))
        self.assertEqual(1, self.env.opponent_play(states))

        predicted_value, final_info = self.dynamics_model.get_predicted_value_and_final_info(
            states)
        self.assertEqual(-.9, predicted_value)
        # Game ended for an illegal move.
        self.assertEqual([4, 4, 0, 0, 1, 1, 0, 0, 0], final_info[0])  # states
        self.assertEqual(5, final_info[1])  # action

    def test_get_predicted_value_and_final_info(self):
        # Check some conditions first.
        states = [0] * 9
        states[4] = 1
        states[0] = 4
        self.assertEqual((False, 0.0), self.env.check(states))
        self.assertEqual(1, self.env.opponent_play(states))

        predicted_value, final_info = self.dynamics_model.get_predicted_value_and_final_info(
            states)
        self.assertEqual(-1., predicted_value)
        # Game ended for an illegal move.
        self.assertEqual([4, 4, 0, 0, 1, 1, 0, 0, 0], final_info[0])  # states
        self.assertEqual(5, final_info[1])  # action

    def test_step(self):
        self.dynamics_model.step([0] * 9, 0)
        # The above is a simulation step, so it should not affect the real environment.
        self.assertEqual([0] * 9, self.env.get_states())
Ejemplo n.º 4
0
class MctsCoreTicTacToeTest(unittest.TestCase):

    def setUp(self):
        self.env = TicTacToeEnv()
        self.model = BasicMctsModel(self.env)
        self.core = MctsCore(env=self.env, model=self.model)

    def test_rollout(self):
        self.core.initialize()
        # TODO: test more succinctly.
        self.assertEqual('{v: 0, p: 1.0, v_sum: 0, s: [0, 0, 0, 0, 0, 0, 0, 0, 0], r: 0, c: {0: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 1: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 2: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 3: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 4: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 5: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 6: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 7: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 8: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}}}',
                         str(self.core.get_root_for_testing()))
        self.core.rollout()
        # TODO: test more succinctly.
        self.assertEqual('{v: 1, p: 1.0, v_sum: -1.0, s: [0, 0, 0, 0, 0, 0, 0, 0, 0], r: 0, c: {0: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 1: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 2: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 3: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 4: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 5: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 6: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 7: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 8: {v: 1, p: 1.0, v_sum: -1.0, s: [4, 0, 0, 0, 0, 0, 0, 0, 1], r: 0.0, c: {0: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 1: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 2: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 3: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 4: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 5: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 6: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 7: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 8: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}}}}}',
                         str(self.core.get_root_for_testing()))
        # rollout should not affect the actual test.
        self.assertEqual([0] * 9, self.env.get_states())
        self.core.rollout()
        # TODO: check nodes after rollout

    def get_ucb_distribution(self, node):
        # A list of (ucb, action, child).
        l1 = list(self.core.get_ucb_distribution(node))
        l2 = [(action, ucb) for ucb, action, child in l1]
        l3 = sorted(l2)
        # Returns the list of ucb in the order of action.
        return [ucb for _, ucb in l3]

    def test_inside_initial_rollouts(self):
        self.core.initialize()
        root = self.core.get_root_for_testing()

        self.assertEqual([0.] * 9, self.get_ucb_distribution(root))

        node1, search_path, last_action = self.core.select_node()

        self.assertNotEqual(root, node1)
        self.assertEqual([root, node1], search_path)
        # We can choose any action since the ucb distribution is uniform over actions.
        self.assertEqual(8, last_action)

        self.core.expand_node(node1)

        self.assertTrue(node1.expanded())
        parent = root
        self.assertIsNotNone(parent.states)

        value = self.core.evaluate_node(node1, parent.states, last_action)

        self.assertEqual(0., node1.reward)
        # Opponent (4) placed X at the first empty space.
        self.assertEqual([4] + [0] * 7 + [1], node1.states)
        # It is likely to lose the game when simulating this.
        self.assertEqual(-1., value)
        self.assertFalse(node1.is_final)

        self.core.backpropagate(search_path, value)

        # Action of 8 yielded a reward of 0. The action has been discounted.
        # TODO: verify that the numbers are correct.
        np.testing.assert_almost_equal([1.2501018] * 8 + [-0.3749491],
                                       self.get_ucb_distribution(root))

        # We visited only action 8. The result is somewhat counter-intuitive so
        # far, but the policy is 100% on action 8.
        np.testing.assert_array_equal([0.] * 8 + [1.],
                                      self.core.get_policy_distribution())