Ejemplo n.º 1
0
def shutdown_output():
    global OUTPUT_DIR
    assert OUTPUT_DIR is not None, "Cannot shutdown output that has not been setup."
    logger.remove_text_output(os.path.join(OUTPUT_DIR, "rllab.txt"))
    logger.remove_tabular_output(os.path.join(OUTPUT_DIR, "rllab.csv"))

    OUTPUT_DIR = None
Ejemplo n.º 2
0
def logger_context(log_dir,
                   name,
                   run_ID,
                   log_params=None,
                   snapshot_mode="none"):
    logger.set_snapshot_mode(snapshot_mode)
    logger.set_log_tabular_only(False)
    abs_log_dir = os.path.abspath(log_dir)
    if LOG_DIR != os.path.commonpath([abs_log_dir, LOG_DIR]):
        print(
            "logger_context received log_dir outside of rllab.config.LOG_DIR: "
            "prepending by {}/local/<yyyymmdd>/".format(LOG_DIR))
        abs_log_dir = make_log_dir(log_dir)
    exp_dir = os.path.join(abs_log_dir, "{}_{}".format(name, run_ID))
    tabular_log_file = os.path.join(exp_dir, "progress.csv")
    text_log_file = os.path.join(exp_dir, "debug.log")
    params_log_file = os.path.join(exp_dir, "params.json")

    logger.set_snapshot_dir(exp_dir)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    logger.push_prefix("{}_{} ".format(name, run_ID))

    if log_params is None:
        log_params = dict()
    log_params["name"] = name
    log_params["run_ID"] = run_ID
    with open(params_log_file, "w") as f:
        json.dump(log_params, f)

    yield

    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Ejemplo n.º 3
0
def eval_mab_policies(n_arms=4,
                      horizon=15,
                      n_traj=1000,
                      log_dir=None,
                      turntaking=False):
    text_output_file = None if log_dir is None else osp.join(log_dir, "text")
    rag = uniform_bernoulli_iterator()
    bandit = BanditEnv(n_arms=n_arms,
                       reward_dist=bernoulli,
                       reward_args_generator=rag,
                       horizon=horizon)
    if text_output_file is not None:
        logger.add_text_output(text_output_file)

    for human_policy in [human_policy_dict['ucl']
                         ]:  # human_policy_dict.values():
        # for i in range(10000):
        #     logger.log("Filler")

        logger.log("-------------------")
        logger.log("Evaluating {} for {} timesteps".format(
            human_policy.__name__, horizon))
        logger.log("-------------------")

        test_pi_H = human_policy(bandit)
        if turntaking:
            test_env = HumanIterativeWrapper(bandit, test_pi_H)
        else:
            test_env = HumanCRLWrapper(bandit, test_pi_H, 0)

        logger.log("Obtaining Samples...")
        # Alas, the rllab samplers don't support hot swapping envs and batch sizes
        # TODO: write a new parallel sampler, instead of sampling manually
        rewards = []
        for i in pyprind.prog_bar(range(n_traj)):
            if turntaking:
                act_counts = [0 for i in range(bandit.nA)]
            observation = test_env.reset()
            action = test_env.nA - 1
            for t in range(horizon):
                observation, reward, done, info = test_env.step(action)
                if turntaking:
                    a_H = observation[1]
                    if a_H < bandit.nA:
                        act_counts[a_H] += 1
                    action = np.argmax(act_counts)
                if done:
                    rewards.append(info["accumulated rewards"])
                    break
        #feel free to add more data
        logger.log("NumTrajs {}".format(n_traj))
        logger.log("AverageReturn {}".format(np.mean(rewards)))
        logger.log("StdReturn {}".format(np.std(rewards)))
        logger.log("MaxReturn {}".format(np.max(rewards)))
        logger.log("MinReturn {}".format(np.min(rewards)))

    if text_output_file is not None:
        logger.remove_text_output(text_output_file)
Ejemplo n.º 4
0
def run_experiment(algo,
                   n_parallel=0,
                   seed=0,
                   plot=False,
                   log_dir=None,
                   exp_name=None,
                   snapshot_mode='last',
                   snapshot_gap=1,
                   exp_prefix='experiment',
                   log_tabular_only=False):
    default_log_dir = config.LOG_DIR + "/local/" + exp_prefix
    set_seed(seed)
    if exp_name is None:
        now = datetime.datetime.now(dateutil.tz.tzlocal())
        timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
        exp_name = 'experiment_%s' % (timestamp)
    if n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=n_parallel)
        parallel_sampler.set_seed(seed)
    if plot:
        from rllab.plotter import plotter
        plotter.init_worker()
    if log_dir is None:
        log_dir = osp.join(default_log_dir, exp_name)
    tabular_log_file = osp.join(log_dir, 'progress.csv')
    text_log_file = osp.join(log_dir, 'debug.log')
    #params_log_file = osp.join(log_dir, 'params.json')

    #logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(snapshot_mode)
    logger.set_snapshot_gap(snapshot_gap)
    logger.set_log_tabular_only(log_tabular_only)
    logger.push_prefix("[%s] " % exp_name)

    algo.train()

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Ejemplo n.º 5
0
def find_optimal_epsilon(n_arms=4, horizon=15, n_traj=10000, log_dir=None):
    text_output_file = None if log_dir is None else osp.join(log_dir, "text")
    rag = uniform_bernoulli_iterator()
    bandit = BanditEnv(n_arms=n_arms,
                       reward_dist=bernoulli,
                       reward_args_generator=rag,
                       horizon=horizon)
    if text_output_file is not None:
        logger.add_text_output(text_output_file)

    for epsilon in candidate_epsilons:
        # for i in range(10000):
        #     logger.log("Filler")

        logger.log("-------------------")
        logger.log("Evaluating epsilon={} for {} timesteps".format(
            epsilon, horizon))
        logger.log("-------------------")

        test_pi_H = EpsGreedyBanditPolicy(bandit, epsilon=epsilon)
        test_env = HumanCRLWrapper(bandit, test_pi_H, 0)
        logger.log("Obtaining Samples...")
        # Alas, the rllab samplers don't support hot swapping envs and batch sizes
        # TODO: write a new parallel sampler, instead of sampling manually
        rewards = []
        for i in pyprind.prog_bar(range(n_traj)):
            observation = test_env.reset()
            for t in range(horizon):
                action = test_env.nA - 1
                observation, reward, done, info = test_env.step(action)
                if done:
                    rewards.append(info["accumulated rewards"])
                    break
        #feel free to add more data
        logger.log("NumTrajs {}".format(n_traj))
        logger.log("AverageReturn {}".format(np.mean(rewards)))
        logger.log("StdReturn {}".format(np.std(rewards)))
        logger.log("MaxReturn {}".format(np.max(rewards)))
        logger.log("MinReturn {}".format(np.min(rewards)))

    if text_output_file is not None:
        logger.remove_text_output(text_output_file)
Ejemplo n.º 6
0
def run_experiment(argv):
    # e2crawfo: These imports, in this order, were necessary for fixing issues on cedar.
    import rllab.mujoco_py.mjlib
    import tensorflow

    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=
        'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
    )
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_tf_summary_dir(osp.join(log_dir, "tf_summary"))
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        maybe_iter = algo.train()
        if is_iterable(maybe_iter):
            for _ in maybe_iter:
                pass
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Ejemplo n.º 7
0
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel', type=int, default=1,
                        help='Number of parallel workers to perform rollouts. 0 => don\'t start any workers')
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
    parser.add_argument('--log_dir', type=str, default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), "gap" (every'
                             '`snapshot_gap` iterations are saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--snapshot_gap', type=int, default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file', type=str, default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument('--resume_from', type=str, default=None,
                        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot', type=ast.literal_eval, default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data', type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Ejemplo n.º 8
0
def run_experiment(argv):

    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel', type=int, default=1,
                        help='Number of parallel workers to perform rollouts.')
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
    parser.add_argument('--log_dir', type=str, default=default_log_dir,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--plot', type=ast.literal_eval, default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')

    args = parser.parse_args(argv[1:])

    from sandbox.vime.sampler import parallel_sampler_expl as parallel_sampler
    parallel_sampler.initialize(n_parallel=args.n_parallel)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    # read from stdin
    data = pickle.loads(base64.b64decode(args.args_data))

    log_dir = args.log_dir
    # exp_dir = osp.join(log_dir, args.exp_name)
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    maybe_iter = concretize(data)
    if is_iterable(maybe_iter):
        for _ in maybe_iter:
            pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Ejemplo n.º 9
0
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=
        'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
    )
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), "gap" (every'
                        '`snapshot_gap` iterations are saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)
    #variant_data is the variant dictionary sent from trpoTests_ExpLite
    if (args.resume_from is not None) and (
            '&|&' in args.resume_from
    ):  #separate string on &|& to get iters and file location
        vals = args.resume_from.split(
            '&|&')  #dirRes | numItrs to go | new batchSize
        dirRes = vals[0]
        numItrs = int(vals[1])
        if (len(vals) > 2):
            batchSize = int(vals[2])
        print("resuming from :{}".format(dirRes))
        data = joblib.load(dirRes)
        #data is dict : 'baseline', 'algo', 'itr', 'policy', 'env'
        assert 'algo' in data
        algo = data['algo']
        assert 'policy' in data
        pol = data['policy']
        bl = data['baseline']
        oldBatchSize = algo.batch_size
        algo.n_itr = numItrs
        if (len(vals) > 2):
            algo.batch_size = batchSize
            print(
                'algo iters : {} cur iter :{} oldBatchSize : {} newBatchSize : {}'
                .format(algo.n_itr, algo.current_itr, oldBatchSize,
                        algo.batch_size))
        else:
            print('algo iters : {} cur iter :{} '.format(
                algo.n_itr, algo.current_itr))
        algo.train()
    else:
        print('Not resuming - building new exp')
        # read from stdin
        if args.use_cloudpickle:  #set to use cloudpickle
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            print('not use cloud pickle')
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Ejemplo n.º 10
0
        #     if params['use_hide_alg'] == 1:
        #         if params['batch_size_uniform'] is not None and params['batch_size_uniform'] > 0:
        #             logger.log('WARNING: Training with uniform sampling. Testing is done BEFORE the optimization !!!!')
        #             algo.train_brownian_with_uniform()
        #         else:
        #             algo.train_brownian()
        #     elif params['use_hide_alg'] == 2:
        #         if train_mode == 0:
        #             algo.train_brownian_with_goals()
        #         elif train_mode == 1:
        #             algo.train_brownian_reverse_repeat()
        #         elif train_mode == 2:
        #             algo.train_brownian_multiseed()
        #         elif train_mode == 3:
        #             algo.train_brownian_multiseed_swap_every_update_period()
        #         else:
        #             raise NotImplementedError
        #     else: #my version of the alg
        #         algo.train_hide_seek()
        # else:
        #     algo.train_seek()
        logger.log('Experiment finished ...')

        logger.set_snapshot_mode(prev_mode)
        logger.set_snapshot_dir(prev_snapshot_dir)
        logger.remove_tabular_output(tabular_log_file_fullpath)
        logger.remove_text_output(text_log_file_fullpath)
        logger.pop_prefix()

        print('Tabular log file:', tabular_log_file_fullpath)
Ejemplo n.º 11
0
def run_experiment(
    args_data,
    variant_data=None,
    seed=None,
    n_parallel=1,
    exp_name=None,
    log_dir=None,
    snapshot_mode='all',
    snapshot_gap=1,
    tabular_log_file='progress.csv',
    text_log_file='debug.log',
    params_log_file='params.json',
    variant_log_file='variant.json',
    resume_from=None,
    plot=False,
    log_tabular_only=False,
    log_debug_log_only=False,
):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    if exp_name is None:
        exp_name = default_exp_name

    if seed is not None:
        set_seed(seed)

    if n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=n_parallel)
        if seed is not None:
            parallel_sampler.set_seed(seed)

    if plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if log_dir is None:
        log_dir = osp.join(default_log_dir, exp_name)
    else:
        log_dir = log_dir
    tabular_log_file = osp.join(log_dir, tabular_log_file)
    text_log_file = osp.join(log_dir, text_log_file)
    params_log_file = osp.join(log_dir, params_log_file)

    if variant_data is not None:
        variant_data = variant_data
        variant_log_file = osp.join(log_dir, variant_log_file)
        # print(variant_log_file)
        # print(variant_data)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(snapshot_mode)
    logger.set_snapshot_gap(snapshot_gap)
    logger.set_log_tabular_only(log_tabular_only)
    logger.set_debug_log_only(log_debug_log_only)
    logger.push_prefix("[%s] " % exp_name)

    if resume_from is not None:
        data = joblib.load(resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        args_data(variant_data)

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Ejemplo n.º 12
0
def close_logger(log_dir):
    logger.remove_tabular_output(osp.join(log_dir, 'progress.csv'))
    logger.remove_text_output(osp.join(log_dir, 'debug.log'))
    logger.pop_prefix()
Ejemplo n.º 13
0
def run_task(v, num_cpu=8, log_dir="./data", ename=None, **kwargs):
    from scipy.stats import bernoulli, uniform, beta

    import tensorflow as tf
    from assistive_bandits.experiments.l2_rnn_baseline import L2RNNBaseline
    from assistive_bandits.experiments.tbptt_optimizer import TBPTTOptimizer
    from sandbox.rocky.tf.policies.categorical_gru_policy import CategoricalGRUPolicy
    from assistive_bandits.experiments.pposgd_clip_ratio import PPOSGD

    if not local_test and force_remote:
        import rl_algs.logger as rl_algs_logger
        log_dir = rl_algs_logger.get_dir()

    if log_dir is not None:
        log_dir = osp.join(log_dir, str(v["n_episodes"]))
        log_dir = osp.join(log_dir, v["human_policy"])

    text_output_file = None if log_dir is None else osp.join(log_dir, "text")
    tabular_output_file = None if log_dir is None else osp.join(
        log_dir, "train_table.csv")
    info_theory_tabular_output = None if log_dir is None else osp.join(
        log_dir, "info_table.csv")
    rag = uniform_bernoulli_iterator()
    bandit = BanditEnv(n_arms=v["n_arms"],
                       reward_dist=bernoulli,
                       reward_args_generator=rag,
                       horizon=v["n_episodes"])
    pi_H = human_policy_dict[v["human_policy"]](bandit)

    h_wrapper = human_wrapper_dict[v["human_wrapper"]]

    env = h_wrapper(bandit, pi_H, penalty=v["intervention_penalty"])

    if text_output_file is not None:
        logger.add_text_output(text_output_file)
        logger.add_tabular_output(tabular_output_file)

    logger.log("Training against {}".format(v["human_policy"]))
    logger.log("Setting seed to {}".format(v["seed"]))
    env.seed(v["seed"])

    baseline = L2RNNBaseline(
        name="vf",
        env_spec=env.spec,
        log_loss_before=False,
        log_loss_after=False,
        hidden_nonlinearity=getattr(tf.nn, v["nonlinearity"]),
        weight_normalization=v["weight_normalization"],
        layer_normalization=v["layer_normalization"],
        state_include_action=False,
        hidden_dim=v["hidden_dim"],
        optimizer=TBPTTOptimizer(
            batch_size=v["opt_batch_size"],
            n_steps=v["opt_n_steps"],
            n_epochs=v["min_epochs"],
        ),
        batch_size=v["opt_batch_size"],
        n_steps=v["opt_n_steps"],
    )
    policy = CategoricalGRUPolicy(env_spec=env.spec,
                                  hidden_nonlinearity=getattr(
                                      tf.nn, v["nonlinearity"]),
                                  hidden_dim=v["hidden_dim"],
                                  state_include_action=True,
                                  name="policy")

    n_itr = 3 if local_test else 100
    # logger.log('sampler_args {}'.format(dict(n_envs=max(1, min(int(np.ceil(v["batch_size"] / v["n_episodes"])), 100)))))

    # parallel_sampler.initialize(6)
    # parallel_sampler.set_seed(v["seed"])

    algo = PPOSGD(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=v["batch_size"],
        max_path_length=v["n_episodes"],
        # 43.65 env time
        sampler_args=dict(n_envs=max(
            1, min(int(np.ceil(v["batch_size"] / v["n_episodes"])), 100))),
        # 100 threads -> 1:36 to sample 187.275
        # 6 threads -> 1:31
        # force_batch_sampler=True,
        n_itr=n_itr,
        step_size=v["mean_kl"],
        clip_lr=v["clip_lr"],
        log_loss_kl_before=False,
        log_loss_kl_after=False,
        use_kl_penalty=v["use_kl_penalty"],
        min_n_epochs=v["min_epochs"],
        entropy_bonus_coeff=v["entropy_bonus_coeff"],
        optimizer=TBPTTOptimizer(
            batch_size=v["opt_batch_size"],
            n_steps=v["opt_n_steps"],
            n_epochs=v["min_epochs"],
        ),
        discount=v["discount"],
        gae_lambda=v["gae_lambda"],
        use_line_search=True
        # scope=ename
    )

    sess = tf.Session()
    with sess.as_default():
        algo.train(sess)

        if text_output_file is not None:
            logger.remove_tabular_output(tabular_output_file)
            logger.add_tabular_output(info_theory_tabular_output)

        # Now gather statistics for t-tests and such!
        for human_policy_name, human_policy in human_policy_dict.items():

            logger.log("-------------------")
            logger.log("Evaluating against {}".format(human_policy.__name__))
            logger.log("-------------------")

            logger.log("Obtaining Samples...")
            test_pi_H = human_policy(bandit)
            test_env = h_wrapper(bandit, test_pi_H, penalty=0.)
            eval_sampler = VectorizedSampler(algo, n_envs=100)
            algo.batch_size = v["num_eval_traj"] * v["n_episodes"]
            algo.env = test_env
            logger.log("algo.env.pi_H has class: {}".format(
                algo.env.pi_H.__class__))
            eval_sampler.start_worker()
            paths = eval_sampler.obtain_samples(-1)
            eval_sampler.shutdown_worker()
            rewards = []

            H_act_seqs = []
            R_act_seqs = []
            best_arms = []
            optimal_a_seqs = []
            for p in paths:
                a_Rs = env.action_space.unflatten_n(p['actions'])
                obs_R = env.observation_space.unflatten_n(p['observations'])
                best_arm = np.argmax(p['env_infos']['arm_means'][0])

                H_act_seqs.append(obs_R[:, 1])
                R_act_seqs.append(a_Rs)
                best_arms.append(best_arm)
                optimal_a_seqs.append(
                    [best_arm for _ in range(v["n_episodes"])])

                rewards.append(np.sum(p['rewards']))

            #feel free to add more data
            logger.log("NumTrajs {}".format(v["num_eval_traj"]))
            logger.log("AverageReturn {}".format(np.mean(rewards)))
            logger.log("StdReturn {}".format(np.std(rewards)))
            logger.log("MaxReturn {}".format(np.max(rewards)))
            logger.log("MinReturn {}".format(np.min(rewards)))

            optimal_a_H_freqs = _frequency_agreement(H_act_seqs,
                                                     optimal_a_seqs)
            optimal_a_R_freqs = _frequency_agreement(R_act_seqs,
                                                     optimal_a_seqs)

            for t in range(v["n_episodes"]):
                logger.record_tabular("PolicyExecTime", 0)
                logger.record_tabular("EnvExecTime", 0)
                logger.record_tabular("ProcessExecTime", 0)
                logger.record_tabular("Tested Against", human_policy_name)
                logger.record_tabular("t", t)
                logger.record_tabular("a_H_agreement", optimal_a_H_freqs[t])
                logger.record_tabular("a_R_agreement", optimal_a_R_freqs[t])

                H_act_seqs_truncated = [a_Hs[0:t] for a_Hs in H_act_seqs]
                R_act_seqs_truncated = [a_Rs[0:t] for a_Rs in R_act_seqs]
                h_mutual_info = _mutual_info_seqs(H_act_seqs_truncated,
                                                  best_arms, v["n_arms"] + 1)
                r_mutual_info = _mutual_info_seqs(R_act_seqs_truncated,
                                                  best_arms, v["n_arms"] + 1)
                logger.record_tabular("h_mutual_info", h_mutual_info)
                logger.record_tabular("r_mutual_info", r_mutual_info)
                logger.record_tabular("a_H_opt_freq", optimal_a_H_freqs[t])
                logger.record_tabular("a_R_opt_freq", optimal_a_R_freqs[t])
                logger.dump_tabular()

            test_env = h_wrapper(bandit, human_policy(bandit), penalty=0.)

            logger.log("Printing Example Trajectories")
            for i in range(v["num_display_traj"]):
                observation = test_env.reset()
                policy.reset()
                logger.log("-- Trajectory {} of {}".format(
                    i + 1, v["num_display_traj"]))
                logger.log("t \t obs \t act \t reward \t act_probs")
                for t in range(v["n_episodes"]):
                    action, act_info = policy.get_action(observation)
                    new_obs, reward, done, info = test_env.step(action)
                    logger.log("{} \t {} \t {} \t {} \t {}".format(
                        t, observation, action, reward, act_info['prob']))
                    observation = new_obs
                    if done:
                        logger.log("Total reward: {}".format(
                            info["accumulated rewards"]))
                        break

    if text_output_file is not None:
        logger.remove_text_output(text_output_file)
        logger.remove_tabular_output(info_theory_tabular_output)
Ejemplo n.º 14
0
            optimal_a_H_freqs = _frequency_agreement(H_act_seqs,
                                                     optimal_a_seqs)
            optimal_a_R_freqs = _frequency_agreement(R_act_seqs,
                                                     optimal_a_seqs)

            for t in range(max_timesteps):
                logger.record_tabular("PolicyExecTime", 0)
                logger.record_tabular("EnvExecTime", 0)
                logger.record_tabular("ProcessExecTime", 0)

                logger.record_tabular("Tested Against", human_policy_name)
                logger.record_tabular("t", t)
                logger.record_tabular("a_H_agreement", optimal_a_H_freqs[t])
                logger.record_tabular("a_R_agreement", optimal_a_R_freqs[t])

                H_act_seqs_truncated = [a_Hs[0:t] for a_Hs in H_act_seqs]
                R_act_seqs_truncated = [a_Rs[0:t] for a_Rs in R_act_seqs]
                h_mutual_info = _mutual_info_seqs(H_act_seqs_truncated,
                                                  best_arms, bandit.nA + 1)
                r_mutual_info = _mutual_info_seqs(R_act_seqs_truncated,
                                                  best_arms, bandit.nA + 1)
                logger.record_tabular("h_mutual_info", h_mutual_info)
                logger.record_tabular("r_mutual_info", r_mutual_info)
                logger.record_tabular("a_H_opt_freq", optimal_a_H_freqs[t])
                logger.record_tabular("a_R_opt_freq", optimal_a_R_freqs[t])
                logger.dump_tabular()

    if text_output_file is not None:
        logger.remove_text_output(text_output_file)
        logger.remove_tabular_output(info_theory_tabular_output)
Ejemplo n.º 15
0
algo = TRPO(
    env=env,
    policy=policy,
    baseline=baseline,
    # batch_size=4000,
    batch_size=1000,
    max_path_length=env.horizon,
    n_itr=500,
    discount=0.99,
    step_size=0.0025,
    # Uncomment both lines (this and the plot parameter below) to enable plotting
    # plot=True,
)
algo.train()

logger.remove_tabular_output(tabular_log_file)
logger.remove_text_output(text_log_file)
logger.pop_prefix()
# run_experiment_lite(
#     algo.train(),
#     # Number of parallel workers for sampling
#     n_parallel=1,
#     # Only keep the snapshot parameters for the last iteration
#     snapshot_mode="last",
#     # Specifies the seed for the experiment. If this is not provided, a random seed
#     # will be used
#     seed=1,
#     # plot=True,
# )
Ejemplo n.º 16
0
def run_experiment(argv):

    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel',
                        type=int,
                        default=1,
                        help='Number of parallel workers to perform rollouts.')
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=default_log_dir,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from sandbox.vase.sampler import parallel_sampler_expl as parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)

        if args.seed is not None:
            set_seed(args.seed)
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    # read from stdin
    data = pickle.loads(base64.b64decode(args.args_data))

    log_dir = args.log_dir
    # exp_dir = osp.join(log_dir, args.exp_name)
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    maybe_iter = concretize(data)
    if is_iterable(maybe_iter):
        for _ in maybe_iter:
            pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Ejemplo n.º 17
0
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=
        'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
    )
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), "gap" (every'
                        '`snapshot_gap` iterations are saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--tensorboard_log_dir',
                        type=str,
                        default='tb',
                        help='Name of the folder for tensorboard_summary.')
    parser.add_argument(
        '--tensorboard_step_key',
        type=str,
        default=None,
        help=
        'Name of the step key in log data which shows the step in tensorboard_summary.'
    )
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False)
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='checkpoint',
                        help='Name of the folder for checkpoints.')
    parser.add_argument('--obs_dir',
                        type=str,
                        default='obs',
                        help='Name of the folder for original observations.')

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)
    tensorboard_log_dir = osp.join(log_dir, args.tensorboard_log_dir)
    checkpoint_dir = osp.join(log_dir, args.checkpoint_dir)
    obs_dir = osp.join(log_dir, args.obs_dir)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    logger.set_tensorboard_dir(tensorboard_log_dir)
    logger.set_checkpoint_dir(checkpoint_dir)
    logger.set_obs_dir(obs_dir)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.set_tensorboard_step_key(args.tensorboard_step_key)
    logger.push_prefix("[%s] " % args.exp_name)

    git_commit = get_git_commit_hash()
    logger.log('Git commit: {}'.format(git_commit))

    git_diff_file_path = osp.join(log_dir,
                                  'git_diff_{}.patch'.format(git_commit))
    save_git_diff_to_file(git_diff_file_path)

    logger.log('hostname: {}, pid: {}, tmux session: {}'.format(
        socket.gethostname(), os.getpid(), get_tmux_session_name()))

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()