예제 #1
0
def rllab_logdir(algo=None, dirname=None):
    if dirname:
        rllablogger.set_snapshot_dir(dirname)
    dirname = rllablogger.get_snapshot_dir()
    rllablogger.add_tabular_output(os.path.join(dirname, 'progress.csv'))
    if algo:
        with open(os.path.join(dirname, 'params.json'), 'w') as f:
            params = extract_hyperparams(algo)
            json.dump(params, f)
    yield dirname
    rllablogger.remove_tabular_output(os.path.join(dirname, 'progress.csv'))
예제 #2
0
    def test(self, dir):  # Rui: just run and render
        self.start_worker()
        logger.set_snapshot_dir(dir)
        self.anneal_step_num(self.current_itr)

        for itr in range(100):  # run 100 iterations in total!
            self.n_parallel = 1
            paths = self.sampler.obtain_samples(itr)
            # print("time_step_agg", self.env.time_steps_agg)
            # record value function
            distance_record = []
            pos = []
            for idx, path in enumerate(paths):
                episode_distance = path['env_infos']['last_env_info'][
                    'distance']  # a list of distances
                episode_pos = path['env_infos']['last_env_info']['actual_pos']
                # episode_pos = path['env_infos']['full_path']['env_infos']['actual_pos']
                # print("one distance", episode_distance)
                distance_record.append(episode_distance)
                pos.append(episode_pos)
            distance_record = np.concatenate(distance_record)
            pos = np.concatenate(pos)
            # print("distance_record", distance_record)
            high_baselines = [self.baseline.predict(path) for path in paths]
            high_baselines = np.concatenate(high_baselines)
            # print("high_baselines", high_baselines)
            with open(dir + "high_values.pkl", 'wb') as f:
                pickle.dump(high_baselines, f)
            with open(dir + "distance.pkl", "wb") as f:
                pickle.dump(distance_record, f)
            # for v plot: record x and y, and baseline value
            with open(dir + "xyposition.pkl", "wb") as f:
                pickle.dump(pos, f)

            self.env.log_diagnostics(paths)
            self.policy.log_diagnostics(paths)
            self.baseline.log_diagnostics(paths)
            logger.dump_tabular(with_prefix=False)
            self.shutdown_worker()
예제 #3
0
def main(args):
    logger.set_snapshot_dir(args.snapshot_dir)
    logger.set_snapshot_mode("none")
    logger.add_tabular_output(os.path.join(args.snapshot_dir, "tabular.csv"))
    env = GymEnv(args.env_id)

    # If the user provided a starting policy, use it. Otherwise, we start with
    # a fresh policy.
    if args.input_policy is not None:
        with open(args.input_policy, "rb") as f:
            policy = pickle.load(f)
    else:
        policy = CategoricalMLPPolicy(env_spec=env.spec,
                                      hidden_sizes=args.hidden_sizes)

        # policy = CategoricalMLPPolicy(
        #     env_spec=env.spec,
        #     hidden_sizes=(16, 16),
        #     hidden_nonlinearity=lasagne.nonlinearities.rectify)

        # policy = CategoricalGRUPolicy(env_spec=env.spec)

    baseline = LinearFeatureBaseline(env_spec=env.spec)
    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.batch_size,
        max_path_length=env.horizon,
        n_itr=args.n_itr,
        discount=args.discount,
        step_size=args.step_size,
        gae_lambda=args.gae_lambda,
    )
    algo.train()
    with open(args.output_policy, "wb") as f:
        pickle.dump(policy, f)
예제 #4
0
def experiment(variant):

    seed = variant['seed']

    tf.set_random_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    fast_learning_rate = variant['flr']

    fast_batch_size = variant[
        'fbs']  # 10 works for [0.1, 0.2], 20 doesn't improve much for [0,0.2]
    meta_batch_size = 20  # 10 also works, but much less stable, 20 is fairly stable, 40 is more stable
    max_path_length = 150
    num_grad_updates = 1
    meta_step_size = variant['mlr']

    regionSize = variant['regionSize']

    if regionSize == '20X20':

        tasksFile = '/root/code/multiworld/multiworld/envs/goals/pickPlace_20X20_6_8.pkl'

    else:
        assert regionSize == '60X30'

        tasksFile = '/root/code/multiworld/multiworld/envs/goals/pickPlace_60X30.pkl'

    tasks = pickle.load(open(tasksFile, 'rb'))

    envType = variant['envType']

    if envType == 'Push':

        baseEnv = SawyerPushEnv(tasks=tasks)
    else:
        assert (envType) == 'PickPlace'

        baseEnv = SawyerPickPlaceEnv(tasks=tasks)
    env = FinnMamlEnv(
        FlatGoalEnv(baseEnv,
                    obs_keys=['state_observation', 'state_desired_goal']))

    env = TfEnv(NormalizedBoxEnv(env))

    baseline = LinearFeatureBaseline(env_spec=env.spec)

    algo = MAMLTRPO(
        env=env,
        policy=None,
        load_policy=variant['init_param_file'],
        baseline=baseline,
        batch_size=fast_batch_size,  # number of trajs for grad update
        max_path_length=max_path_length,
        meta_batch_size=meta_batch_size,
        num_grad_updates=num_grad_updates,
        n_itr=1000,
        use_maml=True,
        step_size=meta_step_size,
        plot=False,
    )

    import os

    saveDir = variant['saveDir']

    if os.path.isdir(saveDir) == False:
        os.mkdir(saveDir)

    logger.set_snapshot_dir(saveDir)
    logger.add_tabular_output(saveDir + 'progress.csv')

    algo.train()
예제 #5
0
env._wrapped_env.generate_grid = False
env._wrapped_env.generate_b0_start_goal = False
env.reset()

log_dir = "./Data/obs_1goal20step0stay_1_gru"

tabular_log_file = osp.join(log_dir, "progress.csv")
text_log_file = osp.join(log_dir, "debug.log")
params_log_file = osp.join(log_dir, "params.json")
pkl_file = osp.join(log_dir, "params.pkl")

logger.add_text_output(text_log_file)
logger.add_tabular_output(tabular_log_file)
prev_snapshot_dir = logger.get_snapshot_dir()
prev_mode = logger.get_snapshot_mode()
logger.set_snapshot_dir(log_dir)
logger.set_snapshot_mode("gaplast")
logger.set_snapshot_gap(1000)
logger.set_log_tabular_only(False)
logger.push_prefix("[%s] " % "FixMapStartState")

from Algo import parallel_sampler
parallel_sampler.initialize(n_parallel=1)
parallel_sampler.set_seed(0)

policy = QMDPPolicy(env_spec=env.spec,
                    name="QMDP",
                    qmdp_param=env._wrapped_env.params)

baseline = LinearFeatureBaseline(env_spec=env.spec)
예제 #6
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', type=str, default=default_exp_name,
                        help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)
    parser.add_argument('--reward_scale', type=float, default=1.0)
    parser.add_argument('--enable_obsnorm', action='store_true', default=False)
    parser.add_argument('--chunked', action='store_true', default=False)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum', action='store_true', default=False)
    parser.add_argument('--anneal_step_size', type=int, default=0)

    parser.add_argument('--n_timesteps', type=int, default=8000)

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--radius', type=float, default=0.015)
    parser.add_argument('--n_evaders', type=int, default=10)
    parser.add_argument('--n_pursuers', type=int, default=8)
    parser.add_argument('--n_poison', type=int, default=10)
    parser.add_argument('--n_coop', type=int, default=4)
    parser.add_argument('--n_sensors', type=int, default=30)
    parser.add_argument('--sensor_range', type=str, default='0.2')
    parser.add_argument('--food_reward', type=float, default=5)
    parser.add_argument('--poison_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.05)
    parser.add_argument('--reward_mech', type=str, default='local')

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--baseline_type', type=str, default='linear')
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_hidden_sizes', type=str, default='128,128')

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data', type=str, help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument(
        '--log_tabular_only', type=ast.literal_eval, default=False,
        help='Whether to only print the tabular log information (in a horizontal format)')

    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    centralized = True if args.control == 'centralized' else False

    sensor_range = np.array(map(float, args.sensor_range.split(',')))
    if len(sensor_range) == 1:
        sensor_range = sensor_range[0]
    else:
        assert sensor_range.shape == (args.n_pursuers,)

    env = MAWaterWorld(args.n_pursuers, args.n_evaders, args.n_coop, args.n_poison,
                       radius=args.radius, n_sensors=args.n_sensors, food_reward=args.food_reward,
                       poison_reward=args.poison_reward, encounter_reward=args.encounter_reward,
                       reward_mech=args.reward_mech, sensor_range=sensor_range, obstacle_loc=None)

    env = TfEnv(
        RLLabEnv(
            StandardizedEnv(env, scale_reward=args.reward_scale,
                            enable_obsnorm=args.enable_obsnorm), mode=args.control))

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if args.recurrent:
        feature_network = MLP(
            name='feature_net',
            input_shape=(env.spec.observation_space.flat_dim + env.spec.action_space.flat_dim,),
            output_dim=16, hidden_sizes=(128, 64, 32), hidden_nonlinearity=tf.nn.tanh,
            output_nonlinearity=None)
        if args.recurrent == 'gru':
            policy = GaussianGRUPolicy(env_spec=env.spec, feature_network=feature_network,
                                       hidden_dim=int(args.policy_hidden_sizes), name='policy')
        elif args.recurrent == 'lstm':
            policy = GaussianLSTMPolicy(env_spec=env.spec, feature_network=feature_network,
                                        hidden_dim=int(args.policy_hidden_sizes), name='policy')
    else:
        policy = GaussianMLPPolicy(
            name='policy', env_spec=env.spec,
            hidden_sizes=tuple(map(int, args.policy_hidden_sizes.split(','))), min_std=10e-5)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    elif args.baseline_type == 'mlp':
        raise NotImplementedError()
        # baseline = GaussianMLPBaseline(
        #     env_spec=env.spec, hidden_sizes=tuple(map(int, args.baseline_hidden_sizes.split(','))))
    else:
        baseline = ZeroBaseline(env_spec=env.spec)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        #max_path_length_limit=args.max_path_length_limit,
        update_max_path_length=args.update_curriculum,
        anneal_step_size=args.anneal_step_size,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)) if
        args.recurrent else None,
        mode=args.control if not args.chunked else 'chunk_{}'.format(args.control),)

    algo.train()
예제 #7
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum', action='store_true', default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--n_good', type=int, default=3)
    parser.add_argument('--n_hostage', type=int, default=5)
    parser.add_argument('--n_bad', type=int, default=5)
    parser.add_argument('--n_coop_save', type=int, default=2)
    parser.add_argument('--n_coop_avoid', type=int, default=2)
    parser.add_argument('--n_sensors', type=int, default=20)
    parser.add_argument('--sensor_range', type=float, default=0.2)
    parser.add_argument('--save_reward', type=float, default=3)
    parser.add_argument('--hit_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.01)
    parser.add_argument('--bomb_reward', type=float, default=-10.)

    parser.add_argument('--recurrent', action='store_true', default=False)
    parser.add_argument('--baseline_type', type=str, default='linear')
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')


    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    centralized = True if args.control == 'centralized' else False

    sensor_range = np.array(map(float, args.sensor_range.split(',')))
    assert sensor_range.shape == (args.n_pursuers,)

    env = ContinuousHostageWorld(args.n_good, args.n_hostage, args.n_bad, args.n_coop_save,
                                 args.n_coop_avoid, n_sensors=args.n_sensors,
                                 sensor_range=args.sensor_range, save_reward=args.save_reward,
                                 hit_reward=args.hit_reward, encounter_reward=args.encounter_reward,
                                 bomb_reward=args.bomb_reward)

    env = RLLabEnv(StandardizedEnv(env), mode=args.control)

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if args.recurrent:
        policy = GaussianGRUPolicy(env_spec=env.spec, hidden_sizes=args.hidden_sizes)
    else:
        policy = GaussianMLPPolicy(env_spec=env.spec, hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(obsfeat_space)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(env=env,
            policy=policy,
            baseline=baseline,
            batch_size=args.n_timesteps,
            max_path_length=args.max_traj_len,
            n_itr=args.n_iter,
            discount=args.discount,
            step_size=args.max_kl,
            mode=args.control,)

    algo.train()
예제 #8
0
파일: rurllab.py 프로젝트: zeyuan1987/MADRL
    def __init__(self, env, args):
        self.args = args
        # Parallel setup
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            set_seed(args.seed)
            parallel_sampler.set_seed(args.seed)

        env, policy = rllab_envpolicy_parser(env, args)

        if not args.algo == 'thddpg':
            # Baseline
            if args.baseline_type == 'linear':
                baseline = LinearFeatureBaseline(env_spec=env.spec)
            elif args.baseline_type == 'zero':
                baseline = ZeroBaseline(env_spec=env.spec)
            else:
                raise NotImplementedError(args.baseline_type)

        # Logger
        default_log_dir = config.LOG_DIR
        if args.log_dir is None:
            log_dir = osp.join(default_log_dir, args.exp_name)
        else:
            log_dir = args.log_dir

        tabular_log_file = osp.join(log_dir, args.tabular_log_file)
        text_log_file = osp.join(log_dir, args.text_log_file)
        params_log_file = osp.join(log_dir, args.params_log_file)

        logger.log_parameters_lite(params_log_file, args)
        logger.add_text_output(text_log_file)
        logger.add_tabular_output(tabular_log_file)
        prev_snapshot_dir = logger.get_snapshot_dir()
        prev_mode = logger.get_snapshot_mode()
        logger.set_snapshot_dir(log_dir)
        logger.set_snapshot_mode(args.snapshot_mode)
        logger.set_log_tabular_only(args.log_tabular_only)
        logger.push_prefix("[%s] " % args.exp_name)

        if args.algo == 'tftrpo':
            self.algo = TRPO(
                env=env,
                policy=policy,
                baseline=baseline,
                batch_size=args.batch_size,
                max_path_length=args.max_path_length,
                n_itr=args.n_iter,
                discount=args.discount,
                gae_lambda=args.gae_lambda,
                step_size=args.step_size,
                optimizer=ConjugateGradientOptimizer(
                    hvp_approach=FiniteDifferenceHvp(
                        base_eps=1e-5)) if args.recurrent else None,
                mode=args.control)
        elif args.algo == 'thddpg':
            qfunc = thContinuousMLPQFunction(env_spec=env.spec)
            if args.exp_strategy == 'ou':
                es = OUStrategy(env_spec=env.spec)
            elif args.exp_strategy == 'gauss':
                es = GaussianStrategy(env_spec=env.spec)
            else:
                raise NotImplementedError()

            self.algo = thDDPG(env=env,
                               policy=policy,
                               qf=qfunc,
                               es=es,
                               batch_size=args.batch_size,
                               max_path_length=args.max_path_length,
                               epoch_length=args.epoch_length,
                               min_pool_size=args.min_pool_size,
                               replay_pool_size=args.replay_pool_size,
                               n_epochs=args.n_iter,
                               discount=args.discount,
                               scale_reward=0.01,
                               qf_learning_rate=args.qfunc_lr,
                               policy_learning_rate=args.policy_lr,
                               eval_samples=args.eval_samples,
                               mode=args.control)
예제 #9
0
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=
        'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
    )
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), "gap" (every'
                        '`snapshot_gap` iterations are saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)
    #variant_data is the variant dictionary sent from trpoTests_ExpLite
    if (args.resume_from is not None) and (
            '&|&' in args.resume_from
    ):  #separate string on &|& to get iters and file location
        vals = args.resume_from.split(
            '&|&')  #dirRes | numItrs to go | new batchSize
        dirRes = vals[0]
        numItrs = int(vals[1])
        if (len(vals) > 2):
            batchSize = int(vals[2])
        print("resuming from :{}".format(dirRes))
        data = joblib.load(dirRes)
        #data is dict : 'baseline', 'algo', 'itr', 'policy', 'env'
        assert 'algo' in data
        algo = data['algo']
        assert 'policy' in data
        pol = data['policy']
        bl = data['baseline']
        oldBatchSize = algo.batch_size
        algo.n_itr = numItrs
        if (len(vals) > 2):
            algo.batch_size = batchSize
            print(
                'algo iters : {} cur iter :{} oldBatchSize : {} newBatchSize : {}'
                .format(algo.n_itr, algo.current_itr, oldBatchSize,
                        algo.batch_size))
        else:
            print('algo iters : {} cur iter :{} '.format(
                algo.n_itr, algo.current_itr))
        algo.train()
    else:
        print('Not resuming - building new exp')
        # read from stdin
        if args.use_cloudpickle:  #set to use cloudpickle
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            print('not use cloud pickle')
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
예제 #10
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)
    parser.add_argument('--reward_scale', type=float, default=1.0)
    parser.add_argument('--enable_obsnorm', action='store_true', default=False)
    parser.add_argument('--chunked', action='store_true', default=False)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum',
                        action='store_true',
                        default=False)
    parser.add_argument('--anneal_step_size', type=int, default=0)

    parser.add_argument('--n_timesteps', type=int, default=8000)

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--radius', type=float, default=0.015)
    parser.add_argument('--n_evaders', type=int, default=10)
    parser.add_argument('--n_pursuers', type=int, default=8)
    parser.add_argument('--n_poison', type=int, default=10)
    parser.add_argument('--n_coop', type=int, default=4)
    parser.add_argument('--n_sensors', type=int, default=30)
    parser.add_argument('--sensor_range', type=str, default='0.2')
    parser.add_argument('--food_reward', type=float, default=5)
    parser.add_argument('--poison_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.05)
    parser.add_argument('--reward_mech', type=str, default='local')

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--baseline_type', type=str, default='linear')
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_hidden_sizes', type=str, default='128,128')

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )

    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    centralized = True if args.control == 'centralized' else False

    sensor_range = np.array(map(float, args.sensor_range.split(',')))
    if len(sensor_range) == 1:
        sensor_range = sensor_range[0]
    else:
        assert sensor_range.shape == (args.n_pursuers, )

    env = MAWaterWorld(args.n_pursuers,
                       args.n_evaders,
                       args.n_coop,
                       args.n_poison,
                       radius=args.radius,
                       n_sensors=args.n_sensors,
                       food_reward=args.food_reward,
                       poison_reward=args.poison_reward,
                       encounter_reward=args.encounter_reward,
                       reward_mech=args.reward_mech,
                       sensor_range=sensor_range,
                       obstacle_loc=None)

    env = TfEnv(
        RLLabEnv(StandardizedEnv(env,
                                 scale_reward=args.reward_scale,
                                 enable_obsnorm=args.enable_obsnorm),
                 mode=args.control))

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if args.recurrent:
        feature_network = MLP(
            name='feature_net',
            input_shape=(env.spec.observation_space.flat_dim +
                         env.spec.action_space.flat_dim, ),
            output_dim=16,
            hidden_sizes=(128, 64, 32),
            hidden_nonlinearity=tf.nn.tanh,
            output_nonlinearity=None)
        if args.recurrent == 'gru':
            policy = GaussianGRUPolicy(env_spec=env.spec,
                                       feature_network=feature_network,
                                       hidden_dim=int(
                                           args.policy_hidden_sizes),
                                       name='policy')
        elif args.recurrent == 'lstm':
            policy = GaussianLSTMPolicy(env_spec=env.spec,
                                        feature_network=feature_network,
                                        hidden_dim=int(
                                            args.policy_hidden_sizes),
                                        name='policy')
    else:
        policy = GaussianMLPPolicy(
            name='policy',
            env_spec=env.spec,
            hidden_sizes=tuple(map(int, args.policy_hidden_sizes.split(','))),
            min_std=10e-5)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    elif args.baseline_type == 'mlp':
        raise NotImplementedError()
        # baseline = GaussianMLPBaseline(
        #     env_spec=env.spec, hidden_sizes=tuple(map(int, args.baseline_hidden_sizes.split(','))))
    else:
        baseline = ZeroBaseline(env_spec=env.spec)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        #max_path_length_limit=args.max_path_length_limit,
        update_max_path_length=args.update_curriculum,
        anneal_step_size=args.anneal_step_size,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp(
            base_eps=1e-5)) if args.recurrent else None,
        mode=args.control
        if not args.chunked else 'chunk_{}'.format(args.control),
    )

    algo.train()
예제 #11
0
파일: resume_vrep.py 프로젝트: flyers/rllab
import joblib

default_log_dir = '/home/sliay/Documents/rllab/data/local/experiment'
now = datetime.datetime.now(dateutil.tz.tzlocal())
# avoid name clashes when running distributed jobs
rand_id = str(uuid.uuid4())[:5]
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
log_dir = os.path.join(default_log_dir, default_exp_name)
tabular_log_file = os.path.join(log_dir, 'progress.csv')
text_log_file = os.path.join(log_dir, 'debug.log')
params_log_file = os.path.join(log_dir, 'params.json')

logger.add_text_output(text_log_file)
logger.add_tabular_output(tabular_log_file)
logger.set_snapshot_dir(log_dir)
logger.set_snapshot_mode('last')
logger.set_log_tabular_only(False)
logger.push_prefix("[%s] " % default_exp_name)

last_snapshot_dir = '/home/sliay/Documents/rllab/data/local/experiment/experiment_2016_07_07_498itr'
data = joblib.load(os.path.join(last_snapshot_dir, 'params.pkl'))
policy = data['policy']
env = data['env']
baseline = data['baseline']
# env = normalize(GymEnv("VREP-v0", record_video=False))

# policy = GaussianMLPPolicy(
#     env_spec=env.spec,
#     The neural network policy should have two hidden layers, each with 32 hidden units.
#     hidden_sizes=(128, 128)
예제 #12
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.99)
    parser.add_argument('--gae_lambda', type=float, default=1.0)
    parser.add_argument('--reward_scale', type=float, default=1.0)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum', action='store_true', default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--rectangle', type=str, default='10,10')
    parser.add_argument('--map_type', type=str, default='rectangle')
    parser.add_argument('--n_evaders', type=int, default=5)
    parser.add_argument('--n_pursuers', type=int, default=2)
    parser.add_argument('--obs_range', type=int, default=3)
    parser.add_argument('--n_catch', type=int, default=2)
    parser.add_argument('--urgency', type=float, default=0.0)
    parser.add_argument('--pursuit', dest='train_pursuit', action='store_true')
    parser.add_argument('--evade', dest='train_pursuit', action='store_false')
    parser.set_defaults(train_pursuit=True)
    parser.add_argument('--surround', action='store_true', default=False)
    parser.add_argument('--constraint_window', type=float, default=1.0)
    parser.add_argument('--sample_maps', action='store_true', default=False)
    parser.add_argument('--map_file', type=str, default='../maps/map_pool.npy')
    parser.add_argument('--flatten', action='store_true', default=False)
    parser.add_argument('--reward_mech', type=str, default='global')
    parser.add_argument('--catchr', type=float, default=0.1)
    parser.add_argument('--term_pursuit', type=float, default=5.0)

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_type', type=str, default='linear')

    parser.add_argument('--conv', action='store_true', default=False)

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')


    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    if args.sample_maps:
        map_pool = np.load(args.map_file)
    else:
        if args.map_type == 'rectangle':
            env_map = TwoDMaps.rectangle_map(*map(int, args.rectangle.split(',')))
        elif args.map_type == 'complex':
            env_map = TwoDMaps.complex_map(*map(int, args.rectangle.split(',')))
        else:
            raise NotImplementedError()
        map_pool = [env_map]

    env = PursuitEvade(map_pool, n_evaders=args.n_evaders, n_pursuers=args.n_pursuers,
                       obs_range=args.obs_range, n_catch=args.n_catch,
                       train_pursuit=args.train_pursuit, urgency_reward=args.urgency,
                       surround=args.surround, sample_maps=args.sample_maps,
                       constraint_window=args.constraint_window,
                       flatten=args.flatten,
                       reward_mech=args.reward_mech,
                       catchr=args.catchr,
                       term_pursuit=args.term_pursuit)

    env = RLLabEnv(
            StandardizedEnv(env, scale_reward=args.reward_scale, enable_obsnorm=False),
            mode=args.control)

    if args.recurrent:
        if args.conv:
            feature_network = ConvNetwork(
                input_shape=emv.spec.observation_space.shape,
                output_dim=5, 
                conv_filters=(8,16,16),
                conv_filter_sizes=(3,3,3),
                conv_strides=(1,1,1),
                conv_pads=('VALID','VALID','VALID'),
                hidden_sizes=(64,), 
                hidden_nonlinearity=NL.rectify,
                output_nonlinearity=NL.softmax)
        else:
            feature_network = MLP(
                input_shape=(env.spec.observation_space.flat_dim + env.spec.action_space.flat_dim,),
                output_dim=5, hidden_sizes=(128,128,128), hidden_nonlinearity=NL.tanh,
                output_nonlinearity=None)
        if args.recurrent == 'gru':
            policy = CategoricalGRUPolicy(env_spec=env.spec, feature_network=feature_network,
                                       hidden_dim=int(args.policy_hidden_sizes))
    elif args.conv:
        feature_network = ConvNetwork(
            input_shape=env.spec.observation_space.shape,
            output_dim=5, 
            conv_filters=(8,16,16),
            conv_filter_sizes=(3,3,3),
            conv_strides=(1,1,1),
            conv_pads=('valid','valid','valid'),
            hidden_sizes=(64,), 
            hidden_nonlinearity=NL.rectify,
            output_nonlinearity=NL.softmax)
        policy = CategoricalMLPPolicy(env_spec=env.spec, prob_network=feature_network)
    else:
        policy = CategoricalMLPPolicy(env_spec=env.spec, hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(obsfeat_space)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        mode=args.control,)

    algo.train()
예제 #13
0
def experiment(variant):

    seed = variant['seed']
    tf.set_random_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    initial_params_file = variant['initial_params_file']
    goalIndex = variant['goalIndex']
    init_step_size = variant['init_step_size']

    regionSize = variant['regionSize']

    mode = variant['mode']

    if 'docker' in mode:
        taskFilePrefix = '/root/code'
    else:
        taskFilePrefix = '/home/russellm'

    if variant['valRegionSize'] != None:
        valRegionSize = variant['valRegionSize']

        tasksFile = taskFilePrefix + '/multiworld/multiworld/envs/goals/pickPlace_' + valRegionSize + '_val.pkl'

    else:
        tasksFile = taskFilePrefix + '/multiworld/multiworld/envs/goals/pickPlace_' + regionSize + '.pkl'

    tasks = pickle.load(open(tasksFile, 'rb'))

    envType = variant['envType']
    if envType == 'Push':
        baseEnv = SawyerPushEnv(tasks=tasks)
    else:
        assert (envType) == 'PickPlace'
        baseEnv = SawyerPickPlaceEnv(tasks=tasks)

    env = FinnMamlEnv(
        FlatGoalEnv(baseEnv,
                    obs_keys=['state_observation', 'state_desired_goal']))
    env = TfEnv(NormalizedBoxEnv(env))
    baseline = LinearFeatureBaseline(env_spec=env.spec)

    algo = VPG(
        env=env,
        policy=None,
        load_policy=initial_params_file,
        baseline=baseline,
        batch_size=7500,  # 2x
        max_path_length=150,
        n_itr=10,
        reset_arg=goalIndex,
        optimizer_args={
            'init_learning_rate': init_step_size,
            'tf_optimizer_args': {
                'learning_rate': 0.1 * init_step_size
            },
            'tf_optimizer_cls': tf.train.GradientDescentOptimizer
        })
    import os
    saveDir = variant['saveDir']
    currPath = ''
    for _dir in saveDir.split('/'):
        currPath += _dir + '/'
        if os.path.isdir(currPath) == False:
            os.mkdir(currPath)

    logger.set_snapshot_dir(saveDir)
    logger.add_tabular_output(saveDir + 'progress.csv')
    algo.train()
예제 #14
0
    env_spec=env.spec,
    # The neural network policy should have two hidden layers, each with 32 hidden units.
    hidden_sizes=(32, ))

baseline = LinearFeatureBaseline(env_spec=env.spec)

# logger
LOG_DIR = 'walker_gru_test'

tabular_log_file = osp.join(LOG_DIR, 'progress.csv')
text_log_file = osp.join(LOG_DIR, 'debug.log')
params_log_file = osp.join(LOG_DIR, 'params.json')

logger.add_text_output(text_log_file)
logger.add_tabular_output(tabular_log_file)
logger.set_snapshot_dir(LOG_DIR)
logger.set_snapshot_mode('last')
logger.set_log_tabular_only(False)
logger.push_prefix("[%s] " % 'Walker')

algo = TRPO(env=env,
            policy=policy,
            baseline=baseline,
            batch_size=1200,
            max_path_length=500,
            n_itr=500,
            discount=0.99,
            step_size=0.01,
            mode='centralized')
algo.train()
예제 #15
0
    def __init__(self,
                 base_kwargs,
                 env,
                 policy,
                 discriminator,
                 qf,
                 vf,
                 pool,
                 metric,
                 env_id,
                 log_dir,
                 eval_freq,
                 plotter=None,
                 lr=3E-3,
                 scale_entropy=1,
                 discount=0.99,
                 tau=0.01,
                 num_skills=20,
                 save_full_state=False,
                 find_best_skill_interval=10,
                 best_skill_n_rollouts=10,
                 learn_p_z=False,
                 include_actions=False,
                 add_p_z=True):
        """
            Same as DIAYN just added behaviour descriptor tracking and passing 
            env and log info.
        """
        Serializable.quick_init(self, locals())
        super(SAC, self).__init__(**base_kwargs)

        self._env = env
        self._policy = policy
        self._discriminator = discriminator
        self._qf = qf
        self._vf = vf
        self._pool = pool
        self._plotter = plotter

        self._policy_lr = lr
        self._discriminator_lr = lr
        self._qf_lr = lr
        self._vf_lr = lr
        self._scale_entropy = scale_entropy
        self._discount = discount
        self._tau = tau
        self._num_skills = num_skills
        self._p_z = np.full(num_skills, 1.0 / num_skills)
        self._find_best_skill_interval = find_best_skill_interval
        self._best_skill_n_rollouts = best_skill_n_rollouts
        self._learn_p_z = learn_p_z
        self._save_full_state = save_full_state
        self._include_actions = include_actions
        self._add_p_z = add_p_z

        self._Da = self._env.action_space.flat_dim
        self._Do = self._env.observation_space.flat_dim

        self._training_ops = list()

        config = tf.ConfigProto()
        # config.gpu_options.allow_growth = True
        config.intra_op_parallelism_threads = 2
        config.inter_op_parallelism_threads = 2
        self._sess = tf.InteractiveSession(config=config)

        self._init_placeholders()
        self._init_actor_update()
        self._init_critic_update()
        self._init_discriminator_update()
        self._init_target_ops()

        self._sess.run(tf.global_variables_initializer())

        ### Additional params for behaviour tracking ###

        self.dirname = log_dir
        self.eval_freq = eval_freq
        self.env_id = env_id
        logger.set_snapshot_dir(self.dirname)
        logger.log("EXPERIMENT NAME: {} (eval freq: {})".format(
            self.dirname, self.eval_freq))

        # Initialise behaviour descriptor metric
        env_info = env._wrapped_env.env.env_info
        env_class = ''.join([c.capitalize() for c in env_id.split('_')])
        bd_class = ''.join([s.capitalize() for s in metric['type'].split('_')])
        self.bd_metric = bmet.__dict__[env_class + bd_class](**metric,
                                                             **env_info)

        # Initialise data discovery writer
        if not os.path.isdir(self.dirname):
            os.makedirs(self.dirname)
        filepath = '{}/ref_data_DIAYN_{}.csv'.format(self.dirname, self.env_id)
        with open(filepath, 'a') as outfile:
            writer = csv.DictWriter(outfile,
                                    fieldnames=[
                                        "nloop", "niter", "nsmp", "nstep",
                                        "coverage", "fitness", "outcomes",
                                        "ratios"
                                    ])
            writer.writeheader()

        # Save the modified arguments
        metadata_dict = {
            "experiment": {
                "controller": {
                    "architecture": [20, 20],
                    "type": "nn_policy"
                },
                "environment": {
                    "id": env_id
                },
                "metric": metric,
                "type": "nn_policy"
            }
        }
        filename = '{}/experiment_metadata.json'.format(self.dirname)
        with open(filename, 'w') as outfile:
            json.dump(metadata_dict, outfile, sort_keys=True, indent=4)
예제 #16
0
from inverse_rl.envs.env_utils import CustomGymEnv
from inverse_rl.utils.log_utils import rllab_logdir
from inverse_rl.utils.hyper_sweep import run_sweep_parallel, run_sweep_serial

#Loads a policy from the given pickle-file and records a video
if __name__ == "__main__":
    #filename='data/ant_data_collect/2018_05_25_13_42_59_0/itr_1499.pkl'
    #filename='data/ant_data_collect/2018_05_23_15_21_40_0/itr_1499.pkl'
    #filename='data/ant_data_collect/2018_05_19_07_56_37_1/itr_1499.pkl'
    #filename='data/ant_data_collect/2018_05_19_07_56_37_0/itr_1485.pkl'
    #filename='data/ant_state_irl/2018_05_26_08_51_16_0/itr_999.pkl'
    #filename='data/ant_state_irl/2018_05_26_08_51_16_1/itr_999.pkl'
    #filename='data/ant_state_irl/2018_05_26_08_51_16_2/itr_999.pkl'
    filename = 'data/ant_transfer/2018_05_26_16_06_05_4/itr_999.pkl'
    import gym
    import joblib
    import rllab.misc.logger as rllablogger
    tf.reset_default_graph()
    with tf.Session(config=get_session_config()) as sess:
        rllablogger.set_snapshot_dir("data/video")
        saved = joblib.load(filename)
        env = TfEnv(
            CustomGymEnv('CustomAnt-v0', record_video=True, record_log=True)
        )  #'DisabledAnt-v0' #Switch for the DisabledAnt for the transfer task
        policy = saved['policy']
        observation = env.reset()
        for _ in range(1000):
            env.render()
            action, rest = policy.get_action(observation)
            observation, reward, done, info = env.step(action)
예제 #17
0
    def setup(self, env, policy, start_itr):

        if not self.args.algo == 'thddpg':
            # Baseline
            if self.args.baseline_type == 'linear':
                baseline = LinearFeatureBaseline(env_spec=env.spec)
            elif self.args.baseline_type == 'zero':
                baseline = ZeroBaseline(env_spec=env.spec)
            else:
                raise NotImplementedError(self.args.baseline_type)

            if self.args.control == 'concurrent':
                baseline = [baseline for _ in range(len(env.agents))]
        # Logger
        default_log_dir = config.LOG_DIR
        if self.args.log_dir is None:
            log_dir = osp.join(default_log_dir, self.args.exp_name)
        else:
            log_dir = self.args.log_dir

        tabular_log_file = osp.join(log_dir, self.args.tabular_log_file)
        text_log_file = osp.join(log_dir, self.args.text_log_file)
        params_log_file = osp.join(log_dir, self.args.params_log_file)

        logger.log_parameters_lite(params_log_file, self.args)
        logger.add_text_output(text_log_file)
        logger.add_tabular_output(tabular_log_file)
        prev_snapshot_dir = logger.get_snapshot_dir()
        prev_mode = logger.get_snapshot_mode()
        logger.set_snapshot_dir(log_dir)
        logger.set_snapshot_mode(self.args.snapshot_mode)
        logger.set_log_tabular_only(self.args.log_tabular_only)
        logger.push_prefix("[%s] " % self.args.exp_name)

        if self.args.algo == 'tftrpo':
            algo = MATRPO(
                env=env,
                policy_or_policies=policy,
                baseline_or_baselines=baseline,
                batch_size=self.args.batch_size,
                start_itr=start_itr,
                max_path_length=self.args.max_path_length,
                n_itr=self.args.n_iter,
                discount=self.args.discount,
                gae_lambda=self.args.gae_lambda,
                step_size=self.args.step_size,
                optimizer=ConjugateGradientOptimizer(
                    hvp_approach=FiniteDifferenceHvp(
                        base_eps=1e-5)) if self.args.recurrent else None,
                ma_mode=self.args.control)
        elif self.args.algo == 'thddpg':
            qfunc = thContinuousMLPQFunction(env_spec=env.spec)
            if self.args.exp_strategy == 'ou':
                es = OUStrategy(env_spec=env.spec)
            elif self.args.exp_strategy == 'gauss':
                es = GaussianStrategy(env_spec=env.spec)
            else:
                raise NotImplementedError()

            algo = thDDPG(env=env,
                          policy=policy,
                          qf=qfunc,
                          es=es,
                          batch_size=self.args.batch_size,
                          max_path_length=self.args.max_path_length,
                          epoch_length=self.args.epoch_length,
                          min_pool_size=self.args.min_pool_size,
                          replay_pool_size=self.args.replay_pool_size,
                          n_epochs=self.args.n_iter,
                          discount=self.args.discount,
                          scale_reward=0.01,
                          qf_learning_rate=self.args.qfunc_lr,
                          policy_learning_rate=self.args.policy_lr,
                          eval_samples=self.args.eval_samples,
                          mode=self.args.control)
        return algo
예제 #18
0
    parser.add_argument('--max_path_length', type=int, default=1000,
                        help='Max length of rollout')
    parser.add_argument('--speedup', type=float, default=1,
                        help='Speedup')
    parser.add_argument('--record_gym', type=bool, default=True,
                        help='Record video and log for gym environment.')
    args = parser.parse_args()

    # If the snapshot file use tensorflow, do:
    # import tensorflow as tf
    # with tf.Session():
    #     [rest of the code]
    while True:
        with tf.Session() as sess:
            #with tf.variable_scope("load_policy"):
            #with tf.variable_scope("load_policy", reuse=True):
                data = joblib.load(args.file)
                policy = data['policy']
                env = data['env']
                if args.record_gym:
                    from sandbox.rocky.tf.launchers.launcher_utils import get_env
                    import rllab.misc.logger as logger
                    import os.path as osp
                    logger.set_snapshot_dir(osp.dirname(args.file))
                    env = get_env(env_name = env._wrapped_env._wrapped_env.env_id,
                        record_video=True, record_log=True)
                while True:
                    path = rollout(env, policy, max_path_length=args.max_path_length,
                                   animated=True, always_return_paths=True, speedup=args.speedup)
                    print("Path length=%d, reward=%f"%(len(path["rewards"]),path["rewards"].sum()))
예제 #19
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum',
                        action='store_true',
                        default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--n_good', type=int, default=3)
    parser.add_argument('--n_hostage', type=int, default=5)
    parser.add_argument('--n_bad', type=int, default=5)
    parser.add_argument('--n_coop_save', type=int, default=2)
    parser.add_argument('--n_coop_avoid', type=int, default=2)
    parser.add_argument('--n_sensors', type=int, default=20)
    parser.add_argument('--sensor_range', type=float, default=0.2)
    parser.add_argument('--save_reward', type=float, default=3)
    parser.add_argument('--hit_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.01)
    parser.add_argument('--bomb_reward', type=float, default=-10.)

    parser.add_argument('--recurrent', action='store_true', default=False)
    parser.add_argument('--baseline_type', type=str, default='linear')
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )

    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    centralized = True if args.control == 'centralized' else False

    sensor_range = np.array(map(float, args.sensor_range.split(',')))
    assert sensor_range.shape == (args.n_pursuers, )

    env = ContinuousHostageWorld(args.n_good,
                                 args.n_hostage,
                                 args.n_bad,
                                 args.n_coop_save,
                                 args.n_coop_avoid,
                                 n_sensors=args.n_sensors,
                                 sensor_range=args.sensor_range,
                                 save_reward=args.save_reward,
                                 hit_reward=args.hit_reward,
                                 encounter_reward=args.encounter_reward,
                                 bomb_reward=args.bomb_reward)

    env = RLLabEnv(StandardizedEnv(env), mode=args.control)

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if args.recurrent:
        policy = GaussianGRUPolicy(env_spec=env.spec,
                                   hidden_sizes=args.hidden_sizes)
    else:
        policy = GaussianMLPPolicy(env_spec=env.spec,
                                   hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(obsfeat_space)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        n_itr=args.n_iter,
        discount=args.discount,
        step_size=args.max_kl,
        mode=args.control,
    )

    algo.train()
예제 #20
0
def run_experiment(argv):

    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel', type=int, default=1,
                        help='Number of parallel workers to perform rollouts.')
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
    parser.add_argument('--log_dir', type=str, default=default_log_dir,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--plot', type=ast.literal_eval, default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')

    args = parser.parse_args(argv[1:])

    from sandbox.vime.sampler import parallel_sampler_expl as parallel_sampler
    parallel_sampler.initialize(n_parallel=args.n_parallel)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    # read from stdin
    data = pickle.loads(base64.b64decode(args.args_data))

    log_dir = args.log_dir
    # exp_dir = osp.join(log_dir, args.exp_name)
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    maybe_iter = concretize(data)
    if is_iterable(maybe_iter):
        for _ in maybe_iter:
            pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
예제 #21
0
    def setup(self, env, policy, start_itr):

        # Baseline
        if self.args.baseline_type == 'linear':
            baseline = LinearFeatureBaseline(env_spec=env.spec)
        elif self.args.baseline_type == 'zero':
            baseline = ZeroBaseline(env_spec=env.spec)
        else:
            raise NotImplementedError(self.args.baseline_type)

        # Logger
        default_log_dir = config.LOG_DIR
        if self.args.log_dir is None:
            log_dir = osp.join(default_log_dir, self.args.exp_name)
        else:
            log_dir = self.args.log_dir

        tabular_log_file = osp.join(log_dir, self.args.tabular_log_file)
        text_log_file = osp.join(log_dir, self.args.text_log_file)
        params_log_file = osp.join(log_dir, self.args.params_log_file)

        logger.log_parameters_lite(params_log_file, self.args)
        logger.add_text_output(text_log_file)
        logger.add_tabular_output(tabular_log_file)
        # prev_snapshot_dir = logger.get_snapshot_dir()
        # prev_mode = logger.get_snapshot_mode()
        logger.set_snapshot_dir(log_dir)
        logger.set_snapshot_mode(self.args.snapshot_mode)
        logger.set_log_tabular_only(self.args.log_tabular_only)
        logger.push_prefix("[%s] " % self.args.exp_name)

        if self.args.algo == 'reinforce':
            algo = MAReinforce(env=env,
                               policy_or_policies=policy,
                               plot=False,
                               baseline_or_baselines=baseline,
                               batch_size=self.args.batch_size,
                               pause_for_plot=True,
                               start_itr=start_itr,
                               max_path_length=self.args.max_path_length,
                               n_itr=self.args.n_iter,
                               discount=self.args.discount,
                               gae_lambda=self.args.gae_lambda,
                               step_size=self.args.step_size,
                               ma_mode=self.args.control,
                               save_param_update=self.args.save_param_update)

        elif self.args.algo == 'dqn':
            algo = MADQN(env=env,
                         networks=policy,
                         plot=False,
                         batch_size=self.args.batch_size,
                         pause_for_plot=True,
                         start_itr=start_itr,
                         max_path_length=self.args.max_path_length,
                         n_itr=self.args.n_iter,
                         discount=self.args.discount,
                         ma_mode=self.args.control,
                         pre_trained_size=self.args.replay_pre_trained_size,
                         target_network_update=self.args.target_network_update,
                         save_param_update=self.args.save_param_update)

        elif self.args.algo == 'a2c':
            algo = MAA2C(env=env,
                         policy_or_policies=policy,
                         plot=False,
                         baseline_or_baselines=baseline,
                         batch_size=self.args.batch_size,
                         pause_for_plot=True,
                         start_itr=start_itr,
                         max_path_length=self.args.max_path_length,
                         n_itr=self.args.n_iter,
                         discount=self.args.discount,
                         ma_mode=self.args.control,
                         actor_learning_rate=self.args.policy_lr,
                         critic_learning_rate=self.args.qfunc_lr,
                         value_coefficient=0.5,
                         entropy_coefficient=0.01,
                         clip_grads=0.5,
                         save_param_update=self.args.save_param_update)

        return algo
예제 #22
0
    def setup(self, env, policy, start_itr):

        if not self.args.algo == 'thddpg':
            # Baseline
            if self.args.baseline_type == 'linear':
                baseline = LinearFeatureBaseline(env_spec=env.spec)
            elif self.args.baseline_type == 'zero':
                baseline = ZeroBaseline(env_spec=env.spec)
            else:
                raise NotImplementedError(self.args.baseline_type)

            if self.args.control == 'concurrent':
                baseline = [baseline for _ in range(len(env.agents))]
        # Logger
        default_log_dir = config.LOG_DIR
        if self.args.log_dir is None:
            log_dir = osp.join(default_log_dir, self.args.exp_name)
        else:
            log_dir = self.args.log_dir

        tabular_log_file = osp.join(log_dir, self.args.tabular_log_file)
        text_log_file = osp.join(log_dir, self.args.text_log_file)
        params_log_file = osp.join(log_dir, self.args.params_log_file)

        logger.log_parameters_lite(params_log_file, self.args)
        logger.add_text_output(text_log_file)
        logger.add_tabular_output(tabular_log_file)
        prev_snapshot_dir = logger.get_snapshot_dir()
        prev_mode = logger.get_snapshot_mode()
        logger.set_snapshot_dir(log_dir)
        logger.set_snapshot_mode(self.args.snapshot_mode)
        logger.set_log_tabular_only(self.args.log_tabular_only)
        logger.push_prefix("[%s] " % self.args.exp_name)

        if self.args.algo == 'tftrpo':
            algo = MATRPO(env=env, policy_or_policies=policy, baseline_or_baselines=baseline,
                          batch_size=self.args.batch_size, start_itr=start_itr,
                          max_path_length=self.args.max_path_length, n_itr=self.args.n_iter,
                          discount=self.args.discount, gae_lambda=self.args.gae_lambda,
                          step_size=self.args.step_size, optimizer=ConjugateGradientOptimizer(
                              hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)) if
                          self.args.recurrent else None, ma_mode=self.args.control)
        elif self.args.algo == 'thddpg':
            qfunc = thContinuousMLPQFunction(env_spec=env.spec)
            if self.args.exp_strategy == 'ou':
                es = OUStrategy(env_spec=env.spec)
            elif self.args.exp_strategy == 'gauss':
                es = GaussianStrategy(env_spec=env.spec)
            else:
                raise NotImplementedError()

            algo = thDDPG(env=env, policy=policy, qf=qfunc, es=es, batch_size=self.args.batch_size,
                          max_path_length=self.args.max_path_length,
                          epoch_length=self.args.epoch_length,
                          min_pool_size=self.args.min_pool_size,
                          replay_pool_size=self.args.replay_pool_size, n_epochs=self.args.n_iter,
                          discount=self.args.discount, scale_reward=0.01,
                          qf_learning_rate=self.args.qfunc_lr,
                          policy_learning_rate=self.args.policy_lr,
                          eval_samples=self.args.eval_samples, mode=self.args.control)
        return algo
예제 #23
0
def run_experiment(
    args_data,
    variant_data=None,
    seed=None,
    n_parallel=1,
    exp_name=None,
    log_dir=None,
    snapshot_mode='all',
    snapshot_gap=1,
    tabular_log_file='progress.csv',
    text_log_file='debug.log',
    params_log_file='params.json',
    variant_log_file='variant.json',
    resume_from=None,
    plot=False,
    log_tabular_only=False,
    log_debug_log_only=False,
):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    if exp_name is None:
        exp_name = default_exp_name

    if seed is not None:
        set_seed(seed)

    if n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=n_parallel)
        if seed is not None:
            parallel_sampler.set_seed(seed)

    if plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if log_dir is None:
        log_dir = osp.join(default_log_dir, exp_name)
    else:
        log_dir = log_dir
    tabular_log_file = osp.join(log_dir, tabular_log_file)
    text_log_file = osp.join(log_dir, text_log_file)
    params_log_file = osp.join(log_dir, params_log_file)

    if variant_data is not None:
        variant_data = variant_data
        variant_log_file = osp.join(log_dir, variant_log_file)
        # print(variant_log_file)
        # print(variant_data)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(snapshot_mode)
    logger.set_snapshot_gap(snapshot_gap)
    logger.set_log_tabular_only(log_tabular_only)
    logger.set_debug_log_only(log_debug_log_only)
    logger.push_prefix("[%s] " % exp_name)

    if resume_from is not None:
        data = joblib.load(resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        args_data(variant_data)

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
예제 #24
0
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel', type=int, default=1,
                        help='Number of parallel workers to perform rollouts. 0 => don\'t start any workers')
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
    parser.add_argument('--log_dir', type=str, default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), "gap" (every'
                             '`snapshot_gap` iterations are saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--snapshot_gap', type=int, default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file', type=str, default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument('--resume_from', type=str, default=None,
                        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot', type=ast.literal_eval, default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data', type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
예제 #25
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.99)
    parser.add_argument('--gae_lambda', type=float, default=1.0)
    parser.add_argument('--reward_scale', type=float, default=1.0)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum',
                        action='store_true',
                        default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--rectangle', type=str, default='10,10')
    parser.add_argument('--map_type', type=str, default='rectangle')
    parser.add_argument('--n_evaders', type=int, default=5)
    parser.add_argument('--n_pursuers', type=int, default=2)
    parser.add_argument('--obs_range', type=int, default=3)
    parser.add_argument('--n_catch', type=int, default=2)
    parser.add_argument('--urgency', type=float, default=0.0)
    parser.add_argument('--pursuit', dest='train_pursuit', action='store_true')
    parser.add_argument('--evade', dest='train_pursuit', action='store_false')
    parser.set_defaults(train_pursuit=True)
    parser.add_argument('--surround', action='store_true', default=False)
    parser.add_argument('--constraint_window', type=float, default=1.0)
    parser.add_argument('--sample_maps', action='store_true', default=False)
    parser.add_argument('--map_file', type=str, default='../maps/map_pool.npy')
    parser.add_argument('--flatten', action='store_true', default=False)
    parser.add_argument('--reward_mech', type=str, default='global')
    parser.add_argument('--catchr', type=float, default=0.1)
    parser.add_argument('--term_pursuit', type=float, default=5.0)

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_type', type=str, default='linear')

    parser.add_argument('--conv', action='store_true', default=False)

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--checkpoint', type=str, default=None)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )

    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    if args.checkpoint:
        with tf.Session() as sess:
            data = joblib.load(args.checkpoint)
            policy = data['policy']
            env = data['env']
    else:
        if args.sample_maps:
            map_pool = np.load(args.map_file)
        else:
            if args.map_type == 'rectangle':
                env_map = TwoDMaps.rectangle_map(
                    *map(int, args.rectangle.split(',')))
            elif args.map_type == 'complex':
                env_map = TwoDMaps.complex_map(
                    *map(int, args.rectangle.split(',')))
            else:
                raise NotImplementedError()
            map_pool = [env_map]

        env = PursuitEvade(map_pool,
                           n_evaders=args.n_evaders,
                           n_pursuers=args.n_pursuers,
                           obs_range=args.obs_range,
                           n_catch=args.n_catch,
                           train_pursuit=args.train_pursuit,
                           urgency_reward=args.urgency,
                           surround=args.surround,
                           sample_maps=args.sample_maps,
                           constraint_window=args.constraint_window,
                           flatten=args.flatten,
                           reward_mech=args.reward_mech,
                           catchr=args.catchr,
                           term_pursuit=args.term_pursuit)

        env = TfEnv(
            RLLabEnv(StandardizedEnv(env,
                                     scale_reward=args.reward_scale,
                                     enable_obsnorm=False),
                     mode=args.control))

        if args.recurrent:
            if args.conv:
                feature_network = ConvNetwork(
                    name='feature_net',
                    input_shape=emv.spec.observation_space.shape,
                    output_dim=5,
                    conv_filters=(16, 32, 32),
                    conv_filter_sizes=(3, 3, 3),
                    conv_strides=(1, 1, 1),
                    conv_pads=('VALID', 'VALID', 'VALID'),
                    hidden_sizes=(64, ),
                    hidden_nonlinearity=tf.nn.relu,
                    output_nonlinearity=tf.nn.softmax)
            else:
                feature_network = MLP(
                    name='feature_net',
                    input_shape=(env.spec.observation_space.flat_dim +
                                 env.spec.action_space.flat_dim, ),
                    output_dim=5,
                    hidden_sizes=(256, 128, 64),
                    hidden_nonlinearity=tf.nn.tanh,
                    output_nonlinearity=None)
            if args.recurrent == 'gru':
                policy = CategoricalGRUPolicy(env_spec=env.spec,
                                              feature_network=feature_network,
                                              hidden_dim=int(
                                                  args.policy_hidden_sizes),
                                              name='policy')
            elif args.recurrent == 'lstm':
                policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                               feature_network=feature_network,
                                               hidden_dim=int(
                                                   args.policy_hidden_sizes),
                                               name='policy')
        elif args.conv:
            feature_network = ConvNetwork(
                name='feature_net',
                input_shape=env.spec.observation_space.shape,
                output_dim=5,
                conv_filters=(8, 16),
                conv_filter_sizes=(3, 3),
                conv_strides=(2, 1),
                conv_pads=('VALID', 'VALID'),
                hidden_sizes=(32, ),
                hidden_nonlinearity=tf.nn.relu,
                output_nonlinearity=tf.nn.softmax)
            policy = CategoricalMLPPolicy(name='policy',
                                          env_spec=env.spec,
                                          prob_network=feature_network)
        else:
            policy = CategoricalMLPPolicy(name='policy',
                                          env_spec=env.spec,
                                          hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(env_spec=env.spec)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp(
            base_eps=1e-5)) if args.recurrent else None,
        mode=args.control,
    )

    algo.train()
예제 #26
0
def run_experiment(argv):
    # e2crawfo: These imports, in this order, were necessary for fixing issues on cedar.
    import rllab.mujoco_py.mjlib
    import tensorflow

    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=
        'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
    )
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_tf_summary_dir(osp.join(log_dir, "tf_summary"))
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        maybe_iter = algo.train()
        if is_iterable(maybe_iter):
            for _ in maybe_iter:
                pass
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
예제 #27
0
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=
        'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
    )
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), "gap" (every'
                        '`snapshot_gap` iterations are saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--tensorboard_log_dir',
                        type=str,
                        default='tb',
                        help='Name of the folder for tensorboard_summary.')
    parser.add_argument(
        '--tensorboard_step_key',
        type=str,
        default=None,
        help=
        'Name of the step key in log data which shows the step in tensorboard_summary.'
    )
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False)
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='checkpoint',
                        help='Name of the folder for checkpoints.')
    parser.add_argument('--obs_dir',
                        type=str,
                        default='obs',
                        help='Name of the folder for original observations.')

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)
    tensorboard_log_dir = osp.join(log_dir, args.tensorboard_log_dir)
    checkpoint_dir = osp.join(log_dir, args.checkpoint_dir)
    obs_dir = osp.join(log_dir, args.obs_dir)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    logger.set_tensorboard_dir(tensorboard_log_dir)
    logger.set_checkpoint_dir(checkpoint_dir)
    logger.set_obs_dir(obs_dir)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.set_tensorboard_step_key(args.tensorboard_step_key)
    logger.push_prefix("[%s] " % args.exp_name)

    git_commit = get_git_commit_hash()
    logger.log('Git commit: {}'.format(git_commit))

    git_diff_file_path = osp.join(log_dir,
                                  'git_diff_{}.patch'.format(git_commit))
    save_git_diff_to_file(git_diff_file_path)

    logger.log('hostname: {}, pid: {}, tmux session: {}'.format(
        socket.gethostname(), os.getpid(), get_tmux_session_name()))

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
예제 #28
0
def run_experiment(argv):

    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel',
                        type=int,
                        default=1,
                        help='Number of parallel workers to perform rollouts.')
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=default_log_dir,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from sandbox.vase.sampler import parallel_sampler_expl as parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)

        if args.seed is not None:
            set_seed(args.seed)
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    # read from stdin
    data = pickle.loads(base64.b64decode(args.args_data))

    log_dir = args.log_dir
    # exp_dir = osp.join(log_dir, args.exp_name)
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    maybe_iter = concretize(data)
    if is_iterable(maybe_iter):
        for _ in maybe_iter:
            pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
예제 #29
0
# -*- coding: utf-8 -*-
import os
#trpo
import rllab.misc.logger as logger
from rllab.algos.trpo import TRPO
from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline
#from rllab.envs.box2d.cartpole_env import CartpoleEnv
from rllab.envs.gym_env import GymEnv
from rllab.envs.normalized_env import normalize
from rllab.policies.gaussian_mlp_policy import GaussianMLPPolicy
#from rllab.policies.categorical_mlp_policy import CategoricalMLPPolicy

#set this to only save snapshots

logger.set_snapshot_dir('snapShots16_8')
logger.set_snapshot_mode('last')
#logger.set_snapshot_gap(100)

env = normalize(GymEnv('DartWalker3d-v1'))
env.render()

policy16 = GaussianMLPPolicy(
    env_spec=env.spec,
    hidden_sizes=(16, 8)
    #hidden_sizes=(128,128,64)
    #hidden_sizes=(64, 32, 16)
    #hidden_sizes=(32, 32, 16)
)
baseline = LinearFeatureBaseline(env_spec=env.spec)

algo = TRPO(env=env,