Exemple #1
0
def test(v):
    # logger.record_tabular('tabular', 1)
    # logger.dump_tabular(with_prefix=False)
    # logger.log('Try this log')
    logger.log(logger.get_snapshot_dir())
    # logger.log(resource_manager.get_file('policy_validation_inits_swimmer_rllab.save'))
    # logger.log(os.path.join(config.PROJECT_PATH,'sandbox/thanard/bootstrapping/data/policy_validation_inits_swimmer_rllab.save'))
    # print('how about print?')
    import pickle
    dictionary = {'a':1,'b':2}

    # Loading
    filename = os.path.join(config.PROJECT_PATH, 'data_upload/policy_validation_inits_swimmer_rllab.save')
    # filename = os.path.join(logger.get_snapshot_dir(), 'sandbox/thanard/bootstrapping/data/policy_validation_inits_swimmer_rllab.save')
    # filename =  resource_manager.get_file('policy_validation_inits_swimmer_rllab.save')
    with open(filename, 'rb') as f:
        x=pickle.load(f)
    logger.log(str(x))

    # Saving
    os.makedirs(os.path.join(logger.get_snapshot_dir(),'my_dict'))
    with open(os.path.join(logger.get_snapshot_dir(),'my_dict/mydict.pkl'), 'wb') as f:
        pickle.dump(dictionary, f)

    # Loading
    with open(os.path.join(logger.get_snapshot_dir(),'my_dict/mydict.pkl'), 'rb') as f:
        x = pickle.load(f)
    logger.log(str(x))
Exemple #2
0
    def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None, record_log=True,
                 force_reset=False):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        self.env = env
        self.env_id = env.spec.id

        monitor_manager.logger.setLevel(logging.WARNING)

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        self._action_space = convert_gym_space(env.action_space)
        self._horizon = env.spec.timestep_limit
        self._log_dir = log_dir
        self._force_reset = force_reset
    def __init__(self, env, record_video=True, video_schedule=None, log_dir=None, record_log=True,
                 force_reset=False):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        self.env = env
        # self.env_id = env.spec.id

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        logger.log("observation space: {}".format(self._observation_space))
        self._action_space = convert_gym_space(env.action_space)
        logger.log("action space: {}".format(self._action_space))
        # self._horizon = env.spec.tags['wrapper_config.TimeLimit.max_episode_steps']
        self._log_dir = log_dir
        self._force_reset = force_reset
Exemple #4
0
    def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None, record_log=True,
                 force_reset=False):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        self.env = env
        self.env_id = env.spec.id

        monitor_manager.logger.setLevel(logging.WARNING)

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        self._action_space = convert_gym_space(env.action_space)
        self._horizon = env.spec.timestep_limit
        self._log_dir = log_dir
        self._force_reset = force_reset
Exemple #5
0
    def __init__(self,
                 env_name,
                 record_video=False,
                 video_schedule=None,
                 log_dir=None,
                 record_log=False,
                 force_reset=True):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log(
                    "Warning: skipping Gym environment monitoring since snapshot_dir not configured."
                )
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        if env_name == 'humanoid-rllab':
            from sac.envs.fully_observable_humanoid import FullyObservableHumanoid
            env = FullyObservableHumanoid()
        elif env_name == 'ant':
            from sac.envs.fully_observable_ant import FullyObservableAnt
            env = FullyObservableAnt()
        elif env_name == 'hc':
            from sac.envs.fully_observable_half_cheetah import FullyObservableHalfCheetah
            env = FullyObservableHalfCheetah()
        else:
            raise ValueError('case {} not handled'.format(env_name))

        # HACK: Gets rid of the TimeLimit wrapper that sets 'done = True' when
        # the time limit specified for each environment has been passed and
        # therefore the environment is not Markovian (terminal condition depends
        # on time rather than state).
        # env = env.env

        self._mjcenv = env

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self._mjcenv = gym.wrappers.Monitor(self.env,
                                                log_dir,
                                                video_callable=video_schedule,
                                                force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        logger.log("observation space: {}".format(self._observation_space))
        self._action_space = convert_gym_space(env.action_space)
        logger.log("action space: {}".format(self._action_space))
        # env.spec.tags['wrapper_config.TimeLimit.max_episode_steps']
        self._horizon = 1000
        self._log_dir = log_dir
        self._force_reset = force_reset
    def __init__(self,
                 env_name,
                 record_video=True,
                 video_schedule=None,
                 log_dir=None,
                 record_log=True,
                 force_reset=False,
                 screen_width=84,
                 screen_height=84):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log(
                    "Warning: skipping Gym environment monitoring since snapshot_dir not configured."
                )
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        if 'Doom' in env_name:
            from ppaquette_gym_doom.wrappers.action_space import ToDiscrete
            wrapper = ToDiscrete('minimal')
            env = wrapper(env)

        self.env = env
        self.env_id = env.spec.id

        monitor_manager.logger.setLevel(logging.WARNING)

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env,
                                            log_dir,
                                            video_callable=video_schedule,
                                            force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        self._action_space = convert_gym_space(env.action_space)
        self._horizon = env.spec.timestep_limit
        self._log_dir = log_dir
        self._force_reset = force_reset
        self.screen_width = screen_width
        self.screen_height = screen_height
        self._observation_space = Box(low=0,
                                      high=1,
                                      shape=(screen_width, screen_height, 1))
Exemple #7
0
    def __init__(self,
                 env_name,
                 record_video=False,
                 video_schedule=None,
                 log_dir=None,
                 record_log=False,
                 force_reset=True):
        super(GymEnv, self).__init__()
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log(
                    "Warning: skipping Gym environment monitoring since snapshot_dir not configured."
                )
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)

        # HACK: Gets rid of the TimeLimit wrapper that sets 'done = True' when
        # the time limit specified for each environment has been passed and
        # therefore the environment is not Markovian (terminal condition depends
        # on time rather than state).
        env = env.env

        self.env = env
        self.env_id = env.spec.id

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env,
                                            log_dir,
                                            video_callable=video_schedule,
                                            force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        logger.log("observation space: {}".format(self._observation_space))
        self._action_space = convert_gym_space(env.action_space)
        logger.log("action space: {}".format(self._action_space))
        self._horizon = env.spec.tags[
            'wrapper_config.TimeLimit.max_episode_steps']
        self._log_dir = log_dir
        self._force_reset = force_reset

        self._obs_space = None
        self._current_obs_dim = None
Exemple #8
0
    def __init__(self,
                 env_name,
                 wrappers=(),
                 wrapper_args=(),
                 record_video=True,
                 video_schedule=None,
                 log_dir=None,
                 record_log=True,
                 post_create_env_seed=None,
                 force_reset=False):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log(
                    "Warning: skipping Gym environment monitoring since snapshot_dir not configured."
                )
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        if post_create_env_seed is not None:
            env.set_env_seed(post_create_env_seed)
        for i, wrapper in enumerate(wrappers):
            if wrapper_args and len(wrapper_args) == len(wrappers):
                env = wrapper(env, **wrapper_args[i])
            else:
                env = wrapper(env)
        self.env = env
        self.env_id = env.spec.id

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            # self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
            self.env = CustomGymMonitorEnv(self.env,
                                           log_dir,
                                           video_callable=video_schedule,
                                           force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        logger.log("observation space: {}".format(self._observation_space))
        self._action_space = convert_gym_space(env.action_space)
        logger.log("action space: {}".format(self._action_space))
        self._horizon = env.spec.tags.get(
            'wrapper_config.TimeLimit.max_episode_steps')
        self._log_dir = log_dir
        self._force_reset = force_reset
Exemple #9
0
    def __init__(self,
                 env_names,
                 record_video=True,
                 video_schedule=None,
                 log_dir=None,
                 record_log=True,
                 force_reset=False):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log(
                    "Warning: skipping Gym environment monitoring since snapshot_dir not configured."
                )
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        envs = []
        for name in env_names:
            envs.append(gym.envs.make(name))
        self.envs = envs
        self.envs_id = []
        for env in envs:
            self.envs_id.append(env.spec.id)

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env,
                                            log_dir,
                                            video_callable=video_schedule,
                                            force=True)
            self.monitoring = True

        # NOTE: the observation_space, action_space and horizon are assumed to be the same across the multiple loaded envs
        self._observation_space = convert_gym_space(envs[0].observation_space,
                                                    len(self.envs))
        logger.log("observation space: {}".format(self._observation_space))
        self._action_space = convert_gym_space(envs[0].action_space)
        logger.log("action space: {}".format(self._action_space))
        self._horizon = envs[0].spec.tags[
            'wrapper_config.TimeLimit.max_episode_steps']
        self._log_dir = log_dir
        self._force_reset = force_reset
        self._current_activated_env = 0

        self.split_task_test = True
        self.avg_div = 0
Exemple #10
0
    def restore(self, checkpoint_dir=None):
        if checkpoint_dir is None: checkpoint_dir = logger.get_snapshot_dir()
        checkpoint_file = os.path.join(checkpoint_dir, 'params.chk')
        if os.path.isfile(checkpoint_file + '.meta'):
            sess = tf.get_default_session()
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint_file)

            tabular_chk_file = os.path.join(checkpoint_dir, 'progress.csv.chk')
            if os.path.isfile(tabular_chk_file):
                tabular_file = os.path.join(checkpoint_dir, 'progress.csv')
                logger.remove_tabular_output(tabular_file)
                shutil.copy(tabular_chk_file, tabular_file)
                logger.add_tabular_output(tabular_file)

            if self.qf is not None:
                pool_file = os.path.join(checkpoint_dir, 'pool.chk')
                if self.save_format == 'pickle':
                    pickle_load(pool_file)
                elif self.save_format == 'joblib':
                    self.pool = joblib.load(pool_file)
                else:
                    raise NotImplementedError

            logger.log('Restored from checkpoint %s' % checkpoint_file)
        else:
            logger.log('No checkpoint %s' % checkpoint_file)
Exemple #11
0
 def _save_rollouts_file(self, itr, rollouts, eval=False):
     if eval:
         fname = 'itr_{0}_rollouts_eval.pkl'.format(itr)
     else:
         fname = 'itr_{0}_rollouts.pkl'.format(itr)
     fname = os.path.join(logger.get_snapshot_dir(), fname)
     joblib.dump({'rollouts': rollouts}, fname, compress=3)
Exemple #12
0
    def __init__(self, algo, args, exp_name):
        self.args = args
        self.algo = algo

        env = algo.env
        baseline = algo.baseline

        # Logger
        default_log_dir = config.LOG_DIR
        if args.log_dir is None:
            log_dir = osp.join(default_log_dir, exp_name)
        else:
            log_dir = args.log_dir

        tabular_log_file = osp.join(log_dir, args.tabular_log_file)
        text_log_file = osp.join(log_dir, args.text_log_file)
        params_log_file = osp.join(log_dir, args.params_log_file)

        logger.log_parameters_lite(params_log_file, args)
        logger.add_text_output(text_log_file)
        logger.add_tabular_output(tabular_log_file)
        logger.set_snapshot_dir(log_dir)
        logger.set_snapshot_mode(args.snapshot_mode)
        logger.set_log_tabular_only(args.log_tabular_only)
        logger.push_prefix("[%s] " % exp_name)

        prev_snapshot_dir = logger.get_snapshot_dir()
        prev_mode = logger.get_snapshot_mode()
Exemple #13
0
    def save(self, checkpoint_dir=None):
        if checkpoint_dir is None: checkpoint_dir = logger.get_snapshot_dir()

        if self.qf is not None:
            pool_file = os.path.join(checkpoint_dir, 'pool.chk')
            if self.save_format == 'pickle':
                pickle_dump(pool_file + '.tmp', self.pool)
            elif self.save_format == 'joblib':
                joblib.dump(self.pool,
                            pool_file + '.tmp',
                            compress=1,
                            cache_size=1e9)
            else:
                raise NotImplementedError
            shutil.move(pool_file + '.tmp', pool_file)

        checkpoint_file = os.path.join(checkpoint_dir, 'params.chk')
        sess = tf.get_default_session()
        saver = tf.train.Saver()
        saver.save(sess, checkpoint_file)

        tabular_file = os.path.join(checkpoint_dir, 'progress.csv')
        if os.path.isfile(tabular_file):
            tabular_chk_file = os.path.join(checkpoint_dir, 'progress.csv.chk')
            shutil.copy(tabular_file, tabular_chk_file)

        logger.log('Saved to checkpoint %s' % checkpoint_file)
def run_task(vv, log_dir=None, exp_name=None):

    # Load environment
    radius = vv['radius']
    target_velocity = vv['target_velocity']
    env = normalize(CircleEnv(radius, target_velocity))

    # Save variant information for comparison plots
    variant_file = logger.get_snapshot_dir() + '/variant.json'
    logger.log_variant(variant_file, vv)

    # Train policy using TRPO
    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=(32, 32)
    )
    baseline = LinearFeatureBaseline(env_spec=env.spec)
    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=1000,
        max_path_length=env.horizon,
        n_itr=1000,
        discount=0.99,
        step_size=0.01,
        plot=False,
    )
    algo.train()
    def _evaluate(self, epoch):

        logger.log("Collecting samples for evaluation")
        snapshot_dir = logger.get_snapshot_dir()

        paths = rollouts(self._env, self._eval_policy, self._max_path_length,
                         self._n_eval_episodes)

        average_discounted_return = np.mean([
            special.discount_return(path["rewards"], self._discount)
            for path in paths
        ])
        returns = np.asarray([sum(path["rewards"]) for path in paths])

        statistics = OrderedDict([
            ('Epoch', epoch),
            ('AverageDiscountedReturn', average_discounted_return),
            ('Alpha', self._alpha), ('returns', returns)
        ])

        for key, value in statistics.items():
            logger.record_tabular(key, value)

        self._env.log_diagnostics(paths)

        # Plot test paths.
        if (hasattr(self._env, 'plot_paths')
                and self._env_plot_settings is not None):
            img_file = os.path.join(snapshot_dir, 'env_itr_%05d.png' % epoch)
            # Remove previous paths.
            if self._env_lines is not None:
                [path.remove() for path in self._env_lines]
            self._env_lines = self._env.plot_paths(paths, self._ax_env)
            plt.pause(0.001)
            plt.draw()
            self._fig_env.savefig(img_file, dpi=100)

        # Plot the Q-function level curves and action samples.
        if (hasattr(self._qf_eval, 'plot_level_curves')
                and self._q_plot_settings is not None):
            img_file = os.path.join(snapshot_dir, 'q_itr_%05d.png' % epoch)
            [ax.clear() for ax in self._ax_q_lst]
            self._qf_eval.plot_level_curves(
                ax_lst=self._ax_q_lst,
                observations=self._q_plot_settings['obs_lst'],
                action_dims=self._q_plot_settings['action_dims'],
                xlim=self._q_plot_settings['xlim'],
                ylim=self._q_plot_settings['ylim'],
            )
            self._visualization_policy.plot_samples(
                self._ax_q_lst, self._q_plot_settings['obs_lst'])
            for ax in self._ax_q_lst:
                ax.set_xlim(self._q_plot_settings['xlim'])
                ax.set_ylim(self._q_plot_settings['ylim'])
            plt.pause(0.001)
            plt.draw()
            self._fig_q.savefig(img_file, dpi=100)

        gc.collect()
Exemple #16
0
    def __init__(self,
                 env_name,
                 adv_fraction=1.0,
                 record_video=True,
                 video_schedule=None,
                 log_dir=None,
                 record_log=True):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log(
                    "Warning: skipping Gym environment monitoring since snapshot_dir not configured."
                )
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        def_adv = env.adv_action_space.high[0]
        new_adv = def_adv * adv_fraction
        env.update_adversary(new_adv)
        self.env = env
        self.env_id = env.spec.id

        monitor.logger.setLevel(logging.WARNING)

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env.monitor.start(
                log_dir, video_schedule,
                force=True)  # add 'force=True' if want overwrite dirs
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        self._pro_action_space = convert_gym_space(env.pro_action_space)
        self._adv_action_space = convert_gym_space(env.adv_action_space)
        self._horizon = env.spec.timestep_limit
        self._log_dir = log_dir
Exemple #17
0
def run_gcg(params):
    # copy yaml for posterity
    try:
        yaml_path = os.path.join(logger.get_snapshot_dir(),
                                 '{0}.yaml'.format(params['exp_name']))
        with open(yaml_path, 'w') as f:
            f.write(params['txt'])
    except:
        pass

    os.environ["CUDA_VISIBLE_DEVICES"] = str(
        params['policy']['gpu_device'])  # TODO: hack so don't double GPU
    config.USE_TF = True

    normalize_env = params['alg'].pop('normalize_env')

    env_str = params['alg'].pop('env')
    env = create_env(env_str, is_normalize=normalize_env, seed=params['seed'])

    env_eval_str = params['alg'].pop('env_eval', env_str)
    env_eval = create_env(env_eval_str,
                          is_normalize=normalize_env,
                          seed=params['seed'])

    env.reset()
    env_eval.reset()

    #####################
    ### Create policy ###
    #####################

    policy_class = params['policy']['class']
    PolicyClass = eval(policy_class)
    policy_params = params['policy'][policy_class]

    policy = PolicyClass(
        env_spec=env.spec,
        exploration_strategies=params['alg'].pop('exploration_strategies'),
        **policy_params,
        **params['policy'])

    ########################
    ### Create algorithm ###
    ########################

    if 'max_path_length' in params['alg']:
        max_path_length = params['alg'].pop('max_path_length')
    else:
        max_path_length = env.horizon
    algo = GCG(env=env,
               env_eval=env_eval,
               policy=policy,
               max_path_length=max_path_length,
               env_str=env_str,
               **params['alg'])
    algo.train()
Exemple #18
0
def rllab_logdir(algo=None, dirname=None):
    if dirname:
        rllablogger.set_snapshot_dir(dirname)
    dirname = rllablogger.get_snapshot_dir()
    rllablogger.add_tabular_output(os.path.join(dirname, 'progress.csv'))
    if algo:
        with open(os.path.join(dirname, 'params.json'), 'w') as f:
            params = extract_hyperparams(algo)
            json.dump(params, f)
    yield dirname
    rllablogger.remove_tabular_output(os.path.join(dirname, 'progress.csv'))
Exemple #19
0
def rllab_logdir(algo=None, dirname=None):
    if dirname:
        rllablogger.set_snapshot_dir(dirname)
    dirname = rllablogger.get_snapshot_dir()
    rllablogger.add_tabular_output(os.path.join(dirname, 'progress.csv'))
    if algo:
        with open(os.path.join(dirname, 'params.json'), 'w') as f:
            params = extract_hyperparams(algo)
            json.dump(params, f)
    yield dirname
    rllablogger.remove_tabular_output(os.path.join(dirname, 'progress.csv'))
Exemple #20
0
    def __init__(self,
                 env_name,
                 record_video=True,
                 video_schedule=None,
                 log_dir=None):

        ## following lines modified by me (correspondingly commented out below) to suppress the warning messages
        if log_dir is None and logger.get_snapshot_dir() is not None:
            log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")

# *********************
#        if log_dir is None:
#            if logger.get_snapshot_dir() is None:
#                logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
#            else:
#                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
# *********************

        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        self.env = env
        self.env_id = env.spec.id

        monitor.logger.setLevel(logging.CRITICAL)

        if log_dir is None:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env.monitor.start(log_dir, video_schedule)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        self._action_space = convert_gym_space(env.action_space)
        self._horizon = env.spec.timestep_limit
        self._log_dir = log_dir
Exemple #21
0
    def __init__(self, env_name, record_video=False, video_schedule=None, log_dir=None, record_log=False,
                 force_reset=True):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)

        # HACK: Gets rid of the TimeLimit wrapper that sets 'done = True' when
        # the time limit specified for each environment has been passed and
        # therefore the environment is not Markovian (terminal condition depends
        # on time rather than state).
        env = env.env

        self.env = env
        self.env_id = env.spec.id

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        logger.log("observation space: {}".format(self._observation_space))
        self._action_space = convert_gym_space(env.action_space)
        logger.log("action space: {}".format(self._action_space))
        self._horizon = env.spec.tags['wrapper_config.TimeLimit.max_episode_steps']
        self._log_dir = log_dir
        self._force_reset = force_reset
Exemple #22
0
    def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None):

## following lines modified by me (correspondingly commented out below) to suppress the warning messages
        if log_dir is None and logger.get_snapshot_dir() is not None:
            log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")

# *********************
#        if log_dir is None:
#            if logger.get_snapshot_dir() is None:
#                logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
#            else:
#                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
# *********************

        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        self.env = env
        self.env_id = env.spec.id

        #monitor.logger.setLevel(logging.CRITICAL)

        if log_dir is None:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env.monitor.start(log_dir, video_schedule)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        self._action_space = convert_gym_space(env.action_space)
        self._horizon = env.spec.timestep_limit
        self._log_dir = log_dir
        self.injure_idx = None
def run_task(v):
    random.seed(v['seed'])
    np.random.seed(v['seed'])
    logger.log("Initializing report...")
    log_dir = logger.get_snapshot_dir()
    if log_dir is None:
        log_dir = "/home/michael/"
    report = HTMLReport(osp.join(log_dir, 'report.html'), images_per_row=1)

    fixed_goal_generator = FixedStateGenerator(state=v['ultimate_goal'])
    fixed_start_generator = FixedStateGenerator(state=v['ultimate_goal'])

    inner_env = normalize(SwimmerMazeEnv(maze_size_scaling=3))
    env = GoalStartExplorationEnv(
        env=inner_env,
        start_generator=fixed_start_generator,
        obs2start_transform=lambda x: x[:v['start_size']],
        goal_generator=fixed_goal_generator,
        obs2goal_transform=lambda x: x[-3:-1],
        terminal_eps=v['terminal_eps'],
        distance_metric=v['distance_metric'],
        extend_dist_rew=v['extend_dist_rew'],
        inner_weight=v['inner_weight'],
        goal_weight=v['goal_weight'],
        terminate_env=True,
    )

    seed_starts = generate_starts(
        env,
        starts=[v['start_goal']],
        horizon=v['initial_brownian_horizon'],
        size=500,  #TODO: increase to 2000
        # size speeds up training a bit
        variance=v['brownian_variance'],
        subsample=v['num_new_starts'],
    )  # , animated=True, speedup=1)
    np.random.shuffle(seed_starts)

    # with env.set_kill_outside():
    feasible_states = find_all_feasible_states_plotting(
        env,
        seed_starts,
        report,
        distance_threshold=0.2,
        brownian_variance=1,
        animate=True,
        limit=v['goal_range'],
        check_feasible=False,
        center=v['goal_center'])
    return
Exemple #24
0
def run_experiment(algo,
                   n_parallel=0,
                   seed=0,
                   plot=False,
                   log_dir=None,
                   exp_name=None,
                   snapshot_mode='last',
                   snapshot_gap=1,
                   exp_prefix='experiment',
                   log_tabular_only=False):
    default_log_dir = config.LOG_DIR + "/local/" + exp_prefix
    set_seed(seed)
    if exp_name is None:
        now = datetime.datetime.now(dateutil.tz.tzlocal())
        timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
        exp_name = 'experiment_%s' % (timestamp)
    if n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=n_parallel)
        parallel_sampler.set_seed(seed)
    if plot:
        from rllab.plotter import plotter
        plotter.init_worker()
    if log_dir is None:
        log_dir = osp.join(default_log_dir, exp_name)
    tabular_log_file = osp.join(log_dir, 'progress.csv')
    text_log_file = osp.join(log_dir, 'debug.log')
    #params_log_file = osp.join(log_dir, 'params.json')

    #logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(snapshot_mode)
    logger.set_snapshot_gap(snapshot_gap)
    logger.set_log_tabular_only(log_tabular_only)
    logger.push_prefix("[%s] " % exp_name)

    algo.train()

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Exemple #25
0
def find_all_feasible_states(env,
                             seed_starts,
                             distance_threshold=0.1,
                             brownian_variance=1,
                             animate=False,
                             speedup=10,
                             max_states=None,
                             horizon=1000,
                             states_transform=None):
    # states_transform is optional transform of states
    # print('the seed_starts are of shape: ', seed_starts.shape)
    log_dir = logger.get_snapshot_dir()
    if states_transform is not None:
        all_feasible_starts = StateCollection(
            distance_threshold=distance_threshold,
            states_transform=states_transform)
    else:
        all_feasible_starts = StateCollection(
            distance_threshold=distance_threshold)
    all_feasible_starts.append(seed_starts)
    logger.log('finish appending all seed_starts')
    no_new_states = 0
    while no_new_states < 5:
        total_num_starts = all_feasible_starts.size
        if max_states is not None:
            if total_num_starts > max_states:
                return
        starts = all_feasible_starts.sample(100)
        new_starts = generate_starts(env,
                                     starts=starts,
                                     horizon=horizon,
                                     size=10000,
                                     variance=brownian_variance,
                                     animated=animate,
                                     speedup=speedup)
        logger.log("Done generating new starts")
        all_feasible_starts.append(new_starts, n_process=1)
        num_new_starts = all_feasible_starts.size - total_num_starts
        logger.log("number of new states: {}, total_states: {}".format(
            num_new_starts, all_feasible_starts.size))
        if num_new_starts < 10:
            no_new_states += 1
        with open(osp.join(log_dir, 'all_feasible_states.pkl'), 'wb') as f:
            cloudpickle.dump(all_feasible_starts, f, protocol=3)
Exemple #26
0
def setup_logging(log_dir, algo, env, novice, expert, DR):
    tabular_log_file = os.path.join(log_dir, "progress.csv")
    text_log_file = os.path.join(log_dir, "debug.log")
    params_log_file = os.path.join(log_dir, "params.json")
    snapshot_mode = "last"
    snapshot_gap = 1
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(snapshot_mode)
    logger.set_snapshot_gap(snapshot_gap)
    print("Finished setting up logging.")

    # log some stuff
    logger.log("Created algorithm object of type: %s", type(algo))
    logger.log("env of type: %s" % type(env))
    logger.log("novice of type: %s" % type(novice))
    logger.log("expert of type: %s" % type(expert))
    logger.log("decision_rule of type: %s" % type(DR))
    logger.log("DAgger beta decay: %s" % DR.beta_decay)
    logger.log("numtrajs per epoch/itr: %s" % algo.numtrajs)
    logger.log("n_iter: %s" % algo.n_itr)
    logger.log("max path length: %s" % algo.max_path_length)
    logger.log("Optimizer info - ")
    logger.log("Optimizer of type: %s" % type(algo.optimizer))
    if type(algo.optimizer) == FirstOrderOptimizer:
        logger.log("Optimizer of class: %s" %
                   type(algo.optimizer._tf_optimizer))
        logger.log("optimizer learning rate: %s" %
                   algo.optimizer._tf_optimizer._lr)
        logger.log("optimizer max epochs: %s" % algo.optimizer._max_epochs)
        logger.log("optimizer batch size: %s" % algo.optimizer._batch_size)
    elif type(algo.optimizer) == PenaltyLbfgsOptimizer:
        logger.log("initial_penalty %s" % algo.optimizer._initial_penalty)
        logger.log("max_opt_itr %s" % algo.optimizer._max_opt_itr)
        logger.log("max_penalty %s" % algo.optimizer._max_penalty)

    return True
def run_task(variant):

    log_dir = logger.get_snapshot_dir()
    report = HTMLReport(os.path.join(log_dir, 'report.html'),
                        images_per_row=2,
                        default_image_width=500)
    report.add_header('Simple Circle Sampling')
    report.add_text(format_dict(variant))
    report.save()

    gan = SimpleGAN(noise_size=5, tf_session=tf.Session())

    rand_theta = np.random.uniform(0, 2 * np.pi, size=(5000, 1))
    data = np.hstack([0.5 * np.cos(rand_theta), 0.5 * np.sin(rand_theta)])
    data = data + np.random.normal(scale=0.05, size=data.shape)

    report.add_image(plot_samples(data[:500, :]), 'Real data')

    generated_samples, _ = gan.sample_generator(100)
    report.add_image(plot_samples(generated_samples))

    for outer_iter in range(30):
        dloss, gloss = gan.train(
            data,
            outer_iters=variant['outer_iters'],
        )
        logger.log('Outer iteration: {}, disc loss: {}, gen loss: {}'.format(
            outer_iter, dloss, gloss))
        report.add_text(
            'Outer iteration: {}, disc loss: {}, gen loss: {}'.format(
                outer_iter, dloss, gloss))
        generated_samples, _ = gan.sample_generator(50)
        report.add_image(plot_samples(generated_samples))
        report.add_image(plot_dicriminator(gan))

        report.save()
Exemple #28
0
def run_experiment(argv):

    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel',
                        type=int,
                        default=1,
                        help='Number of parallel workers to perform rollouts.')
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=default_log_dir,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from sandbox.vase.sampler import parallel_sampler_expl as parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)

        if args.seed is not None:
            set_seed(args.seed)
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    # read from stdin
    data = pickle.loads(base64.b64decode(args.args_data))

    log_dir = args.log_dir
    # exp_dir = osp.join(log_dir, args.exp_name)
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    maybe_iter = concretize(data)
    if is_iterable(maybe_iter):
        for _ in maybe_iter:
            pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.99)
    parser.add_argument('--gae_lambda', type=float, default=1.0)
    parser.add_argument('--reward_scale', type=float, default=1.0)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum', action='store_true', default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--rectangle', type=str, default='10,10')
    parser.add_argument('--map_type', type=str, default='rectangle')
    parser.add_argument('--n_evaders', type=int, default=5)
    parser.add_argument('--n_pursuers', type=int, default=2)
    parser.add_argument('--obs_range', type=int, default=3)
    parser.add_argument('--n_catch', type=int, default=2)
    parser.add_argument('--urgency', type=float, default=0.0)
    parser.add_argument('--pursuit', dest='train_pursuit', action='store_true')
    parser.add_argument('--evade', dest='train_pursuit', action='store_false')
    parser.set_defaults(train_pursuit=True)
    parser.add_argument('--surround', action='store_true', default=False)
    parser.add_argument('--constraint_window', type=float, default=1.0)
    parser.add_argument('--sample_maps', action='store_true', default=False)
    parser.add_argument('--map_file', type=str, default='../maps/map_pool.npy')
    parser.add_argument('--flatten', action='store_true', default=False)
    parser.add_argument('--reward_mech', type=str, default='global')
    parser.add_argument('--catchr', type=float, default=0.1)
    parser.add_argument('--term_pursuit', type=float, default=5.0)

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_type', type=str, default='linear')

    parser.add_argument('--conv', action='store_true', default=False)

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')


    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    if args.sample_maps:
        map_pool = np.load(args.map_file)
    else:
        if args.map_type == 'rectangle':
            env_map = TwoDMaps.rectangle_map(*map(int, args.rectangle.split(',')))
        elif args.map_type == 'complex':
            env_map = TwoDMaps.complex_map(*map(int, args.rectangle.split(',')))
        else:
            raise NotImplementedError()
        map_pool = [env_map]

    env = PursuitEvade(map_pool, n_evaders=args.n_evaders, n_pursuers=args.n_pursuers,
                       obs_range=args.obs_range, n_catch=args.n_catch,
                       train_pursuit=args.train_pursuit, urgency_reward=args.urgency,
                       surround=args.surround, sample_maps=args.sample_maps,
                       constraint_window=args.constraint_window,
                       flatten=args.flatten,
                       reward_mech=args.reward_mech,
                       catchr=args.catchr,
                       term_pursuit=args.term_pursuit)

    env = RLLabEnv(
            StandardizedEnv(env, scale_reward=args.reward_scale, enable_obsnorm=False),
            mode=args.control)

    if args.recurrent:
        if args.conv:
            feature_network = ConvNetwork(
                input_shape=emv.spec.observation_space.shape,
                output_dim=5, 
                conv_filters=(8,16,16),
                conv_filter_sizes=(3,3,3),
                conv_strides=(1,1,1),
                conv_pads=('VALID','VALID','VALID'),
                hidden_sizes=(64,), 
                hidden_nonlinearity=NL.rectify,
                output_nonlinearity=NL.softmax)
        else:
            feature_network = MLP(
                input_shape=(env.spec.observation_space.flat_dim + env.spec.action_space.flat_dim,),
                output_dim=5, hidden_sizes=(128,128,128), hidden_nonlinearity=NL.tanh,
                output_nonlinearity=None)
        if args.recurrent == 'gru':
            policy = CategoricalGRUPolicy(env_spec=env.spec, feature_network=feature_network,
                                       hidden_dim=int(args.policy_hidden_sizes))
    elif args.conv:
        feature_network = ConvNetwork(
            input_shape=env.spec.observation_space.shape,
            output_dim=5, 
            conv_filters=(8,16,16),
            conv_filter_sizes=(3,3,3),
            conv_strides=(1,1,1),
            conv_pads=('valid','valid','valid'),
            hidden_sizes=(64,), 
            hidden_nonlinearity=NL.rectify,
            output_nonlinearity=NL.softmax)
        policy = CategoricalMLPPolicy(env_spec=env.spec, prob_network=feature_network)
    else:
        policy = CategoricalMLPPolicy(env_spec=env.spec, hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(obsfeat_space)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        mode=args.control,)

    algo.train()
    def train(self):
        # TODO - make this a util
        flatten_list = lambda l: [item for sublist in l for item in sublist]

        with tf.Session() as sess:
            # Code for loading a previous policy. Somewhat hacky because needs to be in sess.
            if self.load_policy is not None:
                import joblib
                self.policy = joblib.load(self.load_policy)['policy']
            self.init_opt()
            # initialize uninitialized vars  (only initialize vars that were not loaded)
            uninit_vars = []
            for var in tf.global_variables():
                # note - this is hacky, may be better way to do this in newer TF.
                try:
                    sess.run(var)
                except tf.errors.FailedPreconditionError:
                    uninit_vars.append(var)
            sess.run(tf.variables_initializer(uninit_vars))

            self.start_worker()
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):
                    logger.log("Sampling set of tasks/goals for this meta-batch...")

                    env = self.env
                    while 'sample_goals' not in dir(env):
                        env = env.wrapped_env
                    learner_env_goals = env.sample_goals(self.meta_batch_size)

                    self.policy.switch_to_init_dist()  # Switch to pre-update policy

                    all_samples_data, all_paths = [], []
                    for step in range(self.num_grad_updates+1):
                        #if step > 0:
                        #    import pdb; pdb.set_trace() # test param_vals functions.
                        logger.log('** Step ' + str(step) + ' **')
                        logger.log("Obtaining samples...")
                        paths = self.obtain_samples(itr, reset_args=learner_env_goals, log_prefix=str(step))
                        all_paths.append(paths)
                        logger.log("Processing samples...")
                        samples_data = {}
                        for key in paths.keys():  # the keys are the tasks
                            # don't log because this will spam the consol with every task.
                            samples_data[key] = self.process_samples(itr, paths[key], log=False)
                        all_samples_data.append(samples_data)
                        # for logging purposes only
                        self.process_samples(itr, flatten_list(paths.values()), prefix=str(step), log=True)
                        logger.log("Logging diagnostics...")
                        self.log_diagnostics(flatten_list(paths.values()), prefix=str(step))
                        if step < self.num_grad_updates:
                            logger.log("Computing policy updates...")
                            self.policy.compute_updated_dists(samples_data)


                    logger.log("Optimizing policy...")
                    # This needs to take all samples_data so that it can construct graph for meta-optimization.
                    self.optimize_policy(itr, all_samples_data)
                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(itr, all_samples_data[-1])  # , **kwargs)
                    if self.store_paths:
                        params["paths"] = all_samples_data[-1]["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime', time.time() - itr_start_time)

                    logger.dump_tabular(with_prefix=False)

                    # The rest is some example plotting code.
                    # Plotting code is useful for visualizing trajectories across a few different tasks.
                    if False and itr % 2 == 0 and self.env.observation_space.shape[0] <= 4: # point-mass
                        logger.log("Saving visualization of paths")
                        for ind in range(min(5, self.meta_batch_size)):
                            plt.clf()
                            plt.plot(learner_env_goals[ind][0], learner_env_goals[ind][1], 'k*', markersize=10)
                            plt.hold(True)

                            preupdate_paths = all_paths[0]
                            postupdate_paths = all_paths[-1]

                            pre_points = preupdate_paths[ind][0]['observations']
                            post_points = postupdate_paths[ind][0]['observations']
                            plt.plot(pre_points[:,0], pre_points[:,1], '-r', linewidth=2)
                            plt.plot(post_points[:,0], post_points[:,1], '-b', linewidth=1)

                            pre_points = preupdate_paths[ind][1]['observations']
                            post_points = postupdate_paths[ind][1]['observations']
                            plt.plot(pre_points[:,0], pre_points[:,1], '--r', linewidth=2)
                            plt.plot(post_points[:,0], post_points[:,1], '--b', linewidth=1)

                            pre_points = preupdate_paths[ind][2]['observations']
                            post_points = postupdate_paths[ind][2]['observations']
                            plt.plot(pre_points[:,0], pre_points[:,1], '-.r', linewidth=2)
                            plt.plot(post_points[:,0], post_points[:,1], '-.b', linewidth=1)

                            plt.plot(0,0, 'k.', markersize=5)
                            plt.xlim([-0.8, 0.8])
                            plt.ylim([-0.8, 0.8])
                            plt.legend(['goal', 'preupdate path', 'postupdate path'])
                            plt.savefig(osp.join(logger.get_snapshot_dir(), 'prepost_path'+str(ind)+'.png'))
                    elif False and itr % 2 == 0:  # swimmer or cheetah
                        logger.log("Saving visualization of paths")
                        for ind in range(min(5, self.meta_batch_size)):
                            plt.clf()
                            goal_vel = learner_env_goals[ind]
                            plt.title('Swimmer paths, goal vel='+str(goal_vel))
                            plt.hold(True)

                            prepathobs = all_paths[0][ind][0]['observations']
                            postpathobs = all_paths[-1][ind][0]['observations']
                            plt.plot(prepathobs[:,0], prepathobs[:,1], '-r', linewidth=2)
                            plt.plot(postpathobs[:,0], postpathobs[:,1], '--b', linewidth=1)
                            plt.plot(prepathobs[-1,0], prepathobs[-1,1], 'r*', markersize=10)
                            plt.plot(postpathobs[-1,0], postpathobs[-1,1], 'b*', markersize=10)
                            plt.xlim([-1.0, 5.0])
                            plt.ylim([-1.0, 1.0])

                            plt.legend(['preupdate path', 'postupdate path'], loc=2)
                            plt.savefig(osp.join(logger.get_snapshot_dir(), 'swim1d_prepost_itr'+str(itr)+'_id'+str(ind)+'.pdf'))
        self.shutdown_worker()
    def train(self):
        # TODO - make this a util
        flatten_list = lambda l: [item for sublist in l for item in sublist]

        with tf.Session() as sess:
            # Code for loading a previous policy. Somewhat hacky because needs to be in sess.
            if self.load_policy is not None:
                import joblib
                self.policy = joblib.load(self.load_policy)['policy']
            self.init_opt()
            # initialize uninitialized vars  (only initialize vars that were not loaded)
            uninit_vars = []
            for var in tf.global_variables():
                # note - this is hacky, may be better way to do this in newer TF.
                try:
                    sess.run(var)
                except tf.errors.FailedPreconditionError:
                    uninit_vars.append(var)
            sess.run(tf.variables_initializer(uninit_vars))

            self.start_worker()
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):
                    logger.log(
                        "Sampling set of tasks/goals for this meta-batch...")

                    env = self.env
                    while 'sample_goals' not in dir(env):
                        env = env.wrapped_env
                    learner_env_goals = env.sample_goals(self.meta_batch_size)

                    self.policy.switch_to_init_dist(
                    )  # Switch to pre-update policy

                    all_samples_data, all_paths = [], []
                    for step in range(self.num_grad_updates + 1):
                        #if step > 0:
                        #    import pdb; pdb.set_trace() # test param_vals functions.
                        logger.log('** Step ' + str(step) + ' **')
                        logger.log("Obtaining samples...")
                        paths = self.obtain_samples(
                            itr,
                            reset_args=learner_env_goals,
                            log_prefix=str(step))
                        all_paths.append(paths)
                        logger.log("Processing samples...")
                        samples_data = {}
                        for key in paths.keys():  # the keys are the tasks
                            # don't log because this will spam the consol with every task.
                            samples_data[key] = self.process_samples(
                                itr, paths[key], log=False)
                        all_samples_data.append(samples_data)
                        # for logging purposes only
                        self.process_samples(itr,
                                             flatten_list(paths.values()),
                                             prefix=str(step),
                                             log=True)
                        logger.log("Logging diagnostics...")
                        self.log_diagnostics(flatten_list(paths.values()),
                                             prefix=str(step))
                        if step < self.num_grad_updates:
                            logger.log("Computing policy updates...")
                            self.policy.compute_updated_dists(samples_data)

                    logger.log("Optimizing policy...")
                    # This needs to take all samples_data so that it can construct graph for meta-optimization.
                    self.optimize_policy(itr, all_samples_data)
                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(
                        itr, all_samples_data[-1])  # , **kwargs)
                    if self.store_paths:
                        params["paths"] = all_samples_data[-1]["paths"]
                    #logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime',
                                          time.time() - itr_start_time)

                    logger.dump_tabular(with_prefix=False)

                    # The rest is some example plotting code.
                    # Plotting code is useful for visualizing trajectories across a few different tasks.
                    if False and itr % 2 == 0 and self.env.observation_space.shape[
                            0] <= 4:  # point-mass
                        logger.log("Saving visualization of paths")
                        for ind in range(min(5, self.meta_batch_size)):
                            plt.clf()
                            plt.plot(learner_env_goals[ind][0],
                                     learner_env_goals[ind][1],
                                     'k*',
                                     markersize=10)
                            plt.hold(True)

                            preupdate_paths = all_paths[0]
                            postupdate_paths = all_paths[-1]

                            pre_points = preupdate_paths[ind][0][
                                'observations']
                            post_points = postupdate_paths[ind][0][
                                'observations']
                            plt.plot(pre_points[:, 0],
                                     pre_points[:, 1],
                                     '-r',
                                     linewidth=2)
                            plt.plot(post_points[:, 0],
                                     post_points[:, 1],
                                     '-b',
                                     linewidth=1)

                            pre_points = preupdate_paths[ind][1][
                                'observations']
                            post_points = postupdate_paths[ind][1][
                                'observations']
                            plt.plot(pre_points[:, 0],
                                     pre_points[:, 1],
                                     '--r',
                                     linewidth=2)
                            plt.plot(post_points[:, 0],
                                     post_points[:, 1],
                                     '--b',
                                     linewidth=1)

                            pre_points = preupdate_paths[ind][2][
                                'observations']
                            post_points = postupdate_paths[ind][2][
                                'observations']
                            plt.plot(pre_points[:, 0],
                                     pre_points[:, 1],
                                     '-.r',
                                     linewidth=2)
                            plt.plot(post_points[:, 0],
                                     post_points[:, 1],
                                     '-.b',
                                     linewidth=1)

                            plt.plot(0, 0, 'k.', markersize=5)
                            plt.xlim([-0.8, 0.8])
                            plt.ylim([-0.8, 0.8])
                            plt.legend(
                                ['goal', 'preupdate path', 'postupdate path'])
                            plt.savefig(
                                osp.join(logger.get_snapshot_dir(),
                                         'prepost_path' + str(ind) + '.png'))
                    elif False and itr % 2 == 0:  # swimmer or cheetah
                        logger.log("Saving visualization of paths")
                        for ind in range(min(5, self.meta_batch_size)):
                            plt.clf()
                            goal_vel = learner_env_goals[ind]
                            plt.title('Swimmer paths, goal vel=' +
                                      str(goal_vel))
                            plt.hold(True)

                            prepathobs = all_paths[0][ind][0]['observations']
                            postpathobs = all_paths[-1][ind][0]['observations']
                            plt.plot(prepathobs[:, 0],
                                     prepathobs[:, 1],
                                     '-r',
                                     linewidth=2)
                            plt.plot(postpathobs[:, 0],
                                     postpathobs[:, 1],
                                     '--b',
                                     linewidth=1)
                            plt.plot(prepathobs[-1, 0],
                                     prepathobs[-1, 1],
                                     'r*',
                                     markersize=10)
                            plt.plot(postpathobs[-1, 0],
                                     postpathobs[-1, 1],
                                     'b*',
                                     markersize=10)
                            plt.xlim([-1.0, 5.0])
                            plt.ylim([-1.0, 1.0])

                            plt.legend(['preupdate path', 'postupdate path'],
                                       loc=2)
                            plt.savefig(
                                osp.join(
                                    logger.get_snapshot_dir(),
                                    'swim1d_prepost_itr' + str(itr) + '_id' +
                                    str(ind) + '.pdf'))
        self.shutdown_worker()
Exemple #32
0
def run_task(v):
    random.seed(v['seed'])
    np.random.seed(v['seed'])

    # Log performance of randomly initialized policy with FIXED goal [0.1, 0.1]
    logger.log("Initializing report...")
    log_dir = logger.get_snapshot_dir()  # problem with logger module here!!
    if log_dir is None:
        log_dir = "/home/michael/"
    report = HTMLReport(osp.join(log_dir, 'report.html'), images_per_row=3)

    report.add_header("{}".format(EXPERIMENT_TYPE))
    report.add_text(format_dict(v))

    inner_env = normalize(AntMazeEnv())

    fixed_goal_generator = FixedStateGenerator(state=v['ultimate_goal'])
    fixed_start_generator = FixedStateGenerator(state=v['ultimate_goal'])

    env = GoalStartExplorationEnv(
        env=inner_env,
        start_generator=fixed_start_generator,
        obs2start_transform=lambda x: x[:v['start_size']],
        goal_generator=fixed_goal_generator,
        obs2goal_transform=lambda x: x[-3:-1],
        terminal_eps=v['terminal_eps'],
        distance_metric=v['distance_metric'],
        extend_dist_rew=v['extend_dist_rew'],
        inner_weight=v['inner_weight'],
        goal_weight=v['goal_weight'],
        terminate_env=True,
    )

    policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=(64, 64),
        # Fix the variance since different goals will require different variances, making this parameter hard to learn.
        learn_std=v['learn_std'],
        adaptive_std=v['adaptive_std'],
        std_hidden_sizes=(16, 16),  # this is only used if adaptive_std is true!
        output_gain=v['output_gain'],
        init_std=v['policy_init_std'],
    )

    #baseline = LinearFeatureBaseline(env_spec=env.spec)
    if v["baseline"] == "MLP":
        baseline = GaussianMLPBaseline(env_spec=env.spec)
    else:
        baseline = LinearFeatureBaseline(env_spec=env.spec)

    # load the state collection from data_upload

    all_starts = StateCollection(distance_threshold=v['coll_eps'], states_transform=lambda x: x[:, :2])

    # can also filter these starts optionally

    load_dir = 'sandbox/young_clgan/experiments/starts/maze/maze_ant/'
    all_feasible_starts = pickle.load(
        open(osp.join(config.PROJECT_PATH, load_dir, 'good_all_feasible_starts.pkl'), 'rb'))
    logger.log("We have %d feasible starts" % all_feasible_starts.size)

    min_reward = 0.1
    max_reward = 0.9
    improvement_threshold = 0
    old_rewards = None

    # hardest to easiest
    init_pos = [[0, 0],
                [1, 0],
                [2, 0],
                [3, 0],
                [4, 0],
                [4, 1],
                [4, 2],
                [4, 3],
                [4, 4],
                [3, 4],
                [2, 4],
                [1, 4]
                ][::-1]
    for pos in init_pos:
        pos.extend([0.55, 1, 0, 0, 0, 0, 1, 0, -1, 0, -1, 0, 1, ])
    array_init_pos = np.array(init_pos)
    init_pos = [tuple(pos) for pos in init_pos]
    online_start_generator = Online_TCSL(init_pos)


    for outer_iter in range(1, v['outer_iters']):

        logger.log("Outer itr # %i" % outer_iter)
        logger.log("Sampling starts")

        report.save()

        # generate starts from the previous seed starts, which are defined below
        dist = online_start_generator.get_distribution() # added
        logger.log(np.array_str(online_start_generator.get_q()))
        # how to log Q values?
        # with logger.tabular_prefix("General: "):
        #     logger.record_tabular("Q values:", online_start_generator.get_q())
        logger.log(np.array_str(dist))

        # Following code should be indented
        with ExperimentLogger(log_dir, outer_iter // 50, snapshot_mode='last', hold_outter_log=True):
            logger.log("Updating the environment start generator")
            #TODO: might be faster to sample if we just create a roughly representative UniformListStateGenerator?
            env.update_start_generator(
                ListStateGenerator(
                    init_pos, dist
                )
            )

            logger.log("Training the algorithm")
            algo = TRPO(
                env=env,
                policy=policy,
                baseline=baseline,
                batch_size=v['pg_batch_size'],
                max_path_length=v['horizon'],
                n_itr=v['inner_iters'],
                step_size=0.01,
                discount=v['discount'],
                plot=False,
            )

            trpo_paths = algo.train()



        logger.log("Labeling the starts")
        [starts, labels, mean_rewards, updated] = label_states_from_paths(trpo_paths, n_traj=v['n_traj'], key='goal_reached',  # using the min n_traj
                                                   as_goal=False, env=env, return_mean_rewards=True, order_of_states=init_pos)

        start_classes, text_labels = convert_label(labels)
        plot_labeled_states(starts, labels, report=report, itr=outer_iter, limit=v['goal_range'],
                            center=v['goal_center'], maze_id=v['maze_id'])

        online_start_generator.update_q(np.array(mean_rewards), np.array(updated)) # added
        labels = np.logical_and(labels[:, 0], labels[:, 1]).astype(int).reshape((-1, 1))

        # append new states to list of all starts (replay buffer): Not the low reward ones!!
        filtered_raw_starts = [start for start, label in zip(starts, labels) if label[0] == 1]

        if v['seed_with'] == 'only_goods':
            if len(filtered_raw_starts) > 0:  # add a ton of noise if all the states I had ended up being high_reward!
                logger.log("We have {} good starts!".format(len(filtered_raw_starts)))
                seed_starts = filtered_raw_starts
            elif np.sum(start_classes == 0) > np.sum(start_classes == 1):  # if more low reward than high reward
                logger.log("More bad starts than good starts, sampling seeds from replay buffer")
                seed_starts = all_starts.sample(300)  # sample them from the replay
            else:
                logger.log("More good starts than bad starts, resampling")
                seed_starts = generate_starts(env, starts=starts, horizon=v['horizon'] * 2, subsample=v['num_new_starts'], size=10000,
                                              variance=v['brownian_variance'] * 10)
        elif v['seed_with'] == 'all_previous':
            seed_starts = starts
        else:
            raise Exception

        all_starts.append(filtered_raw_starts)

        # need to put this last! otherwise labels variable gets confused
        logger.log("Labeling on uniform starts")
        with logger.tabular_prefix("Uniform_"):
            unif_starts = all_feasible_starts.sample(100)
            mean_reward, paths = evaluate_states(unif_starts, env, policy, v['horizon'], n_traj=v['n_traj'], key='goal_reached',
                                                 as_goals=False, full_path=True)
            env.log_diagnostics(paths)
            mean_rewards = mean_reward.reshape(-1, 1)
            labels = compute_labels(mean_rewards, old_rewards=old_rewards, min_reward=min_reward, max_reward=max_reward,
                                    improvement_threshold=improvement_threshold)
            logger.log("Starts labelled")
            plot_labeled_states(unif_starts, labels, report=report, itr=outer_iter, limit=v['goal_range'],
                                center=v['goal_center'], maze_id=v['maze_id'],
                                summary_string_base='initial starts labels:\n')
            # report.add_text("Success: " + str(np.mean(mean_reward)))

        with logger.tabular_prefix("Fixed_"):
            mean_reward, paths = evaluate_states(array_init_pos, env, policy, v['horizon'], n_traj=5, key='goal_reached',
                                                 as_goals=False, full_path=True)
            env.log_diagnostics(paths)
            mean_rewards = mean_reward.reshape(-1, 1)
            labels = compute_labels(mean_rewards, old_rewards=old_rewards, min_reward=min_reward, max_reward=max_reward,
                                    improvement_threshold=improvement_threshold)
            logger.log("Starts labelled")
            plot_labeled_states(array_init_pos, labels, report=report, itr=outer_iter, limit=v['goal_range'],
                                center=v['goal_center'], maze_id=v['maze_id'],
                                summary_string_base='initial starts labels:\n')
            report.add_text("Fixed Success: " + str(np.mean(mean_reward)))

        report.new_row()
        report.save()
        logger.record_tabular("Fixed test set_success: ", np.mean(mean_reward))
        logger.dump_tabular()
def plot_state(self, name='sensors', state=None):
    if state:
        self.wrapped_env.reset(state)

    structure = self.__class__.MAZE_STRUCTURE
    size_scaling = self.__class__.MAZE_SIZE_SCALING
    # duplicate cells to plot the maze
    structure_plot = np.zeros(
        ((len(structure) - 1) * 2, (len(structure[0]) - 1) * 2))
    for i in range(len(structure)):
        for j in range(len(structure[0])):
            cell = structure[i][j]
            if type(cell) is not int:
                cell = 0.3 if cell == 'r' else 0.7
            if i == 0:
                if j == 0:
                    structure_plot[i, j] = cell
                elif j == len(structure[0]) - 1:
                    structure_plot[i, 2 * j - 1] = cell
                else:
                    structure_plot[i, 2 * j - 1:2 * j + 1] = cell
            elif i == len(structure) - 1:
                if j == 0:
                    structure_plot[2 * i - 1, j] = cell
                elif j == len(structure[0]) - 1:
                    structure_plot[2 * i - 1, 2 * j - 1] = cell
                else:
                    structure_plot[2 * i - 1, 2 * j - 1:2 * j + 1] = cell
            else:
                if j == 0:
                    structure_plot[2 * i - 1:2 * i + 1, j] = cell
                elif j == len(structure[0]) - 1:
                    structure_plot[2 * i - 1:2 * i + 1, 2 * j - 1] = cell
                else:
                    structure_plot[2 * i - 1:2 * i + 1,
                                   2 * j - 1:2 * j + 1] = cell

    fig, ax = plt.subplots()
    im = ax.pcolor(-np.array(structure_plot),
                   cmap='gray',
                   edgecolor='black',
                   linestyle=':',
                   lw=1)
    x_labels = list(range(len(structure[0])))
    y_labels = list(range(len(structure)))
    ax.grid(True)  # elimiate this to avoid inner lines

    ax.xaxis.set(ticks=2 * np.arange(len(x_labels)), ticklabels=x_labels)
    ax.yaxis.set(ticks=2 * np.arange(len(y_labels)), ticklabels=y_labels)

    obs = self.get_current_maze_obs()

    robot_xy = np.array(self.wrapped_env.get_body_com("torso")
                        [:2])  # the coordinates of this are wrt the init
    ori = self.get_ori(
    )  # for Ant this is computed with atan2, which gives [-pi, pi]

    # compute origin cell i_o, j_o coordinates and center of it x_o, y_o (with 0,0 in the top-right corner of struc)
    o_xy = np.array(self._find_robot(
    ))  # this is self.init_torso_x, self.init_torso_y: center of the cell xy!
    o_ij = (o_xy / size_scaling).astype(
        int)  # this is the position in the grid

    o_xy_plot = o_xy / size_scaling * 2
    robot_xy_plot = o_xy_plot + robot_xy / size_scaling * 2

    plt.scatter(*robot_xy_plot)

    for ray_idx in range(self._n_bins):
        length_wall = self._sensor_range - obs[
            ray_idx] * self._sensor_range if obs[ray_idx] else 1e-6
        ray_ori = ori - self._sensor_span * 0.5 + ray_idx / (
            self._n_bins - 1) * self._sensor_span
        if ray_ori > math.pi:
            ray_ori -= 2 * math.pi
        elif ray_ori < -math.pi:
            ray_ori += 2 * math.pi
        # find the end point wall
        end_xy = (
            robot_xy + length_wall *
            np.array([math.cos(ray_ori), math.sin(ray_ori)]))
        end_xy_plot = (o_ij + end_xy / size_scaling) * 2
        plt.plot([robot_xy_plot[0], end_xy_plot[0]],
                 [robot_xy_plot[1], end_xy_plot[1]], 'r')

        length_goal = self._sensor_range - obs[
            ray_idx + self._n_bins] * self._sensor_range if obs[
                ray_idx + self._n_bins] else 1e-6
        ray_ori = ori - self._sensor_span * 0.5 + ray_idx / (
            self._n_bins - 1) * self._sensor_span
        # find the end point goal
        end_xy = (
            robot_xy + length_goal *
            np.array([math.cos(ray_ori), math.sin(ray_ori)]))
        end_xy_plot = (o_ij + end_xy / size_scaling) * 2
        plt.plot([robot_xy_plot[0], end_xy_plot[0]],
                 [robot_xy_plot[1], end_xy_plot[1]], 'g')

    log_dir = logger.get_snapshot_dir()
    ax.set_title('sensors: ' + name)

    plt.savefig(osp.join(
        log_dir,
        name + '_sesors.png'))  # this saves the current figure, here f
    plt.close()
Exemple #34
0
             start_state=start_state,
             goal_state=goal_state))
env._wrapped_env.generate_grid = False
env._wrapped_env.generate_b0_start_goal = False
env.reset()

log_dir = "./Data/obs_1goal20step0stay_1_gru"

tabular_log_file = osp.join(log_dir, "progress.csv")
text_log_file = osp.join(log_dir, "debug.log")
params_log_file = osp.join(log_dir, "params.json")
pkl_file = osp.join(log_dir, "params.pkl")

logger.add_text_output(text_log_file)
logger.add_tabular_output(tabular_log_file)
prev_snapshot_dir = logger.get_snapshot_dir()
prev_mode = logger.get_snapshot_mode()
logger.set_snapshot_dir(log_dir)
logger.set_snapshot_mode("gaplast")
logger.set_snapshot_gap(1000)
logger.set_log_tabular_only(False)
logger.push_prefix("[%s] " % "FixMapStartState")

from Algo import parallel_sampler
parallel_sampler.initialize(n_parallel=1)
parallel_sampler.set_seed(0)

policy = QMDPPolicy(env_spec=env.spec,
                    name="QMDP",
                    qmdp_param=env._wrapped_env.params)
Exemple #35
0
def run_task(v):
    random.seed(v['seed'])
    np.random.seed(v['seed'])
    logger.log("Initializing report...")
    log_dir = logger.get_snapshot_dir()
    if log_dir is None:
        log_dir = "/home/michael/"
    report = HTMLReport(osp.join(log_dir, 'report.html'), images_per_row=1)

    fixed_goal_generator = FixedStateGenerator(state=v['ultimate_goal'])
    fixed_start_generator = FixedStateGenerator(state=v['ultimate_goal'])

    inner_env = normalize(AntMazeEnv())
    env = GoalStartExplorationEnv(
        env=inner_env,
        start_generator=fixed_start_generator,
        obs2start_transform=lambda x: x[:v['start_size']],
        goal_generator=fixed_goal_generator,
        obs2goal_transform=lambda x: x[-3:-1],
        terminal_eps=v['terminal_eps'],
        distance_metric=v['distance_metric'],
        extend_dist_rew=v['extend_dist_rew'],
        inner_weight=v['inner_weight'],
        goal_weight=v['goal_weight'],
        terminate_env=True,
    )

    seed_starts = generate_starts(
        env,
        starts=[v['start_goal']],
        horizon=v['initial_brownian_horizon'],  #TODO: increase to 2000
        size=1000,
        # size speeds up training a bit
        variance=v['brownian_variance'],
        animated=True,
        subsample=v['num_new_starts'],
    )  # , animated=True, speedup=1)
    np.random.shuffle(seed_starts)

    logger.log("Prefilter seed starts: {}".format(len(seed_starts)))
    # breaks code
    # starts = seed_starts
    # # hack to not print code
    # seed_starts = [start for start in seed_starts if check_feasibility(start, env, 10)]
    # starts = np.array(starts)
    # seed_starts = starts
    logger.log("Filtered seed starts: {}".format(len(seed_starts)))

    # with env.set_kill_outside():
    feasible_states = find_all_feasible_states_plotting(
        env,
        seed_starts,
        report,
        distance_threshold=0.1,
        brownian_variance=1,
        size=8000,
        animate=True,
        limit=v['goal_range'],
        check_feasible=True,
        check_feasible_path_length=500,
        center=v['goal_center'])
    return
Exemple #36
0
def run_experiment(argv):

    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel', type=int, default=1,
                        help='Number of parallel workers to perform rollouts.')
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
    parser.add_argument('--log_dir', type=str, default=default_log_dir,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--plot', type=ast.literal_eval, default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')

    args = parser.parse_args(argv[1:])

    from sandbox.vime.sampler import parallel_sampler_expl as parallel_sampler
    parallel_sampler.initialize(n_parallel=args.n_parallel)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    # read from stdin
    data = pickle.loads(base64.b64decode(args.args_data))

    log_dir = args.log_dir
    # exp_dir = osp.join(log_dir, args.exp_name)
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    maybe_iter = concretize(data)
    if is_iterable(maybe_iter):
        for _ in maybe_iter:
            pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Exemple #37
0
def find_all_feasible_states_plotting(env,
                                      seed_starts,
                                      report,
                                      distance_threshold=0.1,
                                      size=10000,
                                      horizon=300,
                                      brownian_variance=1,
                                      animate=False,
                                      num_samples=100,
                                      limit=None,
                                      center=None,
                                      fast=True,
                                      check_feasible=True,
                                      check_feasible_path_length=50):
    """
    Generates states for two maze environments (ant and swimmer)
    :param env:
    :param seed_starts:
    :param report:
    :param distance_threshold: min distance between states
    :param size:
    :param horizon:
    :param brownian_variance:
    :param animate:
    :param num_samples: number of samples produced every iteration
    :param limit:
    :param center:
    :param fast:
    :param check_feasible:
    :param check_feasible_path_length:
    :return:
    """
    # If fast is True, we sample half the states from the last set generated and half from all previous generated
    # label some states generated from last iteration and some from all
    log_dir = logger.get_snapshot_dir()
    if log_dir is None:
        log_dir = "/home/michael/"

    iteration = 0
    # use only first two coordinates (so in fransformed space
    all_feasible_starts = StateCollection(
        distance_threshold=distance_threshold,
        states_transform=lambda x: x[:, :2])
    all_feasible_starts.append(seed_starts)
    all_starts_samples = all_feasible_starts.sample(num_samples)
    text_labels = OrderedDict({
        0: 'New starts',
        1: 'Old sampled starts',
        2: 'Other'
    })
    img = plot_labeled_samples(
        samples=all_starts_samples[:, :2],  # first two are COM
        sample_classes=np.zeros(num_samples, dtype=int),
        text_labels=text_labels,
        limit=limit,
        center=center,
        maze_id=0,
    )
    report.add_image(img, 'itr: {}\n'.format(iteration), width=500)
    report.save()

    no_new_states = 0
    while no_new_states < 30:
        iteration += 1
        logger.log("Iteration: {}".format(iteration))
        total_num_starts = all_feasible_starts.size
        starts = all_feasible_starts.sample(num_samples)

        # definitely want to initialize from new generated states, roughtly half proportion of both
        if fast and iteration > 1:
            print(len(added_states))
            if len(added_states) > 0:
                while len(starts) < 1.5 * num_samples:
                    starts = np.concatenate((starts, added_states), axis=0)
        new_starts = generate_starts(env,
                                     starts=starts,
                                     horizon=horizon,
                                     size=size,
                                     variance=brownian_variance,
                                     animated=animate,
                                     speedup=50)
        # filters starts so that we only keep the good starts
        if check_feasible:  # used for ant maze environment, where we ant to run no_action
            logger.log("Prefilteredstarts: {}".format(len(new_starts)))
            new_starts = parallel_check_feasibility(
                env=env,
                starts=new_starts,
                max_path_length=check_feasible_path_length)
            # new_starts = [start for start in new_starts if check_feasibility(env, start, check_feasible_path_length)]
            logger.log("Filtered starts: {}".format(len(new_starts)))
        all_starts_samples = all_feasible_starts.sample(num_samples)
        added_states = all_feasible_starts.append(new_starts)
        num_new_starts = len(added_states)
        logger.log("number of new states: " + str(num_new_starts))
        if num_new_starts < 3:
            no_new_states += 1
        with open(osp.join(log_dir, 'all_feasible_states.pkl'), 'wb') as f:
            cloudpickle.dump(all_feasible_starts, f, protocol=3)

        # want to plot added_states and old sampled starts
        img = plot_labeled_samples(
            samples=np.concatenate(
                (added_states[:, :2], all_starts_samples[:, :2]),
                axis=0),  # first two are COM
            sample_classes=np.concatenate((np.zeros(
                num_new_starts, dtype=int), np.ones(num_samples, dtype=int)),
                                          axis=0),
            text_labels=text_labels,
            limit=limit,
            center=center,
            maze_id=0,
        )  # fine if sample classes is longer

        report.add_image(img, 'itr: {}\n'.format(iteration), width=500)
        report.add_text("number of new states: " + str(num_new_starts))
        report.save()
        # break

    all_starts_samples = all_feasible_starts.sample(all_feasible_starts.size)
    img = plot_labeled_samples(
        samples=all_starts_samples,
        # first two are COM
        sample_classes=np.ones(all_feasible_starts.size, dtype=int),
        text_labels=text_labels,
        limit=limit,
        center=center,
        maze_id=0,
    )  # fine if sample classes is longer

    report.add_image(img, 'itr: {}\n'.format(iteration), width=500)
    report.add_text("Total number of states: " + str(all_feasible_starts.size))
    report.save()
Exemple #38
0
    def setup(self, env, policy, start_itr):

        if not self.args.algo == 'thddpg':
            # Baseline
            if self.args.baseline_type == 'linear':
                baseline = LinearFeatureBaseline(env_spec=env.spec)
            elif self.args.baseline_type == 'zero':
                baseline = ZeroBaseline(env_spec=env.spec)
            else:
                raise NotImplementedError(self.args.baseline_type)

            if self.args.control == 'concurrent':
                baseline = [baseline for _ in range(len(env.agents))]
        # Logger
        default_log_dir = config.LOG_DIR
        if self.args.log_dir is None:
            log_dir = osp.join(default_log_dir, self.args.exp_name)
        else:
            log_dir = self.args.log_dir

        tabular_log_file = osp.join(log_dir, self.args.tabular_log_file)
        text_log_file = osp.join(log_dir, self.args.text_log_file)
        params_log_file = osp.join(log_dir, self.args.params_log_file)

        logger.log_parameters_lite(params_log_file, self.args)
        logger.add_text_output(text_log_file)
        logger.add_tabular_output(tabular_log_file)
        prev_snapshot_dir = logger.get_snapshot_dir()
        prev_mode = logger.get_snapshot_mode()
        logger.set_snapshot_dir(log_dir)
        logger.set_snapshot_mode(self.args.snapshot_mode)
        logger.set_log_tabular_only(self.args.log_tabular_only)
        logger.push_prefix("[%s] " % self.args.exp_name)

        if self.args.algo == 'tftrpo':
            algo = MATRPO(env=env, policy_or_policies=policy, baseline_or_baselines=baseline,
                          batch_size=self.args.batch_size, start_itr=start_itr,
                          max_path_length=self.args.max_path_length, n_itr=self.args.n_iter,
                          discount=self.args.discount, gae_lambda=self.args.gae_lambda,
                          step_size=self.args.step_size, optimizer=ConjugateGradientOptimizer(
                              hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)) if
                          self.args.recurrent else None, ma_mode=self.args.control)
        elif self.args.algo == 'thddpg':
            qfunc = thContinuousMLPQFunction(env_spec=env.spec)
            if self.args.exp_strategy == 'ou':
                es = OUStrategy(env_spec=env.spec)
            elif self.args.exp_strategy == 'gauss':
                es = GaussianStrategy(env_spec=env.spec)
            else:
                raise NotImplementedError()

            algo = thDDPG(env=env, policy=policy, qf=qfunc, es=es, batch_size=self.args.batch_size,
                          max_path_length=self.args.max_path_length,
                          epoch_length=self.args.epoch_length,
                          min_pool_size=self.args.min_pool_size,
                          replay_pool_size=self.args.replay_pool_size, n_epochs=self.args.n_iter,
                          discount=self.args.discount, scale_reward=0.01,
                          qf_learning_rate=self.args.qfunc_lr,
                          policy_learning_rate=self.args.policy_lr,
                          eval_samples=self.args.eval_samples, mode=self.args.control)
        return algo
Exemple #39
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', type=str, default=default_exp_name,
                        help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)
    parser.add_argument('--reward_scale', type=float, default=1.0)
    parser.add_argument('--enable_obsnorm', action='store_true', default=False)
    parser.add_argument('--chunked', action='store_true', default=False)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum', action='store_true', default=False)
    parser.add_argument('--anneal_step_size', type=int, default=0)

    parser.add_argument('--n_timesteps', type=int, default=8000)

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--radius', type=float, default=0.015)
    parser.add_argument('--n_evaders', type=int, default=10)
    parser.add_argument('--n_pursuers', type=int, default=8)
    parser.add_argument('--n_poison', type=int, default=10)
    parser.add_argument('--n_coop', type=int, default=4)
    parser.add_argument('--n_sensors', type=int, default=30)
    parser.add_argument('--sensor_range', type=str, default='0.2')
    parser.add_argument('--food_reward', type=float, default=5)
    parser.add_argument('--poison_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.05)
    parser.add_argument('--reward_mech', type=str, default='local')

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--baseline_type', type=str, default='linear')
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_hidden_sizes', type=str, default='128,128')

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data', type=str, help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument(
        '--log_tabular_only', type=ast.literal_eval, default=False,
        help='Whether to only print the tabular log information (in a horizontal format)')

    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    centralized = True if args.control == 'centralized' else False

    sensor_range = np.array(map(float, args.sensor_range.split(',')))
    if len(sensor_range) == 1:
        sensor_range = sensor_range[0]
    else:
        assert sensor_range.shape == (args.n_pursuers,)

    env = MAWaterWorld(args.n_pursuers, args.n_evaders, args.n_coop, args.n_poison,
                       radius=args.radius, n_sensors=args.n_sensors, food_reward=args.food_reward,
                       poison_reward=args.poison_reward, encounter_reward=args.encounter_reward,
                       reward_mech=args.reward_mech, sensor_range=sensor_range, obstacle_loc=None)

    env = TfEnv(
        RLLabEnv(
            StandardizedEnv(env, scale_reward=args.reward_scale,
                            enable_obsnorm=args.enable_obsnorm), mode=args.control))

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if args.recurrent:
        feature_network = MLP(
            name='feature_net',
            input_shape=(env.spec.observation_space.flat_dim + env.spec.action_space.flat_dim,),
            output_dim=16, hidden_sizes=(128, 64, 32), hidden_nonlinearity=tf.nn.tanh,
            output_nonlinearity=None)
        if args.recurrent == 'gru':
            policy = GaussianGRUPolicy(env_spec=env.spec, feature_network=feature_network,
                                       hidden_dim=int(args.policy_hidden_sizes), name='policy')
        elif args.recurrent == 'lstm':
            policy = GaussianLSTMPolicy(env_spec=env.spec, feature_network=feature_network,
                                        hidden_dim=int(args.policy_hidden_sizes), name='policy')
    else:
        policy = GaussianMLPPolicy(
            name='policy', env_spec=env.spec,
            hidden_sizes=tuple(map(int, args.policy_hidden_sizes.split(','))), min_std=10e-5)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    elif args.baseline_type == 'mlp':
        raise NotImplementedError()
        # baseline = GaussianMLPBaseline(
        #     env_spec=env.spec, hidden_sizes=tuple(map(int, args.baseline_hidden_sizes.split(','))))
    else:
        baseline = ZeroBaseline(env_spec=env.spec)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        #max_path_length_limit=args.max_path_length_limit,
        update_max_path_length=args.update_curriculum,
        anneal_step_size=args.anneal_step_size,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)) if
        args.recurrent else None,
        mode=args.control if not args.chunked else 'chunk_{}'.format(args.control),)

    algo.train()
Exemple #40
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum', action='store_true', default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--n_good', type=int, default=3)
    parser.add_argument('--n_hostage', type=int, default=5)
    parser.add_argument('--n_bad', type=int, default=5)
    parser.add_argument('--n_coop_save', type=int, default=2)
    parser.add_argument('--n_coop_avoid', type=int, default=2)
    parser.add_argument('--n_sensors', type=int, default=20)
    parser.add_argument('--sensor_range', type=float, default=0.2)
    parser.add_argument('--save_reward', type=float, default=3)
    parser.add_argument('--hit_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.01)
    parser.add_argument('--bomb_reward', type=float, default=-10.)

    parser.add_argument('--recurrent', action='store_true', default=False)
    parser.add_argument('--baseline_type', type=str, default='linear')
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')


    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    centralized = True if args.control == 'centralized' else False

    sensor_range = np.array(map(float, args.sensor_range.split(',')))
    assert sensor_range.shape == (args.n_pursuers,)

    env = ContinuousHostageWorld(args.n_good, args.n_hostage, args.n_bad, args.n_coop_save,
                                 args.n_coop_avoid, n_sensors=args.n_sensors,
                                 sensor_range=args.sensor_range, save_reward=args.save_reward,
                                 hit_reward=args.hit_reward, encounter_reward=args.encounter_reward,
                                 bomb_reward=args.bomb_reward)

    env = RLLabEnv(StandardizedEnv(env), mode=args.control)

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if args.recurrent:
        policy = GaussianGRUPolicy(env_spec=env.spec, hidden_sizes=args.hidden_sizes)
    else:
        policy = GaussianMLPPolicy(env_spec=env.spec, hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(obsfeat_space)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(env=env,
            policy=policy,
            baseline=baseline,
            batch_size=args.n_timesteps,
            max_path_length=args.max_traj_len,
            n_itr=args.n_iter,
            discount=args.discount,
            step_size=args.max_kl,
            mode=args.control,)

    algo.train()
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel', type=int, default=1,
                        help='Number of parallel workers to perform rollouts. 0 => don\'t start any workers')
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
    parser.add_argument('--log_dir', type=str, default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), "gap" (every'
                             '`snapshot_gap` iterations are saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--snapshot_gap', type=int, default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file', type=str, default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument('--resume_from', type=str, default=None,
                        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot', type=ast.literal_eval, default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data', type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()