def main(parser):
    mode = parser._mode
    args = parser.args
    env_config = dict(n_walkers=args.n_walkers,
                      position_noise=args.position_noise,
                      angle_noise=args.angle_noise,
                      reward_mech=args.reward_mech,
                      forward_reward=args.forward_reward,
                      fall_reward=args.fall_reward,
                      drop_reward=args.drop_reward,
                      terminate_on_fall=bool(args.terminate_on_fall),
                      one_hot=bool(args.one_hot))
    env = MultiWalkerEnv(**env_config)
    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if mode == 'rllab':
        from runners.rurllab import RLLabRunner
        run = RLLabRunner(env, args)
    elif mode == 'rltools':
        from runners.rurltools import RLToolsRunner
        run = RLToolsRunner(env, args)
    else:
        raise NotImplementedError()

    if args.curriculum:
        curr = Curriculum(args.curriculum)
        run(curr)
    else:
        run()
def main(parser):
    mode = parser._mode
    args = parser.args

    env = MultiAnt(n_legs=args.n_legs,
                   ts=args.ts,
                   integrator=args.integrator,
                   leg_length=args.leg_length,
                   out_file=args.out_file,
                   base_file=args.base_file,
                   reward_mech=args.reward_mech)

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if mode == 'rllab':
        from runners.rurllab import RLLabRunner
        run = RLLabRunner(env, args)
    elif mode == 'rltools':
        from runners.rurltools import RLToolsRunner
        run = RLToolsRunner(env, args)
    else:
        raise NotImplementedError()

    if args.curriculum:
        curr = Curriculum(args.curriculum)
        run(curr)
    else:
        run()
def main(parser):
    mode = parser._mode
    args = parser.args
    env = MAWaterWorld(args.n_pursuers,
                       args.n_evaders,
                       args.n_coop,
                       args.n_poison,
                       radius=args.radius,
                       n_sensors=args.n_sensors,
                       food_reward=args.food_reward,
                       poison_reward=args.poison_reward,
                       encounter_reward=args.encounter_reward,
                       reward_mech=args.reward_mech,
                       sensor_range=args.sensor_range,
                       obstacle_loc=None,
                       addid=True if not args.noid else False,
                       speed_features=bool(args.speed_features))

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if mode == 'rllab':
        from runners.rurllab import RLLabRunner
        run = RLLabRunner(env, args)
    elif mode == 'rltools':
        from runners.rurltools import RLToolsRunner
        run = RLToolsRunner(env, args)
    else:
        raise NotImplementedError()

    if args.curriculum:
        curr = Curriculum(args.curriculum)
        run(curr)
    else:
        run()
Esempio n. 4
0
def main(parser):
    args = parser.args

    if args.map_file:
        map_pool = np.load(args.map_file)
    else:
        if args.map_type == 'rectangle':
            env_map = TwoDMaps.rectangle_map(*map(int, args.map_size.split(',')))
        elif args.map_type == 'complex':
            env_map = TwoDMaps.complex_map(*map(int, args.map_size.split(',')))
        else:
            raise NotImplementedError()
        map_pool = [env_map]

    env = PursuitEvade(map_pool, n_evaders=args.n_evaders, n_pursuers=args.n_pursuers,
                       obs_range=args.obs_range, n_catch=args.n_catch, urgency_reward=args.urgency,
                       surround=bool(args.surround), sample_maps=bool(args.sample_maps),
                       flatten=bool(args.flatten), reward_mech=args.reward_mech, catchr=args.catchr,
                       term_pursuit=args.term_pursuit, include_id=not bool(args.noid))

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    runner = Runner(env, args)
    runner.start_training()
Esempio n. 5
0
def main(parser):
    mode = parser._mode
    args = parser.args
    env_config = dict(n_agents=args.n_agents,
                      speed_noise=args.speed_noise,
                      position_noise=args.position_noise,
                      angle_noise=args.angle_noise,
                      reward_mech=args.reward_mech,
                      rew_arrival=args.rew_arrival,
                      rew_closing=args.rew_closing,
                      rew_nmac=args.rew_nmac,
                      rew_large_turnrate=args.rew_large_turnrate,
                      rew_large_acc=args.rew_large_acc,
                      pen_action_heavy=args.pen_action_heavy,
                      one_hot=bool(args.one_hot))

    env = MultiAircraftEnv(**env_config)
    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if mode == 'rllab':
        from runners.rurllab import RLLabRunner
        run = RLLabRunner(env, args)
    elif mode == 'rltools':
        from runners.rurltools import RLToolsRunner
        run = RLToolsRunner(env, args)
    else:
        raise NotImplementedError()

    if args.curriculum:
        curr = Curriculum(args.curriculum)
        run(curr)
    else:
        run()
Esempio n. 6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('filename',
                        type=str)  # defaultIS.h5/snapshots/iter0000480
    parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4')
    parser.add_argument('--deterministic', action='store_true', default=False)
    parser.add_argument('--heuristic', action='store_true', default=False)
    parser.add_argument('--evaluate', action='store_true', default=False)
    parser.add_argument('--n_trajs', type=int, default=20)
    parser.add_argument('--n_steps', type=int, default=500)
    parser.add_argument('--same_con_pol', action='store_true')
    args = parser.parse_args()

    fh = FileHandler(args.filename)
    env_map = TwoDMaps.rectangle_map(
        *map(int, fh.train_args['map_size'].split(',')))
    map_pool = [env_map]
    # map_pool = np.load(
    # os.path.join('/scratch/megorov/deeprl/MADRL/runners/maps/', os.path.basename(fh.train_args[
    # 'map_file'])))

    env = PursuitEvade(map_pool,
                       n_evaders=fh.train_args['n_evaders'],
                       n_pursuers=fh.train_args['n_pursuers'],
                       obs_range=fh.train_args['obs_range'],
                       n_catch=fh.train_args['n_catch'],
                       urgency_reward=fh.train_args['urgency'],
                       surround=bool(fh.train_args['surround']),
                       sample_maps=bool(fh.train_args['sample_maps']),
                       flatten=bool(fh.train_args['flatten']),
                       reward_mech='global',
                       catchr=fh.train_args['catchr'],
                       term_pursuit=fh.train_args['term_pursuit'])

    if fh.train_args['buffer_size'] > 1:
        env = ObservationBuffer(env, fh.train_args['buffer_size'])

    hpolicy = None
    if args.heuristic:
        from heuristics.pursuit import PursuitHeuristicPolicy
        hpolicy = PursuitHeuristicPolicy(env.agents[0].observation_space,
                                         env.agents[0].action_space)

    if args.evaluate:
        minion = Evaluator(env, fh.train_args, args.n_steps, args.n_trajs,
                           args.deterministic,
                           'heuristic' if args.heuristic else fh.mode)
        evr = minion(fh.filename,
                     file_key=fh.file_key,
                     same_con_pol=args.same_con_pol,
                     hpolicy=hpolicy)
        from tabulate import tabulate
        print(evr)
        print(tabulate(evr, headers='keys'))
    else:
        minion = Visualizer(env, fh.train_args, args.n_steps, args.n_trajs,
                            args.deterministic, fh.mode)
        rew, info = minion(fh.filename, file_key=fh.file_key, vid=args.vid)
        pprint.pprint(rew)
        pprint.pprint(info)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('filename',
                        type=str)  # defaultIS.h5/snapshots/iter0000480
    parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4')
    parser.add_argument('--deterministic', action='store_true', default=False)
    parser.add_argument('--heuristic', action='store_true', default=False)
    parser.add_argument('--evaluate', action='store_true', default=False)
    parser.add_argument('--n_trajs', type=int, default=10)
    parser.add_argument('--n_steps', type=int, default=500)
    parser.add_argument('--same_con_pol', action='store_true')
    args = parser.parse_args()

    fh = FileHandler(args.filename)

    env = MAWaterWorld(
        fh.train_args['n_pursuers'],
        fh.train_args['n_evaders'],
        fh.train_args['n_coop'],
        fh.train_args['n_poison'],
        n_sensors=fh.train_args['n_sensors'],
        food_reward=fh.train_args['food_reward'],
        poison_reward=fh.train_args['poison_reward'],
        reward_mech='global',
        encounter_reward=0,  #fh.train_args['encounter_reward'],
        addid=True,
    )

    if fh.train_args['buffer_size'] > 1:
        env = ObservationBuffer(env, fh.train_args['buffer_size'])

    hpolicy = None
    if args.heuristic:
        from heuristics.waterworld import WaterworldHeuristicPolicy
        hpolicy = WaterworldHeuristicPolicy(env.agents[0].observation_space,
                                            env.agents[0].action_space)

    if args.evaluate:
        minion = Evaluator(env, fh.train_args, args.n_steps, args.n_trajs,
                           args.deterministic,
                           'heuristic' if args.heuristic else fh.mode)
        evr = minion(fh.filename,
                     file_key=fh.file_key,
                     same_con_pol=args.same_con_pol,
                     hpolicy=hpolicy)
        from tabulate import tabulate
        print(tabulate(evr, headers='keys'))
    else:
        minion = Visualizer(env, fh.train_args, args.n_steps, args.n_trajs,
                            args.deterministic,
                            'heuristic' if args.heuristic else fh.mode)
        rew, info = minion(fh.filename,
                           file_key=fh.file_key,
                           vid=args.vid,
                           hpolicy=hpolicy)
        pprint.pprint(rew)
        pprint.pprint(info)
Esempio n. 8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('filename', type=str)
    parser.add_argument('--vid', type=str, default='madrl.mp4')
    parser.add_argument('--deterministic', action='store_true', default=False)
    parser.add_argument('--heuristic', action='store_true', default=False)
    parser.add_argument('--evaluate', action='store_true', default=True)
    parser.add_argument('--n_trajs', type=int, default=20)
    parser.add_argument('--n_steps', type=int, default=20)
    parser.add_argument('--same_con_pol', action='store_true')
    args = parser.parse_args()

    fh = FileHandler(args.filename)

    if fh.train_args['map_file'] is not None:
        map_pool = np.load(
            os.path.join('', os.path.basename(fh.train_args[
                'map_file'])))
    else:
        if fh.train_args['map_type'] == 'rectangle':
            env_map = TwoDMaps.rectangle_map(*map(int, fh.train_args['map_size'].split(',')))
        elif args.map_type == 'complex':
            env_map = TwoDMaps.complex_map(*map(int, fh.train_args['map_size'].split(',')))
        else:
            raise NotImplementedError()
        map_pool = [env_map]

    env = PursuitEvade(map_pool, n_evaders=fh.train_args['n_evaders'],
                       n_pursuers=fh.train_args['n_pursuers'], obs_range=fh.train_args['obs_range'],
                       n_catch=fh.train_args['n_catch'], urgency_reward=fh.train_args['urgency'],
                       surround=bool(fh.train_args['surround']),
                       sample_maps=bool(fh.train_args['sample_maps']),
                       flatten=bool(fh.train_args['flatten']), reward_mech='global',
                       catchr=fh.train_args['catchr'], term_pursuit=fh.train_args['term_pursuit'])

    if fh.train_args['buffer_size'] > 1:
        env = ObservationBuffer(env, fh.train_args['buffer_size'])

    hpolicy = None
    if args.evaluate:
        minion = Evaluator(env, fh.train_args, args.n_steps, args.n_trajs, args.deterministic,
                           'heuristic' if args.heuristic else fh.mode)
        evr = minion(fh.filename, file_key=fh.file_key, same_con_pol=args.same_con_pol,
                     hpolicy=hpolicy)
        print(evr)
        print(tabulate(evr, headers='keys'))
    else:
        minion = Visualizer(env, fh.train_args, args.n_steps, args.n_trajs, args.deterministic,
                            fh.mode)
        rew, info = minion(fh.filename, file_key=fh.file_key, vid=args.vid)
        pprint.pprint(rew)
        pprint.pprint(info)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('filename',
                        type=str)  # defaultIS.h5/snapshots/iter0000480
    parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4')
    parser.add_argument('--deterministic', action='store_true', default=False)
    parser.add_argument('--heuristic', action='store_true', default=False)
    parser.add_argument('--evaluate', action='store_true', default=False)
    parser.add_argument('--save_file', type=str, default=None)
    parser.add_argument('--n_trajs', type=int, default=20)
    parser.add_argument('--n_steps', type=int, default=500)
    parser.add_argument('--same_con_pol', action='store_true')
    args = parser.parse_args()

    fh = FileHandler(args.filename)

    env = MultiAnt(n_legs=fh.train_args['n_legs'],
                   ts=fh.train_args['ts'],
                   integrator=fh.train_args['integrator'],
                   out_file=fh.train_args['out_file'],
                   base_file=fh.train_args['base_file'],
                   reward_mech=fh.train_args['reward_mech'])

    if fh.train_args['buffer_size'] > 1:
        env = ObservationBuffer(env, fh.train_args['buffer_size'])

    hpolicy = None
    if args.evaluate:
        minion = Evaluator(env, fh.train_args, args.n_steps, args.n_trajs,
                           args.deterministic,
                           'heuristic' if args.heuristic else fh.mode)
        evr = minion(fh.filename,
                     file_key=fh.file_key,
                     same_con_pol=args.same_con_pol,
                     hpolicy=hpolicy)
        if args.save_file:
            pickle.dump(evr, open(args.save_file, "wb"))
        from tabulate import tabulate
        #print(tabulate(evr, headers='keys'))

    else:
        minion = Visualizer(env, fh.train_args, args.n_steps, args.n_trajs,
                            args.deterministic, fh.mode)
        rew, info = minion(fh.filename, file_key=fh.file_key, vid=args.vid)
        pprint.pprint(rew)
        pprint.pprint(info)
Esempio n. 10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('filename',
                        type=str)  # defaultIS.h5/snapshots/iter0000480
    parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4')
    parser.add_argument('--deterministic', action='store_true', default=False)
    parser.add_argument('--heuristic', action='store_true', default=False)
    parser.add_argument('--evaluate', action='store_true', default=False)
    parser.add_argument('--n_trajs', type=int, default=20)
    parser.add_argument('--n_steps', type=int, default=500)
    parser.add_argument('--same_con_pol', action='store_true')
    args = parser.parse_args()

    fh = FileHandler(args.filename)

    env = MultiWalkerEnv(fh.train_args['n_walkers'],
                         fh.train_args['position_noise'],
                         fh.train_args['angle_noise'],
                         reward_mech='global')  #fh.train_args['reward_mech'])

    if fh.train_args['buffer_size'] > 1:
        env = ObservationBuffer(env, fh.train_args['buffer_size'])

    hpolicy = None
    if args.heuristic:
        from heuristics.multiwalker import MultiwalkerHeuristicPolicy
        hpolicy = MultiwalkerHeuristicPolicy(env.agents[0].observation_space,
                                             env.agents[0].action_space)

    if args.evaluate:
        minion = Evaluator(env, fh.train_args, args.n_steps, args.n_trajs,
                           args.deterministic,
                           'heuristic' if args.heuristic else fh.mode)
        evr = minion(fh.filename,
                     file_key=fh.file_key,
                     same_con_pol=args.same_con_pol,
                     hpolicy=hpolicy)
        from tabulate import tabulate
        print(tabulate(evr, headers='keys'))
    else:
        minion = Visualizer(env, fh.train_args, args.n_steps, args.n_trajs,
                            args.deterministic, fh.mode)
        rew, info = minion(fh.filename, file_key=fh.file_key, vid=args.vid)
        pprint.pprint(rew)
        pprint.pprint(info)
Esempio n. 11
0
def main(parser):
    mode = parser._mode
    args = parser.args

    if args.map_file:
        map_pool = np.load(args.map_file)
    else:
        if args.map_type == 'rectangle':
            env_map = TwoDMaps.rectangle_map(
                *map(int, args.map_size.split(',')))
        elif args.map_type == 'complex':
            env_map = TwoDMaps.complex_map(*map(int, args.map_size.split(',')))
        else:
            raise NotImplementedError()
        map_pool = [env_map]

    env = PursuitEvade(map_pool,
                       n_evaders=args.n_evaders,
                       n_pursuers=args.n_pursuers,
                       obs_range=args.obs_range,
                       n_catch=args.n_catch,
                       urgency_reward=args.urgency,
                       surround=bool(args.surround),
                       sample_maps=bool(args.sample_maps),
                       flatten=bool(args.flatten),
                       reward_mech=args.reward_mech,
                       catchr=args.catchr,
                       term_pursuit=args.term_pursuit)

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if mode == 'rllab':
        from runners.rurllab import RLLabRunner
        run = RLLabRunner(env, args)
    elif mode == 'rltools':
        from runners.rurltools import RLToolsRunner
        run = RLToolsRunner(env, args)
    else:
        raise NotImplementedError()

    run()
Esempio n. 12
0
def main(parser):
    mode = parser._mode
    args = parser.args

    if args.map_file:
        map_pool = np.load(args.map_file)
    else:
        if args.map_type == 'rectangle':
            #passes in tuple of what map should be 
            env_map = TwoDMaps.rectangle_map(*map(int, args.map_size.split(',')))
        elif args.map_type == 'complex':
            env_map = TwoDMaps.complex_map(*map(int, args.map_size.split(',')))
        else:
            raise NotImplementedError()
        #map pool is list of maps of different shapes for environment 
        map_pool = [env_map]


    env = sniper(map_pool, n_targets=args.n_targets, n_snipers=args.n_snipers,
                       obs_range=args.obs_range, n_catch=args.n_catch,
                       urgency_reward=args.urgency,
                       surround=bool(args.surround), sample_maps=bool(args.sample_maps),
                       flatten=bool(args.flatten),
                       reward_mech=args.reward_mech,
                       catchr=args.catchr,
                       term_sniper=args.term_sniper)

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if mode == 'rllab':
        from runners.rurllab import RLLabRunner
        run = RLLabRunner(env, args)
    elif mode == 'rltools':
        from runners.rurltools import RLToolsRunner
        run = RLToolsRunner(env, args)
    else:
        raise NotImplementedError()

    run()
Esempio n. 13
0
def main(parser):
    mode = parser._mode
    args = parser.args

    env = ContinuousHostageWorld(
        n_good=args.n_good,
        n_hostages=args.n_hostages,
        n_bad=args.n_bad,
        n_coop_save=args.n_coop_save,
        n_coop_avoid=args.n_coop_avoid,
        radius=args.radius,
        key_loc=args.key_loc,
        bad_speed=args.bad_speed,
        n_sensors=args.n_sensors,
        sensor_range=args.sensor_range,
        save_reward=args.save_reward,
        hit_reward=args.hit_reward,
        encounter_reward=args.encounter_reward,
        bomb_reward=args.bomb_reward,
        bomb_radius=args.bomb_radius,
        control_penalty=args.control_penalty,
        reward_mech=args.reward_mech,
    )

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if mode == 'rllab':
        from runners.rurllab import RLLabRunner
        run = RLLabRunner(env, args)
    elif mode == 'rltools':
        from runners.rurltools import RLToolsRunner
        run = RLToolsRunner(env, args)
    else:
        raise NotImplementedError()

    run()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler', type=str, default='simple')
    parser.add_argument('--sampler_workers', type=int, default=4)
    parser.add_argument('--max_traj_len', type=int, default=500)
    parser.add_argument('--adaptive_batch', action='store_true', default=False)

    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--n_timesteps_min', type=int, default=1000)
    parser.add_argument('--n_timesteps_max', type=int, default=64000)
    parser.add_argument('--timestep_rate', type=int, default=20)

    parser.add_argument('--is_n_backtrack', type=int, default=1)
    parser.add_argument('--is_randomize_draw',
                        action='store_true',
                        default=False)
    parser.add_argument('--is_n_pretrain', type=int, default=0)
    parser.add_argument('--is_skip_is', action='store_true', default=False)
    parser.add_argument('--is_max_is_ratio', type=float, default=0)

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--radius', type=float, default=0.015)
    parser.add_argument('--n_evaders', type=int, default=10)
    parser.add_argument('--n_pursuers', type=int, default=8)
    parser.add_argument('--n_poison', type=int, default=10)
    parser.add_argument('--n_coop', type=int, default=4)
    parser.add_argument('--n_sensors', type=int, default=30)
    parser.add_argument('--sensor_range', type=str, default='0.2')
    parser.add_argument('--food_reward', type=float, default=5)
    parser.add_argument('--poison_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.05)
    parser.add_argument('--reward_mech', type=str, default='local')

    parser.add_argument('--recurrent', action='store_true', default=False)
    parser.add_argument('--policy_hidden_spec', type=str, default='GAE_ARCH')

    parser.add_argument('--baseline_type', type=str, default='mlp')
    parser.add_argument('--baseline_hidden_spec', type=str, default='GAE_ARCH')

    parser.add_argument('--max_kl', type=float, default=0.01)
    parser.add_argument('--vf_max_kl', type=float, default=0.01)
    parser.add_argument('--vf_cg_damping', type=float, default=0.01)

    parser.add_argument('--save_freq', type=int, default=20)
    parser.add_argument('--log', type=str, required=False)
    parser.add_argument('--tblog',
                        type=str,
                        default='/tmp/madrl_tb_{}'.format(uuid.uuid4()))
    parser.add_argument('--debug', dest='debug', action='store_true')
    parser.add_argument('--no-debug', dest='debug', action='store_false')
    parser.set_defaults(debug=True)

    args = parser.parse_args()

    centralized = True if args.control == 'centralized' else False
    if args.recurrent:
        args.policy_hidden_spec = 'SIMPLE_GRU_ARCH'
    args.policy_hidden_spec = get_arch(args.policy_hidden_spec)
    args.baseline_hidden_spec = get_arch(args.baseline_hidden_spec)

    sensor_range = np.array(map(float, args.sensor_range.split(',')))
    if len(sensor_range) == 1:
        sensor_range = sensor_range[0]
    else:
        assert sensor_range.shape == (args.n_pursuers, )

    env = MAWaterWorld(args.n_pursuers,
                       args.n_evaders,
                       args.n_coop,
                       args.n_poison,
                       radius=args.radius,
                       n_sensors=args.n_sensors,
                       food_reward=args.food_reward,
                       poison_reward=args.poison_reward,
                       encounter_reward=args.encounter_reward,
                       reward_mech=args.reward_mech,
                       sensor_range=sensor_range,
                       obstacle_loc=None)

    # env = StandardizedEnv(
    #     MAWaterWorld(args.n_pursuers, args.n_evaders, args.n_coop, args.n_poison,
    #                  radius=args.radius, n_sensors=args.n_sensors, food_reward=args.food_reward,
    #                  poison_reward=args.poison_reward, encounter_reward=args.encounter_reward,
    #                  sensor_range=sensor_range, obstacle_loc=None), enable_obsnorm=True,
    #     enable_rewnorm=True)

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if centralized:
        obsfeat_space = spaces.Box(
            low=env.agents[0].observation_space.low[0],
            high=env.agents[0].observation_space.high[0],
            shape=(env.agents[0].observation_space.shape[0] *
                   len(env.agents), ))  # XXX
        action_space = spaces.Box(low=env.agents[0].action_space.low[0],
                                  high=env.agents[0].action_space.high[0],
                                  shape=(env.agents[0].action_space.shape[0] *
                                         len(env.agents), ))  # XXX
    else:
        obsfeat_space = env.agents[0].observation_space
        action_space = env.agents[0].action_space

    if args.recurrent:
        policy = GaussianGRUPolicy(obsfeat_space,
                                   action_space,
                                   hidden_spec=args.policy_hidden_spec,
                                   min_stdev=0.,
                                   init_logstdev=0.,
                                   enable_obsnorm=False,
                                   state_include_action=False,
                                   tblog=args.tblog,
                                   varscope_name='gaussgru_policy')
    else:
        policy = GaussianMLPPolicy(obsfeat_space,
                                   action_space,
                                   hidden_spec=args.policy_hidden_spec,
                                   enable_obsnorm=True,
                                   min_stdev=0.,
                                   init_logstdev=0.,
                                   tblog=args.tblog,
                                   varscope_name='gaussmlp_policy')
    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(
            obsfeat_space,
            enable_obsnorm=False,
            varscope_name='pursuit_linear_baseline')
    elif args.baseline_type == 'mlp':
        baseline = MLPBaseline(obsfeat_space,
                               args.baseline_hidden_spec,
                               enable_obsnorm=True,
                               enable_vnorm=True,
                               max_kl=args.vf_max_kl,
                               damping=args.vf_cg_damping,
                               time_scale=1. / args.max_traj_len,
                               varscope_name='pursuit_mlp_baseline')
    else:
        baseline = ZeroBaseline(obsfeat_space)

    if args.sampler == 'simple':
        if centralized:
            sampler_cls = SimpleSampler
        elif args.control == 'decentralized':
            sampler_cls = DecSampler
        else:
            raise NotImplementedError()
        sampler_args = dict(max_traj_len=args.max_traj_len,
                            n_timesteps=args.n_timesteps,
                            n_timesteps_min=args.n_timesteps_min,
                            n_timesteps_max=args.n_timesteps_max,
                            timestep_rate=args.timestep_rate,
                            adaptive=args.adaptive_batch)
    elif args.sampler == 'parallel':
        sampler_cls = ParallelSampler
        sampler_args = dict(max_traj_len=args.max_traj_len,
                            n_timesteps=args.n_timesteps,
                            n_timesteps_min=args.n_timesteps_min,
                            n_timesteps_max=args.n_timesteps_max,
                            timestep_rate=args.timestep_rate,
                            adaptive=args.adaptive_batch,
                            n_workers=args.sampler_workers,
                            mode=args.control)

    elif args.sampler == 'imp':
        sampler_cls = ImportanceWeightedSampler
        sampler_args = dict(max_traj_len=args.max_traj_len,
                            n_timesteps=args.n_timesteps,
                            n_timesteps_min=args.n_timesteps_min,
                            n_timesteps_max=args.n_timesteps_max,
                            timestep_rate=args.timestep_rate,
                            adaptive=args.adaptive_batch,
                            n_backtrack=args.is_n_backtrack,
                            randomize_draw=args.is_randomize_draw,
                            n_pretrain=args.is_n_pretrain,
                            skip_is=args.is_skip_is,
                            max_is_ratio=args.is_max_is_ratio)
    else:
        raise NotImplementedError()
    step_func = rltools.algos.policyopt.TRPO(max_kl=args.max_kl)
    popt = rltools.algos.policyopt.SamplingPolicyOptimizer(
        env=env,
        policy=policy,
        baseline=baseline,
        step_func=step_func,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        sampler_cls=sampler_cls,
        sampler_args=sampler_args,
        n_iter=args.n_iter)
    argstr = json.dumps(vars(args), separators=(',', ':'), indent=2)
    rltools.util.header(argstr)
    log_f = rltools.log.TrainingLog(args.log, [('args', argstr)],
                                    debug=args.debug)

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        popt.train(sess, log_f, args.save_freq)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)

    parser.add_argument('--interp_alpha', type=float, default=0.5)
    parser.add_argument('--policy_avg_weights',
                        type=str,
                        default='0.33,0.33,0.33')

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler', type=str, default='simple')
    parser.add_argument('--sampler_workers', type=int, default=4)
    parser.add_argument('--max_traj_len', type=int, default=500)
    parser.add_argument('--adaptive_batch', action='store_true', default=False)

    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--n_timesteps_min', type=int, default=1000)
    parser.add_argument('--n_timesteps_max', type=int, default=64000)
    parser.add_argument('--timestep_rate', type=int, default=20)

    parser.add_argument('--is_n_backtrack', type=int, default=1)
    parser.add_argument('--is_randomize_draw',
                        action='store_true',
                        default=False)
    parser.add_argument('--is_n_pretrain', type=int, default=0)
    parser.add_argument('--is_skip_is', action='store_true', default=False)
    parser.add_argument('--is_max_is_ratio', type=float, default=0)

    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--n_evaders', type=int, default=5)
    parser.add_argument('--n_pursuers', type=int, default=3)
    parser.add_argument('--n_poison', type=int, default=10)
    parser.add_argument('--n_coop', type=int, default=2)
    parser.add_argument('--n_sensors', type=int, default=30)
    parser.add_argument('--sensor_range', type=str, default='0.2,0.2,0.2')
    parser.add_argument('--food_reward', type=float, default=3)
    parser.add_argument('--poison_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.05)

    parser.add_argument('--policy_hidden_spec', type=str, default=GAE_ARCH)
    parser.add_argument('--blend_freq', type=int, default=20)

    parser.add_argument('--baseline_type', type=str, default='mlp')
    parser.add_argument('--baseline_hidden_spec', type=str, default=GAE_ARCH)

    parser.add_argument('--max_kl', type=float, default=0.01)
    parser.add_argument('--vf_max_kl', type=float, default=0.01)
    parser.add_argument('--vf_cg_damping', type=float, default=0.01)

    parser.add_argument('--save_freq', type=int, default=20)
    parser.add_argument('--log', type=str, required=False)
    parser.add_argument('--tblog', type=str, default='/tmp/madrl_tb')
    parser.add_argument('--debug', dest='debug', action='store_true')
    parser.add_argument('--no-debug', dest='debug', action='store_false')
    parser.set_defaults(debug=True)

    args = parser.parse_args()

    sensor_range = np.array(map(float, args.sensor_range.split(',')))
    assert sensor_range.shape == (args.n_pursuers, )

    policy_avg_weights = np.array(
        map(float, args.policy_avg_weights.split(',')))
    assert len(policy_avg_weights) == args.n_pursuers

    env = MAWaterWorld(args.n_pursuers,
                       args.n_evaders,
                       args.n_coop,
                       args.n_poison,
                       n_sensors=args.n_sensors,
                       food_reward=args.food_reward,
                       poison_reward=args.poison_reward,
                       encounter_reward=args.encounter_reward,
                       sensor_range=sensor_range,
                       obstacle_loc=None)

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    policies = [
        GaussianMLPPolicy(agent.observation_space,
                          agent.action_space,
                          hidden_spec=args.policy_hidden_spec,
                          enable_obsnorm=True,
                          min_stdev=0.,
                          init_logstdev=0.,
                          tblog=args.tblog,
                          varscope_name='gaussmlp_policy_{}'.format(agid))
        for agid, agent in enumerate(env.agents)
    ]

    if args.blend_freq:
        assert all([
            agent.observation_space == env.agents[0].observation_space
            for agent in env.agents
        ])
        target_policy = GaussianMLPPolicy(
            env.agents[0].observation_space,
            env.agents[0].action_space,
            hidden_spec=args.policy_hidden_spec,
            enable_obsnorm=True,
            min_stdev=0.,
            init_logstdev=0.,
            tblog=args.tblog,
            varscope_name='targetgaussmlp_policy')
    else:
        target_policy = None

    if args.baseline_type == 'linear':
        baselines = [
            LinearFeatureBaseline(
                agent.observation_space,
                enable_obsnorm=True,
                varscope_name='linear_baseline_{}'.format(agid))
            for agid, agent in enumerate(env.agents)
        ]
    elif args.baseline_type == 'mlp':
        baselines = [
            MLPBaseline(agent.observation_space,
                        args.baseline_hidden_spec,
                        enable_obsnorm=True,
                        enable_vnorm=True,
                        max_kl=args.vf_max_kl,
                        damping=args.vf_cg_damping,
                        time_scale=1. / args.max_traj_len,
                        varscope_name='mlp_baseline_{}'.format(agid))
            for agid, agent in enumerate(env.agents)
        ]
    else:
        baselines = [
            ZeroBaseline(agent.observation_space) for agent in env.agents
        ]

    if args.sampler == 'parallel':
        sampler_cls = ParallelSampler
        sampler_args = dict(max_traj_len=args.max_traj_len,
                            n_timesteps=args.n_timesteps,
                            n_timesteps_min=args.n_timesteps_min,
                            n_timesteps_max=args.n_timesteps_max,
                            timestep_rate=args.timestep_rate,
                            adaptive=args.adaptive_batch,
                            enable_rewnorm=True,
                            n_workers=args.sampler_workers,
                            mode='concurrent')

    else:
        raise NotImplementedError()

    step_func = rltools.algos.policyopt.TRPO(max_kl=args.max_kl)

    popt = rltools.algos.policyopt.ConcurrentPolicyOptimizer(
        env=env,
        policies=policies,
        baselines=baselines,
        step_func=step_func,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        sampler_cls=sampler_cls,
        sampler_args=sampler_args,
        n_iter=args.n_iter,
        target_policy=target_policy,
        weights=policy_avg_weights,
        interp_alpha=args.interp_alpha)

    argstr = json.dumps(vars(args), separators=(',', ':'), indent=2)
    rltools.util.header(argstr)
    log_f = rltools.log.TrainingLog(args.log, [('args', argstr)],
                                    debug=args.debug)

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        popt.train(sess, log_f, args.blend_freq, args.save_freq)
Esempio n. 16
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)

    parser.add_argument('--n_iter', type=int, default=500)
    parser.add_argument('--sampler', type=str, default='simple')
    parser.add_argument('--sampler_workers', type=int, default=2)
    parser.add_argument('--max_traj_len', type=int, default=1500)
    parser.add_argument('--adaptive_batch', action='store_true', default=False)

    parser.add_argument('--n_timesteps', type=int, default=8000)
    parser.add_argument('--n_timesteps_min', type=int, default=1000)
    parser.add_argument('--n_timesteps_max', type=int, default=64000)
    parser.add_argument('--timestep_rate', type=int, default=20)

    parser.add_argument('--is_n_backtrack', type=int, default=1)
    parser.add_argument('--is_randomize_draw',
                        action='store_true',
                        default=False)
    parser.add_argument('--is_n_pretrain', type=int, default=0)
    parser.add_argument('--is_skip_is', action='store_true', default=False)
    parser.add_argument('--is_max_is_ratio', type=float, default=0)

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--n_good', type=int, default=3)
    parser.add_argument('--n_hostage', type=int, default=5)
    parser.add_argument('--n_bad', type=int, default=5)
    parser.add_argument('--n_coop_save', type=int, default=2)
    parser.add_argument('--n_coop_avoid', type=int, default=2)
    parser.add_argument('--n_sensors', type=int, default=20)
    parser.add_argument('--sensor_range', type=float, default=0.2)
    parser.add_argument('--save_reward', type=float, default=3)
    parser.add_argument('--hit_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.01)
    parser.add_argument('--bomb_reward', type=float, default=-10.)

    parser.add_argument('--policy_hidden_spec', type=str, default='GAE_ARCH')
    parser.add_argument('--min_std', type=float, default=0)

    parser.add_argument('--baseline_type', type=str, default='mlp')
    parser.add_argument('--baseline_hidden_spec', type=str, default='GAE_ARCH')

    parser.add_argument('--max_kl', type=float, default=0.01)
    parser.add_argument('--vf_max_kl', type=float, default=0.01)
    parser.add_argument('--vf_cg_damping', type=float, default=0.01)

    parser.add_argument('--save_freq', type=int, default=20)
    parser.add_argument('--log', type=str, required=False)
    parser.add_argument('--tblog',
                        type=str,
                        default='/tmp/madrl_tb_{}'.format(uuid.uuid4()))
    parser.add_argument('--debug', dest='debug', action='store_true')
    parser.add_argument('--no-debug', dest='debug', action='store_false')
    parser.set_defaults(debug=True)

    args = parser.parse_args()

    args.policy_hidden_spec = get_arch(args.policy_hidden_spec)
    args.baseline_hidden_spec = get_arch(args.baseline_hidden_spec)

    centralized = True if args.control == 'centralized' else False

    env = ContinuousHostageWorld(args.n_good,
                                 args.n_hostage,
                                 args.n_bad,
                                 args.n_coop_save,
                                 args.n_coop_avoid,
                                 n_sensors=args.n_sensors,
                                 sensor_range=args.sensor_range,
                                 save_reward=args.save_reward,
                                 hit_reward=args.hit_reward,
                                 encounter_reward=args.encounter_reward,
                                 bomb_reward=args.bomb_reward)

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if centralized:
        obsfeat_space = spaces.Box(
            low=env.agents[0].observation_space.low[0],
            high=env.agents[0].observation_space.high[0],
            shape=(env.agents[0].observation_space.shape[0] *
                   len(env.agents), ))  # XXX
        action_space = spaces.Box(low=env.agents[0].action_space.low[0],
                                  high=env.agents[0].action_space.high[0],
                                  shape=(env.agents[0].action_space.shape[0] *
                                         len(env.agents), ))  # XXX
    else:
        obsfeat_space = env.agents[0].observation_space
        action_space = env.agents[0].action_space

    policy = GaussianMLPPolicy(obsfeat_space,
                               action_space,
                               hidden_spec=args.policy_hidden_spec,
                               enable_obsnorm=True,
                               min_stdev=args.min_std,
                               init_logstdev=0.,
                               tblog=args.tblog,
                               varscope_name='gaussmlp_policy')

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(
            obsfeat_space,
            enable_obsnorm=True,
            varscope_name='pursuit_linear_baseline')
    elif args.baseline_type == 'mlp':
        baseline = MLPBaseline(obsfeat_space,
                               args.baseline_hidden_spec,
                               enable_obsnorm=True,
                               enable_vnorm=True,
                               max_kl=args.vf_max_kl,
                               damping=args.vf_cg_damping,
                               time_scale=1. / args.max_traj_len,
                               varscope_name='pursuit_mlp_baseline')
    else:
        baseline = ZeroBaseline(obsfeat_space)

    if args.sampler == 'simple':
        if centralized:
            sampler_cls = SimpleSampler
        elif args.control == 'decentralized':
            sampler_cls = DecSampler
        else:
            raise NotImplementedError()
        sampler_args = dict(max_traj_len=args.max_traj_len,
                            n_timesteps=args.n_timesteps,
                            n_timesteps_min=args.n_timesteps_min,
                            n_timesteps_max=args.n_timesteps_max,
                            timestep_rate=args.timestep_rate,
                            adaptive=args.adaptive_batch,
                            enable_rewnorm=True)
    elif args.sampler == 'thread':
        sampler_cls = ThreadedSampler
        sampler_args = dict(max_traj_len=args.max_traj_len,
                            n_timesteps=args.n_timesteps,
                            n_timesteps_min=args.n_timesteps_min,
                            n_timesteps_max=args.n_timesteps_max,
                            timestep_rate=args.timestep_rate,
                            adaptive=args.adaptive_batch,
                            enable_rewnorm=True)
    elif args.sampler == 'parallel':
        sampler_cls = ParallelSampler
        sampler_args = dict(max_traj_len=args.max_traj_len,
                            n_timesteps=args.n_timesteps,
                            n_timesteps_min=args.n_timesteps_min,
                            n_timesteps_max=args.n_timesteps_max,
                            timestep_rate=args.timestep_rate,
                            adaptive=args.adaptive_batch,
                            n_workers=args.sampler_workers,
                            mode=args.control)
    elif args.sampler == 'imp':
        sampler_cls = ImportanceWeightedSampler
        sampler_args = dict(max_traj_len=args.max_traj_len,
                            n_timesteps=args.n_timesteps,
                            n_timesteps_min=args.n_timesteps_min,
                            n_timesteps_max=args.n_timesteps_max,
                            timestep_rate=args.timestep_rate,
                            adaptive=args.adaptive_batch,
                            enable_rewnorm=True,
                            n_backtrack=args.is_n_backtrack,
                            randomize_draw=args.is_randomize_draw,
                            n_pretrain=args.is_n_pretrain,
                            skip_is=args.is_skip_is,
                            max_is_ratio=args.is_max_is_ratio)
    else:
        raise NotImplementedError()
    step_func = rltools.algos.policyopt.TRPO(max_kl=args.max_kl)
    popt = rltools.algos.policyopt.SamplingPolicyOptimizer(
        env=env,
        policy=policy,
        baseline=baseline,
        step_func=step_func,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        sampler_cls=sampler_cls,
        sampler_args=sampler_args,
        n_iter=args.n_iter)
    argstr = json.dumps(vars(args), separators=(',', ':'), indent=2)
    rltools.util.header(argstr)
    log_f = rltools.log.TrainingLog(args.log, [('args', argstr)],
                                    debug=args.debug)

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        popt.train(sess, log_f, args.save_freq)
Esempio n. 17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('filename',
                        type=str)  # defaultIS.h5/snapshots/iter0000480
    parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4')
    parser.add_argument('--deterministic', action='store_true', default=False)
    parser.add_argument('--n_steps', type=int, default=1000)
    args = parser.parse_args()

    # Load file
    filename, file_key = rltools.util.split_h5_name(args.filename)
    print('Loading parameters from {} in {}'.format(file_key, filename))
    with h5py.File(filename, 'r') as f:
        train_args = json.loads(f.attrs['args'])
        dset = f[file_key]

        pprint.pprint(dict(dset.attrs))

    centralized = True if train_args['control'] == 'centralized' else False
    env = ContinuousHostageWorld(
        train_args['n_good'],
        train_args['n_hostage'],
        train_args['n_bad'],
        train_args['n_coop_save'],
        train_args['n_coop_avoid'],
        n_sensors=train_args['n_sensors'],
        sensor_range=train_args['sensor_range'],
        save_reward=train_args['save_reward'],
        hit_reward=train_args['hit_reward'],
        encounter_reward=train_args['encounter_reward'],
        bomb_reward=train_args['bomb_reward'],
    )

    if train_args['buffer_size'] > 1:
        env = ObservationBuffer(env, train_args['buffer_size'])

    if centralized:
        obsfeat_space = spaces.Box(
            low=env.agents[0].observation_space.low[0],
            high=env.agents[0].observation_space.high[0],
            shape=(env.agents[0].observation_space.shape[0] *
                   len(env.agents), ))  # XXX
        action_space = spaces.Box(low=env.agents[0].action_space.low[0],
                                  high=env.agents[0].action_space.high[0],
                                  shape=(env.agents[0].action_space.shape[0] *
                                         len(env.agents), ))  # XXX
    else:
        obsfeat_space = env.agents[0].observation_space
        action_space = env.agents[0].action_space

    policy = GaussianMLPPolicy(obsfeat_space,
                               action_space,
                               hidden_spec=train_args['policy_hidden_spec'],
                               enable_obsnorm=True,
                               min_stdev=0.,
                               init_logstdev=0.,
                               tblog=train_args['tblog'],
                               varscope_name='gaussmlp_policy')

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        policy.load_h5(sess, filename, file_key)

        rew = env.animate(act_fn=lambda o: policy.sample_actions(
            sess, o[None, ...], deterministic=args.deterministic),
                          nsteps=args.n_steps,
                          file_name=args.vid)
        print(rew)
def main():
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')

    parser.add_argument('--discount', type=float, default=0.95)
    parser.add_argument('--gae_lambda', type=float, default=0.99)
    parser.add_argument('--reward_scale', type=float, default=1.0)
    parser.add_argument('--enable_obsnorm', action='store_true', default=False)
    parser.add_argument('--chunked', action='store_true', default=False)

    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--sampler_workers', type=int, default=1)
    parser.add_argument('--max_traj_len', type=int, default=250)
    parser.add_argument('--update_curriculum',
                        action='store_true',
                        default=False)
    parser.add_argument('--anneal_step_size', type=int, default=0)

    parser.add_argument('--n_timesteps', type=int, default=8000)

    parser.add_argument('--control', type=str, default='centralized')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--radius', type=float, default=0.015)
    parser.add_argument('--n_evaders', type=int, default=10)
    parser.add_argument('--n_pursuers', type=int, default=8)
    parser.add_argument('--n_poison', type=int, default=10)
    parser.add_argument('--n_coop', type=int, default=4)
    parser.add_argument('--n_sensors', type=int, default=30)
    parser.add_argument('--sensor_range', type=str, default='0.2')
    parser.add_argument('--food_reward', type=float, default=5)
    parser.add_argument('--poison_reward', type=float, default=-1)
    parser.add_argument('--encounter_reward', type=float, default=0.05)
    parser.add_argument('--reward_mech', type=str, default='local')

    parser.add_argument('--recurrent', type=str, default=None)
    parser.add_argument('--baseline_type', type=str, default='linear')
    parser.add_argument('--policy_hidden_sizes', type=str, default='128,128')
    parser.add_argument('--baseline_hidden_sizes', type=str, default='128,128')

    parser.add_argument('--max_kl', type=float, default=0.01)

    parser.add_argument('--log_dir', type=str, required=False)
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )

    args = parser.parse_args()

    parallel_sampler.initialize(n_parallel=args.sampler_workers)

    if args.seed is not None:
        set_seed(args.seed)
        parallel_sampler.set_seed(args.seed)

    args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(',')))

    centralized = True if args.control == 'centralized' else False

    sensor_range = np.array(map(float, args.sensor_range.split(',')))
    if len(sensor_range) == 1:
        sensor_range = sensor_range[0]
    else:
        assert sensor_range.shape == (args.n_pursuers, )

    env = MAWaterWorld(args.n_pursuers,
                       args.n_evaders,
                       args.n_coop,
                       args.n_poison,
                       radius=args.radius,
                       n_sensors=args.n_sensors,
                       food_reward=args.food_reward,
                       poison_reward=args.poison_reward,
                       encounter_reward=args.encounter_reward,
                       reward_mech=args.reward_mech,
                       sensor_range=sensor_range,
                       obstacle_loc=None)

    env = TfEnv(
        RLLabEnv(StandardizedEnv(env,
                                 scale_reward=args.reward_scale,
                                 enable_obsnorm=args.enable_obsnorm),
                 mode=args.control))

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if args.recurrent:
        feature_network = MLP(
            name='feature_net',
            input_shape=(env.spec.observation_space.flat_dim +
                         env.spec.action_space.flat_dim, ),
            output_dim=16,
            hidden_sizes=(128, 64, 32),
            hidden_nonlinearity=tf.nn.tanh,
            output_nonlinearity=None)
        if args.recurrent == 'gru':
            policy = GaussianGRUPolicy(env_spec=env.spec,
                                       feature_network=feature_network,
                                       hidden_dim=int(
                                           args.policy_hidden_sizes),
                                       name='policy')
        elif args.recurrent == 'lstm':
            policy = GaussianLSTMPolicy(env_spec=env.spec,
                                        feature_network=feature_network,
                                        hidden_dim=int(
                                            args.policy_hidden_sizes),
                                        name='policy')
    else:
        policy = GaussianMLPPolicy(
            name='policy',
            env_spec=env.spec,
            hidden_sizes=tuple(map(int, args.policy_hidden_sizes.split(','))),
            min_std=10e-5)

    if args.baseline_type == 'linear':
        baseline = LinearFeatureBaseline(env_spec=env.spec)
    elif args.baseline_type == 'mlp':
        raise NotImplementedError()
        # baseline = GaussianMLPBaseline(
        #     env_spec=env.spec, hidden_sizes=tuple(map(int, args.baseline_hidden_sizes.split(','))))
    else:
        baseline = ZeroBaseline(env_spec=env.spec)

    # logger
    default_log_dir = config.LOG_DIR
    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    logger.log_parameters_lite(params_log_file, args)
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        batch_size=args.n_timesteps,
        max_path_length=args.max_traj_len,
        #max_path_length_limit=args.max_path_length_limit,
        update_max_path_length=args.update_curriculum,
        anneal_step_size=args.anneal_step_size,
        n_itr=args.n_iter,
        discount=args.discount,
        gae_lambda=args.gae_lambda,
        step_size=args.max_kl,
        optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp(
            base_eps=1e-5)) if args.recurrent else None,
        mode=args.control
        if not args.chunked else 'chunk_{}'.format(args.control),
    )

    algo.train()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('filename',
                        type=str)  # defaultIS.h5/snapshots/iter0000480
    parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4')
    parser.add_argument('--deterministic', action='store_true', default=False)
    parser.add_argument('--evaluate', action='store_true', default=False)
    parser.add_argument('--n_steps', type=int, default=500)
    args = parser.parse_args()

    # Load file
    filename, file_key = rltools.util.split_h5_name(args.filename)
    print('Loading parameters from {} in {}'.format(file_key, filename))
    with h5py.File(filename, 'r') as f:
        train_args = json.loads(f.attrs['args'])
        dset = f[file_key]

        pprint.pprint(dict(dset.attrs))

    centralized = True if train_args['control'] == 'centralized' else False
    env = MAWaterWorld(train_args['n_pursuers'],
                       train_args['n_evaders'],
                       train_args['n_coop'],
                       train_args['n_poison'],
                       n_sensors=train_args['n_sensors'],
                       food_reward=train_args['food_reward'],
                       poison_reward=train_args['poison_reward'],
                       obstacle_location=None,
                       encounter_reward=train_args['encounter_reward'],
                       addid=False if not centralized else True,
                       speed_features=bool(train_args['speed_features']))

    if train_args['buffer_size'] > 1:
        env = ObservationBuffer(env, train_args['buffer_size'])

    if centralized:
        obs_space = spaces.Box(
            low=env.agents[0].observation_space.low[0],
            high=env.agents[0].observation_space.high[0],
            shape=(env.agents[0].observation_space.shape[0] *
                   len(env.agents), ))  # XXX
        action_space = spaces.Box(low=env.agents[0].action_space.low[0],
                                  high=env.agents[0].action_space.high[0],
                                  shape=(env.agents[0].action_space.shape[0] *
                                         len(env.agents), ))  # XXX
    else:
        obsfeat_space = env.agents[0].observation_space
        action_space = env.agents[0].action_space

    policy = GaussianMLPPolicy(obsfeat_space,
                               action_space,
                               hidden_spec=train_args['policy_hidden_spec'],
                               enable_obsnorm=train_args['enable_obsnorm'],
                               min_stdev=0.,
                               init_logstdev=0.,
                               tblog=train_args['tblog'],
                               varscope_name='policy')

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        policy.load_h5(sess, filename, file_key)
        # Evaluate
        if args.evaluate:
            rltools.util.ok("Evaluating...")
            evr = rltools.util.evaluate_policy(
                env,
                policy,
                deterministic=args.deterministic,
                disc=train_args['discount'],
                mode=train_args['control'],
                max_traj_len=args.n_steps,
                n_trajs=100)
            from tabulate import tabulate
            print(tabulate(evr, headers='keys'))
        else:
            rew, trajinfo = env.animate(act_fn=lambda o: policy.sample_actions(
                o[None, ...], deterministic=args.deterministic)[0],
                                        nsteps=args.n_steps)
            info = {key: np.sum(value) for key, value in trajinfo.items()}
            print(rew, info)