예제 #1
0
        def train_smac(ctxt=None, args_dict=vars(args)):
            args = SimpleNamespace(**args_dict)

            env = SMACWrapper(
                centralized=True,  # important, using centralized sampler
                map_name=args.map,
                difficulty=args.difficulty,
                # seed=args.seed
            )
            env = GarageEnv(env)

            runner = LocalRunnerWrapper(ctxt,
                                        eval=args.eval_during_training,
                                        n_eval_episodes=args.n_eval_episodes,
                                        eval_greedy=args.eval_greedy,
                                        eval_epoch_freq=args.eval_epoch_freq,
                                        save_env=env.pickleable)

            hidden_nonlinearity = F.relu if args.hidden_nonlinearity == 'relu' \
                                    else torch.tanh

            policy = DecCategoricalLSTMPolicy(
                env.spec,
                n_agents=env.n_agents,
                encoder_hidden_sizes=args.encoder_hidden_sizes,
                embedding_dim=args.embedding_dim,  # encoder output size
                lstm_hidden_size=args.lstm_hidden_size,
                state_include_actions=args.state_include_actions,
                name='dec_categorical_lstm_policy')

            baseline = GaussianMLPBaseline(env_spec=env.spec,
                                           hidden_sizes=(64, 64, 64))

            # Set max_path_length <= max_steps
            # If max_path_length > max_steps, algo will pad obs
            # obs.shape = torch.Size([n_paths, algo.max_path_length, feat_dim])
            algo = CentralizedMAPPO(
                env_spec=env.spec,
                policy=policy,
                baseline=baseline,
                max_path_length=env.episode_limit, # Notice
                discount=args.discount,
                center_adv=bool(args.center_adv),
                positive_adv=bool(args.positive_adv),
                gae_lambda=args.gae_lambda,
                policy_ent_coeff=args.ent,
                entropy_method=args.entropy_method,
                stop_entropy_gradient=True \
                   if args.entropy_method == 'max' else False,
                clip_grad_norm=args.clip_grad_norm,
                optimization_n_minibatches=args.opt_n_minibatches,
                optimization_mini_epochs=args.opt_mini_epochs,
            )

            runner.setup(algo,
                         env,
                         sampler_cls=CentralizedMAOnPolicyVectorizedSampler,
                         sampler_args={'n_envs': args.n_envs})
            runner.train(n_epochs=args.n_epochs, batch_size=args.bs)
예제 #2
0
        def train_predatorprey(ctxt=None, args_dict=vars(args)):
            args = SimpleNamespace(**args_dict)
            set_seed(args.seed)
            env = PredatorPreyWrapper(
                centralized=True,  # centralized training
                grid_shape=(args.grid_size, args.grid_size),
                n_agents=args.n_agents,
                n_preys=args.n_preys,
                max_steps=args.max_env_steps,
                step_cost=args.step_cost,
                prey_capture_reward=args.capture_reward,
                penalty=args.penalty,
                other_agent_visible=args.agent_visible)
            env = GarageEnv(env)

            runner = LocalRunnerWrapper(ctxt,
                                        eval=args.eval_during_training,
                                        n_eval_episodes=args.n_eval_episodes,
                                        eval_greedy=args.eval_greedy,
                                        eval_epoch_freq=args.eval_epoch_freq,
                                        save_env=env.pickleable)

            hidden_nonlinearity = F.relu if args.hidden_nonlinearity == 'relu' \
                                    else torch.tanh
            policy = DecCategoricalMLPPolicy(
                env.spec,
                env.n_agents,
                hidden_nonlinearity=hidden_nonlinearity,
                hidden_sizes=args.hidden_sizes,
                name='dec_categorical_mlp_policy')

            baseline = GaussianMLPBaseline(env_spec=env.spec,
                                           hidden_sizes=(64, 64, 64))

            # Set max_path_length <= max_steps
            # If max_path_length > max_steps, algo will pad obs
            # obs.shape = torch.Size([n_paths, algo.max_path_length, feat_dim])
            algo = CentralizedMAPPO(
                env_spec=env.spec,
                policy=policy,
                baseline=baseline,
                max_path_length=args.max_env_steps, # Notice
                discount=args.discount,
                center_adv=bool(args.center_adv),
                positive_adv=bool(args.positive_adv),
                gae_lambda=args.gae_lambda,
                policy_ent_coeff=args.ent,
                entropy_method=args.entropy_method,
                stop_entropy_gradient = True \
                   if args.entropy_method == 'max' else False,
                optimization_n_minibatches=args.opt_n_minibatches,
                optimization_mini_epochs=args.opt_mini_epochs,
            )

            runner.setup(algo,
                         env,
                         sampler_cls=CentralizedMAOnPolicyVectorizedSampler,
                         sampler_args={'n_envs': args.n_envs})
            runner.train(n_epochs=args.n_epochs, batch_size=args.bs)
예제 #3
0
파일: runner_utils.py 프로젝트: sisl/DICG
def restore_training(log_dir, exp_name, args, env_saved=True, env=None):
    tabular_log_file = os.path.join(
        log_dir, 'progress_restored.{}.{}.csv'.format(
            str(time.time())[:10], socket.gethostname()))
    text_log_file = os.path.join(
        log_dir, 'debug_restored.{}.{}.log'.format(
            str(time.time())[:10], socket.gethostname()))
    logger.add_output(dowel.TextOutput(text_log_file))
    logger.add_output(dowel.CsvOutput(tabular_log_file))
    logger.add_output(dowel.TensorBoardOutput(log_dir))
    logger.add_output(dowel.StdOutput())
    logger.push_prefix('[%s] ' % exp_name)

    ctxt = ExperimentContext(snapshot_dir=log_dir,
                             snapshot_mode='last',
                             snapshot_gap=1)

    runner = LocalRunnerWrapper(ctxt,
                                eval=args.eval_during_training,
                                n_eval_episodes=args.n_eval_episodes,
                                eval_greedy=args.eval_greedy,
                                eval_epoch_freq=args.eval_epoch_freq,
                                save_env=env_saved)
    saved = runner._snapshotter.load(log_dir, 'last')
    runner._setup_args = saved['setup_args']
    runner._train_args = saved['train_args']
    runner._stats = saved['stats']

    set_seed(runner._setup_args.seed)
    algo = saved['algo']

    # Compatibility patch
    if not hasattr(algo, '_clip_grad_norm'):
        setattr(algo, '_clip_grad_norm', args.clip_grad_norm)

    if env_saved:
        env = saved['env']

    runner.setup(env=env,
                 algo=algo,
                 sampler_cls=runner._setup_args.sampler_cls,
                 sampler_args=runner._setup_args.sampler_args)
    runner._train_args.start_epoch = runner._stats.total_epoch + 1
    runner._train_args.n_epochs = runner._train_args.start_epoch + args.n_epochs

    print('\nRestored checkpoint from epoch #{}...'.format(
        runner._train_args.start_epoch))
    print('To be trained for additional {} epochs...'.format(args.n_epochs))
    print('Will be finished at epoch #{}...\n'.format(
        runner._train_args.n_epochs))

    return runner._algo.train(runner)
예제 #4
0
        def train_predatorprey(ctxt=None, args_dict=vars(args)):
            args = SimpleNamespace(**args_dict)

            set_seed(args.seed)

            if args.curriculum:
                curr_start = int(0.125 * args.n_epochs)
                curr_end = int(0.625 * args.n_epochs)
            else:
                curr_start = 0
                curr_end = 0
                args.add_rate_min = args.add_rate_max

            env = TrafficJunctionWrapper(centralized=True,
                                         dim=args.dim,
                                         vision=1,
                                         add_rate_min=args.add_rate_min,
                                         add_rate_max=args.add_rate_max,
                                         curr_start=curr_start,
                                         curr_end=curr_end,
                                         difficulty=args.difficulty,
                                         n_agents=args.n_agents,
                                         max_steps=args.max_env_steps)
            env = GarageEnv(env)

            runner = LocalRunnerWrapper(ctxt,
                                        eval=args.eval_during_training,
                                        n_eval_episodes=args.n_eval_episodes,
                                        eval_greedy=args.eval_greedy,
                                        eval_epoch_freq=args.eval_epoch_freq,
                                        save_env=env.pickleable)

            hidden_nonlinearity = F.relu if args.hidden_nonlinearity == 'relu' \
                                    else torch.tanh

            policy = DecCategoricalMLPPolicy(
                env.spec,
                env.n_agents,
                hidden_nonlinearity=hidden_nonlinearity,
                hidden_sizes=args.policy_hidden_sizes,
                name='dec_categorical_mlp_policy')

            baseline = DICGCritic(
                env.spec,
                env.n_agents,
                encoder_hidden_sizes=args.encoder_hidden_sizes,
                embedding_dim=args.embedding_dim,
                attention_type=args.attention_type,
                n_gcn_layers=args.n_gcn_layers,
                residual=args.residual,
                gcn_bias=args.gcn_bias,
                name='dicg_critic')

            # Set max_path_length <= max_steps
            # If max_path_length > max_steps, algo will pad obs
            # obs.shape = torch.Size([n_paths, algo.max_path_length, feat_dim])
            algo = CentralizedMAPPO(
                env_spec=env.spec,
                policy=policy,
                baseline=baseline,
                max_path_length=args.max_env_steps, # Notice
                discount=args.discount,
                center_adv=bool(args.center_adv),
                positive_adv=bool(args.positive_adv),
                gae_lambda=args.gae_lambda,
                policy_ent_coeff=args.ent,
                entropy_method=args.entropy_method,
                stop_entropy_gradient=True \
                   if args.entropy_method == 'max' else False,
                clip_grad_norm=args.clip_grad_norm,
                optimization_n_minibatches=args.opt_n_minibatches,
                optimization_mini_epochs=args.opt_mini_epochs,
            )

            runner.setup(algo,
                         env,
                         sampler_cls=CentralizedMAOnPolicyVectorizedSampler,
                         sampler_args={'n_envs': args.n_envs})
            runner.train(n_epochs=args.n_epochs, batch_size=args.bs)