示例#1
0
def trainer_fn(use_dice, use_opp_modeling, epochs, batch_size, env_name, trace_length, gamma, grid_size,
               lr_inner, lr_outer, lr_value, lr_om, inner_asymm, n_agents, n_inner_steps, value_batch_size,
               value_epochs, om_batch_size, om_epochs, use_baseline, policy_maker, make_optim, **kwargs):
    # Instantiate the environment
    if env_name == "IPD":
        env = lola_dice_envs.IPD(max_steps=trace_length, batch_size=batch_size)
    elif env_name == "IMP":
        env = lola_dice_envs.IMP(trace_length)
    elif env_name == "CoinGame":
        env = lola_dice_envs.CG(trace_length, batch_size, grid_size)
        timestamp = datetime.now().timestamp()
        env.seed(int(timestamp))
    elif env_name == "AsymCoinGame":
        env = lola_dice_envs.AsymCG(trace_length, batch_size, grid_size)
        timestamp = datetime.now().timestamp()
        env.seed(int(timestamp))
    else:
        raise ValueError(f"env_name: {env_name}")

    train(env, policy_maker,
          make_optim,
          epochs=epochs,
          gamma=gamma,
          lr_inner=lr_inner,
          lr_outer=lr_outer,
          lr_value=lr_value,
          lr_om=lr_om,
          inner_asymm=inner_asymm,
          n_agents=n_agents,
          n_inner_steps=n_inner_steps,
          value_batch_size=value_batch_size,
          value_epochs=value_epochs,
          om_batch_size=om_batch_size,
          om_epochs=om_epochs,
          use_baseline=use_baseline,
          use_dice=use_dice,
          use_opp_modeling=use_opp_modeling)
示例#2
0
def trainer_fn(exp_name, num_episodes, trace_length, exact, pseudo, grid_size,
               lr, lr_correction, batch_size, bs_mul, simple_net, hidden, reg,
               gamma, lola_update, opp_model, mem_efficient, seed, set_zero,
               warmup, changed_config, ac_lr, summary_len, use_MAE,
               use_toolbox_env, clip_lola_update_norm, clip_loss_norm,
               entropy_coeff, weigth_decay, **kwargs):
    # Instantiate the environment
    if exp_name == "IPD":
        raise NotImplementedError()
    elif exp_name == "IMP":
        raise NotImplementedError()
    elif exp_name == "CoinGame":
        if use_toolbox_env:
            env = VectorizedCoinGame(
                config={
                    "batch_size": batch_size,
                    "max_steps": trace_length,
                    "grid_size": grid_size,
                    "get_additional_info": True,
                    "add_position_in_epi": False,
                })
        else:
            env = lola_dice_envs.CG(trace_length, batch_size, grid_size)
        env.seed(seed)
    elif exp_name == "AsymCoinGame":
        if use_toolbox_env:
            env = AsymVectorizedCoinGame(
                config={
                    "batch_size": batch_size,
                    "max_steps": trace_length,
                    "grid_size": grid_size,
                    "get_additional_info": True,
                    "add_position_in_epi": False,
                })
        else:
            env = lola_dice_envs.AsymCG(trace_length, batch_size, grid_size)
        env.seed(seed)
    else:
        raise ValueError(f"exp_name: {exp_name}")

    # Import the right training function
    if exact:
        raise NotImplementedError()
    elif exp_name in ("IPD", "IMP"):
        train_pg.train(env,
                       num_episodes=num_episodes,
                       trace_length=trace_length,
                       batch_size=batch_size,
                       gamma=gamma,
                       set_zero=set_zero,
                       lr=lr,
                       corrections=lola_update,
                       simple_net=simple_net,
                       hidden=hidden,
                       mem_efficient=mem_efficient)
    elif exp_name in ("CoinGame", "AsymCoinGame"):
        train_cg.train(
            env,
            num_episodes=num_episodes,
            trace_length=trace_length,
            batch_size=batch_size,
            bs_mul=bs_mul,
            gamma=gamma,
            grid_size=grid_size,
            lr=lr,
            corrections=lola_update,
            opp_model=opp_model,
            hidden=hidden,
            mem_efficient=mem_efficient,
            asymmetry=exp_name == "AsymCoinGame",
            warmup=warmup,
            changed_config=changed_config,
            ac_lr=ac_lr,
            summary_len=summary_len,
            use_MAE=use_MAE,
            use_toolbox_env=use_toolbox_env,
            clip_lola_update_norm=clip_lola_update_norm,
            clip_loss_norm=clip_loss_norm,
            entropy_coeff=entropy_coeff,
            weigth_decay=weigth_decay,
        )
    else:
        raise ValueError(f"exp_name: {exp_name}")
    def _init_lola(self, *, env, make_policy, make_optimizer, epochs,
                   batch_size, trace_length, grid_size, gamma, lr_inner,
                   lr_outer, lr_value, lr_om, inner_asymm, n_agents,
                   n_inner_steps, value_batch_size, value_epochs,
                   om_batch_size, om_epochs, use_baseline, use_dice,
                   use_opp_modeling, seed, **kwargs):

        print("args not used:", kwargs)

        # Instantiate the environment
        if env == "IPD":
            self.env = lola_dice_envs.IPD(max_steps=trace_length,
                                          batch_size=batch_size)
        elif env == "AsymBoS":
            self.env = lola_dice_envs.AsymBoS(max_steps=trace_length,
                                              batch_size=batch_size)
        elif env == "IMP":
            self.env = lola_dice_envs.IMP(trace_length)
        elif env == "CoinGame":
            self.env = lola_dice_envs.CG(trace_length, batch_size, grid_size)
            self.env.seed(int(seed))
        elif env == "AsymCoinGame":
            self.env = lola_dice_envs.AsymCG(trace_length, batch_size,
                                             grid_size)
            self.env.seed(int(seed))
        else:
            raise ValueError(f"env: {env}")

        self.gamma = gamma
        self.lr_inner = lr_inner
        self.lr_outer = lr_outer
        self.lr_value = lr_value
        self.lr_om = lr_om
        self.inner_asymm = inner_asymm
        self.n_agents = n_agents
        self.n_inner_steps = n_inner_steps
        self.value_batch_size = value_batch_size
        self.value_epochs = value_epochs
        self.om_batch_size = om_batch_size
        self.om_epochs = om_epochs
        self.use_baseline = use_baseline
        self.use_dice = use_dice
        self.use_opp_modeling = use_opp_modeling
        self.timestep = 0

        if make_policy[0] == "make_simple_policy":
            make_policy = functools.partial(make_simple_policy,
                                            **make_policy[1])
        elif make_policy[0] == "make_conv_policy":
            make_policy = functools.partial(make_conv_policy, **make_policy[1])
        elif make_policy[0] == "make_mlp_policy":
            make_policy = functools.partial(make_mlp_policy, **make_policy[1])
        else:
            NotImplementedError()

        if make_optimizer[0] == "make_adam_optimizer":
            make_optimizer = functools.partial(make_adam_optimizer,
                                               **make_optimizer[1])
        elif make_optimizer[0] == "make_sgd_optimizer":
            make_optimizer = functools.partial(make_sgd_optimizer,
                                               **make_optimizer[1])
        else:
            NotImplementedError()

        # Build.
        graph = tf.Graph()
        with graph.as_default() as g:

            (self.policies, self.rollout_policies, pol_losses, val_losses,
             om_losses, update_pol_ops, update_val_ops,
             update_om_ops) = build_graph(
                 self.env,
                 make_policy,
                 make_optimizer,
                 lr_inner=lr_inner,
                 lr_outer=lr_outer,
                 lr_value=lr_value,
                 lr_om=lr_om,
                 n_agents=self.n_agents,
                 n_inner_steps=n_inner_steps,
                 use_baseline=use_baseline,
                 use_dice=use_dice,
                 use_opp_modeling=self.use_opp_modeling,
                 inner_asymm=inner_asymm)

            # Train.
            self.acs_all = []
            self.rets_all = []
            self.params_all = []
            self.params_om_all = []
            self.times_all = []
            self.pick_speed_all = []

            self.sess = tf.Session()
            self.sess.run(tf.global_variables_initializer())

            # Construct update functions.
            self.update_funcs = {
                'policy': [
                    get_update([self.policies[k]] + self.policies[k].opponents,
                               pol_losses[k],
                               update_pol_ops[k],
                               self.sess,
                               gamma=self.gamma) for k in range(self.n_agents)
                ],
                'value': [
                    get_update([self.policies[k]],
                               val_losses[k],
                               update_val_ops[k],
                               self.sess,
                               gamma=self.gamma) for k in range(self.n_agents)
                ],
                'opp': [
                    get_update(self.policies[k].root.opponents,
                               om_losses[k],
                               update_om_ops[k],
                               self.sess,
                               gamma=self.gamma) for k in range(self.n_agents)
                ] if om_losses else None,
            }

            self.root_policies = [pi.root for pi in self.policies]

            self.saver = tf.train.Saver(max_to_keep=5)