Пример #1
0
class SimpleQFunction(QFunction2):
    """Simple QFunction for testing."""
    def __init__(self, env_spec, name='SimpleQFunction'):
        super().__init__(name)
        self.obs_dim = env_spec.observation_space.shape
        action_dim = env_spec.observation_space.flat_dim
        self.model = SimpleMLPModel(output_dim=action_dim)

        self._initialize()

    def _initialize(self):
        obs_ph = tf.compat.v1.placeholder(tf.float32, (None, ) + self.obs_dim,
                                          name='obs')

        with tf.compat.v1.variable_scope(self.name, reuse=False) as vs:
            self._variable_scope = vs
            self.model.build(obs_ph)

    @property
    def q_vals(self):
        return self.model.networks['default'].outputs

    def __setstate__(self, state):
        self.__dict__.update(state)
        self._initialize()
Пример #2
0
class SimpleQFunction(QFunction):
    """Simple QFunction for testing.

    Args:
        env_spec (garage.envs.env_spec.EnvSpec): Environment specification.
        name (str): Name of the q-function, also serves as the variable scope.

    """
    def __init__(self, env_spec, name='SimpleQFunction'):
        super().__init__(name)
        self.obs_dim = env_spec.observation_space.shape
        action_dim = env_spec.observation_space.flat_dim
        self.model = SimpleMLPModel(output_dim=action_dim)

        self._q_val = None

        self._initialize()

    def _initialize(self):
        """Initialize QFunction."""
        obs_ph = tf.compat.v1.placeholder(tf.float32, (None, ) + self.obs_dim,
                                          name='obs')

        with tf.compat.v1.variable_scope(self.name, reuse=False) as vs:
            self._variable_scope = vs
            self._q_val = self.model.build(obs_ph).outputs

    @property
    def q_vals(self):
        """Return the Q values, the output of the network.

        Return:
            list[tf.Tensor]: Q values.

        """
        return self._q_val

    def __setstate__(self, state):
        """Object.__setstate__.

        Args:
            state (dict): Unpickled state.

        """
        self.__dict__.update(state)
        self._initialize()

    def __getstate__(self):
        """Object.__getstate__.

        Returns:
            dict: the state to be pickled for the instance.

        """
        new_dict = self.__dict__.copy()
        del new_dict['_q_val']
        return new_dict
Пример #3
0
class SimpleMLPRegressor(Regressor2):
    """Simple GaussianMLPRegressor for testing."""
    def __init__(self, input_shape, output_dim, name, *args, **kwargs):
        super().__init__(input_shape, output_dim, name)

        self.model = SimpleMLPModel(output_dim=self._output_dim,
                                    name='SimpleMLPModel')

        self._initialize()

    def _initialize(self):
        input_ph = tf.placeholder(tf.float32,
                                  shape=(None, ) + self._input_shape)
        with tf.variable_scope(self._name) as vs:
            self._variable_scope = vs
            self.model.build(input_ph)
        self.ys = None

    def fit(self, xs, ys):
        self.ys = ys

    def predict(self, xs):
        if self.ys is None:
            outputs = tf.get_default_session().run(
                self.model.networks['default'].outputs,
                feed_dict={self.model.networks['default'].input: xs})
            self.ys = outputs

        return self.ys

    def get_params_internal(self, *args, **kwargs):
        return self._variable_scope.trainable_variables()

    def __setstate__(self, state):
        """Object.__setstate__."""
        super().__setstate__(state)
        self._initialize()
Пример #4
0
class SimpleMLPRegressor(Regressor):
    """Simple MLPRegressor for testing.

    Args:
        input_shape (tuple[int]): Input shape of the training data.
        output_dim (int): Output dimension of the model.
        name (str): Model name, also the variable scope.
        args (list): Unused positionl arguments.
        kwargs (dict): Unused keyword arguments.

    """

    def __init__(self, input_shape, output_dim, name, *args, **kwargs):
        super().__init__(input_shape, output_dim, name)
        del args, kwargs
        self.model = SimpleMLPModel(output_dim=self._output_dim,
                                    name='SimpleMLPModel')

        self._ys = None
        self._initialize()

    def _initialize(self):
        """Initialize graph."""
        input_ph = tf.compat.v1.placeholder(tf.float32,
                                            shape=(None, ) + self._input_shape)
        with tf.compat.v1.variable_scope(self._name) as vs:
            self._variable_scope = vs
            self.model.build(input_ph)

    @property
    def recurrent(self):
        """bool: If this module has a hidden state."""
        return False

    @property
    def vectorized(self):
        """bool: If this module supports vectorization input."""
        return True

    def fit(self, xs, ys):
        """Fit with input data xs and label ys.

        Args:
            xs (numpy.ndarray): Input data.
            ys (numpy.ndarray): Label of input data.

        """
        self._ys = ys

    def predict(self, xs):
        """Predict ys based on input xs.

        Args:
            xs (numpy.ndarray): Input data.

        Return:
            np.ndarray: The predicted ys.

        """
        if self._ys is None:
            outputs = tf.compat.v1.get_default_session().run(
                self.model.networks['default'].outputs,
                feed_dict={self.model.networks['default'].input: xs})
            self._ys = outputs

        return self._ys

    def get_params_internal(self):
        """Get the params, which are the trainable variables.

        Returns:
            List[tf.Variable]: A list of trainable variables in the current
            variable scope.

        """
        return self._variable_scope.trainable_variables()

    def __setstate__(self, state):
        """Object.__setstate__.

        Args:
            state (dict): Unpickled state.

        """
        super().__setstate__(state)
        self._initialize()