Exemplo n.º 1
0
    def _self_play_single(self):
        """Play a single game and add it to the replay buffer."""
        state = self.game.new_initial_state()
        trajectory = []

        while not state.is_terminal():
            root = self.bot.mcts_search(state)
            target_policy = np.zeros(self.game.num_distinct_actions(),
                                     dtype=np.float32)
            for child in root.children:
                target_policy[child.action] = child.explore_count
            target_policy /= sum(target_policy)

            trajectory.append(
                model_lib.TrainInput(state.observation_tensor(),
                                     state.legal_actions_mask(), target_policy,
                                     root.total_reward / root.explore_count))

            action = self._select_action(root.children, len(trajectory))
            state.apply_action(action)

        terminal_rewards = state.rewards()
        for state in trajectory:
            self.replay_buffer.add(
                model_lib.TrainInput(state.observation, state.legals_mask,
                                     state.policy, terminal_rewards[0]))
Exemplo n.º 2
0
  def collect_trajectories():
    """Collects the trajectories from actors into the replay buffer."""
    num_trajectories = 0
    num_states = 0
    for trajectory in trajectory_generator():
      num_trajectories += 1
      num_states += len(trajectory.states)
      game_lengths.add(len(trajectory.states))
      game_lengths_hist.add(len(trajectory.states))

      p1_outcome = trajectory.returns[0]
      if p1_outcome > 0:
        outcomes.add(0)
      elif p1_outcome < 0:
        outcomes.add(1)
      else:
        outcomes.add(2)

      replay_buffer.extend(
          model_lib.TrainInput(
              s.observation, s.legals_mask, s.policy, p1_outcome)
          for s in trajectory.states)

      for stage in range(stage_count):
        # Scale for the length of the game
        index = (len(trajectory.states) - 1) * stage // (stage_count - 1)
        n = trajectory.states[index]
        accurate = (n.value >= 0) == (trajectory.returns[n.current_player] >= 0)
        value_accuracies[stage].add(1 if accurate else 0)
        value_predictions[stage].add(abs(n.value))

      if num_states >= learn_rate:
        break
    return num_trajectories, num_states
Exemplo n.º 3
0
  def test_model_learns_simple(self, model_type):
    game = pyspiel.load_game("tic_tac_toe")
    model = build_model(game, model_type)
    print("Num variables:", model.num_trainable_variables)
    model.print_trainable_variables()

    train_inputs = []
    state = game.new_initial_state()
    while not state.is_terminal():
      obs = state.observation_tensor()
      act_mask = state.legal_actions_mask()
      action = state.legal_actions()[0]
      policy = np.zeros(len(act_mask), dtype=float)
      policy[action] = 1
      train_inputs.append(model_lib.TrainInput(obs, act_mask, policy, value=1))
      state.apply_action(action)
      value, policy = model.inference([obs], [act_mask])
      self.assertLen(policy, 1)
      self.assertLen(value, 1)
      self.assertLen(policy[0], game.num_distinct_actions())
      self.assertLen(value[0], 1)

    losses = []
    for i in range(1000):
      loss = model.update(train_inputs)
      print(i, loss)
      losses.append(loss)
      if loss.policy < 0.05 and loss.value < 0.05:
        break

    self.assertGreater(losses[0].total, losses[-1].total)
    self.assertGreater(losses[0].policy, losses[-1].policy)
    self.assertGreater(losses[0].value, losses[-1].value)
    self.assertLess(losses[-1].value, 0.05)
    self.assertLess(losses[-1].policy, 0.05)
Exemplo n.º 4
0
 def collect_trajectories():
     """Collects the trajectories from actors into the replay buffer."""
     num_trajectories = 0
     num_states = 0
     for actor_process in actors:
         while True:
             try:
                 trajectory = actor_process.queue.get_nowait()
             except spawn.Empty:
                 break
             num_trajectories += 1
             num_states += len(trajectory.states)
             replay_buffer.extend(
                 model_lib.TrainInput(s.observation, s.legals_mask,
                                      s.policy, trajectory.returns[0])
                 for s in trajectory.states)
     return num_trajectories, num_states
Exemplo n.º 5
0
def solve_game(state):
  state_str = str(state)
  if state_str in solved:
    return solved[state_str].value
  if state.is_terminal():
    return state.returns()[0]

  max_player = state.current_player() == 0
  obs = state.observation_tensor()
  act_mask = np.array(state.legal_actions_mask())
  values = np.full(act_mask.shape, -2 if max_player else 2)
  for action in state.legal_actions():
    values[action] = solve_game(state.child(action))
  value = values.max() if max_player else values.min()
  best_actions = np.where((values == value) & act_mask)
  policy = np.zeros_like(act_mask)
  policy[best_actions[0][0]] = 1  # Choose the first for a deterministic policy.
  solved[state_str] = model_lib.TrainInput(obs, act_mask, policy, value)
  return value
Exemplo n.º 6
0
  def test_evaluator_caching(self):
    game = pyspiel.load_game("tic_tac_toe")
    model = build_model(game)
    evaluator = evaluator_lib.AlphaZeroEvaluator(game, model)

    state = game.new_initial_state()
    obs = state.observation_tensor()
    act_mask = state.legal_actions_mask()
    action = state.legal_actions()[0]
    policy = np.zeros(len(act_mask), dtype=float)
    policy[action] = 1
    train_inputs = [model_lib.TrainInput(obs, act_mask, policy, value=1)]

    value = evaluator.evaluate(state)
    self.assertEqual(value[0], -value[1])
    value = value[0]

    value2 = evaluator.evaluate(state)[0]
    self.assertEqual(value, value2)

    prior = evaluator.prior(state)
    prior2 = evaluator.prior(state)
    np.testing.assert_array_equal(prior, prior2)

    info = evaluator.cache_info()
    self.assertEqual(info.misses, 1)
    self.assertEqual(info.hits, 3)

    for _ in range(20):
      model.update(train_inputs)

    # Still equal due to not clearing the cache
    value3 = evaluator.evaluate(state)[0]
    self.assertEqual(value, value3)

    info = evaluator.cache_info()
    self.assertEqual(info.misses, 1)
    self.assertEqual(info.hits, 4)

    evaluator.clear_cache()

    info = evaluator.cache_info()
    self.assertEqual(info.misses, 0)
    self.assertEqual(info.hits, 0)

    # Now they differ from before
    value4 = evaluator.evaluate(state)[0]
    value5 = evaluator.evaluate(state)[0]
    self.assertNotEqual(value, value4)
    self.assertEqual(value4, value5)

    info = evaluator.cache_info()
    self.assertEqual(info.misses, 1)
    self.assertEqual(info.hits, 1)

    value6 = evaluator.evaluate(game.new_initial_state())[0]
    self.assertEqual(value4, value6)

    info = evaluator.cache_info()
    self.assertEqual(info.misses, 1)
    self.assertEqual(info.hits, 2)