Exemplo n.º 1
0
 def __init__(self, algo):
     """
     :type algo: BatchPolopt
     """
     self.algo = algo
     if self.algo.num_workers > 1:
         parallel_sampler.initialize(n_parallel=self.algo.num_workers)
Exemplo n.º 2
0
def _init(args):
    env_name = args.env_name
    print('Using environment %s' % env_name)
    params_dict = {
        'env_name': [env_name],
        'rundir': [args.rundir],
        'ent_wt': [args.trpo_ent],
        'trpo_step': [args.trpo_step],
        'hid_size': [args.hid_size],
        'hid_layers': [args.hid_layers],
        'many_runs': [args.repeat > 1]
    }
    if args.repeat > 1:
        # stacked parallel thing doesn't work, bleh
        warnings.warn(
            "You're trying to use --repeat N for N > 1, but that "
            "disables parallel sampling. This is probably going to be "
            "heinously slow or something, use at own risk.")
        # parallel_sampler.initialize(n_parallel=1)
        # parallel_sampler.set_seed(1)
        run_sweep_parallel(main, params_dict, repeat=args.repeat)
    else:
        parallel_sampler.initialize(n_parallel=8)
        parallel_sampler.set_seed(1)
        run_sweep_serial(main, params_dict, repeat=1)
Exemplo n.º 3
0
 def __init__(self, env, args):
     self.env = env
     self.args = args
     # Parallel setup
     parallel_sampler.initialize(n_parallel=args.n_parallel)
     if args.seed is not None:
         set_seed(args.seed)
         parallel_sampler.set_seed(args.seed)
Exemplo n.º 4
0
 def __init__(self, env, args):
     self.env = env
     self.args = args
     # Parallel setup
     parallel_sampler.initialize(n_parallel=args.n_parallel)
     if args.seed is not None:
         set_seed(args.seed)
         parallel_sampler.set_seed(args.seed)
Exemplo n.º 5
0
    def __init__(self, env, args):

        self.env = env
        self.args = args

        # Parallel setup
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            set_seed(args.seed)
            parallel_sampler.set_seed(args.seed)

        index = 0
        env, policy = self.parse_env_args(env, args)
        self.algo = self.setup(env, policy, start_itr=index)
Exemplo n.º 6
0
def setup(seed, n_parallel, log_dir):
    if seed is not None:
        set_seed(seed)

    if n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=n_parallel)
        if seed is not None:
            parallel_sampler.set_seed(seed)

    if os.path.isdir(log_dir) == False:
        os.makedirs(log_dir, exist_ok=True)

    logger.set_snapshot_dir(log_dir)
    logger.add_tabular_output(log_dir + '/progress.csv')
Exemplo n.º 7
0
def run_experiment(algo,
                   n_parallel=0,
                   seed=0,
                   plot=False,
                   log_dir=None,
                   exp_name=None,
                   snapshot_mode='last',
                   snapshot_gap=1,
                   exp_prefix='experiment',
                   log_tabular_only=False):
    default_log_dir = config.LOG_DIR + "/local/" + exp_prefix
    set_seed(seed)
    if exp_name is None:
        now = datetime.datetime.now(dateutil.tz.tzlocal())
        timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
        exp_name = 'experiment_%s' % (timestamp)
    if n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=n_parallel)
        parallel_sampler.set_seed(seed)
    if plot:
        from rllab.plotter import plotter
        plotter.init_worker()
    if log_dir is None:
        log_dir = osp.join(default_log_dir, exp_name)
    tabular_log_file = osp.join(log_dir, 'progress.csv')
    text_log_file = osp.join(log_dir, 'debug.log')
    #params_log_file = osp.join(log_dir, 'params.json')

    #logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(snapshot_mode)
    logger.set_snapshot_gap(snapshot_gap)
    logger.set_log_tabular_only(log_tabular_only)
    logger.push_prefix("[%s] " % exp_name)

    algo.train()

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Exemplo n.º 8
0
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=
        'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
    )
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), "gap" (every'
                        '`snapshot_gap` iterations are saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--tensorboard_log_dir',
                        type=str,
                        default='tb',
                        help='Name of the folder for tensorboard_summary.')
    parser.add_argument(
        '--tensorboard_step_key',
        type=str,
        default=None,
        help=
        'Name of the step key in log data which shows the step in tensorboard_summary.'
    )
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False)
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='checkpoint',
                        help='Name of the folder for checkpoints.')
    parser.add_argument('--obs_dir',
                        type=str,
                        default='obs',
                        help='Name of the folder for original observations.')

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)
    tensorboard_log_dir = osp.join(log_dir, args.tensorboard_log_dir)
    checkpoint_dir = osp.join(log_dir, args.checkpoint_dir)
    obs_dir = osp.join(log_dir, args.obs_dir)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    logger.set_tensorboard_dir(tensorboard_log_dir)
    logger.set_checkpoint_dir(checkpoint_dir)
    logger.set_obs_dir(obs_dir)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.set_tensorboard_step_key(args.tensorboard_step_key)
    logger.push_prefix("[%s] " % args.exp_name)

    git_commit = get_git_commit_hash()
    logger.log('Git commit: {}'.format(git_commit))

    git_diff_file_path = osp.join(log_dir,
                                  'git_diff_{}.patch'.format(git_commit))
    save_git_diff_to_file(git_diff_file_path)

    logger.log('hostname: {}, pid: {}, tmux session: {}'.format(
        socket.gethostname(), os.getpid(), get_tmux_session_name()))

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Exemplo n.º 9
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum', action='store_true', default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--n_good', type=int, default=3)
    parser.add_argument('--n_hostage', type=int, default=5)
    parser.add_argument('--n_bad', type=int, default=5)
    parser.add_argument('--n_coop_save', type=int, default=2)
    parser.add_argument('--n_coop_avoid', type=int, default=2)
    parser.add_argument('--n_sensors', type=int, default=20)
    parser.add_argument('--sensor_range', type=float, default=0.2)
    parser.add_argument('--save_reward', type=float, default=3)
    parser.add_argument('--hit_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.01)
    parser.add_argument('--bomb_reward', type=float, default=-10.)

    parser.add_argument('--recurrent', action='store_true', default=False)
    parser.add_argument('--baseline_type', type=str, default='linear')
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')


    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    centralized = True if args.control == 'centralized' else False

    sensor_range = np.array(map(float, args.sensor_range.split(',')))
    assert sensor_range.shape == (args.n_pursuers,)

    env = ContinuousHostageWorld(args.n_good, args.n_hostage, args.n_bad, args.n_coop_save,
                                 args.n_coop_avoid, n_sensors=args.n_sensors,
                                 sensor_range=args.sensor_range, save_reward=args.save_reward,
                                 hit_reward=args.hit_reward, encounter_reward=args.encounter_reward,
                                 bomb_reward=args.bomb_reward)

    env = RLLabEnv(StandardizedEnv(env), mode=args.control)

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if args.recurrent:
        policy = GaussianGRUPolicy(env_spec=env.spec, hidden_sizes=args.hidden_sizes)
    else:
        policy = GaussianMLPPolicy(env_spec=env.spec, hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(obsfeat_space)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(env=env,
            policy=policy,
            baseline=baseline,
            batch_size=args.n_timesteps,
            max_path_length=args.max_traj_len,
            n_itr=args.n_iter,
            discount=args.discount,
            step_size=args.max_kl,
            mode=args.control,)

    algo.train()
Exemplo n.º 10
0
def before_experiment(cfg, results):
    #initialize the parallel workers
    if cfg['num_cores'] > 1:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(cfg['num_cores'])
Exemplo n.º 11
0
    def __init__(self, env, args):
        self.args = args
        # Parallel setup
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            set_seed(args.seed)
            parallel_sampler.set_seed(args.seed)

        env, policy = rllab_envpolicy_parser(env, args)

        if not args.algo == 'thddpg':
            # Baseline
            if args.baseline_type == 'linear':
                baseline = LinearFeatureBaseline(env_spec=env.spec)
            elif args.baseline_type == 'zero':
                baseline = ZeroBaseline(env_spec=env.spec)
            else:
                raise NotImplementedError(args.baseline_type)

        # Logger
        default_log_dir = config.LOG_DIR
        if args.log_dir is None:
            log_dir = osp.join(default_log_dir, args.exp_name)
        else:
            log_dir = args.log_dir

        tabular_log_file = osp.join(log_dir, args.tabular_log_file)
        text_log_file = osp.join(log_dir, args.text_log_file)
        params_log_file = osp.join(log_dir, args.params_log_file)

        logger.log_parameters_lite(params_log_file, args)
        logger.add_text_output(text_log_file)
        logger.add_tabular_output(tabular_log_file)
        prev_snapshot_dir = logger.get_snapshot_dir()
        prev_mode = logger.get_snapshot_mode()
        logger.set_snapshot_dir(log_dir)
        logger.set_snapshot_mode(args.snapshot_mode)
        logger.set_log_tabular_only(args.log_tabular_only)
        logger.push_prefix("[%s] " % args.exp_name)

        if args.algo == 'tftrpo':
            self.algo = TRPO(
                env=env,
                policy=policy,
                baseline=baseline,
                batch_size=args.batch_size,
                max_path_length=args.max_path_length,
                n_itr=args.n_iter,
                discount=args.discount,
                gae_lambda=args.gae_lambda,
                step_size=args.step_size,
                optimizer=ConjugateGradientOptimizer(
                    hvp_approach=FiniteDifferenceHvp(
                        base_eps=1e-5)) if args.recurrent else None,
                mode=args.control)
        elif args.algo == 'thddpg':
            qfunc = thContinuousMLPQFunction(env_spec=env.spec)
            if args.exp_strategy == 'ou':
                es = OUStrategy(env_spec=env.spec)
            elif args.exp_strategy == 'gauss':
                es = GaussianStrategy(env_spec=env.spec)
            else:
                raise NotImplementedError()

            self.algo = thDDPG(env=env,
                               policy=policy,
                               qf=qfunc,
                               es=es,
                               batch_size=args.batch_size,
                               max_path_length=args.max_path_length,
                               epoch_length=args.epoch_length,
                               min_pool_size=args.min_pool_size,
                               replay_pool_size=args.replay_pool_size,
                               n_epochs=args.n_iter,
                               discount=args.discount,
                               scale_reward=0.01,
                               qf_learning_rate=args.qfunc_lr,
                               policy_learning_rate=args.policy_lr,
                               eval_samples=args.eval_samples,
                               mode=args.control)
Exemplo n.º 12
0
def perform_evaluation(num_parallel,
                       hidden_size,
                       batch_size,
                       pathlength,
                       random_split,
                       prioritized_split,
                       adaptive_sample,
                       initialize_epochs,
                       grad_epochs,
                       test_epochs,
                       append,
                       task_size,
                       load_init_policy,
                       load_split_data,
                       alternate_update,
                       accumulate_gradient,
                       imbalance_sample,
                       sample_ratio,
                       split_percentages,
                       env_name,
                       seed,
                       test_num=1,
                       param_update_start=50,
                       param_update_frequency=50,
                       param_update_end=200,
                       use_param_variance=0,
                       param_variance_batch=10000,
                       param_variance_sample=100,
                       reverse_metric=False):
    reps = 1

    learning_curves = []
    kl_divergences = []
    for i in range(len(split_percentages)):
        learning_curves.append([])
        kl_divergences.append([])

    performances = []

    diretory = 'data/trained/gradient_temp/rl_split_' + append

    if not os.path.exists(diretory):
        os.makedirs(diretory)
        os.makedirs(diretory + '/policies')

    for testit in range(test_num):
        print('======== Start Test ', testit, ' ========')
        env = normalize(GymEnv(env_name, record_log=False, record_video=False))
        dartenv = env._wrapped_env.env.env
        if env._wrapped_env.monitoring:
            dartenv = dartenv.env

        np.random.seed(testit * 3 + seed)
        random.seed(testit * 3 + seed)

        pre_training_learning_curve = []

        policy = GaussianMLPPolicy(
            env_spec=env.spec,
            # The neural network policy should have two hidden layers, each with 32 hidden units.
            hidden_sizes=hidden_size,
            # append_dim=2,
            net_mode=0,
        )
        baseline = LinearFeatureBaseline(env_spec=env.spec, additional_dim=0)

        if load_init_policy:
            policy = joblib.load(diretory + '/init_policy.pkl')

        if adaptive_sample:
            new_batch_size = int(batch_size / task_size)
        else:
            new_batch_size = batch_size

        algo = TRPO(  # _MultiTask(
            env=env,
            policy=policy,
            baseline=baseline,
            batch_size=new_batch_size,
            max_path_length=pathlength,
            n_itr=5,
            discount=0.995,
            step_size=0.02,
            gae_lambda=0.97,
            whole_paths=False,
            # task_num=task_size,
        )
        algo.init_opt()

        from rllab.sampler import parallel_sampler

        parallel_sampler.initialize(n_parallel=num_parallel)
        parallel_sampler.set_seed(0)

        algo.start_worker()

        if not load_init_policy:
            for i in range(initialize_epochs):
                print('------ Iter ', i, ' in Init Training --------')
                if adaptive_sample:
                    paths = []
                    reward_paths = []
                    for t in range(task_size):
                        paths += algo.sampler.obtain_samples(0, t)
                        #reward_paths += algo.sampler.obtain_samples(0)
                elif imbalance_sample:
                    paths = []
                    reward_paths = []
                    for t in range(task_size):
                        algo.batch_size = batch_size * sample_ratio[t]
                        task_path = algo.sampler.obtain_samples(0, t)
                        paths += task_path
                        if t == 0:
                            reward_paths += task_path
                else:
                    paths = algo.sampler.obtain_samples(0)
                samples_data = algo.sampler.process_samples(0, paths)
                opt_data = algo.optimize_policy(0, samples_data)
                pol_aft = (policy.get_param_values())
                print(algo.mean_kl(samples_data))

                print(dict(logger._tabular)['AverageReturn'])
                pre_training_learning_curve.append(
                    dict(logger._tabular)['AverageReturn'])
            joblib.dump(policy, diretory + '/init_policy.pkl', compress=True)

        print('------- initial training complete ---------------')
        if not load_split_data:
            split_data = []
            net_weights = []
            net_weight_values = []
            for i in range(grad_epochs):
                cur_param_val = np.copy(policy.get_param_values())
                cur_param = copy.deepcopy(policy.get_params())

                cp = []
                for param in policy._mean_network.get_params():
                    cp.append(np.copy(param.get_value()))
                net_weights.append(cp)
                net_weight_values.append(np.copy(policy.get_param_values()))

                if adaptive_sample:
                    paths = []
                    reward_paths = []
                    for t in range(task_size):
                        paths += algo.sampler.obtain_samples(0, t)
                        #reward_paths += algo.sampler.obtain_samples(0)
                elif imbalance_sample:
                    paths = []
                    reward_paths = []
                    for t in range(task_size):
                        algo.batch_size = batch_size * sample_ratio[t]
                        task_path = algo.sampler.obtain_samples(0, t)
                        paths += task_path
                        if t == 0:
                            reward_paths += task_path
                else:
                    paths = algo.sampler.obtain_samples(0)
                split_data.append(paths)

                samples_data = algo.sampler.process_samples(0, paths)
                opt_data = algo.optimize_policy(0, samples_data)
                pre_training_learning_curve.append(
                    dict(logger._tabular)['AverageReturn'])
            joblib.dump(split_data,
                        diretory + '/split_data.pkl',
                        compress=True)
            joblib.dump(net_weights,
                        diretory + '/net_weights.pkl',
                        compress=True)
            joblib.dump(net_weight_values,
                        diretory + '/net_weight_values.pkl',
                        compress=True)
            joblib.dump(pre_training_learning_curve,
                        diretory + '/pretrain_learningcurve_' + str(testit) +
                        '.pkl',
                        compress=True)
        else:
            split_data = joblib.load(diretory + '/split_data.pkl')
            net_weights = joblib.load(diretory + '/net_weights.pkl')
            net_weight_values = joblib.load(diretory +
                                            '/net_weight_values.pkl')
            pre_training_learning_curve = joblib.load(
                diretory + '/pretrain_learningcurve_' + str(testit) + '.pkl')

        task_grads = []
        variance_grads = []
        for i in range(task_size):
            task_grads.append([])
        for i in range(grad_epochs):
            policy.set_param_values(net_weight_values[i])
            task_paths = []
            for j in range(task_size):
                task_paths.append([])
            for path in split_data[i]:
                taskid = path['env_infos']['state_index'][-1]
                task_paths[taskid].append(path)

            for j in range(task_size):
                samples_data = algo.sampler.process_samples(
                    0, task_paths[j], False)
                grad = get_gradient(algo, samples_data, False)
                task_grads[j].append(grad)
            if use_param_variance == 1 and i == grad_epochs - 1:
                for j in range(param_variance_sample):
                    samples_data_ori = algo.sampler.process_samples(
                        0, split_data[i], False)
                    samples_data = {}
                    indices = np.arange(len(samples_data_ori['observations']))
                    np.random.shuffle(indices)
                    samples_data["observations"] = samples_data_ori[
                        "observations"][indices[0:param_variance_batch]]
                    samples_data["actions"] = samples_data_ori["actions"][
                        indices[0:param_variance_batch]]
                    samples_data["rewards"] = samples_data_ori["rewards"][
                        indices[0:param_variance_batch]]
                    samples_data["advantages"] = samples_data_ori[
                        "advantages"][indices[0:param_variance_batch]]
                    samples_data["agent_infos"] = {}
                    samples_data["agent_infos"]["log_std"] = samples_data_ori[
                        "agent_infos"]["log_std"][
                            indices[0:param_variance_batch]]
                    samples_data["agent_infos"]["mean"] = samples_data_ori[
                        "agent_infos"]["mean"][indices[0:param_variance_batch]]
                    grad = get_gradient(algo, samples_data, False)
                    variance_grads.append(grad)
            algo.sampler.process_samples(0, split_data[i])

        weight_variances = []
        for i in range(len(task_grads[0][0]) - 1):
            weight_variances.append(np.zeros(task_grads[0][0][i].shape))
        if use_param_variance == 1:
            for k in range(len(task_grads[0][0]) - 1):
                one_grad = []
                for g in range(len(variance_grads)):
                    one_grad.append(np.asarray(variance_grads[g][k]))
                weight_variances[k] += np.var(one_grad, axis=0)

        print('------- collected gradient info -------------')

        split_counts = []
        for i in range(len(task_grads[0][0]) - 1):
            split_counts.append(np.zeros(task_grads[0][0][i].shape))

        for i in range(len(task_grads[0])):
            for k in range(len(task_grads[0][i]) - 1):
                region_gradients = []
                for region in range(len(task_grads)):
                    region_gradients.append(task_grads[region][i][k])
                region_gradients = np.array(region_gradients)
                if not random_split:
                    split_counts[k] += np.var(
                        region_gradients, axis=0
                    )  # * np.abs(net_weights[i][k])# + 100 * (len(task_grads[0][i])-k)
                elif prioritized_split:
                    split_counts[k] += np.random.random(
                        split_counts[k].shape) * (len(task_grads[0][i]) - k)
                else:
                    split_counts[k] += np.random.random(split_counts[k].shape)

        for j in range(len(split_counts)):
            plt.figure()
            plt.title(policy._mean_network.get_params()[j].name)
            if len(split_counts[j].shape) == 2:
                plt.imshow(split_counts[j])
                plt.colorbar()
            elif len(split_counts[j].shape) == 1:
                plt.plot(split_counts[j])

            plt.savefig(diretory + '/' +
                        policy._mean_network.get_params()[j].name + '.png')

            if use_param_variance:
                plt.figure()
                plt.title(policy._mean_network.get_params()[j].name)
                if len(weight_variances[j].shape) == 2:
                    plt.imshow(weight_variances[j])
                    plt.colorbar()
                elif len(weight_variances[j].shape) == 1:
                    plt.plot(weight_variances[j])

                plt.savefig(diretory + '/' +
                            policy._mean_network.get_params()[j].name +
                            '_variances.png')

        algo.shutdown_worker()

        # organize the metric into each edges and sort them
        split_metrics = []
        metrics_list = []
        variance_list = []
        for k in range(len(task_grads[0][0]) - 1):
            for index, value in np.ndenumerate(split_counts[k]):
                split_metrics.append(
                    [k, index, value, weight_variances[k][index]])
                metrics_list.append(value)
                variance_list.append(weight_variances[k][index])
        if use_param_variance == 0:
            split_metrics.sort(key=lambda x: x[2], reverse=True)
        else:
            split_metrics.sort(key=lambda x: x[3], reverse=True)

        # test the effect of splitting
        total_param_size = len(policy._mean_network.get_param_values())

        pred_list = []
        # use the optimized network
        init_param_value = np.copy(policy.get_param_values())

        for split_id, split_percentage in enumerate(split_percentages):
            split_param_size = split_percentage * total_param_size
            masks = []
            for k in range(len(task_grads[0][0]) - 1):
                masks.append(np.zeros(split_counts[k].shape))

            if split_percentage <= 1.0:
                for i in range(int(split_param_size)):
                    masks[split_metrics[i][0]][split_metrics[i][1]] = 1
            else:
                threshold = np.mean(metrics_list) + np.std(metrics_list)
                print('threashold,', threshold)
                for i in range(len(split_metrics)):
                    if split_metrics[i][2] < threshold:
                        break
                    else:
                        masks[split_metrics[i][0]][split_metrics[i][1]] = 1

            mask_split_flat = np.array([])
            for k in range(int((len(task_grads[0][0]) - 1) / 2)):
                for j in range(task_size):
                    mask_split_flat = np.concatenate([
                        mask_split_flat,
                        np.array(masks[k * 2]).flatten(),
                        np.array(masks[k * 2 + 1]).flatten()
                    ])
            mask_share_flat = np.ones(len(mask_split_flat))
            mask_share_flat -= mask_split_flat
            if np.abs(split_percentage - 1.0) < 0.0001:
                mask_split_flat = np.concatenate(
                    [mask_split_flat,
                     np.ones(dartenv.act_dim * task_size)])
                mask_share_flat = np.concatenate(
                    [mask_share_flat,
                     np.zeros(dartenv.act_dim * task_size)])
            else:
                mask_split_flat = np.concatenate(
                    [mask_split_flat,
                     np.zeros(dartenv.act_dim)])
                mask_share_flat = np.concatenate(
                    [mask_share_flat,
                     np.ones(dartenv.act_dim)])

            policy.set_param_values(init_param_value)
            if split_param_size != 0:
                if dartenv.avg_div != task_size:
                    dartenv.avg_div = task_size
                    dartenv.obs_dim += dartenv.avg_div
                    high = np.inf * np.ones(dartenv.obs_dim)
                    low = -high
                    dartenv.observation_space = spaces.Box(low, high)
                    env._wrapped_env._observation_space = rllab.envs.gym_env.convert_gym_space(
                        dartenv.observation_space)
                    env.spec = rllab.envs.env_spec.EnvSpec(
                        observation_space=env.observation_space,
                        action_space=env.action_space,
                    )

                split_policy = GaussianMLPPolicy(
                    env_spec=env.spec,
                    # The neural network policy should have two hidden layers, each with 32 hidden units.
                    hidden_sizes=hidden_size,
                    # append_dim=2,
                    net_mode=8,
                    split_num=task_size,
                    split_masks=masks,
                    split_init_net=policy,
                    split_std=np.abs(split_percentage - 1.0) < 0.0001,
                )
            else:
                split_policy = copy.deepcopy(policy)

            if split_param_size == 0:
                baseline_add = 0
            else:
                baseline_add = task_size  # use 0 for now, though task_size should in theory improve performance more
            split_baseline = LinearFeatureBaseline(env_spec=env.spec,
                                                   additional_dim=baseline_add)

            new_batch_size = batch_size
            if (split_param_size != 0 and alternate_update) or adaptive_sample:
                new_batch_size = int(batch_size / task_size)
            split_algo = TRPO(  # _MultiTask(
                env=env,
                policy=split_policy,
                baseline=split_baseline,
                batch_size=new_batch_size,
                max_path_length=pathlength,
                n_itr=5,
                discount=0.995,
                step_size=0.02,
                gae_lambda=0.97,
                whole_paths=False,
                # task_num=task_size,
            )
            split_algo.init_opt()

            parallel_sampler.initialize(n_parallel=num_parallel)
            parallel_sampler.set_seed(0)

            split_algo.start_worker()
            if split_param_size != 0:
                parallel_sampler.update_env_params({
                    'avg_div':
                    dartenv.avg_div,
                    'obs_dim':
                    dartenv.obs_dim,
                    'observation_space':
                    dartenv.observation_space
                })

            print('Network parameter size: ', total_param_size,
                  len(split_policy.get_param_values()))

            split_init_param = np.copy(split_policy.get_param_values())
            avg_error = 0.0

            avg_learning_curve = []
            for rep in range(int(reps)):
                split_policy.set_param_values(split_init_param)
                learning_curve = []
                kl_div_curve = []
                for i in range(test_epochs):
                    # if not split
                    if split_param_size == 0:
                        paths, _ = get_samples(split_algo, task_size,
                                               adaptive_sample,
                                               imbalance_sample, batch_size,
                                               sample_ratio)
                        # sanity check
                        samp_num = 0
                        for p in paths:
                            samp_num += len(p['observations'])
                        print('samp_num: ', samp_num, adaptive_sample,
                              imbalance_sample)
                        samples_data = split_algo.sampler.process_samples(
                            0, paths)
                        opt_data = split_algo.optimize_policy(0, samples_data)

                        if imbalance_sample:
                            reward = 0
                            for path in reward_paths:
                                reward += np.sum(path["rewards"])
                            reward /= len(reward_paths)
                        else:
                            reward = float(
                                (dict(logger._tabular)['AverageReturn']))
                        kl_div_curve.append(split_algo.mean_kl(samples_data))
                        print('reward: ', reward)
                        print(split_algo.mean_kl(samples_data))
                    elif alternate_update:
                        reward = 0
                        total_traj = 0
                        task_rewards = []
                        for j in range(task_size):
                            paths = split_algo.sampler.obtain_samples(0, j)
                            # split_algo.sampler.process_samples(0, paths)
                            samples_data = split_algo.sampler.process_samples(
                                0, paths)
                            opt_data = split_algo.optimize_policy(
                                0, samples_data)
                            reward += float((dict(
                                logger._tabular)['AverageReturn'])) * float(
                                    (dict(logger._tabular)['NumTrajs']))
                            total_traj += float(
                                (dict(logger._tabular)['NumTrajs']))
                            task_rewards.append(
                                dict(logger._tabular)['AverageReturn'])
                        reward /= total_traj
                        print('reward for different tasks: ', task_rewards,
                              reward)
                    elif accumulate_gradient:
                        paths, _ = get_samples(split_algo, task_size,
                                               adaptive_sample,
                                               imbalance_sample, batch_size,
                                               sample_ratio)

                        task_paths = []
                        task_rewards = []
                        for j in range(task_size):
                            task_paths.append([])
                            task_rewards.append([])
                        for path in paths:
                            taskid = path['env_infos']['state_index'][-1]
                            task_paths[taskid].append(path)
                            task_rewards[taskid].append(np.sum(
                                path['rewards']))
                        pre_opt_parameter = np.copy(
                            split_policy.get_param_values())

                        # compute the split gradient first
                        split_policy.set_param_values(pre_opt_parameter)
                        accum_grad = np.zeros(pre_opt_parameter.shape)
                        processed_task_data = []
                        for j in range(task_size):
                            if len(task_paths[j]) == 0:
                                processed_task_data.append([])
                                continue
                            split_policy.set_param_values(pre_opt_parameter)
                            # split_algo.sampler.process_samples(0, task_paths[j])
                            samples_data = split_algo.sampler.process_samples(
                                0, task_paths[j], False)
                            processed_task_data.append(samples_data)
                            #split_algo.optimize_policy(0, samples_data)

                            # if j == 1:
                            accum_grad += split_policy.get_param_values(
                            ) - pre_opt_parameter
                        # sanity check
                        samp_num = 0
                        for p in paths:
                            samp_num += len(p['observations'])
                        print('samp_num: ', samp_num)

                        # compute the gradient together
                        split_policy.set_param_values(pre_opt_parameter)
                        all_data = split_algo.sampler.process_samples(0, paths)
                        if imbalance_sample:
                            reward = 0
                            for path in reward_paths:
                                reward += np.sum(path["rewards"])
                            reward /= len(reward_paths)
                        else:
                            reward = float(
                                (dict(logger._tabular)['AverageReturn']))

                        split_algo.optimize_policy(0, all_data)
                        all_data_grad = split_policy.get_param_values(
                        ) - pre_opt_parameter

                        # do a line search to project the udpate onto the constraint manifold
                        sum_grad = all_data_grad  # * mask_split_flat + all_data_grad * mask_share_flat

                        ls_steps = []
                        loss_before = split_algo.loss(all_data)

                        for s in range(50):
                            ls_steps.append(0.97**s)
                        for step in ls_steps:
                            split_policy.set_param_values(pre_opt_parameter +
                                                          sum_grad * step)
                            if split_algo.mean_kl(
                                    all_data
                            )[0] < split_algo.step_size:  # and split_algo.loss(all_data)[0] < loss_before[0]:
                                break
                        # step=1

                        split_policy.set_param_values(pre_opt_parameter +
                                                      sum_grad * step)

                        for j in range(task_size):
                            task_rewards[j] = np.mean(task_rewards[j])

                        print('reward for different tasks: ', task_rewards,
                              reward)
                        print('mean kl: ', split_algo.mean_kl(all_data),
                              ' step size: ', step)
                        task_mean_kls = []
                        for j in range(task_size):
                            if len(processed_task_data[j]) == 0:
                                task_mean_kls.append(0)
                            else:
                                task_mean_kls.append(
                                    split_algo.mean_kl(
                                        processed_task_data[j])[0])
                        print('mean kl for different tasks: ', task_mean_kls)
                        kl_div_curve.append(
                            np.concatenate(
                                [split_algo.mean_kl(all_data), task_mean_kls]))
                    else:
                        paths = split_algo.sampler.obtain_samples(0)
                        reward = float(
                            (dict(logger._tabular)['AverageReturn']))
                        task_paths = []
                        task_rewards = []
                        for j in range(task_size):
                            task_paths.append([])
                            task_rewards.append([])
                        for path in paths:
                            taskid = path['env_infos']['state_index'][-1]
                            task_paths[taskid].append(path)
                            task_rewards[taskid].append(np.sum(
                                path['rewards']))
                        pre_opt_parameter = np.copy(
                            split_policy.get_param_values())
                        # optimize the shared part
                        # split_algo.sampler.process_samples(0, paths)
                        samples_data = split_algo.sampler.process_samples(
                            0, paths)
                        for layer in split_policy._mean_network._layers:
                            for param in layer.get_params():
                                if 'split' in param.name:
                                    layer.params[param].remove('trainable')
                        split_policy._cached_params = {}
                        split_policy._cached_param_dtypes = {}
                        split_policy._cached_param_shapes = {}
                        split_algo.init_opt()
                        print(
                            'Optimizing shared parameter size: ',
                            len(split_policy.get_param_values(trainable=True)))
                        split_algo.optimize_policy(0, samples_data)

                        # optimize the tasks
                        for layer in split_policy._mean_network._layers:
                            for param in layer.get_params():
                                if 'split' in param.name:
                                    layer.params[param].add('trainable')
                                if 'share' in param.name:
                                    layer.params[param].remove('trainable')

                        # shuffle the optimization order
                        opt_order = np.arange(task_size)
                        np.random.shuffle(opt_order)
                        split_policy._cached_params = {}
                        split_policy._cached_param_dtypes = {}
                        split_policy._cached_param_shapes = {}
                        split_algo.init_opt()
                        for taskid in opt_order:
                            # split_algo.sampler.process_samples(0, task_paths[taskid])
                            samples_data = split_algo.sampler.process_samples(
                                0, task_paths[taskid])
                            print(
                                'Optimizing parameter size: ',
                                len(
                                    split_policy.get_param_values(
                                        trainable=True)))
                            split_algo.optimize_policy(0, samples_data)
                        for layer in split_policy._mean_network._layers:
                            for param in layer.get_params():
                                if 'share' in param.name:
                                    layer.params[param].add('trainable')

                        for j in range(task_size):
                            task_rewards[j] = np.mean(task_rewards[j])
                        print('reward for different tasks: ', task_rewards,
                              reward)

                    learning_curve.append(reward)
                    if (i + initialize_epochs +
                            grad_epochs) % param_update_frequency == 0 and (
                                i + initialize_epochs +
                                grad_epochs) < param_update_end and (
                                    i + initialize_epochs +
                                    grad_epochs) > param_update_start:
                        print("Updating model parameters...")
                        parallel_sampler.update_env_params(
                            {'task_expand_flag': True})
                    print('============= Finished ', split_percentage, ' Rep ',
                          rep, '   test ', i, ' ================')
                    print(diretory)
                    joblib.dump(split_policy,
                                diretory + '/policies/policy_' + str(rep) +
                                '_' + str(i) + '_' + str(split_percentage) +
                                '.pkl',
                                compress=True)
                avg_learning_curve.append(learning_curve)
                kl_divergences[split_id].append(kl_div_curve)
                joblib.dump(split_policy,
                            diretory + '/policies/final_policy_' +
                            str(split_percentage) + '.pkl',
                            compress=True)

                avg_error += float(reward)
            pred_list.append(avg_error / reps)
            print(split_percentage, avg_error / reps)
            split_algo.shutdown_worker()
            print(avg_learning_curve)
            avg_learning_curve = np.mean(avg_learning_curve, axis=0)
            learning_curves[split_id].append(avg_learning_curve)
            # output the learning curves so far
            joblib.dump(learning_curves,
                        diretory + '/learning_curve.pkl',
                        compress=True)
            avg_learning_curve = []
            for lc in range(len(learning_curves)):
                avg_learning_curve.append(np.mean(learning_curves[lc], axis=0))
            plt.figure()
            for lc in range(len(learning_curves)):
                plt.plot(avg_learning_curve[lc],
                         label=str(split_percentages[lc]))
            plt.legend(bbox_to_anchor=(0.3, 0.3),
                       bbox_transform=plt.gcf().transFigure,
                       numpoints=1)
            plt.savefig(diretory + '/split_learning_curves.png')

            if len(kl_divergences[0]) > 0:
                #print('kldiv:', kl_divergences)
                avg_kl_div = []
                for i in range(len(kl_divergences)):
                    if len(kl_divergences[i]) > 0:
                        avg_kl_div.append(np.mean(kl_divergences[i], axis=0))
                #print(avg_kl_div)
                joblib.dump(avg_kl_div,
                            diretory + '/kl_divs.pkl',
                            compress=True)
                for i in range(len(avg_kl_div)):
                    one_perc_kl_div = np.array(avg_kl_div[i])
                    #print(i, one_perc_kl_div)
                    plt.figure()
                    for j in range(len(one_perc_kl_div[0])):
                        append = 'task%d' % j
                        if j == 0:
                            append = 'all'
                        plt.plot(one_perc_kl_div[:, j],
                                 label=str(split_percentages[i]) + append,
                                 alpha=0.3)
                    plt.legend(bbox_to_anchor=(0.3, 0.3),
                               bbox_transform=plt.gcf().transFigure,
                               numpoints=1)
                    plt.savefig(diretory +
                                '/kl_div_%s.png' % str(split_percentages[i]))
        performances.append(pred_list)

    np.savetxt(diretory + '/performance.txt', performances)
    plt.figure()
    plt.plot(split_percentages, np.mean(performances, axis=0))
    plt.savefig(diretory + '/split_performance.png')
    joblib.dump(learning_curves,
                diretory + '/learning_curve.pkl',
                compress=True)

    avg_learning_curve = []
    for i in range(len(learning_curves)):
        avg_learning_curve.append(np.mean(learning_curves[i], axis=0))
    plt.figure()
    for i in range(len(split_percentages)):
        plt.plot(avg_learning_curve[i], label=str(split_percentages[i]))
    plt.legend(bbox_to_anchor=(0.3, 0.3),
               bbox_transform=plt.gcf().transFigure,
               numpoints=1)
    plt.savefig(diretory + '/split_learning_curves.png')
    #np.savetxt(diretory + '/learning_curves.txt', avg_learning_curve)

    if len(kl_divergences[0]) > 0:
        avg_kl_div = []
        for i in range(len(kl_divergences)):
            avg_kl_div.append(np.mean(kl_divergences[i], axis=0))
        joblib.dump(avg_kl_div, diretory + '/kl_divs.pkl', compress=True)
        for i in range(len(avg_kl_div)):
            one_perc_kl_div = np.array(avg_kl_div[i])
            plt.figure()
            for j in range(len(one_perc_kl_div[0])):
                append = 'task%d' % j
                if j == 0:
                    append = 'all'
                plt.plot(one_perc_kl_div[:, j],
                         label=str(split_percentages[i]) + append,
                         alpha=0.3)
            plt.legend(bbox_to_anchor=(0.3, 0.3),
                       bbox_transform=plt.gcf().transFigure,
                       numpoints=1)
            plt.savefig(diretory +
                        '/kl_div_%s.png' % str(split_percentages[i]))

    plt.close('all')

    print(diretory)
Exemplo n.º 13
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum',
                        action='store_true',
                        default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--n_good', type=int, default=3)
    parser.add_argument('--n_hostage', type=int, default=5)
    parser.add_argument('--n_bad', type=int, default=5)
    parser.add_argument('--n_coop_save', type=int, default=2)
    parser.add_argument('--n_coop_avoid', type=int, default=2)
    parser.add_argument('--n_sensors', type=int, default=20)
    parser.add_argument('--sensor_range', type=float, default=0.2)
    parser.add_argument('--save_reward', type=float, default=3)
    parser.add_argument('--hit_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.01)
    parser.add_argument('--bomb_reward', type=float, default=-10.)

    parser.add_argument('--recurrent', action='store_true', default=False)
    parser.add_argument('--baseline_type', type=str, default='linear')
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )

    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    centralized = True if args.control == 'centralized' else False

    sensor_range = np.array(map(float, args.sensor_range.split(',')))
    assert sensor_range.shape == (args.n_pursuers, )

    env = ContinuousHostageWorld(args.n_good,
                                 args.n_hostage,
                                 args.n_bad,
                                 args.n_coop_save,
                                 args.n_coop_avoid,
                                 n_sensors=args.n_sensors,
                                 sensor_range=args.sensor_range,
                                 save_reward=args.save_reward,
                                 hit_reward=args.hit_reward,
                                 encounter_reward=args.encounter_reward,
                                 bomb_reward=args.bomb_reward)

    env = RLLabEnv(StandardizedEnv(env), mode=args.control)

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if args.recurrent:
        policy = GaussianGRUPolicy(env_spec=env.spec,
                                   hidden_sizes=args.hidden_sizes)
    else:
        policy = GaussianMLPPolicy(env_spec=env.spec,
                                   hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(obsfeat_space)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        n_itr=args.n_iter,
        discount=args.discount,
        step_size=args.max_kl,
        mode=args.control,
    )

    algo.train()
Exemplo n.º 14
0
)

f_train_imp = theano.function(
    inputs=[observations_var, actions_var, d_rewards_var, importance_weights_var],
    outputs=grad_var,
)



variance_svrg_data={}
variance_sgd_data={}
importance_weights_data={}
rewards_snapshot_data={}
rewards_subiter_data={}
n_sub_iter_data={}
parallel_sampler.initialize(3)
for k in range(10):
    if (load_policy):
        snap_policy.set_param_values(np.loadtxt('policy_novar.txt'), trainable=True)
        policy.set_param_values(np.loadtxt('policy_novar.txt'), trainable=True)
    avg_return = list()
    n_sub_iter=[]
    rewards_sub_iter=[]
    rewards_snapshot=[]
    importance_weights=[]
    variance_svrg = []
    variance_sgd = []

    #np.savetxt("policy_novar.txt",snap_policy.get_param_values(trainable=True))
    j=0
    while j<s_tot-N:
Exemplo n.º 15
0
    zero_adv_policy = ConstantControlPolicy(
        env_spec=env.spec,
        is_protagonist=False,
        constant_val = 0.0
    )

    ## Adversary policy definition ##
    adv_policy = GaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=layer_size,
        is_protagonist=False
    )
    adv_baseline = LinearFeatureBaseline(env_spec=env.spec)

    ## Initializing the parallel sampler ##
    parallel_sampler.initialize(n_process)

    ## Optimizer for the Protagonist ##
    pro_algo = TRPO(
        env=env,
        pro_policy=pro_policy,
        adv_policy=adv_policy,
        pro_baseline=pro_baseline,
        adv_baseline=adv_baseline,
        batch_size=batch_size,
        max_path_length=path_length,
        n_itr=n_pro_itr,
        discount=0.995,
        gae_lambda=gae_lambda,
        step_size=step_size,
        is_protagonist=True
Exemplo n.º 16
0
def run_experiment(argv):
    # e2crawfo: These imports, in this order, were necessary for fixing issues on cedar.
    import rllab.mujoco_py.mjlib
    import tensorflow

    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=
        'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
    )
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_tf_summary_dir(osp.join(log_dir, "tf_summary"))
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        maybe_iter = algo.train()
        if is_iterable(maybe_iter):
            for _ in maybe_iter:
                pass
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Exemplo n.º 17
0
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel', type=int, default=1,
                        help='Number of parallel workers to perform rollouts. 0 => don\'t start any workers')
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
    parser.add_argument('--log_dir', type=str, default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), "gap" (every'
                             '`snapshot_gap` iterations are saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--snapshot_gap', type=int, default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file', type=str, default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument('--resume_from', type=str, default=None,
                        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot', type=ast.literal_eval, default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data', type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Exemplo n.º 18
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.99)
    parser.add_argument('--gae_lambda', type=float, default=1.0)
    parser.add_argument('--reward_scale', type=float, default=1.0)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum',
                        action='store_true',
                        default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--rectangle', type=str, default='10,10')
    parser.add_argument('--map_type', type=str, default='rectangle')
    parser.add_argument('--n_evaders', type=int, default=5)
    parser.add_argument('--n_pursuers', type=int, default=2)
    parser.add_argument('--obs_range', type=int, default=3)
    parser.add_argument('--n_catch', type=int, default=2)
    parser.add_argument('--urgency', type=float, default=0.0)
    parser.add_argument('--pursuit', dest='train_pursuit', action='store_true')
    parser.add_argument('--evade', dest='train_pursuit', action='store_false')
    parser.set_defaults(train_pursuit=True)
    parser.add_argument('--surround', action='store_true', default=False)
    parser.add_argument('--constraint_window', type=float, default=1.0)
    parser.add_argument('--sample_maps', action='store_true', default=False)
    parser.add_argument('--map_file', type=str, default='../maps/map_pool.npy')
    parser.add_argument('--flatten', action='store_true', default=False)
    parser.add_argument('--reward_mech', type=str, default='global')
    parser.add_argument('--catchr', type=float, default=0.1)
    parser.add_argument('--term_pursuit', type=float, default=5.0)

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_type', type=str, default='linear')

    parser.add_argument('--conv', action='store_true', default=False)

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--checkpoint', type=str, default=None)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )

    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    if args.checkpoint:
        with tf.Session() as sess:
            data = joblib.load(args.checkpoint)
            policy = data['policy']
            env = data['env']
    else:
        if args.sample_maps:
            map_pool = np.load(args.map_file)
        else:
            if args.map_type == 'rectangle':
                env_map = TwoDMaps.rectangle_map(
                    *map(int, args.rectangle.split(',')))
            elif args.map_type == 'complex':
                env_map = TwoDMaps.complex_map(
                    *map(int, args.rectangle.split(',')))
            else:
                raise NotImplementedError()
            map_pool = [env_map]

        env = PursuitEvade(map_pool,
                           n_evaders=args.n_evaders,
                           n_pursuers=args.n_pursuers,
                           obs_range=args.obs_range,
                           n_catch=args.n_catch,
                           train_pursuit=args.train_pursuit,
                           urgency_reward=args.urgency,
                           surround=args.surround,
                           sample_maps=args.sample_maps,
                           constraint_window=args.constraint_window,
                           flatten=args.flatten,
                           reward_mech=args.reward_mech,
                           catchr=args.catchr,
                           term_pursuit=args.term_pursuit)

        env = TfEnv(
            RLLabEnv(StandardizedEnv(env,
                                     scale_reward=args.reward_scale,
                                     enable_obsnorm=False),
                     mode=args.control))

        if args.recurrent:
            if args.conv:
                feature_network = ConvNetwork(
                    name='feature_net',
                    input_shape=emv.spec.observation_space.shape,
                    output_dim=5,
                    conv_filters=(16, 32, 32),
                    conv_filter_sizes=(3, 3, 3),
                    conv_strides=(1, 1, 1),
                    conv_pads=('VALID', 'VALID', 'VALID'),
                    hidden_sizes=(64, ),
                    hidden_nonlinearity=tf.nn.relu,
                    output_nonlinearity=tf.nn.softmax)
            else:
                feature_network = MLP(
                    name='feature_net',
                    input_shape=(env.spec.observation_space.flat_dim +
                                 env.spec.action_space.flat_dim, ),
                    output_dim=5,
                    hidden_sizes=(256, 128, 64),
                    hidden_nonlinearity=tf.nn.tanh,
                    output_nonlinearity=None)
            if args.recurrent == 'gru':
                policy = CategoricalGRUPolicy(env_spec=env.spec,
                                              feature_network=feature_network,
                                              hidden_dim=int(
                                                  args.policy_hidden_sizes),
                                              name='policy')
            elif args.recurrent == 'lstm':
                policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                               feature_network=feature_network,
                                               hidden_dim=int(
                                                   args.policy_hidden_sizes),
                                               name='policy')
        elif args.conv:
            feature_network = ConvNetwork(
                name='feature_net',
                input_shape=env.spec.observation_space.shape,
                output_dim=5,
                conv_filters=(8, 16),
                conv_filter_sizes=(3, 3),
                conv_strides=(2, 1),
                conv_pads=('VALID', 'VALID'),
                hidden_sizes=(32, ),
                hidden_nonlinearity=tf.nn.relu,
                output_nonlinearity=tf.nn.softmax)
            policy = CategoricalMLPPolicy(name='policy',
                                          env_spec=env.spec,
                                          prob_network=feature_network)
        else:
            policy = CategoricalMLPPolicy(name='policy',
                                          env_spec=env.spec,
                                          hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(env_spec=env.spec)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp(
            base_eps=1e-5)) if args.recurrent else None,
        mode=args.control,
    )

    algo.train()
Exemplo n.º 19
0
def run_experiment(
    args_data,
    variant_data=None,
    seed=None,
    n_parallel=1,
    exp_name=None,
    log_dir=None,
    snapshot_mode='all',
    snapshot_gap=1,
    tabular_log_file='progress.csv',
    text_log_file='debug.log',
    params_log_file='params.json',
    variant_log_file='variant.json',
    resume_from=None,
    plot=False,
    log_tabular_only=False,
    log_debug_log_only=False,
):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    if exp_name is None:
        exp_name = default_exp_name

    if seed is not None:
        set_seed(seed)

    if n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=n_parallel)
        if seed is not None:
            parallel_sampler.set_seed(seed)

    if plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if log_dir is None:
        log_dir = osp.join(default_log_dir, exp_name)
    else:
        log_dir = log_dir
    tabular_log_file = osp.join(log_dir, tabular_log_file)
    text_log_file = osp.join(log_dir, text_log_file)
    params_log_file = osp.join(log_dir, params_log_file)

    if variant_data is not None:
        variant_data = variant_data
        variant_log_file = osp.join(log_dir, variant_log_file)
        # print(variant_log_file)
        # print(variant_data)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(snapshot_mode)
    logger.set_snapshot_gap(snapshot_gap)
    logger.set_log_tabular_only(log_tabular_only)
    logger.set_debug_log_only(log_debug_log_only)
    logger.push_prefix("[%s] " % exp_name)

    if resume_from is not None:
        data = joblib.load(resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        args_data(variant_data)

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Exemplo n.º 20
0
def run_experiment(argv):
    default_log_dir = config.LOCAL_LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=
        'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
    )
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), "gap" (every'
                        '`snapshot_gap` iterations are saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False)
    parser.add_argument('--code_diff',
                        type=str,
                        help='A string of the code diff to save.')
    parser.add_argument('--commit_hash',
                        type=str,
                        help='A string of the commit hash')
    parser.add_argument('--script_name',
                        type=str,
                        help='Name of the launched script')

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        raise NotImplementedError("Not supporting non-cloud-pickle")

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)
    """
    Save information for code reproducibility.
    """
    if args.code_diff is not None:
        code_diff_str = cloudpickle.loads(base64.b64decode(args.code_diff))
        with open(osp.join(log_dir, "code.diff"), "w") as f:
            f.write(code_diff_str)
    if args.commit_hash is not None:
        with open(osp.join(log_dir, "commit_hash.txt"), "w") as f:
            f.write(args.commit_hash)
    if args.script_name is not None:
        with open(osp.join(log_dir, "script_name.txt"), "w") as f:
            f.write(args.script_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Exemplo n.º 21
0
def initialize_parallel_sampler(n_processes=-1):
    if n_processes == -1:
        n_processes = min(64, multiprocessing.cpu_count())
    parallel_sampler.initialize(n_parallel=n_processes)
    inputs=[
        observations_var, actions_var, d_rewards_var, importance_weights_var
    ],
    outputs=grad_imp,
)

alla = {}
variance_svrg_data = {}
variance_sgd_data = {}
importance_weights_data = {}
rewards_snapshot_data = {}
rewards_subiter_data = {}
n_sub_iter_data = {}
diff_lr_data = {}
alfa_t_data = {}
parallel_sampler.initialize(10)
for k in range(10):
    if (load_policy):
        snap_policy.set_param_values(np.loadtxt('policy_swimmer.txt'),
                                     trainable=True)
        policy.set_param_values(np.loadtxt('policy_swimmer.txt'),
                                trainable=True)
    else:
        policy.set_param_values(snap_policy.get_param_values(trainable=True),
                                trainable=True)
    avg_return = []
    #np.savetxt("policy_novar.txt",snap_policy.get_param_values(trainable=True))
    n_sub_iter = []
    rewards_sub_iter = []
    rewards_snapshot = []
    importance_weights = []
Exemplo n.º 23
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)
    parser.add_argument('--reward_scale', type=float, default=1.0)
    parser.add_argument('--enable_obsnorm', action='store_true', default=False)
    parser.add_argument('--chunked', action='store_true', default=False)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum',
                        action='store_true',
                        default=False)
    parser.add_argument('--anneal_step_size', type=int, default=0)

    parser.add_argument('--n_timesteps', type=int, default=8000)

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--radius', type=float, default=0.015)
    parser.add_argument('--n_evaders', type=int, default=10)
    parser.add_argument('--n_pursuers', type=int, default=8)
    parser.add_argument('--n_poison', type=int, default=10)
    parser.add_argument('--n_coop', type=int, default=4)
    parser.add_argument('--n_sensors', type=int, default=30)
    parser.add_argument('--sensor_range', type=str, default='0.2')
    parser.add_argument('--food_reward', type=float, default=5)
    parser.add_argument('--poison_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.05)
    parser.add_argument('--reward_mech', type=str, default='local')

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--baseline_type', type=str, default='linear')
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_hidden_sizes', type=str, default='128,128')

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )

    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    centralized = True if args.control == 'centralized' else False

    sensor_range = np.array(map(float, args.sensor_range.split(',')))
    if len(sensor_range) == 1:
        sensor_range = sensor_range[0]
    else:
        assert sensor_range.shape == (args.n_pursuers, )

    env = MAWaterWorld(args.n_pursuers,
                       args.n_evaders,
                       args.n_coop,
                       args.n_poison,
                       radius=args.radius,
                       n_sensors=args.n_sensors,
                       food_reward=args.food_reward,
                       poison_reward=args.poison_reward,
                       encounter_reward=args.encounter_reward,
                       reward_mech=args.reward_mech,
                       sensor_range=sensor_range,
                       obstacle_loc=None)

    env = TfEnv(
        RLLabEnv(StandardizedEnv(env,
                                 scale_reward=args.reward_scale,
                                 enable_obsnorm=args.enable_obsnorm),
                 mode=args.control))

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if args.recurrent:
        feature_network = MLP(
            name='feature_net',
            input_shape=(env.spec.observation_space.flat_dim +
                         env.spec.action_space.flat_dim, ),
            output_dim=16,
            hidden_sizes=(128, 64, 32),
            hidden_nonlinearity=tf.nn.tanh,
            output_nonlinearity=None)
        if args.recurrent == 'gru':
            policy = GaussianGRUPolicy(env_spec=env.spec,
                                       feature_network=feature_network,
                                       hidden_dim=int(
                                           args.policy_hidden_sizes),
                                       name='policy')
        elif args.recurrent == 'lstm':
            policy = GaussianLSTMPolicy(env_spec=env.spec,
                                        feature_network=feature_network,
                                        hidden_dim=int(
                                            args.policy_hidden_sizes),
                                        name='policy')
    else:
        policy = GaussianMLPPolicy(
            name='policy',
            env_spec=env.spec,
            hidden_sizes=tuple(map(int, args.policy_hidden_sizes.split(','))),
            min_std=10e-5)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    elif args.baseline_type == 'mlp':
        raise NotImplementedError()
        # baseline = GaussianMLPBaseline(
        #     env_spec=env.spec, hidden_sizes=tuple(map(int, args.baseline_hidden_sizes.split(','))))
    else:
        baseline = ZeroBaseline(env_spec=env.spec)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        #max_path_length_limit=args.max_path_length_limit,
        update_max_path_length=args.update_curriculum,
        anneal_step_size=args.anneal_step_size,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp(
            base_eps=1e-5)) if args.recurrent else None,
        mode=args.control
        if not args.chunked else 'chunk_{}'.format(args.control),
    )

    algo.train()
Exemplo n.º 24
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', type=str, default=default_exp_name,
                        help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)
    parser.add_argument('--reward_scale', type=float, default=1.0)
    parser.add_argument('--enable_obsnorm', action='store_true', default=False)
    parser.add_argument('--chunked', action='store_true', default=False)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum', action='store_true', default=False)
    parser.add_argument('--anneal_step_size', type=int, default=0)

    parser.add_argument('--n_timesteps', type=int, default=8000)

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--radius', type=float, default=0.015)
    parser.add_argument('--n_evaders', type=int, default=10)
    parser.add_argument('--n_pursuers', type=int, default=8)
    parser.add_argument('--n_poison', type=int, default=10)
    parser.add_argument('--n_coop', type=int, default=4)
    parser.add_argument('--n_sensors', type=int, default=30)
    parser.add_argument('--sensor_range', type=str, default='0.2')
    parser.add_argument('--food_reward', type=float, default=5)
    parser.add_argument('--poison_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.05)
    parser.add_argument('--reward_mech', type=str, default='local')

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--baseline_type', type=str, default='linear')
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_hidden_sizes', type=str, default='128,128')

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data', type=str, help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument(
        '--log_tabular_only', type=ast.literal_eval, default=False,
        help='Whether to only print the tabular log information (in a horizontal format)')

    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    centralized = True if args.control == 'centralized' else False

    sensor_range = np.array(map(float, args.sensor_range.split(',')))
    if len(sensor_range) == 1:
        sensor_range = sensor_range[0]
    else:
        assert sensor_range.shape == (args.n_pursuers,)

    env = MAWaterWorld(args.n_pursuers, args.n_evaders, args.n_coop, args.n_poison,
                       radius=args.radius, n_sensors=args.n_sensors, food_reward=args.food_reward,
                       poison_reward=args.poison_reward, encounter_reward=args.encounter_reward,
                       reward_mech=args.reward_mech, sensor_range=sensor_range, obstacle_loc=None)

    env = TfEnv(
        RLLabEnv(
            StandardizedEnv(env, scale_reward=args.reward_scale,
                            enable_obsnorm=args.enable_obsnorm), mode=args.control))

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if args.recurrent:
        feature_network = MLP(
            name='feature_net',
            input_shape=(env.spec.observation_space.flat_dim + env.spec.action_space.flat_dim,),
            output_dim=16, hidden_sizes=(128, 64, 32), hidden_nonlinearity=tf.nn.tanh,
            output_nonlinearity=None)
        if args.recurrent == 'gru':
            policy = GaussianGRUPolicy(env_spec=env.spec, feature_network=feature_network,
                                       hidden_dim=int(args.policy_hidden_sizes), name='policy')
        elif args.recurrent == 'lstm':
            policy = GaussianLSTMPolicy(env_spec=env.spec, feature_network=feature_network,
                                        hidden_dim=int(args.policy_hidden_sizes), name='policy')
    else:
        policy = GaussianMLPPolicy(
            name='policy', env_spec=env.spec,
            hidden_sizes=tuple(map(int, args.policy_hidden_sizes.split(','))), min_std=10e-5)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    elif args.baseline_type == 'mlp':
        raise NotImplementedError()
        # baseline = GaussianMLPBaseline(
        #     env_spec=env.spec, hidden_sizes=tuple(map(int, args.baseline_hidden_sizes.split(','))))
    else:
        baseline = ZeroBaseline(env_spec=env.spec)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        #max_path_length_limit=args.max_path_length_limit,
        update_max_path_length=args.update_curriculum,
        anneal_step_size=args.anneal_step_size,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)) if
        args.recurrent else None,
        mode=args.control if not args.chunked else 'chunk_{}'.format(args.control),)

    algo.train()
Exemplo n.º 25
0
def run_experiment_here(
        experiment_function,
        exp_prefix="default",
        variant=None,
        exp_id=0,
        seed=0,
        use_gpu=True,
        snapshot_mode='last',
        snapshot_gap=1,
        code_diff=None,
        commit_hash=None,
        script_name=None,
        n_parallel=0,
        base_log_dir=None,
        log_dir=None,
        exp_name=None,
):
    """
    Run an experiment locally without any serialization.

    :param experiment_function: Function. `variant` will be passed in as its
    only argument.
    :param exp_prefix: Experiment prefix for the save file.
    :param variant: Dictionary passed in to `experiment_function`.
    :param exp_id: Experiment ID. Should be unique across all
    experiments. Note that one experiment may correspond to multiple seeds,.
    :param seed: Seed used for this experiment.
    :param use_gpu: Run with GPU. By default False.
    :param script_name: Name of the running script
    :param log_dir: If set, set the log directory to this. Otherwise,
    the directory will be auto-generated based on the exp_prefix.
    :return:
    """
    if variant is None:
        variant = {}
    if seed is None and 'seed' not in variant:
        seed = random.randint(0, 100000)
        variant['seed'] = str(seed)
    if n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=n_parallel)
        parallel_sampler.set_seed(seed)
    variant['exp_id'] = str(exp_id)
    reset_execution_environment()
    set_seed(seed)
    setup_logger(
        exp_prefix=exp_prefix,
        variant=variant,
        exp_id=exp_id,
        seed=seed,
        snapshot_mode=snapshot_mode,
        snapshot_gap=snapshot_gap,
        base_log_dir=base_log_dir,
        log_dir=log_dir,
        exp_name=exp_name,
    )
    log_dir = logger.get_snapshot_dir()
    if code_diff is not None:
        with open(osp.join(log_dir, "code.diff"), "w") as f:
            f.write(code_diff)
    if commit_hash is not None:
        with open(osp.join(log_dir, "commit_hash.txt"), "w") as f:
            f.write(commit_hash)
    if script_name is not None:
        with open(osp.join(log_dir, "script_name.txt"), "w") as f:
            f.write(script_name)
    set_gpu_mode(use_gpu)

    print('variant', variant)
    return experiment_function(variant)
Exemplo n.º 26
0
def run_experiment(argv):

    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel',
                        type=int,
                        default=1,
                        help='Number of parallel workers to perform rollouts.')
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=default_log_dir,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')

    args = parser.parse_args(argv[1:])

    from rllab.sampler import parallel_sampler
    parallel_sampler.initialize(n_parallel=args.n_parallel)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    # read from stdin
    data = pickle.loads(base64.b64decode(args.args_data))

    log_dir = args.log_dir
    # exp_dir = osp.join(log_dir, args.exp_name)
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    maybe_iter = concretize(data)
    if is_iterable(maybe_iter):
        for _ in maybe_iter:
            pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Exemplo n.º 27
0
def train(num_experiments, thread_id, queue):

    ############ DEFAULT PARAMETERS ############

    env_name = None  #Name of adversarial environment
    path_length = 1000  #Maximum episode length
    layer_size = tuple([100, 100, 100])  #Layer definition
    ifRender = False  #Should we render?
    afterRender = 100  #After how many to animate
    n_exps = 1  #Number of training instances to run
    n_itr = 25  #Number of iterations of the alternating optimization
    n_pro_itr = 1  #Number of iterations for the protaginist
    n_adv_itr = 1  #Number of interations for the adversary
    batch_size = 4000  #Number of training samples for each iteration
    ifSave = True  #Should we save?
    save_every = 100  #Save checkpoint every save_every iterations
    n_process = 1  #Number of parallel threads for sampling environment
    adv_fraction = 0.25  #Fraction of maximum adversarial force to be applied
    step_size = 0.01  #kl step size for TRPO
    gae_lambda = 0.97  #gae_lambda for learner
    save_dir = './results'  #folder to save result in

    ############ ENV SPECIFIC PARAMETERS ############

    env_name = 'HopperAdv-v1'

    layer_size = tuple([64, 64])
    step_size = 0.01
    gae_lambda = 1.0
    batch_size = 25000

    n_exps = num_experiments
    n_itr = 500
    ifSave = False
    n_process = 4

    adv_fraction = 3.0

    save_dir = './../results/StaticHopper'

    args = [
        env_name, path_length, layer_size, ifRender, afterRender, n_exps,
        n_itr, n_pro_itr, n_adv_itr, batch_size, save_every, n_process,
        adv_fraction, step_size, gae_lambda, save_dir
    ]

    ############ ADVERSARIAL POLICY LOAD ############

    filepath = './../initial_results/Hopper/env-HopperAdv-v1_Exp1_Itr500_BS25000_Adv0.25_stp0.01_lam1.0_369983.p'
    res_D = pickle.load(open(filepath, 'rb'))
    pretrained_adv_policy = res_D['adv_policy']

    ############ MAIN LOOP ############

    ## Initializing summaries for the tests ##
    const_test_rew_summary = []
    rand_test_rew_summary = []
    step_test_rew_summary = []
    rand_step_test_rew_summary = []
    adv_test_rew_summary = []

    ## Preparing file to save results in ##
    save_prefix = 'static_env-{}_Exp{}_Itr{}_BS{}_Adv{}_stp{}_lam{}_{}'.format(
        env_name, n_exps, n_itr, batch_size, adv_fraction, step_size,
        gae_lambda, random.randint(0, 1000000))
    save_name = save_dir + '/' + save_prefix

    ## Looping over experiments to carry out ##
    for ne in range(n_exps):
        ## Environment definition ##
        ## The second argument in GymEnv defines the relative magnitude of adversary. For testing we set this to 1.0.
        env = normalize(GymEnv(env_name, adv_fraction))
        env_orig = normalize(GymEnv(env_name, 1.0))

        ## Protagonist policy definition ##
        pro_policy = GaussianMLPPolicy(env_spec=env.spec,
                                       hidden_sizes=layer_size,
                                       is_protagonist=True)
        pro_baseline = LinearFeatureBaseline(env_spec=env.spec)

        ## Zero Adversary for the protagonist training ##
        zero_adv_policy = ConstantControlPolicy(env_spec=env.spec,
                                                is_protagonist=False,
                                                constant_val=0.0)

        ## Adversary policy definition ##
        adv_policy = pretrained_adv_policy
        adv_baseline = LinearFeatureBaseline(env_spec=env.spec)

        ## Initializing the parallel sampler ##
        parallel_sampler.initialize(n_process)

        ## Optimizer for the Protagonist ##
        pro_algo = TRPO(env=env,
                        pro_policy=pro_policy,
                        adv_policy=adv_policy,
                        pro_baseline=pro_baseline,
                        adv_baseline=adv_baseline,
                        batch_size=batch_size,
                        max_path_length=path_length,
                        n_itr=n_pro_itr,
                        discount=0.995,
                        gae_lambda=gae_lambda,
                        step_size=step_size,
                        is_protagonist=True)

        ## Setting up summaries for testing for a specific training instance ##
        pro_rews = []
        adv_rews = []
        all_rews = []
        const_testing_rews = []
        const_testing_rews.append(
            test_const_adv(env_orig, pro_policy, path_length=path_length))
        rand_testing_rews = []
        rand_testing_rews.append(
            test_rand_adv(env_orig, pro_policy, path_length=path_length))
        step_testing_rews = []
        step_testing_rews.append(
            test_step_adv(env_orig, pro_policy, path_length=path_length))
        rand_step_testing_rews = []
        rand_step_testing_rews.append(
            test_rand_step_adv(env_orig, pro_policy, path_length=path_length))
        adv_testing_rews = []
        adv_testing_rews.append(
            test_learnt_adv(env,
                            pro_policy,
                            adv_policy,
                            path_length=path_length))

        ## Beginning alternating optimization ##
        for ni in range(n_itr):
            logger.log('\n\nThread: {} Experiment: {} Iteration: {}\n'.format(
                thread_id,
                ne,
                ni,
            ))

            ## Train Protagonist
            pro_algo.train()
            pro_rews += pro_algo.rews
            all_rews += pro_algo.rews
            logger.log('Protag Reward: {}'.format(
                np.array(pro_algo.rews).mean()))

            ## Test the learnt policies
            const_testing_rews.append(
                test_const_adv(env, pro_policy, path_length=path_length))
            rand_testing_rews.append(
                test_rand_adv(env, pro_policy, path_length=path_length))
            step_testing_rews.append(
                test_step_adv(env, pro_policy, path_length=path_length))
            rand_step_testing_rews.append(
                test_rand_step_adv(env, pro_policy, path_length=path_length))
            adv_testing_rews.append(
                test_learnt_adv(env,
                                pro_policy,
                                adv_policy,
                                path_length=path_length))

            if ni % afterRender == 0 and ifRender == True:
                test_const_adv(env,
                               pro_policy,
                               path_length=path_length,
                               n_traj=1,
                               render=True)

            if ni != 0 and ni % save_every == 0 and ifSave == True:
                ## SAVING CHECKPOINT INFO ##
                pickle.dump(
                    {
                        'args': args,
                        'pro_policy': pro_policy,
                        'adv_policy': adv_policy,
                        'zero_test': [const_testing_rews],
                        'rand_test': [rand_testing_rews],
                        'step_test': [step_testing_rews],
                        'rand_step_test': [rand_step_testing_rews],
                        'iter_save': ni,
                        'exp_save': ne,
                        'adv_test': [adv_testing_rews]
                    }, open(save_name + '_' + str(ni) + '.p', 'wb'))

        ## Shutting down the optimizer ##
        pro_algo.shutdown_worker()

        ## Updating the test summaries over all training instances
        const_test_rew_summary.append(const_testing_rews)
        rand_test_rew_summary.append(rand_testing_rews)
        step_test_rew_summary.append(step_testing_rews)
        rand_step_test_rew_summary.append(rand_step_testing_rews)
        adv_test_rew_summary.append(adv_testing_rews)

    queue.put([
        const_test_rew_summary, rand_test_rew_summary, step_test_rew_summary,
        rand_step_test_rew_summary, adv_test_rew_summary
    ])

    ############ SAVING MODEL ############
    '''
Exemplo n.º 28
0
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=
        'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
    )
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), "gap" (every'
                        '`snapshot_gap` iterations are saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)
    #variant_data is the variant dictionary sent from trpoTests_ExpLite
    if (args.resume_from is not None) and (
            '&|&' in args.resume_from
    ):  #separate string on &|& to get iters and file location
        vals = args.resume_from.split(
            '&|&')  #dirRes | numItrs to go | new batchSize
        dirRes = vals[0]
        numItrs = int(vals[1])
        if (len(vals) > 2):
            batchSize = int(vals[2])
        print("resuming from :{}".format(dirRes))
        data = joblib.load(dirRes)
        #data is dict : 'baseline', 'algo', 'itr', 'policy', 'env'
        assert 'algo' in data
        algo = data['algo']
        assert 'policy' in data
        pol = data['policy']
        bl = data['baseline']
        oldBatchSize = algo.batch_size
        algo.n_itr = numItrs
        if (len(vals) > 2):
            algo.batch_size = batchSize
            print(
                'algo iters : {} cur iter :{} oldBatchSize : {} newBatchSize : {}'
                .format(algo.n_itr, algo.current_itr, oldBatchSize,
                        algo.batch_size))
        else:
            print('algo iters : {} cur iter :{} '.format(
                algo.n_itr, algo.current_itr))
        algo.train()
    else:
        print('Not resuming - building new exp')
        # read from stdin
        if args.use_cloudpickle:  #set to use cloudpickle
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            print('not use cloud pickle')
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Exemplo n.º 29
0
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=batch_size,
        max_path_length=pathlength,
        n_itr=5,
        discount=0.995,
        step_size=0.01,
        gae_lambda=0.97,

        #task_num=task_size,
    )
    algo.init_opt()

    from rllab.sampler import parallel_sampler
    parallel_sampler.initialize(n_parallel=num_parallel)
    parallel_sampler.set_seed(0)

    algo.start_worker()

    for i in range(initialize_epochs):
        print('------ Iter ', i, ' in Init Training ', diretory, '--------')
        paths = algo.sampler.obtain_samples(0)
        samples_data = algo.sampler.process_samples(0, paths)
        opt_data = algo.optimize_policy(0, samples_data)
        pol_aft = (policy.get_param_values())
        print(algo.mean_kl(samples_data))
        print(dict(logger._tabular)['AverageReturn'])

    data_perc_list = [0.999, 0.7, 0.5, 0.3, 0.1, 0.05, 0.01]
Exemplo n.º 30
0
from rllab.policies.gaussian_mlp_policy import GaussianMLPPolicy
from rllab.optimizers.conjugate_gradient_optimizer import ConjugateGradientOptimizer
from examples.discriminator_bullet import Mlp_Discriminator
import pickle
from subprocess import Popen
import time
import os
import signal
import sys, traceback

# auto save config
experiment_spec = "100X50X25_22D_DiscriminateReward_GAE"
save_policy_every = 50

from rllab.sampler import parallel_sampler
parallel_sampler.initialize(n_parallel=1)

simulator = Popen(["./HumanDemoNoGUI"])
time.sleep(3)

# try:
discriminator = Mlp_Discriminator(a_max=0.8,
                                  a_min=0.5,
                                  decent_portion=0.8,
                                  disc_window=2,
                                  iteration=3000,
                                  disc_joints_dim=16,
                                  hidden_sizes=(128, 64, 32))

# baseline
#env = normalize(HumanEnv_v2(discriminator=None), normalize_obs=True)
Exemplo n.º 31
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.99)
    parser.add_argument('--gae_lambda', type=float, default=1.0)
    parser.add_argument('--reward_scale', type=float, default=1.0)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum', action='store_true', default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--rectangle', type=str, default='10,10')
    parser.add_argument('--map_type', type=str, default='rectangle')
    parser.add_argument('--n_evaders', type=int, default=5)
    parser.add_argument('--n_pursuers', type=int, default=2)
    parser.add_argument('--obs_range', type=int, default=3)
    parser.add_argument('--n_catch', type=int, default=2)
    parser.add_argument('--urgency', type=float, default=0.0)
    parser.add_argument('--pursuit', dest='train_pursuit', action='store_true')
    parser.add_argument('--evade', dest='train_pursuit', action='store_false')
    parser.set_defaults(train_pursuit=True)
    parser.add_argument('--surround', action='store_true', default=False)
    parser.add_argument('--constraint_window', type=float, default=1.0)
    parser.add_argument('--sample_maps', action='store_true', default=False)
    parser.add_argument('--map_file', type=str, default='../maps/map_pool.npy')
    parser.add_argument('--flatten', action='store_true', default=False)
    parser.add_argument('--reward_mech', type=str, default='global')
    parser.add_argument('--catchr', type=float, default=0.1)
    parser.add_argument('--term_pursuit', type=float, default=5.0)

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_type', type=str, default='linear')

    parser.add_argument('--conv', action='store_true', default=False)

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')


    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    if args.sample_maps:
        map_pool = np.load(args.map_file)
    else:
        if args.map_type == 'rectangle':
            env_map = TwoDMaps.rectangle_map(*map(int, args.rectangle.split(',')))
        elif args.map_type == 'complex':
            env_map = TwoDMaps.complex_map(*map(int, args.rectangle.split(',')))
        else:
            raise NotImplementedError()
        map_pool = [env_map]

    env = PursuitEvade(map_pool, n_evaders=args.n_evaders, n_pursuers=args.n_pursuers,
                       obs_range=args.obs_range, n_catch=args.n_catch,
                       train_pursuit=args.train_pursuit, urgency_reward=args.urgency,
                       surround=args.surround, sample_maps=args.sample_maps,
                       constraint_window=args.constraint_window,
                       flatten=args.flatten,
                       reward_mech=args.reward_mech,
                       catchr=args.catchr,
                       term_pursuit=args.term_pursuit)

    env = RLLabEnv(
            StandardizedEnv(env, scale_reward=args.reward_scale, enable_obsnorm=False),
            mode=args.control)

    if args.recurrent:
        if args.conv:
            feature_network = ConvNetwork(
                input_shape=emv.spec.observation_space.shape,
                output_dim=5, 
                conv_filters=(8,16,16),
                conv_filter_sizes=(3,3,3),
                conv_strides=(1,1,1),
                conv_pads=('VALID','VALID','VALID'),
                hidden_sizes=(64,), 
                hidden_nonlinearity=NL.rectify,
                output_nonlinearity=NL.softmax)
        else:
            feature_network = MLP(
                input_shape=(env.spec.observation_space.flat_dim + env.spec.action_space.flat_dim,),
                output_dim=5, hidden_sizes=(128,128,128), hidden_nonlinearity=NL.tanh,
                output_nonlinearity=None)
        if args.recurrent == 'gru':
            policy = CategoricalGRUPolicy(env_spec=env.spec, feature_network=feature_network,
                                       hidden_dim=int(args.policy_hidden_sizes))
    elif args.conv:
        feature_network = ConvNetwork(
            input_shape=env.spec.observation_space.shape,
            output_dim=5, 
            conv_filters=(8,16,16),
            conv_filter_sizes=(3,3,3),
            conv_strides=(1,1,1),
            conv_pads=('valid','valid','valid'),
            hidden_sizes=(64,), 
            hidden_nonlinearity=NL.rectify,
            output_nonlinearity=NL.softmax)
        policy = CategoricalMLPPolicy(env_spec=env.spec, prob_network=feature_network)
    else:
        policy = CategoricalMLPPolicy(env_spec=env.spec, hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(obsfeat_space)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        mode=args.control,)

    algo.train()