예제 #1
0
class NetworkTest(unittest.TestCase):
    def setUp(self):
        # Make tests reproducible.
        np.random.seed(0)
        tf.random.set_seed(0)
        self.env = TicTacToeEnv()
        self.network_initializer = TicTacToeInitializer()
        self.network = Network(self.network_initializer)
        self.model = MuZeroMctsModel(self.env, self.network)

    def test_initial_inference(self):
        game_state = np.array(self.env.get_states()).reshape(
            -1, TicTacToeConfig.action_size)
        output = self.network.initial_inference(game_state)
        self.assertTrue(output.reward == 0)
        # self.assertTrue(output.value.shape == (1, 2*TicTacToeConfig.support_size + 1))
        self.assertTrue(output.value.shape == (1, TicTacToeConfig.value_size))
        self.assertTrue(
            output.policy_logits.shape == (1, TicTacToeConfig.action_size))
        self.assertTrue(
            output.hidden_state.shape == (1, TicTacToeConfig.hidden_size))

    def test_recurrent_inference(self):
        game_state = np.array(self.env.get_states()).reshape(
            -1, TicTacToeConfig.action_size)
        action = Action(0)
        output = self.network.recurrent_inference(game_state, action)
        #self.assertEqual(0, output.reward)
        # self.assertTrue(output.value.shape == (1, 2*TicTacToeConfig.support_size + 1))
        self.assertTrue(output.value.shape == (1, TicTacToeConfig.value_size))
        self.assertTrue(
            output.policy_logits.shape == (1, TicTacToeConfig.action_size))
        self.assertTrue(
            output.hidden_state.shape == (1, TicTacToeConfig.hidden_size))
예제 #2
0
 def setUp(self):
     # Make tests reproducible.
     np.random.seed(0)
     tf.random.set_seed(0)
     self.env = TicTacToeEnv()
     self.network_initializer = TicTacToeInitializer()
     self.network = Network(self.network_initializer)
     self.model = MuZeroMctsModel(self.env, self.network)
    def setUp(self):
        # Make tests reproducible.
        np.random.seed(0)
        tf.random.set_seed(0)

        self.env = TicTacToeEnv()
        self.network_initializer = TicTacToeInitializer()
        self.prediction_network, self.dynamics_network, self.representation_network, self.dynamics_encoder, self.representation_encoder = self.network_initializer.initialize(
        )
class TicTacToeNetworkInitializerTest(unittest.TestCase):
    def setUp(self):
        # Make tests reproducible.
        np.random.seed(0)
        tf.random.set_seed(0)

        self.env = TicTacToeEnv()
        self.network_initializer = TicTacToeInitializer()
        self.prediction_network, self.dynamics_network, self.representation_network, self.dynamics_encoder, self.representation_encoder = self.network_initializer.initialize(
        )

    def test_prediction_network(self):
        input_image = np.array(self.env.get_states()).reshape(
            -1, TicTacToeConfig.action_size)
        policy_logits, value = self.prediction_network(input_image)
        # self.assertTrue(value.shape == (1, 2*TicTacToeConfig.support_size + 1))
        self.assertTrue(value.shape == (1, TicTacToeConfig.value_size))
        self.assertTrue(policy_logits.shape == (1,
                                                TicTacToeConfig.action_size))

    def test_representation_network(self):
        input_image = np.array(self.env.get_states()).reshape(
            -1, TicTacToeConfig.action_size)
        hidden_state = self.representation_network(input_image)

        self.assertEqual([1, TicTacToeConfig.hidden_size], hidden_state.shape)

        print('LAYERS: ', self.representation_network.layers)
        print('INPUTS: ', self.representation_network.inputs)
        print('OUTPUTS: ', self.representation_network.outputs)
        print('SUMMARY: ', self.representation_network.summary())

        # self.assertEqual(output.value, np.zeros([1, 2*support_size + 1]))
        # self.assertTrue(output.reward == 0)
        # self.assertTrue(output.reward == 0)

    def test_dynamics_network(self):
        hidden_state = np.zeros(
            (TicTacToeConfig.batch_size, TicTacToeConfig.hidden_size))
        action = Action(0)
        encoded_state = self.dynamics_encoder.encode(hidden_state, action)
        hidden_state, reward = self.dynamics_network(encoded_state)
        #self.assertTrue(reward == 0)
        self.assertTrue(hidden_state.shape == (1, TicTacToeConfig.hidden_size))

    def test_encoded_dynamics_state(self):
        # hidden_state = np.zeros((TicTacToeConfig.batch_size, TicTacToeConfig.hidden_size))
        hidden_state = np.zeros(TicTacToeConfig.hidden_size)
        action = Action(0)
        hidden_state = np.zeros(
            (TicTacToeConfig.batch_size, TicTacToeConfig.hidden_size))
        encoded_state = self.dynamics_encoder.encode(hidden_state, action)
        #TODO(FJUR): This should encode 2 planes, a 1 hot plane with the selected action,
        # and a binary plane whether the move was valid or not (all 0's or 1's)
        self.assertTrue(
            encoded_state.shape == (1, TicTacToeConfig.representation_size))
 def initialize(self, use_random, r_seed):
     self.env = TicTacToeEnv(use_random=use_random, r_seed=r_seed)
     self.network_initializer = TicTacToeInitializer()
     self.network = Network(self.network_initializer)
     self.replay_buffer = ReplayBuffer()
     self.rng = np.random.RandomState(0)
     self.policy = MuZeroCollectionPolicy(self.env,
                                          self.network,
                                          self.replay_buffer,
                                          num_simulations=100,
                                          discount=1.,
                                          rng=self.rng)
예제 #6
0
 def play_game_once(self, r_seed):
     self.env = TicTacToeEnv(use_random=True, r_seed=r_seed)
     self.model = BasicMctsModel(self.env, r_seed=r_seed)
     self.policy = MctsPolicy(self.env,
                              self.model,
                              num_simulations=100,
                              r_seed=r_seed)
     while True:
         action = self.policy.action()
         states_isfinal_reward = self.env.step(action)
         states, is_final, reward = states_isfinal_reward
         if is_final:
             return states, is_final, reward
예제 #7
0
def get_agents(args: argparse.Namespace = get_args(),
               agent_learn: Optional[BasePolicy] = None,
               agent_opponent: Optional[BasePolicy] = None,
               optim: Optional[torch.optim.Optimizer] = None,
               ) -> Tuple[BasePolicy, torch.optim.Optimizer]:
    env = TicTacToeEnv(args.board_size, args.win_size)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    if agent_learn is None:
        # model
        net = Net(args.layer_num, args.state_shape, args.action_shape,
                  args.device).to(args.device)
        if optim is None:
            optim = torch.optim.Adam(net.parameters(), lr=args.lr)
        agent_learn = DQNPolicy(
            net, optim, args.gamma, args.n_step,
            target_update_freq=args.target_update_freq)
        if args.resume_path:
            agent_learn.load_state_dict(torch.load(args.resume_path))

    if agent_opponent is None:
        if args.opponent_path:
            agent_opponent = deepcopy(agent_learn)
            agent_opponent.load_state_dict(torch.load(args.opponent_path))
        else:
            agent_opponent = RandomPolicy()

    if args.agent_id == 1:
        agents = [agent_learn, agent_opponent]
    else:
        agents = [agent_opponent, agent_learn]
    policy = MultiAgentPolicyManager(agents)
    return policy, optim
예제 #8
0
def watch(args: argparse.Namespace = get_args(),
          agent_learn: Optional[BasePolicy] = None,
          agent_opponent: Optional[BasePolicy] = None,
          ) -> None:
    env = TicTacToeEnv(args.board_size, args.win_size)
    policy, optim = get_agents(
        args, agent_learn=agent_learn, agent_opponent=agent_opponent)
    collector = Collector(policy, env)
    result = collector.collect(n_episode=1, render=args.render)
    print(f'Final reward: {result["rew"]}, length: {result["len"]}')
예제 #9
0
class BasicMctsModelTest(unittest.TestCase):
    def setUp(self):
        self.env = TicTacToeEnv()
        self.dynamics_model = BasicMctsModel(self.env)

    def test_get_predicted_value_and_final_info_discounted(self):
        self.dynamics_model = BasicMctsModel(self.env, discount=0.9)
        # Check some conditions first.
        states = [0] * 9
        states[4] = 1
        states[0] = 4
        self.assertEqual((False, 0.0), self.env.check(states))
        self.assertEqual(1, self.env.opponent_play(states))

        predicted_value, final_info = self.dynamics_model.get_predicted_value_and_final_info(
            states)
        self.assertEqual(-.9, predicted_value)
        # Game ended for an illegal move.
        self.assertEqual([4, 4, 0, 0, 1, 1, 0, 0, 0], final_info[0])  # states
        self.assertEqual(5, final_info[1])  # action

    def test_get_predicted_value_and_final_info(self):
        # Check some conditions first.
        states = [0] * 9
        states[4] = 1
        states[0] = 4
        self.assertEqual((False, 0.0), self.env.check(states))
        self.assertEqual(1, self.env.opponent_play(states))

        predicted_value, final_info = self.dynamics_model.get_predicted_value_and_final_info(
            states)
        self.assertEqual(-1., predicted_value)
        # Game ended for an illegal move.
        self.assertEqual([4, 4, 0, 0, 1, 1, 0, 0, 0], final_info[0])  # states
        self.assertEqual(5, final_info[1])  # action

    def test_step(self):
        self.dynamics_model.step([0] * 9, 0)
        # The above is a simulation step, so it should not affect the real environment.
        self.assertEqual([0] * 9, self.env.get_states())
예제 #10
0
def watch(
    args: argparse.Namespace = get_args(),
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
) -> None:
    env = TicTacToeEnv(args.board_size, args.win_size)
    policy, optim = get_agents(
        args, agent_learn=agent_learn, agent_opponent=agent_opponent)
    policy.eval()
    policy.policies[args.agent_id - 1].set_eps(args.eps_test)
    collector = Collector(policy, env, exploration_noise=True)
    result = collector.collect(n_episode=1, render=args.render)
    rews, lens = result["rews"], result["lens"]
    print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")
예제 #11
0
class MctsPolicyTicTacToeTest(unittest.TestCase):
    def setUp(self):
        self.env = TicTacToeEnv()
        self.model = BasicMctsModel(self.env)
        self.policy = MctsPolicy(self.env, self.model, num_simulations=100)

    def test_action_start(self):
        action = self.policy.action()
        states_isfinal_reward = self.env.step(action)
        self.assertEqual(0, action)
        self.assertEqual(([1, 4, 0, 0, 0, 0, 0, 0, 0], False, 0.0),
                         states_isfinal_reward)

    def test_action_win(self):
        self.env.set_states([1, 0, 1, 1, 0, 4, 4, 4, 0])
        action = self.policy.action()
        states_isfinal_reward = self.env.step(action)
        self.assertEqual(1, action)
        self.assertEqual(([1, 1, 1, 1, 0, 4, 4, 4, 0], True, 1.0),
                         states_isfinal_reward)

    def test_action_win_2(self):
        self.env.set_states([1, 1, 4, 0, 0, 4, 1, 4, 0])
        action = self.policy.action()
        states_isfinal_reward = self.env.step(action)
        self.assertEqual(3, action)
        self.assertEqual(([1, 1, 4, 1, 0, 4, 1, 4, 0], True, 1.0),
                         states_isfinal_reward)

    def test_policy_logits(self):
        logits = self.policy.get_policy_logits()
        tf.assert_equal(
            tf.constant([0.14, 0.09, 0.13, 0.09, 0.13, 0.11, 0.09, 0.11, 0.11],
                        dtype=tf.float64), logits)

    def test_choose_action(self):
        self.assertEqual(
            1,
            self.policy.choose_action(
                tf.constant([
                    0.11, 0.116, 0.11, 0.11, 0.11, 0.111, 0.111, 0.111, 0.111
                ])))

    def test_game_deterministic(self):
        while True:
            action = self.policy.action()
            states_isfinal_reward = self.env.step(action)
            states, is_final, reward = states_isfinal_reward
            if is_final:
                break
        self.assertEqual(1.0, reward)

    def play_game_once(self, r_seed):
        self.env = TicTacToeEnv(use_random=True, r_seed=r_seed)
        self.model = BasicMctsModel(self.env, r_seed=r_seed)
        self.policy = MctsPolicy(self.env,
                                 self.model,
                                 num_simulations=100,
                                 r_seed=r_seed)
        while True:
            action = self.policy.action()
            states_isfinal_reward = self.env.step(action)
            states, is_final, reward = states_isfinal_reward
            if is_final:
                return states, is_final, reward

    def test_game_random(self):
        reward_dict = collections.defaultdict(int)
        for r_seed in range(100):
            _, _, reward = self.play_game_once(r_seed)
            reward_dict[reward] += 1
        print('reward distribution: ', reward_dict)
        # 96% winning ratio.
        self.assertEqual({1.0: 96, 0.0: 1, -1.0: 3}, reward_dict)
예제 #12
0
class TicTacToeEnvTest(unittest.TestCase):
    def setUp(self):
        self.env = TicTacToeEnv()

    def test_check(self):
        self.assertEqual((False, 0.), self.env.check([0] * 8 + [1]))
        self.assertEqual((True, -1.),
                         self.env.check([4, 4, 4, 0, 1, 1, 0, 0, 0]))
        self.assertEqual((True, 1.),
                         self.env.check([4, 4, 0, 1, 1, 1, 0, 0, 0]))
        self.assertEqual((False, 0.),
                         self.env.check([4, 4, 1, 1, 4, 1, 1, 0, 0]))
        self.assertEqual((True, 0.),
                         self.env.check([4, 1, 4, 1, 4, 1, 1, 4, 1]))

    def test_legal_actions(self):
        states = [0] * 9
        states[3] = 1
        states[7] = 4
        states[8] = 1
        self.assertEqual([0, 1, 2, 4, 5, 6], self.env.legal_actions(states))

    def test_opponent_play(self):
        # Chooses the first available space.
        self.assertEqual(0, self.env.opponent_play([0] * 8 + [1]))
        self.assertEqual(8, self.env.opponent_play([1] * 8 + [0]))

    def test_opponent_play_random(self):
        self.env = TicTacToeEnv(r_seed=0, use_random=True)
        s = set()
        for i in range(100):
            s.add(self.env.opponent_play([0, 1, 4, 0, 0, 0, 0, 0, 0]))
        self.assertEqual([0] + list(range(3, 9)), list(s))

    def test_step(self):
        self.env.set_states([4, 4, 0, 0, 1, 1, 0, 0, 0])
        states, is_final, reward = self.env.step(3)
        self.assertEqual([4, 4, 0, 1, 1, 1, 0, 0, 0], states)
        self.assertTrue(is_final)
        self.assertEqual(1., reward)
class MuZeroCollectionPolicyTicTacToeTest(unittest.TestCase):
    def setUp(self):
        # Make tests reproducible.
        np.random.seed(0)
        tf.random.set_seed(0)

        self.initialize(False, 0)

    def initialize(self, use_random, r_seed):
        self.env = TicTacToeEnv(use_random=use_random, r_seed=r_seed)
        self.network_initializer = TicTacToeInitializer()
        self.network = Network(self.network_initializer)
        self.replay_buffer = ReplayBuffer()
        self.rng = np.random.RandomState(0)
        self.policy = MuZeroCollectionPolicy(self.env,
                                             self.network,
                                             self.replay_buffer,
                                             num_simulations=100,
                                             discount=1.,
                                             rng=self.rng)

    def test_action_start(self):
        action = self.policy.action()
        # All corners are optimal first actions.
        # TODO: fix this
        #self.assertIn(action, [0, 2, 6, 8])
        self.assertEqual(1, action)

    def test_action_win(self):
        self.env.set_states([1, 0, 1, 1, 0, 4, 4, 4, 0])
        action = self.policy.action()
        # TODO: fix this to be 1.
        # self.assertEqual(1, action)
        # self.assertEqual(5, action)

    def test_action_win_2(self):
        self.env.set_states([1, 1, 4, 0, 0, 4, 1, 4, 0])
        action = self.policy.action()
        # TODO: fix this to be 3.
        # self.assertEqual(3, action)
        # self.assertEqual(4, action)

    def test_policy_logits(self):
        pass
        # TODO: fix this to provide correct logits.
        logits = self.policy.get_policy_logits()
        # tf.assert_equal(tf.constant([0.14, 0.09, 0.13, 0.09, 0.13, 0.11, 0.09, 0.11, 0.11],
        #                             dtype=tf.float64), logits)

    def test_choose_action(self):
        self.assertEqual(
            1,
            self.policy.choose_action(
                tf.constant([
                    0.11, 0.116, 0.11, 0.11, 0.11, 0.111, 0.111, 0.111, 0.111
                ])))

    def test_game_deterministic(self):
        while True:
            action = self.policy.action()
            states_isfinal_reward = self.env.step(action)
            states, is_final, reward = states_isfinal_reward
            if is_final:
                break
        # TODO: fix this to win.
        self.assertEqual(-1.0, reward)

    def test_run_self_play(self):
        self.policy.run_self_play()
        self.assertEqual(1, len(self.replay_buffer.buffer))
        traj = self.replay_buffer.buffer[0]
        self.assertEqual([1, 0], traj.action_history)
        self.assertEqual([0., -1.], traj.rewards)

    def play_game_once(self, r_seed):
        self.initialize(True, r_seed)
        while True:
            action = self.policy.action()
            states_isfinal_reward = self.env.step(action)
            states, is_final, reward = states_isfinal_reward
            if is_final:
                return states, is_final, reward
예제 #14
0
 def setUp(self):
     self.env = TicTacToeEnv()
     self.dynamics_model = BasicMctsModel(self.env)
예제 #15
0
 def test_opponent_play_random(self):
     self.env = TicTacToeEnv(r_seed=0, use_random=True)
     s = set()
     for i in range(100):
         s.add(self.env.opponent_play([0, 1, 4, 0, 0, 0, 0, 0, 0]))
     self.assertEqual([0] + list(range(3, 9)), list(s))
예제 #16
0
 def setUp(self):
     self.env = TicTacToeEnv()
예제 #17
0
 def setUp(self):
     self.env = TicTacToeEnv()
     self.model = BasicMctsModel(self.env)
     self.policy = MctsPolicy(self.env, self.model, num_simulations=100)
예제 #18
0
 def env_func():
     return TicTacToeEnv(args.board_size, args.win_size)
예제 #19
0
 def setUp(self):
     self.env = TicTacToeEnv()
     self.model = BasicMctsModel(self.env)
     self.core = MctsCore(env=self.env, model=self.model)
예제 #20
0
class MctsCoreTicTacToeTest(unittest.TestCase):

    def setUp(self):
        self.env = TicTacToeEnv()
        self.model = BasicMctsModel(self.env)
        self.core = MctsCore(env=self.env, model=self.model)

    def test_rollout(self):
        self.core.initialize()
        # TODO: test more succinctly.
        self.assertEqual('{v: 0, p: 1.0, v_sum: 0, s: [0, 0, 0, 0, 0, 0, 0, 0, 0], r: 0, c: {0: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 1: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 2: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 3: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 4: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 5: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 6: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 7: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 8: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}}}',
                         str(self.core.get_root_for_testing()))
        self.core.rollout()
        # TODO: test more succinctly.
        self.assertEqual('{v: 1, p: 1.0, v_sum: -1.0, s: [0, 0, 0, 0, 0, 0, 0, 0, 0], r: 0, c: {0: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 1: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 2: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 3: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 4: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 5: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 6: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 7: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 8: {v: 1, p: 1.0, v_sum: -1.0, s: [4, 0, 0, 0, 0, 0, 0, 0, 1], r: 0.0, c: {0: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 1: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 2: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 3: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 4: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 5: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 6: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 7: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}, 8: {v: 0, p: 1.0, v_sum: 0, s: None, r: 0, c: {}}}}}}',
                         str(self.core.get_root_for_testing()))
        # rollout should not affect the actual test.
        self.assertEqual([0] * 9, self.env.get_states())
        self.core.rollout()
        # TODO: check nodes after rollout

    def get_ucb_distribution(self, node):
        # A list of (ucb, action, child).
        l1 = list(self.core.get_ucb_distribution(node))
        l2 = [(action, ucb) for ucb, action, child in l1]
        l3 = sorted(l2)
        # Returns the list of ucb in the order of action.
        return [ucb for _, ucb in l3]

    def test_inside_initial_rollouts(self):
        self.core.initialize()
        root = self.core.get_root_for_testing()

        self.assertEqual([0.] * 9, self.get_ucb_distribution(root))

        node1, search_path, last_action = self.core.select_node()

        self.assertNotEqual(root, node1)
        self.assertEqual([root, node1], search_path)
        # We can choose any action since the ucb distribution is uniform over actions.
        self.assertEqual(8, last_action)

        self.core.expand_node(node1)

        self.assertTrue(node1.expanded())
        parent = root
        self.assertIsNotNone(parent.states)

        value = self.core.evaluate_node(node1, parent.states, last_action)

        self.assertEqual(0., node1.reward)
        # Opponent (4) placed X at the first empty space.
        self.assertEqual([4] + [0] * 7 + [1], node1.states)
        # It is likely to lose the game when simulating this.
        self.assertEqual(-1., value)
        self.assertFalse(node1.is_final)

        self.core.backpropagate(search_path, value)

        # Action of 8 yielded a reward of 0. The action has been discounted.
        # TODO: verify that the numbers are correct.
        np.testing.assert_almost_equal([1.2501018] * 8 + [-0.3749491],
                                       self.get_ucb_distribution(root))

        # We visited only action 8. The result is somewhat counter-intuitive so
        # far, but the policy is 100% on action 8.
        np.testing.assert_array_equal([0.] * 8 + [1.],
                                      self.core.get_policy_distribution())
예제 #21
0
from network import Network
from muzero_collection_policy import MuZeroCollectionPolicy
from muzero_eval_policy import MuZeroEvalPolicy
from replay_buffer import ReplayBuffer
from tic_tac_toe_env import TicTacToeEnv
import tensorflow as tf

TRAIN_ITERATIONS = 200

PLAY_ITERATIONS = 20

NUM_TRAIN_STEPS = 20
NUM_UNROLL_STEPS = 5

initializer = TicTacToeInitializer()
env = TicTacToeEnv()

#env = PacmanDetEnv()
network = Network(initializer)
replay_buffer = ReplayBuffer()
col_policy = MuZeroCollectionPolicy(env, network, replay_buffer)
eval_policy = MuZeroEvalPolicy(env, network, replay_buffer)

for train_iter in range(TRAIN_ITERATIONS):
    print('STARTING TRAINING ITERATION #{}'.format(train_iter))
    tf.summary.experimental.set_step(train_iter)
    for play_iter in range(PLAY_ITERATIONS):
        print('STARTING PLAY ITERATION #{}'.format(play_iter))
        start_time = time.time()
        col_policy.run_self_play()
        end_time = time.time()