コード例 #1
0
                action = agent.act(ob, reward, done)
                ob, reward, done, _ = env.step(action)
                ob = crop_normalize(ob, crop)

                replay_buffer[i]['action'].append(action)
                replay_buffer[i]['next_obs'].append(
                    np.concatenate((ob, prev_ob), axis=0))

                if done:
                    break
        else:
            while True:
                replay_buffer[i]['obs'].append(ob[1])

                action = agent.act(ob, reward, done)
                ob, reward, done, _ = env.step(action)

                replay_buffer[i]['action'].append(action)
                replay_buffer[i]['next_obs'].append(ob[1])

                if done:
                    break

        if i % 10 == 0:
            print("iter " + str(i))

    env.close()

    # Save replay buffer to disk.
    utils.save_list_dict_h5py(replay_buffer, args.fname)
コード例 #2
0
ファイル: collect.py プロジェクト: ondrejba/baby-a3c
def main(args):

    hidden = 256
    max_episodes = args.max_episodes
    max_steps = args.num_steps
    dataset_save_path = args.save_path
    min_burnin = args.min_burnin
    max_burnin = args.max_burnin
    env_name = args.env_id
    seed = args.seed
    save_dir = "./{:s}/".format(env_name.lower())

    torch.manual_seed(seed)

    env, num_actions = init_env(env_name, seed)
    model = init_and_load_model(hidden, num_actions, save_dir)
    random_agent = init_random_agent(env)

    episode_length, epr, eploss, done = 0, 0, 0, True
    state = env.reset()
    prev_state = state
    hx = reset_rnn_state()

    if env_name == 'PongDeterministic-v4':
        crop = (35, 190)
    elif env_name == 'SpaceInvadersDeterministic-v4':
        crop = (30, 200)
    else:
        raise NotImplementedError(
            "Only Pong and Space were used in the original paper.")

    replay_buffer = []

    blacklist_state_ids = None
    if args.check_dup_paths:
        blacklist_state_ids = construct_start_states_set(args.check_dup_paths)

    # TODO: what are the max episodes in the envs, does the C-SWM repo change that?
    with torch.no_grad():

        burnin_steps = np.random.randint(min_burnin, max_burnin)
        replay_init_episode(replay_buffer)

        while True:

            episode_length += 1

            start_collection = episode_length > burnin_steps

            if start_collection:
                replay_buffer[-1]['state_ids'].append(
                    np.array(cp.deepcopy(env.unwrapped._get_ram()),
                             dtype=np.int32))

            if start_collection:
                action = random_agent.act(None, None, None)
            else:
                action = select_action(preprocess_state(state), model, hx,
                                       args.eps)

            next_state, reward, done, _ = env.step(action)
            # print(reward, done, start_collection, env.env.ale.lives())
            # import matplotlib.pyplot as plt
            # if start_collection:
            #     print(episode_length)
            #     plt.subplot(1, 2, 1)
            #     plt.imshow(state)
            #     plt.subplot(1, 2, 2)
            #     plt.imshow(next_state)
            #     plt.pause(0.05)

            if env_name == 'PongDeterministic-v4':
                # reset when we win/lose a round (pos/neg reward)
                # don't reset once we are collecting random data
                # if we do reset, the dataset is extremely limited
                # because we only allow full 10-step episodes
                if reward != 0 and not start_collection:
                    done = True
            elif env_name == 'SpaceInvadersDeterministic-v4':
                # reset when we lose life (we start with 3 lives)
                if env.env.ale.lives() != 3:
                    done = True

            if blacklist_state_ids is not None:
                # first step of data collection
                if episode_length == burnin_steps + 1:
                    # if this start state exists in the training set, go to the next episode
                    if replay_buffer[-1]['state_ids'][-1].tobytes(
                    ) in blacklist_state_ids:
                        print("duplicate start state, skip episode")
                        done = True

            if start_collection:
                state_replay = np.concatenate((crop_normalize(
                    prev_state, crop), crop_normalize(state, crop)),
                                              axis=0)
                next_state_replay = np.concatenate((crop_normalize(
                    state, crop), crop_normalize(next_state, crop)),
                                                   axis=0)
                replay_buffer[-1]['obs'].append(state_replay)
                replay_buffer[-1]['next_obs'].append(next_state_replay)
                replay_buffer[-1]['action'].append(action)
                replay_buffer[-1]['next_state_ids'].append(
                    np.array(cp.deepcopy(env.unwrapped._get_ram()),
                             dtype=np.int32))

            epr += reward
            done = done or episode_length >= 1e4

            prev_state = state
            state = next_state

            num_samples = len(replay_buffer[-1]['obs'])
            if num_samples == max_steps:
                done = True

            if done:
                print("ep {:d}, length: {:d}".format(len(replay_buffer),
                                                     episode_length))

                hx = reset_rnn_state()
                episode_length, epr, eploss = 0, 0, 0
                state = env.reset()
                prev_state = state

                # check if episode was long enough
                if num_samples != max_steps:
                    del replay_buffer[-1]

                # termination condition
                if len(replay_buffer) == max_episodes:
                    break

                burnin_steps = np.random.randint(min_burnin, max_burnin)
                replay_init_episode(replay_buffer)

    env.close()
    utils.save_list_dict_h5py(replay_buffer, dataset_save_path)