class NetworkTest(unittest.TestCase): def setUp(self): # Make tests reproducible. np.random.seed(0) tf.random.set_seed(0) self.env = TicTacToeEnv() self.network_initializer = TicTacToeInitializer() self.network = Network(self.network_initializer) self.model = MuZeroMctsModel(self.env, self.network) def test_initial_inference(self): game_state = np.array(self.env.get_states()).reshape( -1, TicTacToeConfig.action_size) output = self.network.initial_inference(game_state) self.assertTrue(output.reward == 0) # self.assertTrue(output.value.shape == (1, 2*TicTacToeConfig.support_size + 1)) self.assertTrue(output.value.shape == (1, TicTacToeConfig.value_size)) self.assertTrue( output.policy_logits.shape == (1, TicTacToeConfig.action_size)) self.assertTrue( output.hidden_state.shape == (1, TicTacToeConfig.hidden_size)) def test_recurrent_inference(self): game_state = np.array(self.env.get_states()).reshape( -1, TicTacToeConfig.action_size) action = Action(0) output = self.network.recurrent_inference(game_state, action) #self.assertEqual(0, output.reward) # self.assertTrue(output.value.shape == (1, 2*TicTacToeConfig.support_size + 1)) self.assertTrue(output.value.shape == (1, TicTacToeConfig.value_size)) self.assertTrue( output.policy_logits.shape == (1, TicTacToeConfig.action_size)) self.assertTrue( output.hidden_state.shape == (1, TicTacToeConfig.hidden_size))
class TicTacToeNetworkInitializerTest(unittest.TestCase): def setUp(self): # Make tests reproducible. np.random.seed(0) tf.random.set_seed(0) self.env = TicTacToeEnv() self.network_initializer = TicTacToeInitializer() self.prediction_network, self.dynamics_network, self.representation_network, self.dynamics_encoder, self.representation_encoder = self.network_initializer.initialize( ) def test_prediction_network(self): input_image = np.array(self.env.get_states()).reshape( -1, TicTacToeConfig.action_size) policy_logits, value = self.prediction_network(input_image) # self.assertTrue(value.shape == (1, 2*TicTacToeConfig.support_size + 1)) self.assertTrue(value.shape == (1, TicTacToeConfig.value_size)) self.assertTrue(policy_logits.shape == (1, TicTacToeConfig.action_size)) def test_representation_network(self): input_image = np.array(self.env.get_states()).reshape( -1, TicTacToeConfig.action_size) hidden_state = self.representation_network(input_image) self.assertEqual([1, TicTacToeConfig.hidden_size], hidden_state.shape) print('LAYERS: ', self.representation_network.layers) print('INPUTS: ', self.representation_network.inputs) print('OUTPUTS: ', self.representation_network.outputs) print('SUMMARY: ', self.representation_network.summary()) # self.assertEqual(output.value, np.zeros([1, 2*support_size + 1])) # self.assertTrue(output.reward == 0) # self.assertTrue(output.reward == 0) def test_dynamics_network(self): hidden_state = np.zeros( (TicTacToeConfig.batch_size, TicTacToeConfig.hidden_size)) action = Action(0) encoded_state = self.dynamics_encoder.encode(hidden_state, action) hidden_state, reward = self.dynamics_network(encoded_state) #self.assertTrue(reward == 0) self.assertTrue(hidden_state.shape == (1, TicTacToeConfig.hidden_size)) def test_encoded_dynamics_state(self): # hidden_state = np.zeros((TicTacToeConfig.batch_size, TicTacToeConfig.hidden_size)) hidden_state = np.zeros(TicTacToeConfig.hidden_size) action = Action(0) hidden_state = np.zeros( (TicTacToeConfig.batch_size, TicTacToeConfig.hidden_size)) encoded_state = self.dynamics_encoder.encode(hidden_state, action) #TODO(FJUR): This should encode 2 planes, a 1 hot plane with the selected action, # and a binary plane whether the move was valid or not (all 0's or 1's) self.assertTrue( encoded_state.shape == (1, TicTacToeConfig.representation_size))
class BasicMctsModelTest(unittest.TestCase): def setUp(self): self.env = TicTacToeEnv() self.dynamics_model = BasicMctsModel(self.env) def test_get_predicted_value_and_final_info_discounted(self): self.dynamics_model = BasicMctsModel(self.env, discount=0.9) # Check some conditions first. states = [0] * 9 states[4] = 1 states[0] = 4 self.assertEqual((False, 0.0), self.env.check(states)) self.assertEqual(1, self.env.opponent_play(states)) predicted_value, final_info = self.dynamics_model.get_predicted_value_and_final_info( states) self.assertEqual(-.9, predicted_value) # Game ended for an illegal move. self.assertEqual([4, 4, 0, 0, 1, 1, 0, 0, 0], final_info[0]) # states self.assertEqual(5, final_info[1]) # action def test_get_predicted_value_and_final_info(self): # Check some conditions first. states = [0] * 9 states[4] = 1 states[0] = 4 self.assertEqual((False, 0.0), self.env.check(states)) self.assertEqual(1, self.env.opponent_play(states)) predicted_value, final_info = self.dynamics_model.get_predicted_value_and_final_info( states) self.assertEqual(-1., predicted_value) # Game ended for an illegal move. self.assertEqual([4, 4, 0, 0, 1, 1, 0, 0, 0], final_info[0]) # states self.assertEqual(5, final_info[1]) # action def test_step(self): self.dynamics_model.step([0] * 9, 0) # The above is a simulation step, so it should not affect the real environment. self.assertEqual([0] * 9, self.env.get_states())
class MctsCoreTicTacToeTest(unittest.TestCase): def setUp(self): self.env = TicTacToeEnv() self.model = BasicMctsModel(self.env) self.core = MctsCore(env=self.env, model=self.model) def test_rollout(self): self.core.initialize() # TODO: test more succinctly. self.assertEqual('{v: 0, p: 1.0, v_sum: 0, s: [0, 0, 0, 0, 0, 0, 0, 0, 0], r: 0, c: {0: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 1: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 2: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 3: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 4: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 5: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 6: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 7: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 8: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}}}', str(self.core.get_root_for_testing())) self.core.rollout() # TODO: test more succinctly. self.assertEqual('{v: 1, p: 1.0, v_sum: -1.0, s: [0, 0, 0, 0, 0, 0, 0, 0, 0], r: 0, c: {0: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 1: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 2: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 3: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 4: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 5: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 6: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 7: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 8: {v: 1, p: 1.0, v_sum: -1.0, s: [4, 0, 0, 0, 0, 0, 0, 0, 1], r: 0.0, c: {0: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 1: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 2: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 3: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 4: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 5: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 6: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 7: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 8: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}}}}}', str(self.core.get_root_for_testing())) # rollout should not affect the actual test. self.assertEqual([0] * 9, self.env.get_states()) self.core.rollout() # TODO: check nodes after rollout def get_ucb_distribution(self, node): # A list of (ucb, action, child). l1 = list(self.core.get_ucb_distribution(node)) l2 = [(action, ucb) for ucb, action, child in l1] l3 = sorted(l2) # Returns the list of ucb in the order of action. return [ucb for _, ucb in l3] def test_inside_initial_rollouts(self): self.core.initialize() root = self.core.get_root_for_testing() self.assertEqual([0.] * 9, self.get_ucb_distribution(root)) node1, search_path, last_action = self.core.select_node() self.assertNotEqual(root, node1) self.assertEqual([root, node1], search_path) # We can choose any action since the ucb distribution is uniform over actions. self.assertEqual(8, last_action) self.core.expand_node(node1) self.assertTrue(node1.expanded()) parent = root self.assertIsNotNone(parent.states) value = self.core.evaluate_node(node1, parent.states, last_action) self.assertEqual(0., node1.reward) # Opponent (4) placed X at the first empty space. self.assertEqual([4] + [0] * 7 + [1], node1.states) # It is likely to lose the game when simulating this. self.assertEqual(-1., value) self.assertFalse(node1.is_final) self.core.backpropagate(search_path, value) # Action of 8 yielded a reward of 0. The action has been discounted. # TODO: verify that the numbers are correct. np.testing.assert_almost_equal([1.2501018] * 8 + [-0.3749491], self.get_ucb_distribution(root)) # We visited only action 8. The result is somewhat counter-intuitive so # far, but the policy is 100% on action 8. np.testing.assert_array_equal([0.] * 8 + [1.], self.core.get_policy_distribution())