Exemplo n.º 1
0
    def apply(self, x, *, stride, filters, train):
        norm_layer = nn.BatchNorm.partial(use_running_average=not train,
                                          momentum=0.9,
                                          epsilon=1e-5)
        conv3x3 = nn.Conv.partial(kernel_size=(3, 3),
                                  padding="SAME",
                                  bias=False)
        conv1x1 = nn.Conv.partial(kernel_size=(1, 1),
                                  padding="SAME",
                                  bias=False)

        x = norm_layer(x)
        x = nn.relu(x)
        identity = x
        needs_projection = x.shape[-1] != filters or stride != (1, 1)
        if needs_projection:
            identity = conv1x1(x, features=filters, strides=stride)

        x = conv3x3(x, features=filters, strides=stride)
        x = norm_layer(x)
        x = nn.relu(x)
        x = conv3x3(x, features=filters, strides=(1, 1))

        x += identity
        return x
Exemplo n.º 2
0
 def apply(self, x, action_dim, max_action):
     x = nn.Dense(x, features=256)
     x = nn.relu(x)
     x = nn.Dense(x, features=256)
     x = nn.relu(x)
     x = nn.Dense(x, features=action_dim)
     return max_action * nn.tanh(x)
Exemplo n.º 3
0
    def apply(self, x, filters, strides=(1, 1), train=True, dtype=jnp.float32):
        needs_projection = x.shape[-1] != filters * 4 or strides != (1, 1)
        batch_norm = nn.BatchNorm.partial(use_running_average=not train,
                                          momentum=0.9,
                                          epsilon=1e-5,
                                          dtype=dtype)
        conv = nn.Conv.partial(bias=False, dtype=dtype)

        residual = x
        if needs_projection:
            residual = conv(residual,
                            filters * 4, (1, 1),
                            strides,
                            name='proj_conv')
            residual = batch_norm(residual, name='proj_bn')

        y = conv(x, filters, (1, 1), name='conv1')
        y = batch_norm(y, name='bn1')
        y = nn.relu(y)
        y = conv(y, filters, (3, 3), strides, name='conv2')
        y = batch_norm(y, name='bn2')
        y = nn.relu(y)
        y = conv(y, filters * 4, (1, 1), name='conv3')

        y = batch_norm(y, name='bn3', scale_init=nn.initializers.zeros)
        y = nn.relu(residual + y)
        return y
  def apply(self, x, rep_size, m_layers, m_features, m_kernel_sizes, conv_rep_size, padding_mask=None):
        
    H_0 = nn.relu(nn.Dense(x, conv_rep_size))
    G_0 = nn.relu(nn.Dense(x, conv_rep_size))
    H, G = jnp.expand_dims(H_0, axis=2), jnp.expand_dims(G_0, axis=2)

    for layer in range(1, m_layers+1):
      
      if layer < m_layers:
        H_features, G_features = m_features[layer-1]
      else:
        H_features, G_features = conv_rep_size, conv_rep_size
      
      H_kernel_size, G_kernel_size = m_kernel_sizes[layer-1]

      H = nn.Conv(H, features=H_features, kernel_size=(H_kernel_size, 1))
      G = nn.Conv(G, features=G_features, kernel_size=(G_kernel_size, 1)) 

      if layer < m_layers:
        H = nn.relu(H)
        G = nn.relu(G)
      else:
        H = nn.tanh(H)
        G = nn.sigmoid(G)

    H, G = jnp.squeeze(H, axis=2), jnp.squeeze(G, axis=2)
    
    F = H * G + G_0
    
    rep = linear_max_pool(F, padding_mask=padding_mask, rep_size=rep_size)
    
    return rep
Exemplo n.º 5
0
    def apply(self, x, nout, strides=(1, 1), bottleneck=True):
        features = nout
        nout = nout * 4 if bottleneck else nout
        needs_projection = x.shape[-1] != nout or strides != (1, 1)
        residual = x
        if needs_projection:
            residual = StdConv(residual,
                               nout, (1, 1),
                               strides,
                               bias=False,
                               name="conv_proj")
            residual = nn.GroupNorm(residual, epsilon=1e-4, name="gn_proj")

        if bottleneck:
            x = StdConv(x, features, (1, 1), bias=False, name="conv1")
            x = nn.GroupNorm(x, epsilon=1e-4, name="gn1")
            x = nn.relu(x)

        x = StdConv(x, features, (3, 3), strides, bias=False, name="conv2")
        x = nn.GroupNorm(x, epsilon=1e-4, name="gn2")
        x = nn.relu(x)

        last_kernel = (1, 1) if bottleneck else (3, 3)
        x = StdConv(x, nout, last_kernel, bias=False, name="conv3")
        x = nn.GroupNorm(x,
                         epsilon=1e-4,
                         name="gn3",
                         scale_init=nn.initializers.zeros)
        x = nn.relu(residual + x)

        return x
Exemplo n.º 6
0
def classifier(x, num_outputs, dropout_rate, deterministic):
    """Implements the classification portion of the network."""

    x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
    x = nn.Dense(x, 512)
    x = nn.relu(x)
    x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
    x = nn.Dense(x, 512)
    x = nn.relu(x)
    x = nn.Dense(x, num_outputs)
    return x
Exemplo n.º 7
0
 def apply(self, x, depth, features, kernel_size, use_one_hot):
     if use_one_hot:
         x = one_hot(x)
     x = MaskedConv1d(x, features, (kernel_size, ), is_first_layer=True)
     x = nn.relu(x)
     for _ in range(depth - 2):
         x = MaskedConv1d(x, features, (kernel_size, ))
         x = nn.relu(x)
     x = MaskedConv1d(x, 4, (kernel_size, ))
     x = real_to_complex(x)
     return x
Exemplo n.º 8
0
 def apply(self, x):
     x = nn.Conv(x, features=32, kernel_size=(3, 3))
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = nn.Conv(x, features=64, kernel_size=(3, 3))
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = x.reshape((x.shape[0], -1))  # flatten
     x = nn.Dense(x, features=256)
     x = nn.relu(x)
     x = nn.Dense(x, features=10)
     x = nn.log_softmax(x)
     return x
Exemplo n.º 9
0
def classifier_head_dual(encoded1,
                         encoded2,
                         num_classes,
                         mlp_dim,
                         pooling_mode='MEAN',
                         interaction=None):
    """Classifier head for dual encoding or pairwise problem.

  We put this here just so that all models consistently call the same function.

  Args:
    encoded1: tensor inputs are shape of [bs, len, dim].
    encoded2: tensor inputs are shape of [bs, len, dim].
    num_classes: int, number of classes
    mlp_dim: int, dim of intermediate MLP.
    pooling_mode: str, string dictating pooling op {MEAN}
    interaction: str, string dictating interaction between e1, e2

  Returns:
    tensor of shape [bs, num_classes]

  """
    if pooling_mode == 'MEAN':
        encoded1 = jnp.mean(encoded1, axis=1)
        encoded2 = jnp.mean(encoded2, axis=1)
    elif pooling_mode == 'SUM':
        encoded1 = jnp.sum(encoded1, axis=1)
        encoded2 = jnp.sum(encoded2, axis=1)
    elif pooling_mode == 'FLATTEN':
        encoded1 = encoded1.reshape((encoded1.shape[0], -1))
        encoded2 = encoded2.reshape((encoded2.shape[0], -1))
    elif pooling_mode == 'CLS':
        encoded1 = encoded1[:, 0]
        encoded2 = encoded2[:, 0]
    else:
        raise NotImplementedError('Pooling not supported yet.')

    if interaction == 'NLI':
        # NLI interaction style
        encoded = jnp.concatenate(
            [encoded1, encoded2, encoded1 * encoded2, encoded1 - encoded2], 1)
    else:
        encoded = jnp.concatenate([encoded1, encoded2], 1)
    encoded = nn.Dense(encoded, mlp_dim, name='mlp')
    encoded = nn.relu(encoded)
    encoded = nn.Dense(encoded, int(mlp_dim // 2), name='mlp2')
    encoded = nn.relu(encoded)
    encoded = nn.Dense(encoded, num_classes, name='logits')
    return encoded
Exemplo n.º 10
0
 def apply(self, x, use_squeeze_excite=False):
     x = nn.Conv(x, features=8, kernel_size=(3, 3), padding="VALID")
     x = nn.relu(x)
     x = nn.Conv(x, features=16, kernel_size=(3, 3), padding="VALID")
     x = nn.relu(x)
     if use_squeeze_excite:
         x = SqueezeExciteLayer(x)
     x = nn.Conv(x, features=32, kernel_size=(3, 3), padding="VALID")
     x = nn.relu(x)
     if use_squeeze_excite:
         x = SqueezeExciteLayer(x)
     x = nn.Conv(x, features=1, kernel_size=(3, 3), padding="VALID")
     scores = nn.max_pool(x, window_shape=(8, 8), strides=(8, 8))[Ellipsis,
                                                                  0]
     return scores
Exemplo n.º 11
0
    def apply(
        self,
        x,
        action_dim,
        max_action,
        key=None,
        MPO=False,
        sample=False,
        log_sig_min=-20,
        log_sig_max=2,
    ):
        x = nn.Dense(x, features=200)
        x = nn.LayerNorm(x)
        x = nn.tanh(x)
        x = nn.Dense(x, features=200)
        x = nn.elu(x)
        x = nn.Dense(x, features=2 * action_dim)

        mu, log_sig = jnp.split(x, 2, axis=-1)
        log_sig = nn.softplus(log_sig)
        log_sig = jnp.clip(log_sig, log_sig_min, log_sig_max)

        if MPO:
            return mu, log_sig

        if not sample:
            return max_action * nn.tanh(mu), log_sig
        else:
            pi = mu + random.normal(key, mu.shape) * jnp.exp(log_sig)
            log_pi = gaussian_likelihood(pi, mu, log_sig)
            pi = nn.tanh(pi)
            log_pi -= jnp.sum(jnp.log(nn.relu(1 - pi ** 2) + 1e-6), axis=1)
            return max_action * pi, log_pi
Exemplo n.º 12
0
 def apply(self,
           inputs,
           mlp_dim,
           dtype=jnp.float32,
           out_dim=None,
           dropout_rate=0.1,
           deterministic=False,
           kernel_init=nn.initializers.xavier_uniform(),
           bias_init=nn.initializers.normal(stddev=1e-6)):
     """Applies Transformer MlpBlock module."""
     actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim
     x = nn.Dense(inputs,
                  mlp_dim,
                  dtype=dtype,
                  kernel_init=kernel_init,
                  bias_init=bias_init)
     x = nn.relu(x)
     x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
     output = nn.Dense(x,
                       actual_out_dim,
                       dtype=dtype,
                       kernel_init=kernel_init,
                       bias_init=bias_init)
     output = nn.dropout(output,
                         rate=dropout_rate,
                         deterministic=deterministic)
     return output
Exemplo n.º 13
0
    def apply(self, x, num_classes=1000, width_factor=1, num_layers=50):
        block_sizes = _block_sizes[num_layers]

        width = 64 * width_factor

        root_block = RootBlock.partial(width=width)
        x = root_block(x, name='root_block')

        # Blocks
        for i, block_size in enumerate(block_sizes):
            x = ResidualBlock(x,
                              block_size,
                              width * 2**i,
                              first_stride=(1, 1) if i == 0 else (2, 2),
                              name=f"block{i + 1}")

        # Pre-head
        x = GroupNorm(x, name='norm-pre-head')
        x = nn.relu(x)
        x = jnp.mean(x, axis=(1, 2))

        # Head
        x = nn.Dense(x,
                     num_classes,
                     name="conv_head",
                     kernel_init=nn.initializers.zeros)

        return x.astype(jnp.float32)
Exemplo n.º 14
0
 def apply(self,
           inputs,
           mlp_dim,
           dtype=jnp.float32,
           out_dim=None,
           dropout_rate=0.1,
           deterministic=False,
           kernel_init=nn.initializers.xavier_uniform(),
           bias_init=nn.initializers.normal(stddev=1e-6),
           num_partitions=2):
     """Applies Transformer MlpBlock module."""
     actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim
     inputs_shape = inputs.shape
     inputs = inputs.reshape((-1, inputs_shape[-1]))
     x = nn.Dense(inputs,
                  mlp_dim,
                  dtype=dtype,
                  kernel_init=kernel_init,
                  bias_init=bias_init)
     x = nn.relu(x)
     if num_partitions > 1:
         x = with_sharding_constraint(x, P(1, num_partitions))
     x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
     output = nn.Dense(x,
                       actual_out_dim,
                       dtype=dtype,
                       kernel_init=kernel_init,
                       bias_init=bias_init)
     output = nn.dropout(output,
                         rate=dropout_rate,
                         deterministic=deterministic)
     output = output.reshape(inputs_shape[:-1] + (actual_out_dim, ))
     return output
Exemplo n.º 15
0
def classifier_head(encoded, num_classes, mlp_dim, pooling_mode='MEAN'):
  """Classifier head.

  We put this here just so that all models consistently call the same function.

  Args:
    encoded: tensor inputs are shape of [bs, len, dim].
    num_classes: int, number of classes
    mlp_dim: int, dim of intermediate MLP.
    pooling_mode: str, string dictating pooling op {MEAN}

  Returns:
    tensor of shape [bs, num_classes]

  """
  if pooling_mode == 'MEAN':
    encoded = jnp.mean(encoded, axis=1)
  elif pooling_mode == 'SUM':
    encoded = jnp.sum(encoded, axis=1)
  elif pooling_mode == 'FLATTEN':
    encoded = encoded.reshape((encoded.shape[0], -1))
  elif pooling_mode == 'CLS':
    encoded = encoded[:, 0]
  else:
    raise NotImplementedError('Pooling not supported yet.')
  encoded = nn.Dense(encoded, mlp_dim, name='mlp')
  encoded = nn.relu(encoded)
  encoded = nn.Dense(encoded, num_classes, name='logits')
  return encoded
Exemplo n.º 16
0
def apply_activation(intermediate_output, intermediate_activation):
    """Applies selected activation function to intermediate output."""
    if intermediate_activation is None:
        return intermediate_output

    if intermediate_activation == 'gelu':
        intermediate_output = nn.gelu(intermediate_output)
    elif intermediate_activation == 'relu':
        intermediate_output = nn.relu(intermediate_output)
    elif intermediate_activation == 'sigmoid':
        intermediate_output = nn.sigmoid(intermediate_output)
    elif intermediate_activation == 'softmax':
        intermediate_output = nn.softmax(intermediate_output)
    elif intermediate_activation == 'celu':
        intermediate_output = nn.celu(intermediate_output)
    elif intermediate_activation == 'elu':
        intermediate_output = nn.elu(intermediate_output)
    elif intermediate_activation == 'log_sigmoid':
        intermediate_output = nn.log_sigmoid(intermediate_output)
    elif intermediate_activation == 'log_softmax':
        intermediate_output = nn.log_softmax(intermediate_output)
    elif intermediate_activation == 'soft_sign':
        intermediate_output = nn.soft_sign(intermediate_output)
    elif intermediate_activation == 'softplus':
        intermediate_output = nn.softplus(intermediate_output)
    elif intermediate_activation == 'swish':
        intermediate_output = nn.swish(intermediate_output)
    elif intermediate_activation == 'tanh':
        intermediate_output = jnp.tanh(intermediate_output)
    else:
        raise NotImplementedError(
            '%s activation function is not yet supported.' %
            intermediate_activation)

    return intermediate_output
Exemplo n.º 17
0
 def apply(self, x):
     x = nn.Conv(x, features=32, kernel_size=(3, 3), name="conv")
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = x.reshape((x.shape[0], -1))  # flatten.
     x = nn.Dense(x, 128, name="fc")
     return x
Exemplo n.º 18
0
  def apply(self, x, num_classes, *,
            stage_sizes,
            block_cls,
            num_filters=64,
            dtype=jnp.float32,
            act=nn.relu,
            train=True):
    conv = nn.Conv.partial(bias=False, dtype=dtype)
    norm = nn.BatchNorm.partial(
        use_running_average=not train,
        momentum=0.9, epsilon=1e-5,
        dtype=dtype)

    x = conv(x, num_filters, (7, 7), (2, 2),
             padding=[(3, 3), (3, 3)],
             name='conv_init')
    x = norm(x, name='bn_init')
    x = nn.relu(x)
    x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    for i, block_size in enumerate(stage_sizes):
      for j in range(block_size):
        strides = (2, 2) if i > 0 and j == 0 else (1, 1)
        x = block_cls(x, num_filters * 2 ** i,
                      strides=strides,
                      conv=conv,
                      norm=norm,
                      act=act)
    x = jnp.mean(x, axis=(1, 2))
    x = nn.Dense(x, num_classes, dtype=dtype)
    x = jnp.asarray(x, dtype)
    x = nn.log_softmax(x)
    return x
Exemplo n.º 19
0
 def apply(self, x, hidden_layers, hidden_dim, n_classes):
     x = jnp.reshape(x, (x.shape[0], -1))
     for layer in range(hidden_layers):
         x = nn.Dense(x, hidden_dim, name=f'fc{layer}')
         x = nn.relu(x)
     x = nn.Dense(x, n_classes, name=f'fc{hidden_layers}')
     preds = nn.log_softmax(x)
     return preds
Exemplo n.º 20
0
 def apply(self, x, reduction=16):
     num_channels = x.shape[-1]
     y = x.mean(axis=(1, 2))
     y = nn.Dense(y, features=num_channels // reduction, bias=False)
     y = nn.relu(y)
     y = nn.Dense(y, features=num_channels, bias=False)
     y = nn.sigmoid(y)
     return x * y[:, None, None, :]
Exemplo n.º 21
0
    def apply(self,
              x,
              num_layers,
              num_outputs,
              growth_rate,
              reduction,
              normalizer='batch_norm',
              dtype='float32',
              train=True):
        def dense_layers(y, block, num_blocks, growth_rate):
            for _ in range(num_blocks):
                y = block(y, growth_rate)
            return y

        def update_num_features(num_features, num_blocks, growth_rate,
                                reduction):
            num_features += num_blocks * growth_rate
            if reduction is not None:
                num_features = int(math.floor(num_features * reduction))
            return num_features

        # Initial convolutional layer
        num_features = 2 * growth_rate
        conv = nn.Conv.partial(bias=False, dtype=dtype)
        y = conv(x,
                 features=num_features,
                 kernel_size=(3, 3),
                 padding=((1, 1), (1, 1)),
                 name='conv1')

        # Internal dense and transtion blocks
        num_blocks = _block_size_options[num_layers]
        block = BottleneckBlock.partial(train=train,
                                        dtype=dtype,
                                        normalizer=normalizer)
        for i in range(3):
            y = dense_layers(y, block, num_blocks[i], growth_rate)
            num_features = update_num_features(num_features, num_blocks[i],
                                               growth_rate, reduction)
            y = TransitionBlock(y,
                                num_features,
                                train=train,
                                dtype=dtype,
                                normalizer=normalizer)

        # Final dense block
        y = dense_layers(y, block, num_blocks[3], growth_rate)

        # Final pooling
        maybe_normalize = model_utils.get_normalizer(normalizer, train)
        y = maybe_normalize(y)
        y = nn.relu(y)
        y = nn.avg_pool(y, window_shape=(4, 4))

        # Classification layer
        y = jnp.reshape(y, (y.shape[0], -1))
        y = nn.Dense(y, num_outputs)
        return y
Exemplo n.º 22
0
    def apply(self, state, action, Q1=False):
        state_action = jnp.concatenate([state, action], axis=1)

        q1 = nn.Dense(state_action, features=256)
        q1 = nn.relu(q1)
        q1 = nn.Dense(q1, features=256)
        q1 = nn.relu(q1)
        q1 = nn.Dense(q1, features=1)

        if Q1:
            return q1

        q2 = nn.Dense(state_action, features=256)
        q2 = nn.relu(q2)
        q2 = nn.Dense(q2, features=256)
        q2 = nn.relu(q2)
        q2 = nn.Dense(q2, features=1)

        return q1, q2
Exemplo n.º 23
0
 def apply(
     self,
     x,
     num_classes,
     num_filters=64,
     num_layers=50,
     train=True,
     axis_name=None,
     axis_index_groups=None,
     dtype=jnp.float32,
     batch_norm_momentum=0.9,
     batch_norm_epsilon=1e-5,
     bn_output_scale=0.0,
     virtual_batch_size=None,
     data_format=None):
   if num_layers not in _block_size_options:
     raise ValueError('Please provide a valid number of layers')
   block_sizes = _block_size_options[num_layers]
   conv = nn.Conv.partial(padding=[(3, 3), (3, 3)])
   x = conv(x, num_filters, kernel_size=(7, 7), strides=(2, 2), bias=False,
            dtype=dtype, name='conv0')
   x = normalization.VirtualBatchNorm(
       x,
       use_running_average=not train,
       momentum=batch_norm_momentum,
       epsilon=batch_norm_epsilon,
       name='init_bn',
       axis_name=axis_name,
       axis_index_groups=axis_index_groups,
       dtype=dtype,
       virtual_batch_size=virtual_batch_size,
       data_format=data_format)
   x = nn.relu(x)  # MLPerf-required
   x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
   for i, block_size in enumerate(block_sizes):
     for j in range(block_size):
       strides = (2, 2) if i > 0 and j == 0 else (1, 1)
       x = ResidualBlock(
           x,
           num_filters * 2 ** i,
           strides=strides,
           train=train,
           axis_name=axis_name,
           axis_index_groups=axis_index_groups,
           dtype=dtype,
           batch_norm_momentum=batch_norm_momentum,
           batch_norm_epsilon=batch_norm_epsilon,
           bn_output_scale=bn_output_scale,
           virtual_batch_size=virtual_batch_size,
           data_format=data_format)
   x = jnp.mean(x, axis=(1, 2))
   x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(),
                dtype=dtype)
   return x
Exemplo n.º 24
0
    def apply(self,
              x,
              filters,
              strides=(1, 1),
              train=True,
              batch_stats=None,
              dtype=jnp.float32,
              batch_norm_momentum=0.9,
              batch_norm_epsilon=1e-5,
              virtual_batch_size=None,
              data_format=None):
        needs_projection = x.shape[-1] != filters * 4 or strides != (1, 1)
        batch_norm = normalization.VirtualBatchNorm.partial(
            batch_stats=batch_stats,
            use_running_average=not train,
            momentum=batch_norm_momentum,
            epsilon=batch_norm_epsilon,
            dtype=dtype,
            virtual_batch_size=virtual_batch_size,
            data_format=data_format)
        conv = nn.Conv.partial(bias=False, dtype=dtype)

        residual = x
        if needs_projection:
            residual = conv(residual,
                            filters * 4, (1, 1),
                            strides,
                            name='proj_conv')
            residual = batch_norm(residual, name='proj_bn')

        y = conv(x, filters, (1, 1), name='conv1')
        y = batch_norm(y, name='bn1')
        y = nn.relu(y)
        y = conv(y, filters, (3, 3), strides, name='conv2')
        y = batch_norm(y, name='bn2')
        y = nn.relu(y)
        y = conv(y, filters * 4, (1, 1), name='conv3')

        y = batch_norm(y, name='bn3', scale_init=nn.initializers.zeros)
        y = nn.relu(residual + y)
        return y
Exemplo n.º 25
0
    def apply(self, x, num_outputs):
        """Define the convolutional network architecture.

    Architecture originates from "Human-level control through deep reinforcement
    learning.", Nature 518, no. 7540 (2015): 529-533.
    Note that this is different than the one from  "Playing atari with deep
    reinforcement learning." arxiv.org/abs/1312.5602 (2013)
    """
        dtype = jnp.float32
        x = x.astype(dtype) / 255.
        x = nn.Conv(x,
                    features=32,
                    kernel_size=(8, 8),
                    strides=(4, 4),
                    name='conv1',
                    dtype=dtype)
        x = nn.relu(x)
        x = nn.Conv(x,
                    features=64,
                    kernel_size=(4, 4),
                    strides=(2, 2),
                    name='conv2',
                    dtype=dtype)
        x = nn.relu(x)
        x = nn.Conv(x,
                    features=64,
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    name='conv3',
                    dtype=dtype)
        x = nn.relu(x)
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(x, features=512, name='hidden', dtype=dtype)
        x = nn.relu(x)
        # Network used to both estimate policy (logits) and expected state value.
        # See github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py
        logits = nn.Dense(x, features=num_outputs, name='logits', dtype=dtype)
        policy_log_probabilities = nn.log_softmax(logits)
        value = nn.Dense(x, features=1, name='value', dtype=dtype)
        return policy_log_probabilities, value
Exemplo n.º 26
0
 def apply(self,
           x,
           num_classes,
           num_filters=64,
           num_layers=50,
           train=True,
           axis_name=None,
           axis_index_groups=None,
           dtype=jnp.float32,
           conv0_space_to_depth=False):
     if num_layers not in _block_size_options:
         raise ValueError('Please provide a valid number of layers')
     block_sizes = _block_size_options[num_layers]
     if conv0_space_to_depth:
         conv = SpaceToDepthConv.partial(block_size=(2, 2),
                                         padding=[(2, 1), (2, 1)])
     else:
         conv = nn.Conv.partial(padding=[(3, 3), (3, 3)])
     x = conv(x,
              num_filters,
              kernel_size=(7, 7),
              strides=(2, 2),
              bias=False,
              dtype=dtype,
              name='conv0')
     x = nn.BatchNorm(x,
                      use_running_average=not train,
                      momentum=0.9,
                      epsilon=1e-5,
                      name='init_bn',
                      axis_name=axis_name,
                      axis_index_groups=axis_index_groups,
                      dtype=dtype)
     x = nn.relu(x)  # MLPerf-required
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(block_sizes):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = ResidualBlock(x,
                               num_filters * 2**i,
                               strides=strides,
                               train=train,
                               axis_name=axis_name,
                               axis_index_groups=axis_index_groups,
                               dtype=dtype)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(x,
                  num_classes,
                  kernel_init=nn.initializers.normal(),
                  dtype=dtype)
     x = nn.log_softmax(x)
     return x
  def apply(self, x, n_layers, n_features, n_kernel_sizes):
    
    x = jnp.expand_dims(x, axis=2)

    for layer in range(n_layers):
      features = n_features[layer]
      kernel_size = (n_kernel_sizes[layer], 1)
      x = nn.Conv(x, features=features, kernel_size=kernel_size)
      x = nn.relu(x)
    
    x = jnp.squeeze(x, axis=2)

    return x
Exemplo n.º 28
0
 def apply(self, actions, num_layers, hidden_dims):
     timesteps = actions.shape[1]
     # flatten time into batch
     actions = jnp.reshape(actions, (-1, ) + actions.shape[2:])
     # embed actions
     x = nn.Dense(actions, hidden_dims)
     for _ in range(num_layers):
         x = nn.Dense(x, hidden_dims)
         x = nn.LayerNorm(x)
         x = nn.relu(x)
     x = nn.Dense(x, 1)
     x = jnp.reshape(x, (-1, timesteps, 1))
     return x
Exemplo n.º 29
0
    def apply(self,
              x,
              num_features,
              train=True,
              dtype=jnp.float32,
              normalizer='batch_norm'):
        conv = nn.Conv.partial(bias=False, dtype=dtype)
        maybe_normalize = model_utils.get_normalizer(normalizer, train)

        y = maybe_normalize(x)
        y = nn.relu(y)
        y = conv(y, features=num_features, kernel_size=(1, 1))
        y = nn.avg_pool(y, window_shape=(2, 2))
        return y
Exemplo n.º 30
0
    def apply(self, x, nout, strides=(1, 1)):
        x_shortcut = x
        needs_projection = x.shape[-1] != nout * 4 or strides != (1, 1)

        group_norm = GroupNorm
        conv = StdConv.partial(bias=False)

        x = group_norm(x, name="gn1")
        x = nn.relu(x)
        if needs_projection:
            x_shortcut = conv(x, nout * 4, (1, 1), strides, name="conv_proj")
        x = conv(x, nout, (1, 1), name="conv1")

        x = group_norm(x, name="gn2")
        x = nn.relu(x)
        x = fixed_padding(x, 3)
        x = conv(x, nout, (3, 3), strides, name="conv2", padding='VALID')

        x = group_norm(x, name="gn3")
        x = nn.relu(x)
        x = conv(x, nout * 4, (1, 1), name="conv3")

        return x + x_shortcut