Esempio n. 1
0
  def __init__(
      self,
      environment_spec: specs.EnvironmentSpec,
      hidden_sizes: Tuple[int, ...],
  ):
    super(MLPTransitionModel, self).__init__(name='mlp_transition_model')

    # Get num actions/observation shape.
    self._num_actions = environment_spec.actions.num_values
    self._input_shape = environment_spec.observations.shape
    self._flat_shape = int(np.prod(self._input_shape))

    # Prediction networks.
    self._state_network = snt.Sequential([
        snt.nets.MLP(hidden_sizes + (self._flat_shape,)),
        snt.Reshape(self._input_shape)
    ])
    self._reward_network = snt.Sequential([
        snt.nets.MLP(hidden_sizes + (1,)),
        lambda r: tf.squeeze(r, axis=-1),
    ])
    self._discount_network = snt.Sequential([
        snt.nets.MLP(hidden_sizes + (1,)),
        lambda d: tf.squeeze(d, axis=-1),
    ])
Esempio n. 2
0
    def __init__(self, n_latent=4, kernel_size=4, name=None):
        super(VariationalAutoEncoder, self).__init__(name=name)

        self.n_latent = n_latent
        self.encoder = snt.Sequential([
            snt.Conv2D(4, kernel_size, stride=4, padding='SAME'),
            tf.nn.relu,  # [b, 250, 250, 4]
            snt.Conv2D(16, kernel_size, stride=4, padding='SAME'),
            tf.nn.relu,  # [b, 63, 63, 16]
            snt.Conv2D(32, kernel_size, stride=4, padding='SAME'),
            tf.nn.relu,  # [b, 16, 16, 32]
            snt.Conv2D(64, kernel_size, stride=2, padding='SAME'),
            tf.nn.relu,  # [b, 8, 8, 64]
            snt.Flatten()
        ])

        self.mn = snt.nets.MLP([n_latent], activation=tf.nn.relu)
        self.std = snt.nets.MLP([n_latent], activation=tf.nn.relu)

        self.decoder = snt.Sequential([
            snt.nets.MLP([8 * 8 * 64], activation=tf.nn.leaky_relu),
            snt.Reshape([8, 8, 64]),
            snt.Conv2DTranspose(64, kernel_size, stride=2, padding='SAME'),
            tf.nn.relu,  # [b, 16, 16, 64]
            snt.Conv2DTranspose(32, kernel_size, stride=4, padding='SAME'),
            tf.nn.relu,  # [b, 64, 64, 32]
            snt.Conv2DTranspose(16, kernel_size, stride=4, padding='SAME'),
            tf.nn.relu,  # [b, 256, 256, 16]
            snt.Conv2DTranspose(4, kernel_size, stride=4, padding='SAME'),
            tf.nn.relu,  # [b, 1024, 1024, 4]
            snt.Conv2D(1, kernel_size, padding='SAME')
        ])  # [b, 1024, 1024, 1]
Esempio n. 3
0
 def decode(self, code):
     """Decode the image observation from a latent code."""
     if self._convnet_output_shape is None:
         raise ValueError('Must call `encode` before `decode`.')
     transpose_convnet_in_flat = snt.Linear(
         self._convnet_output_shape.num_elements(),
         name='decode_initial_linear')(code)
     transpose_convnet_in_flat = tf.nn.relu(transpose_convnet_in_flat)
     transpose_convnet_in = snt.Reshape(
         self._convnet_output_shape.as_list())(transpose_convnet_in_flat)
     print("Decode the image observation from a latent code.")
     return self._convnet.transpose(None)(transpose_convnet_in)