def log_reward_statistics(vec_env, num_last_eps=100, prefix=""): all_stats = None for _ in range(10): try: all_stats = load_results(osp.dirname( vec_env.results_writer.f.name)) except FileNotFoundError: time.sleep(1) continue if all_stats is not None: episode_rewards = all_stats["r"] episode_lengths = all_stats["l"] recent_episode_rewards = episode_rewards[-num_last_eps:] recent_episode_lengths = episode_lengths[-num_last_eps:] if len(recent_episode_rewards) > 0: kvs = { prefix + "AverageReturn": np.mean(recent_episode_rewards), prefix + "MinReturn": np.min(recent_episode_rewards), prefix + "MaxReturn": np.max(recent_episode_rewards), prefix + "StdReturn": np.std(recent_episode_rewards), prefix + "AverageEpisodeLength": np.mean(recent_episode_lengths), prefix + "MinEpisodeLength": np.min(recent_episode_lengths), prefix + "MaxEpisodeLength": np.max(recent_episode_lengths), prefix + "StdEpisodeLength": np.std(recent_episode_lengths), } logger.logkvs(kvs) logger.logkv(prefix + "TotalNEpisodes", len(episode_rewards))
def train(self, saver, logger_dir): # 初始化计算图, 初始化 rollout 类 self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamics=self.dynamics) previous_saved_tcount = 0 while True: info = self.agent.step() # 与环境交互一个周期, 收集样本, 计算内在激励, 并训练 if info['update']: logger.logkvs(info['update']) logger.dumpkvs() if self.hps["save_period"] and (int( self.agent.rollout.stats['tcount'] / self.hps["save_freq"]) > previous_saved_tcount): previous_saved_tcount += 1 save_path = saver.save( tf.get_default_session(), os.path.join( logger_dir, "model_" + str(previous_saved_tcount) + ".ckpt")) print("Periodically model saved in path:", save_path) if self.agent.rollout.stats['tcount'] > self.num_timesteps: save_path = saver.save( tf.get_default_session(), os.path.join(logger_dir, "model_last.ckpt")) print("Model saved in path:", save_path) break self.agent.stop_interaction()
def train(self): if self.save_checkpoint: params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) saver = tf.train.Saver(var_list=params, max_to_keep=self.num_timesteps // 1000000 + 1) periods = list(range(0, self.num_timesteps + 1, 1000000)) idx = 0 self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamics=self.dynamics) while True: info = self.agent.step() if info['update']: logger.logkvs(info['update']) logger.dumpkvs() if self.save_checkpoint: if self.agent.rollout.stats['tcount'] >= periods[idx]: self.save(saver, logger.get_dir() + '/checkpoint/', periods[idx]) idx += 1 if self.agent.rollout.stats['tcount'] > self.num_timesteps: break self.agent.stop_interaction()
def train(self, saver, sess, restore=False): from baselines import logger self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamics=self.dynamics) if restore: print("Restoring model for training") saver.restore(sess, "models/" + self.hps['restore_model'] + ".ckpt") print("Loaded model", self.hps['restore_model']) write_meta_graph = False while True: info = self.agent.step() if info['update']: if info['update']['recent_best_ext_ret'] is None: info['update']['recent_best_ext_ret'] = 0 wandb.log(info['update']) logger.logkvs(info['update']) logger.dumpkvs() if self.agent.rollout.stats['tcount'] > self.num_timesteps: break if self.hps['tune_env']: filename = "models/" + self.hps['restore_model'] + "_tune_on_" + self.hps['tune_env'] + "_final.ckpt" else: filename = "models/" + self.hps['exp_name'] + "_final.ckpt" saver.save(sess, filename, write_meta_graph=False) self.policy.save_model(self.hps['exp_name'], 'final') self.agent.stop_interaction()
def train(self): import random self.agent.start_interaction(self.envs, nlump=self.hps["nlumps"], dynamics=self.dynamics) count = 0 while True: count += 1 info = self.agent.step() if info["update"]: logger.logkvs(info["update"]) logger.dumpkvs() if self.hps["feat_learning"] == "pix2pix": making_video = random.choice(99 * [False] + [True]) else: making_video = False self.agent.rollout.making_video = making_video for a_key in info.keys(): wandb.log(info[a_key]) wandb.log( {"average_sigma": np.mean(self.agent.rollout.buf_sigmas)}) # going to have to log it here if self.agent.rollout.stats["tcount"] > self.num_timesteps: break self.agent.stop_interaction()
def train(self, saver, sess, restore=False): self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamics=self.dynamics) write_meta_graph = False saves = 0 loops = 0 while True: info = self.agent.step(eval=False) if info is not None: if info['update'] and not restore: logger.logkvs(info['update']) logger.dumpkvs() steps = self.agent.rollout.stats['tcount'] if loops % 10 == 0: filename = args.saved_model_dir + 'model.ckpt' saver.save(sess, filename, global_step=int(saves), write_meta_graph=False) saves += 1 loops += 1 if steps > self.num_timesteps: break self.agent.stop_interaction()
def train(self): self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamics=self.dynamics) save_path = 'models' tf_sess = tf.get_default_session() # Create a saver. saver = tf.train.Saver(save_relative_paths=True) # if self.hps['restore_latest_checkpoint']: # Restore latest checkpoint if set in arguments # saver.restore(tf_sess, tf.train.latest_checkpoint(save_path)) while True: info = self.agent.step() if info['update']: logger.logkvs(info['update']) logger.dumpkvs() if self.agent.rollout.stat['tcount'] > self.num_timesteps: break # Saving the model every 1,000 steps. if info['n_updates'] % 1000 == 0: # Append the step number to the checkpoint name: saver.save(tf_sess, save_path + '/obstacle_tower', global_step=int(self.agent.rollout.stats['tcount'])) # Append the step number to the last checkpoint name: saver.save(tf_sess, save_path + '/obstacle_tower', global_step=int(self.agent.rollout.stats['tcount'])) self.agent.stop_interaction()
def train(self): self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamics=self.action_dynamics) expdir = osp.join("/result", self.hps['env'], self.hps['exp_name']) save_checkpoints = [] if self.hps['save_interval'] is not None: save_checkpoints = [i*self.hps['save_interval'] for i in range(1, self.hps['num_timesteps']//self.hps['save_interval'])] if self.hps['load_dir'] is not None: self.train_feature_extractor.load(self.hps['load_dir']) self.train_dynamics.load(self.hps['load_dir']) while True: info = self.agent.step() if info['update']: logger.logkvs(info['update']) logger.dumpkvs() if len(save_checkpoints) > 0: if self.agent.rollout.stats['tcount'] > save_checkpoints[0]: self.train_feature_extractor.save(expdir, self.agent.rollout.stats['tcount']) self.train_dynamics.save(expdir, self.agent.rollout.stats['tcount']) save_checkpoints.remove(save_checkpoints[0]) if self.agent.rollout.stats['tcount'] > self.num_timesteps: break if self.hps['save_dynamics'] and MPI.COMM_WORLD.Get_rank()== 0: # save auxilary task and dynamics parameter self.train_feature_extractor.save(expdir) self.train_dynamics.save(expdir) self.agent.stop_interaction()
def train(self): self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamics=self.dynamics) while True: info = self.agent.step() if info['update']: logger.logkvs(info['update']) logger.dumpkvs() if self.agent.rollout.stats['tcount'] > self.num_timesteps: break self.agent.stop_interaction()
def train(self): self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], intrinsic_model=self.intrinsic_model) sess = getsess() while True: info = self.agent.step() if info['update']: logger.logkvs(info['update']) logger.dumpkvs() if self.agent.rollout.stats['tcount'] > self.num_timesteps: break self.agent.stop_interaction()
def train(self): self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamics=self.dynamics) while True: info = self.agent.step() if info['update']: logger.logkvs(info['update']) logger.dumpkvs() if self.agent.rollout.stats['tcount'] == 0: fname = os.path.join(self.hps['save_dir'], 'checkpoints') if os.path.exists(fname+'.index'): load_state(fname) print('load successfully') else: print('fail to load') if self.agent.rollout.stats['tcount']%int(self.num_timesteps/self.num_timesteps)==0: fname = os.path.join(self.hps['save_dir'], 'checkpoints') save_state(fname) if self.agent.rollout.stats['tcount'] > self.num_timesteps: break # print(self.agent.rollout.stats['tcount']) self.agent.stop_interaction()
def make_save_dir_and_log_basics(argdict): if not gflag.save_dir: assert not gflag.resumable, "You cannot set --resumable without setting --save-dir." else: assert gflag.resumable or (not os.path.exists(gflag.save_dir)), \ "--save_dir '%s' already exists and resumable is False. " % (gflag.save_dir) + \ "This might be because condor killed and rescheduled the original task." + \ "To prevent log.txt being overwritten/appended by a possibly different model's log, " + \ "the program will terminate now." os.makedirs(gflag.save_dir, exist_ok=True) logger.configure(gflag.save_dir, format_strs=['log', 'stdout']) logger.logkvs(argdict) logger.dumpkvs() # copy related py files to save_dir to generate a snapshot of code being run snapshot_dir = gflag.save_dir + "/all_py_files_snapshot/" py_files = subprocess.check_output("find baselines | grep '\\.py$'", shell=True).decode('utf-8').split() py_files += subprocess.check_output( "ls *.py", shell=True).decode('utf-8').split() for py_file in py_files: os.makedirs(snapshot_dir + os.path.dirname(py_file), exist_ok=True) shutil.copyfile(py_file, snapshot_dir + py_file)
def test(self, saver, sess): self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamics=self.dynamics) print('loading model') saver.restore(sess, args.saved_model_dir + args.model_name) print('loaded model,', args.saved_model_dir + args.model_name) include_images = args.include_images and eval info = self.agent.step(eval=True, include_images=include_images) if info['update']: logger.logkvs(info['update']) logger.dumpkvs() # save actions, news, and / or images np.save(args.env + '_data.npy', info) print('EVALUATION COMPLETED') print('SAVED DATA IN CURRENT DIRECTORY') print('FILENAME', args.env + '_data.npy') self.agent.stop_interaction()
def train(*, env_id, num_env, hps, num_timesteps, seed): venv = VecFrameStack( make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(), start_index=num_env * MPI.COMM_WORLD.Get_rank(), max_episode_steps=hps.pop('max_episode_steps')), hps.pop('frame_stack')) # venv.score_multiple = {'Mario': 500, # 'MontezumaRevengeNoFrameskip-v4': 100, # 'GravitarNoFrameskip-v4': 250, # 'PrivateEyeNoFrameskip-v4': 500, # 'SolarisNoFrameskip-v4': None, # 'VentureNoFrameskip-v4': 200, # 'PitfallNoFrameskip-v4': 100, # }[env_id] venv.score_multiple = 1 venv.record_obs = True # venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False ob_space = venv.observation_space ac_space = venv.action_space gamma = hps.pop('gamma') policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')] agent = PpoAgent( scope='ppo', ob_space=ob_space, ac_space=ac_space, stochpol_fn=functools.partial( policy, scope='pol', ob_space=ob_space, ac_space=ac_space, update_ob_stats_independently_per_gpu=hps.pop('update_ob_stats_independently_per_gpu'), proportion_of_exp_used_for_predictor_update=hps.pop('proportion_of_exp_used_for_predictor_update'), dynamics_bonus = hps.pop("dynamics_bonus") ), gamma=gamma, gamma_ext=hps.pop('gamma_ext'), lam=hps.pop('lam'), nepochs=hps.pop('nepochs'), nminibatches=hps.pop('nminibatches'), lr=hps.pop('lr'), cliprange=0.1, nsteps=128, ent_coef=0.001, max_grad_norm=hps.pop('max_grad_norm'), use_news=hps.pop("use_news"), comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None, update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'), int_coeff=hps.pop('int_coeff'), ext_coeff=hps.pop('ext_coeff'), ) agent.start_interaction([venv]) if hps.pop('update_ob_stats_from_random_agent'): agent.collect_random_statistics(num_timesteps=128*50) assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys()) counter = 0 while True: info = agent.step() if info['update']: logger.logkvs(info['update']) logger.dumpkvs() counter += 1 if agent.I.stats['tcount'] > num_timesteps: break agent.stop_interaction() return agent
def train(*, env_id, num_env, hps, num_timesteps, seed): experiment = os.environ.get('EXPERIMENT_LVL') if experiment == 'ego': # for the ego experiment we needed a higher intrinsic coefficient hps['int_coeff'] = 3.0 hyperparams = copy(hps) hyperparams.update({'seed': seed}) logger.info("Hyperparameters:") logger.info(hyperparams) venv = VecFrameStack( make_atari_env(env_id, num_env, seed, wrapper_kwargs={}, start_index=num_env * MPI.COMM_WORLD.Get_rank(), max_episode_steps=hps.pop('max_episode_steps')), hps.pop('frame_stack')) venv.score_multiple = { 'Mario': 500, 'MontezumaRevengeNoFrameskip-v4': 1, 'GravitarNoFrameskip-v4': 250, 'PrivateEyeNoFrameskip-v4': 500, 'SolarisNoFrameskip-v4': None, 'VentureNoFrameskip-v4': 200, 'PitfallNoFrameskip-v4': 100, }[env_id] venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False ob_space = venv.observation_space ac_space = venv.action_space gamma = hps.pop('gamma') policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')] agent = PpoAgent( scope='ppo', ob_space=ob_space, ac_space=ac_space, stochpol_fn=functools.partial( policy, scope='pol', ob_space=ob_space, ac_space=ac_space, update_ob_stats_independently_per_gpu=hps.pop( 'update_ob_stats_independently_per_gpu'), proportion_of_exp_used_for_predictor_update=hps.pop( 'proportion_of_exp_used_for_predictor_update'), dynamics_bonus=hps.pop("dynamics_bonus")), gamma=gamma, gamma_ext=hps.pop('gamma_ext'), lam=hps.pop('lam'), nepochs=hps.pop('nepochs'), nminibatches=hps.pop('nminibatches'), lr=hps.pop('lr'), cliprange=0.1, nsteps=128, ent_coef=0.001, max_grad_norm=hps.pop('max_grad_norm'), use_news=hps.pop("use_news"), comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None, update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'), int_coeff=hps.pop('int_coeff'), ext_coeff=hps.pop('ext_coeff'), restore_model_path=hps.pop('restore_model_path')) agent.start_interaction([venv]) if hps.pop('update_ob_stats_from_random_agent'): agent.collect_random_statistics(num_timesteps=128 * 50) save_model = hps.pop('save_model') assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys()) #profiler = cProfile.Profile() #profiler.enable() #tracemalloc.start() #prev_snap = tracemalloc.take_snapshot() counter = 0 while True: info = agent.step() if info['update']: logger.logkvs(info['update']) logger.dumpkvs() counter += 1 # if (counter % 10) == 0: # snapshot = tracemalloc.take_snapshot() # top_stats = snapshot.compare_to(prev_snap, 'lineno') # for stat in top_stats[:10]: # print(stat) # prev_snap = snapshot # profiler.dump_stats("profile_rnd") if (counter % 100) == 0 and save_model: agent.save_model(agent.I.step_count) if agent.I.stats['tcount'] > num_timesteps: break agent.stop_interaction()
def train(self): curr_iter = 0 # train progress results logger format_strs = ['csv'] format_strs = filter(None, format_strs) dirc = os.path.join(self.args['log_dir'], 'inter') if self.restore_iter > -1: dirc = os.path.join(self.args['log_dir'], 'inter-{}'.format(self.restore_iter)) output_formats = [ logger.make_output_format(f, dirc) for f in format_strs ] self.result_logger = logger.Logger(dir=dirc, output_formats=output_formats) # in case we are restoring the training if self.restore_iter > -1: self.agent.load(self.load_path) if not self.args['transfer_load']: curr_iter = self.restore_iter print('max_iter: {}'.format(self.max_iter)) # interim saves to compare in the future # for 128M frames, inter_save = [] for i in range(3): divisor = (2**(i + 1)) inter_save.append( int(self.args['num_timesteps'] // divisor) // (self.args['nsteps'] * self.args['NUM_ENVS'] * self.args['nframeskip'])) print('inter_save: {}'.format(inter_save)) total_time = 0.0 # results_list = [] while curr_iter < self.early_max_iter: frac = 1.0 - (float(curr_iter) / self.max_iter) # self.agent.update calls rollout start_time = time.time() ## linearly annealing curr_lr = self.lr(frac) curr_cr = self.cliprange(frac) ## removed within training evaluation ## i could not make flag_sum to work properly ## evaluate each 100 run for 20 training levels # only for mario (first evaluate, then update) # i am doing change to get zero-shot generalization without any effort if curr_iter % (self.args['save_interval']) == 0: save_video = False nlevels = 20 if self.args[ 'env_kind'] == 'mario' else self.args['NUM_LEVELS'] results, _ = self.agent.evaluate(nlevels, save_video) results['iter'] = curr_iter for (k, v) in results.items(): self.result_logger.logkv(k, v) self.result_logger.dumpkvs() # representation learning in each 25 steps info = self.agent.update(lr=curr_lr, cliprange=curr_cr) end_time = time.time() # additional info info['frac'] = frac info['curr_lr'] = curr_lr info['curr_cr'] = curr_cr info['curr_iter'] = curr_iter # info['max_iter'] = self.max_iter info['elapsed_time'] = end_time - start_time # info['total_time'] = total_time = (total_time + info['elapsed_time']) / 3600.0 info['expected_time'] = self.max_iter * info[ 'elapsed_time'] / 3600.0 ## logging results using baselines's logger logger.logkvs(info) logger.dumpkvs() if curr_iter % self.args['save_interval'] == 0: self.agent.save(curr_iter, cliprange=curr_cr) if curr_iter in inter_save: self.agent.save(curr_iter, cliprange=curr_cr) curr_iter += 1 self.agent.save(curr_iter, cliprange=curr_cr) # final evaluation for mario save_video = False nlevels = 20 if self.args['env_kind'] == 'mario' else self.args[ 'NUM_LEVELS'] results, _ = self.agent.evaluate(nlevels, save_video) results['iter'] = curr_iter for (k, v) in results.items(): self.result_logger.logkv(k, v) self.result_logger.dumpkvs()
def train(*, env_id, num_env, hps, num_timesteps, seed): venv = VecFrameStack( make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(), start_index=num_env * MPI.COMM_WORLD.Get_rank(), max_episode_steps=hps.pop('max_episode_steps')), hps.pop('frame_stack')) venv.score_multiple = 1 venv.record_obs = False ob_space = venv.observation_space ac_space = venv.action_space gamma = hps.pop('gamma') policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')] agent = PpoAgent( scope='ppo', ob_space=ob_space, ac_space=ac_space, stochpol_fn=functools.partial( policy, scope='pol', ob_space=ob_space, ac_space=ac_space, update_ob_stats_independently_per_gpu=hps.pop( 'update_ob_stats_independently_per_gpu'), proportion_of_exp_used_for_predictor_update=hps.pop( 'proportion_of_exp_used_for_predictor_update'), exploration_type=hps.pop("exploration_type"), beta=hps.pop("beta"), ), gamma=gamma, gamma_ext=hps.pop('gamma_ext'), lam=hps.pop('lam'), nepochs=hps.pop('nepochs'), nminibatches=hps.pop('nminibatches'), lr=hps.pop('lr'), cliprange=0.1, nsteps=128, ent_coef=0.001, max_grad_norm=hps.pop('max_grad_norm'), use_news=hps.pop("use_news"), comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None, update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'), int_coeff=hps.pop('int_coeff'), ext_coeff=hps.pop('ext_coeff'), noise_type=hps.pop('noise_type'), noise_p=hps.pop('noise_p'), use_sched=hps.pop('use_sched'), num_env=num_env, exp_name=hps.pop('exp_name'), ) agent.start_interaction([venv]) if hps.pop('update_ob_stats_from_random_agent'): agent.collect_random_statistics(num_timesteps=128 * 50) assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys()) counter = 0 while True: info = agent.step() n_updates = 0 if info['update']: logger.logkvs(info['update']) logger.dumpkvs() if NSML: n_updates = int(info['update']['n_updates']) nsml_dict = { k: np.float64(v) for k, v in info['update'].items() if isinstance(v, Number) } nsml.report(step=n_updates, **nsml_dict) counter += 1 #if n_updates >= 40*1000: # 40K updates # break if agent.I.stats['tcount'] > num_timesteps: break agent.stop_interaction()
def train(*, env_id, num_env, hps, num_timesteps, seed): if "NoFrameskip" in env_id: env_factory = make_atari_env else: env_factory = make_non_atari_env venv = VecFrameStack( env_factory( env_id, num_env, seed, wrapper_kwargs=dict(), start_index=num_env * MPI.COMM_WORLD.Get_rank(), max_episode_steps=hps.pop("max_episode_steps"), ), hps.pop("frame_stack"), ) # venv.score_multiple = {'Mario': 500, # 'MontezumaRevengeNoFrameskip-v4': 100, # 'GravitarNoFrameskip-v4': 250, # 'PrivateEyeNoFrameskip-v4': 500, # 'SolarisNoFrameskip-v4': None, # 'VentureNoFrameskip-v4': 200, # 'PitfallNoFrameskip-v4': 100, # }[env_id] venv.score_multiple = 1 venv.record_obs = True if env_id == "SolarisNoFrameskip-v4" else False ob_space = venv.observation_space ac_space = venv.action_space gamma = hps.pop("gamma") policy = { "rnn": CnnGruPolicy, "cnn": CnnPolicy, 'ffnn': GruPolicy }[hps.pop("policy")] agent = PpoAgent( scope="ppo", ob_space=ob_space, ac_space=ac_space, stochpol_fn=functools.partial( policy, scope="pol", ob_space=ob_space, ac_space=ac_space, update_ob_stats_independently_per_gpu=hps.pop( "update_ob_stats_independently_per_gpu"), proportion_of_exp_used_for_predictor_update=hps.pop( "proportion_of_exp_used_for_predictor_update"), dynamics_bonus=hps.pop("dynamics_bonus"), meta_rl=hps['meta_rl']), gamma=gamma, gamma_ext=hps.pop("gamma_ext"), lam=hps.pop("lam"), nepochs=hps.pop("nepochs"), nminibatches=hps.pop("nminibatches"), lr=hps.pop("lr"), cliprange=0.1, nsteps=128, ent_coef=0.001, max_grad_norm=hps.pop("max_grad_norm"), use_news=hps.pop("use_news"), comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None, update_ob_stats_every_step=hps.pop("update_ob_stats_every_step"), int_coeff=hps.pop("int_coeff"), ext_coeff=hps.pop("ext_coeff"), meta_rl=hps.pop('meta_rl')) agent.start_interaction([venv]) if hps.pop("update_ob_stats_from_random_agent"): agent.collect_random_statistics(num_timesteps=128 * 50) assert len(hps) == 0, f"Unused hyperparameters: {list(hps.keys())}" counter = 0 while True: info = agent.step() if info["update"]: logger.logkvs(info["update"]) logger.dumpkvs() counter += 1 if agent.I.stats["tcount"] > num_timesteps: break agent.stop_interaction()
def train(args): xpid = '-{}-{}-reproduce-s{}'.format(args.run_name, args.env_name, args.seed) wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=args, name=xpid) args.cuda = not args.no_cuda and torch.cuda.is_available() if args.cuda: print('Using CUDA') else: print('Not using CUDA') torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) log_dir = os.path.expanduser(args.log_dir) utils.cleanup_log_dir(log_dir) torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") log_file = xpid venv = ProcgenEnv(num_envs=args.num_processes, env_name=args.env_name, \ num_levels=args.num_levels, start_level=args.start_level, \ distribution_mode=args.distribution_mode) venv = VecExtractDictObs(venv, "rgb") venv = VecMonitor(venv=venv, filename=None, keep_buf=100) venv = VecNormalize(venv=venv, ob=False) envs = VecPyTorchProcgen(venv, device) obs_shape = envs.observation_space.shape actor_critic = Policy(obs_shape, envs.action_space.n, base_kwargs={ 'recurrent': False, 'hidden_size': args.hidden_size }) actor_critic.to(device) rollouts = RolloutStorage(args.num_steps, args.num_processes, envs.observation_space.shape, envs.action_space, actor_critic.recurrent_hidden_state_size, aug_type=args.aug_type, split_ratio=args.split_ratio) batch_size = int(args.num_processes * args.num_steps / args.num_mini_batch) if args.use_ucb: print('Using UCB') aug_id = data_augs.Identity aug_list = [ aug_to_func[t](batch_size=batch_size) for t in list(aug_to_func.keys()) ] agent = algo.UCBDrAC(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, aug_list=aug_list, aug_id=aug_id, aug_coef=args.aug_coef, num_aug_types=len(list(aug_to_func.keys())), ucb_exploration_coef=args.ucb_exploration_coef, ucb_window_length=args.ucb_window_length) elif args.use_meta_learning: aug_id = data_augs.Identity aug_list = [aug_to_func[t](batch_size=batch_size) \ for t in list(aug_to_func.keys())] aug_model = AugCNN() aug_model.to(device) agent = algo.MetaDrAC(actor_critic, aug_model, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, meta_grad_clip=args.meta_grad_clip, meta_num_train_steps=args.meta_num_train_steps, meta_num_test_steps=args.meta_num_test_steps, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, aug_id=aug_id, aug_coef=args.aug_coef) elif args.use_rl2: aug_id = data_augs.Identity aug_list = [ aug_to_func[t](batch_size=batch_size) for t in list(aug_to_func.keys()) ] rl2_obs_shape = [envs.action_space.n + 1] rl2_learner = Policy(rl2_obs_shape, len(list(aug_to_func.keys())), base_kwargs={ 'recurrent': True, 'hidden_size': args.rl2_hidden_size }) rl2_learner.to(device) agent = algo.RL2DrAC(actor_critic, rl2_learner, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, args.rl2_entropy_coef, lr=args.lr, eps=args.eps, rl2_lr=args.rl2_lr, rl2_eps=args.rl2_eps, max_grad_norm=args.max_grad_norm, aug_list=aug_list, aug_id=aug_id, aug_coef=args.aug_coef, num_aug_types=len(list(aug_to_func.keys())), recurrent_hidden_size=args.rl2_hidden_size, num_actions=envs.action_space.n, device=device) else: aug_id = data_augs.Identity aug_func = aug_to_func[args.aug_type](batch_size=batch_size) agent = algo.DrAC(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, aug_id=aug_id, aug_func=aug_func, aug_coef=args.aug_coef, env_name=args.env_name) checkpoint_path = os.path.join(args.save_dir, "agent" + log_file + ".pt") if os.path.exists(checkpoint_path) and args.preempt: checkpoint = torch.load(checkpoint_path) agent.actor_critic.load_state_dict(checkpoint['model_state_dict']) agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) init_epoch = checkpoint['epoch'] + 1 logger.configure(dir=args.log_dir, format_strs=['csv', 'stdout'], log_suffix=log_file + "-e%s" % init_epoch) else: init_epoch = 0 logger.configure(dir=args.log_dir, format_strs=['csv', 'stdout'], log_suffix=log_file) obs = envs.reset() rollouts.obs[0].copy_(obs) rollouts.to(device) episode_rewards = deque(maxlen=10) num_updates = int( args.num_env_steps) // args.num_steps // args.num_processes for j in range(init_epoch, num_updates): actor_critic.train() for step in range(args.num_steps): # Sample actions with torch.no_grad(): obs_id = aug_id(rollouts.obs[step]) value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( obs_id, rollouts.recurrent_hidden_states[step], rollouts.masks[step]) # Obser reward and next obs obs, reward, done, infos = envs.step(action) for info in infos: if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) # If done then clean the history of observations. masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]) rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, bad_masks) with torch.no_grad(): obs_id = aug_id(rollouts.obs[-1]) next_value = actor_critic.get_value( obs_id, rollouts.recurrent_hidden_states[-1], rollouts.masks[-1]).detach() rollouts.compute_returns(next_value, args.gamma, args.gae_lambda) if args.use_ucb and j > 0: agent.update_ucb_values(rollouts) value_loss, action_loss, dist_entropy = agent.update(rollouts) rollouts.after_update() # save for every interval-th episode or for the last epoch total_num_steps = (j + 1) * args.num_processes * args.num_steps if j % args.log_interval == 0 and len(episode_rewards) > 1: total_num_steps = (j + 1) * args.num_processes * args.num_steps print( "\nUpdate {}, step {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}" .format(j, total_num_steps, len(episode_rewards), np.mean(episode_rewards), np.median(episode_rewards), dist_entropy, value_loss, action_loss)) ### Eval on the Full Distribution of Levels ### eval_episode_rewards = evaluate(args, actor_critic, device, aug_id=aug_id) stats = { "train/nupdates": j, "train/total_num_steps": total_num_steps, "losses/dist_entropy": dist_entropy, "losses/value_loss": value_loss, "losses/action_loss": action_loss, "train/mean_episode_reward": np.mean(episode_rewards), "train/median_episode_reward": np.median(episode_rewards), "test/mean_episode_reward": np.mean(eval_episode_rewards), "test/median_episode_reward": np.median(eval_episode_rewards) } logger.logkvs(stats) logger.dumpkvs() wandb.log(stats) # Save Model if (j > 0 and j % args.save_interval == 0 or j == num_updates - 1) and args.save_dir != "": try: os.makedirs(args.save_dir) except OSError: pass torch.save( { 'epoch': j, 'model_state_dict': agent.actor_critic.state_dict(), 'optimizer_state_dict': agent.optimizer.state_dict(), }, os.path.join(args.save_dir, "agent" + log_file + ".pt"))
def train(*, env_id, num_env, hps, num_timesteps, seed, use_reward, ep_path, dmlab): venv = VecFrameStack( DummyVecEnv([ lambda: CollectGymDataset(Minecraft('MineRLTreechop-v0', 'treechop' ), ep_path, atari=False) ]), hps.pop('frame_stack')) venv.score_multiple = 1 venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False ob_space = venv.observation_space ac_space = venv.action_space gamma = hps.pop('gamma') policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')] print('Running train ==========================================') agent = PpoAgent( scope='ppo', ob_space=ob_space, ac_space=ac_space, stochpol_fn=functools.partial( policy, scope='pol', ob_space=ob_space, ac_space=ac_space, update_ob_stats_independently_per_gpu=hps.pop( 'update_ob_stats_independently_per_gpu'), proportion_of_exp_used_for_predictor_update=hps.pop( 'proportion_of_exp_used_for_predictor_update'), dynamics_bonus=hps.pop("dynamics_bonus")), gamma=gamma, gamma_ext=hps.pop('gamma_ext'), lam=hps.pop('lam'), nepochs=hps.pop('nepochs'), nminibatches=hps.pop('nminibatches'), lr=hps.pop('lr'), cliprange=0.1, nsteps=128, ent_coef=0.001, max_grad_norm=hps.pop('max_grad_norm'), use_news=hps.pop("use_news"), comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None, update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'), int_coeff=hps.pop('int_coeff'), ext_coeff=hps.pop('ext_coeff')) agent.start_interaction([venv]) if hps.pop('update_ob_stats_from_random_agent'): agent.collect_random_statistics(num_timesteps=128 * 50) assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys()) counter = 0 while True: info = agent.step() if info['update']: logger.logkvs(info['update']) logger.dumpkvs() counter += 1 if agent.I.stats['tcount'] > num_timesteps: break agent.stop_interaction()
def train(*, env_id, num_env, hps, num_timesteps, seed): venv = VecFrameStack( make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(), start_index=num_env * MPI.COMM_WORLD.Get_rank(), max_episode_steps=hps.pop('max_episode_steps')), hps.pop('frame_stack')) # Size of states when stored in the memory. only_train_r = hps.pop('only_train_r') online_r_training = hps.pop('online_train_r') or only_train_r r_network_trainer = None save_path = hps.pop('save_path') r_network_weights_path = hps.pop('r_path') ''' ec_type = 'none' # hps.pop('ec_type') venv = CuriosityEnvWrapperFrameStack( make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(), start_index=num_env * MPI.COMM_WORLD.Get_rank(), max_episode_steps=hps.pop('max_episode_steps')), vec_episodic_memory = None, observation_embedding_fn = None, exploration_reward = ec_type, exploration_reward_min_step = 0, nstack = hps.pop('frame_stack'), only_train_r = only_train_r ) ''' # venv.score_multiple = {'Mario': 500, # 'MontezumaRevengeNoFrameskip-v4': 100, # 'GravitarNoFrameskip-v4': 250, # 'PrivateEyeNoFrameskip-v4': 500, # 'SolarisNoFrameskip-v4': None, # 'VentureNoFrameskip-v4': 200, # 'PitfallNoFrameskip-v4': 100, # }[env_id] venv.score_multiple = 1 venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False ob_space = venv.observation_space ac_space = venv.action_space gamma = hps.pop('gamma') log_interval = hps.pop('log_interval') nminibatches = hps.pop('nminibatches') play = hps.pop('play') if play: nsteps = 1 rnd_type = hps.pop('rnd_type') div_type = hps.pop('div_type') num_agents = hps.pop('num_agents') load_ram = hps.pop('load_ram') debug = hps.pop('debug') rnd_mask_prob = hps.pop('rnd_mask_prob') rnd_mask_type = hps.pop('rnd_mask_type') indep_rnd = hps.pop('indep_rnd') logger.info("indep_rnd:", indep_rnd) indep_policy = hps.pop('indep_policy') sd_type = hps.pop('sd_type') from_scratch = hps.pop('from_scratch') use_kl = hps.pop('use_kl') save_interval = 100 policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')] agent = PpoAgent( scope='ppo', ob_space=ob_space, ac_space=ac_space, stochpol_fn=functools.partial( policy, scope='pol', ob_space=ob_space, ac_space=ac_space, update_ob_stats_independently_per_gpu=hps.pop( 'update_ob_stats_independently_per_gpu'), proportion_of_exp_used_for_predictor_update=hps.pop( 'proportion_of_exp_used_for_predictor_update'), dynamics_bonus=hps.pop("dynamics_bonus"), num_agents=num_agents, rnd_type=rnd_type, div_type=div_type, indep_rnd=indep_rnd, indep_policy=indep_policy, sd_type=sd_type, rnd_mask_prob=rnd_mask_prob), gamma=gamma, gamma_ext=hps.pop('gamma_ext'), gamma_div=hps.pop('gamma_div'), lam=hps.pop('lam'), nepochs=hps.pop('nepochs'), nminibatches=nminibatches, lr=hps.pop('lr'), cliprange=0.1, nsteps=5 if debug else 128, ent_coef=0.001, max_grad_norm=hps.pop('max_grad_norm'), use_news=hps.pop("use_news"), comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None, update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'), int_coeff=hps.pop('int_coeff'), ext_coeff=hps.pop('ext_coeff'), log_interval=log_interval, only_train_r=only_train_r, rnd_type=rnd_type, reset=hps.pop('reset'), dynamics_sample=hps.pop('dynamics_sample'), save_path=save_path, num_agents=num_agents, div_type=div_type, load_ram=load_ram, debug=debug, rnd_mask_prob=rnd_mask_prob, rnd_mask_type=rnd_mask_type, sd_type=sd_type, from_scratch=from_scratch, use_kl=use_kl, indep_rnd=indep_rnd) load_path = hps.pop('load_path') base_load_path = hps.pop('base_load_path') agent.start_interaction([venv]) if load_path is not None: if play: agent.load(load_path) else: #agent.load(load_path) #agent.load_help_info(0, load_path) #agent.load_help_info(1, load_path) #load diversity agent #base_agent_idx = 1 #logger.info("load base agents weights from {} agent {}".format(base_load_path, str(base_agent_idx))) #agent.load_agent(base_agent_idx, base_load_path) #agent.clone_baseline_agent(base_agent_idx) #agent.load_help_info(0, dagent_load_path) #agent.clone_agent(0) #load main agen1 src_agent_idx = 1 logger.info("load main agent weights from {} agent {}".format( load_path, str(src_agent_idx))) agent.load_agent(src_agent_idx, load_path) if indep_rnd == False: rnd_agent_idx = 1 else: rnd_agent_idx = src_agent_idx #rnd_agent_idx = 0 logger.info("load rnd weights from {} agent {}".format( load_path, str(rnd_agent_idx))) agent.load_rnd(rnd_agent_idx, load_path) agent.clone_agent(rnd_agent_idx, rnd=True, policy=False, help_info=False) logger.info("load help info from {} agent {}".format( load_path, str(src_agent_idx))) agent.load_help_info(src_agent_idx, load_path) agent.clone_agent(src_agent_idx, rnd=False, policy=True, help_info=True) #logger.info("load main agent weights from {} agent {}".format(load_path, str(2))) #load_path = '/data/xupeifrom7700_1000/seed1_log0.5_clip-0.5~0.5_3agent_hasint4_2divrew_-1~1/models' #agent.load_agent(1, load_path) #agent.clone_baseline_agent() #if sd_type =='sd': # agent.load_sd("save_dir/models_sd_trained") #agent.initialize_discriminator() update_ob_stats_from_random_agent = hps.pop( 'update_ob_stats_from_random_agent') if play == False: if load_path is not None: pass #agent.collect_statistics_from_model() else: if update_ob_stats_from_random_agent and rnd_type == 'rnd': agent.collect_random_statistics(num_timesteps=128 * 5 if debug else 128 * 50) assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys()) #agent.collect_rnd_info(128*50) ''' if sd_type=='sd': agent.train_sd(max_nepoch=300, max_neps=5) path = '{}_sd_trained'.format(save_path) logger.log("save model:",path) agent.save(path) return #agent.update_diverse_agent(max_nepoch=1000) #path = '{}_divupdated'.format(save_path) #logger.log("save model:",path) #agent.save(path) ''' counter = 0 while True: info = agent.step() n_updates = agent.I.stats["n_updates"] if info['update']: logger.logkvs(info['update']) logger.dumpkvs() counter += 1 if info['update'] and save_path is not None and ( n_updates % save_interval == 0 or n_updates == 1): path = '{}_{}'.format(save_path, str(n_updates)) logger.log("save model:", path) agent.save(path) agent.save_help_info(save_path, n_updates) if agent.I.stats['tcount'] > num_timesteps: path = '{}_{}'.format(save_path, str(n_updates)) logger.log("save model:", path) agent.save(path) agent.save_help_info(save_path, n_updates) break agent.stop_interaction() else: ''' check_point_rews_list_path ='{}_rewslist'.format(load_path) check_point_rnd_path ='{}_rnd'.format(load_path) oracle_rnd = oracle.OracleExplorationRewardForAllEpisodes() oracle_rnd.load(check_point_rnd_path) #print(oracle_rnd._collected_positions_writer) #print(oracle_rnd._collected_positions_reader) rews_list = load_rews_list(check_point_rews_list_path) print(rews_list) ''' istate = agent.stochpol.initial_state(1) #ph_mean, ph_std = agent.stochpol.get_ph_mean_std() last_obs, prevrews, ec_rews, news, infos, ram_states, _ = agent.env_get( 0) agent.I.step_count += 1 flag = False show_cam = True last_xr = 0 restore = None ''' #path = 'ram_state_500_7room' #path='ram_state_400_6room' #path='ram_state_6700' path='ram_state_7700_10room' f = open(path,'rb') restore = pickle.load(f) f.close() last_obs[0] = agent.I.venvs[0].restore_full_state_by_idx(restore,0) print(last_obs.shape) #path = 'ram_state_400_monitor_rews_6room' #path = 'ram_state_500_monitor_rews_7room' #path='ram_state_6700_monitor_rews' path='ram_state_7700_monitor_rews_10room' f = open(path,'rb') monitor_rews = pickle.load(f) f.close() agent.I.venvs[0].set_cur_monitor_rewards_by_idx(monitor_rews,0) ''' agent_idx = np.asarray([0]) sample_agent_prob = np.asarray([0.5]) ph_mean = agent.stochpol.ob_rms_list[0].mean ph_std = agent.stochpol.ob_rms_list[0].var**0.5 buf_ph_mean = np.zeros( ([1, 1] + list(agent.stochpol.ob_space.shape[:2]) + [1]), np.float32) buf_ph_std = np.zeros( ([1, 1] + list(agent.stochpol.ob_space.shape[:2]) + [1]), np.float32) buf_ph_mean[0, 0] = ph_mean buf_ph_std[0, 0] = ph_std vpreds_ext_list = [] ep_rews = np.zeros((1)) divexp_flag = False step_count = 0 stage_prob = True last_rew_ob = np.full_like(last_obs, 128) clusters = Clusters(1.0) #path = '{}_sd_rms'.format(load_path) #agent.I.sd_rms.load(path) while True: dict_obs = agent.stochpol.ensure_observation_is_dict(last_obs) #acs= np.random.randint(low=0, high=15, size=(1)) acs, vpreds_int, vpreds_ext, nlps, istate, ent = agent.stochpol.call( dict_obs, news, istate, agent_idx[:, None]) step_acs = acs t = '' #if show_cam==True: t = input("input:") if t != '': t = int(t) if t <= 17: step_acs = [t] agent.env_step(0, step_acs) obs, prevrews, ec_rews, news, infos, ram_states, monitor_rews = agent.env_get( 0) if news[0] and restore is not None: obs[0] = agent.I.venvs[0].restore_full_state_by_idx(restore, 0) agent.I.venvs[0].set_cur_monitor_rewards_by_idx( monitor_rews, 0) ep_rews = ep_rews + prevrews print(ep_rews) last_rew_ob[prevrews > 0] = obs[prevrews > 0] room = infos[0]['position'][2] vpreds_ext_list.append([vpreds_ext, room]) #print(monitor_rews[0]) #print(len(monitor_rews[0])) #print(infos[0]['open_door_type']) stack_obs = np.concatenate([last_obs[:, None], obs[:, None]], 1) fd = {} fd[agent.stochpol.ph_ob[None]] = stack_obs fd.update({ agent.stochpol.sep_ph_mean: buf_ph_mean, agent.stochpol.sep_ph_std: buf_ph_std }) fd[agent.stochpol.ph_agent_idx] = agent_idx[:, None] fd[agent.stochpol.sample_agent_prob] = sample_agent_prob[:, None] fd[agent.stochpol.last_rew_ob] = last_rew_ob[:, None] fd[agent.stochpol.game_score] = ep_rews[:, None] fd[agent.stochpol.sd_ph_mean] = agent.I.sd_rms.mean fd[agent.stochpol.sd_ph_std] = agent.I.sd_rms.var**0.5 div_prob = 0 all_div_prob = tf_util.get_session().run( [agent.stochpol.all_div_prob], fd) ''' if prevrews[0] > 0: clusters.update(rnd_em,room) num_clusters = len(clusters._cluster_list) for i in range(num_clusters): print("{} {}".format(str(i),list(clusters._room_set[i]))) ''' print("vpreds_int: ", vpreds_int, "vpreds_ext:", vpreds_ext, "ent:", ent, "all_div_prob:", all_div_prob, "room:", room, "step_count:", step_count) #aaaa = np.asarray(vpreds_ext_list) #print(aaaa[-100:]) ''' if step_acs[0]==0: ram_state = ram_states[0] path='ram_state_7700_10room' f = open(path,'wb') pickle.dump(ram_state,f) f.close() path='ram_state_7700_monitor_rews_10room' f = open(path,'wb') pickle.dump(monitor_rews[0],f) f.close() ''' ''' if restore is None: restore = ram_states[0] if np.random.rand() < 0.1: print("restore") obs = agent.I.venvs[0].restore_full_state_by_idx(restore,0) prevrews = None ec_rews = None news= True infos = {} ram_states = ram_states[0] #restore = ram_states[0] ''' img = agent.I.venvs[0].render() last_obs = obs step_count = step_count + 1 time.sleep(0.04)