コード例 #1
0
 def __init__(
     self,
     env,
     reward_scale=1.,
     obs_mean=None,
     obs_std=None,
 ):
     # self._wrapped_env needs to be called first because
     # Serializable.quick_init calls getattr, on this class. And the
     # implementation of getattr (see below) calls self._wrapped_env.
     # Without setting this first, the call to self._wrapped_env would call
     # getattr again (since it's not set yet) and therefore loop forever.
     self._wrapped_env = env
     # Or else serialization gets delegated to the wrapped_env. Serialize
     # this env separately from the wrapped_env.
     self._serializable_initialized = False
     Serializable.quick_init(self, locals())
     ProxyEnv.__init__(self, env)
     self._should_normalize = not (obs_mean is None and obs_std is None)
     if self._should_normalize:
         if obs_mean is None:
             obs_mean = np.zeros_like(env.observation_space.low)
         else:
             obs_mean = np.array(obs_mean)
         if obs_std is None:
             obs_std = np.ones_like(env.observation_space.low)
         else:
             obs_std = np.array(obs_std)
     self._reward_scale = reward_scale
     self._obs_mean = obs_mean
     self._obs_std = obs_std
     ub = np.ones(self._wrapped_env.action_space.shape)
     self.action_space = Box(-1 * ub, ub)
コード例 #2
0
ファイル: policies.py プロジェクト: jinparksj/RL_NABI
 def __setstate__(self, d):
     Serializable.__setstate__(self, d)
     global load_params
     if load_params:
         tf.get_default_session().run(
             tf.variables_initializer(self.get_params()))
         self.set_param_values(d["params"])
コード例 #3
0
 def __init__(self, observation_space, action_space):
     """
     :type observation_space: Space
     :type action_space: Space
     """
     Serializable.quick_init(self, locals())
     self._observation_space = observation_space
     self._action_space = action_space
コード例 #4
0
ファイル: policies.py プロジェクト: jinparksj/RL_NABI
    def __init__(self, env_spec, obs_pl, action, scope_name=None):
        Serializable.quick_init(self, locals())

        self._obs_pl = obs_pl
        self._action = action
        self._scope_name = (tf.get_variable_scope().name
                            if not scope_name else scope_name)
        super(NNPolicy, self).__init__(env_spec)
コード例 #5
0
ファイル: policies.py プロジェクト: jinparksj/RL_NABI
    def __init__(self, inputs, name, hidden_layer_sizes):
        Parameterized.__init__(self)
        Serializable.quick_init(self, locals())

        self._name = name
        self._inputs = inputs
        self._layer_sizes = list(hidden_layer_sizes) + [1]

        self._output = self._output_for(self._inputs)
コード例 #6
0
ファイル: gym_env.py プロジェクト: jinparksj/RL_NABI
    def __init__(self,
                 env_name,
                 record_video=False,
                 video_schedule=None,
                 log_dir=None,
                 record_log=False,
                 force_reset=True):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log(
                    "Warning: skipping Gym environment monitoring since snapshot_dir not configured."
                )
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)

        # HACK: Gets rid of the TimeLimit wrapper that sets 'done = True' when
        # the time limit specified for each environment has been passed and
        # therefore the environment is not Markovian (terminal condition depends
        # on time rather than state).
        env = env.env

        self.env = env
        self.env_id = env.spec.id

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env,
                                            log_dir,
                                            video_callable=video_schedule,
                                            force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        logger.log("observation space: {}".format(self._observation_space))
        self._action_space = convert_gym_space(env.action_space)
        logger.log("action space: {}".format(self._action_space))
        self._horizon = env.spec.tags[
            'wrapper_config.TimeLimit.max_episode_steps']
        self._log_dir = log_dir
        self._force_reset = force_reset
コード例 #7
0
 def __getstate__(self):
     d = Serializable.__getstate__(self)
     # Add these explicitly in case they were modified
     d["_obs_mean"] = self._obs_mean
     d["_obs_std"] = self._obs_std
     d["_reward_scale"] = self._reward_scale
     return d
コード例 #8
0
 def __init__(
         self,
         env,
         scale_reward=1.,
         normalize_obs=False,
         normalize_reward=False,
         obs_alpha=0.001,
         reward_alpha=0.001,
 ):
     Serializable.quick_init(self, locals())
     ProxyEnv.__init__(self, env)
     self._scale_reward = scale_reward
     self._normalize_obs = normalize_obs
     self._normalize_reward = normalize_reward
     self._obs_alpha = obs_alpha
     self._obs_mean = np.zeros(np.prod(env.observation_space.low.shape))
     self._obs_var = np.ones(np.prod(env.observation_space.low.shape))
     self._reward_alpha = reward_alpha
     self._reward_mean = 0.
     self._reward_var = 1.
コード例 #9
0
ファイル: value_functions.py プロジェクト: jinparksj/RL_NABI
    def __init__(self,
                 env_spec,
                 hidden_layer_sizes=(100, 100),
                 name='q_function',
                 ac_dim=np.int64(3)):
        Serializable.quick_init(self, locals())

        self._Da = ac_dim
        self._Do = env_spec.observation_flat_dim

        self._observations_ph = tf.placeholder(tf.float32,
                                               shape=[None, self._Do],
                                               name='observations')
        self._actions_ph = tf.placeholder(tf.float32,
                                          shape=[None, self._Da],
                                          name='actions')

        super(NNQFunction,
              self).__init__(inputs=(self._observations_ph, self._actions_ph),
                             name=name,
                             hidden_layer_sizes=hidden_layer_sizes)
コード例 #10
0
ファイル: policies.py プロジェクト: jinparksj/RL_NABI
    def __init__(self,
                 env_spec,
                 hidden_layer_sizes,
                 squash=True,
                 name='policy',
                 ac_dim=np.int64(3)):
        Serializable.quick_init(self, locals())

        self._action_dim = ac_dim
        self._observation_dim = env_spec.observation_flat_dim
        self._layer_sizes = list(hidden_layer_sizes) + [self._action_dim]
        self._squash = squash
        self._name = name

        self._observation_ph = tf.placeholder(
            tf.float32,
            shape=[None, self._observation_dim],
            name='observation')

        self._actions = self.actions_for(self._observation_ph)

        super(StochasticNNPolicy,
              self).__init__(env_spec, self._observation_ph, self._actions,
                             self._name)
コード例 #11
0
 def __setstate__(self, d):
     Serializable.__setstate__(self, d)
     self._obs_mean = d["_obs_mean"]
     self._obs_std = d["_obs_std"]
     self._reward_scale = d["_reward_scale"]
コード例 #12
0
 def __init__(self, wrapped_env):
     Serializable.quick_init(self, locals())
     self._wrapped_env = wrapped_env
     self.action_space = self._wrapped_env.action_space
     self.observation_space = self._wrapped_env.observation_space
コード例 #13
0
 def __init__(self, wrapped_env):
     Serializable.quick_init(self, locals())
     self._wrapped_env = wrapped_env
コード例 #14
0
 def init_serialization(self, locals):
     Serializable.quick_init(self, locals)
コード例 #15
0
ファイル: policies.py プロジェクト: jinparksj/RL_NABI
 def __getstate__(self):
     d = Serializable.__getstate__(self)
     global load_params
     if load_params:
         d["params"] = self.get_param_values()
     return d