コード例 #1
0
def main(parser):
    mode = parser._mode
    args = parser.args

    env = MultiAnt(n_legs=args.n_legs,
                   ts=args.ts,
                   integrator=args.integrator,
                   leg_length=args.leg_length,
                   out_file=args.out_file,
                   base_file=args.base_file,
                   reward_mech=args.reward_mech)

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if mode == 'rllab':
        from runners.rurllab import RLLabRunner
        run = RLLabRunner(env, args)
    elif mode == 'rltools':
        from runners.rurltools import RLToolsRunner
        run = RLToolsRunner(env, args)
    else:
        raise NotImplementedError()

    if args.curriculum:
        curr = Curriculum(args.curriculum)
        run(curr)
    else:
        run()
コード例 #2
0
def main(parser):
    mode = parser._mode
    args = parser.args
    env_config = dict(n_walkers=args.n_walkers,
                      position_noise=args.position_noise,
                      angle_noise=args.angle_noise,
                      reward_mech=args.reward_mech,
                      forward_reward=args.forward_reward,
                      fall_reward=args.fall_reward,
                      drop_reward=args.drop_reward,
                      terminate_on_fall=bool(args.terminate_on_fall),
                      one_hot=bool(args.one_hot))
    env = MultiWalkerEnv(**env_config)
    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if mode == 'rllab':
        from runners.rurllab import RLLabRunner
        run = RLLabRunner(env, args)
    elif mode == 'rltools':
        from runners.rurltools import RLToolsRunner
        run = RLToolsRunner(env, args)
    else:
        raise NotImplementedError()

    if args.curriculum:
        curr = Curriculum(args.curriculum)
        run(curr)
    else:
        run()
コード例 #3
0
def main(parser):
    mode = parser._mode
    args = parser.args
    env = MAWaterWorld(args.n_pursuers,
                       args.n_evaders,
                       args.n_coop,
                       args.n_poison,
                       radius=args.radius,
                       n_sensors=args.n_sensors,
                       food_reward=args.food_reward,
                       poison_reward=args.poison_reward,
                       encounter_reward=args.encounter_reward,
                       reward_mech=args.reward_mech,
                       sensor_range=args.sensor_range,
                       obstacle_loc=None,
                       addid=True if not args.noid else False,
                       speed_features=bool(args.speed_features))

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if mode == 'rllab':
        from runners.rurllab import RLLabRunner
        run = RLLabRunner(env, args)
    elif mode == 'rltools':
        from runners.rurltools import RLToolsRunner
        run = RLToolsRunner(env, args)
    else:
        raise NotImplementedError()

    if args.curriculum:
        curr = Curriculum(args.curriculum)
        run(curr)
    else:
        run()
コード例 #4
0
def main(parser):
    mode = parser._mode
    args = parser.args
    env_config = dict(n_agents=args.n_agents,
                      speed_noise=args.speed_noise,
                      position_noise=args.position_noise,
                      angle_noise=args.angle_noise,
                      reward_mech=args.reward_mech,
                      rew_arrival=args.rew_arrival,
                      rew_closing=args.rew_closing,
                      rew_nmac=args.rew_nmac,
                      rew_large_turnrate=args.rew_large_turnrate,
                      rew_large_acc=args.rew_large_acc,
                      pen_action_heavy=args.pen_action_heavy,
                      one_hot=bool(args.one_hot))

    env = MultiAircraftEnv(**env_config)
    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if mode == 'rllab':
        from runners.rurllab import RLLabRunner
        run = RLLabRunner(env, args)
    elif mode == 'rltools':
        from runners.rurltools import RLToolsRunner
        run = RLToolsRunner(env, args)
    else:
        raise NotImplementedError()

    if args.curriculum:
        curr = Curriculum(args.curriculum)
        run(curr)
    else:
        run()
コード例 #5
0
def main(parser):
    mode = parser._mode
    args = parser.args

    if args.map_file:
        map_pool = np.load(args.map_file)
    else:
        if args.map_type == 'rectangle':
            env_map = TwoDMaps.rectangle_map(
                *map(int, args.map_size.split(',')))
        elif args.map_type == 'complex':
            env_map = TwoDMaps.complex_map(*map(int, args.map_size.split(',')))
        else:
            raise NotImplementedError()
        map_pool = [env_map]

    env = PursuitEvade(map_pool,
                       n_evaders=args.n_evaders,
                       n_pursuers=args.n_pursuers,
                       obs_range=args.obs_range,
                       n_catch=args.n_catch,
                       urgency_reward=args.urgency,
                       surround=bool(args.surround),
                       sample_maps=bool(args.sample_maps),
                       flatten=bool(args.flatten),
                       reward_mech=args.reward_mech,
                       catchr=args.catchr,
                       term_pursuit=args.term_pursuit,
                       include_id=not bool(args.noid))

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if mode == 'rllab':
        from runners.rurllab import RLLabRunner
        run = RLLabRunner(env, args)
    elif mode == 'rltools':
        from runners.rurltools import RLToolsRunner
        run = RLToolsRunner(env, args)
    else:
        raise NotImplementedError()

    run()
コード例 #6
0
def main(parser):
    mode = parser._mode
    args = parser.args

    if args.map_file:
        map_pool = np.load(args.map_file)
    else:
        if args.map_type == 'rectangle':
            #passes in tuple of what map should be 
            env_map = TwoDMaps.rectangle_map(*map(int, args.map_size.split(',')))
        elif args.map_type == 'complex':
            env_map = TwoDMaps.complex_map(*map(int, args.map_size.split(',')))
        else:
            raise NotImplementedError()
        #map pool is list of maps of different shapes for environment 
        map_pool = [env_map]


    env = sniper(map_pool, n_targets=args.n_targets, n_snipers=args.n_snipers,
                       obs_range=args.obs_range, n_catch=args.n_catch,
                       urgency_reward=args.urgency,
                       surround=bool(args.surround), sample_maps=bool(args.sample_maps),
                       flatten=bool(args.flatten),
                       reward_mech=args.reward_mech,
                       catchr=args.catchr,
                       term_sniper=args.term_sniper)

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if mode == 'rllab':
        from runners.rurllab import RLLabRunner
        run = RLLabRunner(env, args)
    elif mode == 'rltools':
        from runners.rurltools import RLToolsRunner
        run = RLToolsRunner(env, args)
    else:
        raise NotImplementedError()

    run()
コード例 #7
0
def main(parser):
    mode = parser._mode
    args = parser.args

    env = ContinuousHostageWorld(
        n_good=args.n_good,
        n_hostages=args.n_hostages,
        n_bad=args.n_bad,
        n_coop_save=args.n_coop_save,
        n_coop_avoid=args.n_coop_avoid,
        radius=args.radius,
        key_loc=args.key_loc,
        bad_speed=args.bad_speed,
        n_sensors=args.n_sensors,
        sensor_range=args.sensor_range,
        save_reward=args.save_reward,
        hit_reward=args.hit_reward,
        encounter_reward=args.encounter_reward,
        bomb_reward=args.bomb_reward,
        bomb_radius=args.bomb_radius,
        control_penalty=args.control_penalty,
        reward_mech=args.reward_mech,
    )

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if mode == 'rllab':
        from runners.rurllab import RLLabRunner
        run = RLLabRunner(env, args)
    elif mode == 'rltools':
        from runners.rurltools import RLToolsRunner
        run = RLToolsRunner(env, args)
    else:
        raise NotImplementedError()

    run()
コード例 #8
0
def main(parser):
    mode = parser._mode
    args = parser.args
    env = MAContWorld(args.n_rovers, args.n_areas_of_int, args.n_coop, args.n_crater,
                       radius=args.radius, n_sensors=args.n_sensors, scout_reward=args.scout_reward,
                       crater_reward=args.crater_reward, encounter_reward=args.encounter_reward,
                       reward_mech=args.reward_mech, sensor_range=args.sensor_range,
                       obstacle_loc=None, addid=True if not args.noid else False,
                       speed_features=bool(args.speed_features))

    if args.buffer_size > 1:
        env = ObservationBuffer(env, args.buffer_size)

    if mode == 'rllab':
        from runners.rurllab import RLLabRunner
        run = RLLabRunner(env, args)
    else:
        raise NotImplementedError()

    if args.curriculum:
        curr = Curriculum(args.curriculum)
        run(curr)
    else:
        run()