Example #1
0
 def call(self, state, num_quantiles):
     batch_size = state.get_shape().as_list()[0]
     inputs = tf.constant(np.ones(
         (batch_size * num_quantiles, self.num_actions)),
                          dtype=tf.float32)
     quantiles_shape = [num_quantiles * batch_size, 1]
     quantiles = tf.ones(quantiles_shape)
     return atari_lib.ImplicitQuantileNetworkType(
         self.layer(inputs), quantiles)
 def __call__(self, x, num_quantiles, rng):
     del rng
     x = x.reshape((-1))  # flatten
     state_net_tiled = jnp.tile(x, [num_quantiles, 1])
     x *= state_net_tiled
     quantile_values = linen.Dense(
         features=self.num_actions,
         kernel_init=linen.initializers.ones,
         bias_init=linen.initializers.zeros)(x)
     quantiles = jnp.ones([num_quantiles, 1])
     return atari_lib.ImplicitQuantileNetworkType(
         quantile_values, quantiles)
Example #3
0
 def apply(self, x, num_actions, quantile_embedding_dim, num_quantiles,
           rng):
   del rng
   # This weights_initializer gives action 0 a higher weight, ensuring
   # that it gets picked by the argmax.
   batch_size = x.shape[0]
   x = x[None, :]
   x = x.astype(jnp.float32)
   x = x.reshape((x.shape[0], -1))  # flatten
   quantile_values = nn.Dense(x, features=num_actions,
                              kernel_init=jax.nn.initializers.ones,
                              bias_init=jax.nn.initializers.zeros)
   quantiles = jnp.ones([num_quantiles * batch_size, 1])
   return atari_lib.ImplicitQuantileNetworkType(quantile_values, quantiles)
Example #4
0
 def apply(self, x, num_actions, quantile_embedding_dim, num_quantiles,
           rng):
     initializer = jax.nn.initializers.variance_scaling(
         scale=1.0 / jnp.sqrt(3.0), mode='fan_in', distribution='uniform')
     # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
     # have removed the true batch dimension.
     x = x[None, ...]
     x = x.astype(jnp.float32) / 255.
     x = nn.Conv(x,
                 features=32,
                 kernel_size=(8, 8),
                 strides=(4, 4),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(4, 4),
                 strides=(2, 2),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = x.reshape((x.shape[0], -1))  # flatten
     state_vector_length = x.shape[-1]
     state_net_tiled = jnp.tile(x, [num_quantiles, 1])
     quantiles_shape = [num_quantiles, 1]
     quantiles = jax.random.uniform(rng, shape=quantiles_shape)
     quantile_net = jnp.tile(quantiles, [1, quantile_embedding_dim])
     quantile_net = (
         jnp.arange(1, quantile_embedding_dim + 1, 1).astype(jnp.float32) *
         onp.pi * quantile_net)
     quantile_net = jnp.cos(quantile_net)
     quantile_net = nn.Dense(quantile_net,
                             features=state_vector_length,
                             kernel_init=initializer)
     quantile_net = jax.nn.relu(quantile_net)
     x = state_net_tiled * quantile_net
     x = nn.Dense(x, features=512, kernel_init=initializer)
     x = jax.nn.relu(x)
     quantile_values = nn.Dense(x,
                                features=num_actions,
                                kernel_init=initializer)
     return atari_lib.ImplicitQuantileNetworkType(quantile_values,
                                                  quantiles)
Example #5
0
 def __call__(self, x, num_quantiles, rng):
     initializer = nn.initializers.variance_scaling(scale=1.0 /
                                                    jnp.sqrt(3.0),
                                                    mode='fan_in',
                                                    distribution='uniform')
     if not self.inputs_preprocessed:
         x = preprocess_atari_inputs(x)
     x = nn.Conv(features=32,
                 kernel_size=(8, 8),
                 strides=(4, 4),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Conv(features=64,
                 kernel_size=(4, 4),
                 strides=(2, 2),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Conv(features=64,
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = x.reshape((-1))  # flatten
     state_vector_length = x.shape[-1]
     state_net_tiled = jnp.tile(x, [num_quantiles, 1])
     quantiles_shape = [num_quantiles, 1]
     quantiles = jax.random.uniform(rng, shape=quantiles_shape)
     quantile_net = jnp.tile(quantiles, [1, self.quantile_embedding_dim])
     quantile_net = (jnp.arange(1, self.quantile_embedding_dim + 1,
                                1).astype(jnp.float32) * onp.pi *
                     quantile_net)
     quantile_net = jnp.cos(quantile_net)
     quantile_net = nn.Dense(features=state_vector_length,
                             kernel_init=initializer)(quantile_net)
     quantile_net = nn.relu(quantile_net)
     x = state_net_tiled * quantile_net
     x = nn.Dense(features=512, kernel_init=initializer)(x)
     x = nn.relu(x)
     quantile_values = nn.Dense(features=self.num_actions,
                                kernel_init=initializer)(x)
     return atari_lib.ImplicitQuantileNetworkType(quantile_values,
                                                  quantiles)
    def __call__(self, x, num_quantiles, rng):

        if self.net_conf == 'minatar':
            x = x.squeeze(3)
            x = x.astype(jnp.float32)
            x = nn.Conv(features=16,
                        kernel_size=(3, 3, 3),
                        strides=(1, 1, 1),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = x.reshape((x.shape[0], -1))

        elif self.net_conf == 'atari':
            # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
            # have removed the true batch dimension.
            x = x.astype(jnp.float32) / 255.
            x = nn.Conv(features=32,
                        kernel_size=(8, 8),
                        strides=(4, 4),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = nn.Conv(features=64,
                        kernel_size=(4, 4),
                        strides=(2, 2),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = nn.Conv(features=64,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = x.reshape((-1))  # flatten

        elif self.net_conf == 'classic':
            #classic environments
            #print('x input',x.shape)
            x = x.astype(jnp.float32)
            x = x.reshape((-1))
            #print('x.shape:',x.shape)

        if self.env is not None and self.env in env_inf:
            x = x - env_inf[self.env]['MIN_VALS']
            x /= env_inf[self.env]['MAX_VALS'] - env_inf[self.env]['MIN_VALS']
            x = 2.0 * x - 1.0

        if self.noisy:

            def net(x, features, rng):
                return NoisyNetwork(features, rng=rng, bias_in=True)(x)
        else:

            def net(x, features, rng):
                return nn.Dense(features, kernel_init=self.initzer)(x)

        for _ in range(self.hidden_layer):
            x = net(x, features=self.neurons, rng=rng)
            x = jax.nn.relu(x)

        state_vector_length = x.shape[-1]
        state_net_tiled = jnp.tile(x, [num_quantiles, 1])
        quantiles_shape = [num_quantiles, 1]
        quantiles = jax.random.uniform(rng, shape=quantiles_shape)
        quantile_net = jnp.tile(quantiles, [1, self.quantile_embedding_dim])
        quantile_net = (jnp.arange(1, self.quantile_embedding_dim + 1,
                                   1).astype(jnp.float32) * onp.pi *
                        quantile_net)
        quantile_net = jnp.cos(quantile_net)
        quantile_net = nn.Dense(features=state_vector_length,
                                kernel_init=self.initzer)(quantile_net)
        quantile_net = jax.nn.relu(quantile_net)
        x = state_net_tiled * quantile_net

        #print('X_before_adv:', x.shape)
        adv = net(x, features=self.num_actions, rng=rng)
        val = net(x, features=1, rng=rng)
        #print('value:', val.shape)
        dueling_q = val + (adv - (jnp.mean(adv, -1, keepdims=True)))
        non_dueling_q = net(x, features=self.num_actions, rng=rng)
        quantile_values = jnp.where(self.dueling, dueling_q, non_dueling_q)

        return atari_lib.ImplicitQuantileNetworkType(quantile_values,
                                                     quantiles)