예제 #1
0
    def train(config, checkpoint_dir=None):
        trainer = DQNTrainer(config=config, env='BomberMan-v0')
        # trainer.restore('C:\\Users\\Florian\\ray_results\\PPO_BomberMan-v0_2021-03-16_09-20-44984tj3ip\\checkpoint_002770\\checkpoint-2770')
        iter = 0

        # def update_phase(ev):
        #    ev.foreach_env(lambda e: e.set_phase(phase))

        while True:
            iter += 1
            result = trainer.train()
            if iter % 250 == 0:
                if not os.path.exists(f'./model-{iter}'):
                    trainer.get_policy('policy_01').export_model(
                        f'./model-{iter}')
                else:
                    print("model already saved")
예제 #2
0
 def test_policy_save_restore(self):
     config = DEFAULT_CONFIG.copy()
     for _ in framework_iterator(config):
         trainer = DQNTrainer(config=config, env="CartPole-v0")
         policy = trainer.get_policy()
         state1 = policy.get_state()
         trainer.train()
         state2 = policy.get_state()
         check(state1["_exploration_state"]["last_timestep"],
               state2["_exploration_state"]["last_timestep"],
               false=True)
         check(state1["global_timestep"],
               state2["global_timestep"],
               false=True)
         # Reset policy to its original state and compare.
         policy.set_state(state1)
         state3 = policy.get_state()
         # Make sure everything is the same.
         check(state1, state3)
예제 #3
0
 #                      num_envs=NUM_ENVS)
 # config = {"env": "warehouse_env",
 #           "framework": "torch",
 #           "num_gpus": 0.1,
 #           "num_gpus_per_worker": 0.1,
 #           'num_envs_per_worker': 6,
 #           "evaluation_interval": 5, }
 with open(params_path, "rb") as f:
     config = cloudpickle.load(f)
 config["explore"] = False
 config['num_envs_per_worker'] = 1
 print("Trained on map: \n", config["env_config"]["maps"])
 config["env_config"]["maps"] = MAP_WITH_EXCEPTION
 trainer = DQNTrainer(config=config)
 trainer.restore(path.format(checkpoint, checkpoint))
 policy = trainer.get_policy()
 trainer._evaluate()
 samples = (trainer.evaluation_workers.local_worker().sample()
            for _ in range(NUM_EPISODES))
 rows = map(lambda x: np.concatenate([
     x["unroll_id"][:, None],
     np.arange(0, x.count)[:,None],
     x["obs"],
     x["actions"][:, None],
     x["q_values"],
     x["rewards"][:, None],
     x["dones"][:, None],
     x["new_obs"],
     process_info(x["infos"])],
     -1),
            samples)
예제 #4
0
ray.init(num_cpus=8, num_gpus=0)
DQNAgent = DQNTrainer(env="leduc_holdem", config=config)
DQNAgent.restore(checkpoint_path)

reward_sums = {a: 0 for a in env.possible_agents}
i = 0
env.reset()

for agent in env.agent_iter():
    observation, reward, done, info = env.last()
    obs = observation['observation']
    reward_sums[agent] += reward
    if done:
        action = None
    else:
        print(DQNAgent.get_policy(agent))
        policy = DQNAgent.get_policy(agent)
        batch_obs = {
            'obs': {
                'observation': np.expand_dims(observation['observation'], 0),
                'action_mask': np.expand_dims(observation['action_mask'], 0)
            }
        }
        batched_action, state_out, info = policy.compute_actions_from_input_dict(
            batch_obs)
        single_action = batched_action[0]
        action = single_action

    env.step(action)
    i += 1
    env.render()