Esempio n. 1
0
 def generator(self, data_dir, tmp_dir):
   env_spec = lambda: atari_wrappers.wrap_atari(  # pylint: disable=g-long-lambda
       gym.make(self.env_name),
       warp=False,
       frame_skip=4,
       frame_stack=False)
   hparams = rl.atari_base()
   with tf.variable_scope("train", reuse=tf.AUTO_REUSE):
     policy_lambda = hparams.network
     policy_factory = tf.make_template(
         "network",
         functools.partial(policy_lambda, env_spec().action_space, hparams))
     self._max_frame_pl = tf.placeholder(
         tf.float32, self.env.observation_space.shape)
     actor_critic = policy_factory(tf.expand_dims(tf.expand_dims(
         self._max_frame_pl, 0), 0))
     policy = actor_critic.policy
     self._last_policy_op = policy.mode()
     with tf.Session() as sess:
       model_saver = tf.train.Saver(
           tf.global_variables(".*network_parameters.*"))
       model_saver.restore(sess, FLAGS.model_path)
       for item in super(GymPongTrajectoriesFromPolicy,
                         self).generator(data_dir, tmp_dir):
         yield item
Esempio n. 2
0
 def __init__(self, event_dir, *args, **kwargs):
     super(GymPongTrajectoriesFromPolicy, self).__init__(*args, **kwargs)
     self._env = None
     self._event_dir = event_dir
     env_spec = lambda: atari_wrappers.wrap_atari(  # pylint: disable=g-long-lambda
         gym.make("PongNoFrameskip-v4"),
         warp=False,
         frame_skip=4,
         frame_stack=False)
     hparams = rl.atari_base()
     with tf.variable_scope("train"):
         policy_lambda = hparams.network
         policy_factory = tf.make_template(
             "network",
             functools.partial(policy_lambda,
                               env_spec().action_space, hparams))
         self._max_frame_pl = tf.placeholder(
             tf.float32, self.env.observation_space.shape)
         actor_critic = policy_factory(
             tf.expand_dims(tf.expand_dims(self._max_frame_pl, 0), 0))
         policy = actor_critic.policy
         self._last_policy_op = policy.mode()
     self._last_action = self.env.action_space.sample()
     self._skip = 4
     self._skip_step = 0
     self._obs_buffer = np.zeros((2, ) + self.env.observation_space.shape,
                                 dtype=np.uint8)
     self._sess = tf.Session()
     model_saver = tf.train.Saver(
         tf.global_variables(".*network_parameters.*"))
     model_saver.restore(self._sess, FLAGS.model_path)
Esempio n. 3
0
    def __init__(self, *args, **kwargs):
        super(GymDiscreteProblemWithAgent, self).__init__(*args, **kwargs)
        self._env = None
        self.debug_dump_frames_path = "debug_frames_env"

        # defaults
        self.environment_spec = lambda: gym.make("PongDeterministic-v4")
        self.in_graph_wrappers = []
        self.collect_hparams = rl.atari_base()
        self.settable_num_steps = 20000
        self.simulated_environment = None
        self.warm_up = 10
Esempio n. 4
0
    def __init__(self, *args, **kwargs):
        super(GymDiscreteProblemWithAgent, self).__init__(*args, **kwargs)
        self._env = None
        self.history_size = 2

        # defaults
        self.environment_spec = lambda: gym.make("PongDeterministic-v4")
        self.in_graph_wrappers = [(atari.MaxAndSkipWrapper, {"skip": 4})]
        self.collect_hparams = rl.atari_base()
        self.settable_num_steps = 1000
        self.simulated_environment = None
        self.warm_up = 70
Esempio n. 5
0
    def __init__(self, *args, **kwargs):
        super(GymDiscreteProblemWithAgent, self).__init__(*args, **kwargs)
        self._env = None
        self.debug_dump_frames_path = "debug_frames_env"
        self.make_extra_debug_info = False

        # defaults
        self.environment_spec = lambda: gym.make(self.env_name)
        self.in_graph_wrappers = []
        self.collect_hparams = rl.atari_base()
        self.settable_num_steps = 20000
        self.simulated_environment = None
        self.warm_up = 10  # TODO(piotrm): This should be probably removed.
Esempio n. 6
0
    def __init__(self, *args, **kwargs):
        super(GymDiscreteProblem, self).__init__(*args, **kwargs)
        self.num_channels = 3
        self.history_size = 2

        # defaults
        self.environment_spec = lambda: gym.make("PongNoFrameskip-v4")
        self.in_graph_wrappers = [(MaxAndSkipWrapper, {"skip": 4})]
        self.collect_hparams = rl.atari_base()
        self.num_steps = 1000
        self.movies = True
        self.movies_fps = 24
        self.simulated_environment = None
        self.warm_up = 70