def __init__( self, environment_spec: specs.EnvironmentSpec, hidden_sizes: Tuple[int, ...], ): super(MLPTransitionModel, self).__init__(name='mlp_transition_model') # Get num actions/observation shape. self._num_actions = environment_spec.actions.num_values self._input_shape = environment_spec.observations.shape self._flat_shape = int(np.prod(self._input_shape)) # Prediction networks. self._state_network = snt.Sequential([ snt.nets.MLP(hidden_sizes + (self._flat_shape,)), snt.Reshape(self._input_shape) ]) self._reward_network = snt.Sequential([ snt.nets.MLP(hidden_sizes + (1,)), lambda r: tf.squeeze(r, axis=-1), ]) self._discount_network = snt.Sequential([ snt.nets.MLP(hidden_sizes + (1,)), lambda d: tf.squeeze(d, axis=-1), ])
def __init__(self, n_latent=4, kernel_size=4, name=None): super(VariationalAutoEncoder, self).__init__(name=name) self.n_latent = n_latent self.encoder = snt.Sequential([ snt.Conv2D(4, kernel_size, stride=4, padding='SAME'), tf.nn.relu, # [b, 250, 250, 4] snt.Conv2D(16, kernel_size, stride=4, padding='SAME'), tf.nn.relu, # [b, 63, 63, 16] snt.Conv2D(32, kernel_size, stride=4, padding='SAME'), tf.nn.relu, # [b, 16, 16, 32] snt.Conv2D(64, kernel_size, stride=2, padding='SAME'), tf.nn.relu, # [b, 8, 8, 64] snt.Flatten() ]) self.mn = snt.nets.MLP([n_latent], activation=tf.nn.relu) self.std = snt.nets.MLP([n_latent], activation=tf.nn.relu) self.decoder = snt.Sequential([ snt.nets.MLP([8 * 8 * 64], activation=tf.nn.leaky_relu), snt.Reshape([8, 8, 64]), snt.Conv2DTranspose(64, kernel_size, stride=2, padding='SAME'), tf.nn.relu, # [b, 16, 16, 64] snt.Conv2DTranspose(32, kernel_size, stride=4, padding='SAME'), tf.nn.relu, # [b, 64, 64, 32] snt.Conv2DTranspose(16, kernel_size, stride=4, padding='SAME'), tf.nn.relu, # [b, 256, 256, 16] snt.Conv2DTranspose(4, kernel_size, stride=4, padding='SAME'), tf.nn.relu, # [b, 1024, 1024, 4] snt.Conv2D(1, kernel_size, padding='SAME') ]) # [b, 1024, 1024, 1]
def decode(self, code): """Decode the image observation from a latent code.""" if self._convnet_output_shape is None: raise ValueError('Must call `encode` before `decode`.') transpose_convnet_in_flat = snt.Linear( self._convnet_output_shape.num_elements(), name='decode_initial_linear')(code) transpose_convnet_in_flat = tf.nn.relu(transpose_convnet_in_flat) transpose_convnet_in = snt.Reshape( self._convnet_output_shape.as_list())(transpose_convnet_in_flat) print("Decode the image observation from a latent code.") return self._convnet.transpose(None)(transpose_convnet_in)