def main(): parser = argparse.ArgumentParser() parser.add_argument('filename', type=str) # defaultIS.h5/snapshots/iter0000480 parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4') parser.add_argument('--deterministic', action='store_true', default=False) parser.add_argument('--heuristic', action='store_true', default=False) parser.add_argument('--evaluate', action='store_true', default=False) parser.add_argument('--n_trajs', type=int, default=20) parser.add_argument('--n_steps', type=int, default=500) parser.add_argument('--same_con_pol', action='store_true') args = parser.parse_args() fh = FileHandler(args.filename) env_map = TwoDMaps.rectangle_map( *map(int, fh.train_args['map_size'].split(','))) map_pool = [env_map] # map_pool = np.load( # os.path.join('/scratch/megorov/deeprl/MADRL/runners/maps/', os.path.basename(fh.train_args[ # 'map_file']))) env = PursuitEvade(map_pool, n_evaders=fh.train_args['n_evaders'], n_pursuers=fh.train_args['n_pursuers'], obs_range=fh.train_args['obs_range'], n_catch=fh.train_args['n_catch'], urgency_reward=fh.train_args['urgency'], surround=bool(fh.train_args['surround']), sample_maps=bool(fh.train_args['sample_maps']), flatten=bool(fh.train_args['flatten']), reward_mech='global', catchr=fh.train_args['catchr'], term_pursuit=fh.train_args['term_pursuit']) if fh.train_args['buffer_size'] > 1: env = ObservationBuffer(env, fh.train_args['buffer_size']) hpolicy = None if args.heuristic: from heuristics.pursuit import PursuitHeuristicPolicy hpolicy = PursuitHeuristicPolicy(env.agents[0].observation_space, env.agents[0].action_space) if args.evaluate: minion = Evaluator(env, fh.train_args, args.n_steps, args.n_trajs, args.deterministic, 'heuristic' if args.heuristic else fh.mode) evr = minion(fh.filename, file_key=fh.file_key, same_con_pol=args.same_con_pol, hpolicy=hpolicy) from tabulate import tabulate print(evr) print(tabulate(evr, headers='keys')) else: minion = Visualizer(env, fh.train_args, args.n_steps, args.n_trajs, args.deterministic, fh.mode) rew, info = minion(fh.filename, file_key=fh.file_key, vid=args.vid) pprint.pprint(rew) pprint.pprint(info)
def main(parser): args = parser.args if args.map_file: map_pool = np.load(args.map_file) else: if args.map_type == 'rectangle': env_map = TwoDMaps.rectangle_map(*map(int, args.map_size.split(','))) elif args.map_type == 'complex': env_map = TwoDMaps.complex_map(*map(int, args.map_size.split(','))) else: raise NotImplementedError() map_pool = [env_map] env = PursuitEvade(map_pool, n_evaders=args.n_evaders, n_pursuers=args.n_pursuers, obs_range=args.obs_range, n_catch=args.n_catch, urgency_reward=args.urgency, surround=bool(args.surround), sample_maps=bool(args.sample_maps), flatten=bool(args.flatten), reward_mech=args.reward_mech, catchr=args.catchr, term_pursuit=args.term_pursuit, include_id=not bool(args.noid)) if args.buffer_size > 1: env = ObservationBuffer(env, args.buffer_size) runner = Runner(env, args) runner.start_training()
def envname2env(envname, args): from madrl_environments.pursuit import PursuitEvade from madrl_environments.pursuit import MAWaterWorld from madrl_environments.walker.multi_walker import MultiWalkerEnv # XXX # Will generalize later if envname == 'multiwalker': env = MultiWalkerEnv( args['n_walkers'], args['position_noise'], args['angle_noise'], reward_mech='global', ) elif envname == 'waterworld': env = MAWaterWorld( args['n_pursuers'], args['n_evaders'], args['n_coop'], args['n_poison'], n_sensors=args['n_sensors'], food_reward=args['food_reward'], poison_reward=args['poison_reward'], encounter_reward=args['encounter_reward'], reward_mech='global', ) elif envname == 'pursuit': env = PursuitEvade() else: raise NotImplementedError() return env
def main(): parser = argparse.ArgumentParser() parser.add_argument('filename', type=str) parser.add_argument('--vid', type=str, default='madrl.mp4') parser.add_argument('--deterministic', action='store_true', default=False) parser.add_argument('--heuristic', action='store_true', default=False) parser.add_argument('--evaluate', action='store_true', default=True) parser.add_argument('--n_trajs', type=int, default=20) parser.add_argument('--n_steps', type=int, default=20) parser.add_argument('--same_con_pol', action='store_true') args = parser.parse_args() fh = FileHandler(args.filename) if fh.train_args['map_file'] is not None: map_pool = np.load( os.path.join('', os.path.basename(fh.train_args[ 'map_file']))) else: if fh.train_args['map_type'] == 'rectangle': env_map = TwoDMaps.rectangle_map(*map(int, fh.train_args['map_size'].split(','))) elif args.map_type == 'complex': env_map = TwoDMaps.complex_map(*map(int, fh.train_args['map_size'].split(','))) else: raise NotImplementedError() map_pool = [env_map] env = PursuitEvade(map_pool, n_evaders=fh.train_args['n_evaders'], n_pursuers=fh.train_args['n_pursuers'], obs_range=fh.train_args['obs_range'], n_catch=fh.train_args['n_catch'], urgency_reward=fh.train_args['urgency'], surround=bool(fh.train_args['surround']), sample_maps=bool(fh.train_args['sample_maps']), flatten=bool(fh.train_args['flatten']), reward_mech='global', catchr=fh.train_args['catchr'], term_pursuit=fh.train_args['term_pursuit']) if fh.train_args['buffer_size'] > 1: env = ObservationBuffer(env, fh.train_args['buffer_size']) hpolicy = None if args.evaluate: minion = Evaluator(env, fh.train_args, args.n_steps, args.n_trajs, args.deterministic, 'heuristic' if args.heuristic else fh.mode) evr = minion(fh.filename, file_key=fh.file_key, same_con_pol=args.same_con_pol, hpolicy=hpolicy) print(evr) print(tabulate(evr, headers='keys')) else: minion = Visualizer(env, fh.train_args, args.n_steps, args.n_trajs, args.deterministic, fh.mode) rew, info = minion(fh.filename, file_key=fh.file_key, vid=args.vid) pprint.pprint(rew) pprint.pprint(info)
def main(parser): mode = parser._mode args = parser.args if args.map_file: map_pool = np.load(args.map_file) else: if args.map_type == 'rectangle': env_map = TwoDMaps.rectangle_map( *map(int, args.map_size.split(','))) elif args.map_type == 'complex': env_map = TwoDMaps.complex_map(*map(int, args.map_size.split(','))) else: raise NotImplementedError() map_pool = [env_map] env = PursuitEvade(map_pool, n_evaders=args.n_evaders, n_pursuers=args.n_pursuers, obs_range=args.obs_range, n_catch=args.n_catch, urgency_reward=args.urgency, surround=bool(args.surround), sample_maps=bool(args.sample_maps), flatten=bool(args.flatten), reward_mech=args.reward_mech, catchr=args.catchr, term_pursuit=args.term_pursuit) if args.buffer_size > 1: env = ObservationBuffer(env, args.buffer_size) if mode == 'rllab': from runners.rurllab import RLLabRunner run = RLLabRunner(env, args) elif mode == 'rltools': from runners.rurltools import RLToolsRunner run = RLToolsRunner(env, args) else: raise NotImplementedError() run()
def set_state(self, *args): pass if __name__ == '__main__': from madrl_environments.pursuit import PursuitEvade from madrl_environments.pursuit.utils import * from vis import Visualizer import pprint map_mat = TwoDMaps.rectangle_map(16, 16) env = PursuitEvade([map_mat], n_evaders=30, n_pursuers=8, obs_range=7, n_catch=4, surround=False, flatten=False) policy = PursuitHeuristicPolicy(env.agents[0].observation_space, env.agents[0].action_space) for i in range(20): rew = 0.0 obs = env.reset() infolist = [] for _ in xrange(500): # env.render() act_list = [] for o in obs:
def main(): now = datetime.datetime.now(dateutil.tz.tzlocal()) rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument('--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--discount', type=float, default=0.99) parser.add_argument('--gae_lambda', type=float, default=1.0) parser.add_argument('--reward_scale', type=float, default=1.0) parser.add_argument('--n_iter', type=int, default=250) parser.add_argument('--sampler_workers', type=int, default=1) parser.add_argument('--max_traj_len', type=int, default=250) parser.add_argument('--update_curriculum', action='store_true', default=False) parser.add_argument('--n_timesteps', type=int, default=8000) parser.add_argument('--control', type=str, default='centralized') parser.add_argument('--rectangle', type=str, default='10,10') parser.add_argument('--map_type', type=str, default='rectangle') parser.add_argument('--n_evaders', type=int, default=5) parser.add_argument('--n_pursuers', type=int, default=2) parser.add_argument('--obs_range', type=int, default=3) parser.add_argument('--n_catch', type=int, default=2) parser.add_argument('--urgency', type=float, default=0.0) parser.add_argument('--pursuit', dest='train_pursuit', action='store_true') parser.add_argument('--evade', dest='train_pursuit', action='store_false') parser.set_defaults(train_pursuit=True) parser.add_argument('--surround', action='store_true', default=False) parser.add_argument('--constraint_window', type=float, default=1.0) parser.add_argument('--sample_maps', action='store_true', default=False) parser.add_argument('--map_file', type=str, default='../maps/map_pool.npy') parser.add_argument('--flatten', action='store_true', default=False) parser.add_argument('--reward_mech', type=str, default='global') parser.add_argument('--catchr', type=float, default=0.1) parser.add_argument('--term_pursuit', type=float, default=5.0) parser.add_argument('--recurrent', type=str, default=None) parser.add_argument('--policy_hidden_sizes', type=str, default='128,128') parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128') parser.add_argument('--baseline_type', type=str, default='linear') parser.add_argument('--conv', action='store_true', default=False) parser.add_argument('--max_kl', type=float, default=0.01) parser.add_argument('--checkpoint', type=str, default=None) parser.add_argument('--log_dir', type=str, required=False) parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), or "none" ' '(do not save snapshots)') parser.add_argument( '--log_tabular_only', type=ast.literal_eval, default=False, help= 'Whether to only print the tabular log information (in a horizontal format)' ) args = parser.parse_args() parallel_sampler.initialize(n_parallel=args.sampler_workers) if args.seed is not None: set_seed(args.seed) parallel_sampler.set_seed(args.seed) args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(','))) if args.checkpoint: with tf.Session() as sess: data = joblib.load(args.checkpoint) policy = data['policy'] env = data['env'] else: if args.sample_maps: map_pool = np.load(args.map_file) else: if args.map_type == 'rectangle': env_map = TwoDMaps.rectangle_map( *map(int, args.rectangle.split(','))) elif args.map_type == 'complex': env_map = TwoDMaps.complex_map( *map(int, args.rectangle.split(','))) else: raise NotImplementedError() map_pool = [env_map] env = PursuitEvade(map_pool, n_evaders=args.n_evaders, n_pursuers=args.n_pursuers, obs_range=args.obs_range, n_catch=args.n_catch, train_pursuit=args.train_pursuit, urgency_reward=args.urgency, surround=args.surround, sample_maps=args.sample_maps, constraint_window=args.constraint_window, flatten=args.flatten, reward_mech=args.reward_mech, catchr=args.catchr, term_pursuit=args.term_pursuit) env = TfEnv( RLLabEnv(StandardizedEnv(env, scale_reward=args.reward_scale, enable_obsnorm=False), mode=args.control)) if args.recurrent: if args.conv: feature_network = ConvNetwork( name='feature_net', input_shape=emv.spec.observation_space.shape, output_dim=5, conv_filters=(16, 32, 32), conv_filter_sizes=(3, 3, 3), conv_strides=(1, 1, 1), conv_pads=('VALID', 'VALID', 'VALID'), hidden_sizes=(64, ), hidden_nonlinearity=tf.nn.relu, output_nonlinearity=tf.nn.softmax) else: feature_network = MLP( name='feature_net', input_shape=(env.spec.observation_space.flat_dim + env.spec.action_space.flat_dim, ), output_dim=5, hidden_sizes=(256, 128, 64), hidden_nonlinearity=tf.nn.tanh, output_nonlinearity=None) if args.recurrent == 'gru': policy = CategoricalGRUPolicy(env_spec=env.spec, feature_network=feature_network, hidden_dim=int( args.policy_hidden_sizes), name='policy') elif args.recurrent == 'lstm': policy = CategoricalLSTMPolicy(env_spec=env.spec, feature_network=feature_network, hidden_dim=int( args.policy_hidden_sizes), name='policy') elif args.conv: feature_network = ConvNetwork( name='feature_net', input_shape=env.spec.observation_space.shape, output_dim=5, conv_filters=(8, 16), conv_filter_sizes=(3, 3), conv_strides=(2, 1), conv_pads=('VALID', 'VALID'), hidden_sizes=(32, ), hidden_nonlinearity=tf.nn.relu, output_nonlinearity=tf.nn.softmax) policy = CategoricalMLPPolicy(name='policy', env_spec=env.spec, prob_network=feature_network) else: policy = CategoricalMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=args.hidden_sizes) if args.baseline_type == 'linear': baseline = LinearFeatureBaseline(env_spec=env.spec) else: baseline = ZeroBaseline(env_spec=env.spec) # logger default_log_dir = config.LOG_DIR if args.log_dir is None: log_dir = osp.join(default_log_dir, args.exp_name) else: log_dir = args.log_dir tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) algo = TRPO( env=env, policy=policy, baseline=baseline, batch_size=args.n_timesteps, max_path_length=args.max_traj_len, n_itr=args.n_iter, discount=args.discount, gae_lambda=args.gae_lambda, step_size=args.max_kl, optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp( base_eps=1e-5)) if args.recurrent else None, mode=args.control, ) algo.train()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--discount', type=float, default=0.99) parser.add_argument('--gae_lambda', type=float, default=0.99) parser.add_argument('--n_iter', type=int, default=250) parser.add_argument('--sampler', type=str, default='simple') parser.add_argument('--sampler_workers', type=int, default=4) parser.add_argument('--max_traj_len', type=int, default=250) parser.add_argument('--adaptive_batch', action='store_true', default=False) parser.add_argument('--update_curriculum', action='store_true', default=False) parser.add_argument('--n_timesteps', type=int, default=8000) parser.add_argument('--n_timesteps_min', type=int, default=1000) parser.add_argument('--n_timesteps_max', type=int, default=64000) parser.add_argument('--timestep_rate', type=int, default=20) parser.add_argument('--control', type=str, default='centralized') parser.add_argument('--rectangle', type=str, default='10,10') parser.add_argument('--map_type', type=str, default='rectangle') parser.add_argument('--n_evaders', type=int, default=5) parser.add_argument('--n_pursuers', type=int, default=2) parser.add_argument('--obs_range', type=int, default=3) parser.add_argument('--n_catch', type=int, default=2) parser.add_argument('--urgency', type=float, default=0.0) parser.add_argument('--pursuit', dest='train_pursuit', action='store_true') parser.add_argument('--evade', dest='train_pursuit', action='store_false') parser.set_defaults(train_pursuit=True) parser.add_argument('--surround', action='store_true', default=False) parser.add_argument('--constraint_window', type=float, default=1.0) parser.add_argument('--sample_maps', action='store_true', default=False) parser.add_argument('--map_file', type=str, default='maps/map_pool.npy') parser.add_argument('--flatten', action='store_true', default=False) parser.add_argument('--reward_mech', type=str, default='global') parser.add_argument('--cur_remove', type=int, default=100) parser.add_argument('--cur_const_rate', type=float, default=0.0) parser.add_argument('--cur_shaping', type=int, default=np.inf) parser.add_argument('--catchr', type=float, default=0.1) parser.add_argument('--term_pursuit', type=float, default=5.0) parser.add_argument('--policy_hidden_spec', type=str, default='MED_POLICY_ARCH') parser.add_argument('--baseline_type', type=str, default='mlp') parser.add_argument('--baseline_hidden_spec', type=str, default='MED_POLICY_ARCH') parser.add_argument('--max_kl', type=float, default=0.01) parser.add_argument('--vf_max_kl', type=float, default=0.01) parser.add_argument('--vf_cg_damping', type=float, default=0.01) parser.add_argument('--save_freq', type=int, default=50) parser.add_argument('--log', type=str, required=False) parser.add_argument('--tblog', type=str, default='/tmp/madrl_tb') parser.add_argument('--debug', dest='debug', action='store_true') parser.add_argument('--no-debug', dest='debug', action='store_false') parser.add_argument('--load_checkpoint', type=str, default='') parser.set_defaults(debug=True) args = parser.parse_args() args.policy_hidden_spec = get_arch(args.policy_hidden_spec) args.baseline_hidden_spec = get_arch(args.baseline_hidden_spec) if args.sample_maps: map_pool = np.load(args.map_file) else: if args.map_type == 'rectangle': env_map = TwoDMaps.rectangle_map(*map(int, args.rectangle.split(','))) elif args.map_type == 'complex': env_map = TwoDMaps.complex_map(*map(int, args.rectangle.split(','))) else: raise NotImplementedError() map_pool = [env_map] env = PursuitEvade(map_pool, n_evaders=args.n_evaders, n_pursuers=args.n_pursuers, obs_range=args.obs_range, n_catch=args.n_catch, train_pursuit=args.train_pursuit, urgency_reward=args.urgency, surround=args.surround, sample_maps=args.sample_maps, constraint_window=args.constraint_window, flatten=args.flatten, reward_mech=args.reward_mech, curriculum_remove_every=args.cur_remove, curriculum_constrain_rate=args.cur_const_rate, curriculum_turn_off_shaping=args.cur_shaping, catchr=args.catchr, term_pursuit=args.term_pursuit) if args.control == 'centralized': obsfeat_space = spaces.Box(low=env.agents[0].observation_space.low[0], high=env.agents[0].observation_space.high[0], shape=(env.agents[0].observation_space.shape[0] * len(env.agents),)) # XXX action_space = spaces.Discrete(env.agents[0].action_space.n * len(env.agents)) elif args.control == 'decentralized': obsfeat_space = env.agents[0].observation_space action_space = env.agents[0].action_space else: raise NotImplementedError() if args.load_checkpoint is not '': filename, file_key = rltools.util.split_h5_name(args.load_checkpoint) print('Loading parameters from {} in {}'.format(file_key, filename)) with h5py.File(filename, 'r') as f: train_args = json.loads(f.attrs['args']) dset = f[file_key] pprint.pprint(dict(dset.attrs)) policy = CategoricalMLPPolicy(obsfeat_space, action_space, hidden_spec=train_args['policy_hidden_spec'], enable_obsnorm=False, tblog=train_args['tblog'], varscope_name='pursuit_catmlp_policy') else: policy = CategoricalMLPPolicy(obsfeat_space, action_space, hidden_spec=args.policy_hidden_spec, enable_obsnorm=False, tblog=args.tblog, varscope_name='pursuit_catmlp_policy') if args.baseline_type == 'linear': baseline = LinearFeatureBaseline(obsfeat_space, enable_obsnorm=False, varscope_name='pursuit_linear_baseline') elif args.baseline_type == 'mlp': baseline = MLPBaseline(obsfeat_space, args.baseline_hidden_spec, True, True, max_kl=args.vf_max_kl, damping=args.vf_cg_damping, time_scale=1. / args.max_traj_len, varscope_name='pursuit_mlp_baseline') else: baseline = ZeroBaseline(obsfeat_space) if args.sampler == 'simple': if args.control == "centralized": sampler_cls = SimpleSampler elif args.control == "decentralized": sampler_cls = DecSampler else: raise NotImplementedError() sampler_args = dict(max_traj_len=args.max_traj_len, n_timesteps=args.n_timesteps, n_timesteps_min=args.n_timesteps_min, n_timesteps_max=args.n_timesteps_max, timestep_rate=args.timestep_rate, adaptive=args.adaptive_batch, enable_rewnorm=True) elif args.sampler == 'parallel': sampler_cls = ParallelSampler sampler_args = dict(max_traj_len=args.max_traj_len, n_timesteps=args.n_timesteps, n_timesteps_min=args.n_timesteps_min, n_timesteps_max=args.n_timesteps_max, timestep_rate=args.timestep_rate, adaptive=args.adaptive_batch, enable_rewnorm=True, n_workers=args.sampler_workers, mode=args.control) else: raise NotImplementedError() step_func = rltools.algos.policyopt.TRPO(max_kl=args.max_kl) popt = rltools.algos.policyopt.SamplingPolicyOptimizer(env=env, policy=policy, baseline=baseline, step_func=step_func, discount=args.discount, gae_lambda=args.gae_lambda, sampler_cls=sampler_cls, sampler_args=sampler_args, n_iter=args.n_iter, update_curriculum=args.update_curriculum) argstr = json.dumps(vars(args), separators=(',', ':'), indent=2) rltools.util.header(argstr) log_f = rltools.log.TrainingLog(args.log, [('args', argstr)], debug=args.debug) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) if args.load_checkpoint is not '': policy.load_h5(sess, filename, file_key) popt.train(sess, log_f, args.save_freq)
def main(): parser = argparse.ArgumentParser() parser.add_argument('filename', type=str) # defaultIS.h5/snapshots/iter0000480 parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4') parser.add_argument('--deterministic', action='store_true', default=False) parser.add_argument('--verbose', action='store_true', default=False) parser.add_argument('--n_steps', type=int, default=200) parser.add_argument('--map_file', type=str, default='') args = parser.parse_args() # Load file filename, file_key = rltools.util.split_h5_name(args.filename) print('Loading parameters from {} in {}'.format(file_key, filename)) with h5py.File(filename, 'r') as f: train_args = json.loads(f.attrs['args']) dset = f[file_key] pprint.pprint(dict(dset.attrs)) pprint.pprint(train_args) if train_args['sample_maps']: map_pool = np.load(args.map_file) else: if train_args['map_type'] == 'rectangle': env_map = TwoDMaps.rectangle_map(*map(int, train_args['rectangle'].split(','))) elif train_args['map_type'] == 'complex': env_map = TwoDMaps.complex_map(*map(int, train_args['rectangle'].split(','))) else: raise NotImplementedError() map_pool = [env_map] env = PursuitEvade(map_pool, #n_evaders=train_args['n_evaders'], #n_pursuers=train_args['n_pursuers'], n_evaders=50, n_pursuers=50, obs_range=train_args['obs_range'], n_catch=train_args['n_catch'], urgency_reward=train_args['urgency'], surround=train_args['surround'], sample_maps=train_args['sample_maps'], flatten=train_args['flatten'], reward_mech=train_args['reward_mech'] ) if train_args['control'] == 'decentralized': obsfeat_space = env.agents[0].observation_space action_space = env.agents[0].action_space elif train_args['control'] == 'centralized': obsfeat_space = spaces.Box(low=env.agents[0].observation_space.low[0], high=env.agents[0].observation_space.high[0], shape=(env.agents[0].observation_space.shape[0] * len(env.agents),)) # XXX action_space = spaces.Discrete(env.agents[0].action_space.n * len(env.agents)) else: raise NotImplementedError() policy = CategoricalMLPPolicy(obsfeat_space, action_space, hidden_spec=train_args['policy_hidden_spec'], enable_obsnorm=True, tblog=train_args['tblog'], varscope_name='pursuit_catmlp_policy') with tf.Session() as sess: sess.run(tf.initialize_all_variables()) policy.load_h5(sess, filename, file_key) if train_args['control'] == 'centralized': act_fn = lambda o: policy.sample_actions(np.expand_dims(np.array(o).flatten(),0), deterministic=args.deterministic)[0][0,0] elif train_args['control'] == 'decentralized': def act_fn(o): action_list = [] for agent_obs in o: a, adist = policy.sample_actions(np.expand_dims(agent_obs,0), deterministic=args.deterministic) action_list.append(a[0, 0]) return action_list #import IPython #IPython.embed() env.animate(act_fn=act_fn, nsteps=args.n_steps, file_name=args.vid, verbose=args.verbose)
def main(): parser = argparse.ArgumentParser() parser.add_argument('filename', type=str) # defaultIS.h5/snapshots/iter0000480 parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4') parser.add_argument('--deterministic', action='store_true', default=False) parser.add_argument('--verbose', action='store_true', default=False) parser.add_argument('--n_steps', type=int, default=200) parser.add_argument('--map_file', type=str, default='') args = parser.parse_args() # Load file filename, file_key = rltools.util.split_h5_name(args.filename) print('Loading parameters from {} in {}'.format(file_key, filename)) with h5py.File(filename, 'r') as f: train_args = json.loads(f.attrs['args']) dset = f[file_key] pprint.pprint(dict(dset.attrs)) pprint.pprint(train_args) if train_args['sample_maps']: map_pool = np.load(args.map_file) else: if train_args['map_type'] == 'rectangle': env_map = TwoDMaps.rectangle_map( *map(int, train_args['rectangle'].split(','))) elif train_args['map_type'] == 'complex': env_map = TwoDMaps.complex_map( *map(int, train_args['rectangle'].split(','))) else: raise NotImplementedError() map_pool = [env_map] env = PursuitEvade( map_pool, #n_evaders=train_args['n_evaders'], #n_pursuers=train_args['n_pursuers'], n_evaders=50, n_pursuers=50, obs_range=train_args['obs_range'], n_catch=train_args['n_catch'], urgency_reward=train_args['urgency'], surround=train_args['surround'], sample_maps=train_args['sample_maps'], flatten=train_args['flatten'], reward_mech=train_args['reward_mech']) if train_args['control'] == 'decentralized': obsfeat_space = env.agents[0].observation_space action_space = env.agents[0].action_space elif train_args['control'] == 'centralized': obsfeat_space = spaces.Box( low=env.agents[0].observation_space.low[0], high=env.agents[0].observation_space.high[0], shape=(env.agents[0].observation_space.shape[0] * len(env.agents), )) # XXX action_space = spaces.Discrete(env.agents[0].action_space.n * len(env.agents)) else: raise NotImplementedError() policy = CategoricalMLPPolicy(obsfeat_space, action_space, hidden_spec=train_args['policy_hidden_spec'], enable_obsnorm=True, tblog=train_args['tblog'], varscope_name='pursuit_catmlp_policy') with tf.Session() as sess: sess.run(tf.initialize_all_variables()) policy.load_h5(sess, filename, file_key) if train_args['control'] == 'centralized': act_fn = lambda o: policy.sample_actions( np.expand_dims(np.array(o).flatten(), 0), deterministic=args.deterministic)[0][0, 0] elif train_args['control'] == 'decentralized': def act_fn(o): action_list = [] for agent_obs in o: a, adist = policy.sample_actions( np.expand_dims(agent_obs, 0), deterministic=args.deterministic) action_list.append(a[0, 0]) return action_list #import IPython #IPython.embed() env.animate(act_fn=act_fn, nsteps=args.n_steps, file_name=args.vid, verbose=args.verbose)
def get_state(self): return [] def set_state(self, *args): pass if __name__ == '__main__': from madrl_environments.pursuit import PursuitEvade from madrl_environments.pursuit.utils import * from vis import Visualizer import pprint map_mat = TwoDMaps.rectangle_map(16, 16) env = PursuitEvade([map_mat], n_evaders=30, n_pursuers=8, obs_range=7, n_catch=4, surround=False, flatten=False) policy = PursuitHeuristicPolicy(env.agents[0].observation_space, env.agents[0].action_space) for i in range(20): rew = 0.0 obs = env.reset() infolist = [] for _ in xrange(500): # env.render() act_list = [] for o in obs: a, _ = policy.sample_actions(o) act_list.append(a) obs, r, done, info = env.step(act_list) rew += np.mean(r)