示例#1
0
def main(
    envname: str = "CartPole-v0",
    tau: float = 12 * 20,
    update_freq: int = 10,
) -> rainy.Config:
    c = rainy.Config()
    c.set_env(lambda: ClassicControl(envname))
    c.max_steps = int(4e5)
    c.nworkers = 12
    c.nsteps = 20
    c.set_parallel_env(MultiProcEnv)
    c.set_optimizer(kfac.default_sgd(eta_max=0.1))
    c.set_preconditioner(lambda net: kfac.KfacPreConditioner(
        net,
        tau=tau,
        update_freq=update_freq,
        norm_scaler=kfac.SquaredFisherScaler(eta_max=0.1, delta=0.001),
    ))
    c.gae_lambda = 0.95
    c.use_gae = False
    c.lr_min = 0.0
    c.value_loss_weight = 0.2
    c.entropy_weight = 0.01
    c.eval_freq = None
    return c
示例#2
0
def main(
    envname: str = "Hopper",
    tau: float = 12 * 20,
    update_freq: int = 10,
) -> Config:
    c = Config()
    c.max_steps = int(4e5)
    c.nworkers = 12
    c.nsteps = 20
    c.set_env(lambda: PyBullet(envname))
    c.set_net_fn("actor-critic",
                 net.actor_critic.fc_shared(policy=SeparateStdGaussianDist))
    c.set_parallel_env(pybullet_parallel())
    c.set_optimizer(kfac.default_sgd(eta_max=0.1))
    c.set_preconditioner(lambda net: kfac.KfacPreConditioner(
        net,
        tau=tau,
        update_freq=update_freq,
        norm_scaler=kfac.SquaredFisherScaler(eta_max=0.1, delta=0.001),
    ))
    c.gae_lambda = 0.95
    c.use_gae = True
    c.eval_deterministic = False
    c.value_loss_weight = 0.5
    c.entropy_weight = 0.0
    c.eval_freq = None
    return c
示例#3
0
def main(
    envname: str = "Breakout",
    tau: float = 32 * 20 // 2,
    update_freq: int = 10,
) -> rainy.Config:
    c = rainy.Config()
    c.set_env(lambda: Atari(envname, frame_stack=False))
    c.set_optimizer(kfac.default_sgd(eta_max=0.2))
    c.set_preconditioner(lambda net: kfac.KfacPreConditioner(
        net,
        tau=tau,
        update_freq=update_freq,
        norm_scaler=kfac.SquaredFisherScaler(eta_max=0.2, delta=0.001),
    ))
    c.set_net_fn("actor-critic", rainy.net.actor_critic.conv_shared())
    c.nworkers = 32
    c.nsteps = 20
    c.set_parallel_env(atari_parallel())
    c.value_loss_weight = 1.0
    c.use_gae = True
    c.lr_min = 0.0
    c.max_steps = int(2e7)
    c.eval_env = Atari(envname)
    c.eval_freq = None
    c.episode_log_freq = 100
    c.eval_deterministic = False
    return c
示例#4
0
import os
from rainy import Config
from rainy.agents import AcktrAgent
import rainy.utils.cli as cli
from rainy.envs import MultiProcEnv
from rainy.lib import kfac

KFAC_KWARGS = {
    'tau': 12 * 20,
    'update_freq': 10,
    'norm_scaler': kfac.SquaredFisherScaler(eta_max=0.1, delta=0.001),
}


def config() -> Config:
    c = Config()
    c.max_steps = int(4e5)
    c.nworkers = 12
    c.nsteps = 20
    c.set_parallel_env(MultiProcEnv)
    c.set_optimizer(kfac.default_sgd(eta_max=0.1))
    c.set_preconditioner(
        lambda net: kfac.KfacPreConditioner(net, **KFAC_KWARGS))
    c.gae_lambda = 0.95
    c.use_gae = False
    c.lr_min = 0.0
    c.value_loss_weight = 0.1
    c.entropy_weight = 0.01
    c.eval_freq = None
    return c