Esempio n. 1
0
 def __call__(self, x):
     initializer = nn.initializers.variance_scaling(scale=1.0 /
                                                    jnp.sqrt(3.0),
                                                    mode='fan_in',
                                                    distribution='uniform')
     x = x.astype(jnp.float32) / 255.
     x = nn.Conv(features=32,
                 kernel_size=(8, 8),
                 strides=(4, 4),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Conv(features=64,
                 kernel_size=(4, 4),
                 strides=(2, 2),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Conv(features=64,
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = x.reshape((-1))  # flatten
     x = nn.Dense(features=512, kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Dense(features=self.num_actions * self.num_atoms,
                  kernel_init=initializer)(x)
     logits = x.reshape((self.num_actions, self.num_atoms))
     probabilities = nn.softmax(logits)
     q_values = jnp.mean(logits, axis=1)
     return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Esempio n. 2
0
 def apply(self, x, num_actions, num_atoms):
     initializer = jax.nn.initializers.variance_scaling(
         scale=1.0 / jnp.sqrt(3.0), mode='fan_in', distribution='uniform')
     # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
     # have removed the true batch dimension.
     x = x[None, ...]
     x = x.astype(jnp.float32) / 255.
     x = nn.Conv(x,
                 features=32,
                 kernel_size=(8, 8),
                 strides=(4, 4),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(4, 4),
                 strides=(2, 2),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = x.reshape((x.shape[0], -1))  # flatten
     x = nn.Dense(x, features=512, kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Dense(x,
                  features=num_actions * num_atoms,
                  kernel_init=initializer)
     logits = x.reshape((x.shape[0], num_actions, num_atoms))
     probabilities = nn.softmax(logits)
     q_values = jnp.mean(logits, axis=2)
     return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
 def call(self, state):
     inputs = tf.constant(np.zeros((state.shape[0], stack_size)),
                          dtype=tf.float32)
     net = self.layer(inputs)
     logits = tf.reshape(net,
                         [-1, self.num_actions, self.num_atoms])
     probabilities = tf.keras.activations.softmax(logits)
     qs = tf.reduce_sum(self.support * probabilities, axis=2)
     return atari_lib.RainbowNetworkType(qs, logits, probabilities)
Esempio n. 4
0
  def apply(self, x, num_actions, minatar, env, normalize_obs, noisy, dueling, num_atoms,hidden_layer=2, neurons=512):
    del normalize_obs

    if minatar:
      x = x.squeeze(3)
      x = x[None, ...]
      x = x.astype(jnp.float32)
      x = nn.Conv(x, features=16, kernel_size=(3, 3, 3), strides=(1, 1, 1),  kernel_init=nn.initializers.xavier_uniform())
      x = jax.nn.relu(x)
      x = x.reshape((x.shape[0], -1))

    else:
      x = x[None, ...]
      x = x.astype(jnp.float32)
      x = x.reshape((x.shape[0], -1))


    if env is not None:
      x = x - env_inf[env]['MIN_VALS']
      x /= env_inf[env]['MAX_VALS'] - env_inf[env]['MIN_VALS']
      x = 2.0 * x - 1.0


    if noisy:
      def net(x, features):
        return NoisyNetwork(x, features)
    else:
      def net(x, features):
        return nn.Dense(x, features, kernel_init=nn.initializers.xavier_uniform())


    for _ in range(hidden_layer):
      x = net(x, features=neurons)
      #print('x:',x)
      x = jax.nn.relu(x)

    if dueling:
      print('dueling')
      adv = net(x,features=num_actions * num_atoms)
      value = net(x, features=num_atoms)
      adv = adv.reshape((adv.shape[0], num_actions, num_atoms))
      value = value.reshape((value.shape[0], 1, num_atoms))
      logits = value + (adv - (jnp.mean(adv, -1, keepdims=True)))
      probabilities = nn.softmax(logits)
      q_values = jnp.mean(logits, axis=2)

    else:
      #print('No dueling')
      x = net(x, features=num_actions * num_atoms)
      logits = x.reshape((x.shape[0], num_actions, num_atoms))
      probabilities = nn.softmax(logits)
      q_values = jnp.mean(logits, axis=2)


    return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Esempio n. 5
0
 def __call__(self, x, support):
   initializer = nn.initializers.xavier_uniform()
   x = x.astype(jnp.float32)
   x = nn.Conv(features=16, kernel_size=(3, 3), strides=(1, 1),
               kernel_init=initializer)(x)
   x = nn.relu(x)
   x = x.reshape(-1)  # flatten
   x = nn.Dense(features=self.num_actions * self.num_atoms,
                kernel_init=initializer)(x)
   logits = x.reshape((self.num_actions, self.num_atoms))
   probabilities = nn.softmax(logits)
   q_values = jnp.sum(support * probabilities, axis=1)
   return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Esempio n. 6
0
 def __call__(self, x, support):
     x = x.astype(jnp.float32)
     x = x.reshape((-1))  # flatten
     if self.min_vals is not None:
         x -= self._min_vals
         x /= self._max_vals - self._min_vals
         x = 2.0 * x - 1.0  # Rescale in range [-1, 1].
     for layer in self.layers:
         x = layer(x)
         x = nn.relu(x)
     x = self.final_layer(x)
     logits = x.reshape((self.num_actions, self.num_atoms))
     probabilities = nn.softmax(logits)
     q_values = jnp.sum(support * probabilities, axis=1)
     return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Esempio n. 7
0
 def __call__(self, x, support):
     initializer = nn.initializers.xavier_uniform()
     x = x.astype(jnp.float32)
     x = x.reshape((-1))  # flatten
     x -= gym_lib.ACROBOT_MIN_VALS
     x /= gym_lib.ACROBOT_MAX_VALS - gym_lib.ACROBOT_MIN_VALS
     x = 2.0 * x - 1.0  # Rescale in range [-1, 1].
     x = nn.Dense(features=512, kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Dense(features=512, kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Dense(features=self.num_actions * self.num_atoms,
                  kernel_init=initializer)(x)
     logits = x.reshape((self.num_actions, self.num_atoms))
     probabilities = nn.softmax(logits)
     q_values = jnp.sum(support * probabilities, axis=1)
     return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Esempio n. 8
0
      def apply(self, x, num_actions, num_atoms, support):
        def custom_init(key, shape, dtype=jnp.float32):
          del key
          to_pick_first_action = onp.ones(shape, dtype)
          to_pick_first_action[:, :num_atoms] = onp.arange(1, num_atoms + 1)
          return to_pick_first_action

        x = x[None, :]
        x = x.astype(jnp.float32)
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(x, features=num_actions * num_atoms,
                     kernel_init=custom_init,
                     bias_init=jax.nn.initializers.ones)
        logits = x.reshape((-1, num_actions, num_atoms))
        probabilities = nn.softmax(logits)
        qs = jnp.sum(support * probabilities, axis=2)
        return atari_lib.RainbowNetworkType(qs, logits, probabilities)
Esempio n. 9
0
      def __call__(self, x):
        def custom_init(key, shape, dtype=jnp.float32):
          del key
          to_pick_first_action = onp.ones(shape, dtype)
          to_pick_first_action[:, :self.num_atoms] = onp.arange(
              1, self.num_atoms + 1)
          return to_pick_first_action

        x = x.astype(jnp.float32)
        x = x.reshape((-1))  # flatten
        x = linen.Dense(features=self.num_actions * self.num_atoms,
                        kernel_init=custom_init,
                        bias_init=linen.initializers.ones)(x)
        logits = x.reshape((self.num_actions, self.num_atoms))
        probabilities = linen.softmax(logits)
        qs = jnp.mean(logits, axis=1)
        return atari_lib.RainbowNetworkType(qs, logits, probabilities)
Esempio n. 10
0
  def apply(self, x, num_actions, num_atoms, support, noisy, dueling):
    # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
    # have removed the true batch dimension.

    x = x[None, ...]
    x = x.astype(jnp.float32)
    x = x.reshape((x.shape[0], -1))  # flatten
    #x -= gym_lib.CARTPOLE_MIN_VALS
    #x /= gym_lib.CARTPOLE_MAX_VALS - gym_lib.CARTPOLE_MIN_VALS
    #x = 2.0 * x - 1.0  # Rescale in range [-1, 1].

    if noisy:
        print('LunarLander-Noisy[Johan]')
        initializer = None
        bias = True
        def net(x, features, bias, kernel_init):
            return NoisyNetwork(x, features, bias, kernel_init)
    else:
        initializer = nn.initializers.xavier_uniform()
        bias = None
        def net(x, features, bias, kernel_init):
            return nn.Dense(x, features, kernel_init)

    x = net(x, features=512, bias=bias, kernel_init=initializer)
    x = jax.nn.relu(x)
    x = net(x,features=512, bias=bias, kernel_init=initializer)
    x = jax.nn.relu(x)

    if dueling:
        print('LunarLanderRainbowFull-Dueling')
        adv = net(x,features=num_actions * num_atoms, bias=bias, kernel_init=initializer)
        value = net(x, features=num_atoms, bias=bias, kernel_init=initializer)
        adv = adv.reshape((adv.shape[0], num_actions, num_atoms))
        value = value.reshape((value.shape[0], 1, num_atoms))
        logits = value + (adv - (jnp.mean(adv, -1, keepdims=True)))
        probabilities = nn.softmax(logits)
        q_values = jnp.sum(support * probabilities, axis=2)

    else:
        x = net(x, features=num_actions * num_atoms, bias=bias, kernel_init=initializer)
        logits = x.reshape((x.shape[0], num_actions, num_atoms))
        probabilities = nn.softmax(logits)
        q_values = jnp.sum(support * probabilities, axis=2)
    
    return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Esempio n. 11
0
    def __call__(self, x, support, eval_mode=False, key=None):
        # Generate a random number generation key if not provided
        if key is None:
            key = jax.random.PRNGKey(int(time.time() * 1e6))

        if not self.inputs_preprocessed:
            x = preprocess_atari_inputs(x)

        hidden_sizes = [32, 64, 64]
        kernel_sizes = [8, 4, 3]
        stride_sizes = [4, 2, 1]
        for hidden_size, kernel_size, stride_size in zip(
                hidden_sizes, kernel_sizes, stride_sizes):
            x = nn.Conv(features=hidden_size,
                        kernel_size=(kernel_size, kernel_size),
                        strides=(stride_size, stride_size),
                        kernel_init=nn.initializers.xavier_uniform())(x)
            x = nn.relu(x)
        x = x.reshape((-1))  # flatten

        net = feature_layer(key, self.noisy, eval_mode=eval_mode)
        x = net(x, features=512)  # Single hidden layer of size 512
        x = nn.relu(x)

        if self.dueling:
            adv = net(x, features=self.num_actions * self.num_atoms)
            value = net(x, features=self.num_atoms)
            adv = adv.reshape((self.num_actions, self.num_atoms))
            value = value.reshape((1, self.num_atoms))
            logits = value + (adv - (jnp.mean(adv, axis=0, keepdims=True)))
        else:
            x = net(x, features=self.num_actions * self.num_atoms)
            logits = x.reshape((self.num_actions, self.num_atoms))

        if self.distributional:
            probabilities = nn.softmax(logits)
            q_values = jnp.sum(support * probabilities, axis=1)
            return atari_lib.RainbowNetworkType(q_values, logits,
                                                probabilities)
        q_values = jnp.sum(logits, axis=1)  # Sum over all the num_atoms
        return atari_lib.DQNNetworkType(q_values)
Esempio n. 12
0
  def apply(self, x, num_actions, num_atoms):
    
    initializer = nn.initializers.xavier_uniform()
    # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
    # have removed the true batch dimension.
    x = x[None, ...]
    x = x.astype(jnp.float32)
    x = x.reshape((x.shape[0], -1))  # flatten
    x -= gym_lib.CARTPOLE_MIN_VALS
    x /= gym_lib.CARTPOLE_MAX_VALS - gym_lib.CARTPOLE_MIN_VALS
    x = 2.0 * x - 1.0  # Rescale in range [-1, 1].
    x = nn.Dense(x, features=512, kernel_init=initializer)
    x = jax.nn.relu(x)
    x = nn.Dense(x, features=512, kernel_init=initializer)
    x = jax.nn.relu(x)
    x = nn.Dense(x, features=num_actions * num_atoms, kernel_init=initializer)

    logits = x.reshape((x.shape[0], num_actions, num_atoms))
    probabilities = nn.softmax(logits)
    q_values = jnp.mean(logits, axis=2)
    return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Esempio n. 13
0
  def apply(self, x, num_actions, num_atoms, support, noisy, dueling):
    # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
    # have removed the true batch dimension.
    initializer_conv = nn.initializers.xavier_uniform()
    x = x[None, ...]
    x = x.astype(jnp.float32)
    x = nn.Conv(x, features=16, kernel_size=(3, 3, 3), strides=(1, 1, 1),  kernel_init=initializer_conv)
    x = jax.nn.relu(x)
    x = x.reshape((x.shape[0], -1))  # flatten.

    if noisy:
        print('InvadersRainbowFull-Noisy[Johan]')
        initializer = None
        bias = True
        def net(x, features, bias, kernel_init):
            return NoisyNetwork(x, features, bias, kernel_init)
    else:
        initializer = nn.initializers.xavier_uniform()
        bias = None
        def net(x, features, bias, kernel_init):
            return nn.Dense(x, features, kernel_init)        

    if dueling:
        print('InvadersRainbowFull-Dueling')
        adv = net(x,features=num_actions * num_atoms, bias=bias, kernel_init=initializer)
        value = net(x, features=num_atoms, bias=bias, kernel_init=initializer)
        adv = adv.reshape((adv.shape[0], num_actions, num_atoms))
        value = value.reshape((value.shape[0], 1, num_atoms))
        logits = value + (adv - (jnp.mean(adv, -1, keepdims=True)))
        probabilities = nn.softmax(logits)
        q_values = jnp.sum(support * probabilities, axis=2)

    else:
        x = net(x, features=num_actions * num_atoms, bias=bias, kernel_init=initializer)
        logits = x.reshape((x.shape[0], num_actions, num_atoms))
        probabilities = nn.softmax(logits)
        q_values = jnp.sum(support * probabilities, axis=2)
    
    return atari_lib.RainbowNetworkType(q_values, logits, probabilities)  
Esempio n. 14
0
      def __call__(self, x, support, eval_mode=False, key=None):

        def custom_init(key, shape, dtype=jnp.float32):
          del key
          to_pick_first_action = onp.ones(shape, dtype)
          to_pick_first_action[:, :self.num_atoms] = onp.arange(
              1, self.num_atoms + 1)
          return to_pick_first_action

        x = x.astype(jnp.float32)
        x = x.reshape((-1))  # flatten
        x = nn.Dense(
            features=self.num_actions * self.num_atoms,
            kernel_init=custom_init,
            bias_init=nn.initializers.ones)(
                x)
        logits = x.reshape((self.num_actions, self.num_atoms))
        if not self.distributional:
          qs = jnp.sum(logits, axis=-1)  # Sum over all the num_atoms
          return atari_lib.DQNNetworkType(qs)
        probabilities = nn.softmax(logits)
        qs = jnp.sum(support * probabilities, axis=1)
        return atari_lib.RainbowNetworkType(qs, logits, probabilities)
Esempio n. 15
0
    def __call__(self, x, support, rng):

        if self.net_conf == 'minatar':
            x = x.squeeze(3)
            x = x.astype(jnp.float32)
            for _ in range(self.hidden_conv):
                x = nn.Conv(features=16,
                            kernel_size=(3, 3),
                            strides=(1, 1),
                            padding='SAME',
                            kernel_init=self.initzer)(x)
                x = layer_funct_inf[self.layer_funct](x)
            x = nn.Conv(features=16,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        kernel_init=self.initzer)(x)
            x = layer_funct_inf[self.layer_funct](x)
            x = x.reshape((-1))

        elif self.net_conf == 'atari':
            # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
            # have removed the true batch dimension.
            x = x.astype(jnp.float32) / 255.
            x = nn.Conv(features=32,
                        kernel_size=(8, 8),
                        strides=(4, 4),
                        kernel_init=self.initzer)(x)
            x = layer_funct_inf[self.layer_funct](x)
            x = nn.Conv(features=64,
                        kernel_size=(4, 4),
                        strides=(2, 2),
                        kernel_init=self.initzer)(x)
            x = layer_funct_inf[self.layer_funct](x)
            x = nn.Conv(features=64,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        kernel_init=self.initzer)(x)
            x = layer_funct_inf[self.layer_funct](x)
            x = x.reshape((-1))  # flatten

        elif self.net_conf == 'classic':
            x = x.astype(jnp.float32)
            x = x.reshape((-1))

        if self.env is not None and self.env in env_inf:
            x = x - env_inf[self.env]['MIN_VALS']
            x /= env_inf[self.env]['MAX_VALS'] - env_inf[self.env]['MIN_VALS']
            x = 2.0 * x - 1.0

        if self.noisy:

            def net(x, features, rng):
                return NoisyNetwork(features, rng=rng, bias_in=True)(x)
        else:

            def net(x, features, rng):
                return nn.Dense(features, kernel_init=self.initzer)(x)

        for _ in range(self.hidden_layer):
            x = net(x, features=self.neurons, rng=rng)
            if self.normalization == 'non_normalization':
                if self.layer_funct != 'non_activation':
                    x = layer_funct_inf[self.layer_funct](x)
            elif self.normalization == 'BatchNorm':
                x = nn.BatchNorm(use_running_average=True)(x)
                if self.layer_funct != 'non_activation':
                    x = layer_funct_inf[self.layer_funct](x)
            elif self.normalization == 'LayerNorm':
                if self.layer_funct != 'non_activation':
                    x = layer_funct_inf[self.layer_funct](x)
                x = nn.LayerNorm()(x)
            else:
                print('error: Choose a correct Normalization Module')

        if self.dueling:
            adv = net(x, features=self.num_actions * self.num_atoms, rng=rng)
            value = net(x, features=self.num_atoms, rng=rng)

            adv = adv.reshape((self.num_actions, self.num_atoms))
            value = value.reshape((1, self.num_atoms))

            logits = value + (adv - (jnp.mean(adv, -2, keepdims=True)))
            probabilities = nn.softmax(logits)
            q_values = jnp.sum(support * probabilities, axis=1)

        else:
            x = net(x, features=self.num_actions * self.num_atoms, rng=rng)
            logits = x.reshape((self.num_actions, self.num_atoms))
            probabilities = nn.softmax(logits)
            q_values = jnp.sum(support * probabilities, axis=1)

        return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Esempio n. 16
0
 def call(self, state):
   x = self.net(state)
   logits = tf.reshape(x, [-1, self.num_actions, self.num_atoms])
   probabilities = layers.softmax(logits)
   q_values = tf.reduce_sum(self.support * probabilities, axis=2)
   return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
    def __call__(self, x, rng):

        if self.net_conf == 'minatar':
            x = x.squeeze(3)
            x = x.astype(jnp.float32)
            x = nn.Conv(features=16,
                        kernel_size=(3, 3, 3),
                        strides=(1, 1, 1),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = x.reshape((x.shape[0], -1))

        elif self.net_conf == 'atari':
            # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
            # have removed the true batch dimension.
            x = x.astype(jnp.float32) / 255.
            x = nn.Conv(features=32,
                        kernel_size=(8, 8),
                        strides=(4, 4),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = nn.Conv(features=64,
                        kernel_size=(4, 4),
                        strides=(2, 2),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = nn.Conv(features=64,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = x.reshape((-1))  # flatten

        elif self.net_conf == 'classic':
            #classic environments
            x = x.astype(jnp.float32)
            x = x.reshape((-1))

        if self.env is not None and self.env in env_inf:
            x = x - env_inf[self.env]['MIN_VALS']
            x /= env_inf[self.env]['MAX_VALS'] - env_inf[self.env]['MIN_VALS']
            x = 2.0 * x - 1.0

        if self.noisy:

            def net(x, features, rng):
                return NoisyNetwork(features, rng=rng, bias_in=True)(x)
        else:

            def net(x, features, rng):
                return nn.Dense(features, kernel_init=self.initzer)(x)

        for _ in range(self.hidden_layer):
            x = net(x, features=self.neurons, rng=rng)
            x = jax.nn.relu(x)

        if self.dueling:
            adv = net(x, features=self.num_actions * self.num_atoms, rng=rng)
            value = net(x, features=self.num_atoms, rng=rng)
            adv = adv.reshape((self.num_actions, self.num_atoms))
            value = value.reshape((1, self.num_atoms))
            #print('value:', value.shape)
            logits = value + (adv - (jnp.mean(adv, -2, keepdims=True)))
            probabilities = nn.softmax(logits)
            q_values = jnp.mean(logits, axis=1)

        else:
            x = net(x, features=self.num_actions * self.num_atoms, rng=rng)
            logits = x.reshape((self.num_actions, self.num_atoms))
            probabilities = nn.softmax(logits)
            q_values = jnp.mean(logits, axis=1)

        return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
    def apply(self,
              x,
              num_actions,
              net_conf,
              env,
              normalize_obs,
              noisy,
              dueling,
              num_atoms,
              hidden_layer=2,
              neurons=512):
        del normalize_obs

        if net_conf == 'minatar':
            x = x.squeeze(3)
            x = x[None, ...]
            x = x.astype(jnp.float32)
            x = nn.Conv(x,
                        features=16,
                        kernel_size=(3, 3, 3),
                        strides=(1, 1, 1),
                        kernel_init=nn.initializers.xavier_uniform())
            x = jax.nn.relu(x)
            x = x.reshape((x.shape[0], -1))

        elif net_conf == 'atari':
            # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
            # have removed the true batch dimension.
            x = x[None, ...]
            x = x.astype(jnp.float32) / 255.
            x = nn.Conv(x,
                        features=32,
                        kernel_size=(8, 8),
                        strides=(4, 4),
                        kernel_init=nn.initializers.xavier_uniform())
            x = jax.nn.relu(x)
            x = nn.Conv(x,
                        features=64,
                        kernel_size=(4, 4),
                        strides=(2, 2),
                        kernel_init=nn.initializers.xavier_uniform())
            x = jax.nn.relu(x)
            x = nn.Conv(x,
                        features=64,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        kernel_init=nn.initializers.xavier_uniform())
            x = jax.nn.relu(x)
            x = x.reshape((x.shape[0], -1))  # flatten

        elif net_conf == 'classic':
            #classic environments
            x = x[None, ...]
            x = x.astype(jnp.float32)
            x = x.reshape((x.shape[0], -1))

        if env is not None:
            x = x - env_inf[env]['MIN_VALS']
            x /= env_inf[env]['MAX_VALS'] - env_inf[env]['MIN_VALS']
            x = 2.0 * x - 1.0

        if noisy:

            def net(x, features):
                return NoisyNetwork(x, features)
        else:

            def net(x, features):
                return nn.Dense(x,
                                features,
                                kernel_init=nn.initializers.xavier_uniform())

        for _ in range(hidden_layer):
            x = net(x, features=neurons)
            x = jax.nn.relu(x)

        if dueling:
            adv = net(x, features=num_actions * num_atoms)
            value = net(x, features=num_atoms)
            adv = adv.reshape((adv.shape[0], num_actions, num_atoms))
            value = value.reshape((value.shape[0], 1, num_atoms))
            logits = value + (adv - (jnp.mean(adv, -1, keepdims=True)))
            probabilities = nn.softmax(logits)
            q_values = jnp.mean(logits, axis=2)

        else:
            x = net(x, features=num_actions * num_atoms)
            logits = x.reshape((x.shape[0], num_actions, num_atoms))
            probabilities = nn.softmax(logits)
            q_values = jnp.mean(logits, axis=2)

        return atari_lib.RainbowNetworkType(q_values, logits, probabilities)