コード例 #1
0
ファイル: alpha_zero.py プロジェクト: ngrupen/open_spiel
def learner(*, game, config, actors, evaluators, broadcast_fn, logger):
  """A learner that consumes the replay buffer and trains the network."""
  logger.also_to_stdout = True
  replay_buffer = Buffer(config.replay_buffer_size)
  learn_rate = config.replay_buffer_size // config.replay_buffer_reuse
  logger.print("Initializing model")
  model = _init_model_from_config(config)
  logger.print("Model type: %s(%s, %s)" % (config.nn_model, config.nn_width,
                                           config.nn_depth))
  logger.print("Model size:", model.num_trainable_variables, "variables")
  save_path = model.save_checkpoint(0)
  logger.print("Initial checkpoint:", save_path)
  broadcast_fn(save_path)

  data_log = data_logger.DataLoggerJsonLines(config.path, "learner", True)

  stage_count = 7
  value_accuracies = [stats.BasicStats() for _ in range(stage_count)]
  value_predictions = [stats.BasicStats() for _ in range(stage_count)]
  game_lengths = stats.BasicStats()
  game_lengths_hist = stats.HistogramNumbered(game.max_game_length() + 1)
  outcomes = stats.HistogramNamed(["Player1", "Player2", "Draw"])
  evals = [Buffer(config.evaluation_window) for _ in range(config.eval_levels)]
  total_trajectories = 0

  def trajectory_generator():
    """Merge all the actor queues into a single generator."""
    while True:
      found = 0
      for actor_process in actors:
        try:
          yield actor_process.queue.get_nowait()
        except spawn.Empty:
          pass
        else:
          found += 1
      if found == 0:
        time.sleep(0.01)  # 10ms

  def collect_trajectories():
    """Collects the trajectories from actors into the replay buffer."""
    num_trajectories = 0
    num_states = 0
    for trajectory in trajectory_generator():
      num_trajectories += 1
      num_states += len(trajectory.states)
      game_lengths.add(len(trajectory.states))
      game_lengths_hist.add(len(trajectory.states))

      p1_outcome = trajectory.returns[0]
      if p1_outcome > 0:
        outcomes.add(0)
      elif p1_outcome < 0:
        outcomes.add(1)
      else:
        outcomes.add(2)

      replay_buffer.extend(
          model_lib.TrainInput(
              s.observation, s.legals_mask, s.policy, p1_outcome)
          for s in trajectory.states)

      for stage in range(stage_count):
        # Scale for the length of the game
        index = (len(trajectory.states) - 1) * stage // (stage_count - 1)
        n = trajectory.states[index]
        accurate = (n.value >= 0) == (trajectory.returns[n.current_player] >= 0)
        value_accuracies[stage].add(1 if accurate else 0)
        value_predictions[stage].add(abs(n.value))

      if num_states >= learn_rate:
        break
    return num_trajectories, num_states

  def learn(step):
    """Sample from the replay buffer, update weights and save a checkpoint."""
    losses = []
    for _ in range(len(replay_buffer) // config.train_batch_size):
      data = replay_buffer.sample(config.train_batch_size)
      losses.append(model.update(data))

    # Always save a checkpoint, either for keeping or for loading the weights to
    # the actors. It only allows numbers, so use -1 as "latest".
    save_path = model.save_checkpoint(
        step if step % config.checkpoint_freq == 0 else -1)
    losses = sum(losses, model_lib.Losses(0, 0, 0)) / len(losses)
    logger.print(losses)
    logger.print("Checkpoint saved:", save_path)
    return save_path, losses

  last_time = time.time() - 60
  for step in itertools.count(1):
    for value_accuracy in value_accuracies:
      value_accuracy.reset()
    for value_prediction in value_predictions:
      value_prediction.reset()
    game_lengths.reset()
    game_lengths_hist.reset()
    outcomes.reset()

    num_trajectories, num_states = collect_trajectories()
    total_trajectories += num_trajectories
    now = time.time()
    seconds = now - last_time
    last_time = now

    logger.print("Step:", step)
    logger.print(
        ("Collected {:5} states from {:3} games, {:.1f} states/s. "
         "{:.1f} states/(s*actor), game length: {:.1f}").format(
             num_states, num_trajectories, num_states / seconds,
             num_states / (config.actors * seconds),
             num_states / num_trajectories))
    logger.print("Buffer size: {}. States seen: {}".format(
        len(replay_buffer), replay_buffer.total_seen))

    save_path, losses = learn(step)

    for eval_process in evaluators:
      while True:
        try:
          difficulty, outcome = eval_process.queue.get_nowait()
          evals[difficulty].append(outcome)
        except spawn.Empty:
          break

    batch_size_stats = stats.BasicStats()  # Only makes sense in C++.
    batch_size_stats.add(1)
    data_log.write({
        "step": step,
        "total_states": replay_buffer.total_seen,
        "states_per_s": num_states / seconds,
        "states_per_s_actor": num_states / (config.actors * seconds),
        "total_trajectories": total_trajectories,
        "trajectories_per_s": num_trajectories / seconds,
        "queue_size": 0,  # Only available in C++.
        "game_length": game_lengths.as_dict,
        "game_length_hist": game_lengths_hist.data,
        "outcomes": outcomes.data,
        "value_accuracy": [v.as_dict for v in value_accuracies],
        "value_prediction": [v.as_dict for v in value_predictions],
        "eval": {
            "count": evals[0].total_seen,
            "results": [sum(e.data) / len(e) if e else 0 for e in evals],
        },
        "batch_size": batch_size_stats.as_dict,
        "batch_size_hist": [0, 1],
        "loss": {
            "policy": losses.policy,
            "value": losses.value,
            "l2reg": losses.l2,
            "sum": losses.total,
        },
        "cache": {  # Null stats because it's hard to report between processes.
            "size": 0,
            "max_size": 0,
            "usage": 0,
            "requests": 0,
            "requests_per_s": 0,
            "hits": 0,
            "misses": 0,
            "misses_per_s": 0,
            "hit_rate": 0,
        },
    })
    logger.print()

    if config.max_steps > 0 and step >= config.max_steps:
      break

    broadcast_fn(save_path)
コード例 #2
0
ファイル: dual_mcts.py プロジェクト: prashankkadam/Dual-MCTS
def learner(*, game, config, config_mpv, actors_1, actors_2, evaluators_1,
            evaluators_2, broadcast_fn, logger):
    """A learner that consumes the replay buffer and trains the network."""

    stage_count = 7
    total_trajectories_1, total_trajectories_2 = 0, 0

    def learner_inner(config_inner):
        logger.also_to_stdout = True
        replay_buffer = Buffer(config_inner.replay_buffer_size)
        learn_rate = config_inner.replay_buffer_size // config_inner.replay_buffer_reuse
        logger.print("Initializing model")
        model = _init_model_from_config(config_inner)
        logger.print("Model type: %s(%s, %s)" %
                     (config_inner.nn_model, config_inner.nn_width,
                      config_inner.nn_depth))
        logger.print("Model size:", model.num_trainable_variables, "variables")
        save_path = model.save_checkpoint(0)
        logger.print("Initial checkpoint:", save_path)
        broadcast_fn(save_path)

        value_accuracies = [stats.BasicStats() for _ in range(stage_count)]
        value_predictions = [stats.BasicStats() for _ in range(stage_count)]
        game_lengths = stats.BasicStats()
        game_lengths_hist = stats.HistogramNumbered(game.max_game_length() + 1)
        outcomes = stats.HistogramNamed(["Player1", "Player2", "Draw"])
        evals = [
            Buffer(config_inner.evaluation_window)
            for _ in range(config_inner.eval_levels)
        ]

        return replay_buffer, learn_rate, model, save_path, value_accuracies, value_predictions, \
               game_lengths, game_lengths_hist, outcomes, evals

    replay_buffer_1, learn_rate_1, model_1, save_path, value_accuracies_1, \
    value_predictions_1, game_lengths_1, game_lengths_hist_1, outcomes_1, evals_1 = learner_inner(config)

    data_log_1 = data_logger.DataLoggerJsonLines(config.path, "learner_1",
                                                 True)

    replay_buffer_2, learn_rate_2, model_2, save_path, value_accuracies_2, \
    value_predictions_2, game_lengths_2, game_lengths_hist_2, outcomes_2, evals_2 = learner_inner(config_mpv)

    data_log_2 = data_logger.DataLoggerJsonLines(config_mpv.path, "learner_2",
                                                 True)

    def trajectory_generator(actors_gen):
        """Merge all the actor queues into a single generator."""
        while True:
            found = 0
            for actor_process in actors_gen:
                try:
                    yield actor_process.queue.get_nowait()
                except spawn.Empty:
                    pass
                else:
                    found += 1
            if found == 0:
                time.sleep(0.01)  # 10ms

    def collect_trajectories(game_lengths, game_lengths_hist, outcomes,
                             replay_buffer, value_accuracies,
                             value_predictions, learn_rate, actors):
        """Collects the trajectories from actors into the replay buffer."""
        num_trajectories = 0
        num_states = 0
        for trajectory in trajectory_generator(actors):
            num_trajectories += 1
            num_states += len(trajectory.states)
            game_lengths.add(len(trajectory.states))
            game_lengths_hist.add(len(trajectory.states))

            p1_outcome = trajectory.returns[0]
            if p1_outcome > 0:
                outcomes.add(0)
            elif p1_outcome < 0:
                outcomes.add(1)
            else:
                outcomes.add(2)

            replay_buffer.extend(
                model_lib.TrainInput(s.observation, s.legals_mask, s.policy,
                                     p1_outcome) for s in trajectory.states)

            for stage in range(stage_count):
                # Scale for the length of the game
                index = (len(trajectory.states) - 1) * stage // (stage_count -
                                                                 1)
                n = trajectory.states[index]
                accurate = (n.value >=
                            0) == (trajectory.returns[n.current_player] >= 0)
                value_accuracies[stage].add(1 if accurate else 0)
                value_predictions[stage].add(abs(n.value))

            if num_states >= learn_rate:
                break
        return num_trajectories, num_states

    def learn(step, replay_buffer, model, config_learn, model_num):
        """Sample from the replay buffer, update weights and save a checkpoint."""
        losses = []
        mpv_upd = Buffer(len(replay_buffer) / 3)
        for i in range(len(replay_buffer) // config_learn.train_batch_size):
            data = replay_buffer.sample(config_learn.train_batch_size)
            losses.append(model.update(data))  # weight update
            if (i + 1) % 4 == 0:
                mpv_upd.append_buffer(
                    data)  # replay buffer sample for bigger n/w

        # Always save a checkpoint, either for keeping or for loading the weights to
        # the actors. It only allows numbers, so use -1 as "latest".
        save_path = model.save_checkpoint(
            step if step % config_learn.checkpoint_freq == 0 else -1)
        losses = sum(losses, model_lib.Losses(0, 0, 0)) / len(losses)
        logger.print(losses)
        logger.print("Checkpoint saved:", save_path)

        if model_num == 1:
            return save_path, losses, mpv_upd
        else:
            return save_path, losses

    last_time = time.time() - 60
    for step in itertools.count(1):
        for value_accuracy_1, value_accuracy_2 in zip(value_accuracies_1,
                                                      value_accuracies_2):
            value_accuracy_1.reset()
            value_accuracy_1.reset()
        for value_prediction_1, value_prediction_2 in zip(
                value_predictions_1, value_predictions_2):
            value_prediction_1.reset()
            value_prediction_2.reset()
        game_lengths_1.reset()
        game_lengths_2.reset()
        game_lengths_hist_1.reset()
        game_lengths_hist_2.reset()
        outcomes_1.reset()
        outcomes_2.reset()

        # pudb.set_trace()
        num_trajectories_1, num_states_1 = collect_trajectories(
            game_lengths_1, game_lengths_hist_1, outcomes_1, replay_buffer_1,
            value_accuracies_1, value_predictions_1, learn_rate_1, actors_1)
        total_trajectories_1 += num_trajectories_1
        now = time.time()
        seconds = now - last_time
        last_time = now

        logger.print("Step:", step)
        logger.print(
            ("Collected {:5} states from {:3} games, {:.1f} states/s. "
             "{:.1f} states/(s*actor), game length: {:.1f}").format(
                 num_states_1, num_trajectories_1, num_states_1 / seconds,
                 num_states_1 / (config.actors * seconds),
                 num_states_1 / num_trajectories_1))
        logger.print("Buffer size: {}. States seen: {}".format(
            len(replay_buffer_1), replay_buffer_1.total_seen))

        save_path, losses_1, mpv_upd_1 = learn(step, replay_buffer_1, model_1,
                                               config, 1)

        def update_buffer(mpv_upd, replay_buffer, config_buffer):
            # print("1", replay_buffer.data[0:2])
            # print("2", mpv_upd.data)
            # print("3", replay_buffer.sample(config_buffer.train_batch_size))
            for i in range(
                (len(replay_buffer) // config_buffer.train_batch_size) // 4):
                # replay_buffer.data.remove(replay_buffer.sample(config_buffer.train_batch_size))
                # random.sample(list(i for i, _ in enumerate(l)), 4)
                # replay_buffer.remove_buffer(replay_buffer.sample(config_buffer.train_batch_size))
                sampled_list = random.sample(
                    list(i for i, _ in enumerate(replay_buffer)),
                    config_buffer.train_batch_size)

                # print("Sampled list  ", sampled_list)
                replay_buffer.remove_buffer(sampled_list)
                # for idx in sampled_list:
                #     # index_buf = int(idx)
                #     replay_buffer.remove_buffer(idx)
                # replay_buffer.remove_buffer(random.sample(list(i for i, _ in enumerate(replay_buffer)),
                #                                           config_buffer.train_batch_size))

            replay_buffer.append_buffer(mpv_upd)
            random.shuffle(replay_buffer)

            return replay_buffer

        # sleep(10)

        for eval_process in evaluators_1:
            while True:
                try:
                    difficulty, outcome = eval_process.queue.get_nowait()
                    evals_1[difficulty].append(outcome)
                except spawn.Empty:
                    break

        batch_size_stats = stats.BasicStats()  # Only makes sense in C++.
        batch_size_stats.add(1)

        data_log_1.write({
            "step": step,
            "total_states": replay_buffer_1.total_seen,
            "states_per_s": num_states_1 / seconds,
            "states_per_s_actor": num_states_1 / (config.actors * seconds),
            "total_trajectories": total_trajectories_1,
            "trajectories_per_s": num_trajectories_1 / seconds,
            "queue_size": 0,  # Only available in C++.
            "game_length": game_lengths_1.as_dict,
            "game_length_hist": game_lengths_hist_1.data,
            "outcomes": outcomes_1.data,
            "value_accuracy": [v.as_dict for v in value_accuracies_1],
            "value_prediction": [v.as_dict for v in value_predictions_1],
            "eval": {
                "count": evals_1[0].total_seen,
                "results": [sum(e.data) / len(e) for e in evals_1],
            },
            "batch_size": batch_size_stats.as_dict,
            "batch_size_hist": [0, 1],
            "loss": {
                "policy": losses_1.policy,
                "value": losses_1.value,
                "l2reg": losses_1.l2,
                "sum": losses_1.total,
            },
            "cache": {  # Null stats because it's hard to report between processes.
                "size": 0,
                "max_size": 0,
                "usage": 0,
                "requests": 0,
                "requests_per_s": 0,
                "hits": 0,
                "misses": 0,
                "misses_per_s": 0,
                "hit_rate": 0,
            },
        })
        logger.print()

        num_trajectories_2, num_states_2 = collect_trajectories(
            game_lengths_2, game_lengths_hist_2, outcomes_2, replay_buffer_2,
            value_accuracies_2, value_predictions_2, learn_rate_2, actors_2)
        total_trajectories_2 += num_trajectories_2
        now = time.time()
        seconds = now - last_time
        last_time = now

        logger.print("Step:", step)
        logger.print(
            ("Collected {:5} states from {:3} games, {:.1f} states/s. "
             "{:.1f} states/(s*actor), game length: {:.1f}").format(
                 num_states_2, num_trajectories_2, num_states_2 / seconds,
                 num_states_2 / (config.actors * seconds),
                 num_states_2 / num_trajectories_2))
        logger.print("Buffer size: {}. States seen: {}".format(
            len(replay_buffer_1), replay_buffer_1.total_seen))

        # pudb.set_trace()
        replay_buffer_2 = update_buffer(mpv_upd_1, replay_buffer_2, config_mpv)

        save_path, losses_2 = learn(step, replay_buffer_2, model_2, config_mpv,
                                    2)

        # sleep(10)

        for eval_process in evaluators_2:
            while True:
                try:
                    difficulty, outcome = eval_process.queue.get_nowait()
                    evals_2[difficulty].append(outcome)
                except spawn.Empty:
                    break

        data_log_2.write({
            "step": step,
            "total_states": replay_buffer_2.total_seen,
            "states_per_s": num_states_2 / seconds,
            "states_per_s_actor": num_states_2 / (config.actors * seconds),
            "total_trajectories": total_trajectories_2,
            "trajectories_per_s": num_trajectories_2 / seconds,
            "queue_size": 0,  # Only available in C++.
            "game_length": game_lengths_2.as_dict,
            "game_length_hist": game_lengths_hist_2.data,
            "outcomes": outcomes_2.data,
            "value_accuracy": [v.as_dict for v in value_accuracies_2],
            "value_prediction": [v.as_dict for v in value_predictions_2],
            "eval": {
                "count": evals_2[0].total_seen,
                "results": [sum(e.data) / len(e) for e in evals_2],
            },
            "batch_size": batch_size_stats.as_dict,
            "batch_size_hist": [0, 1],
            "loss": {
                "policy": losses_2.policy,
                "value": losses_2.value,
                "l2reg": losses_2.l2,
                "sum": losses_2.total,
            },
            "cache": {  # Null stats because it's hard to report between processes.
                "size": 0,
                "max_size": 0,
                "usage": 0,
                "requests": 0,
                "requests_per_s": 0,
                "hits": 0,
                "misses": 0,
                "misses_per_s": 0,
                "hit_rate": 0,
            },
        })
        logger.print()

        if config.max_steps > 0 and step >= config.max_steps:
            break

        broadcast_fn(save_path)