예제 #1
0
파일: rurllab.py 프로젝트: zeyuan1987/MADRL
def rllab_envpolicy_parser(env, args):
    if isinstance(args, dict):
        args = tonamedtuple(args)

    env = RLLabEnv(env, mode=args.control)
    if args.algo[:2] == 'tf':
        env = TfEnv(env)

        # Policy
        if args.recurrent:
            if args.feature_net:
                feature_network = MLP(
                    name='feature_net',
                    input_shape=(env.spec.observation_space.flat_dim +
                                 env.spec.action_space.flat_dim, ),
                    output_dim=args.feature_output,
                    hidden_sizes=tuple(args.feature_hidden),
                    hidden_nonlinearity=tf.nn.tanh,
                    output_nonlinearity=None)
            elif args.conv:
                strides = tuple(args.conv_strides)
                chans = tuple(args.conv_channels)
                filts = tuple(args.conv_filters)

                assert len(strides) == len(chans) == len(
                    filts), "strides, chans and filts not equal"
                # only discrete actions supported, should be straightforward to extend to continuous
                assert isinstance(
                    env.spec.action_space,
                    Discrete), "Only discrete action spaces support conv"
                feature_network = ConvNetwork(
                    name='feature_net',
                    input_shape=env.spec.observation_space.shape,
                    output_dim=args.feature_output,
                    conv_filters=chans,
                    conv_filter_sizes=filts,
                    conv_strides=strides,
                    conv_pads=('VALID', ) * len(chans),
                    hidden_sizes=tuple(args.feature_hidden),
                    hidden_nonlinearity=tf.nn.relu,
                    output_nonlinearity=None)
            else:
                feature_network = None
            if args.recurrent == 'gru':
                if isinstance(env.spec.action_space, Box):
                    policy = GaussianGRUPolicy(env_spec=env.spec,
                                               feature_network=feature_network,
                                               hidden_dim=int(
                                                   args.policy_hidden[0]),
                                               name='policy')
                elif isinstance(env.spec.action_space, Discrete):
                    policy = CategoricalGRUPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden[0]),
                        name='policy',
                        state_include_action=False if args.conv else True)
                else:
                    raise NotImplementedError(env.spec.observation_space)

            elif args.recurrent == 'lstm':
                if isinstance(env.spec.action_space, Box):
                    policy = GaussianLSTMPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden),
                        name='policy')
                elif isinstance(env.spec.action_space, Discrete):
                    policy = CategoricalLSTMPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden),
                        name='policy')
                else:
                    raise NotImplementedError(env.spec.action_space)

            else:
                raise NotImplementedError(args.recurrent)
        elif args.conv:
            strides = tuple(args.conv_strides)
            chans = tuple(args.conv_channels)
            filts = tuple(args.conv_filters)

            assert len(strides) == len(chans) == len(
                filts), "strides, chans and filts not equal"
            # only discrete actions supported, should be straightforward to extend to continuous
            assert isinstance(
                env.spec.action_space,
                Discrete), "Only discrete action spaces support conv"
            feature_network = ConvNetwork(
                name='feature_net',
                input_shape=env.spec.observation_space.shape,
                output_dim=env.spec.action_space.n,
                conv_filters=chans,
                conv_filter_sizes=filts,
                conv_strides=strides,
                conv_pads=('VALID', ) * len(chans),
                hidden_sizes=tuple(args.policy_hidden),
                hidden_nonlinearity=tf.nn.relu,
                output_nonlinearity=tf.nn.softmax)
            policy = CategoricalMLPPolicy(name='policy',
                                          env_spec=env.spec,
                                          prob_network=feature_network)
        else:
            if isinstance(env.spec.action_space, Box):
                policy = GaussianMLPPolicy(env_spec=env.spec,
                                           hidden_sizes=tuple(
                                               args.policy_hidden),
                                           min_std=args.min_std,
                                           name='policy')
            elif isinstance(env.spec.action_space, Discrete):
                policy = CategoricalMLPPolicy(env_spec=env.spec,
                                              hidden_sizes=tuple(
                                                  args.policy_hidden),
                                              name='policy')
            else:
                raise NotImplementedError(env.spec.action_space)
    elif args.algo[:2] == 'th':
        # Policy
        if args.recurrent:
            if args.feature_net:
                feature_network = thMLP(
                    input_shape=(env.spec.observation_space.flat_dim +
                                 env.spec.action_space.flat_dim, ),
                    output_dim=args.feature_output,
                    hidden_sizes=tuple(args.feature_hidden),
                    hidden_nonlinearity=tf.nn.tanh,
                    output_nonlinearity=None)
            else:
                feature_network = None
            if args.recurrent == 'gru':
                if isinstance(env.spec.observation_space, thBox):
                    policy = thGaussianGRUPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden[0]),
                    )
                elif isinstance(env.spec.observation_space, thDiscrete):
                    policy = thCategoricalGRUPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden[0]),
                    )
                else:
                    raise NotImplementedError(env.spec.observation_space)

            # elif args.recurrent == 'lstm':
            #     if isinstance(env.spec.action_space, thBox):
            #         policy = thGaussianLSTMPolicy(env_spec=env.spec,
            #                                       feature_network=feature_network,
            #                                       hidden_dim=int(args.policy_hidden),
            #                                       name='policy')
            #     elif isinstance(env.spec.action_space, thDiscrete):
            #         policy = thCategoricalLSTMPolicy(env_spec=env.spec,
            #                                          feature_network=feature_network,
            #                                          hidden_dim=int(args.policy_hidden),
            #                                          name='policy')
            #     else:
            #         raise NotImplementedError(env.spec.action_space)

            else:
                raise NotImplementedError(args.recurrent)
        else:
            if args.algo == 'thddpg':
                assert isinstance(env.spec.action_space, thBox)
                policy = thDeterministicMLPPolicy(
                    env_spec=env.spec,
                    hidden_sizes=tuple(args.policy_hidden),
                )
            else:
                if isinstance(env.spec.action_space, thBox):
                    policy = thGaussianMLPPolicy(env_spec=env.spec,
                                                 hidden_sizes=tuple(
                                                     args.policy_hidden),
                                                 min_std=args.min_std)
                elif isinstance(env.spec.action_space, thDiscrete):
                    policy = thCategoricalMLPPolicy(env_spec=env.spec,
                                                    hidden_sizes=tuple(
                                                        args.policy_hidden),
                                                    min_std=args.min_std)
                else:
                    raise NotImplementedError(env.spec.action_space)

    if args.control == 'concurrent':
        return env, policies
    else:
        return env, policy
예제 #2
0
def get_policy(env, algo_name, info, policy_hidden_sizes,
               policy_hidden_nonlinearity, policy_output_nonlinearity,
               recurrent, **kwargs):
    policy = None
    policy_class = None
    hidden_sizes = get_hidden_sizes(policy_hidden_sizes)
    hidden_nonlinearity = get_nonlinearity(policy_hidden_nonlinearity)
    output_nonlinearity = get_nonlinearity(policy_output_nonlinearity)
    if algo_name in [
            'trpo',
            'actrpo',
            'acqftrpo',
            'qprop',
            'mqprop',
            'qfqprop',
            'trpg',
            'trpgoff',
            'nuqprop',
            'nuqfqprop',
            'nafqprop',
            'vpg',
            'qvpg',
            'dspg',
            'dspgoff',
    ]:
        if not info['is_action_discrete']:
            if recurrent:
                policy = GaussianLSTMPolicy(
                    name="gauss_lstm_policy",
                    env_spec=env.spec,
                    lstm_layer_cls=L.TfBasicLSTMLayer,
                    # gru_layer_cls=L.GRULayer,
                    output_nonlinearity=output_nonlinearity,  # None
                )
                policy_class = 'GaussianLSTMPolicy'
            else:
                policy = GaussianMLPPolicy(
                    name="gauss_policy",
                    env_spec=env.spec,
                    hidden_sizes=hidden_sizes,
                    hidden_nonlinearity=hidden_nonlinearity,  # tf.nn.tanh
                    output_nonlinearity=output_nonlinearity,  # None
                )
                policy_class = 'GaussianMLPPolicy'
        else:
            if recurrent:
                policy = CategoricalLSTMPolicy(
                    name="cat_lstm_policy",
                    env_spec=env.spec,
                    lstm_layer_cls=L.TfBasicLSTMLayer,
                    # gru_layer_cls=L.GRULayer,
                )
                policy_class = 'CategoricalLSTMPolicy'
            else:
                policy = CategoricalMLPPolicy(
                    name="cat_policy",
                    env_spec=env.spec,
                    hidden_sizes=hidden_sizes,
                    hidden_nonlinearity=hidden_nonlinearity,  # tf.nn.tanh
                )
                policy_class = 'CategoricalMLPPolicy'
    elif algo_name in [
            'ddpg',
    ]:
        assert not info['is_action_discrete']
        policy = DeterministicMLPPolicy(
            name="det_policy",
            env_spec=env.spec,
            hidden_sizes=hidden_sizes,
            hidden_nonlinearity=hidden_nonlinearity,  # tf.nn.relu
            output_nonlinearity=output_nonlinearity,  # tf.nn.tanh
        )
        policy_class = 'DeterministicMLPPolicy'
    print(
        '[get_policy] Instantiating %s, with sizes=%s, hidden_nonlinearity=%s.'
        % (policy_class, str(hidden_sizes), policy_hidden_nonlinearity))
    print('[get_policy] output_nonlinearity=%s.' %
          (policy_output_nonlinearity))
    return policy
예제 #3
0
    def parse_env_args(self, env, args):

        if isinstance(args, dict):
            args = to_named_tuple(args)

        # Multi-agent wrapper
        env = RLLabEnv(env, ma_mode=args.control)
        env = MATfEnv(env)

        # Policy
        if args.recurrent:
            if args.feature_net:
                feature_network = MLP(
                    name='feature_net',
                    input_shape=(env.spec.observation_space.flat_dim +
                                 env.spec.action_space.flat_dim, ),
                    output_dim=args.feature_output,
                    hidden_sizes=tuple(args.feature_hidden),
                    hidden_nonlinearity=tf.nn.tanh,
                    output_nonlinearity=None)
            elif args.conv:
                strides = tuple(args.conv_strides)
                chans = tuple(args.conv_channels)
                filts = tuple(args.conv_filters)

                assert len(strides) == len(chans) == len(
                    filts), "strides, chans and filts not equal"
                # only discrete actions supported, should be straightforward to extend to continuous
                assert isinstance(
                    env.spec.action_space,
                    Discrete), "Only discrete action spaces support conv"
                feature_network = ConvNetwork(
                    name='feature_net',
                    input_shape=env.spec.observation_space.shape,
                    output_dim=args.feature_output,
                    conv_filters=chans,
                    conv_filter_sizes=filts,
                    conv_strides=strides,
                    conv_pads=('VALID', ) * len(chans),
                    hidden_sizes=tuple(args.feature_hidden),
                    hidden_nonlinearity=tf.nn.relu,
                    output_nonlinearity=None)
            else:
                feature_network = None
            if args.recurrent == 'gru':
                if isinstance(env.spec.action_space, Box):
                    if args.control == 'concurrent':
                        policies = [
                            GaussianGRUPolicy(env_spec=env.spec,
                                              feature_network=feature_network,
                                              hidden_dim=int(
                                                  args.policy_hidden[0]),
                                              name='policy_{}'.format(agid))
                            for agid in range(len(env.agents))
                        ]
                    policy = GaussianGRUPolicy(env_spec=env.spec,
                                               feature_network=feature_network,
                                               hidden_dim=int(
                                                   args.policy_hidden[0]),
                                               name='policy')
                elif isinstance(env.spec.action_space, Discrete):
                    if args.control == 'concurrent':
                        policies = [
                            CategoricalGRUPolicy(
                                env_spec=env.spec,
                                feature_network=feature_network,
                                hidden_dim=int(args.policy_hidden[0]),
                                name='policy_{}'.format(agid),
                                state_include_action=False
                                if args.conv else True)
                            for agid in range(len(env.agents))
                        ]
                    q_network = CategoricalGRUPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden[0]),
                        name='q_network',
                        state_include_action=False if args.conv else True)
                    target_q_network = CategoricalGRUPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden[0]),
                        name='target_q_network',
                        state_include_action=False if args.conv else True)
                    policy = {
                        'q_network': q_network,
                        'target_q_network': target_q_network
                    }
                else:
                    raise NotImplementedError(env.spec.observation_space)

            elif args.recurrent == 'lstm':
                if isinstance(env.spec.action_space, Box):
                    if args.control == 'concurrent':
                        policies = [
                            GaussianLSTMPolicy(env_spec=env.spec,
                                               feature_network=feature_network,
                                               hidden_dim=int(
                                                   args.policy_hidden),
                                               name='policy_{}'.format(agid))
                            for agid in range(len(env.agents))
                        ]
                    policy = GaussianLSTMPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden),
                        name='policy')
                elif isinstance(env.spec.action_space, Discrete):
                    if args.control == 'concurrent':
                        policies = [
                            CategoricalLSTMPolicy(
                                env_spec=env.spec,
                                feature_network=feature_network,
                                hidden_dim=int(args.policy_hidden),
                                name='policy_{}'.format(agid))
                            for agid in range(len(env.agents))
                        ]
                    q_network = CategoricalLSTMPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden),
                        name='q_network')
                    target_q_network = CategoricalLSTMPolicy(
                        env_spec=env.spec,
                        feature_network=feature_network,
                        hidden_dim=int(args.policy_hidden),
                        name='target_q_network')
                    policy = {
                        'q_network': q_network,
                        'target_q_network': target_q_network
                    }
                else:
                    raise NotImplementedError(env.spec.action_space)

            else:
                raise NotImplementedError(args.recurrent)
        elif args.conv:
            strides = tuple(args.conv_strides)
            chans = tuple(args.conv_channels)
            filts = tuple(args.conv_filters)

            assert len(strides) == len(chans) == len(
                filts), "strides, chans and filts not equal"
            # only discrete actions supported, should be straightforward to extend to continuous
            assert isinstance(
                env.spec.action_space,
                Discrete), "Only discrete action spaces support conv"
            feature_network = ConvNetwork(
                name='feature_net',
                input_shape=env.spec.observation_space.shape,
                output_dim=env.spec.action_space.n,
                conv_filters=chans,
                conv_filter_sizes=filts,
                conv_strides=strides,
                conv_pads=(args.conv_pads, ) * len(chans),
                hidden_sizes=tuple(args.policy_hidden),
                hidden_nonlinearity=tf.nn.relu,
                output_nonlinearity=tf.nn.softmax,
                batch_normalization=args.batch_normalization)
            if args.algo == 'dqn':
                q_network = CategoricalMLPPolicy(name='q_network',
                                                 env_spec=env.spec,
                                                 prob_network=feature_network)
                target_q_network = CategoricalMLPPolicy(
                    name='target_q_network',
                    env_spec=env.spec,
                    prob_network=feature_network)
                policy = {
                    'q_network': q_network,
                    'target_q_network': target_q_network
                }

            else:
                policy = CategoricalMLPPolicy(name='policy',
                                              env_spec=env.spec,
                                              prob_network=feature_network)
        else:
            if env.spec is None:

                networks = [
                    DQNNetwork(i,
                               env,
                               target_network_update_freq=self.args.
                               target_network_update,
                               discount_factor=self.args.discount,
                               batch_size=self.args.batch_size,
                               learning_rate=self.args.qfunc_lr)
                    for i in range(env.n)
                ]

                policy = networks

            elif isinstance(env.spec.action_space, Box):
                policy = GaussianMLPPolicy(env_spec=env.spec,
                                           hidden_sizes=tuple(
                                               args.policy_hidden),
                                           min_std=args.min_std,
                                           name='policy')
            elif isinstance(env.spec.action_space, Discrete):
                policy = CategoricalMLPPolicy(env_spec=env.spec,
                                              hidden_sizes=tuple(
                                                  args.policy_hidden),
                                              name='policy')
            else:
                raise NotImplementedError(env.spec.action_space)

        return env, policy
예제 #4
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.99)
    parser.add_argument('--gae_lambda', type=float, default=1.0)
    parser.add_argument('--reward_scale', type=float, default=1.0)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum',
                        action='store_true',
                        default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--rectangle', type=str, default='10,10')
    parser.add_argument('--map_type', type=str, default='rectangle')
    parser.add_argument('--n_evaders', type=int, default=5)
    parser.add_argument('--n_pursuers', type=int, default=2)
    parser.add_argument('--obs_range', type=int, default=3)
    parser.add_argument('--n_catch', type=int, default=2)
    parser.add_argument('--urgency', type=float, default=0.0)
    parser.add_argument('--pursuit', dest='train_pursuit', action='store_true')
    parser.add_argument('--evade', dest='train_pursuit', action='store_false')
    parser.set_defaults(train_pursuit=True)
    parser.add_argument('--surround', action='store_true', default=False)
    parser.add_argument('--constraint_window', type=float, default=1.0)
    parser.add_argument('--sample_maps', action='store_true', default=False)
    parser.add_argument('--map_file', type=str, default='../maps/map_pool.npy')
    parser.add_argument('--flatten', action='store_true', default=False)
    parser.add_argument('--reward_mech', type=str, default='global')
    parser.add_argument('--catchr', type=float, default=0.1)
    parser.add_argument('--term_pursuit', type=float, default=5.0)

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_type', type=str, default='linear')

    parser.add_argument('--conv', action='store_true', default=False)

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--checkpoint', type=str, default=None)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )

    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    if args.checkpoint:
        with tf.Session() as sess:
            data = joblib.load(args.checkpoint)
            policy = data['policy']
            env = data['env']
    else:
        if args.sample_maps:
            map_pool = np.load(args.map_file)
        else:
            if args.map_type == 'rectangle':
                env_map = TwoDMaps.rectangle_map(
                    *map(int, args.rectangle.split(',')))
            elif args.map_type == 'complex':
                env_map = TwoDMaps.complex_map(
                    *map(int, args.rectangle.split(',')))
            else:
                raise NotImplementedError()
            map_pool = [env_map]

        env = PursuitEvade(map_pool,
                           n_evaders=args.n_evaders,
                           n_pursuers=args.n_pursuers,
                           obs_range=args.obs_range,
                           n_catch=args.n_catch,
                           train_pursuit=args.train_pursuit,
                           urgency_reward=args.urgency,
                           surround=args.surround,
                           sample_maps=args.sample_maps,
                           constraint_window=args.constraint_window,
                           flatten=args.flatten,
                           reward_mech=args.reward_mech,
                           catchr=args.catchr,
                           term_pursuit=args.term_pursuit)

        env = TfEnv(
            RLLabEnv(StandardizedEnv(env,
                                     scale_reward=args.reward_scale,
                                     enable_obsnorm=False),
                     mode=args.control))

        if args.recurrent:
            if args.conv:
                feature_network = ConvNetwork(
                    name='feature_net',
                    input_shape=emv.spec.observation_space.shape,
                    output_dim=5,
                    conv_filters=(16, 32, 32),
                    conv_filter_sizes=(3, 3, 3),
                    conv_strides=(1, 1, 1),
                    conv_pads=('VALID', 'VALID', 'VALID'),
                    hidden_sizes=(64, ),
                    hidden_nonlinearity=tf.nn.relu,
                    output_nonlinearity=tf.nn.softmax)
            else:
                feature_network = MLP(
                    name='feature_net',
                    input_shape=(env.spec.observation_space.flat_dim +
                                 env.spec.action_space.flat_dim, ),
                    output_dim=5,
                    hidden_sizes=(256, 128, 64),
                    hidden_nonlinearity=tf.nn.tanh,
                    output_nonlinearity=None)
            if args.recurrent == 'gru':
                policy = CategoricalGRUPolicy(env_spec=env.spec,
                                              feature_network=feature_network,
                                              hidden_dim=int(
                                                  args.policy_hidden_sizes),
                                              name='policy')
            elif args.recurrent == 'lstm':
                policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                               feature_network=feature_network,
                                               hidden_dim=int(
                                                   args.policy_hidden_sizes),
                                               name='policy')
        elif args.conv:
            feature_network = ConvNetwork(
                name='feature_net',
                input_shape=env.spec.observation_space.shape,
                output_dim=5,
                conv_filters=(8, 16),
                conv_filter_sizes=(3, 3),
                conv_strides=(2, 1),
                conv_pads=('VALID', 'VALID'),
                hidden_sizes=(32, ),
                hidden_nonlinearity=tf.nn.relu,
                output_nonlinearity=tf.nn.softmax)
            policy = CategoricalMLPPolicy(name='policy',
                                          env_spec=env.spec,
                                          prob_network=feature_network)
        else:
            policy = CategoricalMLPPolicy(name='policy',
                                          env_spec=env.spec,
                                          hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(env_spec=env.spec)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp(
            base_eps=1e-5)) if args.recurrent else None,
        mode=args.control,
    )

    algo.train()
예제 #5
0
logger.add_text_output(text_log_file)
logger.add_tabular_output(tabular_log_file)
prev_snapshot_dir = logger.get_snapshot_dir()
prev_mode = logger.get_snapshot_mode()
logger.set_snapshot_dir(log_dir)
logger.set_snapshot_mode("gaplast")
logger.set_snapshot_gap(1000)
logger.set_log_tabular_only(False)
logger.push_prefix("[%s] " % (game_name + '_' + str(mask_num)))

from Algo import parallel_sampler
parallel_sampler.initialize(n_parallel=1)
parallel_sampler.set_seed(0)

policy = CategoricalLSTMPolicy(
    env_spec=env.spec,
    name="lstm",
)

baseline = LinearFeatureBaseline(env_spec=env.spec)

with tf.Session() as sess:

    # writer = tf.summary.FileWriter(logdir=log_dir,)

    algo = VPG_t(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=2048,  #2*env._wrapped_env.params['traj_limit'],
        max_path_length=200,
        n_itr=10000,