예제 #1
0
 def apply(self, x):
     net = nn.Dense(x, 500, name='fc1')
     net = nn.leaky_relu(net)
     net = nn.BatchNorm(net)
     net = nn.Dense(net, 500, name='fc2')
     net = nn.leaky_relu(net)
     net = nn.BatchNorm(net)
     net = nn.Dense(net, 500, name='fc3')
     net = nn.leaky_relu(net)
     net = nn.BatchNorm(net)
     return nn.softmax(nn.Dense(net, n_bin))
예제 #2
0
    def apply(self,
              x,
              blocks_per_group,
              channel_multiplier,
              num_outputs,
              dropout_rate=0.0,
              train=True):

        x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv')
        x = WideResnetGroup(x,
                            blocks_per_group,
                            16 * channel_multiplier,
                            dropout_rate=dropout_rate,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            32 * channel_multiplier, (2, 2),
                            dropout_rate=dropout_rate,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            64 * channel_multiplier, (2, 2),
                            dropout_rate=dropout_rate,
                            train=train)
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5)
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs)
        return x
예제 #3
0
    def apply(self,
              x,
              num_filters=64,
              block_sizes=(3, 4, 6, 3),
              train=True,
              block=BottleneckBlock,
              small_inputs=False):
        if small_inputs:
            x = nn.Conv(x,
                        num_filters,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        bias=False,
                        name="init_conv")
        else:
            x = nn.Conv(x,
                        num_filters,
                        kernel_size=(7, 7),
                        strides=(2, 2),
                        bias=False,
                        name="init_conv")
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         epsilon=1e-5,
                         name="init_bn")
        if not small_inputs:
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
        for i, block_size in enumerate(block_sizes):
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                x = block(x, num_filters * 2**i, strides=strides, train=train)

        return x
예제 #4
0
 def apply(
         self,
         x,
         num_outputs,
         num_filters=64,
         block_sizes=[3, 4, 6, 3],  # pylint: disable=dangerous-default-value
         train=True):
     x = nn.Conv(x,
                 num_filters, (7, 7), (2, 2),
                 bias=False,
                 name='init_conv')
     x = nn.BatchNorm(x,
                      use_running_average=not train,
                      epsilon=1e-5,
                      name='init_bn')
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(block_sizes):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = BottleneckBlock(x,
                                 num_filters * 2**i,
                                 strides=strides,
                                 train=train)
     x = jnp.mean(x, axis=(1, 2))
     x_clf = nn.Dense(x, num_outputs, name='clf')
     # We return both the outputs from the dense layer *and* the features
     # that go into it.
     return x_clf, x
예제 #5
0
 def apply(self,
           x,
           num_classes,
           num_filters=64,
           num_layers=50,
           train=True,
           dtype=jnp.float32):
     if num_layers not in _block_size_options:
         raise ValueError('Please provide a valid number of layers')
     block_sizes = _block_size_options[num_layers]
     x = nn.Conv(x,
                 num_filters, (7, 7), (2, 2),
                 padding=[(3, 3), (3, 3)],
                 bias=False,
                 dtype=dtype,
                 name='init_conv')
     x = nn.BatchNorm(x,
                      use_running_average=not train,
                      momentum=0.9,
                      epsilon=1e-5,
                      dtype=dtype,
                      name='init_bn')
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(block_sizes):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = ResidualBlock(x,
                               num_filters * 2**i,
                               strides=strides,
                               train=train,
                               dtype=dtype)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(x, num_classes)
     x = nn.log_softmax(x)
     return x
예제 #6
0
def conv2d(inputs: tf.Tensor,
           conv_filters: Optional[int],
           config: ModelConfig,
           kernel_size: Union[int, Tuple[int, int]] = (1, 1),
           strides: Tuple[int, int] = (1, 1),
           use_batch_norm: bool = True,
           use_bias: bool = False,
           activation: Any = None,
           depthwise: bool = False,
           train: bool = True,
           conv_name: str = None,
           bn_name: str = None) -> jnp.ndarray:
  """Convolutional layer with possibly batch norm and activation.

  Args:
    inputs: Input data with dimensions (batch, spatial_dims..., features).
    conv_filters: Number of convolution filters.
    config: Configuration for the model.
    kernel_size: Size of the kernel, as a tuple of int.
    strides: Strides for the convolution, as a tuple of int.
    use_batch_norm: Whether batch norm should be applied to the output.
    use_bias: Whether we should add bias to the output of the first convolution.
    activation: Name of the activation function to use.
    depthwise: If true, will use depthwise convolutions.
    train: Whether the model should behave in training or inference mode.
    conv_name: Name to give to the convolution layer.
    bn_name: Name to give to the batch norm layer.

  Returns:
    The output of the convolutional layer.
  """
  conv_fn = DepthwiseConv if depthwise else flax.nn.Conv
  kernel_size = ((kernel_size, kernel_size)
                 if isinstance(kernel_size, int) else tuple(kernel_size))
  conv_name = conv_name if conv_name else 'conv2d'
  bn_name = bn_name if bn_name else 'batch_normalization'

  x = conv_fn(
      inputs,
      conv_filters,
      kernel_size,
      tuple(strides),
      padding='SAME',
      bias=use_bias,
      kernel_init=conv_kernel_init_fn,
      name=conv_name)

  if use_batch_norm:
    x = nn.BatchNorm(
        x,
        use_running_average=not train or FLAGS.from_pretrained_checkpoint,
        momentum=config.bn_momentum,
        epsilon=config.bn_epsilon,
        name=bn_name,
        axis_name='batch')

  if activation is not None:
    x = getattr(flax.nn.activation, activation.lower())(x)
  return x
예제 #7
0
 def apply(self, x):
     b = x.shape[0]
     x = nn.Conv(x, features=128, kernel_size=(4, ), padding='SAME')
     x = nn.BatchNorm(x)
     x = nn.leaky_relu(x)
     x = nn.avg_pool(x, window_shape=(2, ), padding='SAME')
     x = nn.Conv(x, features=256, kernel_size=(4, ), padding='SAME')
     x = nn.BatchNorm(x)
     x = nn.leaky_relu(x)
     x = nn.avg_pool(x, window_shape=(2, ), padding='SAME')
     x = x.reshape(b, -1)
     x = nn.Dense(x, features=128)
     x = nn.BatchNorm(x)
     x = nn.leaky_relu(x)
     x = nn.Dense(x, features=n_bins)
     x = nn.softmax(x)
     return x
예제 #8
0
    def apply(self, g, x, in_feats, hidden_feats, out_feats, num_layers, dropout):
        with nn.stochastic(jax.random.PRNGKey(0)):
            x = SAGEConv(g, x, in_feats, hidden_feats)

            for idx in range(num_layers-2):
                x = SAGEConv(g, x, hidden_feats, hidden_feats)
                x = nn.BatchNorm(x)
                x = nn.dropout(x, rate=dropout)

            x = SAGEConv(g, x, hidden_feats, out_feats)

        return jax.nn.log_softmax(x, axis=-1)
예제 #9
0
 def apply(self,
           x,
           num_classes,
           num_filters=64,
           num_layers=50,
           train=True,
           axis_name=None,
           axis_index_groups=None,
           dtype=jnp.float32,
           conv0_space_to_depth=False):
     if num_layers not in _block_size_options:
         raise ValueError('Please provide a valid number of layers')
     block_sizes = _block_size_options[num_layers]
     if conv0_space_to_depth:
         conv = SpaceToDepthConv.partial(block_size=(2, 2),
                                         padding=[(2, 1), (2, 1)])
     else:
         conv = nn.Conv.partial(padding=[(3, 3), (3, 3)])
     x = conv(x,
              num_filters,
              kernel_size=(7, 7),
              strides=(2, 2),
              bias=False,
              dtype=dtype,
              name='conv0')
     x = nn.BatchNorm(x,
                      use_running_average=not train,
                      momentum=0.9,
                      epsilon=1e-5,
                      name='init_bn',
                      axis_name=axis_name,
                      axis_index_groups=axis_index_groups,
                      dtype=dtype)
     x = nn.relu(x)  # MLPerf-required
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(block_sizes):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = ResidualBlock(x,
                               num_filters * 2**i,
                               strides=strides,
                               train=train,
                               axis_name=axis_name,
                               axis_index_groups=axis_index_groups,
                               dtype=dtype)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(x,
                  num_classes,
                  kernel_init=nn.initializers.normal(),
                  dtype=dtype)
     x = nn.log_softmax(x)
     return x
예제 #10
0
  def apply(self, x, num_classes, parameters, num_filters=64,
            train=True, axis_name=None, num_layers='34'):
    block_sizes = [3, 4, 6]
    data_format = 'channels_last'
    if ('conv0_space_to_depth' in parameters and
        parameters['conv0_space_to_depth']):
      # conv0 uses space-to-depth transform for TPU performance.
      x = func_conv0_space_to_depth(inputs=x, data_format=data_format,
                                    dtype=parameters['dtype'])
    else:
      x = conv2d_fixed_padding(
          inputs=x,
          filters=num_filters,
          kernel_size=7,
          strides=2,
          data_format=data_format,
          name='init_conv')

    replica_groups = _make_replica_groups(parameters)
    x = nn.BatchNorm(
        x,
        use_running_average=not train,
        momentum=0.9,
        epsilon=1e-5,
        name='init_bn',
        axis_name=axis_name,
        dtype=parameters['dtype'],
        axis_index_groups=replica_groups)
    x = nn.relu(x)
    x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    for i, block_size in enumerate(block_sizes):
      for j in range(block_size):
        strides = 2 if i == 1 and j == 0 else 1
        use_projection = False if i == 0 or j > 0 else True
        x = ResidualBlock(
            x,
            num_filters * 2**i,
            parameters,
            strides=strides,
            train=train,
            axis_name=axis_name,
            use_projection=use_projection,
            data_format=data_format)
    if num_layers == '34':
      x = jnp.mean(x, axis=(1, 2))
      x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(),
                   dtype=jnp.float32)  # TODO(deveci): dtype=dtype
      x = nn.log_softmax(x)
    return x
예제 #11
0
 def apply(self,
           x,
           num_classes,
           train=True,
           batch_stats=None,
           axis_name=None,
           dtype=jnp.float32):
   x = nn.BatchNorm(x,
                    batch_stats=batch_stats,
                    use_running_average=not train,
                    momentum=0.9, epsilon=1e-5,
                    name='init_bn', axis_name=axis_name, dtype=dtype)
   x = jnp.mean(x, axis=(1, 2))
   x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(),
                dtype=dtype)
   x = nn.log_softmax(x)
   return x
예제 #12
0
 def apply(self, x, num_outputs, train=True):
   x = nn.Conv(x, self.NUM_FILTERS, (7, 7), (2, 2), bias=False,
               name='init_conv')
   x = nn.BatchNorm(x,
                    use_running_average=not train,
                    momentum=0.9, epsilon=1e-5,
                    name='init_bn')
   x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
   for i, block_size in enumerate(self.BLOCK_SIZES):
     for j in range(block_size):
       strides = (2, 2) if i > 0 and j == 0 else (1, 1)
       x = BottleneckBlock(x, self.NUM_FILTERS * 2 ** i,
                           strides=strides, groups=self.GROUPS,
                           base_width=self.WIDTH_PER_GROUP, train=train)
   x = jnp.mean(x, axis=(1, 2))
   x = nn.Dense(x, num_outputs, name='clf')
   return x
예제 #13
0
 def apply(
     self,
     x: Array,
     upsampling_factor: int,
     max_num_features: int,
 ) -> Array:
     activation = nn.silu
     num_upsampling_layers = np.log2(upsampling_factor).astype(int)
     resizes = [1] + [2] * num_upsampling_layers
     feature_sizes = [
         max(max_num_features // 2**n, 2) for n in range(len(resizes))
     ]
     for fs, rs in zip(feature_sizes, resizes):
         x = Upsample2D(x, factor=rs)
         x = PeriodicSpaceConv(x, fs, kernel_size=(3, 3, 3))
         x = nn.BatchNorm(x)
         x = activation(x)
     x = PeriodicSpaceConv(x, features=2, kernel_size=(3, 3, 3))
     return x
예제 #14
0
    def apply(self,
              x,
              num_outputs,
              pyramid_alpha=200,
              pyramid_depth=272,
              train=True):
        assert (pyramid_depth - 2) % 9 == 0

        # Shake-drop hyper-params
        mask_prob = 0.5
        alpha_min, alpha_max = (-1.0, 1.0)
        beta_min, beta_max = (0.0, 1.0)

        # Bottleneck network size
        blocks_per_group = (pyramid_depth - 2) // 9
        # See Eqn 2 in https://arxiv.org/abs/1610.02915
        num_channels = 16
        # N in https://arxiv.org/abs/1610.02915
        total_blocks = blocks_per_group * 3
        delta_channels = pyramid_alpha / total_blocks

        x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv')
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5,
                         name='init_bn')

        layer_num = 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels), (1, 1),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels),
                                    ((2, 2) if block_i == 0 else (1, 1)),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels),
                                    ((2, 2) if block_i == 0 else (1, 1)),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        assert layer_num - 1 == total_blocks
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5,
                         name='final_bn')
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs)
        return x