def learner(*, game, config, actors, evaluators, broadcast_fn, logger): """A learner that consumes the replay buffer and trains the network.""" logger.also_to_stdout = True replay_buffer = Buffer(config.replay_buffer_size) learn_rate = config.replay_buffer_size // config.replay_buffer_reuse logger.print("Initializing model") model = _init_model_from_config(config) logger.print("Model type: %s(%s, %s)" % (config.nn_model, config.nn_width, config.nn_depth)) logger.print("Model size:", model.num_trainable_variables, "variables") save_path = model.save_checkpoint(0) logger.print("Initial checkpoint:", save_path) broadcast_fn(save_path) data_log = data_logger.DataLoggerJsonLines(config.path, "learner", True) stage_count = 7 value_accuracies = [stats.BasicStats() for _ in range(stage_count)] value_predictions = [stats.BasicStats() for _ in range(stage_count)] game_lengths = stats.BasicStats() game_lengths_hist = stats.HistogramNumbered(game.max_game_length() + 1) outcomes = stats.HistogramNamed(["Player1", "Player2", "Draw"]) evals = [Buffer(config.evaluation_window) for _ in range(config.eval_levels)] total_trajectories = 0 def trajectory_generator(): """Merge all the actor queues into a single generator.""" while True: found = 0 for actor_process in actors: try: yield actor_process.queue.get_nowait() except spawn.Empty: pass else: found += 1 if found == 0: time.sleep(0.01) # 10ms def collect_trajectories(): """Collects the trajectories from actors into the replay buffer.""" num_trajectories = 0 num_states = 0 for trajectory in trajectory_generator(): num_trajectories += 1 num_states += len(trajectory.states) game_lengths.add(len(trajectory.states)) game_lengths_hist.add(len(trajectory.states)) p1_outcome = trajectory.returns[0] if p1_outcome > 0: outcomes.add(0) elif p1_outcome < 0: outcomes.add(1) else: outcomes.add(2) replay_buffer.extend( model_lib.TrainInput( s.observation, s.legals_mask, s.policy, p1_outcome) for s in trajectory.states) for stage in range(stage_count): # Scale for the length of the game index = (len(trajectory.states) - 1) * stage // (stage_count - 1) n = trajectory.states[index] accurate = (n.value >= 0) == (trajectory.returns[n.current_player] >= 0) value_accuracies[stage].add(1 if accurate else 0) value_predictions[stage].add(abs(n.value)) if num_states >= learn_rate: break return num_trajectories, num_states def learn(step): """Sample from the replay buffer, update weights and save a checkpoint.""" losses = [] for _ in range(len(replay_buffer) // config.train_batch_size): data = replay_buffer.sample(config.train_batch_size) losses.append(model.update(data)) # Always save a checkpoint, either for keeping or for loading the weights to # the actors. It only allows numbers, so use -1 as "latest". save_path = model.save_checkpoint( step if step % config.checkpoint_freq == 0 else -1) losses = sum(losses, model_lib.Losses(0, 0, 0)) / len(losses) logger.print(losses) logger.print("Checkpoint saved:", save_path) return save_path, losses last_time = time.time() - 60 for step in itertools.count(1): for value_accuracy in value_accuracies: value_accuracy.reset() for value_prediction in value_predictions: value_prediction.reset() game_lengths.reset() game_lengths_hist.reset() outcomes.reset() num_trajectories, num_states = collect_trajectories() total_trajectories += num_trajectories now = time.time() seconds = now - last_time last_time = now logger.print("Step:", step) logger.print( ("Collected {:5} states from {:3} games, {:.1f} states/s. " "{:.1f} states/(s*actor), game length: {:.1f}").format( num_states, num_trajectories, num_states / seconds, num_states / (config.actors * seconds), num_states / num_trajectories)) logger.print("Buffer size: {}. States seen: {}".format( len(replay_buffer), replay_buffer.total_seen)) save_path, losses = learn(step) for eval_process in evaluators: while True: try: difficulty, outcome = eval_process.queue.get_nowait() evals[difficulty].append(outcome) except spawn.Empty: break batch_size_stats = stats.BasicStats() # Only makes sense in C++. batch_size_stats.add(1) data_log.write({ "step": step, "total_states": replay_buffer.total_seen, "states_per_s": num_states / seconds, "states_per_s_actor": num_states / (config.actors * seconds), "total_trajectories": total_trajectories, "trajectories_per_s": num_trajectories / seconds, "queue_size": 0, # Only available in C++. "game_length": game_lengths.as_dict, "game_length_hist": game_lengths_hist.data, "outcomes": outcomes.data, "value_accuracy": [v.as_dict for v in value_accuracies], "value_prediction": [v.as_dict for v in value_predictions], "eval": { "count": evals[0].total_seen, "results": [sum(e.data) / len(e) if e else 0 for e in evals], }, "batch_size": batch_size_stats.as_dict, "batch_size_hist": [0, 1], "loss": { "policy": losses.policy, "value": losses.value, "l2reg": losses.l2, "sum": losses.total, }, "cache": { # Null stats because it's hard to report between processes. "size": 0, "max_size": 0, "usage": 0, "requests": 0, "requests_per_s": 0, "hits": 0, "misses": 0, "misses_per_s": 0, "hit_rate": 0, }, }) logger.print() if config.max_steps > 0 and step >= config.max_steps: break broadcast_fn(save_path)
def learner(*, game, config, config_mpv, actors_1, actors_2, evaluators_1, evaluators_2, broadcast_fn, logger): """A learner that consumes the replay buffer and trains the network.""" stage_count = 7 total_trajectories_1, total_trajectories_2 = 0, 0 def learner_inner(config_inner): logger.also_to_stdout = True replay_buffer = Buffer(config_inner.replay_buffer_size) learn_rate = config_inner.replay_buffer_size // config_inner.replay_buffer_reuse logger.print("Initializing model") model = _init_model_from_config(config_inner) logger.print("Model type: %s(%s, %s)" % (config_inner.nn_model, config_inner.nn_width, config_inner.nn_depth)) logger.print("Model size:", model.num_trainable_variables, "variables") save_path = model.save_checkpoint(0) logger.print("Initial checkpoint:", save_path) broadcast_fn(save_path) value_accuracies = [stats.BasicStats() for _ in range(stage_count)] value_predictions = [stats.BasicStats() for _ in range(stage_count)] game_lengths = stats.BasicStats() game_lengths_hist = stats.HistogramNumbered(game.max_game_length() + 1) outcomes = stats.HistogramNamed(["Player1", "Player2", "Draw"]) evals = [ Buffer(config_inner.evaluation_window) for _ in range(config_inner.eval_levels) ] return replay_buffer, learn_rate, model, save_path, value_accuracies, value_predictions, \ game_lengths, game_lengths_hist, outcomes, evals replay_buffer_1, learn_rate_1, model_1, save_path, value_accuracies_1, \ value_predictions_1, game_lengths_1, game_lengths_hist_1, outcomes_1, evals_1 = learner_inner(config) data_log_1 = data_logger.DataLoggerJsonLines(config.path, "learner_1", True) replay_buffer_2, learn_rate_2, model_2, save_path, value_accuracies_2, \ value_predictions_2, game_lengths_2, game_lengths_hist_2, outcomes_2, evals_2 = learner_inner(config_mpv) data_log_2 = data_logger.DataLoggerJsonLines(config_mpv.path, "learner_2", True) def trajectory_generator(actors_gen): """Merge all the actor queues into a single generator.""" while True: found = 0 for actor_process in actors_gen: try: yield actor_process.queue.get_nowait() except spawn.Empty: pass else: found += 1 if found == 0: time.sleep(0.01) # 10ms def collect_trajectories(game_lengths, game_lengths_hist, outcomes, replay_buffer, value_accuracies, value_predictions, learn_rate, actors): """Collects the trajectories from actors into the replay buffer.""" num_trajectories = 0 num_states = 0 for trajectory in trajectory_generator(actors): num_trajectories += 1 num_states += len(trajectory.states) game_lengths.add(len(trajectory.states)) game_lengths_hist.add(len(trajectory.states)) p1_outcome = trajectory.returns[0] if p1_outcome > 0: outcomes.add(0) elif p1_outcome < 0: outcomes.add(1) else: outcomes.add(2) replay_buffer.extend( model_lib.TrainInput(s.observation, s.legals_mask, s.policy, p1_outcome) for s in trajectory.states) for stage in range(stage_count): # Scale for the length of the game index = (len(trajectory.states) - 1) * stage // (stage_count - 1) n = trajectory.states[index] accurate = (n.value >= 0) == (trajectory.returns[n.current_player] >= 0) value_accuracies[stage].add(1 if accurate else 0) value_predictions[stage].add(abs(n.value)) if num_states >= learn_rate: break return num_trajectories, num_states def learn(step, replay_buffer, model, config_learn, model_num): """Sample from the replay buffer, update weights and save a checkpoint.""" losses = [] mpv_upd = Buffer(len(replay_buffer) / 3) for i in range(len(replay_buffer) // config_learn.train_batch_size): data = replay_buffer.sample(config_learn.train_batch_size) losses.append(model.update(data)) # weight update if (i + 1) % 4 == 0: mpv_upd.append_buffer( data) # replay buffer sample for bigger n/w # Always save a checkpoint, either for keeping or for loading the weights to # the actors. It only allows numbers, so use -1 as "latest". save_path = model.save_checkpoint( step if step % config_learn.checkpoint_freq == 0 else -1) losses = sum(losses, model_lib.Losses(0, 0, 0)) / len(losses) logger.print(losses) logger.print("Checkpoint saved:", save_path) if model_num == 1: return save_path, losses, mpv_upd else: return save_path, losses last_time = time.time() - 60 for step in itertools.count(1): for value_accuracy_1, value_accuracy_2 in zip(value_accuracies_1, value_accuracies_2): value_accuracy_1.reset() value_accuracy_1.reset() for value_prediction_1, value_prediction_2 in zip( value_predictions_1, value_predictions_2): value_prediction_1.reset() value_prediction_2.reset() game_lengths_1.reset() game_lengths_2.reset() game_lengths_hist_1.reset() game_lengths_hist_2.reset() outcomes_1.reset() outcomes_2.reset() # pudb.set_trace() num_trajectories_1, num_states_1 = collect_trajectories( game_lengths_1, game_lengths_hist_1, outcomes_1, replay_buffer_1, value_accuracies_1, value_predictions_1, learn_rate_1, actors_1) total_trajectories_1 += num_trajectories_1 now = time.time() seconds = now - last_time last_time = now logger.print("Step:", step) logger.print( ("Collected {:5} states from {:3} games, {:.1f} states/s. " "{:.1f} states/(s*actor), game length: {:.1f}").format( num_states_1, num_trajectories_1, num_states_1 / seconds, num_states_1 / (config.actors * seconds), num_states_1 / num_trajectories_1)) logger.print("Buffer size: {}. States seen: {}".format( len(replay_buffer_1), replay_buffer_1.total_seen)) save_path, losses_1, mpv_upd_1 = learn(step, replay_buffer_1, model_1, config, 1) def update_buffer(mpv_upd, replay_buffer, config_buffer): # print("1", replay_buffer.data[0:2]) # print("2", mpv_upd.data) # print("3", replay_buffer.sample(config_buffer.train_batch_size)) for i in range( (len(replay_buffer) // config_buffer.train_batch_size) // 4): # replay_buffer.data.remove(replay_buffer.sample(config_buffer.train_batch_size)) # random.sample(list(i for i, _ in enumerate(l)), 4) # replay_buffer.remove_buffer(replay_buffer.sample(config_buffer.train_batch_size)) sampled_list = random.sample( list(i for i, _ in enumerate(replay_buffer)), config_buffer.train_batch_size) # print("Sampled list ", sampled_list) replay_buffer.remove_buffer(sampled_list) # for idx in sampled_list: # # index_buf = int(idx) # replay_buffer.remove_buffer(idx) # replay_buffer.remove_buffer(random.sample(list(i for i, _ in enumerate(replay_buffer)), # config_buffer.train_batch_size)) replay_buffer.append_buffer(mpv_upd) random.shuffle(replay_buffer) return replay_buffer # sleep(10) for eval_process in evaluators_1: while True: try: difficulty, outcome = eval_process.queue.get_nowait() evals_1[difficulty].append(outcome) except spawn.Empty: break batch_size_stats = stats.BasicStats() # Only makes sense in C++. batch_size_stats.add(1) data_log_1.write({ "step": step, "total_states": replay_buffer_1.total_seen, "states_per_s": num_states_1 / seconds, "states_per_s_actor": num_states_1 / (config.actors * seconds), "total_trajectories": total_trajectories_1, "trajectories_per_s": num_trajectories_1 / seconds, "queue_size": 0, # Only available in C++. "game_length": game_lengths_1.as_dict, "game_length_hist": game_lengths_hist_1.data, "outcomes": outcomes_1.data, "value_accuracy": [v.as_dict for v in value_accuracies_1], "value_prediction": [v.as_dict for v in value_predictions_1], "eval": { "count": evals_1[0].total_seen, "results": [sum(e.data) / len(e) for e in evals_1], }, "batch_size": batch_size_stats.as_dict, "batch_size_hist": [0, 1], "loss": { "policy": losses_1.policy, "value": losses_1.value, "l2reg": losses_1.l2, "sum": losses_1.total, }, "cache": { # Null stats because it's hard to report between processes. "size": 0, "max_size": 0, "usage": 0, "requests": 0, "requests_per_s": 0, "hits": 0, "misses": 0, "misses_per_s": 0, "hit_rate": 0, }, }) logger.print() num_trajectories_2, num_states_2 = collect_trajectories( game_lengths_2, game_lengths_hist_2, outcomes_2, replay_buffer_2, value_accuracies_2, value_predictions_2, learn_rate_2, actors_2) total_trajectories_2 += num_trajectories_2 now = time.time() seconds = now - last_time last_time = now logger.print("Step:", step) logger.print( ("Collected {:5} states from {:3} games, {:.1f} states/s. " "{:.1f} states/(s*actor), game length: {:.1f}").format( num_states_2, num_trajectories_2, num_states_2 / seconds, num_states_2 / (config.actors * seconds), num_states_2 / num_trajectories_2)) logger.print("Buffer size: {}. States seen: {}".format( len(replay_buffer_1), replay_buffer_1.total_seen)) # pudb.set_trace() replay_buffer_2 = update_buffer(mpv_upd_1, replay_buffer_2, config_mpv) save_path, losses_2 = learn(step, replay_buffer_2, model_2, config_mpv, 2) # sleep(10) for eval_process in evaluators_2: while True: try: difficulty, outcome = eval_process.queue.get_nowait() evals_2[difficulty].append(outcome) except spawn.Empty: break data_log_2.write({ "step": step, "total_states": replay_buffer_2.total_seen, "states_per_s": num_states_2 / seconds, "states_per_s_actor": num_states_2 / (config.actors * seconds), "total_trajectories": total_trajectories_2, "trajectories_per_s": num_trajectories_2 / seconds, "queue_size": 0, # Only available in C++. "game_length": game_lengths_2.as_dict, "game_length_hist": game_lengths_hist_2.data, "outcomes": outcomes_2.data, "value_accuracy": [v.as_dict for v in value_accuracies_2], "value_prediction": [v.as_dict for v in value_predictions_2], "eval": { "count": evals_2[0].total_seen, "results": [sum(e.data) / len(e) for e in evals_2], }, "batch_size": batch_size_stats.as_dict, "batch_size_hist": [0, 1], "loss": { "policy": losses_2.policy, "value": losses_2.value, "l2reg": losses_2.l2, "sum": losses_2.total, }, "cache": { # Null stats because it's hard to report between processes. "size": 0, "max_size": 0, "usage": 0, "requests": 0, "requests_per_s": 0, "hits": 0, "misses": 0, "misses_per_s": 0, "hit_rate": 0, }, }) logger.print() if config.max_steps > 0 and step >= config.max_steps: break broadcast_fn(save_path)