def run_dqn(experiment_name):
    current_dir = pathlib.Path().absolute()
    directories = Save_paths(data_dir=f'{current_dir}/data', experiment_name=experiment_name)

    game = Winter_is_coming(setup=PARAMS['setup'])
    environment = wrappers.SinglePrecisionWrapper(game)
    spec = specs.make_environment_spec(environment)

    # Build the network.
    def _make_network(spec) -> snt.Module:
        network = snt.Sequential([
            snt.Flatten(),
            snt.nets.MLP([50, 50, spec.actions.num_values]),
        ])
        tf2_utils.create_variables(network, [spec.observations])
        return network

    network = _make_network(spec)

    # Setup the logger
    if neptune_enabled:
        agent_logger = NeptuneLogger(label='DQN agent', time_delta=0.1)
        loop_logger = NeptuneLogger(label='Environment loop', time_delta=0.1)
        PARAMS['network'] = f'{network}'
        neptune.init('cvasquez/sandbox')
        neptune.create_experiment(name=experiment_name, params=PARAMS)
    else:
        agent_logger = loggers.TerminalLogger('DQN agent', time_delta=1.)
        loop_logger = loggers.TerminalLogger('Environment loop', time_delta=1.)

    # Build the agent
    agent = DQN(
        environment_spec=spec,
        network=network,
        params=PARAMS,
        checkpoint=True,
        paths=directories,
        logger=agent_logger
    )
    # Try running the environment loop. We have no assertions here because all
    # we care about is that the agent runs without raising any errors.
    loop = acme.EnvironmentLoop(environment, agent, logger=loop_logger)
    loop.run(num_episodes=PARAMS['num_episodes'])

    last_checkpoint_path = agent.save()

    # Upload last checkpoint
    if neptune_upload_checkpoint and last_checkpoint_path:
        files = os.listdir(last_checkpoint_path)
        for f in files:
            neptune.log_artifact(os.path.join(last_checkpoint_path, f))

    if neptune_enabled:
        neptune.stop()

    do_example_run(game,agent)
示例#2
0
def main():
    get_env_version()
    cfg = DQNConfig(env="CartPole-v0", train_eps=200)
    # cfg = DQNConfig(env="MountainCar-v0", train_eps=500)
    get_env_information(env_name=cfg.env)
    env = gym.make(cfg.env)
    env.seed(0)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    agent = DQN(state_dim, action_dim, cfg)
    rewards, smooth_rewards = train(cfg, env, agent)
    os.makedirs(cfg.result_path)
    agent.save(path=cfg.result_path)
    save_results(rewards, smooth_rewards, tag='train', path=cfg.result_path)
    plot_rewards(rewards,
                 smooth_rewards,
                 tag='train',
                 env=cfg.env,
                 algo=cfg.algo,
                 path=cfg.result_path)