def main(): ray.init() # Hyperparameters of PPO are not well tuned. Most of them refer to https://github.com/xtma/pytorch_car_caring/blob/master/train.py trainer = PPOTrainer(env=MyEnv, config={ "use_pytorch": True, "model": { "custom_model": "mymodel", "custom_options": { 'encoder_path': args.encoder_path, 'train_encoder': args.train_encoder }, "custom_action_dist": "mydist", }, "env_config": { 'game': 'CarRacing' }, "num_workers": args.num_workers, "num_envs_per_worker": args.num_envs_per_worker, "num_gpus": args.num_gpus, "use_gae": args.use_gae, "batch_mode": args.batch_mode, "vf_loss_coeff": args.vf_loss_coeff, "vf_clip_param": args.vf_clip_param, "lr": args.lr, "kl_coeff": args.kl_coeff, "num_sgd_iter": args.num_sgd_iter, "grad_clip": args.grad_clip, "clip_param": args.clip_param, "rollout_fragment_length": args.rollout_fragment_length, "train_batch_size": args.train_batch_size, "sgd_minibatch_size": args.sgd_minibatch_size }) for i in range(args.train_epochs): trainer.train() print("%d Train Done" % (i), "Save Freq: %d" % (args.model_save_freq)) if (i + 1) % args.model_save_freq == 0: print("%d Episodes Done" % (i)) weights = trainer.get_policy().get_weights() torch.save(weights, args.model_save_path + "%d-mode.pt" % (i + 1)) trainer.save(args.trainer_save_path) print("Done All!") trainer.stop()
def Hunter_trainer(config, reporter): multi_hunter_trainer = PPOTrainer(MultiHunterEnv, config) for _ in range(100): environment.simulate() result = multi_hunter_trainer.train() result["phase"] = 1 reporter(**result) phase1_time = result["timesteps_total"] state = multi_hunter_trainer.save() multi_hunter_trainer.stop()
def train(config, checkpoint_dir=None): trainer = PPOTrainer(config=config) if checkpoint_dir: trainer.load_checkpoint(checkpoint_dir) chk_freq = 10 if useModelFromLowLevelTrain: config_low["num_workers"] = 0 config_low["num_envs_per_worker"] = 1 config_low["num_gpus"] = 1 agentLow = PPOTrainer(config_low) agentLow.restore( "/home/aditya/ray_results/{}/{}/checkpoint_{}/checkpoint-{}". format(experiment_name, experiment_id, checkpoint_num, checkpoint_num)) lowWeight = agentLow.get_policy().get_weights() highWeight = trainer.get_policy("low_level_policy").get_weights() lowState = agentLow.get_policy().get_state() importedOptState = OrderedDict([ (k.replace("default_policy", "low_level_policy"), v) for k, v in lowState["_optimizer_variables"].items() ]) importedPolicy = { hw: lowWeight[lw] for hw, lw in zip(highWeight.keys(), lowWeight.keys()) } importedPolicy["_optimizer_variables"] = importedOptState trainer.get_policy("low_level_policy").set_state(importedPolicy) chk_freq = 1 # Hanya perlu 1 kali saja di awal untuk save model hasil import while True: result = trainer.train() tune.report(**result) if (trainer._iteration % chk_freq == 0): with tune.checkpoint_dir( step=trainer._iteration) as checkpoint_dir: trainer.save(checkpoint_dir)
def train_ppo(config, reporter): agent = PPOTrainer(config) # agent.restore("/path/checkpoint_41/checkpoint-41") # continue training i = 0 while True: result = agent.train() if reporter is None: continue else: reporter(**result) if i % 10 == 0: # save every 10th training iteration checkpoint_path = agent.save() print(checkpoint_path) i += 1
def train_model(args): # We are using custom model and environment, which need to be registered in ray/rllib # Names can be anything. register_env("DuckieTown-MultiMap", lambda _: DiscreteWrapper(MultiMapEnv())) # Define trainer. Apart from env, config/framework and config/model, which are common among trainers. trainer = PPOTrainer( env="DuckieTown-MultiMap", config={ "framework": "torch", "model": { "custom_model": "image-ppo", }, "sgd_minibatch_size": 64, "output": None, "compress_observations": True, "num_workers": 0, } ) # Start training from a checkpoint, if available. if args.model_path: trainer.restore(args.model_path) plot = plotter.Plotter('ppo_agent') for i in range(args.epochs): # Number of episodes (basically epochs) print(f'----------------------- Starting epoch {i} ----------------------- ') # train() trains only a single episode result = trainer.train() print(result) plot.add_results(result) # Save model so far. checkpoint_path = trainer.save() print(f'Epoch {i}, checkpoint saved at: {checkpoint_path}') # Cleanup CUDA memory to reduce memory usage. torch.cuda.empty_cache() # Debug log to monitor memory. print(torch.cuda.memory_summary(device=None, abbreviated=False)) plot.plot('PPO DuckieTown-MultiMap')
def my_train_fn(config, reporter): # Train for n iterations with high LR agent1 = PPOTrainer(env="CartPole-v0", config=config) for _ in range(10): result = agent1.train() result["phase"] = 1 reporter(**result) phase1_time = result["timesteps_total"] state = agent1.save() agent1.stop() # Train for n iterations with low LR config["lr"] = 0.0001 agent2 = PPOTrainer(env="CartPole-v0", config=config) agent2.restore(state) for _ in range(10): result = agent2.train() result["phase"] = 2 result["timesteps_total"] += phase1_time # keep time moving forward reporter(**result) agent2.stop()
def single_test(defaultconfig, training_trials, evaluation_trials, check, lr=0.00005, num_workers=4, num_gpus=0.25): ray.shutdown() ray.init(**ray_init_kwargs) config = ppo.DEFAULT_CONFIG.copy() if (num_gpus > 0): config["num_gpus"] = num_gpus config["num_workers"] = num_workers config["lr"] = lr config["train_batch_size"] = 8000 config["num_sgd_iter"] = 5 config["env_config"] = defaultconfig trainer = Trainer(config=config, env=qsdl.QSDEnv) for i in range(training_trials): result = trainer.train() print("train iteration", i + 1, "/", training_trials, " avg_reward =", result["episode_reward_mean"], " timesteps =", result["timesteps_total"]) if i % check == check - 1: checkpoint = trainer.save() print("checkpoint saved at", checkpoint) avgR = 0 for i in range(evaluation_trials): env = qsdl.QSDEnv(defaultconfig) obs = env.reset() done = False while not done: action = trainer.compute_action(obs) obs, r, done, _ = env.step(action) avgR += r return avgR / evaluation_trials
# best weights of this evaluation best_weight = trainer_obj.callbacks.weights_history[0] # set the best weights as player 1 policy weights trainer_obj.get_policy("player1").set_weights(best_weight) model_to_evaluate = trainer_obj.get_policy("player1").model elo_diff, model_score, minimax_score, draw = model_vs_minimax_connect3( model_to_evaluate, minimax_depth, games_vs_minimax) p1_win_rate = model_score / games_vs_minimax additional_metrics["additional_metrics"][ "minimax_depth"] = minimax_depth print("Winrate against minimax: " + str(p1_win_rate)) if 0.45 <= p1_win_rate <= 0.55: trainer_obj.save(important_ckpt_path) if p1_win_rate >= 0.6: minimax_depth += 1 if minimax_depth == max_depth: break additional_metrics["additional_metrics"][ "player1_win_rate"] = p1_win_rate additional_metrics["additional_metrics"][ "minimax_score"] = minimax_score additional_metrics["additional_metrics"][ "player1_score"] = model_score additional_metrics["additional_metrics"][ "elo_difference"] = elo_diff
"sample_batch_size": 20, "sgd_minibatch_size": 500, "num_sgd_iter": 10, "num_workers": 1, # 32 "num_envs_per_worker": 1, #5 "num_gpus": 1, "model": { "dim": 64 } }) def env_creator(env_config): return PodWorldEnv(max_steps=10000, reward_factor=10000.0) register_env("podworld_env", env_creator) agent = PPOTrainer(config=config, env="podworld_env") agent_save_path = None for i in range(50): stats = agent.train() # print(pretty_print(stats)) if i % 5 == 0 and i > 0: path = agent.save() if agent_save_path is None: agent_save_path = path print('Saved agent at', agent_save_path) logger.write((i, stats['episode_reward_min'])) print('episode_reward_mean', stats['episode_reward_min'])
"multiagent": { "policies": policies, "policy_mapping_fn": select_policy, }, "clip_actions": True, "framework": "torch", #"num_sgd_iter": 4, "lr": 1e-4, #"kl_target": 0.03, #"train_batch_size": 1024, "rollout_fragment_length": 100, #"sgd_minibatch_size": 32 } trainer = PPOTrainer(env="wanderer_roborobo", config=config) print(trainer.config.get('no_final_linear')) print('model built') stop_iter = 2000 #%% import numpy as np for i in range(stop_iter): print("== Iteration", i, "==") result_ppo = trainer.train() pretty_print(result_ppo) if (i+1) % 200 == 0: trainer.save('model') trainer.save('model') del trainer ray.shutdown()
result = train_agent.train() # Print the training status for field in results_fields_filter: if not isinstance(field, list): if field in result.keys(): print(f"{field}: {result[field]}") else: for subfield in field[1]: if subfield in result[field[0]].keys(): print(f"{subfield} : {result[field[0]][subfield]}") print("============================") except KeyboardInterrupt: print("Interrupting training...") finally: checkpoint_path = train_agent.save() # ================= Enjoy a trained agent ================= t_end = 10.0 # Total duration of the simulation(s) in seconds try: env = env_creator(rllib_cfg["env_config"]) test_agent = Trainer(agent_cfg, env="my_custom_env") test_agent.restore(checkpoint_path) t_init = time.time() t_prev = t_init while t_prev - t_init < t_end: observ = env.reset() done = False cumulative_reward = 0
print(f".. best checkpoint was: {best_checkpoint}") # Create a new dummy Trainer to "fix" our checkpoint. new_trainer = PPOTrainer(config=config) # Get untrained weights for all policies. untrained_weights = new_trainer.get_weights() # Restore all policies from checkpoint. new_trainer.restore(best_checkpoint) # Set back all weights (except for 1st agent) to original # untrained weights. new_trainer.set_weights( {pid: w for pid, w in untrained_weights.items() if pid != "policy_0"}) # Create the checkpoint from which tune can pick up the # experiment. new_checkpoint = new_trainer.save() new_trainer.stop() print(".. checkpoint to restore from (all policies reset, " f"except policy_0): {new_checkpoint}") print("Starting new tune.run") # Start our actual experiment. stop = { "episode_reward_mean": args.stop_reward, "timesteps_total": args.stop_timesteps, "training_iteration": args.stop_iters, } # Make sure, the non-1st policies are not updated anymore. config["multiagent"]["policies_to_train"] = [
# if agent_id.startswith("fb_1_") # else random.choice(["fb_1", "fb_2"]) # }, # }, # ) trainer = PPOTrainer( env="rcrs_env", config={ "env": "rcrs_env", "num_workers": 1, "multiagent": { "policies": { "fb_1": (None, obs_space, act_space, {}), "fb_2": (None, obs_space, act_space, {}), }, "policy_mapping_fn": lambda agent_id: "fb_1" if agent_id.startswith("fb_1_") else random.choice(["fb_1", "fb_2"]) }, }, ) for i in range(2): result = trainer.train() print(pretty_print(result)) if i % 1 == 0: checkpoint = trainer.save() print("checkpoint saved at", checkpoint) statess = trainer.save() trainer.stop()
#trainer.restore('./checkpoints_flush/checkpoint_379/checkpoint-379') step = 0 best_val = 0.0 if False: save_0 = trainer.save_to_object() while True: if False: save_0 = trainer.save_to_object() result = trainer.train() while result['episode_reward_mean'] > best_val: print('UPENING') best_save = deepcopy(save_0) best_val = result['episode_reward_mean'] save_0 = trainer.save_to_object() trainer.save('./checkpoints_flush') result = trainer.train() print('REVERTING') trainer.restore_from_object(best_save) else: result = trainer.train() if result['episode_reward_mean'] > best_val: improved = step best_val = result['episode_reward_mean'] trainer.save('./checkpoints_iter_' + str(iteration)) elif step > improved + last_improve: trainer.save('./checkpoints_iter_' + str(iteration)) break step += 1 print(step, best_val, result['episode_reward_mean']) sys.stdout.flush()
"policies": policies, "policy_mapping_fn": lambda agent_id: "ppo_policy", }, # "num_gpus": 0, # "num_gpus_per_worker": 0, "callbacks": PlayerScoreCallbacks }) if restore_checkpoint: trainer.restore(checkpoint_path) start = time.time() try: for i in range(num_iter): res = trainer.train() print("Iteration {}. policy result: {}".format(i, res)) if i % eval_every == 0: trainer_eval.set_weights(trainer.get_weights(["ppo_policy"])) res = trainer_eval.train() if i % checkpoint_every == 0: trainer.save() except: trainer.save() stop = time.time() train_duration = time.strftime('%H:%M:%S', time.gmtime(stop - start)) print( 'Training finished ({}), check the results in ~/ray_results/<dir>/'.format( train_duration))
result = agent1.train() print(pretty_print(result)) config2 = DEFAULT_CONFIG.copy() config2['num_workers'] = 4 config2['num_sgd_iter'] = 30 config2['sgd_minibatch_size'] = 128 config2['model']['fcnet_hiddens'] = [100, 100] config2['num_cpus_per_worker'] = 0 agent2 = PPOTrainer(config2, 'CartPole-v0') for i in range(2): result = agent2.train() print(pretty_print(result)) checkpoint_path = agent2.save() print(checkpoint_path) trained_config = config2.copy() test_agent = PPOTrainer(trained_config, 'CartPole-v0') test_agent.restore(checkpoint_path) env = gym.make('CartPole-v0') state = env.reset() done = False cumulative_reward = 0 while not done: action = test_agent.compute_action(state) state, reward, done, _ = env.step(action) cumulative_reward += reward
def main(): ray.init() logging.getLogger().setLevel(logging.INFO) date = datetime.now().strftime('%Y%m%d_%H%M%S') parser = argparse.ArgumentParser() # parser.add_argument('--scenario', type=str, default='PongNoFrameskip-v4') parser.add_argument('--config', type=str, default='config/global_config.json', help='config file') parser.add_argument('--algo', type=str, default='PPO', choices=['DQN', 'DDQN', 'DuelDQN'], help='choose an algorithm') parser.add_argument('--inference', action="store_true", help='inference or training') parser.add_argument('--ckpt', type=str, help='inference or training') parser.add_argument('--epoch', type=int, default=10, help='number of training epochs') parser.add_argument( '--num_step', type=int, default=10**3, help='number of timesteps for one episode, and for inference') parser.add_argument('--save_freq', type=int, default=100, help='model saving frequency') parser.add_argument('--batch_size', type=int, default=128, help='model saving frequency') parser.add_argument('--state_time_span', type=int, default=5, help='state interval to receive long term state') parser.add_argument('--time_span', type=int, default=30, help='time interval to collect data') args = parser.parse_args() config_env = env_config(args) # ray.tune.register_env('gym_cityflow', lambda env_config:CityflowGymEnv(config_env)) config_agent = agent_config(config_env) # # build cityflow environment trainer = PPOTrainer(env=CityflowGymEnv, config=config_agent) for i in range(1000): # Perform one iteration of training the policy with PPO result = trainer.train() print(pretty_print(result)) if i % 30 == 0: checkpoint = trainer.save() print("checkpoint saved at", checkpoint)
class KandboxAgentRLLibPPO(KandboxAgentPlugin): title = "Kandbox Plugin - Agent - realtime - by rllib ppo" slug = "ri_agent_rl_ppo" author = "Kandbox" author_url = "https://github.com/qiyangduan" description = "RLLibPPO for GYM for RL." version = "0.1.0" default_config = { "nbr_of_actions": 4, "n_epochs": 1000, "nbr_of_days_planning_window": 6, "model_path": "default_model_path", "working_dir": "/tmp", "checkpoint_path_key": "ppo_checkpoint_path", } config_form_spec = { "type": "object", "properties": {}, } def __init__(self, agent_config, kandbox_config): self.agent_config = agent_config self.current_best_episode_reward_mean = -99 env_config = agent_config["env_config"] if "rules_slug_config_list" not in env_config.keys(): if "rules" not in env_config.keys(): log.error("no rules_slug_config_list and no rules ") else: env_config["rules_slug_config_list"] = [ [rule.slug, rule.config] for rule in env_config["rules"] ] env_config.pop("rules", None) # self.env_class = env_class = agent_config["env"] self.kandbox_config = self.default_config.copy() self.kandbox_config.update(kandbox_config) # self.trained_model = trained_model self.kandbox_config["create_datetime"] = datetime.now() # self.trainer = None self.env_config = env_config # self.load_model(env_config=self.env_config) print( f"KandboxAgentRLLibPPO __init__ called, at time {self.kandbox_config['create_datetime']}" ) # import pdb # pdb.set_trace() if not ray.is_initialized(): ray.init(ignore_reinit_error=True, log_to_driver=False) # ray.init(redis_address="localhost:6379") def build_model(self): trainer_config = DEFAULT_CONFIG.copy() trainer_config["num_workers"] = 0 # trainer_config["train_batch_size"] = 640 # trainer_config["sgd_minibatch_size"] = 160 # trainer_config["num_sgd_iter"] = 100 trainer_config["exploration_config"] = { "type": "Random", } # EpsilonGreedy(Exploration): # trainer_config["exploration_config"] = { # "type": "Curiosity", # "eta": 0.2, # "lr": 0.001, # "feature_dim": 128, # "feature_net_config": { # "fcnet_hiddens": [], # "fcnet_activation": "relu", # }, # "sub_exploration": { # "type": "StochasticSampling", # } # } # trainer_config["log_level"] = "DEBUG" """ if env_config is not None: for x in env_config.keys(): trainer_config[x] = env_config[x] """ # trainer_config["env_config"] = copy.deepcopy(env_config) # {"rules": "qiyang_role"} trainer_config.update(self.agent_config) self.trainer = PPOTrainer(trainer_config, self.agent_config["env"]) # self.config["trainer"] = self.trainer return self.trainer def load_model(self): # , allow_empty = None env_config = self.agent_config["env_config"] self.trainer = self.build_model() # if (model_path is not None) & (os.path.exists(model_path)): if "ppo_checkpoint_path" in env_config.keys(): # raise FileNotFoundError("can not find model at path: {}".format(model_path)) if os.path.exists(env_config["ppo_checkpoint_path"]): self.trainer.restore(env_config["ppo_checkpoint_path"]) print("Reloaded model from path: {} ".format( env_config["ppo_checkpoint_path"])) else: print( "Env_config has ppo_checkpoint_path = {}, but no files found. I am returning an initial model" .format(env_config["ppo_checkpoint_path"])) else: print( "Env_config has no ppo_checkpoint_path, returning an initial model" ) # self.config["model_path"] = model_path # self.config["trainer"] = self.trainer # self.config["policy"] = self.trainer.workers.local_worker().get_policy() self.policy = self.trainer.workers.local_worker().get_policy() return self.trainer def train_model(self): # self.trainer = self.build_model() for i in range(self.kandbox_config["n_epochs"]): result = self.trainer.train() # print(pretty_print(result)) print( "Finished training iteration {}, Result: episodes_this_iter:{}, policy_reward_max: {}, episode_reward_max {}, episode_reward_mean {}, info.num_steps_trained: {}..." .format( i, result["episodes_this_iter"], result["policy_reward_max"], result["episode_reward_max"], result["episode_reward_mean"], result["info"]["num_steps_trained"], )) if result[ "episode_reward_mean"] > self.current_best_episode_reward_mean * 1.1: model_path = self.save_model() print( "Model is saved after 10 percent increase, episode_reward_mean = {}, file = {}" .format(result["episode_reward_mean"], model_path)) self.current_best_episode_reward_mean = result[ "episode_reward_mean"] return self.save_model() def save_model(self): checkpoint_dir = "{}/model_checkpoint_org_{}_team_{}".format( self.agent_config["env_config"]["working_dir"], self.agent_config["env_config"]["org_code"], self.agent_config["env_config"]["team_id"], ) _path = self.trainer.save(checkpoint_dir=checkpoint_dir) # exported_model_dir = "{}/exported_ppo_model_org_{}_team_{}".format( # self.agent_config["env_config"]["working_dir"], self.agent_config["env_config"]["org_code"], self.agent_config["env_config"]["team_id"] # ) # self.trainer.get_policy().export_model(exported_model_dir + "/1") return _path # self.trainer def predict_action(self, observation=None): action = self.trainer.compute_action(observation) return action def predict_action_list(self, env=None, job_code=None, observation=None): actions = [] if env is not None: self.env = env else: env = self.env if job_code is None: job_i = env.current_job_i else: job_i = env.jobs_dict[job_code].job_index observation = env._get_observation() # export_dir = "/Users/qiyangduan/temp/kandbox/exported_ppo_model_org_duan3_team_3/1" # loaded_policy = tf.saved_model.load(export_dir) # loaded_policy.signatures["serving_default"](observations=observation) predicted_action = self.trainer.compute_action(observation) # V predicted_action = self.policy.compute_action(observation) for _ in range(len(env.workers)): # hist_job_workers_ranked: if len(actions) >= self.config["nbr_of_actions"]: return actions actions.append(list(predicted_action).copy()) max_i = np.argmax(predicted_action[0:len(env.workers)]) predicted_action[max_i] = 0 return actions def predict_action_dict_list(self, env=None, job_code=None, observation=None): if env is not None: self.env = env else: env = self.env curr_job = copy.deepcopy(env.jobs_dict[job_code]) if job_code is None: job_i = env.current_job_i else: job_i = curr_job.job_index env.current_job_i = job_i observation = env._get_observation() action = self.predict_action(observation=observation) action_dict = env.decode_action_into_dict_native(action=action) action_day = int(action_dict.scheduled_start_minutes / 1440) curr_job.requested_start_min_minutes = action_day * 1440 curr_job.requested_start_max_minutes = (action_day + 1) * 1440 action_dict_list = self.env.recommendation_server.search_action_dict_on_worker_day( a_worker_code_list=action.scheduled_worker_codes, curr_job=curr_job, max_number_of_matching=3, ) return action_dict_list
ten_gig = 10737418240 trainer = PPOTrainer( env="ic20env", config=merge_dicts( DEFAULT_CONFIG, { # -- Rollout-Worker 'num_gpus': 1, 'num_workers': 10, "num_envs_per_worker": 1, "num_cpus_per_worker": 1, "memory_per_worker": ten_gig, 'gamma': 0.99, 'lambda': 0.95 })) # Attempt to restore from checkpoint if possible. if os.path.exists(CHECKPOINT_FILE): checkpoint_path = open(CHECKPOINT_FILE).read() print("Restoring from checkpoint path", checkpoint_path) trainer.restore(checkpoint_path) # Serving and training loop while True: print(pretty_print(trainer.train())) checkpoint_path = trainer.save() print("Last checkpoint", checkpoint_path) with open(CHECKPOINT_FILE, "w") as f: f.write(checkpoint_path)
"interaction_hidden_size": 4, }, }, "clip_actions": True, "framework": "torch", "num_sgd_iter": 3, "lr": 1e-4, #"kl_target": 0.03, "no_done_at_end": False, "soft_horizon": True, "train_batch_size": 100, "rollout_fragment_length": 100, "sgd_minibatch_size": 32 } trainer = PPOTrainer(env="negotiate_roborobo", config=config) print(trainer.config.get('no_final_linear')) print('model built') stop_iter = 2000 #%% import numpy as np for i in range(stop_iter): print("== Iteration", i, "==") result_ppo = trainer.train() if (i + 1) % 1 == 0: trainer.save('model_nego') trainer.save('model_nego') del trainerii ray.shutdown()
len_moving_average = np.convolve(episode_len_mean, np.ones((20, )) / 20, mode='valid') reward_moving_average = np.convolve(episode_reward_mean, np.ones((20, )) / 20, mode='valid') print('Current ::: Len:: Mean: ' + str(episode_len_mean[-1]) + '; Reward:: Mean: ' + str(episode_reward_mean[-1]) + ', Max: ' + str(episode_reward_max[-1]) + ', Min: ' + str(episode_reward_min[-1])) print('mAverage20 ::: Len:: Mean: ' + str(np.round(len_moving_average[-1], 1)) + '; Reward:: Mean: ' + str(np.round(reward_moving_average[-1], 1))) if result['training_iteration'] % 10 == 0: checkpoint = trainer.save() print("checkpoint saved at", checkpoint) output = { 'episode_len_mean': episode_len_mean, 'episode_reward_mean': episode_reward_mean, 'episode_reward_max': episode_reward_max, 'episode_reward_min': episode_reward_min, 'num_steps_trained': num_steps_trained, 'clock_time': clock_time, 'training_iteration': training_iteration, 'len_moving_average': len_moving_average, 'reward_moving_average': reward_moving_average } output_path = trainer._logdir + '/_running_results.pkl' with open(output_path, 'wb') as handle: pickle.dump(output, handle, protocol=pickle.HIGHEST_PROTOCOL)
custom_metrics_file) else: best_ckpt = 0 print("Starting training from scratch") for epoch in tqdm(range(best_ckpt + 1, epochs)): print("Epoch " + str(epoch)) results = trainer_obj.train() p1_score = results["custom_metrics"]["player1_score"] minimax_score = results["custom_metrics"]["player2_score"] score_difference = results["custom_metrics"]["score_difference"] actual_depth = trainer_obj.get_policy("minimax").depth if epoch % ckpt_step == 0 and epoch != 0: custom_metrics = results["custom_metrics"] save_checkpoint(trainer_obj, ckpt_dir, custom_metrics_file, custom_metrics, ckpt_to_keep) if p1_score >= minimax_score: print("Player 1 was able to beat MiniMax algorithm with depth " + str(actual_depth)) new_depth = actual_depth + 1 print("Increasing Minimax depth to " + str(new_depth)) trainer_obj.get_policy("minimax").depth = new_depth trainer_obj.save(Config.IMPORTANT_CKPT_PATH) if new_depth > max_depth: print("Max depth reached, training is over\n") break