Exemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('filename',
                        type=str)  # defaultIS.h5/snapshots/iter0000480
    parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4')
    parser.add_argument('--deterministic', action='store_true', default=False)
    parser.add_argument('--heuristic', action='store_true', default=False)
    parser.add_argument('--evaluate', action='store_true', default=False)
    parser.add_argument('--n_trajs', type=int, default=20)
    parser.add_argument('--n_steps', type=int, default=500)
    parser.add_argument('--same_con_pol', action='store_true')
    args = parser.parse_args()

    fh = FileHandler(args.filename)
    env_map = TwoDMaps.rectangle_map(
        *map(int, fh.train_args['map_size'].split(',')))
    map_pool = [env_map]
    # map_pool = np.load(
    # os.path.join('/scratch/megorov/deeprl/MADRL/runners/maps/', os.path.basename(fh.train_args[
    # 'map_file'])))

    env = PursuitEvade(map_pool,
                       n_evaders=fh.train_args['n_evaders'],
                       n_pursuers=fh.train_args['n_pursuers'],
                       obs_range=fh.train_args['obs_range'],
                       n_catch=fh.train_args['n_catch'],
                       urgency_reward=fh.train_args['urgency'],
                       surround=bool(fh.train_args['surround']),
                       sample_maps=bool(fh.train_args['sample_maps']),
                       flatten=bool(fh.train_args['flatten']),
                       reward_mech='global',
                       catchr=fh.train_args['catchr'],
                       term_pursuit=fh.train_args['term_pursuit'])

    if fh.train_args['buffer_size'] > 1:
        env = ObservationBuffer(env, fh.train_args['buffer_size'])

    hpolicy = None
    if args.heuristic:
        from heuristics.pursuit import PursuitHeuristicPolicy
        hpolicy = PursuitHeuristicPolicy(env.agents[0].observation_space,
                                         env.agents[0].action_space)

    if args.evaluate:
        minion = Evaluator(env, fh.train_args, args.n_steps, args.n_trajs,
                           args.deterministic,
                           'heuristic' if args.heuristic else fh.mode)
        evr = minion(fh.filename,
                     file_key=fh.file_key,
                     same_con_pol=args.same_con_pol,
                     hpolicy=hpolicy)
        from tabulate import tabulate
        print(evr)
        print(tabulate(evr, headers='keys'))
    else:
        minion = Visualizer(env, fh.train_args, args.n_steps, args.n_trajs,
                            args.deterministic, fh.mode)
        rew, info = minion(fh.filename, file_key=fh.file_key, vid=args.vid)
        pprint.pprint(rew)
        pprint.pprint(info)
Exemplo n.º 2
0
def main(parser):
    args = parser.args

    if args.map_file:
        map_pool = np.load(args.map_file)
    else:
        if args.map_type == 'rectangle':
            env_map = TwoDMaps.rectangle_map(*map(int, args.map_size.split(',')))
        elif args.map_type == 'complex':
            env_map = TwoDMaps.complex_map(*map(int, args.map_size.split(',')))
        else:
            raise NotImplementedError()
        map_pool = [env_map]

    env = PursuitEvade(map_pool, n_evaders=args.n_evaders, n_pursuers=args.n_pursuers,
                       obs_range=args.obs_range, n_catch=args.n_catch, urgency_reward=args.urgency,
                       surround=bool(args.surround), sample_maps=bool(args.sample_maps),
                       flatten=bool(args.flatten), reward_mech=args.reward_mech, catchr=args.catchr,
                       term_pursuit=args.term_pursuit, include_id=not bool(args.noid))

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    runner = Runner(env, args)
    runner.start_training()
Exemplo n.º 3
0
def envname2env(envname, args):
    from madrl_environments.pursuit import PursuitEvade
    from madrl_environments.pursuit import MAWaterWorld
    from madrl_environments.walker.multi_walker import MultiWalkerEnv

    # XXX
    # Will generalize later
    if envname == 'multiwalker':
        env = MultiWalkerEnv(
            args['n_walkers'],
            args['position_noise'],
            args['angle_noise'],
            reward_mech='global',
        )

    elif envname == 'waterworld':
        env = MAWaterWorld(
            args['n_pursuers'],
            args['n_evaders'],
            args['n_coop'],
            args['n_poison'],
            n_sensors=args['n_sensors'],
            food_reward=args['food_reward'],
            poison_reward=args['poison_reward'],
            encounter_reward=args['encounter_reward'],
            reward_mech='global',
        )

    elif envname == 'pursuit':
        env = PursuitEvade()
    else:
        raise NotImplementedError()

    return env
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('filename', type=str)
    parser.add_argument('--vid', type=str, default='madrl.mp4')
    parser.add_argument('--deterministic', action='store_true', default=False)
    parser.add_argument('--heuristic', action='store_true', default=False)
    parser.add_argument('--evaluate', action='store_true', default=True)
    parser.add_argument('--n_trajs', type=int, default=20)
    parser.add_argument('--n_steps', type=int, default=20)
    parser.add_argument('--same_con_pol', action='store_true')
    args = parser.parse_args()

    fh = FileHandler(args.filename)

    if fh.train_args['map_file'] is not None:
        map_pool = np.load(
            os.path.join('', os.path.basename(fh.train_args[
                'map_file'])))
    else:
        if fh.train_args['map_type'] == 'rectangle':
            env_map = TwoDMaps.rectangle_map(*map(int, fh.train_args['map_size'].split(',')))
        elif args.map_type == 'complex':
            env_map = TwoDMaps.complex_map(*map(int, fh.train_args['map_size'].split(',')))
        else:
            raise NotImplementedError()
        map_pool = [env_map]

    env = PursuitEvade(map_pool, n_evaders=fh.train_args['n_evaders'],
                       n_pursuers=fh.train_args['n_pursuers'], obs_range=fh.train_args['obs_range'],
                       n_catch=fh.train_args['n_catch'], urgency_reward=fh.train_args['urgency'],
                       surround=bool(fh.train_args['surround']),
                       sample_maps=bool(fh.train_args['sample_maps']),
                       flatten=bool(fh.train_args['flatten']), reward_mech='global',
                       catchr=fh.train_args['catchr'], term_pursuit=fh.train_args['term_pursuit'])

    if fh.train_args['buffer_size'] > 1:
        env = ObservationBuffer(env, fh.train_args['buffer_size'])

    hpolicy = None
    if args.evaluate:
        minion = Evaluator(env, fh.train_args, args.n_steps, args.n_trajs, args.deterministic,
                           'heuristic' if args.heuristic else fh.mode)
        evr = minion(fh.filename, file_key=fh.file_key, same_con_pol=args.same_con_pol,
                     hpolicy=hpolicy)
        print(evr)
        print(tabulate(evr, headers='keys'))
    else:
        minion = Visualizer(env, fh.train_args, args.n_steps, args.n_trajs, args.deterministic,
                            fh.mode)
        rew, info = minion(fh.filename, file_key=fh.file_key, vid=args.vid)
        pprint.pprint(rew)
        pprint.pprint(info)
Exemplo n.º 5
0
def main(parser):
    mode = parser._mode
    args = parser.args

    if args.map_file:
        map_pool = np.load(args.map_file)
    else:
        if args.map_type == 'rectangle':
            env_map = TwoDMaps.rectangle_map(
                *map(int, args.map_size.split(',')))
        elif args.map_type == 'complex':
            env_map = TwoDMaps.complex_map(*map(int, args.map_size.split(',')))
        else:
            raise NotImplementedError()
        map_pool = [env_map]

    env = PursuitEvade(map_pool,
                       n_evaders=args.n_evaders,
                       n_pursuers=args.n_pursuers,
                       obs_range=args.obs_range,
                       n_catch=args.n_catch,
                       urgency_reward=args.urgency,
                       surround=bool(args.surround),
                       sample_maps=bool(args.sample_maps),
                       flatten=bool(args.flatten),
                       reward_mech=args.reward_mech,
                       catchr=args.catchr,
                       term_pursuit=args.term_pursuit)

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if mode == 'rllab':
        from runners.rurllab import RLLabRunner
        run = RLLabRunner(env, args)
    elif mode == 'rltools':
        from runners.rurltools import RLToolsRunner
        run = RLToolsRunner(env, args)
    else:
        raise NotImplementedError()

    run()
Exemplo n.º 6
0
    def set_state(self, *args):
        pass


if __name__ == '__main__':
    from madrl_environments.pursuit import PursuitEvade
    from madrl_environments.pursuit.utils import *
    from vis import Visualizer
    import pprint
    map_mat = TwoDMaps.rectangle_map(16, 16)

    env = PursuitEvade([map_mat],
                       n_evaders=30,
                       n_pursuers=8,
                       obs_range=7,
                       n_catch=4,
                       surround=False,
                       flatten=False)

    policy = PursuitHeuristicPolicy(env.agents[0].observation_space,
                                    env.agents[0].action_space)

    for i in range(20):
        rew = 0.0
        obs = env.reset()
        infolist = []
        for _ in xrange(500):
            # env.render()
            act_list = []
            for o in obs:
Exemplo n.º 7
0
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.99)
    parser.add_argument('--gae_lambda', type=float, default=1.0)
    parser.add_argument('--reward_scale', type=float, default=1.0)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum',
                        action='store_true',
                        default=False)
    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--control', type=str, default='centralized')

    parser.add_argument('--rectangle', type=str, default='10,10')
    parser.add_argument('--map_type', type=str, default='rectangle')
    parser.add_argument('--n_evaders', type=int, default=5)
    parser.add_argument('--n_pursuers', type=int, default=2)
    parser.add_argument('--obs_range', type=int, default=3)
    parser.add_argument('--n_catch', type=int, default=2)
    parser.add_argument('--urgency', type=float, default=0.0)
    parser.add_argument('--pursuit', dest='train_pursuit', action='store_true')
    parser.add_argument('--evade', dest='train_pursuit', action='store_false')
    parser.set_defaults(train_pursuit=True)
    parser.add_argument('--surround', action='store_true', default=False)
    parser.add_argument('--constraint_window', type=float, default=1.0)
    parser.add_argument('--sample_maps', action='store_true', default=False)
    parser.add_argument('--map_file', type=str, default='../maps/map_pool.npy')
    parser.add_argument('--flatten', action='store_true', default=False)
    parser.add_argument('--reward_mech', type=str, default='global')
    parser.add_argument('--catchr', type=float, default=0.1)
    parser.add_argument('--term_pursuit', type=float, default=5.0)

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_type', type=str, default='linear')

    parser.add_argument('--conv', action='store_true', default=False)

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--checkpoint', type=str, default=None)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )

    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    if args.checkpoint:
        with tf.Session() as sess:
            data = joblib.load(args.checkpoint)
            policy = data['policy']
            env = data['env']
    else:
        if args.sample_maps:
            map_pool = np.load(args.map_file)
        else:
            if args.map_type == 'rectangle':
                env_map = TwoDMaps.rectangle_map(
                    *map(int, args.rectangle.split(',')))
            elif args.map_type == 'complex':
                env_map = TwoDMaps.complex_map(
                    *map(int, args.rectangle.split(',')))
            else:
                raise NotImplementedError()
            map_pool = [env_map]

        env = PursuitEvade(map_pool,
                           n_evaders=args.n_evaders,
                           n_pursuers=args.n_pursuers,
                           obs_range=args.obs_range,
                           n_catch=args.n_catch,
                           train_pursuit=args.train_pursuit,
                           urgency_reward=args.urgency,
                           surround=args.surround,
                           sample_maps=args.sample_maps,
                           constraint_window=args.constraint_window,
                           flatten=args.flatten,
                           reward_mech=args.reward_mech,
                           catchr=args.catchr,
                           term_pursuit=args.term_pursuit)

        env = TfEnv(
            RLLabEnv(StandardizedEnv(env,
                                     scale_reward=args.reward_scale,
                                     enable_obsnorm=False),
                     mode=args.control))

        if args.recurrent:
            if args.conv:
                feature_network = ConvNetwork(
                    name='feature_net',
                    input_shape=emv.spec.observation_space.shape,
                    output_dim=5,
                    conv_filters=(16, 32, 32),
                    conv_filter_sizes=(3, 3, 3),
                    conv_strides=(1, 1, 1),
                    conv_pads=('VALID', 'VALID', 'VALID'),
                    hidden_sizes=(64, ),
                    hidden_nonlinearity=tf.nn.relu,
                    output_nonlinearity=tf.nn.softmax)
            else:
                feature_network = MLP(
                    name='feature_net',
                    input_shape=(env.spec.observation_space.flat_dim +
                                 env.spec.action_space.flat_dim, ),
                    output_dim=5,
                    hidden_sizes=(256, 128, 64),
                    hidden_nonlinearity=tf.nn.tanh,
                    output_nonlinearity=None)
            if args.recurrent == 'gru':
                policy = CategoricalGRUPolicy(env_spec=env.spec,
                                              feature_network=feature_network,
                                              hidden_dim=int(
                                                  args.policy_hidden_sizes),
                                              name='policy')
            elif args.recurrent == 'lstm':
                policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                               feature_network=feature_network,
                                               hidden_dim=int(
                                                   args.policy_hidden_sizes),
                                               name='policy')
        elif args.conv:
            feature_network = ConvNetwork(
                name='feature_net',
                input_shape=env.spec.observation_space.shape,
                output_dim=5,
                conv_filters=(8, 16),
                conv_filter_sizes=(3, 3),
                conv_strides=(2, 1),
                conv_pads=('VALID', 'VALID'),
                hidden_sizes=(32, ),
                hidden_nonlinearity=tf.nn.relu,
                output_nonlinearity=tf.nn.softmax)
            policy = CategoricalMLPPolicy(name='policy',
                                          env_spec=env.spec,
                                          prob_network=feature_network)
        else:
            policy = CategoricalMLPPolicy(name='policy',
                                          env_spec=env.spec,
                                          hidden_sizes=args.hidden_sizes)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    else:
        baseline = ZeroBaseline(env_spec=env.spec)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp(
            base_eps=1e-5)) if args.recurrent else None,
        mode=args.control,
    )

    algo.train()
Exemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--discount', type=float, default=0.99)
    parser.add_argument('--gae_lambda', type=float, default=0.99)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler', type=str, default='simple')
    parser.add_argument('--sampler_workers', type=int, default=4)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--adaptive_batch', action='store_true', default=False)
    parser.add_argument('--update_curriculum', action='store_true', default=False)

    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--n_timesteps_min', type=int, default=1000)
    parser.add_argument('--n_timesteps_max', type=int, default=64000)
    parser.add_argument('--timestep_rate', type=int, default=20)

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--rectangle', type=str, default='10,10')
    parser.add_argument('--map_type', type=str, default='rectangle')
    parser.add_argument('--n_evaders', type=int, default=5)
    parser.add_argument('--n_pursuers', type=int, default=2)
    parser.add_argument('--obs_range', type=int, default=3)
    parser.add_argument('--n_catch', type=int, default=2)
    parser.add_argument('--urgency', type=float, default=0.0)
    parser.add_argument('--pursuit', dest='train_pursuit', action='store_true')
    parser.add_argument('--evade', dest='train_pursuit', action='store_false')
    parser.set_defaults(train_pursuit=True)
    parser.add_argument('--surround', action='store_true', default=False)
    parser.add_argument('--constraint_window', type=float, default=1.0)
    parser.add_argument('--sample_maps', action='store_true', default=False)
    parser.add_argument('--map_file', type=str, default='maps/map_pool.npy')
    parser.add_argument('--flatten', action='store_true', default=False)
    parser.add_argument('--reward_mech', type=str, default='global')
    parser.add_argument('--cur_remove', type=int, default=100)
    parser.add_argument('--cur_const_rate', type=float, default=0.0)
    parser.add_argument('--cur_shaping', type=int, default=np.inf)
    parser.add_argument('--catchr', type=float, default=0.1)
    parser.add_argument('--term_pursuit', type=float, default=5.0)

    parser.add_argument('--policy_hidden_spec', type=str, default='MED_POLICY_ARCH')

    parser.add_argument('--baseline_type', type=str, default='mlp')
    parser.add_argument('--baseline_hidden_spec', type=str, default='MED_POLICY_ARCH')

    parser.add_argument('--max_kl', type=float, default=0.01)
    parser.add_argument('--vf_max_kl', type=float, default=0.01)
    parser.add_argument('--vf_cg_damping', type=float, default=0.01)

    parser.add_argument('--save_freq', type=int, default=50)
    parser.add_argument('--log', type=str, required=False)
    parser.add_argument('--tblog', type=str, default='/tmp/madrl_tb')
    parser.add_argument('--debug', dest='debug', action='store_true')
    parser.add_argument('--no-debug', dest='debug', action='store_false')
    parser.add_argument('--load_checkpoint', type=str, default='')
    parser.set_defaults(debug=True)

    args = parser.parse_args()

    args.policy_hidden_spec = get_arch(args.policy_hidden_spec)
    args.baseline_hidden_spec = get_arch(args.baseline_hidden_spec)

    if args.sample_maps:
        map_pool = np.load(args.map_file)
    else:
        if args.map_type == 'rectangle':
            env_map = TwoDMaps.rectangle_map(*map(int, args.rectangle.split(',')))
        elif args.map_type == 'complex':
            env_map = TwoDMaps.complex_map(*map(int, args.rectangle.split(',')))
        else:
            raise NotImplementedError()
        map_pool = [env_map]
    env = PursuitEvade(map_pool, n_evaders=args.n_evaders, n_pursuers=args.n_pursuers,
                       obs_range=args.obs_range, n_catch=args.n_catch,
                       train_pursuit=args.train_pursuit, urgency_reward=args.urgency,
                       surround=args.surround, sample_maps=args.sample_maps,
                       constraint_window=args.constraint_window,
                       flatten=args.flatten,
                       reward_mech=args.reward_mech,
                       curriculum_remove_every=args.cur_remove,
                       curriculum_constrain_rate=args.cur_const_rate,
                       curriculum_turn_off_shaping=args.cur_shaping,
                       catchr=args.catchr,
                       term_pursuit=args.term_pursuit)

    if args.control == 'centralized':
        obsfeat_space = spaces.Box(low=env.agents[0].observation_space.low[0],
                                   high=env.agents[0].observation_space.high[0],
                                   shape=(env.agents[0].observation_space.shape[0] *
                                          len(env.agents),))  # XXX
        action_space = spaces.Discrete(env.agents[0].action_space.n * len(env.agents))

    elif args.control == 'decentralized':
        obsfeat_space = env.agents[0].observation_space
        action_space = env.agents[0].action_space
    else:
        raise NotImplementedError()


    if args.load_checkpoint is not '':
        filename, file_key = rltools.util.split_h5_name(args.load_checkpoint)
        print('Loading parameters from {} in {}'.format(file_key, filename))
        with h5py.File(filename, 'r') as f:
            train_args = json.loads(f.attrs['args'])
            dset = f[file_key]

            pprint.pprint(dict(dset.attrs))
        policy = CategoricalMLPPolicy(obsfeat_space, action_space,
                                      hidden_spec=train_args['policy_hidden_spec'], enable_obsnorm=False,
                                      tblog=train_args['tblog'], varscope_name='pursuit_catmlp_policy')
    else: 
        policy = CategoricalMLPPolicy(obsfeat_space, action_space, hidden_spec=args.policy_hidden_spec,
                                      enable_obsnorm=False, tblog=args.tblog,
                                      varscope_name='pursuit_catmlp_policy')

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(obsfeat_space, enable_obsnorm=False,
                                         varscope_name='pursuit_linear_baseline')
    elif args.baseline_type == 'mlp':
        baseline = MLPBaseline(obsfeat_space, args.baseline_hidden_spec, True, True,
                               max_kl=args.vf_max_kl, damping=args.vf_cg_damping,
                               time_scale=1. / args.max_traj_len,
                               varscope_name='pursuit_mlp_baseline')
    else:
        baseline = ZeroBaseline(obsfeat_space)

    if args.sampler == 'simple':
        if args.control == "centralized":
            sampler_cls = SimpleSampler
        elif args.control == "decentralized":
            sampler_cls = DecSampler
        else:
            raise NotImplementedError()
        sampler_args = dict(max_traj_len=args.max_traj_len, n_timesteps=args.n_timesteps,
                            n_timesteps_min=args.n_timesteps_min,
                            n_timesteps_max=args.n_timesteps_max, timestep_rate=args.timestep_rate,
                            adaptive=args.adaptive_batch, enable_rewnorm=True)
    elif args.sampler == 'parallel':
        sampler_cls = ParallelSampler
        sampler_args = dict(max_traj_len=args.max_traj_len, n_timesteps=args.n_timesteps,
                            n_timesteps_min=args.n_timesteps_min,
                            n_timesteps_max=args.n_timesteps_max, timestep_rate=args.timestep_rate,
                            adaptive=args.adaptive_batch, enable_rewnorm=True,
                            n_workers=args.sampler_workers, mode=args.control)
    else:
        raise NotImplementedError()
    step_func = rltools.algos.policyopt.TRPO(max_kl=args.max_kl)
    popt = rltools.algos.policyopt.SamplingPolicyOptimizer(env=env, policy=policy,
                                                           baseline=baseline, step_func=step_func,
                                                           discount=args.discount,
                                                           gae_lambda=args.gae_lambda,
                                                           sampler_cls=sampler_cls,
                                                           sampler_args=sampler_args,
                                                           n_iter=args.n_iter,
                                                           update_curriculum=args.update_curriculum)
    argstr = json.dumps(vars(args), separators=(',', ':'), indent=2)
    rltools.util.header(argstr)
    log_f = rltools.log.TrainingLog(args.log, [('args', argstr)], debug=args.debug)

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        if args.load_checkpoint is not '':
            policy.load_h5(sess, filename, file_key) 
        popt.train(sess, log_f, args.save_freq)
Exemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('filename', type=str) # defaultIS.h5/snapshots/iter0000480
    parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4')
    parser.add_argument('--deterministic', action='store_true', default=False)
    parser.add_argument('--verbose', action='store_true', default=False)
    parser.add_argument('--n_steps', type=int, default=200)
    parser.add_argument('--map_file', type=str, default='')
    args = parser.parse_args()

    # Load file
    filename, file_key = rltools.util.split_h5_name(args.filename)
    print('Loading parameters from {} in {}'.format(file_key, filename))
    with h5py.File(filename, 'r') as f:
        train_args = json.loads(f.attrs['args'])
        dset = f[file_key]

        pprint.pprint(dict(dset.attrs))

    pprint.pprint(train_args)
    if train_args['sample_maps']:
        map_pool = np.load(args.map_file)
    else:
        if train_args['map_type'] == 'rectangle':
            env_map = TwoDMaps.rectangle_map(*map(int, train_args['rectangle'].split(',')))
        elif train_args['map_type'] == 'complex':
            env_map = TwoDMaps.complex_map(*map(int, train_args['rectangle'].split(',')))
        else:
            raise NotImplementedError()
        map_pool = [env_map]

    env = PursuitEvade(map_pool,
                       #n_evaders=train_args['n_evaders'],
                       #n_pursuers=train_args['n_pursuers'],
                       n_evaders=50,
                       n_pursuers=50,
                       obs_range=train_args['obs_range'],
                       n_catch=train_args['n_catch'],
                       urgency_reward=train_args['urgency'],
                       surround=train_args['surround'],
                       sample_maps=train_args['sample_maps'],
                       flatten=train_args['flatten'],
                       reward_mech=train_args['reward_mech']
                       )

    if train_args['control'] == 'decentralized':
        obsfeat_space = env.agents[0].observation_space
        action_space = env.agents[0].action_space
    elif train_args['control'] == 'centralized':
        obsfeat_space = spaces.Box(low=env.agents[0].observation_space.low[0],
                                   high=env.agents[0].observation_space.high[0],
                                   shape=(env.agents[0].observation_space.shape[0] *
                                          len(env.agents),))  # XXX
        action_space = spaces.Discrete(env.agents[0].action_space.n * len(env.agents))

    else:
        raise NotImplementedError()


    policy = CategoricalMLPPolicy(obsfeat_space, action_space,
                                  hidden_spec=train_args['policy_hidden_spec'],
                                  enable_obsnorm=True,
                                  tblog=train_args['tblog'], varscope_name='pursuit_catmlp_policy')


    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        policy.load_h5(sess, filename, file_key)
        if train_args['control'] == 'centralized':
            act_fn = lambda o: policy.sample_actions(np.expand_dims(np.array(o).flatten(),0), deterministic=args.deterministic)[0][0,0]
        elif train_args['control'] == 'decentralized':
            def act_fn(o):
                action_list = []
                for agent_obs in o:
                    a, adist = policy.sample_actions(np.expand_dims(agent_obs,0), deterministic=args.deterministic)
                    action_list.append(a[0, 0])
                return action_list
        #import IPython
        #IPython.embed()
        env.animate(act_fn=act_fn, nsteps=args.n_steps, file_name=args.vid, verbose=args.verbose)
Exemplo n.º 10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('filename',
                        type=str)  # defaultIS.h5/snapshots/iter0000480
    parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4')
    parser.add_argument('--deterministic', action='store_true', default=False)
    parser.add_argument('--verbose', action='store_true', default=False)
    parser.add_argument('--n_steps', type=int, default=200)
    parser.add_argument('--map_file', type=str, default='')
    args = parser.parse_args()

    # Load file
    filename, file_key = rltools.util.split_h5_name(args.filename)
    print('Loading parameters from {} in {}'.format(file_key, filename))
    with h5py.File(filename, 'r') as f:
        train_args = json.loads(f.attrs['args'])
        dset = f[file_key]

        pprint.pprint(dict(dset.attrs))

    pprint.pprint(train_args)
    if train_args['sample_maps']:
        map_pool = np.load(args.map_file)
    else:
        if train_args['map_type'] == 'rectangle':
            env_map = TwoDMaps.rectangle_map(
                *map(int, train_args['rectangle'].split(',')))
        elif train_args['map_type'] == 'complex':
            env_map = TwoDMaps.complex_map(
                *map(int, train_args['rectangle'].split(',')))
        else:
            raise NotImplementedError()
        map_pool = [env_map]

    env = PursuitEvade(
        map_pool,
        #n_evaders=train_args['n_evaders'],
        #n_pursuers=train_args['n_pursuers'],
        n_evaders=50,
        n_pursuers=50,
        obs_range=train_args['obs_range'],
        n_catch=train_args['n_catch'],
        urgency_reward=train_args['urgency'],
        surround=train_args['surround'],
        sample_maps=train_args['sample_maps'],
        flatten=train_args['flatten'],
        reward_mech=train_args['reward_mech'])

    if train_args['control'] == 'decentralized':
        obsfeat_space = env.agents[0].observation_space
        action_space = env.agents[0].action_space
    elif train_args['control'] == 'centralized':
        obsfeat_space = spaces.Box(
            low=env.agents[0].observation_space.low[0],
            high=env.agents[0].observation_space.high[0],
            shape=(env.agents[0].observation_space.shape[0] *
                   len(env.agents), ))  # XXX
        action_space = spaces.Discrete(env.agents[0].action_space.n *
                                       len(env.agents))

    else:
        raise NotImplementedError()

    policy = CategoricalMLPPolicy(obsfeat_space,
                                  action_space,
                                  hidden_spec=train_args['policy_hidden_spec'],
                                  enable_obsnorm=True,
                                  tblog=train_args['tblog'],
                                  varscope_name='pursuit_catmlp_policy')

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        policy.load_h5(sess, filename, file_key)
        if train_args['control'] == 'centralized':
            act_fn = lambda o: policy.sample_actions(
                np.expand_dims(np.array(o).flatten(), 0),
                deterministic=args.deterministic)[0][0, 0]
        elif train_args['control'] == 'decentralized':

            def act_fn(o):
                action_list = []
                for agent_obs in o:
                    a, adist = policy.sample_actions(
                        np.expand_dims(agent_obs, 0),
                        deterministic=args.deterministic)
                    action_list.append(a[0, 0])
                return action_list

        #import IPython
        #IPython.embed()
        env.animate(act_fn=act_fn,
                    nsteps=args.n_steps,
                    file_name=args.vid,
                    verbose=args.verbose)
Exemplo n.º 11
0
    def get_state(self):
        return []

    def set_state(self, *args):
        pass


if __name__ == '__main__':
    from madrl_environments.pursuit import PursuitEvade
    from madrl_environments.pursuit.utils import *
    from vis import Visualizer
    import pprint
    map_mat = TwoDMaps.rectangle_map(16, 16)

    env = PursuitEvade([map_mat], n_evaders=30, n_pursuers=8, obs_range=7, n_catch=4,
                       surround=False, flatten=False)

    policy = PursuitHeuristicPolicy(env.agents[0].observation_space, env.agents[0].action_space)

    for i in range(20):
        rew = 0.0
        obs = env.reset()
        infolist = []
        for _ in xrange(500):
            # env.render()
            act_list = []
            for o in obs:
                a, _ = policy.sample_actions(o)
                act_list.append(a)
            obs, r, done, info = env.step(act_list)
            rew += np.mean(r)