Exemple #1
0
def rllab_envpolicy_parser(env, args):
    if isinstance(args, dict):
        args = tonamedtuple(args)

    env = RLLabEnv(env, mode=args.control)
    if args.algo[:2] == 'tf':
        env = TfEnv(env)

        # Policy
        if args.recurrent:
            if args.feature_net:
                feature_network = MLP(
                    name='feature_net',
                    input_shape=(env.spec.observation_space.flat_dim +
                                 env.spec.action_space.flat_dim, ),
                    output_dim=args.feature_output,
                    hidden_sizes=tuple(args.feature_hidden),
                    hidden_nonlinearity=tf.nn.tanh,
                    output_nonlinearity=None)
            elif args.conv:
                strides = tuple(args.conv_strides)
                chans = tuple(args.conv_channels)
                filts = tuple(args.conv_filters)

                assert len(strides) == len(chans) == len(
                    filts), "strides, chans and filts not equal"
                # only discrete actions supported, should be straightforward to extend to continuous
                assert isinstance(
                    env.spec.action_space,
                    Discrete), "Only discrete action spaces support conv"
                feature_network = ConvNetwork(
                    name='feature_net',
                    input_shape=env.spec.observation_space.shape,
                    output_dim=args.feature_output,
                    conv_filters=chans,
                    conv_filter_sizes=filts,
                    conv_strides=strides,
                    conv_pads=('VALID', ) * len(chans),
                    hidden_sizes=tuple(args.feature_hidden),
                    hidden_nonlinearity=tf.nn.relu,
                    output_nonlinearity=None)
            else:
                feature_network = None
            if args.recurrent == 'gru':
                if isinstance(env.spec.action_space, Box):
                    policy = GaussianGRUPolicy(env_spec=env.spec,
                                               feature_network=feature_network,
                                               hidden_dim=int(
                                                   args.policy_hidden[0]),
                                               name='policy')
                elif isinstance(env.spec.action_space, Discrete):
                    policy = CategoricalGRUPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden[0]),
                        name='policy',
                        state_include_action=False if args.conv else True)
                else:
                    raise NotImplementedError(env.spec.observation_space)

            elif args.recurrent == 'lstm':
                if isinstance(env.spec.action_space, Box):
                    policy = GaussianLSTMPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden),
                        name='policy')
                elif isinstance(env.spec.action_space, Discrete):
                    policy = CategoricalLSTMPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden),
                        name='policy')
                else:
                    raise NotImplementedError(env.spec.action_space)

            else:
                raise NotImplementedError(args.recurrent)
        elif args.conv:
            strides = tuple(args.conv_strides)
            chans = tuple(args.conv_channels)
            filts = tuple(args.conv_filters)

            assert len(strides) == len(chans) == len(
                filts), "strides, chans and filts not equal"
            # only discrete actions supported, should be straightforward to extend to continuous
            assert isinstance(
                env.spec.action_space,
                Discrete), "Only discrete action spaces support conv"
            feature_network = ConvNetwork(
                name='feature_net',
                input_shape=env.spec.observation_space.shape,
                output_dim=env.spec.action_space.n,
                conv_filters=chans,
                conv_filter_sizes=filts,
                conv_strides=strides,
                conv_pads=('VALID', ) * len(chans),
                hidden_sizes=tuple(args.policy_hidden),
                hidden_nonlinearity=tf.nn.relu,
                output_nonlinearity=tf.nn.softmax)
            policy = CategoricalMLPPolicy(name='policy',
                                          env_spec=env.spec,
                                          prob_network=feature_network)
        else:
            if isinstance(env.spec.action_space, Box):
                policy = GaussianMLPPolicy(env_spec=env.spec,
                                           hidden_sizes=tuple(
                                               args.policy_hidden),
                                           min_std=args.min_std,
                                           name='policy')
            elif isinstance(env.spec.action_space, Discrete):
                policy = CategoricalMLPPolicy(env_spec=env.spec,
                                              hidden_sizes=tuple(
                                                  args.policy_hidden),
                                              name='policy')
            else:
                raise NotImplementedError(env.spec.action_space)
    elif args.algo[:2] == 'th':
        # Policy
        if args.recurrent:
            if args.feature_net:
                feature_network = thMLP(
                    input_shape=(env.spec.observation_space.flat_dim +
                                 env.spec.action_space.flat_dim, ),
                    output_dim=args.feature_output,
                    hidden_sizes=tuple(args.feature_hidden),
                    hidden_nonlinearity=tf.nn.tanh,
                    output_nonlinearity=None)
            else:
                feature_network = None
            if args.recurrent == 'gru':
                if isinstance(env.spec.observation_space, thBox):
                    policy = thGaussianGRUPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden[0]),
                    )
                elif isinstance(env.spec.observation_space, thDiscrete):
                    policy = thCategoricalGRUPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden[0]),
                    )
                else:
                    raise NotImplementedError(env.spec.observation_space)

            # elif args.recurrent == 'lstm':
            #     if isinstance(env.spec.action_space, thBox):
            #         policy = thGaussianLSTMPolicy(env_spec=env.spec,
            #                                       feature_network=feature_network,
            #                                       hidden_dim=int(args.policy_hidden),
            #                                       name='policy')
            #     elif isinstance(env.spec.action_space, thDiscrete):
            #         policy = thCategoricalLSTMPolicy(env_spec=env.spec,
            #                                          feature_network=feature_network,
            #                                          hidden_dim=int(args.policy_hidden),
            #                                          name='policy')
            #     else:
            #         raise NotImplementedError(env.spec.action_space)

            else:
                raise NotImplementedError(args.recurrent)
        else:
            if args.algo == 'thddpg':
                assert isinstance(env.spec.action_space, thBox)
                policy = thDeterministicMLPPolicy(
                    env_spec=env.spec,
                    hidden_sizes=tuple(args.policy_hidden),
                )
            else:
                if isinstance(env.spec.action_space, thBox):
                    policy = thGaussianMLPPolicy(env_spec=env.spec,
                                                 hidden_sizes=tuple(
                                                     args.policy_hidden),
                                                 min_std=args.min_std)
                elif isinstance(env.spec.action_space, thDiscrete):
                    policy = thCategoricalMLPPolicy(env_spec=env.spec,
                                                    hidden_sizes=tuple(
                                                        args.policy_hidden),
                                                    min_std=args.min_std)
                else:
                    raise NotImplementedError(env.spec.action_space)

    if args.control == 'concurrent':
        return env, policies
    else:
        return env, policy
logger.add_text_output(text_log_file)
logger.add_tabular_output(tabular_log_file)
prev_snapshot_dir = logger.get_snapshot_dir()
prev_mode = logger.get_snapshot_mode()
logger.set_snapshot_dir(log_dir)
logger.set_snapshot_mode("gaplast")
logger.set_snapshot_gap(1000)
logger.set_log_tabular_only(False)
logger.push_prefix("[%s] " % "FixMapStartState")

from Algo import parallel_sampler
parallel_sampler.initialize(n_parallel=1)
parallel_sampler.set_seed(0)

policy = CategoricalGRUPolicy(
    env_spec=env.spec,
    name="gru",
)

baseline = LinearFeatureBaseline(env_spec=env.spec)

with tf.Session() as sess:
    writer = tf.summary.FileWriter(logdir=log_dir, )

    algo = VPG_t(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=2048,  #2*env._wrapped_env.params['traj_limit'],
        max_path_length=env._wrapped_env.params['traj_limit'],
        n_itr=10000,
        discount=0.95,
Exemple #3
0
    def parse_env_args(self, env, args):

        if isinstance(args, dict):
            args = to_named_tuple(args)

        # Multi-agent wrapper
        env = RLLabEnv(env, ma_mode=args.control)
        env = MATfEnv(env)

        # Policy
        if args.recurrent:
            if args.feature_net:
                feature_network = MLP(
                    name='feature_net',
                    input_shape=(env.spec.observation_space.flat_dim +
                                 env.spec.action_space.flat_dim, ),
                    output_dim=args.feature_output,
                    hidden_sizes=tuple(args.feature_hidden),
                    hidden_nonlinearity=tf.nn.tanh,
                    output_nonlinearity=None)
            elif args.conv:
                strides = tuple(args.conv_strides)
                chans = tuple(args.conv_channels)
                filts = tuple(args.conv_filters)

                assert len(strides) == len(chans) == len(
                    filts), "strides, chans and filts not equal"
                # only discrete actions supported, should be straightforward to extend to continuous
                assert isinstance(
                    env.spec.action_space,
                    Discrete), "Only discrete action spaces support conv"
                feature_network = ConvNetwork(
                    name='feature_net',
                    input_shape=env.spec.observation_space.shape,
                    output_dim=args.feature_output,
                    conv_filters=chans,
                    conv_filter_sizes=filts,
                    conv_strides=strides,
                    conv_pads=('VALID', ) * len(chans),
                    hidden_sizes=tuple(args.feature_hidden),
                    hidden_nonlinearity=tf.nn.relu,
                    output_nonlinearity=None)
            else:
                feature_network = None
            if args.recurrent == 'gru':
                if isinstance(env.spec.action_space, Box):
                    if args.control == 'concurrent':
                        policies = [
                            GaussianGRUPolicy(env_spec=env.spec,
                                              feature_network=feature_network,
                                              hidden_dim=int(
                                                  args.policy_hidden[0]),
                                              name='policy_{}'.format(agid))
                            for agid in range(len(env.agents))
                        ]
                    policy = GaussianGRUPolicy(env_spec=env.spec,
                                               feature_network=feature_network,
                                               hidden_dim=int(
                                                   args.policy_hidden[0]),
                                               name='policy')
                elif isinstance(env.spec.action_space, Discrete):
                    if args.control == 'concurrent':
                        policies = [
                            CategoricalGRUPolicy(
                                env_spec=env.spec,
                                feature_network=feature_network,
                                hidden_dim=int(args.policy_hidden[0]),
                                name='policy_{}'.format(agid),
                                state_include_action=False
                                if args.conv else True)
                            for agid in range(len(env.agents))
                        ]
                    q_network = CategoricalGRUPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden[0]),
                        name='q_network',
                        state_include_action=False if args.conv else True)
                    target_q_network = CategoricalGRUPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden[0]),
                        name='target_q_network',
                        state_include_action=False if args.conv else True)
                    policy = {
                        'q_network': q_network,
                        'target_q_network': target_q_network
                    }
                else:
                    raise NotImplementedError(env.spec.observation_space)

            elif args.recurrent == 'lstm':
                if isinstance(env.spec.action_space, Box):
                    if args.control == 'concurrent':
                        policies = [
                            GaussianLSTMPolicy(env_spec=env.spec,
                                               feature_network=feature_network,
                                               hidden_dim=int(
                                                   args.policy_hidden),
                                               name='policy_{}'.format(agid))
                            for agid in range(len(env.agents))
                        ]
                    policy = GaussianLSTMPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden),
                        name='policy')
                elif isinstance(env.spec.action_space, Discrete):
                    if args.control == 'concurrent':
                        policies = [
                            CategoricalLSTMPolicy(
                                env_spec=env.spec,
                                feature_network=feature_network,
                                hidden_dim=int(args.policy_hidden),
                                name='policy_{}'.format(agid))
                            for agid in range(len(env.agents))
                        ]
                    q_network = CategoricalLSTMPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden),
                        name='q_network')
                    target_q_network = CategoricalLSTMPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden),
                        name='target_q_network')
                    policy = {
                        'q_network': q_network,
                        'target_q_network': target_q_network
                    }
                else:
                    raise NotImplementedError(env.spec.action_space)

            else:
                raise NotImplementedError(args.recurrent)
        elif args.conv:
            strides = tuple(args.conv_strides)
            chans = tuple(args.conv_channels)
            filts = tuple(args.conv_filters)

            assert len(strides) == len(chans) == len(
                filts), "strides, chans and filts not equal"
            # only discrete actions supported, should be straightforward to extend to continuous
            assert isinstance(
                env.spec.action_space,
                Discrete), "Only discrete action spaces support conv"
            feature_network = ConvNetwork(
                name='feature_net',
                input_shape=env.spec.observation_space.shape,
                output_dim=env.spec.action_space.n,
                conv_filters=chans,
                conv_filter_sizes=filts,
                conv_strides=strides,
                conv_pads=(args.conv_pads, ) * len(chans),
                hidden_sizes=tuple(args.policy_hidden),
                hidden_nonlinearity=tf.nn.relu,
                output_nonlinearity=tf.nn.softmax,
                batch_normalization=args.batch_normalization)
            if args.algo == 'dqn':
                q_network = CategoricalMLPPolicy(name='q_network',
                                                 env_spec=env.spec,
                                                 prob_network=feature_network)
                target_q_network = CategoricalMLPPolicy(
                    name='target_q_network',
                    env_spec=env.spec,
                    prob_network=feature_network)
                policy = {
                    'q_network': q_network,
                    'target_q_network': target_q_network
                }

            else:
                policy = CategoricalMLPPolicy(name='policy',
                                              env_spec=env.spec,
                                              prob_network=feature_network)
        else:
            if env.spec is None:

                networks = [
                    DQNNetwork(i,
                               env,
                               target_network_update_freq=self.args.
                               target_network_update,
                               discount_factor=self.args.discount,
                               batch_size=self.args.batch_size,
                               learning_rate=self.args.qfunc_lr)
                    for i in range(env.n)
                ]

                policy = networks

            elif isinstance(env.spec.action_space, Box):
                policy = GaussianMLPPolicy(env_spec=env.spec,
                                           hidden_sizes=tuple(
                                               args.policy_hidden),
                                           min_std=args.min_std,
                                           name='policy')
            elif isinstance(env.spec.action_space, Discrete):
                policy = CategoricalMLPPolicy(env_spec=env.spec,
                                              hidden_sizes=tuple(
                                                  args.policy_hidden),
                                              name='policy')
            else:
                raise NotImplementedError(env.spec.action_space)

        return env, policy
Exemple #4
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.99)
    parser.add_argument('--gae_lambda', type=float, default=1.0)
    parser.add_argument('--reward_scale', type=float, default=1.0)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum',
                        action='store_true',
                        default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--rectangle', type=str, default='10,10')
    parser.add_argument('--map_type', type=str, default='rectangle')
    parser.add_argument('--n_evaders', type=int, default=5)
    parser.add_argument('--n_pursuers', type=int, default=2)
    parser.add_argument('--obs_range', type=int, default=3)
    parser.add_argument('--n_catch', type=int, default=2)
    parser.add_argument('--urgency', type=float, default=0.0)
    parser.add_argument('--pursuit', dest='train_pursuit', action='store_true')
    parser.add_argument('--evade', dest='train_pursuit', action='store_false')
    parser.set_defaults(train_pursuit=True)
    parser.add_argument('--surround', action='store_true', default=False)
    parser.add_argument('--constraint_window', type=float, default=1.0)
    parser.add_argument('--sample_maps', action='store_true', default=False)
    parser.add_argument('--map_file', type=str, default='../maps/map_pool.npy')
    parser.add_argument('--flatten', action='store_true', default=False)
    parser.add_argument('--reward_mech', type=str, default='global')
    parser.add_argument('--catchr', type=float, default=0.1)
    parser.add_argument('--term_pursuit', type=float, default=5.0)

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_type', type=str, default='linear')

    parser.add_argument('--conv', action='store_true', default=False)

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--checkpoint', type=str, default=None)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )

    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    if args.checkpoint:
        with tf.Session() as sess:
            data = joblib.load(args.checkpoint)
            policy = data['policy']
            env = data['env']
    else:
        if args.sample_maps:
            map_pool = np.load(args.map_file)
        else:
            if args.map_type == 'rectangle':
                env_map = TwoDMaps.rectangle_map(
                    *map(int, args.rectangle.split(',')))
            elif args.map_type == 'complex':
                env_map = TwoDMaps.complex_map(
                    *map(int, args.rectangle.split(',')))
            else:
                raise NotImplementedError()
            map_pool = [env_map]

        env = PursuitEvade(map_pool,
                           n_evaders=args.n_evaders,
                           n_pursuers=args.n_pursuers,
                           obs_range=args.obs_range,
                           n_catch=args.n_catch,
                           train_pursuit=args.train_pursuit,
                           urgency_reward=args.urgency,
                           surround=args.surround,
                           sample_maps=args.sample_maps,
                           constraint_window=args.constraint_window,
                           flatten=args.flatten,
                           reward_mech=args.reward_mech,
                           catchr=args.catchr,
                           term_pursuit=args.term_pursuit)

        env = TfEnv(
            RLLabEnv(StandardizedEnv(env,
                                     scale_reward=args.reward_scale,
                                     enable_obsnorm=False),
                     mode=args.control))

        if args.recurrent:
            if args.conv:
                feature_network = ConvNetwork(
                    name='feature_net',
                    input_shape=emv.spec.observation_space.shape,
                    output_dim=5,
                    conv_filters=(16, 32, 32),
                    conv_filter_sizes=(3, 3, 3),
                    conv_strides=(1, 1, 1),
                    conv_pads=('VALID', 'VALID', 'VALID'),
                    hidden_sizes=(64, ),
                    hidden_nonlinearity=tf.nn.relu,
                    output_nonlinearity=tf.nn.softmax)
            else:
                feature_network = MLP(
                    name='feature_net',
                    input_shape=(env.spec.observation_space.flat_dim +
                                 env.spec.action_space.flat_dim, ),
                    output_dim=5,
                    hidden_sizes=(256, 128, 64),
                    hidden_nonlinearity=tf.nn.tanh,
                    output_nonlinearity=None)
            if args.recurrent == 'gru':
                policy = CategoricalGRUPolicy(env_spec=env.spec,
                                              feature_network=feature_network,
                                              hidden_dim=int(
                                                  args.policy_hidden_sizes),
                                              name='policy')
            elif args.recurrent == 'lstm':
                policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                               feature_network=feature_network,
                                               hidden_dim=int(
                                                   args.policy_hidden_sizes),
                                               name='policy')
        elif args.conv:
            feature_network = ConvNetwork(
                name='feature_net',
                input_shape=env.spec.observation_space.shape,
                output_dim=5,
                conv_filters=(8, 16),
                conv_filter_sizes=(3, 3),
                conv_strides=(2, 1),
                conv_pads=('VALID', 'VALID'),
                hidden_sizes=(32, ),
                hidden_nonlinearity=tf.nn.relu,
                output_nonlinearity=tf.nn.softmax)
            policy = CategoricalMLPPolicy(name='policy',
                                          env_spec=env.spec,
                                          prob_network=feature_network)
        else:
            policy = CategoricalMLPPolicy(name='policy',
                                          env_spec=env.spec,
                                          hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(env_spec=env.spec)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp(
            base_eps=1e-5)) if args.recurrent else None,
        mode=args.control,
    )

    algo.train()
def run_task(v, num_cpu=8, log_dir="./data", ename=None, **kwargs):
    from scipy.stats import bernoulli, uniform, beta

    import tensorflow as tf
    from assistive_bandits.experiments.l2_rnn_baseline import L2RNNBaseline
    from assistive_bandits.experiments.tbptt_optimizer import TBPTTOptimizer
    from sandbox.rocky.tf.policies.categorical_gru_policy import CategoricalGRUPolicy
    from assistive_bandits.experiments.pposgd_clip_ratio import PPOSGD

    if not local_test and force_remote:
        import rl_algs.logger as rl_algs_logger
        log_dir = rl_algs_logger.get_dir()

    if log_dir is not None:
        log_dir = osp.join(log_dir, str(v["n_episodes"]))
        log_dir = osp.join(log_dir, v["human_policy"])

    text_output_file = None if log_dir is None else osp.join(log_dir, "text")
    tabular_output_file = None if log_dir is None else osp.join(
        log_dir, "train_table.csv")
    info_theory_tabular_output = None if log_dir is None else osp.join(
        log_dir, "info_table.csv")
    rag = uniform_bernoulli_iterator()
    bandit = BanditEnv(n_arms=v["n_arms"],
                       reward_dist=bernoulli,
                       reward_args_generator=rag,
                       horizon=v["n_episodes"])
    pi_H = human_policy_dict[v["human_policy"]](bandit)

    h_wrapper = human_wrapper_dict[v["human_wrapper"]]

    env = h_wrapper(bandit, pi_H, penalty=v["intervention_penalty"])

    if text_output_file is not None:
        logger.add_text_output(text_output_file)
        logger.add_tabular_output(tabular_output_file)

    logger.log("Training against {}".format(v["human_policy"]))
    logger.log("Setting seed to {}".format(v["seed"]))
    env.seed(v["seed"])

    baseline = L2RNNBaseline(
        name="vf",
        env_spec=env.spec,
        log_loss_before=False,
        log_loss_after=False,
        hidden_nonlinearity=getattr(tf.nn, v["nonlinearity"]),
        weight_normalization=v["weight_normalization"],
        layer_normalization=v["layer_normalization"],
        state_include_action=False,
        hidden_dim=v["hidden_dim"],
        optimizer=TBPTTOptimizer(
            batch_size=v["opt_batch_size"],
            n_steps=v["opt_n_steps"],
            n_epochs=v["min_epochs"],
        ),
        batch_size=v["opt_batch_size"],
        n_steps=v["opt_n_steps"],
    )
    policy = CategoricalGRUPolicy(env_spec=env.spec,
                                  hidden_nonlinearity=getattr(
                                      tf.nn, v["nonlinearity"]),
                                  hidden_dim=v["hidden_dim"],
                                  state_include_action=True,
                                  name="policy")

    n_itr = 3 if local_test else 100
    # logger.log('sampler_args {}'.format(dict(n_envs=max(1, min(int(np.ceil(v["batch_size"] / v["n_episodes"])), 100)))))

    # parallel_sampler.initialize(6)
    # parallel_sampler.set_seed(v["seed"])

    algo = PPOSGD(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=v["batch_size"],
        max_path_length=v["n_episodes"],
        # 43.65 env time
        sampler_args=dict(n_envs=max(
            1, min(int(np.ceil(v["batch_size"] / v["n_episodes"])), 100))),
        # 100 threads -> 1:36 to sample 187.275
        # 6 threads -> 1:31
        # force_batch_sampler=True,
        n_itr=n_itr,
        step_size=v["mean_kl"],
        clip_lr=v["clip_lr"],
        log_loss_kl_before=False,
        log_loss_kl_after=False,
        use_kl_penalty=v["use_kl_penalty"],
        min_n_epochs=v["min_epochs"],
        entropy_bonus_coeff=v["entropy_bonus_coeff"],
        optimizer=TBPTTOptimizer(
            batch_size=v["opt_batch_size"],
            n_steps=v["opt_n_steps"],
            n_epochs=v["min_epochs"],
        ),
        discount=v["discount"],
        gae_lambda=v["gae_lambda"],
        use_line_search=True
        # scope=ename
    )

    sess = tf.Session()
    with sess.as_default():
        algo.train(sess)

        if text_output_file is not None:
            logger.remove_tabular_output(tabular_output_file)
            logger.add_tabular_output(info_theory_tabular_output)

        # Now gather statistics for t-tests and such!
        for human_policy_name, human_policy in human_policy_dict.items():

            logger.log("-------------------")
            logger.log("Evaluating against {}".format(human_policy.__name__))
            logger.log("-------------------")

            logger.log("Obtaining Samples...")
            test_pi_H = human_policy(bandit)
            test_env = h_wrapper(bandit, test_pi_H, penalty=0.)
            eval_sampler = VectorizedSampler(algo, n_envs=100)
            algo.batch_size = v["num_eval_traj"] * v["n_episodes"]
            algo.env = test_env
            logger.log("algo.env.pi_H has class: {}".format(
                algo.env.pi_H.__class__))
            eval_sampler.start_worker()
            paths = eval_sampler.obtain_samples(-1)
            eval_sampler.shutdown_worker()
            rewards = []

            H_act_seqs = []
            R_act_seqs = []
            best_arms = []
            optimal_a_seqs = []
            for p in paths:
                a_Rs = env.action_space.unflatten_n(p['actions'])
                obs_R = env.observation_space.unflatten_n(p['observations'])
                best_arm = np.argmax(p['env_infos']['arm_means'][0])

                H_act_seqs.append(obs_R[:, 1])
                R_act_seqs.append(a_Rs)
                best_arms.append(best_arm)
                optimal_a_seqs.append(
                    [best_arm for _ in range(v["n_episodes"])])

                rewards.append(np.sum(p['rewards']))

            #feel free to add more data
            logger.log("NumTrajs {}".format(v["num_eval_traj"]))
            logger.log("AverageReturn {}".format(np.mean(rewards)))
            logger.log("StdReturn {}".format(np.std(rewards)))
            logger.log("MaxReturn {}".format(np.max(rewards)))
            logger.log("MinReturn {}".format(np.min(rewards)))

            optimal_a_H_freqs = _frequency_agreement(H_act_seqs,
                                                     optimal_a_seqs)
            optimal_a_R_freqs = _frequency_agreement(R_act_seqs,
                                                     optimal_a_seqs)

            for t in range(v["n_episodes"]):
                logger.record_tabular("PolicyExecTime", 0)
                logger.record_tabular("EnvExecTime", 0)
                logger.record_tabular("ProcessExecTime", 0)
                logger.record_tabular("Tested Against", human_policy_name)
                logger.record_tabular("t", t)
                logger.record_tabular("a_H_agreement", optimal_a_H_freqs[t])
                logger.record_tabular("a_R_agreement", optimal_a_R_freqs[t])

                H_act_seqs_truncated = [a_Hs[0:t] for a_Hs in H_act_seqs]
                R_act_seqs_truncated = [a_Rs[0:t] for a_Rs in R_act_seqs]
                h_mutual_info = _mutual_info_seqs(H_act_seqs_truncated,
                                                  best_arms, v["n_arms"] + 1)
                r_mutual_info = _mutual_info_seqs(R_act_seqs_truncated,
                                                  best_arms, v["n_arms"] + 1)
                logger.record_tabular("h_mutual_info", h_mutual_info)
                logger.record_tabular("r_mutual_info", r_mutual_info)
                logger.record_tabular("a_H_opt_freq", optimal_a_H_freqs[t])
                logger.record_tabular("a_R_opt_freq", optimal_a_R_freqs[t])
                logger.dump_tabular()

            test_env = h_wrapper(bandit, human_policy(bandit), penalty=0.)

            logger.log("Printing Example Trajectories")
            for i in range(v["num_display_traj"]):
                observation = test_env.reset()
                policy.reset()
                logger.log("-- Trajectory {} of {}".format(
                    i + 1, v["num_display_traj"]))
                logger.log("t \t obs \t act \t reward \t act_probs")
                for t in range(v["n_episodes"]):
                    action, act_info = policy.get_action(observation)
                    new_obs, reward, done, info = test_env.step(action)
                    logger.log("{} \t {} \t {} \t {} \t {}".format(
                        t, observation, action, reward, act_info['prob']))
                    observation = new_obs
                    if done:
                        logger.log("Total reward: {}".format(
                            info["accumulated rewards"]))
                        break

    if text_output_file is not None:
        logger.remove_text_output(text_output_file)
        logger.remove_tabular_output(info_theory_tabular_output)