Esempio n. 1
0
    def restore(self, checkpoint_dir=None):
        if checkpoint_dir is None: checkpoint_dir = logger.get_snapshot_dir()
        checkpoint_file = os.path.join(checkpoint_dir, 'params.chk')
        if os.path.isfile(checkpoint_file + '.meta'):
            sess = tf.get_default_session()
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint_file)

            tabular_chk_file = os.path.join(checkpoint_dir, 'progress.csv.chk')
            if os.path.isfile(tabular_chk_file):
                tabular_file = os.path.join(checkpoint_dir, 'progress.csv')
                logger.remove_tabular_output(tabular_file)
                shutil.copy(tabular_chk_file, tabular_file)
                logger.add_tabular_output(tabular_file)

            if self.qf is not None:
                pool_file = os.path.join(checkpoint_dir, 'pool.chk')
                if self.save_format == 'pickle':
                    pickle_load(pool_file)
                elif self.save_format == 'joblib':
                    self.pool = joblib.load(pool_file)
                else:
                    raise NotImplementedError

            logger.log('Restored from checkpoint %s' % checkpoint_file)
        else:
            logger.log('No checkpoint %s' % checkpoint_file)
Esempio n. 2
0
def logger_context(log_dir,
                   name,
                   run_ID,
                   log_params=None,
                   snapshot_mode="none"):
    logger.set_snapshot_mode(snapshot_mode)
    logger.set_log_tabular_only(False)
    abs_log_dir = os.path.abspath(log_dir)
    if LOG_DIR != os.path.commonpath([abs_log_dir, LOG_DIR]):
        print(
            "logger_context received log_dir outside of rllab.config.LOG_DIR: "
            "prepending by {}/local/<yyyymmdd>/".format(LOG_DIR))
        abs_log_dir = make_log_dir(log_dir)
    exp_dir = os.path.join(abs_log_dir, "{}_{}".format(name, run_ID))
    tabular_log_file = os.path.join(exp_dir, "progress.csv")
    text_log_file = os.path.join(exp_dir, "debug.log")
    params_log_file = os.path.join(exp_dir, "params.json")

    logger.set_snapshot_dir(exp_dir)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    logger.push_prefix("{}_{} ".format(name, run_ID))

    if log_params is None:
        log_params = dict()
    log_params["name"] = name
    log_params["run_ID"] = run_ID
    with open(params_log_file, "w") as f:
        json.dump(log_params, f)

    yield

    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Esempio n. 3
0
def shutdown_output():
    global OUTPUT_DIR
    assert OUTPUT_DIR is not None, "Cannot shutdown output that has not been setup."
    logger.remove_text_output(os.path.join(OUTPUT_DIR, "rllab.txt"))
    logger.remove_tabular_output(os.path.join(OUTPUT_DIR, "rllab.csv"))

    OUTPUT_DIR = None
Esempio n. 4
0
def rllab_logdir(algo=None, dirname=None):
    if dirname:
        rllablogger.set_snapshot_dir(dirname)
    dirname = rllablogger.get_snapshot_dir()
    rllablogger.add_tabular_output(os.path.join(dirname, 'progress.csv'))
    if algo:
        with open(os.path.join(dirname, 'params.json'), 'w') as f:
            params = extract_hyperparams(algo)
            json.dump(params, f)
    yield dirname
    rllablogger.remove_tabular_output(os.path.join(dirname, 'progress.csv'))
Esempio n. 5
0
def rllab_logdir(algo=None, dirname=None):
    if dirname:
        rllablogger.set_snapshot_dir(dirname)
    dirname = rllablogger.get_snapshot_dir()
    rllablogger.add_tabular_output(os.path.join(dirname, 'progress.csv'))
    if algo:
        with open(os.path.join(dirname, 'params.json'), 'w') as f:
            params = extract_hyperparams(algo)
            json.dump(params, f)
    yield dirname
    rllablogger.remove_tabular_output(os.path.join(dirname, 'progress.csv'))
Esempio n. 6
0
def run_experiment(algo,
                   n_parallel=0,
                   seed=0,
                   plot=False,
                   log_dir=None,
                   exp_name=None,
                   snapshot_mode='last',
                   snapshot_gap=1,
                   exp_prefix='experiment',
                   log_tabular_only=False):
    default_log_dir = config.LOG_DIR + "/local/" + exp_prefix
    set_seed(seed)
    if exp_name is None:
        now = datetime.datetime.now(dateutil.tz.tzlocal())
        timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
        exp_name = 'experiment_%s' % (timestamp)
    if n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=n_parallel)
        parallel_sampler.set_seed(seed)
    if plot:
        from rllab.plotter import plotter
        plotter.init_worker()
    if log_dir is None:
        log_dir = osp.join(default_log_dir, exp_name)
    tabular_log_file = osp.join(log_dir, 'progress.csv')
    text_log_file = osp.join(log_dir, 'debug.log')
    #params_log_file = osp.join(log_dir, 'params.json')

    #logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(snapshot_mode)
    logger.set_snapshot_gap(snapshot_gap)
    logger.set_log_tabular_only(log_tabular_only)
    logger.push_prefix("[%s] " % exp_name)

    algo.train()

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Esempio n. 7
0
    def train(self):

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            if self.load_policy is not None:
                self.policy = joblib.load(self.load_policy)['policy']
            self.init_opt()
            # initialize uninitialized vars (I know, it's ugly)

            uninit_vars = []
            for var in tf.all_variables():
                try:
                    sess.run(var)
                except tf.errors.FailedPreconditionError:
                    uninit_vars.append(var)
            sess.run(tf.initialize_variables(uninit_vars))
            #sess.run(tf.initialize_all_variables())
            self.start_worker()
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                if self.comet_logger:
                    self.comet_logger.set_step(itr + self.outer_iteration)
                if itr == self.n_itr - 1:
                    self.policy.std_modifier = 0.00001
                    #self.policy.std_modifier = 1
                    self.policy.recompute_dist_for_adjusted_std()
                if itr in self.goals_for_ET_dict.keys():
                    # self.policy.std_modifier = 0.0001
                    # self.policy.recompute_dist_for_adjusted_std()
                    goals = self.goals_for_ET_dict[itr]
                    noise = self.action_noise_test
                    self.batch_size = self.batch_size_expert_traj
                else:
                    if self.reset_arg is None:
                        goals = [None]
                    else:
                        goals = [self.reset_arg]
                    noise = self.action_noise_train
                    self.batch_size = self.batch_size_train
                paths_to_save = {}
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):

                    logger.log("Obtaining samples...")

                    preupdate = True if itr < self.n_itr - 1 else False
                    # paths_for_goal = self.obtain_samples(itr=itr, reset_args=[{'goal': goal, 'noise': noise}])  # when using oracle environments with changing noise, use this line!
                    paths = self.obtain_samples(itr=itr,
                                                reset_args=[self.reset_arg],
                                                preupdate=preupdate)

                    logger.log("Processing samples...")
                    samples_data = self.process_samples(itr, paths)
                    logger.log("Logging diagnostics...")
                    self.log_diagnostics(paths)

                    #new_param_values = self.policy.get_variable_values(self.policy.all_params)
                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(itr,
                                                   samples_data)  # , **kwargs)
                    if self.store_paths:
                        params["paths"] = samples_data["paths"]

                    logger.save_itr_params(itr,
                                           samples_data["paths"],
                                           file_name=str(self.reset_arg) +
                                           '.pkl')
                    # print(samples_data['paths'])
                    logger.log("Saved")
                    logger.log("Optimizing policy...")
                    self.optimize_policy(itr, samples_data)

                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime',
                                          time.time() - itr_start_time)

                    logger.dump_tabular(with_prefix=False)

        if self.log_dir is not None:
            logger.remove_tabular_output(self.log_dir + '/progress.csv')
        self.shutdown_worker()
Esempio n. 8
0
def run_experiment(argv):

    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel', type=int, default=1,
                        help='Number of parallel workers to perform rollouts.')
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
    parser.add_argument('--log_dir', type=str, default=default_log_dir,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--plot', type=ast.literal_eval, default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')

    args = parser.parse_args(argv[1:])

    from sandbox.vime.sampler import parallel_sampler_expl as parallel_sampler
    parallel_sampler.initialize(n_parallel=args.n_parallel)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    # read from stdin
    data = pickle.loads(base64.b64decode(args.args_data))

    log_dir = args.log_dir
    # exp_dir = osp.join(log_dir, args.exp_name)
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    maybe_iter = concretize(data)
    if is_iterable(maybe_iter):
        for _ in maybe_iter:
            pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=
        'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
    )
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), "gap" (every'
                        '`snapshot_gap` iterations are saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)
    #variant_data is the variant dictionary sent from trpoTests_ExpLite
    if (args.resume_from is not None) and (
            '&|&' in args.resume_from
    ):  #separate string on &|& to get iters and file location
        vals = args.resume_from.split(
            '&|&')  #dirRes | numItrs to go | new batchSize
        dirRes = vals[0]
        numItrs = int(vals[1])
        if (len(vals) > 2):
            batchSize = int(vals[2])
        print("resuming from :{}".format(dirRes))
        data = joblib.load(dirRes)
        #data is dict : 'baseline', 'algo', 'itr', 'policy', 'env'
        assert 'algo' in data
        algo = data['algo']
        assert 'policy' in data
        pol = data['policy']
        bl = data['baseline']
        oldBatchSize = algo.batch_size
        algo.n_itr = numItrs
        if (len(vals) > 2):
            algo.batch_size = batchSize
            print(
                'algo iters : {} cur iter :{} oldBatchSize : {} newBatchSize : {}'
                .format(algo.n_itr, algo.current_itr, oldBatchSize,
                        algo.batch_size))
        else:
            print('algo iters : {} cur iter :{} '.format(
                algo.n_itr, algo.current_itr))
        algo.train()
    else:
        print('Not resuming - building new exp')
        # read from stdin
        if args.use_cloudpickle:  #set to use cloudpickle
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            print('not use cloud pickle')
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
            optimal_a_H_freqs = _frequency_agreement(H_act_seqs,
                                                     optimal_a_seqs)
            optimal_a_R_freqs = _frequency_agreement(R_act_seqs,
                                                     optimal_a_seqs)

            for t in range(max_timesteps):
                logger.record_tabular("PolicyExecTime", 0)
                logger.record_tabular("EnvExecTime", 0)
                logger.record_tabular("ProcessExecTime", 0)

                logger.record_tabular("Tested Against", human_policy_name)
                logger.record_tabular("t", t)
                logger.record_tabular("a_H_agreement", optimal_a_H_freqs[t])
                logger.record_tabular("a_R_agreement", optimal_a_R_freqs[t])

                H_act_seqs_truncated = [a_Hs[0:t] for a_Hs in H_act_seqs]
                R_act_seqs_truncated = [a_Rs[0:t] for a_Rs in R_act_seqs]
                h_mutual_info = _mutual_info_seqs(H_act_seqs_truncated,
                                                  best_arms, bandit.nA + 1)
                r_mutual_info = _mutual_info_seqs(R_act_seqs_truncated,
                                                  best_arms, bandit.nA + 1)
                logger.record_tabular("h_mutual_info", h_mutual_info)
                logger.record_tabular("r_mutual_info", r_mutual_info)
                logger.record_tabular("a_H_opt_freq", optimal_a_H_freqs[t])
                logger.record_tabular("a_R_opt_freq", optimal_a_R_freqs[t])
                logger.dump_tabular()

    if text_output_file is not None:
        logger.remove_text_output(text_output_file)
        logger.remove_tabular_output(info_theory_tabular_output)
Esempio n. 11
0
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=
        'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
    )
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), "gap" (every'
                        '`snapshot_gap` iterations are saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--tensorboard_log_dir',
                        type=str,
                        default='tb',
                        help='Name of the folder for tensorboard_summary.')
    parser.add_argument(
        '--tensorboard_step_key',
        type=str,
        default=None,
        help=
        'Name of the step key in log data which shows the step in tensorboard_summary.'
    )
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False)
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='checkpoint',
                        help='Name of the folder for checkpoints.')
    parser.add_argument('--obs_dir',
                        type=str,
                        default='obs',
                        help='Name of the folder for original observations.')

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)
    tensorboard_log_dir = osp.join(log_dir, args.tensorboard_log_dir)
    checkpoint_dir = osp.join(log_dir, args.checkpoint_dir)
    obs_dir = osp.join(log_dir, args.obs_dir)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    logger.set_tensorboard_dir(tensorboard_log_dir)
    logger.set_checkpoint_dir(checkpoint_dir)
    logger.set_obs_dir(obs_dir)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.set_tensorboard_step_key(args.tensorboard_step_key)
    logger.push_prefix("[%s] " % args.exp_name)

    git_commit = get_git_commit_hash()
    logger.log('Git commit: {}'.format(git_commit))

    git_diff_file_path = osp.join(log_dir,
                                  'git_diff_{}.patch'.format(git_commit))
    save_git_diff_to_file(git_diff_file_path)

    logger.log('hostname: {}, pid: {}, tmux session: {}'.format(
        socket.gethostname(), os.getpid(), get_tmux_session_name()))

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Esempio n. 12
0
def run_experiment(
    args_data,
    variant_data=None,
    seed=None,
    n_parallel=1,
    exp_name=None,
    log_dir=None,
    snapshot_mode='all',
    snapshot_gap=1,
    tabular_log_file='progress.csv',
    text_log_file='debug.log',
    params_log_file='params.json',
    variant_log_file='variant.json',
    resume_from=None,
    plot=False,
    log_tabular_only=False,
    log_debug_log_only=False,
):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    if exp_name is None:
        exp_name = default_exp_name

    if seed is not None:
        set_seed(seed)

    if n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=n_parallel)
        if seed is not None:
            parallel_sampler.set_seed(seed)

    if plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if log_dir is None:
        log_dir = osp.join(default_log_dir, exp_name)
    else:
        log_dir = log_dir
    tabular_log_file = osp.join(log_dir, tabular_log_file)
    text_log_file = osp.join(log_dir, text_log_file)
    params_log_file = osp.join(log_dir, params_log_file)

    if variant_data is not None:
        variant_data = variant_data
        variant_log_file = osp.join(log_dir, variant_log_file)
        # print(variant_log_file)
        # print(variant_data)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(snapshot_mode)
    logger.set_snapshot_gap(snapshot_gap)
    logger.set_log_tabular_only(log_tabular_only)
    logger.set_debug_log_only(log_debug_log_only)
    logger.push_prefix("[%s] " % exp_name)

    if resume_from is not None:
        data = joblib.load(resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        args_data(variant_data)

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Esempio n. 13
0
def close_logger(log_dir):
    logger.remove_tabular_output(osp.join(log_dir, 'progress.csv'))
    logger.remove_text_output(osp.join(log_dir, 'debug.log'))
    logger.pop_prefix()
def run_task(v, num_cpu=8, log_dir="./data", ename=None, **kwargs):
    from scipy.stats import bernoulli, uniform, beta

    import tensorflow as tf
    from assistive_bandits.experiments.l2_rnn_baseline import L2RNNBaseline
    from assistive_bandits.experiments.tbptt_optimizer import TBPTTOptimizer
    from sandbox.rocky.tf.policies.categorical_gru_policy import CategoricalGRUPolicy
    from assistive_bandits.experiments.pposgd_clip_ratio import PPOSGD

    if not local_test and force_remote:
        import rl_algs.logger as rl_algs_logger
        log_dir = rl_algs_logger.get_dir()

    if log_dir is not None:
        log_dir = osp.join(log_dir, str(v["n_episodes"]))
        log_dir = osp.join(log_dir, v["human_policy"])

    text_output_file = None if log_dir is None else osp.join(log_dir, "text")
    tabular_output_file = None if log_dir is None else osp.join(
        log_dir, "train_table.csv")
    info_theory_tabular_output = None if log_dir is None else osp.join(
        log_dir, "info_table.csv")
    rag = uniform_bernoulli_iterator()
    bandit = BanditEnv(n_arms=v["n_arms"],
                       reward_dist=bernoulli,
                       reward_args_generator=rag,
                       horizon=v["n_episodes"])
    pi_H = human_policy_dict[v["human_policy"]](bandit)

    h_wrapper = human_wrapper_dict[v["human_wrapper"]]

    env = h_wrapper(bandit, pi_H, penalty=v["intervention_penalty"])

    if text_output_file is not None:
        logger.add_text_output(text_output_file)
        logger.add_tabular_output(tabular_output_file)

    logger.log("Training against {}".format(v["human_policy"]))
    logger.log("Setting seed to {}".format(v["seed"]))
    env.seed(v["seed"])

    baseline = L2RNNBaseline(
        name="vf",
        env_spec=env.spec,
        log_loss_before=False,
        log_loss_after=False,
        hidden_nonlinearity=getattr(tf.nn, v["nonlinearity"]),
        weight_normalization=v["weight_normalization"],
        layer_normalization=v["layer_normalization"],
        state_include_action=False,
        hidden_dim=v["hidden_dim"],
        optimizer=TBPTTOptimizer(
            batch_size=v["opt_batch_size"],
            n_steps=v["opt_n_steps"],
            n_epochs=v["min_epochs"],
        ),
        batch_size=v["opt_batch_size"],
        n_steps=v["opt_n_steps"],
    )
    policy = CategoricalGRUPolicy(env_spec=env.spec,
                                  hidden_nonlinearity=getattr(
                                      tf.nn, v["nonlinearity"]),
                                  hidden_dim=v["hidden_dim"],
                                  state_include_action=True,
                                  name="policy")

    n_itr = 3 if local_test else 100
    # logger.log('sampler_args {}'.format(dict(n_envs=max(1, min(int(np.ceil(v["batch_size"] / v["n_episodes"])), 100)))))

    # parallel_sampler.initialize(6)
    # parallel_sampler.set_seed(v["seed"])

    algo = PPOSGD(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=v["batch_size"],
        max_path_length=v["n_episodes"],
        # 43.65 env time
        sampler_args=dict(n_envs=max(
            1, min(int(np.ceil(v["batch_size"] / v["n_episodes"])), 100))),
        # 100 threads -> 1:36 to sample 187.275
        # 6 threads -> 1:31
        # force_batch_sampler=True,
        n_itr=n_itr,
        step_size=v["mean_kl"],
        clip_lr=v["clip_lr"],
        log_loss_kl_before=False,
        log_loss_kl_after=False,
        use_kl_penalty=v["use_kl_penalty"],
        min_n_epochs=v["min_epochs"],
        entropy_bonus_coeff=v["entropy_bonus_coeff"],
        optimizer=TBPTTOptimizer(
            batch_size=v["opt_batch_size"],
            n_steps=v["opt_n_steps"],
            n_epochs=v["min_epochs"],
        ),
        discount=v["discount"],
        gae_lambda=v["gae_lambda"],
        use_line_search=True
        # scope=ename
    )

    sess = tf.Session()
    with sess.as_default():
        algo.train(sess)

        if text_output_file is not None:
            logger.remove_tabular_output(tabular_output_file)
            logger.add_tabular_output(info_theory_tabular_output)

        # Now gather statistics for t-tests and such!
        for human_policy_name, human_policy in human_policy_dict.items():

            logger.log("-------------------")
            logger.log("Evaluating against {}".format(human_policy.__name__))
            logger.log("-------------------")

            logger.log("Obtaining Samples...")
            test_pi_H = human_policy(bandit)
            test_env = h_wrapper(bandit, test_pi_H, penalty=0.)
            eval_sampler = VectorizedSampler(algo, n_envs=100)
            algo.batch_size = v["num_eval_traj"] * v["n_episodes"]
            algo.env = test_env
            logger.log("algo.env.pi_H has class: {}".format(
                algo.env.pi_H.__class__))
            eval_sampler.start_worker()
            paths = eval_sampler.obtain_samples(-1)
            eval_sampler.shutdown_worker()
            rewards = []

            H_act_seqs = []
            R_act_seqs = []
            best_arms = []
            optimal_a_seqs = []
            for p in paths:
                a_Rs = env.action_space.unflatten_n(p['actions'])
                obs_R = env.observation_space.unflatten_n(p['observations'])
                best_arm = np.argmax(p['env_infos']['arm_means'][0])

                H_act_seqs.append(obs_R[:, 1])
                R_act_seqs.append(a_Rs)
                best_arms.append(best_arm)
                optimal_a_seqs.append(
                    [best_arm for _ in range(v["n_episodes"])])

                rewards.append(np.sum(p['rewards']))

            #feel free to add more data
            logger.log("NumTrajs {}".format(v["num_eval_traj"]))
            logger.log("AverageReturn {}".format(np.mean(rewards)))
            logger.log("StdReturn {}".format(np.std(rewards)))
            logger.log("MaxReturn {}".format(np.max(rewards)))
            logger.log("MinReturn {}".format(np.min(rewards)))

            optimal_a_H_freqs = _frequency_agreement(H_act_seqs,
                                                     optimal_a_seqs)
            optimal_a_R_freqs = _frequency_agreement(R_act_seqs,
                                                     optimal_a_seqs)

            for t in range(v["n_episodes"]):
                logger.record_tabular("PolicyExecTime", 0)
                logger.record_tabular("EnvExecTime", 0)
                logger.record_tabular("ProcessExecTime", 0)
                logger.record_tabular("Tested Against", human_policy_name)
                logger.record_tabular("t", t)
                logger.record_tabular("a_H_agreement", optimal_a_H_freqs[t])
                logger.record_tabular("a_R_agreement", optimal_a_R_freqs[t])

                H_act_seqs_truncated = [a_Hs[0:t] for a_Hs in H_act_seqs]
                R_act_seqs_truncated = [a_Rs[0:t] for a_Rs in R_act_seqs]
                h_mutual_info = _mutual_info_seqs(H_act_seqs_truncated,
                                                  best_arms, v["n_arms"] + 1)
                r_mutual_info = _mutual_info_seqs(R_act_seqs_truncated,
                                                  best_arms, v["n_arms"] + 1)
                logger.record_tabular("h_mutual_info", h_mutual_info)
                logger.record_tabular("r_mutual_info", r_mutual_info)
                logger.record_tabular("a_H_opt_freq", optimal_a_H_freqs[t])
                logger.record_tabular("a_R_opt_freq", optimal_a_R_freqs[t])
                logger.dump_tabular()

            test_env = h_wrapper(bandit, human_policy(bandit), penalty=0.)

            logger.log("Printing Example Trajectories")
            for i in range(v["num_display_traj"]):
                observation = test_env.reset()
                policy.reset()
                logger.log("-- Trajectory {} of {}".format(
                    i + 1, v["num_display_traj"]))
                logger.log("t \t obs \t act \t reward \t act_probs")
                for t in range(v["n_episodes"]):
                    action, act_info = policy.get_action(observation)
                    new_obs, reward, done, info = test_env.step(action)
                    logger.log("{} \t {} \t {} \t {} \t {}".format(
                        t, observation, action, reward, act_info['prob']))
                    observation = new_obs
                    if done:
                        logger.log("Total reward: {}".format(
                            info["accumulated rewards"]))
                        break

    if text_output_file is not None:
        logger.remove_text_output(text_output_file)
        logger.remove_tabular_output(info_theory_tabular_output)
Esempio n. 15
0
def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel', type=int, default=1,
                        help='Number of parallel workers to perform rollouts. 0 => don\'t start any workers')
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
    parser.add_argument('--log_dir', type=str, default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), "gap" (every'
                             '`snapshot_gap` iterations are saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--snapshot_gap', type=int, default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file', type=str, default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument('--resume_from', type=str, default=None,
                        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot', type=ast.literal_eval, default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data', type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Esempio n. 16
0
def run_experiment(argv):
    # e2crawfo: These imports, in this order, were necessary for fixing issues on cedar.
    import rllab.mujoco_py.mjlib
    import tensorflow

    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=
        'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
    )
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False)

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        logger.log_parameters_lite(params_log_file, args)

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_tf_summary_dir(osp.join(log_dir, "tf_summary"))
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        maybe_iter = algo.train()
        if is_iterable(maybe_iter):
            for _ in maybe_iter:
                pass
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Esempio n. 17
0
        #     if params['use_hide_alg'] == 1:
        #         if params['batch_size_uniform'] is not None and params['batch_size_uniform'] > 0:
        #             logger.log('WARNING: Training with uniform sampling. Testing is done BEFORE the optimization !!!!')
        #             algo.train_brownian_with_uniform()
        #         else:
        #             algo.train_brownian()
        #     elif params['use_hide_alg'] == 2:
        #         if train_mode == 0:
        #             algo.train_brownian_with_goals()
        #         elif train_mode == 1:
        #             algo.train_brownian_reverse_repeat()
        #         elif train_mode == 2:
        #             algo.train_brownian_multiseed()
        #         elif train_mode == 3:
        #             algo.train_brownian_multiseed_swap_every_update_period()
        #         else:
        #             raise NotImplementedError
        #     else: #my version of the alg
        #         algo.train_hide_seek()
        # else:
        #     algo.train_seek()
        logger.log('Experiment finished ...')

        logger.set_snapshot_mode(prev_mode)
        logger.set_snapshot_dir(prev_snapshot_dir)
        logger.remove_tabular_output(tabular_log_file_fullpath)
        logger.remove_text_output(text_log_file_fullpath)
        logger.pop_prefix()

        print('Tabular log file:', tabular_log_file_fullpath)
Esempio n. 18
0
def run_experiment(argv):

    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel',
                        type=int,
                        default=1,
                        help='Number of parallel workers to perform rollouts.')
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=default_log_dir,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from sandbox.vase.sampler import parallel_sampler_expl as parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)

        if args.seed is not None:
            set_seed(args.seed)
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    # read from stdin
    data = pickle.loads(base64.b64decode(args.args_data))

    log_dir = args.log_dir
    # exp_dir = osp.join(log_dir, args.exp_name)
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    maybe_iter = concretize(data)
    if is_iterable(maybe_iter):
        for _ in maybe_iter:
            pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
Esempio n. 19
0
algo = TRPO(
    env=env,
    policy=policy,
    baseline=baseline,
    # batch_size=4000,
    batch_size=1000,
    max_path_length=env.horizon,
    n_itr=500,
    discount=0.99,
    step_size=0.0025,
    # Uncomment both lines (this and the plot parameter below) to enable plotting
    # plot=True,
)
algo.train()

logger.remove_tabular_output(tabular_log_file)
logger.remove_text_output(text_log_file)
logger.pop_prefix()
# run_experiment_lite(
#     algo.train(),
#     # Number of parallel workers for sampling
#     n_parallel=1,
#     # Only keep the snapshot parameters for the last iteration
#     snapshot_mode="last",
#     # Specifies the seed for the experiment. If this is not provided, a random seed
#     # will be used
#     seed=1,
#     # plot=True,
# )