Пример #1
0
    def apply(self,
              x,
              num_filters=64,
              block_sizes=(3, 4, 6, 3),
              train=True,
              block=BottleneckBlock,
              small_inputs=False):
        if small_inputs:
            x = nn.Conv(x,
                        num_filters,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        bias=False,
                        name="init_conv")
        else:
            x = nn.Conv(x,
                        num_filters,
                        kernel_size=(7, 7),
                        strides=(2, 2),
                        bias=False,
                        name="init_conv")
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         epsilon=1e-5,
                         name="init_bn")
        if not small_inputs:
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
        for i, block_size in enumerate(block_sizes):
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                x = block(x, num_filters * 2**i, strides=strides, train=train)

        return x
Пример #2
0
 def apply(self, x, num_actions, num_atoms):
     initializer = jax.nn.initializers.variance_scaling(
         scale=1.0 / jnp.sqrt(3.0), mode='fan_in', distribution='uniform')
     # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
     # have removed the true batch dimension.
     x = x[None, ...]
     x = x.astype(jnp.float32) / 255.
     x = nn.Conv(x,
                 features=32,
                 kernel_size=(8, 8),
                 strides=(4, 4),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(4, 4),
                 strides=(2, 2),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = x.reshape((x.shape[0], -1))  # flatten
     x = nn.Dense(x, features=512, kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Dense(x,
                  features=num_actions * num_atoms,
                  kernel_init=initializer)
     logits = x.reshape((x.shape[0], num_actions, num_atoms))
     probabilities = nn.softmax(logits)
     q_values = jnp.mean(logits, axis=2)
     return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
  def apply(self, x, rep_size, m_layers, m_features, m_kernel_sizes, conv_rep_size, padding_mask=None):
        
    H_0 = nn.relu(nn.Dense(x, conv_rep_size))
    G_0 = nn.relu(nn.Dense(x, conv_rep_size))
    H, G = jnp.expand_dims(H_0, axis=2), jnp.expand_dims(G_0, axis=2)

    for layer in range(1, m_layers+1):
      
      if layer < m_layers:
        H_features, G_features = m_features[layer-1]
      else:
        H_features, G_features = conv_rep_size, conv_rep_size
      
      H_kernel_size, G_kernel_size = m_kernel_sizes[layer-1]

      H = nn.Conv(H, features=H_features, kernel_size=(H_kernel_size, 1))
      G = nn.Conv(G, features=G_features, kernel_size=(G_kernel_size, 1)) 

      if layer < m_layers:
        H = nn.relu(H)
        G = nn.relu(G)
      else:
        H = nn.tanh(H)
        G = nn.sigmoid(G)

    H, G = jnp.squeeze(H, axis=2), jnp.squeeze(G, axis=2)
    
    F = H * G + G_0
    
    rep = linear_max_pool(F, padding_mask=padding_mask, rep_size=rep_size)
    
    return rep
Пример #4
0
 def apply(self, x, num_actions):
     initializer = nn.initializers.xavier_uniform()
     # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
     # have removed the true batch dimension.
     x = x[None, ...]
     x = x.astype(jnp.float32) / 255.
     x = nn.Conv(x,
                 features=32,
                 kernel_size=(8, 8),
                 strides=(4, 4),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(4, 4),
                 strides=(2, 2),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = x.reshape((x.shape[0], -1))  # flatten
     x = nn.Dense(x, features=512, kernel_init=initializer)
     x = jax.nn.relu(x)
     q_values = nn.Dense(x, features=num_actions, kernel_init=initializer)
     return atari_lib.DQNNetworkType(q_values)
Пример #5
0
 def apply(self,
           x,
           filters,
           strides=(1, 1),
           groups=1,
           base_width=64,
           train=True):
     needs_projection = x.shape[-1] != filters * 4 or strides != (1, 1)
     width = int(filters * (base_width / 64.)) * groups
     batch_norm = nn.BatchNorm.partial(use_running_average=not train,
                                       momentum=0.9,
                                       epsilon=1e-5)
     y = nn.Conv(x, width, (1, 1), (1, 1), bias=False, name='conv1')
     y = batch_norm(y, name='bn1')
     y = jax.nn.relu(y)
     y = nn.Conv(y,
                 width, (3, 3),
                 strides,
                 bias=False,
                 feature_group_count=groups,
                 name='conv2')
     y = batch_norm(y, name='bn2')
     y = jax.nn.relu(y)
     y = nn.Conv(y, filters * 4, (1, 1), (1, 1), bias=False, name='conv3')
     y = batch_norm(y, name='bn3', scale_init=initializers.zeros)
     if needs_projection:
         x = nn.Conv(x,
                     filters * 4, (1, 1),
                     strides,
                     bias=False,
                     name='proj_conv')
         x = batch_norm(x, name='proj_bn')
     return jax.nn.relu(x + y)
Пример #6
0
  def apply(self, x, channels, strides=(1, 1), train=True):
    batch_norm = nn.BatchNorm.partial(use_running_average=not train,
                                      momentum=0.9, epsilon=1e-5)

    a = b = residual = x

    a = jax.nn.relu(a)
    a = nn.Conv(a, channels, (3, 3), strides, padding='SAME', name='conv_a_1')
    a = batch_norm(a, name='bn_a_1')
    a = jax.nn.relu(a)
    a = nn.Conv(a, channels, (3, 3), padding='SAME', name='conv_a_2')
    a = batch_norm(a, name='bn_a_2')

    b = jax.nn.relu(b)
    b = nn.Conv(b, channels, (3, 3), strides, padding='SAME', name='conv_b_1')
    b = batch_norm(b, name='bn_b_1')
    b = jax.nn.relu(b)
    b = nn.Conv(b, channels, (3, 3), padding='SAME', name='conv_b_2')
    b = batch_norm(b, name='bn_b_2')

    if train and not self.is_initializing():
      ab = utils.shake_shake_train(a, b)
    else:
      ab = utils.shake_shake_eval(a, b)

    # Apply an up projection in case of channel mismatch
    if (residual.shape[-1] != channels) or strides != (1, 1):
      residual = nn.Conv(residual, channels, (3, 3), strides, padding='SAME',
                         name='conv_residual')
      residual = batch_norm(residual, name='bn_residual')

    return residual + ab
Пример #7
0
    def apply(self,
              x,
              channels,
              strides,
              prob,
              alpha_min,
              alpha_max,
              beta_min,
              beta_max,
              train=True):
        """Implements the forward pass in the module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, features]
        where dim is the resolution (width and height if the input is an image).
      channels: How many channels to use in the convolutional layers.
      strides: Strides for the pooling.
      prob: Probability of dropping the block (see paper for details).
      alpha_min: See paper.
      alpha_max: See paper.
      beta_min: See paper.
      beta_max: See paper.
      train: If False, will use the moving average for batch norm statistics.
        Else, will use statistics computed on the batch.

    Returns:
      The output of the bottleneck block.
    """
        y = utils.activation(x, apply_relu=False, train=train, name='bn_1_pre')
        y = nn.Conv(y,
                    channels, (1, 1),
                    padding='SAME',
                    bias=False,
                    kernel_init=utils.conv_kernel_init_fn,
                    name='1x1_conv_contract')
        y = utils.activation(y, train=train, name='bn_1_post')
        y = nn.Conv(y,
                    channels, (3, 3),
                    strides,
                    padding='SAME',
                    bias=False,
                    kernel_init=utils.conv_kernel_init_fn,
                    name='3x3')
        y = utils.activation(y, train=train, name='bn_2')
        y = nn.Conv(y,
                    channels * 4, (1, 1),
                    padding='SAME',
                    bias=False,
                    kernel_init=utils.conv_kernel_init_fn,
                    name='1x1_conv_expand')
        y = utils.activation(y, apply_relu=False, train=train, name='bn_3')

        if train:
            y = utils.shake_drop_train(y, prob, alpha_min, alpha_max, beta_min,
                                       beta_max)
        else:
            y = utils.shake_drop_eval(y, prob, alpha_min, alpha_max)

        x = _shortcut(x, channels * 4, strides)
        return x + y
Пример #8
0
 def apply(self, x):
     x = nn.Conv(x, features=32, kernel_size=(3, 3))
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = nn.Conv(x, features=64, kernel_size=(3, 3))
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = x.reshape((x.shape[0], -1))  # flatten
     x = nn.Dense(x, features=256)
     x = nn.relu(x)
     x = nn.Dense(x, features=10)
     x = nn.log_softmax(x)
     return x
Пример #9
0
  def apply(self,
            x,
            channels,
            strides = (1, 1),
            train = True):
    """Implements the forward pass in the module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, features]
        where dim is the resolution (width and height if the input is an image).
      channels: How many channels to use in the convolutional layers.
      strides: Strides for the pooling.
      train: If False, will use the moving average for batch norm statistics.

    Returns:
      The output of the resnet block. Will have shape
        [batch_size, dim, dim, channels] if strides = (1, 1) or
        [batch_size, dim/2, dim/2, channels] if strides = (2, 2).
    """

    if x.shape[-1] == channels:
      return x

    # Skip path 1
    h1 = nn.avg_pool(x, (1, 1), strides=strides, padding='VALID')
    h1 = nn.Conv(
        h1,
        channels // 2, (1, 1),
        strides=(1, 1),
        padding='SAME',
        bias=False,
        kernel_init=utils.conv_kernel_init_fn,
        name='conv_h1')

    # Skip path 2
    # The next two lines offset the "image" by one pixel on the right and one
    # down (see Shake-Shake regularization, Xavier Gastaldi for details)
    pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]]
    h2 = jnp.pad(x, pad_arr)[:, 1:, 1:, :]
    h2 = nn.avg_pool(h2, (1, 1), strides=strides, padding='VALID')
    h2 = nn.Conv(
        h2,
        channels // 2, (1, 1),
        strides=(1, 1),
        padding='SAME',
        bias=False,
        kernel_init=utils.conv_kernel_init_fn,
        name='conv_h2')
    merged_branches = jnp.concatenate([h1, h2], axis=3)
    return utils.activation(
        merged_branches, apply_relu=False, train=train, name='bn_residual')
 def apply(self, x, use_squeeze_excite=False):
     x = nn.Conv(x, features=8, kernel_size=(3, 3), padding="VALID")
     x = nn.relu(x)
     x = nn.Conv(x, features=16, kernel_size=(3, 3), padding="VALID")
     x = nn.relu(x)
     if use_squeeze_excite:
         x = SqueezeExciteLayer(x)
     x = nn.Conv(x, features=32, kernel_size=(3, 3), padding="VALID")
     x = nn.relu(x)
     if use_squeeze_excite:
         x = SqueezeExciteLayer(x)
     x = nn.Conv(x, features=1, kernel_size=(3, 3), padding="VALID")
     scores = nn.max_pool(x, window_shape=(8, 8), strides=(8, 8))[Ellipsis,
                                                                  0]
     return scores
Пример #11
0
    def apply(self,
              x: jnp.ndarray,
              channels: int,
              strides: Tuple[int, int] = (1, 1),
              activate_before_residual: bool = False,
              train: bool = True) -> jnp.ndarray:
        """Implements the forward pass in the module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, features]
        where dim is the resolution (width and height if the input is an image).
      channels: How many channels to use in the convolutional layers.
      strides: Strides for the pooling.
      activate_before_residual: True if the batch norm and relu should be
        applied before the residual branches out (should be True only for the
        first block of the model).
      train: If False, will use the moving average for batch norm statistics.
        Else, will use statistics computed on the batch.

    Returns:
      The output of the resnet block.
    """
        if activate_before_residual:
            x = utils.activation(x, train, name='init_bn')
            orig_x = x
        else:
            orig_x = x

        block_x = x
        if not activate_before_residual:
            block_x = utils.activation(block_x, train, name='init_bn')

        block_x = nn.Conv(block_x,
                          channels, (3, 3),
                          strides,
                          padding='SAME',
                          bias=False,
                          kernel_init=utils.conv_kernel_init_fn,
                          name='conv1')
        block_x = utils.activation(block_x, train=train, name='bn_2')
        block_x = nn.Conv(block_x,
                          channels, (3, 3),
                          padding='SAME',
                          bias=False,
                          kernel_init=utils.conv_kernel_init_fn,
                          name='conv2')

        return _output_add(block_x, orig_x)
Пример #12
0
 def apply(self, x, num_actions, quantile_embedding_dim, num_quantiles,
           rng):
     initializer = jax.nn.initializers.variance_scaling(
         scale=1.0 / jnp.sqrt(3.0), mode='fan_in', distribution='uniform')
     # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
     # have removed the true batch dimension.
     x = x[None, ...]
     x = x.astype(jnp.float32) / 255.
     x = nn.Conv(x,
                 features=32,
                 kernel_size=(8, 8),
                 strides=(4, 4),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(4, 4),
                 strides=(2, 2),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = x.reshape((x.shape[0], -1))  # flatten
     state_vector_length = x.shape[-1]
     state_net_tiled = jnp.tile(x, [num_quantiles, 1])
     quantiles_shape = [num_quantiles, 1]
     quantiles = jax.random.uniform(rng, shape=quantiles_shape)
     quantile_net = jnp.tile(quantiles, [1, quantile_embedding_dim])
     quantile_net = (
         jnp.arange(1, quantile_embedding_dim + 1, 1).astype(jnp.float32) *
         onp.pi * quantile_net)
     quantile_net = jnp.cos(quantile_net)
     quantile_net = nn.Dense(quantile_net,
                             features=state_vector_length,
                             kernel_init=initializer)
     quantile_net = jax.nn.relu(quantile_net)
     x = state_net_tiled * quantile_net
     x = nn.Dense(x, features=512, kernel_init=initializer)
     x = jax.nn.relu(x)
     quantile_values = nn.Dense(x,
                                features=num_actions,
                                kernel_init=initializer)
     return atari_lib.ImplicitQuantileNetworkType(quantile_values,
                                                  quantiles)
Пример #13
0
 def apply(self, x):
     x = nn.Conv(x, features=32, kernel_size=(3, 3), name="conv")
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = x.reshape((x.shape[0], -1))  # flatten.
     x = nn.Dense(x, 128, name="fc")
     return x
Пример #14
0
  def apply(self,x, num_actions, noisy, dueling):
    # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
    # have removed the true batch dimension.
    initializer_conv = nn.initializers.xavier_uniform()
    x = x[None, ...]
    x = x.astype(jnp.float32)
    x = nn.Conv(x, features=16, kernel_size=(3, 3, 3), strides=(1, 1, 1),  kernel_init=initializer_conv)
    x = jax.nn.relu(x)
    x = x.reshape((x.shape[0], -1))  # flatten
    
    if noisy:
        print('InvadersDDQNNetwork-Noisy[Johan]')
        initializer = None
        bias = True
        def net(x, features, bias, kernel_init):
            return NoisyNetwork(x, features, bias, kernel_init)
    else:
        initializer = nn.initializers.xavier_uniform()
        bias = None
        def net(x, features, bias, kernel_init):
            return nn.Dense(x, features, kernel_init)      

    if dueling:
        print('InvadersDDQNNetwork-Dueling[Johan]')
        adv = net(x, features=num_actions, bias=bias, kernel_init=initializer)
        val = net(x, features=1, bias=bias, kernel_init=initializer)
        q_values = val + (adv - (jnp.mean(adv, -1, keepdims=True)))

    else:
        q_values = net(x, features=num_actions, bias=bias, kernel_init=initializer)

    return atari_lib.DQNNetworkType(q_values)
Пример #15
0
 def apply(
         self,
         x,
         num_outputs,
         num_filters=64,
         block_sizes=[3, 4, 6, 3],  # pylint: disable=dangerous-default-value
         train=True):
     x = nn.Conv(x,
                 num_filters, (7, 7), (2, 2),
                 bias=False,
                 name='init_conv')
     x = nn.BatchNorm(x,
                      use_running_average=not train,
                      epsilon=1e-5,
                      name='init_bn')
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(block_sizes):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = BottleneckBlock(x,
                                 num_filters * 2**i,
                                 strides=strides,
                                 train=train)
     x = jnp.mean(x, axis=(1, 2))
     x_clf = nn.Dense(x, num_outputs, name='clf')
     # We return both the outputs from the dense layer *and* the features
     # that go into it.
     return x_clf, x
Пример #16
0
    def apply(self,
              x,
              blocks_per_group,
              channel_multiplier,
              num_outputs,
              dropout_rate=0.0,
              train=True):

        x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv')
        x = WideResnetGroup(x,
                            blocks_per_group,
                            16 * channel_multiplier,
                            dropout_rate=dropout_rate,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            32 * channel_multiplier, (2, 2),
                            dropout_rate=dropout_rate,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            64 * channel_multiplier, (2, 2),
                            dropout_rate=dropout_rate,
                            train=train)
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5)
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs)
        return x
Пример #17
0
  def apply(self,
            x,
            blocks_per_group,
            channel_multiplier,
            num_outputs,
            train=True):

    x = nn.Conv(
        x, 16, (3, 3), padding='SAME', name='init_conv')
    x = WideResnetShakeShakeGroup(
        x,
        blocks_per_group,
        16 * channel_multiplier,
        train=train)
    x = WideResnetShakeShakeGroup(
        x,
        blocks_per_group,
        32 * channel_multiplier, (2, 2),
        train=train)
    x = WideResnetShakeShakeGroup(
        x,
        blocks_per_group,
        64 * channel_multiplier, (2, 2),
        train=train)
    x = jax.nn.relu(x)
    x = nn.avg_pool(x, (8, 8))
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(x, num_outputs)
    return x
Пример #18
0
 def apply(self,
           x,
           num_classes,
           num_filters=64,
           num_layers=50,
           train=True,
           dtype=jnp.float32):
     if num_layers not in _block_size_options:
         raise ValueError('Please provide a valid number of layers')
     block_sizes = _block_size_options[num_layers]
     x = nn.Conv(x,
                 num_filters, (7, 7), (2, 2),
                 padding=[(3, 3), (3, 3)],
                 bias=False,
                 dtype=dtype,
                 name='init_conv')
     x = nn.BatchNorm(x,
                      use_running_average=not train,
                      momentum=0.9,
                      epsilon=1e-5,
                      dtype=dtype,
                      name='init_bn')
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(block_sizes):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = ResidualBlock(x,
                               num_filters * 2**i,
                               strides=strides,
                               train=train,
                               dtype=dtype)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(x, num_classes)
     x = nn.log_softmax(x)
     return x
Пример #19
0
 def apply(self, x):
     b = x.shape[0]
     x = nn.Conv(x, features=128, kernel_size=(4, ), padding='SAME')
     x = nn.BatchNorm(x)
     x = nn.leaky_relu(x)
     x = nn.avg_pool(x, window_shape=(2, ), padding='SAME')
     x = nn.Conv(x, features=256, kernel_size=(4, ), padding='SAME')
     x = nn.BatchNorm(x)
     x = nn.leaky_relu(x)
     x = nn.avg_pool(x, window_shape=(2, ), padding='SAME')
     x = x.reshape(b, -1)
     x = nn.Dense(x, features=128)
     x = nn.BatchNorm(x)
     x = nn.leaky_relu(x)
     x = nn.Dense(x, features=n_bins)
     x = nn.softmax(x)
     return x
Пример #20
0
  def apply(self, x, channels, strides=(1, 1), dropout_rate=0.0, train=True):
    batch_norm = nn.BatchNorm.partial(use_running_average=not train,
                                      momentum=0.9, epsilon=1e-5)

    y = batch_norm(x, name='bn1')
    y = jax.nn.relu(y)
    y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='conv1')
    y = batch_norm(y, name='bn2')
    y = jax.nn.relu(y)
    if dropout_rate > 0.0:
      y = nn.dropout(y, dropout_rate, deterministic=not train)
    y = nn.Conv(y, channels, (3, 3), padding='SAME', name='conv2')

    # Apply an up projection in case of channel mismatch
    if (x.shape[-1] != channels) or strides != (1, 1):
      x = nn.Conv(x, channels, (3, 3), strides, padding='SAME')
    return x + y
Пример #21
0
  def apply(self,
            x,
            num_classes=1000,
            train=False,
            resnet=None,
            patches=None,
            hidden_size=None,
            transformer=None,
            representation_size=None,
            classifier='gap'):

    n, h, w, c = x.shape

    # Embed the grid or patches of the grid.
    fh, fw = patches.size
    gh, gw = h // fh, w // fw
    if hidden_size:  # We can merge s2d+emb into a single conv; it's the same.
      x = nn.Conv(
          x,
          hidden_size, (fh, fw),
          strides=(fh, fw),
          padding='VALID',
          name='embedding')
    else:
      # This path often results in excessive padding.
      x = jnp.reshape(x, [n, gh, fh, gw, fw, c])
      x = jnp.transpose(x, [0, 1, 3, 2, 4, 5])
      x = jnp.reshape(x, [n, gh, gw, -1])

    # Here, x is a grid of embeddings.

    # (Possibly partial) Transformer.
    if transformer is not None:
      n, h, w, c = x.shape
      x = jnp.reshape(x, [n, h * w, c])

      # If we want to add a class token, add it here.
      if classifier == 'token':
        cls = self.param('cls', (1, 1, c), nn.initializers.zeros)
        cls = jnp.tile(cls, [n, 1, 1])
        x = jnp.concatenate([cls, x], axis=1)

      x = Encoder(x, train=train, name='Transformer', **transformer)

    if classifier == 'token':
      x = x[:, 0]
    elif classifier == 'gap':
      x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)

    if representation_size is not None:
      x = nn.Dense(x, representation_size, name='pre_logits')
      x = nn.tanh(x)
    else:
      x = IdentityLayer(x, name='pre_logits')

    x = nn.Dense(x, num_classes, name='head', kernel_init=nn.initializers.zeros)
    return x
Пример #22
0
  def apply(self, x, num_actions, minatar, env, normalize_obs, noisy, dueling, num_atoms,hidden_layer=2, neurons=512):
    del normalize_obs

    if minatar:
      x = x.squeeze(3)
      x = x[None, ...]
      x = x.astype(jnp.float32)
      x = nn.Conv(x, features=16, kernel_size=(3, 3, 3), strides=(1, 1, 1),  kernel_init=nn.initializers.xavier_uniform())
      x = jax.nn.relu(x)
      x = x.reshape((x.shape[0], -1))

    else:
      x = x[None, ...]
      x = x.astype(jnp.float32)
      x = x.reshape((x.shape[0], -1))


    if env is not None:
      x = x - env_inf[env]['MIN_VALS']
      x /= env_inf[env]['MAX_VALS'] - env_inf[env]['MIN_VALS']
      x = 2.0 * x - 1.0


    if noisy:
      def net(x, features):
        return NoisyNetwork(x, features)
    else:
      def net(x, features):
        return nn.Dense(x, features, kernel_init=nn.initializers.xavier_uniform())


    for _ in range(hidden_layer):
      x = net(x, features=neurons)
      #print('x:',x)
      x = jax.nn.relu(x)

    if dueling:
      print('dueling')
      adv = net(x,features=num_actions * num_atoms)
      value = net(x, features=num_atoms)
      adv = adv.reshape((adv.shape[0], num_actions, num_atoms))
      value = value.reshape((value.shape[0], 1, num_atoms))
      logits = value + (adv - (jnp.mean(adv, -1, keepdims=True)))
      probabilities = nn.softmax(logits)
      q_values = jnp.mean(logits, axis=2)

    else:
      #print('No dueling')
      x = net(x, features=num_actions * num_atoms)
      logits = x.reshape((x.shape[0], num_actions, num_atoms))
      probabilities = nn.softmax(logits)
      q_values = jnp.mean(logits, axis=2)


    return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Пример #23
0
    def apply(self,
              x,
              channels,
              strides=(1, 1),
              conv_kernel_init=initializers.lecun_normal(),
              normalizer='batch_norm',
              train=True):
        maybe_normalize = model_utils.get_normalizer(normalizer, train)
        y = maybe_normalize(x, name='bn1')
        y = jax.nn.relu(y)

        # Apply an up projection in case of channel mismatch
        if (x.shape[-1] != channels) or strides != (1, 1):
            x = nn.Conv(
                y,
                channels,
                (1, 1),  # Note: Some implementations use (3, 3) here.
                strides,
                padding='SAME',
                kernel_init=conv_kernel_init,
                bias=False)

        y = nn.Conv(y,
                    channels, (3, 3),
                    strides,
                    padding='SAME',
                    name='conv1',
                    kernel_init=conv_kernel_init,
                    bias=False)
        y = maybe_normalize(y, name='bn2')
        y = jax.nn.relu(y)
        y = nn.Conv(y,
                    channels, (3, 3),
                    padding='SAME',
                    name='conv2',
                    kernel_init=conv_kernel_init,
                    bias=False)

        if normalizer == 'none':
            y = model_utils.ScalarMultiply(y)

        return x + y
Пример #24
0
    def apply(self,
              x: jnp.ndarray,
              blocks_per_group: int,
              channel_multiplier: int,
              num_outputs: int,
              train: bool = True,
              true_gradient: bool = False) -> jnp.ndarray:
        """Implements a WideResnet with ShakeShake regularization module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, 3]
        where dim is the resolution of the image.
      blocks_per_group: How many resnet blocks to add to each group (should be
        4 blocks for a WRN26 as per standard shake shake implementation).
      channel_multiplier: The multiplier to apply to the number of filters in
        the model (1 is classical resnet, 6 for WRN26-2x6, etc...).
      num_outputs: Dimension of the output of the model (ie number of classes
        for a classification problem).
      train: If False, will use the moving average for batch norm statistics.
        Else, will use statistics computed on the batch.
      true_gradient: If true, the same mixing parameter will be used for the
        forward and backward pass (see paper for more details).

    Returns:
      The output of the WideResnet with ShakeShake regularization, a tensor of
      shape [batch_size, num_classes].
    """
        x = nn.Conv(x,
                    16, (3, 3),
                    padding='SAME',
                    kernel_init=utils.conv_kernel_init_fn,
                    bias=False,
                    name='init_conv')
        x = utils.activation(x, apply_relu=False, train=train, name='init_bn')
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      16 * channel_multiplier,
                                      train=train,
                                      true_gradient=true_gradient)
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      32 * channel_multiplier, (2, 2),
                                      train=train,
                                      true_gradient=true_gradient)
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      64 * channel_multiplier, (2, 2),
                                      train=train,
                                      true_gradient=true_gradient)
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, x.shape[1:3])
        x = x.reshape((x.shape[0], -1))
        return nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn)
Пример #25
0
 def apply(self, inputs, output_dim, kernel_size=3, biases=True):
     output = nn.Conv(inputs,
                      features=output_dim,
                      kernel_size=(kernel_size, kernel_size),
                      strides=(1, 1),
                      padding='SAME',
                      bias=biases)
     output = sum([
         output[:, ::2, ::2, :], output[:, 1::2, ::2, :],
         output[:, ::2, 1::2, :], output[:, 1::2, 1::2, :]
     ]) / 4.
     return output
Пример #26
0
    def apply(self, x, num_outputs):
        """Define the convolutional network architecture.

    Architecture originates from "Human-level control through deep reinforcement
    learning.", Nature 518, no. 7540 (2015): 529-533.
    Note that this is different than the one from  "Playing atari with deep
    reinforcement learning." arxiv.org/abs/1312.5602 (2013)
    """
        dtype = jnp.float32
        x = x.astype(dtype) / 255.
        x = nn.Conv(x,
                    features=32,
                    kernel_size=(8, 8),
                    strides=(4, 4),
                    name='conv1',
                    dtype=dtype)
        x = nn.relu(x)
        x = nn.Conv(x,
                    features=64,
                    kernel_size=(4, 4),
                    strides=(2, 2),
                    name='conv2',
                    dtype=dtype)
        x = nn.relu(x)
        x = nn.Conv(x,
                    features=64,
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    name='conv3',
                    dtype=dtype)
        x = nn.relu(x)
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(x, features=512, name='hidden', dtype=dtype)
        x = nn.relu(x)
        # Network used to both estimate policy (logits) and expected state value.
        # See github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py
        logits = nn.Dense(x, features=num_outputs, name='logits', dtype=dtype)
        policy_log_probabilities = nn.log_softmax(logits)
        value = nn.Dense(x, features=1, name='value', dtype=dtype)
        return policy_log_probabilities, value
  def apply(self, x, n_layers, n_features, n_kernel_sizes):
    
    x = jnp.expand_dims(x, axis=2)

    for layer in range(n_layers):
      features = n_features[layer]
      kernel_size = (n_kernel_sizes[layer], 1)
      x = nn.Conv(x, features=features, kernel_size=kernel_size)
      x = nn.relu(x)
    
    x = jnp.squeeze(x, axis=2)

    return x
Пример #28
0
    def apply(self,
              x,
              channels,
              strides,
              prob,
              alpha_min,
              alpha_max,
              beta_min,
              beta_max,
              train=True):
        batch_norm = nn.BatchNorm.partial(use_running_average=not train,
                                          momentum=0.9,
                                          epsilon=1e-5)

        y = batch_norm(x, name='bn_1_pre')
        y = nn.Conv(y,
                    channels, (1, 1),
                    padding='SAME',
                    name='1x1_conv_contract')
        y = batch_norm(y, name='bn_1_post')
        y = jax.nn.relu(y)
        y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='3x3')
        y = batch_norm(y, name='bn_2')
        y = jax.nn.relu(y)
        y = nn.Conv(y,
                    channels * 4, (1, 1),
                    padding='SAME',
                    name='1x1_conv_expand')
        y = batch_norm(y, name='bn_3')

        if train:
            y = shake.shake_drop_train(y, prob, alpha_min, alpha_max, beta_min,
                                       beta_max)
        else:
            y = shake.shake_drop_eval(y, prob, alpha_min, alpha_max)

        x = shortcut(x, channels * 4, strides)
        return x + y
Пример #29
0
    def apply(self,
              x: jnp.ndarray,
              blocks_per_group: int,
              channel_multiplier: int,
              num_outputs: int,
              train: bool = True) -> jnp.ndarray:
        """Implements a WideResnet module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, 3]
        where dim is the resolution of the image.
      blocks_per_group: How many resnet blocks to add to each group (should be
        4 blocks for a WRN28, and 6 for a WRN40).
      channel_multiplier: The multiplier to apply to the number of filters in
        the model (1 is classical resnet, 10 for WRN28-10, etc...).
      num_outputs: Dimension of the output of the model (ie number of classes
        for a classification problem).
      train: If False, will use the moving average for batch norm statistics.

    Returns:
      The output of the WideResnet, a tensor of shape [batch_size, num_classes].
    """
        first_x = x
        x = nn.Conv(x,
                    16, (3, 3),
                    padding='SAME',
                    name='init_conv',
                    kernel_init=utils.conv_kernel_init_fn,
                    bias=False)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            16 * channel_multiplier,
                            activate_before_residual=True,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            32 * channel_multiplier, (2, 2),
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            64 * channel_multiplier, (2, 2),
                            train=train)
        if FLAGS.use_additional_skip_connections:
            x = _output_add(x, first_x)
        x = utils.activation(x, train=train, name='pre-pool-bn')
        x = nn.avg_pool(x, x.shape[1:3])
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn)
        return x
Пример #30
0
 def apply(self,
           x,
           num_outputs,
           num_filters=64,
           num_layers=50,
           train=True,
           batch_stats=None,
           dtype=jnp.float32,
           batch_norm_momentum=0.9,
           batch_norm_epsilon=1e-5,
           virtual_batch_size=None,
           data_format=None):
     if num_layers not in _block_size_options:
         raise ValueError('Please provide a valid number of layers')
     block_sizes = _block_size_options[num_layers]
     x = nn.Conv(x,
                 num_filters, (3, 3), (1, 1),
                 'SAME',
                 bias=False,
                 dtype=dtype,
                 name='init_conv')
     x = normalization.VirtualBatchNorm(
         x,
         batch_stats=batch_stats,
         use_running_average=not train,
         momentum=batch_norm_momentum,
         epsilon=batch_norm_epsilon,
         dtype=dtype,
         name='init_bn',
         virtual_batch_size=virtual_batch_size,
         data_format=data_format)
     x = nn.relu(x)
     residual_block = block_type_options[num_layers]
     for i, block_size in enumerate(block_sizes):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = residual_block(x,
                                num_filters * 2**i,
                                strides=strides,
                                train=train,
                                batch_stats=batch_stats,
                                dtype=dtype,
                                batch_norm_momentum=batch_norm_momentum,
                                batch_norm_epsilon=batch_norm_epsilon,
                                virtual_batch_size=virtual_batch_size,
                                data_format=data_format)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(x, num_outputs, dtype=dtype)
     return x