Пример #1
0
def main(env_id, policy_file, record, stochastic, extra_kwargs):
    import gym
    from gym import wrappers
    import tensorflow as tf
    from es_distributed.policies import ESAtariPolicy
    from es_distributed.atari_wrappers import ScaledFloatFrame, wrap_deepmind
    from es_distributed.es import get_ref_batch
    import numpy as np

    is_atari_policy = "NoFrameskip" in env_id

    env = gym.make(env_id)
    if is_atari_policy:
        env = wrap_deepmind(env)

    if record:
        import uuid
        env = wrappers.Monitor(env, '/tmp/' + str(uuid.uuid4()), force=True)

    if extra_kwargs:
        import json
        extra_kwargs = json.loads(extra_kwargs)

    with tf.Session():
        pi = ESAtariPolicy.Load(policy_file, extra_kwargs=extra_kwargs)
        pi.set_ref_batch(get_ref_batch(env, batch_size=128))

        while True:
            if is_atari_policy:
                rews, t, novelty_vector = pi.rollout(env, render=True, random_stream=np.random if stochastic else None)
            print('return={:.4f} len={}'.format(rews, t))

            if record:
                env.close()
                return
Пример #2
0
def build_env(env_id):
    gym.undo_logger_setup()
    env = gym.make(env_id)

    if env_id.endswith('NoFrameskip-v4'):
        env = wrap_deepmind(env)
    return env
Пример #3
0
def main(env_ids, policy_directory, record, stochastic, extra_kwargs):
    import gym
    from gym import wrappers
    import tensorflow as tf
    from es_distributed.policies import MujocoPolicy, ESAtariPolicy, GAAtariPolicy
    from es_distributed.atari_wrappers import ScaledFloatFrame, wrap_deepmind
    from es_distributed.es import get_ref_batch
    import es_distributed.ns as ns
    import numpy as np
    import os

    env_ids = env_ids.split(' ')

    is_atari_policy = "NoFrameskip" in env_ids[0]

    files = 0

    for policy_name in os.listdir(policy_directory):
        files += 1
        policy_file = "%s/%s" % (policy_directory, policy_name)
        pid = os.fork()
        if (pid == 0):
            env = []
            for i in range(0, len(env_ids)):
                env.append(gym.make(env_ids[i]))
                if env_ids[i].endswith('NoFrameskip-v4'):
                    env[i] = wrap_deepmind(env[i])

            if extra_kwargs:
                import json
                extra_kwargs = json.loads(extra_kwargs)

            with tf.Session():
                if is_atari_policy:
                    pi = GAAtariPolicy.Load(policy_file, extra_kwargs=extra_kwargs)
                    if pi.needs_ref_batch:
                        pi.set_ref_batch(get_ref_batch(env[0], batch_size=128))
                else:
                    pi = MujocoPolicy.Load(policy_file, extra_kwargs=extra_kwargs)

                while True:
                    if is_atari_policy:
                        rews, t, novelty_vector = pi.rollout(env, render=True, random_stream=np.random if stochastic else None)

    for i in range(0, files):
        os.wait()
Пример #4
0
 def make_env(self):
     from es_distributed.atari_wrappers import wrap_deepmind
     env_id = "FrostbiteNoFrameskip-v4"
     env = gym.make(env_id)
     env = wrap_deepmind(env)
     return env