action = agent.act(ob, reward, done) ob, reward, done, _ = env.step(action) ob = crop_normalize(ob, crop) replay_buffer[i]['action'].append(action) replay_buffer[i]['next_obs'].append( np.concatenate((ob, prev_ob), axis=0)) if done: break else: while True: replay_buffer[i]['obs'].append(ob[1]) action = agent.act(ob, reward, done) ob, reward, done, _ = env.step(action) replay_buffer[i]['action'].append(action) replay_buffer[i]['next_obs'].append(ob[1]) if done: break if i % 10 == 0: print("iter " + str(i)) env.close() # Save replay buffer to disk. utils.save_list_dict_h5py(replay_buffer, args.fname)
def main(args): hidden = 256 max_episodes = args.max_episodes max_steps = args.num_steps dataset_save_path = args.save_path min_burnin = args.min_burnin max_burnin = args.max_burnin env_name = args.env_id seed = args.seed save_dir = "./{:s}/".format(env_name.lower()) torch.manual_seed(seed) env, num_actions = init_env(env_name, seed) model = init_and_load_model(hidden, num_actions, save_dir) random_agent = init_random_agent(env) episode_length, epr, eploss, done = 0, 0, 0, True state = env.reset() prev_state = state hx = reset_rnn_state() if env_name == 'PongDeterministic-v4': crop = (35, 190) elif env_name == 'SpaceInvadersDeterministic-v4': crop = (30, 200) else: raise NotImplementedError( "Only Pong and Space were used in the original paper.") replay_buffer = [] blacklist_state_ids = None if args.check_dup_paths: blacklist_state_ids = construct_start_states_set(args.check_dup_paths) # TODO: what are the max episodes in the envs, does the C-SWM repo change that? with torch.no_grad(): burnin_steps = np.random.randint(min_burnin, max_burnin) replay_init_episode(replay_buffer) while True: episode_length += 1 start_collection = episode_length > burnin_steps if start_collection: replay_buffer[-1]['state_ids'].append( np.array(cp.deepcopy(env.unwrapped._get_ram()), dtype=np.int32)) if start_collection: action = random_agent.act(None, None, None) else: action = select_action(preprocess_state(state), model, hx, args.eps) next_state, reward, done, _ = env.step(action) # print(reward, done, start_collection, env.env.ale.lives()) # import matplotlib.pyplot as plt # if start_collection: # print(episode_length) # plt.subplot(1, 2, 1) # plt.imshow(state) # plt.subplot(1, 2, 2) # plt.imshow(next_state) # plt.pause(0.05) if env_name == 'PongDeterministic-v4': # reset when we win/lose a round (pos/neg reward) # don't reset once we are collecting random data # if we do reset, the dataset is extremely limited # because we only allow full 10-step episodes if reward != 0 and not start_collection: done = True elif env_name == 'SpaceInvadersDeterministic-v4': # reset when we lose life (we start with 3 lives) if env.env.ale.lives() != 3: done = True if blacklist_state_ids is not None: # first step of data collection if episode_length == burnin_steps + 1: # if this start state exists in the training set, go to the next episode if replay_buffer[-1]['state_ids'][-1].tobytes( ) in blacklist_state_ids: print("duplicate start state, skip episode") done = True if start_collection: state_replay = np.concatenate((crop_normalize( prev_state, crop), crop_normalize(state, crop)), axis=0) next_state_replay = np.concatenate((crop_normalize( state, crop), crop_normalize(next_state, crop)), axis=0) replay_buffer[-1]['obs'].append(state_replay) replay_buffer[-1]['next_obs'].append(next_state_replay) replay_buffer[-1]['action'].append(action) replay_buffer[-1]['next_state_ids'].append( np.array(cp.deepcopy(env.unwrapped._get_ram()), dtype=np.int32)) epr += reward done = done or episode_length >= 1e4 prev_state = state state = next_state num_samples = len(replay_buffer[-1]['obs']) if num_samples == max_steps: done = True if done: print("ep {:d}, length: {:d}".format(len(replay_buffer), episode_length)) hx = reset_rnn_state() episode_length, epr, eploss = 0, 0, 0 state = env.reset() prev_state = state # check if episode was long enough if num_samples != max_steps: del replay_buffer[-1] # termination condition if len(replay_buffer) == max_episodes: break burnin_steps = np.random.randint(min_burnin, max_burnin) replay_init_episode(replay_buffer) env.close() utils.save_list_dict_h5py(replay_buffer, dataset_save_path)