Esempio n. 1
0
        def train_predatorprey(ctxt=None, args_dict=vars(args)):
            args = SimpleNamespace(**args_dict)
            set_seed(args.seed)
            env = PredatorPreyWrapper(
                centralized=True,  # centralized training
                grid_shape=(args.grid_size, args.grid_size),
                n_agents=args.n_agents,
                n_preys=args.n_preys,
                max_steps=args.max_env_steps,
                step_cost=args.step_cost,
                prey_capture_reward=args.capture_reward,
                penalty=args.penalty,
                other_agent_visible=args.agent_visible)
            env = GarageEnv(env)

            runner = LocalRunnerWrapper(ctxt,
                                        eval=args.eval_during_training,
                                        n_eval_episodes=args.n_eval_episodes,
                                        eval_greedy=args.eval_greedy,
                                        eval_epoch_freq=args.eval_epoch_freq,
                                        save_env=env.pickleable)

            hidden_nonlinearity = F.relu if args.hidden_nonlinearity == 'relu' \
                                    else torch.tanh
            policy = DecCategoricalMLPPolicy(
                env.spec,
                env.n_agents,
                hidden_nonlinearity=hidden_nonlinearity,
                hidden_sizes=args.hidden_sizes,
                name='dec_categorical_mlp_policy')

            baseline = GaussianMLPBaseline(env_spec=env.spec,
                                           hidden_sizes=(64, 64, 64))

            # Set max_path_length <= max_steps
            # If max_path_length > max_steps, algo will pad obs
            # obs.shape = torch.Size([n_paths, algo.max_path_length, feat_dim])
            algo = CentralizedMAPPO(
                env_spec=env.spec,
                policy=policy,
                baseline=baseline,
                max_path_length=args.max_env_steps, # Notice
                discount=args.discount,
                center_adv=bool(args.center_adv),
                positive_adv=bool(args.positive_adv),
                gae_lambda=args.gae_lambda,
                policy_ent_coeff=args.ent,
                entropy_method=args.entropy_method,
                stop_entropy_gradient = True \
                   if args.entropy_method == 'max' else False,
                optimization_n_minibatches=args.opt_n_minibatches,
                optimization_mini_epochs=args.opt_mini_epochs,
            )

            runner.setup(algo,
                         env,
                         sampler_cls=CentralizedMAOnPolicyVectorizedSampler,
                         sampler_args={'n_envs': args.n_envs})
            runner.train(n_epochs=args.n_epochs, batch_size=args.bs)
Esempio n. 2
0
        def train_predatorprey(ctxt=None, args_dict=vars(args)):
            args = SimpleNamespace(**args_dict)

            set_seed(args.seed)

            if args.curriculum:
                curr_start = int(0.125 * args.n_epochs)
                curr_end = int(0.625 * args.n_epochs)
            else:
                curr_start = 0
                curr_end = 0
                args.add_rate_min = args.add_rate_max

            env = TrafficJunctionWrapper(centralized=True,
                                         dim=args.dim,
                                         vision=1,
                                         add_rate_min=args.add_rate_min,
                                         add_rate_max=args.add_rate_max,
                                         curr_start=curr_start,
                                         curr_end=curr_end,
                                         difficulty=args.difficulty,
                                         n_agents=args.n_agents,
                                         max_steps=args.max_env_steps)
            env = GarageEnv(env)

            runner = LocalRunnerWrapper(ctxt,
                                        eval=args.eval_during_training,
                                        n_eval_episodes=args.n_eval_episodes,
                                        eval_greedy=args.eval_greedy,
                                        eval_epoch_freq=args.eval_epoch_freq,
                                        save_env=env.pickleable)

            hidden_nonlinearity = F.relu if args.hidden_nonlinearity == 'relu' \
                                    else torch.tanh

            policy = DecCategoricalMLPPolicy(
                env.spec,
                env.n_agents,
                hidden_nonlinearity=hidden_nonlinearity,
                hidden_sizes=args.policy_hidden_sizes,
                name='dec_categorical_mlp_policy')

            baseline = DICGCritic(
                env.spec,
                env.n_agents,
                encoder_hidden_sizes=args.encoder_hidden_sizes,
                embedding_dim=args.embedding_dim,
                attention_type=args.attention_type,
                n_gcn_layers=args.n_gcn_layers,
                residual=args.residual,
                gcn_bias=args.gcn_bias,
                name='dicg_critic')

            # Set max_path_length <= max_steps
            # If max_path_length > max_steps, algo will pad obs
            # obs.shape = torch.Size([n_paths, algo.max_path_length, feat_dim])
            algo = CentralizedMAPPO(
                env_spec=env.spec,
                policy=policy,
                baseline=baseline,
                max_path_length=args.max_env_steps, # Notice
                discount=args.discount,
                center_adv=bool(args.center_adv),
                positive_adv=bool(args.positive_adv),
                gae_lambda=args.gae_lambda,
                policy_ent_coeff=args.ent,
                entropy_method=args.entropy_method,
                stop_entropy_gradient=True \
                   if args.entropy_method == 'max' else False,
                clip_grad_norm=args.clip_grad_norm,
                optimization_n_minibatches=args.opt_n_minibatches,
                optimization_mini_epochs=args.opt_mini_epochs,
            )

            runner.setup(algo,
                         env,
                         sampler_cls=CentralizedMAOnPolicyVectorizedSampler,
                         sampler_args={'n_envs': args.n_envs})
            runner.train(n_epochs=args.n_epochs, batch_size=args.bs)