except KeyboardInterrupt:
            print("Interrupt got, saving the model...")
            torch.save(prep.state_dict(), save_path / "prep.dat")
            torch.save(cmd.state_dict(), save_path / "cmd.dat")

    print("Using preprocessor and command generator")
    prep.train(False)
    cmd.train(False)

    val_env = gym.make(val_env_id)
    val_env = preproc.TextWorldPreproc(val_env,
                                       use_admissible_commands=False,
                                       keep_admissible_commands=True,
                                       reward_wrong_last_command=-0.1)

    net = model.DQNModel(obs_size=prep.obs_enc_size,
                         cmd_size=prep.obs_enc_size).to(device)
    tgt_net = ptan.agent.TargetNet(net)
    cmd_encoder = preproc.Encoder(params.embeddings,
                                  prep.obs_enc_size).to(device)
    tgt_cmd_encoder = ptan.agent.TargetNet(cmd_encoder)
    agent = model.CmdDQNAgent(env,
                              net,
                              cmd,
                              cmd_encoder,
                              prep,
                              epsilon=1,
                              device=device)
    exp_source = ptan.experience.ExperienceSourceFirstLast(env,
                                                           agent,
                                                           gamma=GAMMA,
                                                           steps_count=1)
Exemplo n.º 2
0
    print("Game %s, with file %s will be used for validation" %
          (val_env_id, val_game_file))

    env = gym.make(env_id)
    env = preproc.TextWorldPreproc(env)

    val_env = gym.make(val_env_id)
    val_env = preproc.TextWorldPreproc(val_env)

    prep = preproc.Preprocessor(dict_size=env.observation_space.vocab_size,
                                emb_size=params.embeddings,
                                num_sequences=env.num_fields,
                                enc_output_size=params.encoder_size).to(device)
    tgt_prep = ptan.agent.TargetNet(prep)

    net = model.DQNModel(obs_size=env.num_fields * params.encoder_size,
                         cmd_size=params.encoder_size)
    net = net.to(device)
    tgt_net = ptan.agent.TargetNet(net)

    agent = model.DQNAgent(net, prep, epsilon=1, device=device)
    exp_source = ptan.experience.ExperienceSourceFirstLast(env,
                                                           agent,
                                                           gamma=GAMMA,
                                                           steps_count=1)
    buffer = ptan.experience.ExperienceReplayBuffer(exp_source,
                                                    params.replay_size)

    optimizer = optim.RMSprop(itertools.chain(net.parameters(),
                                              prep.parameters()),
                              lr=LEARNING_RATE,
                              eps=1e-5)
Exemplo n.º 3
0
    deer_handle, tiger_handle = m_env.get_handles()

    def reset_env():
        m_env.reset()
        m_env.add_walls(method="random", n=MAP_SIZE * MAP_SIZE * WALLS_DENSITY)
        m_env.add_agents(deer_handle, method="random", n=COUNT_DEERS)
        m_env.add_agents(tiger_handle, method="random", n=COUNT_TIGERS)

    env = data.MAgentGroupsEnv(m_env, [deer_handle, tiger_handle],
                               reset_env_func=reset_env)

    deer_obs = data.MAgentEnv.handle_obs_space(m_env, deer_handle)
    tiger_obs = data.MAgentEnv.handle_obs_space(m_env, tiger_handle)

    net_deer = model.DQNModel(
        deer_obs.spaces[0].shape, deer_obs.spaces[1].shape,
        m_env.get_action_space(deer_handle)[0]).to(device)
    tgt_net_deer = ptan.agent.TargetNet(net_deer)
    print(net_deer)

    net_tiger = model.DQNModel(
        tiger_obs.spaces[0].shape, tiger_obs.spaces[1].shape,
        m_env.get_action_space(tiger_handle)[0]).to(device)
    tgt_net_tiger = ptan.agent.TargetNet(net_tiger)
    print(net_tiger)

    action_selector = ptan.actions.EpsilonGreedyActionSelector(
        epsilon=PARAMS.epsilon_start)
    epsilon_tracker = common.EpsilonTracker(action_selector, PARAMS)
    preproc = model.MAgentPreprocessor(device)
Exemplo n.º 4
0
    args = parser.parse_args()

    env = magent.GridWorld('battle', map_size=args.map_size)
    env.set_render_dir(args.render)
    a_handle, b_handle = env.get_handles()

    env.reset()
    env.add_walls(method="random",
                  n=args.map_size * args.map_size * args.walls_density)
    env.add_agents(a_handle, method="random", n=args.count_a)
    env.add_agents(b_handle, method="random", n=args.count_b)

    v = env.get_view_space(a_handle)
    v = (v[-1], ) + v[:2]
    net_a = model.DQNModel(v, env.get_feature_space(a_handle),
                           env.get_action_space(a_handle)[0])
    net_a.load_state_dict(torch.load(args.model_a))
    print(net_a)

    v = env.get_view_space(b_handle)
    v = (v[-1], ) + v[:2]
    net_b = model.DQNModel(v, env.get_feature_space(b_handle),
                           env.get_action_space(b_handle)[0])
    net_b.load_state_dict(torch.load(args.model_b))
    print(net_b)

    a_total_reward = b_total_reward = 0.0
    total_steps = 0

    while True:
        # A actions
Exemplo n.º 5
0
    args = parser.parse_args()

    env = magent.GridWorld(args.mode, map_size=args.map_size)
    env.set_render_dir(args.render)
    deer_handle, tiger_handle = env.get_handles()

    env.reset()
    env.add_walls(method="random",
                  n=args.map_size * args.map_size * args.walls_density)
    env.add_agents(deer_handle, method="random", n=args.deers)
    env.add_agents(tiger_handle, method="random", n=args.tigers)

    v = env.get_view_space(tiger_handle)
    v = (v[-1], ) + v[:2]
    net = model.DQNModel(v, env.get_feature_space(tiger_handle),
                         env.get_action_space(tiger_handle)[0])
    net.load_state_dict(torch.load(args.model))
    print(net)
    total_reward = 0.0

    while True:
        view_obs, feats_obs = env.get_observation(tiger_handle)
        view_obs = np.array(view_obs)
        feats_obs = np.array(feats_obs)
        view_obs = np.moveaxis(view_obs, 3, 1)
        view_t = torch.tensor(view_obs, dtype=torch.float32)
        feats_t = torch.tensor(feats_obs, dtype=torch.float32)
        qvals = net((view_t, feats_t))
        actions = torch.max(qvals, dim=1)[1].cpu().numpy()
        actions = actions.astype(np.int32)
        env.set_action(tiger_handle, actions)
    def reset_env():
        m_env.reset()
        m_env.add_walls(method="random", n=MAP_SIZE * MAP_SIZE * WALLS_DENSITY)
        m_env.add_agents(deer_handle, method="random", n=COUNT_DEERS)
        m_env.add_agents(tiger_handle, method="random", n=COUNT_TIGERS)

    env = data.MAgentEnv(m_env, tiger_handle, reset_env_func=reset_env)

    if args.mode == 'double_attack_nn':
        net = model.DQNNoisyModel(
            env.single_observation_space.spaces[0].shape,
            env.single_observation_space.spaces[1].shape,
            m_env.get_action_space(tiger_handle)[0]).to(device)
    else:
        net = model.DQNModel(
            env.single_observation_space.spaces[0].shape,
            env.single_observation_space.spaces[1].shape,
            m_env.get_action_space(tiger_handle)[0]).to(device)
    tgt_net = ptan.agent.TargetNet(net)
    print(net)

    if args.mode == 'double_attack':
        action_selector = ptan.actions.ArgmaxActionSelector()
        epsilon_tracker = None
    else:
        action_selector = ptan.actions.EpsilonGreedyActionSelector(
            epsilon=PARAMS.epsilon_start)
        epsilon_tracker = common.EpsilonTracker(action_selector, PARAMS)
    preproc = model.MAgentPreprocessor(device)
    agent = ptan.agent.DQNAgent(net,
                                action_selector,
                                device,
        m_env.add_agents(a_handle, method="random", n=COUNT_AGENTS_1)
        m_env.add_agents(b_handle, method="random", n=COUNT_AGENTS_2)

    a_env = data.MAgentEnv(m_env,
                           a_handle,
                           reset_env_func=lambda: None,
                           is_slave=True)
    b_env = data.MAgentEnv(m_env,
                           b_handle,
                           reset_env_func=reset_env,
                           is_slave=False,
                           steps_limit=MAX_EPISODE)

    obs = data.MAgentEnv.handle_obs_space(m_env, a_handle)

    net = model.DQNModel(obs.spaces[0].shape, obs.spaces[1].shape,
                         m_env.get_action_space(a_handle)[0]).to(device)
    tgt_net = ptan.agent.TargetNet(net)
    print(net)

    action_selector = ptan.actions.EpsilonGreedyActionSelector(
        epsilon=PARAMS.epsilon_start)
    epsilon_tracker = common.EpsilonTracker(action_selector, PARAMS)
    preproc = model.MAgentPreprocessor(device)

    agent = ptan.agent.DQNAgent(net,
                                action_selector,
                                device,
                                preprocessor=preproc)
    a_exp_source = ptan.experience.ExperienceSourceFirstLast(a_env,
                                                             agent,
                                                             PARAMS.gamma,