def train_predatorprey(ctxt=None, args_dict=vars(args)): args = SimpleNamespace(**args_dict) set_seed(args.seed) env = PredatorPreyWrapper( centralized=True, # centralized training grid_shape=(args.grid_size, args.grid_size), n_agents=args.n_agents, n_preys=args.n_preys, max_steps=args.max_env_steps, step_cost=args.step_cost, prey_capture_reward=args.capture_reward, penalty=args.penalty, other_agent_visible=args.agent_visible) env = GarageEnv(env) runner = LocalRunnerWrapper(ctxt, eval=args.eval_during_training, n_eval_episodes=args.n_eval_episodes, eval_greedy=args.eval_greedy, eval_epoch_freq=args.eval_epoch_freq, save_env=env.pickleable) hidden_nonlinearity = F.relu if args.hidden_nonlinearity == 'relu' \ else torch.tanh policy = DecCategoricalMLPPolicy( env.spec, env.n_agents, hidden_nonlinearity=hidden_nonlinearity, hidden_sizes=args.hidden_sizes, name='dec_categorical_mlp_policy') baseline = GaussianMLPBaseline(env_spec=env.spec, hidden_sizes=(64, 64, 64)) # Set max_path_length <= max_steps # If max_path_length > max_steps, algo will pad obs # obs.shape = torch.Size([n_paths, algo.max_path_length, feat_dim]) algo = CentralizedMAPPO( env_spec=env.spec, policy=policy, baseline=baseline, max_path_length=args.max_env_steps, # Notice discount=args.discount, center_adv=bool(args.center_adv), positive_adv=bool(args.positive_adv), gae_lambda=args.gae_lambda, policy_ent_coeff=args.ent, entropy_method=args.entropy_method, stop_entropy_gradient = True \ if args.entropy_method == 'max' else False, optimization_n_minibatches=args.opt_n_minibatches, optimization_mini_epochs=args.opt_mini_epochs, ) runner.setup(algo, env, sampler_cls=CentralizedMAOnPolicyVectorizedSampler, sampler_args={'n_envs': args.n_envs}) runner.train(n_epochs=args.n_epochs, batch_size=args.bs)
def train_predatorprey(ctxt=None, args_dict=vars(args)): args = SimpleNamespace(**args_dict) set_seed(args.seed) if args.curriculum: curr_start = int(0.125 * args.n_epochs) curr_end = int(0.625 * args.n_epochs) else: curr_start = 0 curr_end = 0 args.add_rate_min = args.add_rate_max env = TrafficJunctionWrapper(centralized=True, dim=args.dim, vision=1, add_rate_min=args.add_rate_min, add_rate_max=args.add_rate_max, curr_start=curr_start, curr_end=curr_end, difficulty=args.difficulty, n_agents=args.n_agents, max_steps=args.max_env_steps) env = GarageEnv(env) runner = LocalRunnerWrapper(ctxt, eval=args.eval_during_training, n_eval_episodes=args.n_eval_episodes, eval_greedy=args.eval_greedy, eval_epoch_freq=args.eval_epoch_freq, save_env=env.pickleable) hidden_nonlinearity = F.relu if args.hidden_nonlinearity == 'relu' \ else torch.tanh policy = DecCategoricalMLPPolicy( env.spec, env.n_agents, hidden_nonlinearity=hidden_nonlinearity, hidden_sizes=args.policy_hidden_sizes, name='dec_categorical_mlp_policy') baseline = DICGCritic( env.spec, env.n_agents, encoder_hidden_sizes=args.encoder_hidden_sizes, embedding_dim=args.embedding_dim, attention_type=args.attention_type, n_gcn_layers=args.n_gcn_layers, residual=args.residual, gcn_bias=args.gcn_bias, name='dicg_critic') # Set max_path_length <= max_steps # If max_path_length > max_steps, algo will pad obs # obs.shape = torch.Size([n_paths, algo.max_path_length, feat_dim]) algo = CentralizedMAPPO( env_spec=env.spec, policy=policy, baseline=baseline, max_path_length=args.max_env_steps, # Notice discount=args.discount, center_adv=bool(args.center_adv), positive_adv=bool(args.positive_adv), gae_lambda=args.gae_lambda, policy_ent_coeff=args.ent, entropy_method=args.entropy_method, stop_entropy_gradient=True \ if args.entropy_method == 'max' else False, clip_grad_norm=args.clip_grad_norm, optimization_n_minibatches=args.opt_n_minibatches, optimization_mini_epochs=args.opt_mini_epochs, ) runner.setup(algo, env, sampler_cls=CentralizedMAOnPolicyVectorizedSampler, sampler_args={'n_envs': args.n_envs}) runner.train(n_epochs=args.n_epochs, batch_size=args.bs)