def get_venv(args, env_id, num_env, seed, hps): env_type = args.env if env_type == 'pass': env = make_multi_pass_env(env_id, env_type, num_env, seed, args) elif env_type == 'threepass': env = make_m_three_pass_env(env_id, env_type, num_env, seed, args) elif env_type == 'island': env = make_m_island_env(env_id, env_type, num_env, seed, args) elif env_type == 'x_island': env = make_m_x_island_env(env_id, env_type, num_env, seed, args) elif env_type == 'pushball': env = make_m_pushball_env(env_id, env_type, num_env, seed, args) venv = VecFrameStack(env, hps.pop('frame_stack')) return venv
def train(*, env_id, num_env, hps, num_timesteps, seed, use_reward, ep_path, dmlab): venv = VecFrameStack( DummyVecEnv([ lambda: CollectGymDataset(Minecraft('MineRLTreechop-v0', 'treechop' ), ep_path, atari=False) ]), hps.pop('frame_stack')) venv.score_multiple = 1 venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False ob_space = venv.observation_space ac_space = venv.action_space gamma = hps.pop('gamma') policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')] print('Running train ==========================================') agent = PpoAgent( scope='ppo', ob_space=ob_space, ac_space=ac_space, stochpol_fn=functools.partial( policy, scope='pol', ob_space=ob_space, ac_space=ac_space, update_ob_stats_independently_per_gpu=hps.pop( 'update_ob_stats_independently_per_gpu'), proportion_of_exp_used_for_predictor_update=hps.pop( 'proportion_of_exp_used_for_predictor_update'), dynamics_bonus=hps.pop("dynamics_bonus")), gamma=gamma, gamma_ext=hps.pop('gamma_ext'), lam=hps.pop('lam'), nepochs=hps.pop('nepochs'), nminibatches=hps.pop('nminibatches'), lr=hps.pop('lr'), cliprange=0.1, nsteps=128, ent_coef=0.001, max_grad_norm=hps.pop('max_grad_norm'), use_news=hps.pop("use_news"), comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None, update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'), int_coeff=hps.pop('int_coeff'), ext_coeff=hps.pop('ext_coeff')) agent.start_interaction([venv]) if hps.pop('update_ob_stats_from_random_agent'): agent.collect_random_statistics(num_timesteps=128 * 50) assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys()) counter = 0 while True: info = agent.step() if info['update']: logger.logkvs(info['update']) logger.dumpkvs() counter += 1 if agent.I.stats['tcount'] > num_timesteps: break agent.stop_interaction()
def train(*, env_id, num_env, hps, num_timesteps, seed): venv = VecFrameStack( make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(), start_index=num_env * MPI.COMM_WORLD.Get_rank(), max_episode_steps=hps.pop('max_episode_steps')), hps.pop('frame_stack')) # venv.score_multiple = {'Mario': 500, # 'MontezumaRevengeNoFrameskip-v4': 100, # 'GravitarNoFrameskip-v4': 250, # 'PrivateEyeNoFrameskip-v4': 500, # 'SolarisNoFrameskip-v4': None, # 'VentureNoFrameskip-v4': 200, # 'PitfallNoFrameskip-v4': 100, # }[env_id] venv.score_multiple = 1 venv.record_obs = True # venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False ob_space = venv.observation_space ac_space = venv.action_space gamma = hps.pop('gamma') policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')] agent = PpoAgent( scope='ppo', ob_space=ob_space, ac_space=ac_space, stochpol_fn=functools.partial( policy, scope='pol', ob_space=ob_space, ac_space=ac_space, update_ob_stats_independently_per_gpu=hps.pop('update_ob_stats_independently_per_gpu'), proportion_of_exp_used_for_predictor_update=hps.pop('proportion_of_exp_used_for_predictor_update'), dynamics_bonus = hps.pop("dynamics_bonus") ), gamma=gamma, gamma_ext=hps.pop('gamma_ext'), lam=hps.pop('lam'), nepochs=hps.pop('nepochs'), nminibatches=hps.pop('nminibatches'), lr=hps.pop('lr'), cliprange=0.1, nsteps=128, ent_coef=0.001, max_grad_norm=hps.pop('max_grad_norm'), use_news=hps.pop("use_news"), comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None, update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'), int_coeff=hps.pop('int_coeff'), ext_coeff=hps.pop('ext_coeff'), ) agent.start_interaction([venv]) if hps.pop('update_ob_stats_from_random_agent'): agent.collect_random_statistics(num_timesteps=128*50) assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys()) counter = 0 while True: info = agent.step() if info['update']: logger.logkvs(info['update']) logger.dumpkvs() counter += 1 if agent.I.stats['tcount'] > num_timesteps: break agent.stop_interaction() return agent
def train(*, env_id, num_env, hps, num_timesteps, seed): venv = VecFrameStack( make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(), start_index=num_env * MPI.COMM_WORLD.Get_rank(), max_episode_steps=hps.pop('max_episode_steps')), hps.pop('frame_stack')) venv.score_multiple = 1 venv.record_obs = False ob_space = venv.observation_space ac_space = venv.action_space gamma = hps.pop('gamma') policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')] agent = PpoAgent( scope='ppo', ob_space=ob_space, ac_space=ac_space, stochpol_fn=functools.partial( policy, scope='pol', ob_space=ob_space, ac_space=ac_space, update_ob_stats_independently_per_gpu=hps.pop( 'update_ob_stats_independently_per_gpu'), proportion_of_exp_used_for_predictor_update=hps.pop( 'proportion_of_exp_used_for_predictor_update'), exploration_type=hps.pop("exploration_type"), beta=hps.pop("beta"), ), gamma=gamma, gamma_ext=hps.pop('gamma_ext'), lam=hps.pop('lam'), nepochs=hps.pop('nepochs'), nminibatches=hps.pop('nminibatches'), lr=hps.pop('lr'), cliprange=0.1, nsteps=128, ent_coef=0.001, max_grad_norm=hps.pop('max_grad_norm'), use_news=hps.pop("use_news"), comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None, update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'), int_coeff=hps.pop('int_coeff'), ext_coeff=hps.pop('ext_coeff'), noise_type=hps.pop('noise_type'), noise_p=hps.pop('noise_p'), use_sched=hps.pop('use_sched'), num_env=num_env, exp_name=hps.pop('exp_name'), ) agent.start_interaction([venv]) if hps.pop('update_ob_stats_from_random_agent'): agent.collect_random_statistics(num_timesteps=128 * 50) assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys()) counter = 0 while True: info = agent.step() n_updates = 0 if info['update']: logger.logkvs(info['update']) logger.dumpkvs() if NSML: n_updates = int(info['update']['n_updates']) nsml_dict = { k: np.float64(v) for k, v in info['update'].items() if isinstance(v, Number) } nsml.report(step=n_updates, **nsml_dict) counter += 1 #if n_updates >= 40*1000: # 40K updates # break if agent.I.stats['tcount'] > num_timesteps: break agent.stop_interaction()
def train(*, env_id, num_env, hps, num_timesteps, seed): if "NoFrameskip" in env_id: env_factory = make_atari_env else: env_factory = make_non_atari_env venv = VecFrameStack( env_factory( env_id, num_env, seed, wrapper_kwargs=dict(), start_index=num_env * MPI.COMM_WORLD.Get_rank(), max_episode_steps=hps.pop("max_episode_steps"), ), hps.pop("frame_stack"), ) # venv.score_multiple = {'Mario': 500, # 'MontezumaRevengeNoFrameskip-v4': 100, # 'GravitarNoFrameskip-v4': 250, # 'PrivateEyeNoFrameskip-v4': 500, # 'SolarisNoFrameskip-v4': None, # 'VentureNoFrameskip-v4': 200, # 'PitfallNoFrameskip-v4': 100, # }[env_id] venv.score_multiple = 1 venv.record_obs = True if env_id == "SolarisNoFrameskip-v4" else False ob_space = venv.observation_space ac_space = venv.action_space gamma = hps.pop("gamma") policy = { "rnn": CnnGruPolicy, "cnn": CnnPolicy, 'ffnn': GruPolicy }[hps.pop("policy")] agent = PpoAgent( scope="ppo", ob_space=ob_space, ac_space=ac_space, stochpol_fn=functools.partial( policy, scope="pol", ob_space=ob_space, ac_space=ac_space, update_ob_stats_independently_per_gpu=hps.pop( "update_ob_stats_independently_per_gpu"), proportion_of_exp_used_for_predictor_update=hps.pop( "proportion_of_exp_used_for_predictor_update"), dynamics_bonus=hps.pop("dynamics_bonus"), meta_rl=hps['meta_rl']), gamma=gamma, gamma_ext=hps.pop("gamma_ext"), lam=hps.pop("lam"), nepochs=hps.pop("nepochs"), nminibatches=hps.pop("nminibatches"), lr=hps.pop("lr"), cliprange=0.1, nsteps=128, ent_coef=0.001, max_grad_norm=hps.pop("max_grad_norm"), use_news=hps.pop("use_news"), comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None, update_ob_stats_every_step=hps.pop("update_ob_stats_every_step"), int_coeff=hps.pop("int_coeff"), ext_coeff=hps.pop("ext_coeff"), meta_rl=hps.pop('meta_rl')) agent.start_interaction([venv]) if hps.pop("update_ob_stats_from_random_agent"): agent.collect_random_statistics(num_timesteps=128 * 50) assert len(hps) == 0, f"Unused hyperparameters: {list(hps.keys())}" counter = 0 while True: info = agent.step() if info["update"]: logger.logkvs(info["update"]) logger.dumpkvs() counter += 1 if agent.I.stats["tcount"] > num_timesteps: break agent.stop_interaction()
def train(*, env_id, num_env, hps, num_timesteps, seed): venv = VecFrameStack( make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(), start_index=num_env * MPI.COMM_WORLD.Get_rank(), max_episode_steps=hps.pop('max_episode_steps')), hps.pop('frame_stack')) # Size of states when stored in the memory. only_train_r = hps.pop('only_train_r') online_r_training = hps.pop('online_train_r') or only_train_r r_network_trainer = None save_path = hps.pop('save_path') r_network_weights_path = hps.pop('r_path') ''' ec_type = 'none' # hps.pop('ec_type') venv = CuriosityEnvWrapperFrameStack( make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(), start_index=num_env * MPI.COMM_WORLD.Get_rank(), max_episode_steps=hps.pop('max_episode_steps')), vec_episodic_memory = None, observation_embedding_fn = None, exploration_reward = ec_type, exploration_reward_min_step = 0, nstack = hps.pop('frame_stack'), only_train_r = only_train_r ) ''' # venv.score_multiple = {'Mario': 500, # 'MontezumaRevengeNoFrameskip-v4': 100, # 'GravitarNoFrameskip-v4': 250, # 'PrivateEyeNoFrameskip-v4': 500, # 'SolarisNoFrameskip-v4': None, # 'VentureNoFrameskip-v4': 200, # 'PitfallNoFrameskip-v4': 100, # }[env_id] venv.score_multiple = 1 venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False ob_space = venv.observation_space ac_space = venv.action_space gamma = hps.pop('gamma') log_interval = hps.pop('log_interval') nminibatches = hps.pop('nminibatches') play = hps.pop('play') if play: nsteps = 1 rnd_type = hps.pop('rnd_type') div_type = hps.pop('div_type') num_agents = hps.pop('num_agents') load_ram = hps.pop('load_ram') debug = hps.pop('debug') rnd_mask_prob = hps.pop('rnd_mask_prob') rnd_mask_type = hps.pop('rnd_mask_type') indep_rnd = hps.pop('indep_rnd') logger.info("indep_rnd:", indep_rnd) indep_policy = hps.pop('indep_policy') sd_type = hps.pop('sd_type') from_scratch = hps.pop('from_scratch') use_kl = hps.pop('use_kl') save_interval = 100 policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')] agent = PpoAgent( scope='ppo', ob_space=ob_space, ac_space=ac_space, stochpol_fn=functools.partial( policy, scope='pol', ob_space=ob_space, ac_space=ac_space, update_ob_stats_independently_per_gpu=hps.pop( 'update_ob_stats_independently_per_gpu'), proportion_of_exp_used_for_predictor_update=hps.pop( 'proportion_of_exp_used_for_predictor_update'), dynamics_bonus=hps.pop("dynamics_bonus"), num_agents=num_agents, rnd_type=rnd_type, div_type=div_type, indep_rnd=indep_rnd, indep_policy=indep_policy, sd_type=sd_type, rnd_mask_prob=rnd_mask_prob), gamma=gamma, gamma_ext=hps.pop('gamma_ext'), gamma_div=hps.pop('gamma_div'), lam=hps.pop('lam'), nepochs=hps.pop('nepochs'), nminibatches=nminibatches, lr=hps.pop('lr'), cliprange=0.1, nsteps=5 if debug else 128, ent_coef=0.001, max_grad_norm=hps.pop('max_grad_norm'), use_news=hps.pop("use_news"), comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None, update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'), int_coeff=hps.pop('int_coeff'), ext_coeff=hps.pop('ext_coeff'), log_interval=log_interval, only_train_r=only_train_r, rnd_type=rnd_type, reset=hps.pop('reset'), dynamics_sample=hps.pop('dynamics_sample'), save_path=save_path, num_agents=num_agents, div_type=div_type, load_ram=load_ram, debug=debug, rnd_mask_prob=rnd_mask_prob, rnd_mask_type=rnd_mask_type, sd_type=sd_type, from_scratch=from_scratch, use_kl=use_kl, indep_rnd=indep_rnd) load_path = hps.pop('load_path') base_load_path = hps.pop('base_load_path') agent.start_interaction([venv]) if load_path is not None: if play: agent.load(load_path) else: #agent.load(load_path) #agent.load_help_info(0, load_path) #agent.load_help_info(1, load_path) #load diversity agent #base_agent_idx = 1 #logger.info("load base agents weights from {} agent {}".format(base_load_path, str(base_agent_idx))) #agent.load_agent(base_agent_idx, base_load_path) #agent.clone_baseline_agent(base_agent_idx) #agent.load_help_info(0, dagent_load_path) #agent.clone_agent(0) #load main agen1 src_agent_idx = 1 logger.info("load main agent weights from {} agent {}".format( load_path, str(src_agent_idx))) agent.load_agent(src_agent_idx, load_path) if indep_rnd == False: rnd_agent_idx = 1 else: rnd_agent_idx = src_agent_idx #rnd_agent_idx = 0 logger.info("load rnd weights from {} agent {}".format( load_path, str(rnd_agent_idx))) agent.load_rnd(rnd_agent_idx, load_path) agent.clone_agent(rnd_agent_idx, rnd=True, policy=False, help_info=False) logger.info("load help info from {} agent {}".format( load_path, str(src_agent_idx))) agent.load_help_info(src_agent_idx, load_path) agent.clone_agent(src_agent_idx, rnd=False, policy=True, help_info=True) #logger.info("load main agent weights from {} agent {}".format(load_path, str(2))) #load_path = '/data/xupeifrom7700_1000/seed1_log0.5_clip-0.5~0.5_3agent_hasint4_2divrew_-1~1/models' #agent.load_agent(1, load_path) #agent.clone_baseline_agent() #if sd_type =='sd': # agent.load_sd("save_dir/models_sd_trained") #agent.initialize_discriminator() update_ob_stats_from_random_agent = hps.pop( 'update_ob_stats_from_random_agent') if play == False: if load_path is not None: pass #agent.collect_statistics_from_model() else: if update_ob_stats_from_random_agent and rnd_type == 'rnd': agent.collect_random_statistics(num_timesteps=128 * 5 if debug else 128 * 50) assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys()) #agent.collect_rnd_info(128*50) ''' if sd_type=='sd': agent.train_sd(max_nepoch=300, max_neps=5) path = '{}_sd_trained'.format(save_path) logger.log("save model:",path) agent.save(path) return #agent.update_diverse_agent(max_nepoch=1000) #path = '{}_divupdated'.format(save_path) #logger.log("save model:",path) #agent.save(path) ''' counter = 0 while True: info = agent.step() n_updates = agent.I.stats["n_updates"] if info['update']: logger.logkvs(info['update']) logger.dumpkvs() counter += 1 if info['update'] and save_path is not None and ( n_updates % save_interval == 0 or n_updates == 1): path = '{}_{}'.format(save_path, str(n_updates)) logger.log("save model:", path) agent.save(path) agent.save_help_info(save_path, n_updates) if agent.I.stats['tcount'] > num_timesteps: path = '{}_{}'.format(save_path, str(n_updates)) logger.log("save model:", path) agent.save(path) agent.save_help_info(save_path, n_updates) break agent.stop_interaction() else: ''' check_point_rews_list_path ='{}_rewslist'.format(load_path) check_point_rnd_path ='{}_rnd'.format(load_path) oracle_rnd = oracle.OracleExplorationRewardForAllEpisodes() oracle_rnd.load(check_point_rnd_path) #print(oracle_rnd._collected_positions_writer) #print(oracle_rnd._collected_positions_reader) rews_list = load_rews_list(check_point_rews_list_path) print(rews_list) ''' istate = agent.stochpol.initial_state(1) #ph_mean, ph_std = agent.stochpol.get_ph_mean_std() last_obs, prevrews, ec_rews, news, infos, ram_states, _ = agent.env_get( 0) agent.I.step_count += 1 flag = False show_cam = True last_xr = 0 restore = None ''' #path = 'ram_state_500_7room' #path='ram_state_400_6room' #path='ram_state_6700' path='ram_state_7700_10room' f = open(path,'rb') restore = pickle.load(f) f.close() last_obs[0] = agent.I.venvs[0].restore_full_state_by_idx(restore,0) print(last_obs.shape) #path = 'ram_state_400_monitor_rews_6room' #path = 'ram_state_500_monitor_rews_7room' #path='ram_state_6700_monitor_rews' path='ram_state_7700_monitor_rews_10room' f = open(path,'rb') monitor_rews = pickle.load(f) f.close() agent.I.venvs[0].set_cur_monitor_rewards_by_idx(monitor_rews,0) ''' agent_idx = np.asarray([0]) sample_agent_prob = np.asarray([0.5]) ph_mean = agent.stochpol.ob_rms_list[0].mean ph_std = agent.stochpol.ob_rms_list[0].var**0.5 buf_ph_mean = np.zeros( ([1, 1] + list(agent.stochpol.ob_space.shape[:2]) + [1]), np.float32) buf_ph_std = np.zeros( ([1, 1] + list(agent.stochpol.ob_space.shape[:2]) + [1]), np.float32) buf_ph_mean[0, 0] = ph_mean buf_ph_std[0, 0] = ph_std vpreds_ext_list = [] ep_rews = np.zeros((1)) divexp_flag = False step_count = 0 stage_prob = True last_rew_ob = np.full_like(last_obs, 128) clusters = Clusters(1.0) #path = '{}_sd_rms'.format(load_path) #agent.I.sd_rms.load(path) while True: dict_obs = agent.stochpol.ensure_observation_is_dict(last_obs) #acs= np.random.randint(low=0, high=15, size=(1)) acs, vpreds_int, vpreds_ext, nlps, istate, ent = agent.stochpol.call( dict_obs, news, istate, agent_idx[:, None]) step_acs = acs t = '' #if show_cam==True: t = input("input:") if t != '': t = int(t) if t <= 17: step_acs = [t] agent.env_step(0, step_acs) obs, prevrews, ec_rews, news, infos, ram_states, monitor_rews = agent.env_get( 0) if news[0] and restore is not None: obs[0] = agent.I.venvs[0].restore_full_state_by_idx(restore, 0) agent.I.venvs[0].set_cur_monitor_rewards_by_idx( monitor_rews, 0) ep_rews = ep_rews + prevrews print(ep_rews) last_rew_ob[prevrews > 0] = obs[prevrews > 0] room = infos[0]['position'][2] vpreds_ext_list.append([vpreds_ext, room]) #print(monitor_rews[0]) #print(len(monitor_rews[0])) #print(infos[0]['open_door_type']) stack_obs = np.concatenate([last_obs[:, None], obs[:, None]], 1) fd = {} fd[agent.stochpol.ph_ob[None]] = stack_obs fd.update({ agent.stochpol.sep_ph_mean: buf_ph_mean, agent.stochpol.sep_ph_std: buf_ph_std }) fd[agent.stochpol.ph_agent_idx] = agent_idx[:, None] fd[agent.stochpol.sample_agent_prob] = sample_agent_prob[:, None] fd[agent.stochpol.last_rew_ob] = last_rew_ob[:, None] fd[agent.stochpol.game_score] = ep_rews[:, None] fd[agent.stochpol.sd_ph_mean] = agent.I.sd_rms.mean fd[agent.stochpol.sd_ph_std] = agent.I.sd_rms.var**0.5 div_prob = 0 all_div_prob = tf_util.get_session().run( [agent.stochpol.all_div_prob], fd) ''' if prevrews[0] > 0: clusters.update(rnd_em,room) num_clusters = len(clusters._cluster_list) for i in range(num_clusters): print("{} {}".format(str(i),list(clusters._room_set[i]))) ''' print("vpreds_int: ", vpreds_int, "vpreds_ext:", vpreds_ext, "ent:", ent, "all_div_prob:", all_div_prob, "room:", room, "step_count:", step_count) #aaaa = np.asarray(vpreds_ext_list) #print(aaaa[-100:]) ''' if step_acs[0]==0: ram_state = ram_states[0] path='ram_state_7700_10room' f = open(path,'wb') pickle.dump(ram_state,f) f.close() path='ram_state_7700_monitor_rews_10room' f = open(path,'wb') pickle.dump(monitor_rews[0],f) f.close() ''' ''' if restore is None: restore = ram_states[0] if np.random.rand() < 0.1: print("restore") obs = agent.I.venvs[0].restore_full_state_by_idx(restore,0) prevrews = None ec_rews = None news= True infos = {} ram_states = ram_states[0] #restore = ram_states[0] ''' img = agent.I.venvs[0].render() last_obs = obs step_count = step_count + 1 time.sleep(0.04)
def load_test(*, env_id, num_env, hps, num_timesteps, seed, fname): venv = VecFrameStack( make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(), start_index=num_env * MPI.COMM_WORLD.Get_rank(), max_episode_steps=hps.pop('max_episode_steps')), hps.pop('frame_stack')) # venv.score_multiple = {'Mario': 500, # 'MontezumaRevengeNoFrameskip-v4': 100, # 'GravitarNoFrameskip-v4': 250, # 'PrivateEyeNoFrameskip-v4': 500, # 'SolarisNoFrameskip-v4': None, # 'VentureNoFrameskip-v4': 200, # 'PitfallNoFrameskip-v4': 100, # }[env_id] venv.score_multiple = 1 venv.record_obs = True ob_space = venv.observation_space ac_space = venv.action_space gamma = hps.pop('gamma') policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')] agent = PpoAgent( scope='ppo', ob_space=ob_space, ac_space=ac_space, stochpol_fn=functools.partial( policy, scope='pol', ob_space=ob_space, ac_space=ac_space, update_ob_stats_independently_per_gpu=hps.pop( 'update_ob_stats_independently_per_gpu'), proportion_of_exp_used_for_predictor_update=hps.pop( 'proportion_of_exp_used_for_predictor_update'), dynamics_bonus=hps.pop("dynamics_bonus")), gamma=gamma, gamma_ext=hps.pop('gamma_ext'), lam=hps.pop('lam'), nepochs=hps.pop('nepochs'), nminibatches=hps.pop('nminibatches'), lr=hps.pop('lr'), cliprange=0.1, nsteps=128, ent_coef=0.001, max_grad_norm=hps.pop('max_grad_norm'), use_news=hps.pop("use_news"), comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None, update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'), int_coeff=hps.pop('int_coeff'), ext_coeff=hps.pop('ext_coeff'), obs_save_flag=True) tf_util.load_state("saved_states/save1") agent.start_interaction([venv]) counter = 0 while True: info = agent.step() if agent.I.stats['epcount'] > 1: with open("obs_acs.pickle", 'wb') as f1: pickle.dump(agent.obs_rec, f1) break
def train(*, env_id, num_env, hps, num_timesteps, seed, test=False, e_greedy=0, **kwargs): # import pdb; pdb.set_trace() print('_________________________________________________________') pprint.pprint(f'hyperparams: {hps}', width=1) pprint.pprint(f'additional hyperparams: {kwargs}', width=1) print('_________________________________________________________') # TODO: this is just for debugging # tmp_env = make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(clip_rewards=kwargs['clip_rewards'])) venv = VecFrameStack( make_atari_env( env_id, num_env, seed, wrapper_kwargs=dict(clip_rewards=kwargs['clip_rewards']), start_index=num_env * MPI.COMM_WORLD.Get_rank(), max_episode_steps=hps.pop('max_episode_steps'), action_space=kwargs['action_space']), hps.pop('frame_stack')) # venv.score_multiple = {'Mario': 500, # 'MontezumaRevengeNoFrameskip-v4': 100, # 'GravitarNoFrameskip-v4': 250, # 'PrivateEyeNoFrameskip-v4': 500, # 'SolarisNoFrameskip-v4': None, # 'VentureNoFrameskip-v4': 200, # 'PitfallNoFrameskip-v4': 100, # }[env_id] venv.score_multiple = 1 # TODO: understand what is score multiple venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False ob_space = venv.observation_space ac_space = venv.action_space gamma = hps.pop('gamma') policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')] agent = PpoAgent( scope='ppo', ob_space=ob_space, ac_space=ac_space, stochpol_fn=functools.partial( policy, scope='pol', ob_space=ob_space, ac_space=ac_space, update_ob_stats_independently_per_gpu=hps.pop( 'update_ob_stats_independently_per_gpu'), proportion_of_exp_used_for_predictor_update=hps.pop( 'proportion_of_exp_used_for_predictor_update'), dynamics_bonus=hps.pop("dynamics_bonus")), gamma=gamma, gamma_ext=hps.pop('gamma_ext'), lam=hps.pop('lam'), nepochs=hps.pop('nepochs'), nminibatches=hps.pop('nminibatches'), lr=hps.pop('lr'), cliprange=0.1, nsteps=128, ent_coef=0.001, max_grad_norm=hps.pop('max_grad_norm'), use_news=hps.pop("use_news"), comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None, update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'), int_coeff=hps.pop('int_coeff'), ext_coeff=hps.pop('ext_coeff'), ) saver = tf.train.Saver() if test: agent.restore_model(saver, kwargs['load_dir'], kwargs['exp_name'], mtype=kwargs['load_mtype']) agent.start_interaction([venv]) import time st = time.time() if hps.pop('update_ob_stats_from_random_agent'): if os.path.exists('./data/ob_rms.pkl'): with open('./data/ob_rms.pkl', 'rb') as handle: agent.stochpol.ob_rms.mean, agent.stochpol.ob_rms.var, agent.stochpol.ob_rms.count = pickle.load( handle) else: ob_rms = agent.collect_random_statistics( num_timesteps=128 * 50) # original number128*50 with open('./data/ob_rms.pkl', 'wb') as handle: pickle.dump([ob_rms.mean, ob_rms.var, ob_rms.count], handle, protocol=2) assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys()) print(f'Time duration {time.time() - st}') if not test: counter = 1 while True: info = agent.step() if info['update']: # import pdb; pdb.set_trace() logger.logkvs(info['update']) logger.dumpkvs() if agent.I.stats['tcount'] // 1e7 == counter: agent.save_model(saver, kwargs['save_dir'], kwargs['exp_name'], mtype=str(counter * 1e7)) counter += 1 if agent.I.stats['tcount'] > num_timesteps: break if test: for pkls in range(0, 10): print('collecting the', pkls, 'th pickle', ' ' * 20) all_rollout = agent.evaluate(env_id=env_id, seed=seed, episodes=1000, save_image=kwargs['save_image'], e_greedy=e_greedy) with open('./data/newworld21_' + str(pkls).zfill(1) + '.pkl', 'wb') as handle: pickle.dump(all_rollout, handle) agent.stop_interaction() if not test: agent.save_model(saver, kwargs['save_dir'], kwargs['exp_name'])
def train(*, env_id, num_env, hps, num_timesteps, seed): experiment = os.environ.get('EXPERIMENT_LVL') if experiment == 'ego': # for the ego experiment we needed a higher intrinsic coefficient hps['int_coeff'] = 3.0 hyperparams = copy(hps) hyperparams.update({'seed': seed}) logger.info("Hyperparameters:") logger.info(hyperparams) venv = VecFrameStack( make_atari_env(env_id, num_env, seed, wrapper_kwargs={}, start_index=num_env * MPI.COMM_WORLD.Get_rank(), max_episode_steps=hps.pop('max_episode_steps')), hps.pop('frame_stack')) venv.score_multiple = { 'Mario': 500, 'MontezumaRevengeNoFrameskip-v4': 1, 'GravitarNoFrameskip-v4': 250, 'PrivateEyeNoFrameskip-v4': 500, 'SolarisNoFrameskip-v4': None, 'VentureNoFrameskip-v4': 200, 'PitfallNoFrameskip-v4': 100, }[env_id] venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False ob_space = venv.observation_space ac_space = venv.action_space gamma = hps.pop('gamma') policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')] agent = PpoAgent( scope='ppo', ob_space=ob_space, ac_space=ac_space, stochpol_fn=functools.partial( policy, scope='pol', ob_space=ob_space, ac_space=ac_space, update_ob_stats_independently_per_gpu=hps.pop( 'update_ob_stats_independently_per_gpu'), proportion_of_exp_used_for_predictor_update=hps.pop( 'proportion_of_exp_used_for_predictor_update'), dynamics_bonus=hps.pop("dynamics_bonus")), gamma=gamma, gamma_ext=hps.pop('gamma_ext'), lam=hps.pop('lam'), nepochs=hps.pop('nepochs'), nminibatches=hps.pop('nminibatches'), lr=hps.pop('lr'), cliprange=0.1, nsteps=128, ent_coef=0.001, max_grad_norm=hps.pop('max_grad_norm'), use_news=hps.pop("use_news"), comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None, update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'), int_coeff=hps.pop('int_coeff'), ext_coeff=hps.pop('ext_coeff'), restore_model_path=hps.pop('restore_model_path')) agent.start_interaction([venv]) if hps.pop('update_ob_stats_from_random_agent'): agent.collect_random_statistics(num_timesteps=128 * 50) save_model = hps.pop('save_model') assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys()) #profiler = cProfile.Profile() #profiler.enable() #tracemalloc.start() #prev_snap = tracemalloc.take_snapshot() counter = 0 while True: info = agent.step() if info['update']: logger.logkvs(info['update']) logger.dumpkvs() counter += 1 # if (counter % 10) == 0: # snapshot = tracemalloc.take_snapshot() # top_stats = snapshot.compare_to(prev_snap, 'lineno') # for stat in top_stats[:10]: # print(stat) # prev_snap = snapshot # profiler.dump_stats("profile_rnd") if (counter % 100) == 0 and save_model: agent.save_model(agent.I.step_count) if agent.I.stats['tcount'] > num_timesteps: break agent.stop_interaction()