예제 #1
0
def log_reward_statistics(vec_env, num_last_eps=100, prefix=""):
    all_stats = None
    for _ in range(10):
        try:
            all_stats = load_results(osp.dirname(
                vec_env.results_writer.f.name))
        except FileNotFoundError:
            time.sleep(1)
            continue
    if all_stats is not None:
        episode_rewards = all_stats["r"]
        episode_lengths = all_stats["l"]

        recent_episode_rewards = episode_rewards[-num_last_eps:]
        recent_episode_lengths = episode_lengths[-num_last_eps:]

        if len(recent_episode_rewards) > 0:
            kvs = {
                prefix + "AverageReturn": np.mean(recent_episode_rewards),
                prefix + "MinReturn": np.min(recent_episode_rewards),
                prefix + "MaxReturn": np.max(recent_episode_rewards),
                prefix + "StdReturn": np.std(recent_episode_rewards),
                prefix + "AverageEpisodeLength":
                np.mean(recent_episode_lengths),
                prefix + "MinEpisodeLength": np.min(recent_episode_lengths),
                prefix + "MaxEpisodeLength": np.max(recent_episode_lengths),
                prefix + "StdEpisodeLength": np.std(recent_episode_lengths),
            }
            logger.logkvs(kvs)
        logger.logkv(prefix + "TotalNEpisodes", len(episode_rewards))
예제 #2
0
    def train(self, saver, logger_dir):
        # 初始化计算图, 初始化 rollout 类
        self.agent.start_interaction(self.envs,
                                     nlump=self.hps['nlumps'],
                                     dynamics=self.dynamics)
        previous_saved_tcount = 0
        while True:
            info = self.agent.step()  # 与环境交互一个周期, 收集样本, 计算内在激励, 并训练
            if info['update']:
                logger.logkvs(info['update'])
                logger.dumpkvs()
            if self.hps["save_period"] and (int(
                    self.agent.rollout.stats['tcount'] / self.hps["save_freq"])
                                            > previous_saved_tcount):
                previous_saved_tcount += 1
                save_path = saver.save(
                    tf.get_default_session(),
                    os.path.join(
                        logger_dir,
                        "model_" + str(previous_saved_tcount) + ".ckpt"))
                print("Periodically model saved in path:", save_path)
            if self.agent.rollout.stats['tcount'] > self.num_timesteps:
                save_path = saver.save(
                    tf.get_default_session(),
                    os.path.join(logger_dir, "model_last.ckpt"))
                print("Model saved in path:", save_path)
                break

        self.agent.stop_interaction()
예제 #3
0
파일: run.py 프로젝트: ijcai-261/ijcai-261
    def train(self):
        if self.save_checkpoint:
            params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            saver = tf.train.Saver(var_list=params,
                                   max_to_keep=self.num_timesteps // 1000000 +
                                   1)

            periods = list(range(0, self.num_timesteps + 1, 1000000))
            idx = 0

        self.agent.start_interaction(self.envs,
                                     nlump=self.hps['nlumps'],
                                     dynamics=self.dynamics)
        while True:
            info = self.agent.step()

            if info['update']:
                logger.logkvs(info['update'])
                logger.dumpkvs()

            if self.save_checkpoint:
                if self.agent.rollout.stats['tcount'] >= periods[idx]:
                    self.save(saver,
                              logger.get_dir() + '/checkpoint/', periods[idx])
                    idx += 1

            if self.agent.rollout.stats['tcount'] > self.num_timesteps:
                break

        self.agent.stop_interaction()
예제 #4
0
 def train(self, saver, sess, restore=False):
     from baselines import logger
     self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamics=self.dynamics)
     if restore:
         print("Restoring model for training")
         saver.restore(sess, "models/" + self.hps['restore_model'] + ".ckpt")
         print("Loaded model", self.hps['restore_model'])
     write_meta_graph = False
     while True:
         info = self.agent.step()
         if info['update']:
             if info['update']['recent_best_ext_ret'] is None:
                 info['update']['recent_best_ext_ret'] = 0
             wandb.log(info['update'])
             logger.logkvs(info['update'])
             logger.dumpkvs()
         if self.agent.rollout.stats['tcount'] > self.num_timesteps:
             break
     if self.hps['tune_env']:
         filename = "models/" + self.hps['restore_model'] + "_tune_on_" + self.hps['tune_env'] + "_final.ckpt"
     else:
         filename = "models/" + self.hps['exp_name'] + "_final.ckpt"
     saver.save(sess, filename, write_meta_graph=False)
     self.policy.save_model(self.hps['exp_name'], 'final')
     self.agent.stop_interaction()
예제 #5
0
    def train(self):
        import random

        self.agent.start_interaction(self.envs,
                                     nlump=self.hps["nlumps"],
                                     dynamics=self.dynamics)
        count = 0
        while True:
            count += 1
            info = self.agent.step()
            if info["update"]:
                logger.logkvs(info["update"])
                logger.dumpkvs()
            if self.hps["feat_learning"] == "pix2pix":
                making_video = random.choice(99 * [False] + [True])
            else:
                making_video = False
            self.agent.rollout.making_video = making_video
            for a_key in info.keys():
                wandb.log(info[a_key])
            wandb.log(
                {"average_sigma": np.mean(self.agent.rollout.buf_sigmas)})
            # going to have to log it here
            if self.agent.rollout.stats["tcount"] > self.num_timesteps:
                break

        self.agent.stop_interaction()
예제 #6
0
    def train(self, saver, sess, restore=False):

        self.agent.start_interaction(self.envs,
                                     nlump=self.hps['nlumps'],
                                     dynamics=self.dynamics)
        write_meta_graph = False
        saves = 0
        loops = 0
        while True:

            info = self.agent.step(eval=False)

            if info is not None:
                if info['update'] and not restore:
                    logger.logkvs(info['update'])
                    logger.dumpkvs()

            steps = self.agent.rollout.stats['tcount']

            if loops % 10 == 0:
                filename = args.saved_model_dir + 'model.ckpt'
                saver.save(sess,
                           filename,
                           global_step=int(saves),
                           write_meta_graph=False)
                saves += 1
            loops += 1

            if steps > self.num_timesteps:
                break

        self.agent.stop_interaction()
예제 #7
0
    def train(self):
        self.agent.start_interaction(self.envs,
                                     nlump=self.hps['nlumps'],
                                     dynamics=self.dynamics)
        save_path = 'models'
        tf_sess = tf.get_default_session()
        # Create a saver.
        saver = tf.train.Saver(save_relative_paths=True)
        # if self.hps['restore_latest_checkpoint']:
        # Restore latest checkpoint if set in arguments
        # saver.restore(tf_sess, tf.train.latest_checkpoint(save_path))
        while True:
            info = self.agent.step()
            if info['update']:
                logger.logkvs(info['update'])
                logger.dumpkvs()
            if self.agent.rollout.stat['tcount'] > self.num_timesteps:
                break
            # Saving the model every 1,000 steps.
            if info['n_updates'] % 1000 == 0:
                # Append the step number to the checkpoint name:
                saver.save(tf_sess,
                           save_path + '/obstacle_tower',
                           global_step=int(self.agent.rollout.stats['tcount']))

        # Append the step number to the last checkpoint name:
        saver.save(tf_sess,
                   save_path + '/obstacle_tower',
                   global_step=int(self.agent.rollout.stats['tcount']))
        self.agent.stop_interaction()
예제 #8
0
    def train(self):
        self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamics=self.action_dynamics)
        expdir = osp.join("/result", self.hps['env'], self.hps['exp_name'])
        save_checkpoints = []
        if self.hps['save_interval'] is not None:
            save_checkpoints = [i*self.hps['save_interval'] for i in range(1, self.hps['num_timesteps']//self.hps['save_interval'])]
        if self.hps['load_dir'] is not None:
            self.train_feature_extractor.load(self.hps['load_dir'])
            self.train_dynamics.load(self.hps['load_dir'])

        while True:
            info = self.agent.step()
            if info['update']:
                logger.logkvs(info['update'])
                logger.dumpkvs()
            if len(save_checkpoints) > 0:
                if self.agent.rollout.stats['tcount'] > save_checkpoints[0]:
                    self.train_feature_extractor.save(expdir, self.agent.rollout.stats['tcount'])
                    self.train_dynamics.save(expdir, self.agent.rollout.stats['tcount'])
                    save_checkpoints.remove(save_checkpoints[0])
            if self.agent.rollout.stats['tcount'] > self.num_timesteps:
                break

        if self.hps['save_dynamics'] and MPI.COMM_WORLD.Get_rank()== 0:       # save auxilary task and dynamics parameter
            self.train_feature_extractor.save(expdir)
            self.train_dynamics.save(expdir)
        self.agent.stop_interaction()
예제 #9
0
    def train(self):
        self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamics=self.dynamics)
        while True:
            info = self.agent.step()
            if info['update']:
                logger.logkvs(info['update'])
                logger.dumpkvs()
            if self.agent.rollout.stats['tcount'] > self.num_timesteps:
                break

        self.agent.stop_interaction()
예제 #10
0
    def train(self):
        self.agent.start_interaction(self.envs,
                                     nlump=self.hps['nlumps'],
                                     intrinsic_model=self.intrinsic_model)

        sess = getsess()

        while True:
            info = self.agent.step()
            if info['update']:
                logger.logkvs(info['update'])
                logger.dumpkvs()
            if self.agent.rollout.stats['tcount'] > self.num_timesteps:
                break

        self.agent.stop_interaction()
예제 #11
0
    def train(self):
        self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamics=self.dynamics)
        while True:
            info = self.agent.step()
            if info['update']:
                logger.logkvs(info['update'])
                logger.dumpkvs()
            if self.agent.rollout.stats['tcount'] == 0:
                fname = os.path.join(self.hps['save_dir'], 'checkpoints')
                if os.path.exists(fname+'.index'):
                    load_state(fname)
                    print('load successfully')
                else:
                    print('fail to load')
            if self.agent.rollout.stats['tcount']%int(self.num_timesteps/self.num_timesteps)==0:
                fname = os.path.join(self.hps['save_dir'], 'checkpoints')
                save_state(fname)
            if self.agent.rollout.stats['tcount'] > self.num_timesteps:
                break
            # print(self.agent.rollout.stats['tcount'])

        self.agent.stop_interaction()
예제 #12
0
def make_save_dir_and_log_basics(argdict):
    if not gflag.save_dir:
        assert not gflag.resumable, "You cannot set --resumable without setting --save-dir."
    else:
        assert gflag.resumable or (not os.path.exists(gflag.save_dir)), \
          "--save_dir '%s' already exists and resumable is False. " % (gflag.save_dir) + \
          "This might be because condor killed and rescheduled the original task."  + \
          "To prevent log.txt being overwritten/appended by a possibly different model's log, " + \
          "the program will terminate now."
        os.makedirs(gflag.save_dir, exist_ok=True)
        logger.configure(gflag.save_dir, format_strs=['log', 'stdout'])
        logger.logkvs(argdict)
        logger.dumpkvs()

        # copy related py files to save_dir to generate a snapshot of code being run
        snapshot_dir = gflag.save_dir + "/all_py_files_snapshot/"
        py_files = subprocess.check_output("find baselines | grep '\\.py$'",
                                           shell=True).decode('utf-8').split()
        py_files += subprocess.check_output(
            "ls *.py", shell=True).decode('utf-8').split()
        for py_file in py_files:
            os.makedirs(snapshot_dir + os.path.dirname(py_file), exist_ok=True)
            shutil.copyfile(py_file, snapshot_dir + py_file)
예제 #13
0
    def test(self, saver, sess):
        self.agent.start_interaction(self.envs,
                                     nlump=self.hps['nlumps'],
                                     dynamics=self.dynamics)
        print('loading model')
        saver.restore(sess, args.saved_model_dir + args.model_name)
        print('loaded model,', args.saved_model_dir + args.model_name)

        include_images = args.include_images and eval
        info = self.agent.step(eval=True, include_images=include_images)

        if info['update']:
            logger.logkvs(info['update'])
            logger.dumpkvs()

        # save actions, news, and / or images
        np.save(args.env + '_data.npy', info)

        print('EVALUATION COMPLETED')
        print('SAVED DATA IN CURRENT DIRECTORY')
        print('FILENAME', args.env + '_data.npy')

        self.agent.stop_interaction()
def train(*, env_id, num_env, hps, num_timesteps, seed):
    venv = VecFrameStack(
        make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(),
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        hps.pop('frame_stack'))
    # venv.score_multiple = {'Mario': 500,
    #                        'MontezumaRevengeNoFrameskip-v4': 100,
    #                        'GravitarNoFrameskip-v4': 250,
    #                        'PrivateEyeNoFrameskip-v4': 500,
    #                        'SolarisNoFrameskip-v4': None,
    #                        'VentureNoFrameskip-v4': 200,
    #                        'PitfallNoFrameskip-v4': 100,
    #                        }[env_id]
    venv.score_multiple = 1
    venv.record_obs = True
    # venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop('gamma')
    policy = {'rnn': CnnGruPolicy,
              'cnn': CnnPolicy}[hps.pop('policy')]
    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
                scope='pol',
                ob_space=ob_space,
                ac_space=ac_space,
                update_ob_stats_independently_per_gpu=hps.pop('update_ob_stats_independently_per_gpu'),
                proportion_of_exp_used_for_predictor_update=hps.pop('proportion_of_exp_used_for_predictor_update'),
                dynamics_bonus = hps.pop("dynamics_bonus")
            ),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=hps.pop('nminibatches'),
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'),
    )
    agent.start_interaction([venv])
    if hps.pop('update_ob_stats_from_random_agent'):
        agent.collect_random_statistics(num_timesteps=128*50)
    assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())

    counter = 0
    while True:
        info = agent.step()
        if info['update']:
            logger.logkvs(info['update'])
            logger.dumpkvs()
            counter += 1
        if agent.I.stats['tcount'] > num_timesteps:
            break

    agent.stop_interaction()

    return agent
예제 #15
0
def train(*, env_id, num_env, hps, num_timesteps, seed):
    experiment = os.environ.get('EXPERIMENT_LVL')
    if experiment == 'ego':
        # for the ego experiment we needed a higher intrinsic coefficient
        hps['int_coeff'] = 3.0

    hyperparams = copy(hps)
    hyperparams.update({'seed': seed})
    logger.info("Hyperparameters:")
    logger.info(hyperparams)

    venv = VecFrameStack(
        make_atari_env(env_id,
                       num_env,
                       seed,
                       wrapper_kwargs={},
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        hps.pop('frame_stack'))
    venv.score_multiple = {
        'Mario': 500,
        'MontezumaRevengeNoFrameskip-v4': 1,
        'GravitarNoFrameskip-v4': 250,
        'PrivateEyeNoFrameskip-v4': 500,
        'SolarisNoFrameskip-v4': None,
        'VentureNoFrameskip-v4': 200,
        'PitfallNoFrameskip-v4': 100,
    }[env_id]

    venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False
    ob_space = venv.observation_space
    ac_space = venv.action_space

    gamma = hps.pop('gamma')
    policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')]

    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope='pol',
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                'update_ob_stats_independently_per_gpu'),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                'proportion_of_exp_used_for_predictor_update'),
            dynamics_bonus=hps.pop("dynamics_bonus")),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=hps.pop('nminibatches'),
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'),
        restore_model_path=hps.pop('restore_model_path'))
    agent.start_interaction([venv])
    if hps.pop('update_ob_stats_from_random_agent'):
        agent.collect_random_statistics(num_timesteps=128 * 50)

    save_model = hps.pop('save_model')
    assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())

    #profiler = cProfile.Profile()
    #profiler.enable()
    #tracemalloc.start()
    #prev_snap = tracemalloc.take_snapshot()
    counter = 0
    while True:
        info = agent.step()
        if info['update']:
            logger.logkvs(info['update'])
            logger.dumpkvs()
            counter += 1

            # if (counter % 10) == 0:
            #     snapshot = tracemalloc.take_snapshot()
            #     top_stats = snapshot.compare_to(prev_snap, 'lineno')
            #     for stat in top_stats[:10]:
            #         print(stat)
            #     prev_snap = snapshot
            #     profiler.dump_stats("profile_rnd")

            if (counter % 100) == 0 and save_model:
                agent.save_model(agent.I.step_count)

        if agent.I.stats['tcount'] > num_timesteps:
            break

    agent.stop_interaction()
예제 #16
0
파일: run.py 프로젝트: sahanayvaz/rss
    def train(self):
        curr_iter = 0

        # train progress results logger
        format_strs = ['csv']
        format_strs = filter(None, format_strs)
        dirc = os.path.join(self.args['log_dir'], 'inter')
        if self.restore_iter > -1:
            dirc = os.path.join(self.args['log_dir'],
                                'inter-{}'.format(self.restore_iter))
        output_formats = [
            logger.make_output_format(f, dirc) for f in format_strs
        ]
        self.result_logger = logger.Logger(dir=dirc,
                                           output_formats=output_formats)

        # in case we are restoring the training
        if self.restore_iter > -1:
            self.agent.load(self.load_path)
            if not self.args['transfer_load']:
                curr_iter = self.restore_iter

        print('max_iter: {}'.format(self.max_iter))

        # interim saves to compare in the future
        # for 128M frames,

        inter_save = []
        for i in range(3):
            divisor = (2**(i + 1))
            inter_save.append(
                int(self.args['num_timesteps'] // divisor) //
                (self.args['nsteps'] * self.args['NUM_ENVS'] *
                 self.args['nframeskip']))
        print('inter_save: {}'.format(inter_save))

        total_time = 0.0
        # results_list = []

        while curr_iter < self.early_max_iter:
            frac = 1.0 - (float(curr_iter) / self.max_iter)

            # self.agent.update calls rollout
            start_time = time.time()

            ## linearly annealing
            curr_lr = self.lr(frac)
            curr_cr = self.cliprange(frac)

            ## removed within training evaluation
            ## i could not make flag_sum to work properly
            ## evaluate each 100 run for 20 training levels
            # only for mario (first evaluate, then update)
            # i am doing change to get zero-shot generalization without any effort
            if curr_iter % (self.args['save_interval']) == 0:
                save_video = False
                nlevels = 20 if self.args[
                    'env_kind'] == 'mario' else self.args['NUM_LEVELS']
                results, _ = self.agent.evaluate(nlevels, save_video)
                results['iter'] = curr_iter
                for (k, v) in results.items():
                    self.result_logger.logkv(k, v)
                self.result_logger.dumpkvs()

            # representation learning in each 25 steps
            info = self.agent.update(lr=curr_lr, cliprange=curr_cr)
            end_time = time.time()

            # additional info
            info['frac'] = frac
            info['curr_lr'] = curr_lr
            info['curr_cr'] = curr_cr
            info['curr_iter'] = curr_iter
            # info['max_iter'] = self.max_iter
            info['elapsed_time'] = end_time - start_time
            # info['total_time'] = total_time = (total_time + info['elapsed_time']) / 3600.0
            info['expected_time'] = self.max_iter * info[
                'elapsed_time'] / 3600.0

            ## logging results using baselines's logger
            logger.logkvs(info)
            logger.dumpkvs()

            if curr_iter % self.args['save_interval'] == 0:
                self.agent.save(curr_iter, cliprange=curr_cr)

            if curr_iter in inter_save:
                self.agent.save(curr_iter, cliprange=curr_cr)

            curr_iter += 1

        self.agent.save(curr_iter, cliprange=curr_cr)

        # final evaluation for mario
        save_video = False
        nlevels = 20 if self.args['env_kind'] == 'mario' else self.args[
            'NUM_LEVELS']
        results, _ = self.agent.evaluate(nlevels, save_video)
        results['iter'] = curr_iter
        for (k, v) in results.items():
            self.result_logger.logkv(k, v)
        self.result_logger.dumpkvs()
예제 #17
0
파일: run_atari.py 프로젝트: Baichenjia/CB
def train(*, env_id, num_env, hps, num_timesteps, seed):
    venv = VecFrameStack(
        make_atari_env(env_id,
                       num_env,
                       seed,
                       wrapper_kwargs=dict(),
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        hps.pop('frame_stack'))
    venv.score_multiple = 1
    venv.record_obs = False
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop('gamma')
    policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')]

    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope='pol',
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                'update_ob_stats_independently_per_gpu'),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                'proportion_of_exp_used_for_predictor_update'),
            exploration_type=hps.pop("exploration_type"),
            beta=hps.pop("beta"),
        ),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=hps.pop('nminibatches'),
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'),
        noise_type=hps.pop('noise_type'),
        noise_p=hps.pop('noise_p'),
        use_sched=hps.pop('use_sched'),
        num_env=num_env,
        exp_name=hps.pop('exp_name'),
    )
    agent.start_interaction([venv])
    if hps.pop('update_ob_stats_from_random_agent'):
        agent.collect_random_statistics(num_timesteps=128 * 50)
    assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())

    counter = 0
    while True:
        info = agent.step()
        n_updates = 0

        if info['update']:
            logger.logkvs(info['update'])
            logger.dumpkvs()

            if NSML:
                n_updates = int(info['update']['n_updates'])
                nsml_dict = {
                    k: np.float64(v)
                    for k, v in info['update'].items()
                    if isinstance(v, Number)
                }
                nsml.report(step=n_updates, **nsml_dict)

            counter += 1
        #if n_updates >= 40*1000: # 40K updates
        #    break
        if agent.I.stats['tcount'] > num_timesteps:
            break

    agent.stop_interaction()
예제 #18
0
def train(*, env_id, num_env, hps, num_timesteps, seed):
    if "NoFrameskip" in env_id:
        env_factory = make_atari_env
    else:
        env_factory = make_non_atari_env

    venv = VecFrameStack(
        env_factory(
            env_id,
            num_env,
            seed,
            wrapper_kwargs=dict(),
            start_index=num_env * MPI.COMM_WORLD.Get_rank(),
            max_episode_steps=hps.pop("max_episode_steps"),
        ),
        hps.pop("frame_stack"),
    )
    # venv.score_multiple = {'Mario': 500,
    #                        'MontezumaRevengeNoFrameskip-v4': 100,
    #                        'GravitarNoFrameskip-v4': 250,
    #                        'PrivateEyeNoFrameskip-v4': 500,
    #                        'SolarisNoFrameskip-v4': None,
    #                        'VentureNoFrameskip-v4': 200,
    #                        'PitfallNoFrameskip-v4': 100,
    #                        }[env_id]
    venv.score_multiple = 1
    venv.record_obs = True if env_id == "SolarisNoFrameskip-v4" else False
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop("gamma")
    policy = {
        "rnn": CnnGruPolicy,
        "cnn": CnnPolicy,
        'ffnn': GruPolicy
    }[hps.pop("policy")]
    agent = PpoAgent(
        scope="ppo",
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope="pol",
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                "update_ob_stats_independently_per_gpu"),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                "proportion_of_exp_used_for_predictor_update"),
            dynamics_bonus=hps.pop("dynamics_bonus"),
            meta_rl=hps['meta_rl']),
        gamma=gamma,
        gamma_ext=hps.pop("gamma_ext"),
        lam=hps.pop("lam"),
        nepochs=hps.pop("nepochs"),
        nminibatches=hps.pop("nminibatches"),
        lr=hps.pop("lr"),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop("max_grad_norm"),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop("update_ob_stats_every_step"),
        int_coeff=hps.pop("int_coeff"),
        ext_coeff=hps.pop("ext_coeff"),
        meta_rl=hps.pop('meta_rl'))
    agent.start_interaction([venv])
    if hps.pop("update_ob_stats_from_random_agent"):
        agent.collect_random_statistics(num_timesteps=128 * 50)
    assert len(hps) == 0, f"Unused hyperparameters: {list(hps.keys())}"

    counter = 0
    while True:
        info = agent.step()
        if info["update"]:
            logger.logkvs(info["update"])
            logger.dumpkvs()
            counter += 1
        if agent.I.stats["tcount"] > num_timesteps:
            break

    agent.stop_interaction()
예제 #19
0
파일: train.py 프로젝트: minqi/auto-drac
def train(args):
    xpid = '-{}-{}-reproduce-s{}'.format(args.run_name, args.env_name,
                                         args.seed)

    wandb.init(project=args.wandb_project,
               entity=args.wandb_entity,
               config=args,
               name=xpid)
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    if args.cuda:
        print('Using CUDA')
    else:
        print('Not using CUDA')

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    log_dir = os.path.expanduser(args.log_dir)
    utils.cleanup_log_dir(log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    log_file = xpid

    venv = ProcgenEnv(num_envs=args.num_processes, env_name=args.env_name, \
        num_levels=args.num_levels, start_level=args.start_level, \
        distribution_mode=args.distribution_mode)
    venv = VecExtractDictObs(venv, "rgb")
    venv = VecMonitor(venv=venv, filename=None, keep_buf=100)
    venv = VecNormalize(venv=venv, ob=False)
    envs = VecPyTorchProcgen(venv, device)

    obs_shape = envs.observation_space.shape
    actor_critic = Policy(obs_shape,
                          envs.action_space.n,
                          base_kwargs={
                              'recurrent': False,
                              'hidden_size': args.hidden_size
                          })
    actor_critic.to(device)

    rollouts = RolloutStorage(args.num_steps,
                              args.num_processes,
                              envs.observation_space.shape,
                              envs.action_space,
                              actor_critic.recurrent_hidden_state_size,
                              aug_type=args.aug_type,
                              split_ratio=args.split_ratio)

    batch_size = int(args.num_processes * args.num_steps / args.num_mini_batch)

    if args.use_ucb:
        print('Using UCB')
        aug_id = data_augs.Identity
        aug_list = [
            aug_to_func[t](batch_size=batch_size)
            for t in list(aug_to_func.keys())
        ]

        agent = algo.UCBDrAC(actor_critic,
                             args.clip_param,
                             args.ppo_epoch,
                             args.num_mini_batch,
                             args.value_loss_coef,
                             args.entropy_coef,
                             lr=args.lr,
                             eps=args.eps,
                             max_grad_norm=args.max_grad_norm,
                             aug_list=aug_list,
                             aug_id=aug_id,
                             aug_coef=args.aug_coef,
                             num_aug_types=len(list(aug_to_func.keys())),
                             ucb_exploration_coef=args.ucb_exploration_coef,
                             ucb_window_length=args.ucb_window_length)

    elif args.use_meta_learning:
        aug_id = data_augs.Identity
        aug_list = [aug_to_func[t](batch_size=batch_size) \
            for t in list(aug_to_func.keys())]

        aug_model = AugCNN()
        aug_model.to(device)

        agent = algo.MetaDrAC(actor_critic,
                              aug_model,
                              args.clip_param,
                              args.ppo_epoch,
                              args.num_mini_batch,
                              args.value_loss_coef,
                              args.entropy_coef,
                              meta_grad_clip=args.meta_grad_clip,
                              meta_num_train_steps=args.meta_num_train_steps,
                              meta_num_test_steps=args.meta_num_test_steps,
                              lr=args.lr,
                              eps=args.eps,
                              max_grad_norm=args.max_grad_norm,
                              aug_id=aug_id,
                              aug_coef=args.aug_coef)

    elif args.use_rl2:
        aug_id = data_augs.Identity
        aug_list = [
            aug_to_func[t](batch_size=batch_size)
            for t in list(aug_to_func.keys())
        ]

        rl2_obs_shape = [envs.action_space.n + 1]
        rl2_learner = Policy(rl2_obs_shape,
                             len(list(aug_to_func.keys())),
                             base_kwargs={
                                 'recurrent': True,
                                 'hidden_size': args.rl2_hidden_size
                             })
        rl2_learner.to(device)

        agent = algo.RL2DrAC(actor_critic,
                             rl2_learner,
                             args.clip_param,
                             args.ppo_epoch,
                             args.num_mini_batch,
                             args.value_loss_coef,
                             args.entropy_coef,
                             args.rl2_entropy_coef,
                             lr=args.lr,
                             eps=args.eps,
                             rl2_lr=args.rl2_lr,
                             rl2_eps=args.rl2_eps,
                             max_grad_norm=args.max_grad_norm,
                             aug_list=aug_list,
                             aug_id=aug_id,
                             aug_coef=args.aug_coef,
                             num_aug_types=len(list(aug_to_func.keys())),
                             recurrent_hidden_size=args.rl2_hidden_size,
                             num_actions=envs.action_space.n,
                             device=device)

    else:
        aug_id = data_augs.Identity
        aug_func = aug_to_func[args.aug_type](batch_size=batch_size)

        agent = algo.DrAC(actor_critic,
                          args.clip_param,
                          args.ppo_epoch,
                          args.num_mini_batch,
                          args.value_loss_coef,
                          args.entropy_coef,
                          lr=args.lr,
                          eps=args.eps,
                          max_grad_norm=args.max_grad_norm,
                          aug_id=aug_id,
                          aug_func=aug_func,
                          aug_coef=args.aug_coef,
                          env_name=args.env_name)

    checkpoint_path = os.path.join(args.save_dir, "agent" + log_file + ".pt")
    if os.path.exists(checkpoint_path) and args.preempt:
        checkpoint = torch.load(checkpoint_path)
        agent.actor_critic.load_state_dict(checkpoint['model_state_dict'])
        agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        init_epoch = checkpoint['epoch'] + 1
        logger.configure(dir=args.log_dir,
                         format_strs=['csv', 'stdout'],
                         log_suffix=log_file + "-e%s" % init_epoch)
    else:
        init_epoch = 0
        logger.configure(dir=args.log_dir,
                         format_strs=['csv', 'stdout'],
                         log_suffix=log_file)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes

    for j in range(init_epoch, num_updates):
        actor_critic.train()
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                obs_id = aug_id(rollouts.obs[step])
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    obs_id, rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])

            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            obs_id = aug_id(rollouts.obs[-1])
            next_value = actor_critic.get_value(
                obs_id, rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.gamma, args.gae_lambda)

        if args.use_ucb and j > 0:
            agent.update_ucb_values(rollouts)
        value_loss, action_loss, dist_entropy = agent.update(rollouts)
        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        total_num_steps = (j + 1) * args.num_processes * args.num_steps
        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            print(
                "\nUpdate {}, step {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}"
                .format(j, total_num_steps, len(episode_rewards),
                        np.mean(episode_rewards), np.median(episode_rewards),
                        dist_entropy, value_loss, action_loss))

            ### Eval on the Full Distribution of Levels ###
            eval_episode_rewards = evaluate(args,
                                            actor_critic,
                                            device,
                                            aug_id=aug_id)

            stats = {
                "train/nupdates": j,
                "train/total_num_steps": total_num_steps,
                "losses/dist_entropy": dist_entropy,
                "losses/value_loss": value_loss,
                "losses/action_loss": action_loss,
                "train/mean_episode_reward": np.mean(episode_rewards),
                "train/median_episode_reward": np.median(episode_rewards),
                "test/mean_episode_reward": np.mean(eval_episode_rewards),
                "test/median_episode_reward": np.median(eval_episode_rewards)
            }

            logger.logkvs(stats)
            logger.dumpkvs()
            wandb.log(stats)

        # Save Model
        if (j > 0 and j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            try:
                os.makedirs(args.save_dir)
            except OSError:
                pass

            torch.save(
                {
                    'epoch': j,
                    'model_state_dict': agent.actor_critic.state_dict(),
                    'optimizer_state_dict': agent.optimizer.state_dict(),
                }, os.path.join(args.save_dir, "agent" + log_file + ".pt"))
예제 #20
0
def train(*, env_id, num_env, hps, num_timesteps, seed, use_reward, ep_path,
          dmlab):
    venv = VecFrameStack(
        DummyVecEnv([
            lambda: CollectGymDataset(Minecraft('MineRLTreechop-v0', 'treechop'
                                                ),
                                      ep_path,
                                      atari=False)
        ]), hps.pop('frame_stack'))
    venv.score_multiple = 1
    venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop('gamma')
    policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')]
    print('Running train ==========================================')
    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope='pol',
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                'update_ob_stats_independently_per_gpu'),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                'proportion_of_exp_used_for_predictor_update'),
            dynamics_bonus=hps.pop("dynamics_bonus")),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=hps.pop('nminibatches'),
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'))
    agent.start_interaction([venv])
    if hps.pop('update_ob_stats_from_random_agent'):
        agent.collect_random_statistics(num_timesteps=128 * 50)
    assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())

    counter = 0
    while True:
        info = agent.step()
        if info['update']:
            logger.logkvs(info['update'])
            logger.dumpkvs()
            counter += 1
        if agent.I.stats['tcount'] > num_timesteps:
            break

    agent.stop_interaction()
예제 #21
0
def train(*, env_id, num_env, hps, num_timesteps, seed):

    venv = VecFrameStack(
        make_atari_env(env_id,
                       num_env,
                       seed,
                       wrapper_kwargs=dict(),
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        hps.pop('frame_stack'))

    # Size of states when stored in the memory.
    only_train_r = hps.pop('only_train_r')

    online_r_training = hps.pop('online_train_r') or only_train_r

    r_network_trainer = None
    save_path = hps.pop('save_path')
    r_network_weights_path = hps.pop('r_path')
    '''
    ec_type = 'none' # hps.pop('ec_type')

    venv = CuriosityEnvWrapperFrameStack(
        make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(),
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        vec_episodic_memory = None,
        observation_embedding_fn = None,
        exploration_reward = ec_type,
        exploration_reward_min_step = 0,
        nstack = hps.pop('frame_stack'),
        only_train_r = only_train_r
        )
    '''

    # venv.score_multiple = {'Mario': 500,
    #                        'MontezumaRevengeNoFrameskip-v4': 100,
    #                        'GravitarNoFrameskip-v4': 250,
    #                        'PrivateEyeNoFrameskip-v4': 500,
    #                        'SolarisNoFrameskip-v4': None,
    #                        'VentureNoFrameskip-v4': 200,
    #                        'PitfallNoFrameskip-v4': 100,
    #                        }[env_id]
    venv.score_multiple = 1
    venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop('gamma')

    log_interval = hps.pop('log_interval')

    nminibatches = hps.pop('nminibatches')

    play = hps.pop('play')

    if play:
        nsteps = 1

    rnd_type = hps.pop('rnd_type')
    div_type = hps.pop('div_type')

    num_agents = hps.pop('num_agents')

    load_ram = hps.pop('load_ram')

    debug = hps.pop('debug')

    rnd_mask_prob = hps.pop('rnd_mask_prob')

    rnd_mask_type = hps.pop('rnd_mask_type')

    indep_rnd = hps.pop('indep_rnd')
    logger.info("indep_rnd:", indep_rnd)
    indep_policy = hps.pop('indep_policy')

    sd_type = hps.pop('sd_type')

    from_scratch = hps.pop('from_scratch')

    use_kl = hps.pop('use_kl')

    save_interval = 100

    policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')]
    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope='pol',
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                'update_ob_stats_independently_per_gpu'),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                'proportion_of_exp_used_for_predictor_update'),
            dynamics_bonus=hps.pop("dynamics_bonus"),
            num_agents=num_agents,
            rnd_type=rnd_type,
            div_type=div_type,
            indep_rnd=indep_rnd,
            indep_policy=indep_policy,
            sd_type=sd_type,
            rnd_mask_prob=rnd_mask_prob),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        gamma_div=hps.pop('gamma_div'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=nminibatches,
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=5 if debug else 128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'),
        log_interval=log_interval,
        only_train_r=only_train_r,
        rnd_type=rnd_type,
        reset=hps.pop('reset'),
        dynamics_sample=hps.pop('dynamics_sample'),
        save_path=save_path,
        num_agents=num_agents,
        div_type=div_type,
        load_ram=load_ram,
        debug=debug,
        rnd_mask_prob=rnd_mask_prob,
        rnd_mask_type=rnd_mask_type,
        sd_type=sd_type,
        from_scratch=from_scratch,
        use_kl=use_kl,
        indep_rnd=indep_rnd)

    load_path = hps.pop('load_path')
    base_load_path = hps.pop('base_load_path')

    agent.start_interaction([venv])
    if load_path is not None:

        if play:
            agent.load(load_path)
        else:
            #agent.load(load_path)
            #agent.load_help_info(0, load_path)
            #agent.load_help_info(1, load_path)

            #load diversity agent
            #base_agent_idx = 1
            #logger.info("load base  agents weights from {}  agent {}".format(base_load_path, str(base_agent_idx)))
            #agent.load_agent(base_agent_idx, base_load_path)
            #agent.clone_baseline_agent(base_agent_idx)
            #agent.load_help_info(0, dagent_load_path)
            #agent.clone_agent(0)

            #load main agen1
            src_agent_idx = 1

            logger.info("load main agent weights from {} agent {}".format(
                load_path, str(src_agent_idx)))
            agent.load_agent(src_agent_idx, load_path)

            if indep_rnd == False:
                rnd_agent_idx = 1
            else:
                rnd_agent_idx = src_agent_idx
            #rnd_agent_idx = 0
            logger.info("load rnd weights from {} agent {}".format(
                load_path, str(rnd_agent_idx)))
            agent.load_rnd(rnd_agent_idx, load_path)
            agent.clone_agent(rnd_agent_idx,
                              rnd=True,
                              policy=False,
                              help_info=False)

            logger.info("load help info from {} agent {}".format(
                load_path, str(src_agent_idx)))
            agent.load_help_info(src_agent_idx, load_path)

            agent.clone_agent(src_agent_idx,
                              rnd=False,
                              policy=True,
                              help_info=True)

            #logger.info("load main agent weights from {} agent {}".format(load_path, str(2)))

            #load_path = '/data/xupeifrom7700_1000/seed1_log0.5_clip-0.5~0.5_3agent_hasint4_2divrew_-1~1/models'
            #agent.load_agent(1, load_path)

            #agent.clone_baseline_agent()
            #if sd_type =='sd':
            #    agent.load_sd("save_dir/models_sd_trained")

        #agent.initialize_discriminator()

    update_ob_stats_from_random_agent = hps.pop(
        'update_ob_stats_from_random_agent')
    if play == False:

        if load_path is not None:
            pass  #agent.collect_statistics_from_model()
        else:
            if update_ob_stats_from_random_agent and rnd_type == 'rnd':
                agent.collect_random_statistics(num_timesteps=128 *
                                                5 if debug else 128 * 50)
        assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())

        #agent.collect_rnd_info(128*50)
        '''
        if sd_type=='sd':
            agent.train_sd(max_nepoch=300, max_neps=5)
            path = '{}_sd_trained'.format(save_path)
            logger.log("save model:",path)
            agent.save(path)

            return
            #agent.update_diverse_agent(max_nepoch=1000)
            #path = '{}_divupdated'.format(save_path)
            #logger.log("save model:",path)
            #agent.save(path)
        '''
        counter = 0
        while True:
            info = agent.step()

            n_updates = agent.I.stats["n_updates"]
            if info['update']:
                logger.logkvs(info['update'])
                logger.dumpkvs()
                counter += 1

            if info['update'] and save_path is not None and (
                    n_updates % save_interval == 0 or n_updates == 1):
                path = '{}_{}'.format(save_path, str(n_updates))
                logger.log("save model:", path)
                agent.save(path)
                agent.save_help_info(save_path, n_updates)

            if agent.I.stats['tcount'] > num_timesteps:
                path = '{}_{}'.format(save_path, str(n_updates))
                logger.log("save model:", path)
                agent.save(path)
                agent.save_help_info(save_path, n_updates)
                break
        agent.stop_interaction()
    else:
        '''
        check_point_rews_list_path ='{}_rewslist'.format(load_path)
        check_point_rnd_path ='{}_rnd'.format(load_path)
        oracle_rnd = oracle.OracleExplorationRewardForAllEpisodes()
        oracle_rnd.load(check_point_rnd_path)
        #print(oracle_rnd._collected_positions_writer)
        #print(oracle_rnd._collected_positions_reader)

        rews_list = load_rews_list(check_point_rews_list_path)
        print(rews_list)
        '''

        istate = agent.stochpol.initial_state(1)
        #ph_mean, ph_std = agent.stochpol.get_ph_mean_std()

        last_obs, prevrews, ec_rews, news, infos, ram_states, _ = agent.env_get(
            0)
        agent.I.step_count += 1

        flag = False
        show_cam = True

        last_xr = 0

        restore = None
        '''
        #path = 'ram_state_500_7room'
        #path='ram_state_400_6room'
        #path='ram_state_6700' 
        path='ram_state_7700_10room'
        f = open(path,'rb')
        restore = pickle.load(f)
        f.close()
        last_obs[0] = agent.I.venvs[0].restore_full_state_by_idx(restore,0)
        print(last_obs.shape)

        #path = 'ram_state_400_monitor_rews_6room'
        #path = 'ram_state_500_monitor_rews_7room'
        #path='ram_state_6700_monitor_rews'
        path='ram_state_7700_monitor_rews_10room'
        f = open(path,'rb')
        monitor_rews = pickle.load(f)
        f.close()
        
        agent.I.venvs[0].set_cur_monitor_rewards_by_idx(monitor_rews,0)
        '''

        agent_idx = np.asarray([0])
        sample_agent_prob = np.asarray([0.5])

        ph_mean = agent.stochpol.ob_rms_list[0].mean
        ph_std = agent.stochpol.ob_rms_list[0].var**0.5

        buf_ph_mean = np.zeros(
            ([1, 1] + list(agent.stochpol.ob_space.shape[:2]) + [1]),
            np.float32)
        buf_ph_std = np.zeros(
            ([1, 1] + list(agent.stochpol.ob_space.shape[:2]) + [1]),
            np.float32)

        buf_ph_mean[0, 0] = ph_mean
        buf_ph_std[0, 0] = ph_std

        vpreds_ext_list = []

        ep_rews = np.zeros((1))
        divexp_flag = False
        step_count = 0
        stage_prob = True

        last_rew_ob = np.full_like(last_obs, 128)

        clusters = Clusters(1.0)

        #path = '{}_sd_rms'.format(load_path)
        #agent.I.sd_rms.load(path)

        while True:

            dict_obs = agent.stochpol.ensure_observation_is_dict(last_obs)

            #acs= np.random.randint(low=0, high=15, size=(1))
            acs, vpreds_int, vpreds_ext, nlps, istate, ent = agent.stochpol.call(
                dict_obs, news, istate, agent_idx[:, None])

            step_acs = acs
            t = ''
            #if show_cam==True:
            t = input("input:")
            if t != '':
                t = int(t)
                if t <= 17:
                    step_acs = [t]

            agent.env_step(0, step_acs)

            obs, prevrews, ec_rews, news, infos, ram_states, monitor_rews = agent.env_get(
                0)

            if news[0] and restore is not None:
                obs[0] = agent.I.venvs[0].restore_full_state_by_idx(restore, 0)
                agent.I.venvs[0].set_cur_monitor_rewards_by_idx(
                    monitor_rews, 0)

            ep_rews = ep_rews + prevrews

            print(ep_rews)

            last_rew_ob[prevrews > 0] = obs[prevrews > 0]

            room = infos[0]['position'][2]
            vpreds_ext_list.append([vpreds_ext, room])
            #print(monitor_rews[0])
            #print(len(monitor_rews[0]))
            #print(infos[0]['open_door_type'])

            stack_obs = np.concatenate([last_obs[:, None], obs[:, None]], 1)

            fd = {}

            fd[agent.stochpol.ph_ob[None]] = stack_obs

            fd.update({
                agent.stochpol.sep_ph_mean: buf_ph_mean,
                agent.stochpol.sep_ph_std: buf_ph_std
            })
            fd[agent.stochpol.ph_agent_idx] = agent_idx[:, None]
            fd[agent.stochpol.sample_agent_prob] = sample_agent_prob[:, None]

            fd[agent.stochpol.last_rew_ob] = last_rew_ob[:, None]
            fd[agent.stochpol.game_score] = ep_rews[:, None]

            fd[agent.stochpol.sd_ph_mean] = agent.I.sd_rms.mean
            fd[agent.stochpol.sd_ph_std] = agent.I.sd_rms.var**0.5

            div_prob = 0

            all_div_prob = tf_util.get_session().run(
                [agent.stochpol.all_div_prob], fd)
            '''
            if prevrews[0] > 0:
                clusters.update(rnd_em,room)
    
                num_clusters = len(clusters._cluster_list)
                for i in range(num_clusters):
                    print("{} {}".format(str(i),list(clusters._room_set[i])))
            '''
            print("vpreds_int: ", vpreds_int, "vpreds_ext:", vpreds_ext,
                  "ent:", ent, "all_div_prob:", all_div_prob, "room:", room,
                  "step_count:", step_count)

            #aaaa = np.asarray(vpreds_ext_list)
            #print(aaaa[-100:])
            '''
            if step_acs[0]==0:
                ram_state = ram_states[0]
                path='ram_state_7700_10room'
                f = open(path,'wb')
                pickle.dump(ram_state,f)
                f.close()

                path='ram_state_7700_monitor_rews_10room'
                f = open(path,'wb')
                pickle.dump(monitor_rews[0],f)
                f.close()
            '''
            '''
            if  restore is None:
                restore = ram_states[0]

            
            if np.random.rand() < 0.1:
                print("restore")
                obs = agent.I.venvs[0].restore_full_state_by_idx(restore,0)
                prevrews = None
                ec_rews = None
                news= True
                infos = {}
                ram_states = ram_states[0]

                #restore = ram_states[0]
            '''

            img = agent.I.venvs[0].render()

            last_obs = obs

            step_count = step_count + 1

            time.sleep(0.04)