Пример #1
0
    def apply(
        self,
        x,
        action_dim,
        max_action,
        key=None,
        MPO=False,
        sample=False,
        log_sig_min=-20,
        log_sig_max=2,
    ):
        x = nn.Dense(x, features=200)
        x = nn.LayerNorm(x)
        x = nn.tanh(x)
        x = nn.Dense(x, features=200)
        x = nn.elu(x)
        x = nn.Dense(x, features=2 * action_dim)

        mu, log_sig = jnp.split(x, 2, axis=-1)
        log_sig = nn.softplus(log_sig)
        log_sig = jnp.clip(log_sig, log_sig_min, log_sig_max)

        if MPO:
            return mu, log_sig

        if not sample:
            return max_action * nn.tanh(mu), log_sig
        else:
            pi = mu + random.normal(key, mu.shape) * jnp.exp(log_sig)
            log_pi = gaussian_likelihood(pi, mu, log_sig)
            pi = nn.tanh(pi)
            log_pi -= jnp.sum(jnp.log(nn.relu(1 - pi ** 2) + 1e-6), axis=1)
            return max_action * pi, log_pi
Пример #2
0
 def apply(self,
           inputs,
           mlp_dim,
           dtype=jnp.float32,
           out_dim=None,
           dropout_rate=0.1,
           deterministic=True,
           kernel_init=nn.initializers.xavier_uniform(),
           bias_init=nn.initializers.normal(stddev=1e-6)):
   """Applies Transformer MlpBlock module."""
   actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim
   x = nn.Dense(
       inputs,
       mlp_dim,
       dtype=dtype,
       kernel_init=kernel_init,
       bias_init=bias_init)
   x = nn.gelu(x)
   x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
   output = nn.Dense(
       x,
       actual_out_dim,
       dtype=dtype,
       kernel_init=kernel_init,
       bias_init=bias_init)
   output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic)
   return output
Пример #3
0
 def apply(self, x, num_actions):
     initializer = nn.initializers.xavier_uniform()
     # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
     # have removed the true batch dimension.
     x = x[None, ...]
     x = x.astype(jnp.float32) / 255.
     x = nn.Conv(x,
                 features=32,
                 kernel_size=(8, 8),
                 strides=(4, 4),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(4, 4),
                 strides=(2, 2),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = x.reshape((x.shape[0], -1))  # flatten
     x = nn.Dense(x, features=512, kernel_init=initializer)
     x = jax.nn.relu(x)
     q_values = nn.Dense(x, features=num_actions, kernel_init=initializer)
     return atari_lib.DQNNetworkType(q_values)
Пример #4
0
 def apply(self,
           inputs,
           mlp_dim,
           dtype=jnp.float32,
           out_dim=None,
           dropout_rate=0.1,
           deterministic=False,
           kernel_init=nn.initializers.xavier_uniform(),
           bias_init=nn.initializers.normal(stddev=1e-6),
           num_partitions=2):
     """Applies Transformer MlpBlock module."""
     actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim
     inputs_shape = inputs.shape
     inputs = inputs.reshape((-1, inputs_shape[-1]))
     x = nn.Dense(inputs,
                  mlp_dim,
                  dtype=dtype,
                  kernel_init=kernel_init,
                  bias_init=bias_init)
     x = nn.relu(x)
     if num_partitions > 1:
         x = with_sharding_constraint(x, P(1, num_partitions))
     x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
     output = nn.Dense(x,
                       actual_out_dim,
                       dtype=dtype,
                       kernel_init=kernel_init,
                       bias_init=bias_init)
     output = nn.dropout(output,
                         rate=dropout_rate,
                         deterministic=deterministic)
     output = output.reshape(inputs_shape[:-1] + (actual_out_dim, ))
     return output
Пример #5
0
 def apply(self, x):
     x = nn.Dense(x, features=32)
     x = nn.sigmoid(x)
     x = nn.Dense(x, features=32)
     x = nn.sigmoid(x)
     x = nn.Dense(x, features=1)
     return nn.sigmoid(x)
  def apply(self, x, rep_size, m_layers, m_features, m_kernel_sizes, conv_rep_size, padding_mask=None):
        
    H_0 = nn.relu(nn.Dense(x, conv_rep_size))
    G_0 = nn.relu(nn.Dense(x, conv_rep_size))
    H, G = jnp.expand_dims(H_0, axis=2), jnp.expand_dims(G_0, axis=2)

    for layer in range(1, m_layers+1):
      
      if layer < m_layers:
        H_features, G_features = m_features[layer-1]
      else:
        H_features, G_features = conv_rep_size, conv_rep_size
      
      H_kernel_size, G_kernel_size = m_kernel_sizes[layer-1]

      H = nn.Conv(H, features=H_features, kernel_size=(H_kernel_size, 1))
      G = nn.Conv(G, features=G_features, kernel_size=(G_kernel_size, 1)) 

      if layer < m_layers:
        H = nn.relu(H)
        G = nn.relu(G)
      else:
        H = nn.tanh(H)
        G = nn.sigmoid(G)

    H, G = jnp.squeeze(H, axis=2), jnp.squeeze(G, axis=2)
    
    F = H * G + G_0
    
    rep = linear_max_pool(F, padding_mask=padding_mask, rep_size=rep_size)
    
    return rep
Пример #7
0
    def apply(self, s, layers=[10
                               ], bias=False, actFun=[
                                   jax.nn.elu,
                               ]):

        for l in range(len(actFun), len(layers) + 1):
            actFun.append(actFun[-1])

        s = 2 * s - 1
        for l, fun in zip(layers, actFun[:-1]):
            s = fun(
                nn.Dense(s,
                         features=l,
                         bias=bias,
                         dtype=global_defs.tReal,
                         kernel_init=jax.nn.initializers.lecun_normal(
                             dtype=global_defs.tReal),
                         bias_init=partial(jax.nn.initializers.zeros,
                                           dtype=global_defs.tReal)))

        return jnp.sum(actFun[-1](nn.Dense(
            s,
            features=1,
            bias=bias,
            dtype=global_defs.tReal,
            kernel_init=jax.nn.initializers.lecun_normal(
                dtype=global_defs.tReal),
            bias_init=partial(jax.nn.initializers.zeros,
                              dtype=global_defs.tReal))))
Пример #8
0
def classifier_head(encoded, num_classes, mlp_dim, pooling_mode='MEAN'):
  """Classifier head.

  We put this here just so that all models consistently call the same function.

  Args:
    encoded: tensor inputs are shape of [bs, len, dim].
    num_classes: int, number of classes
    mlp_dim: int, dim of intermediate MLP.
    pooling_mode: str, string dictating pooling op {MEAN}

  Returns:
    tensor of shape [bs, num_classes]

  """
  if pooling_mode == 'MEAN':
    encoded = jnp.mean(encoded, axis=1)
  elif pooling_mode == 'SUM':
    encoded = jnp.sum(encoded, axis=1)
  elif pooling_mode == 'FLATTEN':
    encoded = encoded.reshape((encoded.shape[0], -1))
  elif pooling_mode == 'CLS':
    encoded = encoded[:, 0]
  else:
    raise NotImplementedError('Pooling not supported yet.')
  encoded = nn.Dense(encoded, mlp_dim, name='mlp')
  encoded = nn.relu(encoded)
  encoded = nn.Dense(encoded, num_classes, name='logits')
  return encoded
Пример #9
0
    def apply(self, x, num_actions):
        initializer = nn.initializers.xavier_uniform()
        # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
        # have removed the true batch dimension.
        x = x[None, ...]
        x = x.astype(jnp.float32)
        x = x.reshape((x.shape[0], -1))  # flatten
        #x -= gym_lib.CARTPOLE_MIN_VALS
        #x /= gym_lib.CARTPOLE_MAX_VALS - gym_lib.CARTPOLE_MIN_VALS
        #x = 2.0 * x - 1.0  # Rescale in range [-1, 1].
        x = nn.Dense(x, features=512, kernel_init=initializer)
        x = jax.nn.relu(x)
        x = nn.Dense(x, features=512, kernel_init=initializer)
        x = jax.nn.relu(x)
        print('x', x.shape, len(x))

        adv = nn.Dense(x, features=num_actions, kernel_init=initializer)
        val = nn.Dense(x, features=1, kernel_init=initializer)

        #q_values = nn.Dense(x, features=num_actions, kernel_init=initializer)
        # https://jax.readthedocs.io/en/latest/_modules/jax/nn/functions.html (JAX Mean)

        #q_values = val + (adv - (jnp.mean(adv, 1, keepdims=True)))
        q_values = val + (adv - (jnp.mean(adv, -1, keepdims=True)))
        return atari_lib.DQNNetworkType(q_values)
Пример #10
0
 def apply(self, x, action_dim, max_action):
     x = nn.Dense(x, features=256)
     x = nn.relu(x)
     x = nn.Dense(x, features=256)
     x = nn.relu(x)
     x = nn.Dense(x, features=action_dim)
     return max_action * nn.tanh(x)
Пример #11
0
 def apply(
         self,
         hidden_states,
         *,
         d_ff: int,
         dropout_rate: float = 0.0,
         intermediate_activation=nn.gelu,
         # TODO(kitaev): chunk_size hparam for chunking
         kernel_init=nn.initializers.xavier_uniform(),
         deterministic: bool = False):
     """Applies FeedForward module."""
     d_model = hidden_states.shape[-1]
     hidden_states = nn.Dense(hidden_states,
                              d_ff,
                              kernel_init=kernel_init,
                              name='intermediate')
     hidden_states = intermediate_activation(hidden_states)
     hidden_states = nn.Dense(hidden_states,
                              d_model,
                              kernel_init=kernel_init,
                              name='output')
     hidden_states = nn.dropout(hidden_states,
                                rate=dropout_rate,
                                deterministic=deterministic)
     return hidden_states
Пример #12
0
 def apply(self, x, num_actions, num_atoms):
     initializer = jax.nn.initializers.variance_scaling(
         scale=1.0 / jnp.sqrt(3.0), mode='fan_in', distribution='uniform')
     # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
     # have removed the true batch dimension.
     x = x[None, ...]
     x = x.astype(jnp.float32) / 255.
     x = nn.Conv(x,
                 features=32,
                 kernel_size=(8, 8),
                 strides=(4, 4),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(4, 4),
                 strides=(2, 2),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = x.reshape((x.shape[0], -1))  # flatten
     x = nn.Dense(x, features=512, kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Dense(x,
                  features=num_actions * num_atoms,
                  kernel_init=initializer)
     logits = x.reshape((x.shape[0], num_actions, num_atoms))
     probabilities = nn.softmax(logits)
     q_values = jnp.mean(logits, axis=2)
     return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Пример #13
0
 def apply(self, x, hidden_layers, hidden_dim, n_classes):
     x = jnp.reshape(x, (x.shape[0], -1))
     for layer in range(hidden_layers):
         x = nn.Dense(x, hidden_dim, name=f'fc{layer}')
         x = nn.relu(x)
     x = nn.Dense(x, n_classes, name=f'fc{hidden_layers}')
     preds = nn.log_softmax(x)
     return preds
 def apply(self, x, reduction=16):
     num_channels = x.shape[-1]
     y = x.mean(axis=(1, 2))
     y = nn.Dense(y, features=num_channels // reduction, bias=False)
     y = nn.relu(y)
     y = nn.Dense(y, features=num_channels, bias=False)
     y = nn.sigmoid(y)
     return x * y[:, None, None, :]
Пример #15
0
  def apply(self,
            x,
            num_classes=1000,
            train=False,
            resnet=None,
            patches=None,
            hidden_size=None,
            transformer=None,
            representation_size=None,
            classifier='gap'):

    n, h, w, c = x.shape

    # Embed the grid or patches of the grid.
    fh, fw = patches.size
    gh, gw = h // fh, w // fw
    if hidden_size:  # We can merge s2d+emb into a single conv; it's the same.
      x = nn.Conv(
          x,
          hidden_size, (fh, fw),
          strides=(fh, fw),
          padding='VALID',
          name='embedding')
    else:
      # This path often results in excessive padding.
      x = jnp.reshape(x, [n, gh, fh, gw, fw, c])
      x = jnp.transpose(x, [0, 1, 3, 2, 4, 5])
      x = jnp.reshape(x, [n, gh, gw, -1])

    # Here, x is a grid of embeddings.

    # (Possibly partial) Transformer.
    if transformer is not None:
      n, h, w, c = x.shape
      x = jnp.reshape(x, [n, h * w, c])

      # If we want to add a class token, add it here.
      if classifier == 'token':
        cls = self.param('cls', (1, 1, c), nn.initializers.zeros)
        cls = jnp.tile(cls, [n, 1, 1])
        x = jnp.concatenate([cls, x], axis=1)

      x = Encoder(x, train=train, name='Transformer', **transformer)

    if classifier == 'token':
      x = x[:, 0]
    elif classifier == 'gap':
      x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)

    if representation_size is not None:
      x = nn.Dense(x, representation_size, name='pre_logits')
      x = nn.tanh(x)
    else:
      x = IdentityLayer(x, name='pre_logits')

    x = nn.Dense(x, num_classes, name='head', kernel_init=nn.initializers.zeros)
    return x
Пример #16
0
    def apply(self, x):
        real = nn.Dense(x, 25)
        real = jnp.sin(real)
        real = nn.Dense(real, 2)

        imag = nn.Dense(x, 25)
        imag = jnp.sin(imag)
        imag = nn.Dense(imag, 2)
        imag = jnp.pi * nn.soft_sign(imag)
        return real * jnp.exp(1j * imag)
Пример #17
0
 def apply(self, x):
     net = nn.Dense(x, 500, name='fc1')
     net = nn.leaky_relu(net)
     net = nn.BatchNorm(net)
     net = nn.Dense(net, 500, name='fc2')
     net = nn.leaky_relu(net)
     net = nn.BatchNorm(net)
     net = nn.Dense(net, 500, name='fc3')
     net = nn.leaky_relu(net)
     net = nn.BatchNorm(net)
     return nn.softmax(nn.Dense(net, n_bin))
Пример #18
0
def classifier(x, num_outputs, dropout_rate, deterministic):
    """Implements the classification portion of the network."""

    x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
    x = nn.Dense(x, 512)
    x = nn.relu(x)
    x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
    x = nn.Dense(x, 512)
    x = nn.relu(x)
    x = nn.Dense(x, num_outputs)
    return x
Пример #19
0
 def apply(self, x, num_actions):
     initializer = nn.initializers.xavier_uniform()
     x = x[None, ...]
     x = x.astype(jnp.float32)
     x = x.reshape((x.shape[0], -1))  # flatten
     x = nn.Dense(x, features=512, kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Dense(x, features=512, kernel_init=initializer)
     x = jax.nn.relu(x)
     q_values = nn.Dense(x, features=num_actions, kernel_init=initializer)
     return atari_lib.DQNNetworkType(q_values)
Пример #20
0
    def apply(self,
              x,
              L=10,
              hiddenSize=10,
              inputDim=2,
              actFun=nn.elu,
              initScale=1.0,
              logProbFactor=0.5):

        rnnCell = RNNCell.shared(hiddenSize=hiddenSize,
                                 outDim=hiddenSize,
                                 actFun=actFun,
                                 initScale=initScale,
                                 name="myCell")

        probDense = nn.Dense.shared(
            features=inputDim,
            name="probDense",
            dtype=global_defs.tReal,
            kernel_init=jax.nn.initializers.lecun_normal(
                dtype=global_defs.tReal),
            bias_init=partial(jax.nn.initializers.zeros,
                              dtype=global_defs.tReal))

        state = jnp.zeros((hiddenSize, ))

        def rnn_cell(carry, x):
            newCarry, out = rnnCell(carry[0], carry[1])
            logProb = nn.log_softmax(actFun(probDense(out)))
            logProb = jnp.sum(logProb * x, axis=-1)
            return (newCarry, x), (jnp.nan_to_num(logProb, nan=-35), out)

        _, (probs, phaseOut) = jax.lax.scan(rnn_cell,
                                            (state, jnp.zeros(inputDim)),
                                            jax.nn.one_hot(x, inputDim))

        phase = nn.Dense(phaseOut,
                         features=6,
                         dtype=global_defs.tReal,
                         kernel_init=jax.nn.initializers.lecun_normal(
                             dtype=global_defs.tReal),
                         bias_init=partial(jax.nn.initializers.zeros,
                                           dtype=global_defs.tReal))
        phase = actFun(phase)
        phase = nn.Dense(phaseOut,
                         features=4,
                         dtype=global_defs.tReal,
                         kernel_init=jax.nn.initializers.lecun_normal(
                             dtype=global_defs.tReal),
                         bias_init=partial(jax.nn.initializers.zeros,
                                           dtype=global_defs.tReal))

        return logProbFactor * jnp.sum(probs, axis=0) + 1.j * jnp.mean(
            actFun(phase))
Пример #21
0
 def apply(self, x):
     x = nn.Dense(x, features=50)
     x = nn.tanh(x)
     x = nn.Dense(x, features=50)
     x = nn.tanh(x)
     x = nn.Dense(x, features=50)
     x = nn.tanh(x)
     x = nn.Dense(x, features=50)
     x = nn.tanh(x)
     x = nn.Dense(x, features=1)
     return x
Пример #22
0
 def apply(self, actions, num_layers, hidden_dims):
     timesteps = actions.shape[1]
     # flatten time into batch
     actions = jnp.reshape(actions, (-1, ) + actions.shape[2:])
     # embed actions
     x = nn.Dense(actions, hidden_dims)
     for _ in range(num_layers):
         x = nn.Dense(x, hidden_dims)
         x = nn.LayerNorm(x)
         x = nn.relu(x)
     x = nn.Dense(x, 1)
     x = jnp.reshape(x, (-1, timesteps, 1))
     return x
Пример #23
0
 def apply(self, x):
     x = nn.Conv(x, features=32, kernel_size=(3, 3))
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = nn.Conv(x, features=64, kernel_size=(3, 3))
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = x.reshape((x.shape[0], -1))  # flatten
     x = nn.Dense(x, features=256)
     x = nn.relu(x)
     x = nn.Dense(x, features=10)
     x = nn.log_softmax(x)
     return x
Пример #24
0
 def apply(self, x, num_actions):
     initializer = nn.initializers.xavier_uniform()
     x = x[None, ...]
     x = x.astype(jnp.float32)
     x = x.reshape((x.shape[0], -1))  # flatten
     x -= gym_lib.ACROBOT_MIN_VALS
     x /= gym_lib.ACROBOT_MAX_VALS - gym_lib.ACROBOT_MIN_VALS
     x = 2.0 * x - 1.0  # Rescale in range [-1, 1].
     x = nn.Dense(x, features=512, kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Dense(x, features=512, kernel_init=initializer)
     x = jax.nn.relu(x)
     q_values = nn.Dense(x, features=num_actions, kernel_init=initializer)
     return atari_lib.DQNNetworkType(q_values)
Пример #25
0
 def apply(self,
           inputs: jnp.ndarray,
           hidden_size: int = None,
           output_size: int = None,
           output_bias: bool = False,
           dropout: float = None,
           train: bool = None):
     # inputs.shape = <float32>[batch_size, seq_length, hidden_size]
     hidden = nn.Dense(inputs, hidden_size, name='hidden')
     hidden = nn.tanh(hidden)
     if train:
         hidden = nn.dropout(hidden, rate=dropout)
     output = nn.Dense(hidden, output_size, bias=output_bias, name='output')
     return output
Пример #26
0
def classifier_head_dual(encoded1,
                         encoded2,
                         num_classes,
                         mlp_dim,
                         pooling_mode='MEAN',
                         interaction=None):
    """Classifier head for dual encoding or pairwise problem.

  We put this here just so that all models consistently call the same function.

  Args:
    encoded1: tensor inputs are shape of [bs, len, dim].
    encoded2: tensor inputs are shape of [bs, len, dim].
    num_classes: int, number of classes
    mlp_dim: int, dim of intermediate MLP.
    pooling_mode: str, string dictating pooling op {MEAN}
    interaction: str, string dictating interaction between e1, e2

  Returns:
    tensor of shape [bs, num_classes]

  """
    if pooling_mode == 'MEAN':
        encoded1 = jnp.mean(encoded1, axis=1)
        encoded2 = jnp.mean(encoded2, axis=1)
    elif pooling_mode == 'SUM':
        encoded1 = jnp.sum(encoded1, axis=1)
        encoded2 = jnp.sum(encoded2, axis=1)
    elif pooling_mode == 'FLATTEN':
        encoded1 = encoded1.reshape((encoded1.shape[0], -1))
        encoded2 = encoded2.reshape((encoded2.shape[0], -1))
    elif pooling_mode == 'CLS':
        encoded1 = encoded1[:, 0]
        encoded2 = encoded2[:, 0]
    else:
        raise NotImplementedError('Pooling not supported yet.')

    if interaction == 'NLI':
        # NLI interaction style
        encoded = jnp.concatenate(
            [encoded1, encoded2, encoded1 * encoded2, encoded1 - encoded2], 1)
    else:
        encoded = jnp.concatenate([encoded1, encoded2], 1)
    encoded = nn.Dense(encoded, mlp_dim, name='mlp')
    encoded = nn.relu(encoded)
    encoded = nn.Dense(encoded, int(mlp_dim // 2), name='mlp2')
    encoded = nn.relu(encoded)
    encoded = nn.Dense(encoded, num_classes, name='logits')
    return encoded
Пример #27
0
 def apply(self, x, num_actions, quantile_embedding_dim, num_quantiles,
           rng):
     initializer = jax.nn.initializers.variance_scaling(
         scale=1.0 / jnp.sqrt(3.0), mode='fan_in', distribution='uniform')
     # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
     # have removed the true batch dimension.
     x = x[None, ...]
     x = x.astype(jnp.float32) / 255.
     x = nn.Conv(x,
                 features=32,
                 kernel_size=(8, 8),
                 strides=(4, 4),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(4, 4),
                 strides=(2, 2),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = x.reshape((x.shape[0], -1))  # flatten
     state_vector_length = x.shape[-1]
     state_net_tiled = jnp.tile(x, [num_quantiles, 1])
     quantiles_shape = [num_quantiles, 1]
     quantiles = jax.random.uniform(rng, shape=quantiles_shape)
     quantile_net = jnp.tile(quantiles, [1, quantile_embedding_dim])
     quantile_net = (
         jnp.arange(1, quantile_embedding_dim + 1, 1).astype(jnp.float32) *
         onp.pi * quantile_net)
     quantile_net = jnp.cos(quantile_net)
     quantile_net = nn.Dense(quantile_net,
                             features=state_vector_length,
                             kernel_init=initializer)
     quantile_net = jax.nn.relu(quantile_net)
     x = state_net_tiled * quantile_net
     x = nn.Dense(x, features=512, kernel_init=initializer)
     x = jax.nn.relu(x)
     quantile_values = nn.Dense(x,
                                features=num_actions,
                                kernel_init=initializer)
     return atari_lib.ImplicitQuantileNetworkType(quantile_values,
                                                  quantiles)
Пример #28
0
    def apply(self,
              x,
              act,
              normalize,
              temb=None,
              out_ch=None,
              conv_shortcut=False,
              dropout=0.1,
              train=True,
              skip_rescale=False,
              init_scale=0.):
        B, H, W, C = x.shape
        out_ch = out_ch if out_ch else C
        h = act(normalize(x, num_groups=min(x.shape[-1] // 4, 32)))
        h = conv3x3(h, out_ch)
        # Add bias to each feature map conditioned on the time embedding
        if temb is not None:
            h += nn.Dense(act(temb), out_ch,
                          kernel_init=default_init())[:, None, None, :]

        h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32)))
        h = nn.dropout(h, dropout, deterministic=not train)
        h = conv3x3(h, out_ch, init_scale=init_scale)
        if C != out_ch:
            if conv_shortcut:
                x = conv3x3(x, out_ch)
            else:
                x = NIN(x, out_ch)

        if not skip_rescale:
            return x + h
        else:
            return (x + h) / np.sqrt(2.)
Пример #29
0
 def apply(self, x, num_actions):
     initializer = nn.initializers.xavier_uniform()
     # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
     # have removed the true batch dimension.
     x = x[None, ...]
     x = x.astype(jnp.float32)
     x = x.reshape((x.shape[0], -1))  # flatten
     x -= gym_lib.CARTPOLE_MIN_VALS
     x /= gym_lib.CARTPOLE_MAX_VALS - gym_lib.CARTPOLE_MIN_VALS
     x = 2.0 * x - 1.0  # Rescale in range [-1, 1].
     x = nn.Dense(x, features=512, kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Dense(x, features=512, kernel_init=initializer)
     x = jax.nn.relu(x)
     q_values = nn.Dense(x, features=num_actions, kernel_init=initializer)
     return atari_lib.DQNNetworkType(q_values)
Пример #30
0
    def apply(self,
              x,
              blocks_per_group,
              channel_multiplier,
              num_outputs,
              dropout_rate=0.0,
              train=True):

        x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv')
        x = WideResnetGroup(x,
                            blocks_per_group,
                            16 * channel_multiplier,
                            dropout_rate=dropout_rate,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            32 * channel_multiplier, (2, 2),
                            dropout_rate=dropout_rate,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            64 * channel_multiplier, (2, 2),
                            dropout_rate=dropout_rate,
                            train=train)
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5)
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs)
        return x