コード例 #1
0
ファイル: log_utils.py プロジェクト: saadmahboob/inverse_rl
def rllab_logdir(algo=None, dirname=None):
    if dirname:
        rllablogger.set_snapshot_dir(dirname)
    dirname = rllablogger.get_snapshot_dir()
    rllablogger.add_tabular_output(os.path.join(dirname, 'progress.csv'))
    if algo:
        with open(os.path.join(dirname, 'params.json'), 'w') as f:
            params = extract_hyperparams(algo)
            json.dump(params, f)
    yield dirname
    rllablogger.remove_tabular_output(os.path.join(dirname, 'progress.csv'))
                    default=1e-5,
                    type=float,
                    help="Regularization coefficient for TRPO")
parser.add_argument("--text_log_file",
                    default="./data/debug.log",
                    help="Where text output will go")
parser.add_argument("--tabular_log_file",
                    default="./data/progress.csv",
                    help="Where tabular output will go")
args = parser.parse_args()

# stub(globals())

# ext.set_seed(1)
logger.add_text_output(args.text_log_file)
logger.add_tabular_output(args.tabular_log_file)
logger.set_log_tabular_only(False)

envs = []

for env_name in args.envs:
    gymenv = GymEnv(env_name,
                    force_reset=True,
                    record_video=False,
                    record_log=False)
    env = TfEnv(normalize(gymenv))
    envs.append((env_name, env))

policy = GaussianMLPPolicy(
    name="policy",
    env_spec=env.spec,
コード例 #3
0
from rllab.misc import logger
import os.path as osp
import tensorflow as tf
from sandbox.rocky.tf.samplers.batch_sampler import BatchSampler
import joblib

env_path = "./TrainEnv"
log_dir = "./Data/obs_1goal20step0stay_1_kdist_01_keep5"

tabular_log_file = osp.join(log_dir, "progress.csv")
text_log_file = osp.join(log_dir, "debug.log")
params_log_file = osp.join(log_dir, "params.json")
pkl_file = osp.join(log_dir, "params.pkl")

logger.add_text_output(text_log_file)
logger.add_tabular_output(tabular_log_file)
prev_snapshot_dir = logger.get_snapshot_dir()
prev_mode = logger.get_snapshot_mode()
logger.set_snapshot_dir(log_dir)
logger.set_snapshot_mode("gaplast")
logger.set_snapshot_gap(100)
logger.set_log_tabular_only(False)
logger.push_prefix("[%s] " % "FixMapStartState")

from Algo import parallel_sampler
parallel_sampler.initialize(n_parallel=1)
parallel_sampler.set_seed(0)

with tf.Session() as sess:
    params = joblib.load(log_dir+'/params.pkl')
    itr=params['itr']
コード例 #4
0
        if seed is not None:
            set_seed(seed)
            parallel_sampler.set_seed(seed)

        if plot:
            from rllab.plotter import plotter
            plotter.init_worker()

        tabular_log_file_fullpath = osp.join(log_dir, tabular_log_file)
        text_log_file_fullpath = osp.join(log_dir, text_log_file)
        # params_log_file_fullpath = osp.join(log_dir, params_log_file)
        params_all_log_file_fullpath = osp.join(log_dir, params_all_log_file)

        # logger.log_parameters_lite(params_log_file, args)
        logger.add_text_output(text_log_file_fullpath)
        logger.add_tabular_output(tabular_log_file_fullpath)
        prev_snapshot_dir = logger.get_snapshot_dir()
        prev_mode = logger.get_snapshot_mode()
        logger.set_snapshot_dir(log_dir)
        logger.set_snapshot_mode(snapshot_mode)
        logger.set_log_tabular_only(log_tabular_only)
        logger.push_prefix("[%s] " % exp_name)

        ############################################################
        ## Dumping config
        with open(params_all_log_file_fullpath, 'w') as yaml_file:
            yaml_file.write(yaml.dump(params, default_flow_style=False))

        ############################################################
        ## RUNNING THE EXPERIMENT
        logger.log('Running the experiment ...')
コード例 #5
0
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel', type=int, default=1,
                        help='Number of parallel workers to perform rollouts. 0 => don\'t start any workers. '
                             'If you use GPU then you probably have only one thus set n_parallel=1. Remember, this '
                             'number is not a number of processors you want to use but number of parallel '
                             'samplers of rollouts.')
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
    parser.add_argument('--log_dir', type=str, default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--snapshot_gap', type=int, default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file', type=str, default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument('--resume_from', type=str, default=None,
                        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot', type=ast.literal_eval, default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data', type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
コード例 #6
0
ファイル: rurllab.py プロジェクト: TJUSCS-RLLAB/MADRL
    def setup(self, env, policy, start_itr):

        if not self.args.algo == 'thddpg':
            # Baseline
            if self.args.baseline_type == 'linear':
                baseline = LinearFeatureBaseline(env_spec=env.spec)
            elif self.args.baseline_type == 'zero':
                baseline = ZeroBaseline(env_spec=env.spec)
            else:
                raise NotImplementedError(self.args.baseline_type)

            if self.args.control == 'concurrent':
                baseline = [baseline for _ in range(len(env.agents))]
        # Logger
        default_log_dir = config.LOG_DIR
        if self.args.log_dir is None:
            log_dir = osp.join(default_log_dir, self.args.exp_name)
        else:
            log_dir = self.args.log_dir

        tabular_log_file = osp.join(log_dir, self.args.tabular_log_file)
        text_log_file = osp.join(log_dir, self.args.text_log_file)
        params_log_file = osp.join(log_dir, self.args.params_log_file)

        logger.log_parameters_lite(params_log_file, self.args)
        logger.add_text_output(text_log_file)
        logger.add_tabular_output(tabular_log_file)
        prev_snapshot_dir = logger.get_snapshot_dir()
        prev_mode = logger.get_snapshot_mode()
        logger.set_snapshot_dir(log_dir)
        logger.set_snapshot_mode(self.args.snapshot_mode)
        logger.set_log_tabular_only(self.args.log_tabular_only)
        logger.push_prefix("[%s] " % self.args.exp_name)

        if self.args.algo == 'tftrpo':
            algo = MATRPO(env=env, policy_or_policies=policy, baseline_or_baselines=baseline,
                          batch_size=self.args.batch_size, start_itr=start_itr,
                          max_path_length=self.args.max_path_length, n_itr=self.args.n_iter,
                          discount=self.args.discount, gae_lambda=self.args.gae_lambda,
                          step_size=self.args.step_size, optimizer=ConjugateGradientOptimizer(
                              hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)) if
                          self.args.recurrent else None, ma_mode=self.args.control)
        elif self.args.algo == 'thddpg':
            qfunc = thContinuousMLPQFunction(env_spec=env.spec)
            if self.args.exp_strategy == 'ou':
                es = OUStrategy(env_spec=env.spec)
            elif self.args.exp_strategy == 'gauss':
                es = GaussianStrategy(env_spec=env.spec)
            else:
                raise NotImplementedError()

            algo = thDDPG(env=env, policy=policy, qf=qfunc, es=es, batch_size=self.args.batch_size,
                          max_path_length=self.args.max_path_length,
                          epoch_length=self.args.epoch_length,
                          min_pool_size=self.args.min_pool_size,
                          replay_pool_size=self.args.replay_pool_size, n_epochs=self.args.n_iter,
                          discount=self.args.discount, scale_reward=0.01,
                          qf_learning_rate=self.args.qfunc_lr,
                          policy_learning_rate=self.args.policy_lr,
                          eval_samples=self.args.eval_samples, mode=self.args.control)
        return algo
コード例 #7
0
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel', type=int, default=1,
                        help='Number of parallel workers to perform rollouts. 0 => don\'t start any workers')
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
    parser.add_argument('--log_dir', type=str, default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), "gap" (every'
                             '`snapshot_gap` iterations are saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--snapshot_gap', type=int, default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file', type=str, default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument('--resume_from', type=str, default=None,
                        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot', type=ast.literal_eval, default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data', type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
コード例 #8
0
    def setup(self, env, policy, start_itr):

        if not self.args.algo == 'thddpg':
            # Baseline
            if self.args.baseline_type == 'linear':
                baseline = LinearFeatureBaseline(env_spec=env.spec)
            elif self.args.baseline_type == 'zero':
                baseline = ZeroBaseline(env_spec=env.spec)
            else:
                raise NotImplementedError(self.args.baseline_type)

            if self.args.control == 'concurrent':
                baseline = [baseline for _ in range(len(env.agents))]
        # Logger
        default_log_dir = config.LOG_DIR
        if self.args.log_dir is None:
            log_dir = osp.join(default_log_dir, self.args.exp_name)
        else:
            log_dir = self.args.log_dir

        tabular_log_file = osp.join(log_dir, self.args.tabular_log_file)
        text_log_file = osp.join(log_dir, self.args.text_log_file)
        params_log_file = osp.join(log_dir, self.args.params_log_file)

        logger.log_parameters_lite(params_log_file, self.args)
        logger.add_text_output(text_log_file)
        logger.add_tabular_output(tabular_log_file)
        prev_snapshot_dir = logger.get_snapshot_dir()
        prev_mode = logger.get_snapshot_mode()
        logger.set_snapshot_dir(log_dir)
        logger.set_snapshot_mode(self.args.snapshot_mode)
        logger.set_log_tabular_only(self.args.log_tabular_only)
        logger.push_prefix("[%s] " % self.args.exp_name)

        if self.args.algo == 'tftrpo':
            algo = MATRPO(
                env=env,
                policy_or_policies=policy,
                baseline_or_baselines=baseline,
                batch_size=self.args.batch_size,
                start_itr=start_itr,
                max_path_length=self.args.max_path_length,
                n_itr=self.args.n_iter,
                discount=self.args.discount,
                gae_lambda=self.args.gae_lambda,
                step_size=self.args.step_size,
                optimizer=ConjugateGradientOptimizer(
                    hvp_approach=FiniteDifferenceHvp(
                        base_eps=1e-5)) if self.args.recurrent else None,
                ma_mode=self.args.control)
        elif self.args.algo == 'thddpg':
            qfunc = thContinuousMLPQFunction(env_spec=env.spec)
            if self.args.exp_strategy == 'ou':
                es = OUStrategy(env_spec=env.spec)
            elif self.args.exp_strategy == 'gauss':
                es = GaussianStrategy(env_spec=env.spec)
            else:
                raise NotImplementedError()

            algo = thDDPG(env=env,
                          policy=policy,
                          qf=qfunc,
                          es=es,
                          batch_size=self.args.batch_size,
                          max_path_length=self.args.max_path_length,
                          epoch_length=self.args.epoch_length,
                          min_pool_size=self.args.min_pool_size,
                          replay_pool_size=self.args.replay_pool_size,
                          n_epochs=self.args.n_iter,
                          discount=self.args.discount,
                          scale_reward=0.01,
                          qf_learning_rate=self.args.qfunc_lr,
                          policy_learning_rate=self.args.policy_lr,
                          eval_samples=self.args.eval_samples,
                          mode=self.args.control)
        return algo
コード例 #9
0
                            'eps-opt', 'all', 'self'
                        ],
                        default='all')
    parser.add_argument('-l', '--log_dir', type=str, default="./data")
    args = parser.parse_args()
    log_dir = args.log_dir

    num_eval_traj = args.num_eval_traj

    text_output_file = None if log_dir is None else osp.join(log_dir, "text")
    info_theory_tabular_output = None if log_dir is None else osp.join(
        log_dir, "info_table.csv")

    if text_output_file is not None:
        logger.add_text_output(text_output_file)
        logger.add_tabular_output(info_theory_tabular_output)

    with tf.Session() as sess:
        data = joblib.load(args.file)
        policy = data['policy']
        env = data['env']
        algo = PPOSGD(env=env, baseline=data['baseline'], policy=policy)
        bandit = env.wrapped_env
        logger.log("Loaded policy trained against {} for {} timesteps".format(
            env.pi_H.__class__.__name__, bandit.horizon))
        max_timesteps = min(bandit.horizon, args.max_path_length)

        if args.eval_against == 'all':
            itr = human_policy_dict.items()
        elif args.eval_against == 'self':
            itr = [(args.eva_against, None)]
コード例 #10
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.99)
    parser.add_argument('--gae_lambda', type=float, default=1.0)
    parser.add_argument('--reward_scale', type=float, default=1.0)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum', action='store_true', default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--rectangle', type=str, default='10,10')
    parser.add_argument('--map_type', type=str, default='rectangle')
    parser.add_argument('--n_evaders', type=int, default=5)
    parser.add_argument('--n_pursuers', type=int, default=2)
    parser.add_argument('--obs_range', type=int, default=3)
    parser.add_argument('--n_catch', type=int, default=2)
    parser.add_argument('--urgency', type=float, default=0.0)
    parser.add_argument('--pursuit', dest='train_pursuit', action='store_true')
    parser.add_argument('--evade', dest='train_pursuit', action='store_false')
    parser.set_defaults(train_pursuit=True)
    parser.add_argument('--surround', action='store_true', default=False)
    parser.add_argument('--constraint_window', type=float, default=1.0)
    parser.add_argument('--sample_maps', action='store_true', default=False)
    parser.add_argument('--map_file', type=str, default='../maps/map_pool.npy')
    parser.add_argument('--flatten', action='store_true', default=False)
    parser.add_argument('--reward_mech', type=str, default='global')
    parser.add_argument('--catchr', type=float, default=0.1)
    parser.add_argument('--term_pursuit', type=float, default=5.0)

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_type', type=str, default='linear')

    parser.add_argument('--conv', action='store_true', default=False)

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')


    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    if args.sample_maps:
        map_pool = np.load(args.map_file)
    else:
        if args.map_type == 'rectangle':
            env_map = TwoDMaps.rectangle_map(*map(int, args.rectangle.split(',')))
        elif args.map_type == 'complex':
            env_map = TwoDMaps.complex_map(*map(int, args.rectangle.split(',')))
        else:
            raise NotImplementedError()
        map_pool = [env_map]

    env = PursuitEvade(map_pool, n_evaders=args.n_evaders, n_pursuers=args.n_pursuers,
                       obs_range=args.obs_range, n_catch=args.n_catch,
                       train_pursuit=args.train_pursuit, urgency_reward=args.urgency,
                       surround=args.surround, sample_maps=args.sample_maps,
                       constraint_window=args.constraint_window,
                       flatten=args.flatten,
                       reward_mech=args.reward_mech,
                       catchr=args.catchr,
                       term_pursuit=args.term_pursuit)

    env = RLLabEnv(
            StandardizedEnv(env, scale_reward=args.reward_scale, enable_obsnorm=False),
            mode=args.control)

    if args.recurrent:
        if args.conv:
            feature_network = ConvNetwork(
                input_shape=emv.spec.observation_space.shape,
                output_dim=5, 
                conv_filters=(8,16,16),
                conv_filter_sizes=(3,3,3),
                conv_strides=(1,1,1),
                conv_pads=('VALID','VALID','VALID'),
                hidden_sizes=(64,), 
                hidden_nonlinearity=NL.rectify,
                output_nonlinearity=NL.softmax)
        else:
            feature_network = MLP(
                input_shape=(env.spec.observation_space.flat_dim + env.spec.action_space.flat_dim,),
                output_dim=5, hidden_sizes=(128,128,128), hidden_nonlinearity=NL.tanh,
                output_nonlinearity=None)
        if args.recurrent == 'gru':
            policy = CategoricalGRUPolicy(env_spec=env.spec, feature_network=feature_network,
                                       hidden_dim=int(args.policy_hidden_sizes))
    elif args.conv:
        feature_network = ConvNetwork(
            input_shape=env.spec.observation_space.shape,
            output_dim=5, 
            conv_filters=(8,16,16),
            conv_filter_sizes=(3,3,3),
            conv_strides=(1,1,1),
            conv_pads=('valid','valid','valid'),
            hidden_sizes=(64,), 
            hidden_nonlinearity=NL.rectify,
            output_nonlinearity=NL.softmax)
        policy = CategoricalMLPPolicy(env_spec=env.spec, prob_network=feature_network)
    else:
        policy = CategoricalMLPPolicy(env_spec=env.spec, hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(obsfeat_space)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        mode=args.control,)

    algo.train()
コード例 #11
0
ファイル: resume_vrep.py プロジェクト: flyers/rllab
import base64
import joblib

default_log_dir = '/home/sliay/Documents/rllab/data/local/experiment'
now = datetime.datetime.now(dateutil.tz.tzlocal())
# avoid name clashes when running distributed jobs
rand_id = str(uuid.uuid4())[:5]
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
log_dir = os.path.join(default_log_dir, default_exp_name)
tabular_log_file = os.path.join(log_dir, 'progress.csv')
text_log_file = os.path.join(log_dir, 'debug.log')
params_log_file = os.path.join(log_dir, 'params.json')

logger.add_text_output(text_log_file)
logger.add_tabular_output(tabular_log_file)
logger.set_snapshot_dir(log_dir)
logger.set_snapshot_mode('last')
logger.set_log_tabular_only(False)
logger.push_prefix("[%s] " % default_exp_name)

last_snapshot_dir = '/home/sliay/Documents/rllab/data/local/experiment/experiment_2016_07_07_498itr'
data = joblib.load(os.path.join(last_snapshot_dir, 'params.pkl'))
policy = data['policy']
env = data['env']
baseline = data['baseline']
# env = normalize(GymEnv("VREP-v0", record_video=False))

# policy = GaussianMLPPolicy(
#     env_spec=env.spec,
#     The neural network policy should have two hidden layers, each with 32 hidden units.
コード例 #12
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.99)
    parser.add_argument('--gae_lambda', type=float, default=1.0)
    parser.add_argument('--reward_scale', type=float, default=1.0)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum',
                        action='store_true',
                        default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--rectangle', type=str, default='10,10')
    parser.add_argument('--map_type', type=str, default='rectangle')
    parser.add_argument('--n_evaders', type=int, default=5)
    parser.add_argument('--n_pursuers', type=int, default=2)
    parser.add_argument('--obs_range', type=int, default=3)
    parser.add_argument('--n_catch', type=int, default=2)
    parser.add_argument('--urgency', type=float, default=0.0)
    parser.add_argument('--pursuit', dest='train_pursuit', action='store_true')
    parser.add_argument('--evade', dest='train_pursuit', action='store_false')
    parser.set_defaults(train_pursuit=True)
    parser.add_argument('--surround', action='store_true', default=False)
    parser.add_argument('--constraint_window', type=float, default=1.0)
    parser.add_argument('--sample_maps', action='store_true', default=False)
    parser.add_argument('--map_file', type=str, default='../maps/map_pool.npy')
    parser.add_argument('--flatten', action='store_true', default=False)
    parser.add_argument('--reward_mech', type=str, default='global')
    parser.add_argument('--catchr', type=float, default=0.1)
    parser.add_argument('--term_pursuit', type=float, default=5.0)

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_type', type=str, default='linear')

    parser.add_argument('--conv', action='store_true', default=False)

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )

    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    if args.sample_maps:
        map_pool = np.load(args.map_file)
    else:
        if args.map_type == 'rectangle':
            env_map = TwoDMaps.rectangle_map(
                *map(int, args.rectangle.split(',')))
        elif args.map_type == 'complex':
            env_map = TwoDMaps.complex_map(
                *map(int, args.rectangle.split(',')))
        else:
            raise NotImplementedError()
        map_pool = [env_map]

    env = PursuitEvade(map_pool,
                       n_evaders=args.n_evaders,
                       n_pursuers=args.n_pursuers,
                       obs_range=args.obs_range,
                       n_catch=args.n_catch,
                       train_pursuit=args.train_pursuit,
                       urgency_reward=args.urgency,
                       surround=args.surround,
                       sample_maps=args.sample_maps,
                       constraint_window=args.constraint_window,
                       flatten=args.flatten,
                       reward_mech=args.reward_mech,
                       catchr=args.catchr,
                       term_pursuit=args.term_pursuit)

    env = RLLabEnv(StandardizedEnv(env,
                                   scale_reward=args.reward_scale,
                                   enable_obsnorm=False),
                   mode=args.control)

    if args.recurrent:
        if args.conv:
            feature_network = ConvNetwork(
                input_shape=emv.spec.observation_space.shape,
                output_dim=5,
                conv_filters=(8, 16, 16),
                conv_filter_sizes=(3, 3, 3),
                conv_strides=(1, 1, 1),
                conv_pads=('VALID', 'VALID', 'VALID'),
                hidden_sizes=(64, ),
                hidden_nonlinearity=NL.rectify,
                output_nonlinearity=NL.softmax)
        else:
            feature_network = MLP(
                input_shape=(env.spec.observation_space.flat_dim +
                             env.spec.action_space.flat_dim, ),
                output_dim=5,
                hidden_sizes=(128, 128, 128),
                hidden_nonlinearity=NL.tanh,
                output_nonlinearity=None)
        if args.recurrent == 'gru':
            policy = CategoricalGRUPolicy(env_spec=env.spec,
                                          feature_network=feature_network,
                                          hidden_dim=int(
                                              args.policy_hidden_sizes))
    elif args.conv:
        feature_network = ConvNetwork(
            input_shape=env.spec.observation_space.shape,
            output_dim=5,
            conv_filters=(8, 16, 16),
            conv_filter_sizes=(3, 3, 3),
            conv_strides=(1, 1, 1),
            conv_pads=('valid', 'valid', 'valid'),
            hidden_sizes=(64, ),
            hidden_nonlinearity=NL.rectify,
            output_nonlinearity=NL.softmax)
        policy = CategoricalMLPPolicy(env_spec=env.spec,
                                      prob_network=feature_network)
    else:
        policy = CategoricalMLPPolicy(env_spec=env.spec,
                                      hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(obsfeat_space)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        mode=args.control,
    )

    algo.train()
コード例 #13
0
def experiment(variant):

    seed = variant['seed']
    tf.set_random_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    initial_params_file = variant['initial_params_file']
    goalIndex = variant['goalIndex']
    init_step_size = variant['init_step_size']

    regionSize = variant['regionSize']

    mode = variant['mode']

    if 'docker' in mode:
        taskFilePrefix = '/root/code'
    else:
        taskFilePrefix = '/home/russellm'

    if variant['valRegionSize'] != None:
        valRegionSize = variant['valRegionSize']

        tasksFile = taskFilePrefix + '/multiworld/multiworld/envs/goals/pickPlace_' + valRegionSize + '_val.pkl'

    else:
        tasksFile = taskFilePrefix + '/multiworld/multiworld/envs/goals/pickPlace_' + regionSize + '.pkl'

    tasks = pickle.load(open(tasksFile, 'rb'))

    envType = variant['envType']
    if envType == 'Push':
        baseEnv = SawyerPushEnv(tasks=tasks)
    else:
        assert (envType) == 'PickPlace'
        baseEnv = SawyerPickPlaceEnv(tasks=tasks)

    env = FinnMamlEnv(
        FlatGoalEnv(baseEnv,
                    obs_keys=['state_observation', 'state_desired_goal']))
    env = TfEnv(NormalizedBoxEnv(env))
    baseline = LinearFeatureBaseline(env_spec=env.spec)

    algo = VPG(
        env=env,
        policy=None,
        load_policy=initial_params_file,
        baseline=baseline,
        batch_size=7500,  # 2x
        max_path_length=150,
        n_itr=10,
        reset_arg=goalIndex,
        optimizer_args={
            'init_learning_rate': init_step_size,
            'tf_optimizer_args': {
                'learning_rate': 0.1 * init_step_size
            },
            'tf_optimizer_cls': tf.train.GradientDescentOptimizer
        })
    import os
    saveDir = variant['saveDir']
    currPath = ''
    for _dir in saveDir.split('/'):
        currPath += _dir + '/'
        if os.path.isdir(currPath) == False:
            os.mkdir(currPath)

    logger.set_snapshot_dir(saveDir)
    logger.add_tabular_output(saveDir + 'progress.csv')
    algo.train()
コード例 #14
0
def run_experiment(argv):

    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel',
                        type=int,
                        default=1,
                        help='Number of parallel workers to perform rollouts.')
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=default_log_dir,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--use_cloudpickle', type=bool, help='NOT USED')

    args = parser.parse_args(argv[1:])

    from sandbox.vime.sampler import parallel_sampler_expl as parallel_sampler
    parallel_sampler.initialize(n_parallel=args.n_parallel)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    # read from stdin
    data = pickle.loads(base64.b64decode(args.args_data))

    log_dir = args.log_dir
    # exp_dir = osp.join(log_dir, args.exp_name)
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    maybe_iter = concretize(data)
    if is_iterable(maybe_iter):
        for _ in maybe_iter:
            pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
コード例 #15
0
ファイル: run_experiment_lite.py プロジェクト: jpdoyle/vime
def run_experiment(argv):

    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel', type=int, default=1,
                        help='Number of parallel workers to perform rollouts.')
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
    parser.add_argument('--log_dir', type=str, default=default_log_dir,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--plot', type=ast.literal_eval, default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')

    args = parser.parse_args(argv[1:])

    from sandbox.vime.sampler import parallel_sampler_expl as parallel_sampler
    parallel_sampler.initialize(n_parallel=args.n_parallel)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    # read from stdin
    data = pickle.loads(base64.b64decode(args.args_data))

    log_dir = args.log_dir
    # exp_dir = osp.join(log_dir, args.exp_name)
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    maybe_iter = concretize(data)
    if is_iterable(maybe_iter):
        for _ in maybe_iter:
            pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
コード例 #16
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum',
                        action='store_true',
                        default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--n_good', type=int, default=3)
    parser.add_argument('--n_hostage', type=int, default=5)
    parser.add_argument('--n_bad', type=int, default=5)
    parser.add_argument('--n_coop_save', type=int, default=2)
    parser.add_argument('--n_coop_avoid', type=int, default=2)
    parser.add_argument('--n_sensors', type=int, default=20)
    parser.add_argument('--sensor_range', type=float, default=0.2)
    parser.add_argument('--save_reward', type=float, default=3)
    parser.add_argument('--hit_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.01)
    parser.add_argument('--bomb_reward', type=float, default=-10.)

    parser.add_argument('--recurrent', action='store_true', default=False)
    parser.add_argument('--baseline_type', type=str, default='linear')
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )

    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    centralized = True if args.control == 'centralized' else False

    sensor_range = np.array(map(float, args.sensor_range.split(',')))
    assert sensor_range.shape == (args.n_pursuers, )

    env = ContinuousHostageWorld(args.n_good,
                                 args.n_hostage,
                                 args.n_bad,
                                 args.n_coop_save,
                                 args.n_coop_avoid,
                                 n_sensors=args.n_sensors,
                                 sensor_range=args.sensor_range,
                                 save_reward=args.save_reward,
                                 hit_reward=args.hit_reward,
                                 encounter_reward=args.encounter_reward,
                                 bomb_reward=args.bomb_reward)

    env = RLLabEnv(StandardizedEnv(env), mode=args.control)

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if args.recurrent:
        policy = GaussianGRUPolicy(env_spec=env.spec,
                                   hidden_sizes=args.hidden_sizes)
    else:
        policy = GaussianMLPPolicy(env_spec=env.spec,
                                   hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(obsfeat_space)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        n_itr=args.n_iter,
        discount=args.discount,
        step_size=args.max_kl,
        mode=args.control,
    )

    algo.train()
コード例 #17
0
ファイル: run_trpo.py プロジェクト: PangYanbo/gym-extensions
parser = argparse.ArgumentParser()
parser.add_argument("envs", nargs='+', help="The list of environments to train on in order. Eval rollouts will be run on all environments at the end.")
parser.add_argument("--num_epochs", default=100, type=int, help="Number of epochs to run.")
parser.add_argument("--num_final_rollouts", default=20, type=int, help="Number of rollouts to run on final evaluation of environments.")
parser.add_argument("--batch_size", default=25000, type=int, help="Batch_size per epoch (this is the number of (state, action) samples, not the number of rollouts)")
parser.add_argument("--step_size", default=0.01, type=float, help="Step size for TRPO (i.e. the maximum KL bound)")
parser.add_argument("--reg_coeff", default=1e-5, type=float, help="Regularization coefficient for TRPO")
parser.add_argument("--text_log_file", default="./data/debug.log", help="Where text output will go")
parser.add_argument("--tabular_log_file", default="./data/progress.csv", help="Where tabular output will go")
args = parser.parse_args()

# stub(globals())

# ext.set_seed(1)
logger.add_text_output(args.text_log_file)
logger.add_tabular_output(args.tabular_log_file)
logger.set_log_tabular_only(False)

envs = []

for env_name in args.envs:
    gymenv = GymEnv(env_name, force_reset=True, record_video=False, record_log=False)
    env = TfEnv(normalize(gymenv))
    envs.append((env_name, env))

policy = GaussianMLPPolicy(
name="policy",
env_spec=env.spec,
# The neural network policy should have two hidden layers, each with 32 hidden units.
hidden_sizes=(100, 50, 25),
hidden_nonlinearity=tf.nn.relu,
コード例 #18
0
def run_task(v, num_cpu=8, log_dir="./data", ename=None, **kwargs):
    from scipy.stats import bernoulli, uniform, beta

    import tensorflow as tf
    from assistive_bandits.experiments.l2_rnn_baseline import L2RNNBaseline
    from assistive_bandits.experiments.tbptt_optimizer import TBPTTOptimizer
    from sandbox.rocky.tf.policies.categorical_gru_policy import CategoricalGRUPolicy
    from assistive_bandits.experiments.pposgd_clip_ratio import PPOSGD

    if not local_test and force_remote:
        import rl_algs.logger as rl_algs_logger
        log_dir = rl_algs_logger.get_dir()

    if log_dir is not None:
        log_dir = osp.join(log_dir, str(v["n_episodes"]))
        log_dir = osp.join(log_dir, v["human_policy"])

    text_output_file = None if log_dir is None else osp.join(log_dir, "text")
    tabular_output_file = None if log_dir is None else osp.join(
        log_dir, "train_table.csv")
    info_theory_tabular_output = None if log_dir is None else osp.join(
        log_dir, "info_table.csv")
    rag = uniform_bernoulli_iterator()
    bandit = BanditEnv(n_arms=v["n_arms"],
                       reward_dist=bernoulli,
                       reward_args_generator=rag,
                       horizon=v["n_episodes"])
    pi_H = human_policy_dict[v["human_policy"]](bandit)

    h_wrapper = human_wrapper_dict[v["human_wrapper"]]

    env = h_wrapper(bandit, pi_H, penalty=v["intervention_penalty"])

    if text_output_file is not None:
        logger.add_text_output(text_output_file)
        logger.add_tabular_output(tabular_output_file)

    logger.log("Training against {}".format(v["human_policy"]))
    logger.log("Setting seed to {}".format(v["seed"]))
    env.seed(v["seed"])

    baseline = L2RNNBaseline(
        name="vf",
        env_spec=env.spec,
        log_loss_before=False,
        log_loss_after=False,
        hidden_nonlinearity=getattr(tf.nn, v["nonlinearity"]),
        weight_normalization=v["weight_normalization"],
        layer_normalization=v["layer_normalization"],
        state_include_action=False,
        hidden_dim=v["hidden_dim"],
        optimizer=TBPTTOptimizer(
            batch_size=v["opt_batch_size"],
            n_steps=v["opt_n_steps"],
            n_epochs=v["min_epochs"],
        ),
        batch_size=v["opt_batch_size"],
        n_steps=v["opt_n_steps"],
    )
    policy = CategoricalGRUPolicy(env_spec=env.spec,
                                  hidden_nonlinearity=getattr(
                                      tf.nn, v["nonlinearity"]),
                                  hidden_dim=v["hidden_dim"],
                                  state_include_action=True,
                                  name="policy")

    n_itr = 3 if local_test else 100
    # logger.log('sampler_args {}'.format(dict(n_envs=max(1, min(int(np.ceil(v["batch_size"] / v["n_episodes"])), 100)))))

    # parallel_sampler.initialize(6)
    # parallel_sampler.set_seed(v["seed"])

    algo = PPOSGD(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=v["batch_size"],
        max_path_length=v["n_episodes"],
        # 43.65 env time
        sampler_args=dict(n_envs=max(
            1, min(int(np.ceil(v["batch_size"] / v["n_episodes"])), 100))),
        # 100 threads -> 1:36 to sample 187.275
        # 6 threads -> 1:31
        # force_batch_sampler=True,
        n_itr=n_itr,
        step_size=v["mean_kl"],
        clip_lr=v["clip_lr"],
        log_loss_kl_before=False,
        log_loss_kl_after=False,
        use_kl_penalty=v["use_kl_penalty"],
        min_n_epochs=v["min_epochs"],
        entropy_bonus_coeff=v["entropy_bonus_coeff"],
        optimizer=TBPTTOptimizer(
            batch_size=v["opt_batch_size"],
            n_steps=v["opt_n_steps"],
            n_epochs=v["min_epochs"],
        ),
        discount=v["discount"],
        gae_lambda=v["gae_lambda"],
        use_line_search=True
        # scope=ename
    )

    sess = tf.Session()
    with sess.as_default():
        algo.train(sess)

        if text_output_file is not None:
            logger.remove_tabular_output(tabular_output_file)
            logger.add_tabular_output(info_theory_tabular_output)

        # Now gather statistics for t-tests and such!
        for human_policy_name, human_policy in human_policy_dict.items():

            logger.log("-------------------")
            logger.log("Evaluating against {}".format(human_policy.__name__))
            logger.log("-------------------")

            logger.log("Obtaining Samples...")
            test_pi_H = human_policy(bandit)
            test_env = h_wrapper(bandit, test_pi_H, penalty=0.)
            eval_sampler = VectorizedSampler(algo, n_envs=100)
            algo.batch_size = v["num_eval_traj"] * v["n_episodes"]
            algo.env = test_env
            logger.log("algo.env.pi_H has class: {}".format(
                algo.env.pi_H.__class__))
            eval_sampler.start_worker()
            paths = eval_sampler.obtain_samples(-1)
            eval_sampler.shutdown_worker()
            rewards = []

            H_act_seqs = []
            R_act_seqs = []
            best_arms = []
            optimal_a_seqs = []
            for p in paths:
                a_Rs = env.action_space.unflatten_n(p['actions'])
                obs_R = env.observation_space.unflatten_n(p['observations'])
                best_arm = np.argmax(p['env_infos']['arm_means'][0])

                H_act_seqs.append(obs_R[:, 1])
                R_act_seqs.append(a_Rs)
                best_arms.append(best_arm)
                optimal_a_seqs.append(
                    [best_arm for _ in range(v["n_episodes"])])

                rewards.append(np.sum(p['rewards']))

            #feel free to add more data
            logger.log("NumTrajs {}".format(v["num_eval_traj"]))
            logger.log("AverageReturn {}".format(np.mean(rewards)))
            logger.log("StdReturn {}".format(np.std(rewards)))
            logger.log("MaxReturn {}".format(np.max(rewards)))
            logger.log("MinReturn {}".format(np.min(rewards)))

            optimal_a_H_freqs = _frequency_agreement(H_act_seqs,
                                                     optimal_a_seqs)
            optimal_a_R_freqs = _frequency_agreement(R_act_seqs,
                                                     optimal_a_seqs)

            for t in range(v["n_episodes"]):
                logger.record_tabular("PolicyExecTime", 0)
                logger.record_tabular("EnvExecTime", 0)
                logger.record_tabular("ProcessExecTime", 0)
                logger.record_tabular("Tested Against", human_policy_name)
                logger.record_tabular("t", t)
                logger.record_tabular("a_H_agreement", optimal_a_H_freqs[t])
                logger.record_tabular("a_R_agreement", optimal_a_R_freqs[t])

                H_act_seqs_truncated = [a_Hs[0:t] for a_Hs in H_act_seqs]
                R_act_seqs_truncated = [a_Rs[0:t] for a_Rs in R_act_seqs]
                h_mutual_info = _mutual_info_seqs(H_act_seqs_truncated,
                                                  best_arms, v["n_arms"] + 1)
                r_mutual_info = _mutual_info_seqs(R_act_seqs_truncated,
                                                  best_arms, v["n_arms"] + 1)
                logger.record_tabular("h_mutual_info", h_mutual_info)
                logger.record_tabular("r_mutual_info", r_mutual_info)
                logger.record_tabular("a_H_opt_freq", optimal_a_H_freqs[t])
                logger.record_tabular("a_R_opt_freq", optimal_a_R_freqs[t])
                logger.dump_tabular()

            test_env = h_wrapper(bandit, human_policy(bandit), penalty=0.)

            logger.log("Printing Example Trajectories")
            for i in range(v["num_display_traj"]):
                observation = test_env.reset()
                policy.reset()
                logger.log("-- Trajectory {} of {}".format(
                    i + 1, v["num_display_traj"]))
                logger.log("t \t obs \t act \t reward \t act_probs")
                for t in range(v["n_episodes"]):
                    action, act_info = policy.get_action(observation)
                    new_obs, reward, done, info = test_env.step(action)
                    logger.log("{} \t {} \t {} \t {} \t {}".format(
                        t, observation, action, reward, act_info['prob']))
                    observation = new_obs
                    if done:
                        logger.log("Total reward: {}".format(
                            info["accumulated rewards"]))
                        break

    if text_output_file is not None:
        logger.remove_text_output(text_output_file)
        logger.remove_tabular_output(info_theory_tabular_output)
コード例 #19
0
ファイル: run_hostage.py プロジェクト: TJUSCS-RLLAB/MADRL
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum', action='store_true', default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--n_good', type=int, default=3)
    parser.add_argument('--n_hostage', type=int, default=5)
    parser.add_argument('--n_bad', type=int, default=5)
    parser.add_argument('--n_coop_save', type=int, default=2)
    parser.add_argument('--n_coop_avoid', type=int, default=2)
    parser.add_argument('--n_sensors', type=int, default=20)
    parser.add_argument('--sensor_range', type=float, default=0.2)
    parser.add_argument('--save_reward', type=float, default=3)
    parser.add_argument('--hit_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.01)
    parser.add_argument('--bomb_reward', type=float, default=-10.)

    parser.add_argument('--recurrent', action='store_true', default=False)
    parser.add_argument('--baseline_type', type=str, default='linear')
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')


    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    centralized = True if args.control == 'centralized' else False

    sensor_range = np.array(map(float, args.sensor_range.split(',')))
    assert sensor_range.shape == (args.n_pursuers,)

    env = ContinuousHostageWorld(args.n_good, args.n_hostage, args.n_bad, args.n_coop_save,
                                 args.n_coop_avoid, n_sensors=args.n_sensors,
                                 sensor_range=args.sensor_range, save_reward=args.save_reward,
                                 hit_reward=args.hit_reward, encounter_reward=args.encounter_reward,
                                 bomb_reward=args.bomb_reward)

    env = RLLabEnv(StandardizedEnv(env), mode=args.control)

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if args.recurrent:
        policy = GaussianGRUPolicy(env_spec=env.spec, hidden_sizes=args.hidden_sizes)
    else:
        policy = GaussianMLPPolicy(env_spec=env.spec, hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(obsfeat_space)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(env=env,
            policy=policy,
            baseline=baseline,
            batch_size=args.n_timesteps,
            max_path_length=args.max_traj_len,
            n_itr=args.n_iter,
            discount=args.discount,
            step_size=args.max_kl,
            mode=args.control,)

    algo.train()
コード例 #20
0
    def __init__(self, env, args):
        self.args = args
        # Parallel setup
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            set_seed(args.seed)
            parallel_sampler.set_seed(args.seed)

        env, policy = rllab_envpolicy_parser(env, args)

        if not args.algo == 'thddpg':
            # Baseline
            if args.baseline_type == 'linear':
                baseline = LinearFeatureBaseline(env_spec=env.spec)
            elif args.baseline_type == 'zero':
                baseline = ZeroBaseline(env_spec=env.spec)
            else:
                raise NotImplementedError(args.baseline_type)

        # Logger
        default_log_dir = config.LOG_DIR
        if args.log_dir is None:
            log_dir = osp.join(default_log_dir, args.exp_name)
        else:
            log_dir = args.log_dir

        tabular_log_file = osp.join(log_dir, args.tabular_log_file)
        text_log_file = osp.join(log_dir, args.text_log_file)
        params_log_file = osp.join(log_dir, args.params_log_file)

        logger.log_parameters_lite(params_log_file, args)
        logger.add_text_output(text_log_file)
        logger.add_tabular_output(tabular_log_file)
        prev_snapshot_dir = logger.get_snapshot_dir()
        prev_mode = logger.get_snapshot_mode()
        logger.set_snapshot_dir(log_dir)
        logger.set_snapshot_mode(args.snapshot_mode)
        logger.set_log_tabular_only(args.log_tabular_only)
        logger.push_prefix("[%s] " % args.exp_name)

        if args.algo == 'tftrpo':
            self.algo = TRPO(
                env=env,
                policy=policy,
                baseline=baseline,
                batch_size=args.batch_size,
                max_path_length=args.max_path_length,
                n_itr=args.n_iter,
                discount=args.discount,
                gae_lambda=args.gae_lambda,
                step_size=args.step_size,
                optimizer=ConjugateGradientOptimizer(
                    hvp_approach=FiniteDifferenceHvp(
                        base_eps=1e-5)) if args.recurrent else None,
                mode=args.control,
                sampler_cls=GSMDPBatchSampler)
        elif args.algo == 'thddpg':
            raise NotImplementedError(args.algo)
            qfunc = thContinuousMLPQFunction(env_spec=env.spec)
            if args.exp_strategy == 'ou':
                es = OUStrategy(env_spec=env.spec)
            elif args.exp_strategy == 'gauss':
                es = GaussianStrategy(env_spec=env.spec)
            else:
                raise NotImplementedError()

            self.algo = thDDPG(
                env=env,
                policy=policy,
                qf=qfunc,
                es=es,
                batch_size=args.batch_size,
                max_path_length=args.max_path_length,
                epoch_length=args.epoch_length,
                min_pool_size=args.min_pool_size,
                replay_pool_size=args.replay_pool_size,
                n_epochs=args.n_iter,
                discount=args.discount,
                scale_reward=0.01,
                qf_learning_rate=args.qfunc_lr,
                policy_learning_rate=args.policy_lr,
                eval_samples=args.eval_samples,
                mode=args.control,
            )
コード例 #21
0
ファイル: run_waterworld.py プロジェクト: TJUSCS-RLLAB/MADRL
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', type=str, default=default_exp_name,
                        help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)
    parser.add_argument('--reward_scale', type=float, default=1.0)
    parser.add_argument('--enable_obsnorm', action='store_true', default=False)
    parser.add_argument('--chunked', action='store_true', default=False)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum', action='store_true', default=False)
    parser.add_argument('--anneal_step_size', type=int, default=0)

    parser.add_argument('--n_timesteps', type=int, default=8000)

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--radius', type=float, default=0.015)
    parser.add_argument('--n_evaders', type=int, default=10)
    parser.add_argument('--n_pursuers', type=int, default=8)
    parser.add_argument('--n_poison', type=int, default=10)
    parser.add_argument('--n_coop', type=int, default=4)
    parser.add_argument('--n_sensors', type=int, default=30)
    parser.add_argument('--sensor_range', type=str, default='0.2')
    parser.add_argument('--food_reward', type=float, default=5)
    parser.add_argument('--poison_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.05)
    parser.add_argument('--reward_mech', type=str, default='local')

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--baseline_type', type=str, default='linear')
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_hidden_sizes', type=str, default='128,128')

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data', type=str, help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument(
        '--log_tabular_only', type=ast.literal_eval, default=False,
        help='Whether to only print the tabular log information (in a horizontal format)')

    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    centralized = True if args.control == 'centralized' else False

    sensor_range = np.array(map(float, args.sensor_range.split(',')))
    if len(sensor_range) == 1:
        sensor_range = sensor_range[0]
    else:
        assert sensor_range.shape == (args.n_pursuers,)

    env = MAWaterWorld(args.n_pursuers, args.n_evaders, args.n_coop, args.n_poison,
                       radius=args.radius, n_sensors=args.n_sensors, food_reward=args.food_reward,
                       poison_reward=args.poison_reward, encounter_reward=args.encounter_reward,
                       reward_mech=args.reward_mech, sensor_range=sensor_range, obstacle_loc=None)

    env = TfEnv(
        RLLabEnv(
            StandardizedEnv(env, scale_reward=args.reward_scale,
                            enable_obsnorm=args.enable_obsnorm), mode=args.control))

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if args.recurrent:
        feature_network = MLP(
            name='feature_net',
            input_shape=(env.spec.observation_space.flat_dim + env.spec.action_space.flat_dim,),
            output_dim=16, hidden_sizes=(128, 64, 32), hidden_nonlinearity=tf.nn.tanh,
            output_nonlinearity=None)
        if args.recurrent == 'gru':
            policy = GaussianGRUPolicy(env_spec=env.spec, feature_network=feature_network,
                                       hidden_dim=int(args.policy_hidden_sizes), name='policy')
        elif args.recurrent == 'lstm':
            policy = GaussianLSTMPolicy(env_spec=env.spec, feature_network=feature_network,
                                        hidden_dim=int(args.policy_hidden_sizes), name='policy')
    else:
        policy = GaussianMLPPolicy(
            name='policy', env_spec=env.spec,
            hidden_sizes=tuple(map(int, args.policy_hidden_sizes.split(','))), min_std=10e-5)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    elif args.baseline_type == 'mlp':
        raise NotImplementedError()
        # baseline = GaussianMLPBaseline(
        #     env_spec=env.spec, hidden_sizes=tuple(map(int, args.baseline_hidden_sizes.split(','))))
    else:
        baseline = ZeroBaseline(env_spec=env.spec)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        #max_path_length_limit=args.max_path_length_limit,
        update_max_path_length=args.update_curriculum,
        anneal_step_size=args.anneal_step_size,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)) if
        args.recurrent else None,
        mode=args.control if not args.chunked else 'chunk_{}'.format(args.control),)

    algo.train()
コード例 #22
0
def experiment(variant):

    seed = variant['seed']

    tf.set_random_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    fast_learning_rate = variant['flr']

    fast_batch_size = variant[
        'fbs']  # 10 works for [0.1, 0.2], 20 doesn't improve much for [0,0.2]
    meta_batch_size = 20  # 10 also works, but much less stable, 20 is fairly stable, 40 is more stable
    max_path_length = 150
    num_grad_updates = 1
    meta_step_size = variant['mlr']

    regionSize = variant['regionSize']

    if regionSize == '20X20':

        tasksFile = '/root/code/multiworld/multiworld/envs/goals/pickPlace_20X20_6_8.pkl'

    else:
        assert regionSize == '60X30'

        tasksFile = '/root/code/multiworld/multiworld/envs/goals/pickPlace_60X30.pkl'

    tasks = pickle.load(open(tasksFile, 'rb'))

    envType = variant['envType']

    if envType == 'Push':

        baseEnv = SawyerPushEnv(tasks=tasks)
    else:
        assert (envType) == 'PickPlace'

        baseEnv = SawyerPickPlaceEnv(tasks=tasks)
    env = FinnMamlEnv(
        FlatGoalEnv(baseEnv,
                    obs_keys=['state_observation', 'state_desired_goal']))

    env = TfEnv(NormalizedBoxEnv(env))

    baseline = LinearFeatureBaseline(env_spec=env.spec)

    algo = MAMLTRPO(
        env=env,
        policy=None,
        load_policy=variant['init_param_file'],
        baseline=baseline,
        batch_size=fast_batch_size,  # number of trajs for grad update
        max_path_length=max_path_length,
        meta_batch_size=meta_batch_size,
        num_grad_updates=num_grad_updates,
        n_itr=1000,
        use_maml=True,
        step_size=meta_step_size,
        plot=False,
    )

    import os

    saveDir = variant['saveDir']

    if os.path.isdir(saveDir) == False:
        os.mkdir(saveDir)

    logger.set_snapshot_dir(saveDir)
    logger.add_tabular_output(saveDir + 'progress.csv')

    algo.train()