Ejemplo n.º 1
0
def test():
    # Parallel environments
    n_cpu = 4
    env = SubprocVecEnv([lambda: RSEnv() for i in range(n_cpu)])

    model = A2C(MlpPolicy, env, verbose=1)
    model.learn(total_timesteps=600000, log_interval=10)

    model.save("sba2c")

    env = TestRSEnv()
    obs = env.reset()
    done = False
    while not done:
        action, _ = model.predict(obs)
        obs, rewards, done, info = env.step(action)
        env.render()
    env.close()
Ejemplo n.º 2
0
from env.RSEnv import RSEnv
from env.TestRSEnv import TestRSEnv
from acme import environment_loop
from acme import specs
from acme import wrappers
from acme.agents.tf import d4pg
from acme.tf import networks
from acme.tf import utils as tf2_utils
from acme.utils import loggers
import numpy as np
import sonnet as snt

import gym

environment = RSEnv()
environment = wrappers.GymWrapper(environment)  # To dm_env interface.

# Make sure the environment outputs single-precision floats.
environment = wrappers.SinglePrecisionWrapper(environment)

# Grab the spec of the environment.
environment_spec = specs.make_environment_spec(environment)

#@title Build agent networks
# BUILDING A D4PG AGENT

# Get total number of action dimensions from action spec.
num_dimensions = np.prod(environment_spec.actions.shape, dtype=int)

# Create the shared observation network; here simply a state-less operation.
Ejemplo n.º 3
0
import gym

from stable_baselines.common.policies import MlpPolicy
from stable_baselines import PPO1
from env.RSEnv import RSEnv
from env.TestRSEnv import TestRSEnv

env = RSEnv()

#model = PPO1(MlpPolicy, env, verbose=1)
model = PPO1.load("sbppov3")
model.set_env(env)
model.learn(total_timesteps=3000000,
            log_interval=10,
            reset_num_timesteps=False)
model.save("sbppov4")

env = TestRSEnv()
obs = env.reset()
done = False
while not done:
    action, _ = model.predict(obs)
    obs, rewards, done, info = env.step(action)
    env.render()
env.close()
Ejemplo n.º 4
0
 def _init():
     env = RSEnv()
     env.seed(seed + rank)
     return env