def save_best_response_checkpoint(trainer: DQNTrainer, player: int, save_dir: str, timesteps_training_br: int, episodes_training_br: int, active_policy_num: int = None): policy_name = active_policy_num if active_policy_num is not None else "unclaimed" date_time = datetime_str() checkpoint_name = f"policy_{policy_name}_{date_time}.h5" checkpoint_path = os.path.join(save_dir, checkpoint_name) br_weights = trainer.get_weights([f"best_response"])["best_response"] br_weights = {k.replace(".", "_dot_"): v for k, v in br_weights.items() } # periods cause HDF5 NaturalNaming warnings ensure_dir(file_path=checkpoint_path) num_save_attempts = 5 for attempt in range(num_save_attempts): try: deepdish.io.save(path=checkpoint_path, data={ "weights": br_weights, "player": player, "policy_num": active_policy_num, "date_time_str": date_time, "seconds_since_epoch": time.time(), "timesteps_training_br": timesteps_training_br, "episodes_training_br": episodes_training_br }) break except HDF5ExtError: if attempt + 1 == num_save_attempts: raise time.sleep(1.0) return checkpoint_path
# dqn_policy: X # ppo_policy: Y for i in range(args.stop_iters): print("== Iteration", i, "==") # improve the DQN policy print("-- DQN --") result_dqn = dqn_trainer.train() print(pretty_print(result_dqn)) # improve the PPO policy print("-- PPO --") result_ppo = ppo_trainer.train() print(pretty_print(result_ppo)) # Test passed gracefully. if args.as_test and \ result_dqn["episode_reward_mean"] > args.stop_reward and \ result_ppo["episode_reward_mean"] > args.stop_reward: print("test passed (both agents above requested reward)") quit(0) # swap weights to synchronize dqn_trainer.set_weights(ppo_trainer.get_weights(["ppo_policy"])) ppo_trainer.set_weights(dqn_trainer.get_weights(["dqn_policy"])) # Desired reward not reached. if args.as_test: raise ValueError("Desired reward ({}) not reached!".format( args.stop_reward))