コード例 #1
0
ファイル: train_v3.py プロジェクト: nicoladainese96/SC2-RL
def test(step_idx, agent, test_env, process_ID, op, action_dict, num_test,
         save_path):
    score = 0.0
    done = False

    for _ in range(num_test):

        obs = reset_and_skip_first_frame(test_env)
        s_dict, _ = op.get_state(obs)
        s = merge_screen_and_minimap(s_dict)
        s = s[np.newaxis, ...]  # add batch dim
        available_actions = obs[0].observation.available_actions
        a_mask = get_action_mask(available_actions,
                                 action_dict)[np.newaxis, ...]  # add batch dim

        while not done:
            a, log_prob, probs = agent.step(s, a_mask)
            obs = test_env.step(a)
            s_prime_dict, _ = op.get_state(obs)
            s_prime = merge_screen_and_minimap(s_prime_dict)
            s_prime = s_prime[np.newaxis, ...]  # add batch dim
            reward = obs[0].reward
            done = obs[0].last()
            available_actions = obs[0].observation.available_actions
            a_mask = get_action_mask(available_actions,
                                     action_dict)[np.newaxis,
                                                  ...]  # add batch dim

            s = s_prime
            score += reward
        done = False

    with open(save_path + '/Logging/' + process_ID + '.txt', 'a+') as f:
        print(f"{step_idx},{score/num_test:.1f}", file=f)
    return score / num_test
コード例 #2
0
ファイル: train_v3.py プロジェクト: nicoladainese96/SC2-RL
def inspection_test(step_idx, agent, test_env, process_ID, op, action_dict):
    inspector = InspectionDict(step_idx, process_ID, action_dict, test_env)

    obs = reset_and_skip_first_frame(test_env)
    s_dict, _ = op.get_state(obs)
    s = merge_screen_and_minimap(s_dict)
    s = s[np.newaxis, ...]  # add batch dim
    available_actions = obs[0].observation.available_actions
    a_mask = get_action_mask(available_actions,
                             action_dict)[np.newaxis, ...]  # add batch dim

    done = False
    G = 0.0
    # list used for update
    s_lst, r_lst, done_lst, bootstrap_lst, s_trg_lst = list(), list(), list(
    ), list(), list()
    log_probs = []
    entropies = []
    while not done:
        a, log_prob, entropy = inspection_step(agent, inspector, s, a_mask)
        log_probs.append(log_prob)
        entropies.append(entropy)
        obs = test_env.step(a)
        s_prime_dict, _ = op.get_state(obs)
        s_prime = merge_screen_and_minimap(s_prime_dict)
        s_prime = s_prime[np.newaxis, ...]  # add batch dim
        reward = obs[0].reward
        done = obs[0].last()
        available_actions = obs[0].observation.available_actions
        a_mask = get_action_mask(available_actions,
                                 action_dict)[np.newaxis, ...]  # add batch dim
        if done:
            bootstrap = True
        else:
            bootstrap = False

        inspector.dict['state_traj'].append(s)
        s_lst.append(s)
        r_lst.append(reward)
        done_lst.append(done)
        bootstrap_lst.append(bootstrap)
        s_trg_lst.append(s_prime)

        s = s_prime
        G += reward

    inspector.dict['rewards'] = r_lst
    s_lst = np.array(s_lst).transpose(1, 0, 2, 3, 4)
    r_lst = np.array(r_lst).reshape(1, -1)
    done_lst = np.array(done_lst).reshape(1, -1)
    bootstrap_lst = np.array(bootstrap_lst).reshape(1, -1)
    s_trg_lst = np.array(s_trg_lst).transpose(1, 0, 2, 3, 4)
    update_dict = inspection_update(agent, r_lst, log_probs, entropies, s_lst,
                                    done_lst, bootstrap_lst, s_trg_lst)
    inspector.store_update(update_dict)
    return inspector
コード例 #3
0
ファイル: train_v3.py プロジェクト: nicoladainese96/SC2-RL
def worker(worker_id, master_end, worker_end, game_params, map_name,
           obs_proc_params, action_dict):
    master_end.close()  # Forbid worker to use the master end for messaging
    np.random.seed()  # sets random seed for the environment
    env = init_game(game_params,
                    map_name,
                    random_seed=np.random.randint(10000))
    op = ObsProcesser(**obs_proc_params)

    while True:
        cmd, data = worker_end.recv()
        if cmd == 'step':
            obs = env.step([data])
            state_trg_dict, _ = op.get_state(
                obs)  #returns (state_dict, names_dict)
            state_trg = merge_screen_and_minimap(state_trg_dict)
            reward = obs[0].reward
            done = obs[0].last()

            # Always bootstrap when episode finishes (in MoveToBeacon there is no real end)
            if done:
                bootstrap = True
            else:
                bootstrap = False

            # state_trg is the state used as next state for the update
            # state is the new state used to decide the next action
            # (different if the episode ends and another one begins)
            if done:
                obs = reset_and_skip_first_frame(env)
                state_dict, _ = op.get_state(
                    obs)  # returns (state_dict, names_dict)
                state = merge_screen_and_minimap(state_dict)
            else:
                state = state_trg

            available_actions = obs[0].observation.available_actions
            action_mask = get_action_mask(available_actions, action_dict)
            worker_end.send(
                (state, reward, done, bootstrap, state_trg, action_mask))

        elif cmd == 'reset':
            obs = reset_and_skip_first_frame(env)
            state_dict, _ = op.get_state(
                obs)  # returns (state_dict, names_dict)
            state = merge_screen_and_minimap(state_dict)
            available_actions = obs[0].observation.available_actions
            action_mask = get_action_mask(available_actions, action_dict)

            worker_end.send((state, action_mask))
        elif cmd == 'close':
            worker_end.close()
            break
        else:
            raise NotImplementedError