Exemplo n.º 1
0
    def __call__(self, x, train=False):
        conv_block = partial(BasicConv, train=train, dtype=self.dtype)
        inception_a = partial(InceptionA, conv_block=conv_block)
        inception_b = partial(InceptionB, conv_block=conv_block)
        inception_c = partial(InceptionC, conv_block=conv_block)
        inception_d = partial(InceptionD, conv_block=conv_block)
        inception_e = partial(InceptionE, conv_block=conv_block)
        inception_aux = partial(InceptionAux, conv_block=conv_block)

        if self.transform_input:
            x = np.transpose(x, (0, 3, 1, 2))
            x_ch0 = jnp.expand_dims(x[:, 0],
                                    1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = jnp.expand_dims(x[:, 1],
                                    1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = jnp.expand_dims(x[:, 2],
                                    1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = jnp.concatenate((x_ch0, x_ch1, x_ch2), 1)
            x = np.transpose(x, (0, 2, 3, 1))

        x = conv_block(32,
                       kernel_size=(3, 3),
                       strides=(2, 2),
                       name='Conv2d_1a_3x3')(x)
        x = conv_block(32, kernel_size=(3, 3), name='Conv2d_2a_3x3')(x)
        x = conv_block(64,
                       kernel_size=(3, 3),
                       padding=[(1, 1), (1, 1)],
                       name='Conv2d_2b_3x3')(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2))
        x = conv_block(80, kernel_size=(1, 1), name='Conv2d_3b_1x1')(x)
        x = conv_block(192, kernel_size=(3, 3), name='Conv2d_4a_3x3')(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2))

        x = inception_a(pool_features=32, name='Mixed_5b')(x)
        x = inception_a(pool_features=64, name='Mixed_5c')(x)
        x = inception_a(pool_features=64, name='Mixed_5d')(x)
        x = inception_b(name='Mixed_6a')(x)
        x = inception_c(channels_7x7=128, name='Mixed_6b')(x)
        x = inception_c(channels_7x7=160, name='Mixed_6c')(x)
        x = inception_c(channels_7x7=160, name='Mixed_6d')(x)
        x = inception_c(channels_7x7=192, name='Mixed_6e')(x)

        aux = inception_aux(self.num_classes, name='AuxLogits')(x) \
              if train and self.aux_logits else None

        x = inception_d(name='Mixed_7a')(x)
        x = inception_e(name='Mixed_7b')(x)
        x = inception_e(name='Mixed_7c')(x)
        x = nn.avg_pool(x, (8, 8))
        x = nn.Dropout(0.5)(x, deterministic=not train)

        return x, aux
Exemplo n.º 2
0
    def __call__(self, inputs, is_training: bool):

        x = nn.Conv(features=self.num_ch,
                    use_bias=self.use_bias,
                    kernel_size=(self.conv_kernel_size, self.conv_kernel_size),
                    strides=(self.conv_stride, self.conv_stride),
                    padding=[(self.patch_shape[0], ) * 2,
                             (self.patch_shape[1], ) * 2])(inputs)
        x = nn.BatchNorm(use_running_average=not is_training,
                         momentum=self.bn_momentum,
                         epsilon=self.bn_epsilon,
                         dtype=self.dtype)(x)
        x = nn.max_pool(
            inputs=x,
            window_shape=(self.pool_window_size, ) * 2,
            strides=(self.pool_stride, ) * 2,
        )
        x = rearrange(
            x,
            'b (h ph) (w pw) c -> b (h w) (ph pw c)',
            ph=self.patch_shape[0],
            pw=self.patch_shape[1],
        )

        output = nn.Dense(features=self.embed_dim,
                          use_bias=self.use_bias,
                          dtype=self.dtype,
                          precision=self.precision,
                          kernel_init=self.kernel_init,
                          bias_init=self.bias_init)(x)
        return output
Exemplo n.º 3
0
 def setup(self):
     self.maxpool_conv = Sequential([
         lambda x: nn.max_pool(x, (2, 2), (2, 2)),
         DoubleConv(self.in_channels, self.out_channels, self.out_channels,
                    self.test, self.group_norm, self.num_groups,
                    self.activation),
     ])
Exemplo n.º 4
0
    def __call__(self, x):
        initializer = nn.initializers.xavier_uniform()
        conv_out = nn.Conv(features=self.num_ch,
                           kernel_size=(3, 3),
                           strides=1,
                           kernel_init=initializer,
                           padding='SAME')(x)
        if self.use_max_pooling:
            conv_out = nn.max_pool(conv_out,
                                   window_shape=(3, 3),
                                   padding='SAME',
                                   strides=(2, 2))

        for _ in range(self.num_blocks):
            block_input = conv_out
            conv_out = nn.relu(conv_out)
            conv_out = nn.Conv(features=self.num_ch,
                               kernel_size=(3, 3),
                               strides=1,
                               padding='SAME')(conv_out)
            conv_out = nn.relu(conv_out)
            conv_out = nn.Conv(features=self.num_ch,
                               kernel_size=(3, 3),
                               strides=1,
                               padding='SAME')(conv_out)
            conv_out += block_input

        return conv_out
Exemplo n.º 5
0
    def __call__(self, x):
        branch3x3 = self.conv_block(192,
                                    kernel_size=(1, 1),
                                    name='branch3x3_1')(x)
        branch3x3 = self.conv_block(320,
                                    kernel_size=(3, 3),
                                    strides=(2, 2),
                                    name='branch3x3_2')(branch3x3)

        branch7x7x3 = self.conv_block(192,
                                      kernel_size=(1, 1),
                                      name='branch7x7x3_1')(x)
        branch7x7x3 = self.conv_block(192,
                                      kernel_size=(1, 7),
                                      padding=[(0, 0), (3, 3)],
                                      name='branch7x7x3_2')(branch7x7x3)
        branch7x7x3 = self.conv_block(192,
                                      kernel_size=(7, 1),
                                      padding=[(3, 3), (0, 0)],
                                      name='branch7x7x3_3')(branch7x7x3)
        branch7x7x3 = self.conv_block(192,
                                      kernel_size=(3, 3),
                                      strides=(2, 2),
                                      name='branch7x7x3_4')(branch7x7x3)

        branch_pool = nn.max_pool(x, (3, 3), strides=(2, 2))

        outputs = [branch3x3, branch7x7x3, branch_pool]
        return jnp.concatenate(outputs, 3)
Exemplo n.º 6
0
    def __call__(self, x, train: bool = True):
        conv = partial(nn.Conv, use_bias=False, dtype=self.dtype)
        norm = partial(nn.BatchNorm,
                       use_running_average=not train,
                       momentum=0.9,
                       epsilon=1e-5,
                       dtype=self.dtype)

        x = conv(self.num_filters, (7, 7), (2, 2),
                 padding=[(3, 3), (3, 3)],
                 name='conv_init')(x)
        x = norm(name='bn_init')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        for i, block_size in enumerate(self.stage_sizes):
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                x = self.block_cls(self.num_filters * 2**i,
                                   strides=strides,
                                   conv=conv,
                                   norm=norm,
                                   act=self.act)(x)
        x = jnp.mean(x, axis=(1, 2))
        x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
        x = jnp.asarray(x, self.dtype)
        x = nn.log_softmax(x)
        return x
Exemplo n.º 7
0
 def __call__(self, x, train):
     maybe_normalize = model_utils.get_normalizer(self.normalizer, train)
     iterator = zip(self.num_filters, self.kernel_sizes,
                    self.kernel_paddings, self.window_sizes,
                    self.window_paddings, self.strides)
     for num_filters, kernel_size, kernel_padding, window_size, window_padding, stride in iterator:
         x = nn.Conv(num_filters, (kernel_size, kernel_size), (1, 1),
                     padding=kernel_padding,
                     kernel_init=self.kernel_init,
                     bias_init=self.bias_init)(x)
         x = model_utils.ACTIVATIONS[self.activation_fn](x)
         x = maybe_normalize()(x)
         x = nn.max_pool(x,
                         window_shape=(window_size, window_size),
                         strides=(stride, stride),
                         padding=window_padding)
     x = jnp.reshape(x, (x.shape[0], -1))
     for num_units in self.num_dense_units:
         x = nn.Dense(num_units,
                      kernel_init=self.kernel_init,
                      bias_init=self.bias_init)(x)
         x = model_utils.ACTIVATIONS[self.activation_fn](x)
         x = maybe_normalize()(x)
     x = nn.Dense(self.num_outputs,
                  kernel_init=self.kernel_init,
                  bias_init=self.bias_init)(x)
     return x
Exemplo n.º 8
0
    def __call__(
        self,
        inputs,
    ):
        """Applies ResNet model. Number of residual blocks inferred from hparams."""
        num_classes = self.num_classes
        hparams = self.hparams
        num_filters = self.num_filters
        dtype = self.dtype

        x = aqt_flax_layers.ConvAqt(
            features=num_filters,
            kernel_size=(7, 7),
            strides=(2, 2),
            padding=[(3, 3), (3, 3)],
            use_bias=False,
            dtype=dtype,
            name='init_conv',
            train=self.train,
            quant_context=self.quant_context,
            paxis_name='batch',
            hparams=hparams.conv_init,
        )(inputs)
        x = nn.BatchNorm(use_running_average=not self.train,
                         momentum=0.9,
                         epsilon=1e-5,
                         dtype=dtype,
                         name='init_bn')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        filter_multiplier = hparams.filter_multiplier
        for i, block_hparams in enumerate(hparams.residual_blocks):
            proj = block_hparams.conv_proj
            # For projection layers (unless it is the first layer), strides = (2, 2)
            if i > 0 and proj is not None:
                filter_multiplier *= 2
                strides = (2, 2)
            else:
                strides = (1, 1)
            x = ResidualBlock(filters=int(num_filters * filter_multiplier),
                              hparams=block_hparams,
                              quant_context=self.quant_context,
                              strides=strides,
                              train=self.train,
                              dtype=dtype)(x)
        x = jnp.mean(x, axis=(1, 2))

        x = aqt_flax_layers.DenseAqt(
            features=num_classes,
            dtype=dtype,
            train=self.train,
            quant_context=self.quant_context,
            paxis_name='batch',
            hparams=hparams.dense_layer,
        )(x, padding_mask=None)

        x = jnp.asarray(x, dtype)
        output = nn.log_softmax(x)
        return output
    def __call__(self, inputs, train: bool = True):
        """Passes the input through the network.
        Arguments:
            inputs:     [batch_size, height, width, channels]
            train:      bool
        Returns:
            output:     [batch_size, config.num_classes]
        """
        conv = partial(nn.Conv,
                       use_bias=False,
                       dtype=self.dtype,
                       precision=self.precision,
                       kernel_init=self.kernel_init)
        norm = partial(nn.BatchNorm,
                       use_running_average=not train,
                       momentum=self.bn_momentum,
                       epsilon=self.bn_epsilon,
                       dtype=self.dtype)

        y = conv(self.initial_filters,
                 kernel_size=(7, 7),
                 strides=(2, 2),
                 padding=[(3, 3), (3, 3)])(inputs)
        y = norm()(y)
        y = self.activation_fn(y)
        y = nn.max_pool(y, (3, 3), strides=(2, 2), padding='SAME')
        for i, block_size in enumerate(self.stage_sizes[:-1]):
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                y = BottleneckResNetBlock(
                    filters=self.initial_filters * 2**i,
                    strides=strides,
                    conv=conv,
                    norm=norm,
                    se_ratio=self.se_ratio,
                    projection_factor=self.projection_factor,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype,
                )(y)
        for j in range(self.stage_sizes[-1]):
            strides = (2, 2) if j == 0 and self.stride_one is False else (1, 1)
            y = BoTBlock(filters=self.initial_filters * 2**(i + 1),
                         strides=strides,
                         conv=conv,
                         norm=norm,
                         projection_factor=self.projection_factor,
                         activation_fn=self.activation_fn)(y)
        y = jnp.mean(y, axis=(1, 2))
        y = nn.Dense(self.num_classes,
                     dtype=self.dtype,
                     kernel_init=self.kernel_init,
                     bias_init=self.bias_init)(y)
        y = jnp.asarray(y, dtype=self.dtype)
        return y
Exemplo n.º 10
0
    def __call__(self, x):
        for feat in self.conv_features:
            x = nn.Conv(feat, kernel_size=(3, 3))(x)
            x = nn.max_pool(x, window_shape=(2, 2))
            x = nn.relu(x)

        x = x.reshape((x.shape[0], -1))
        for feat in self.mlp_features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.mlp_features[-1])(x)
        return x
Exemplo n.º 11
0
 def __call__(self, x):
     x = self.act(x)
     path = x
     for _ in range(self.n_stages):
         path = nn.max_pool(path,
                            window_shape=(5, 5),
                            strides=(1, 1),
                            padding='SAME')
         path = ncsn_conv3x3(path, self.features, stride=1, bias=False)
         x = path + x
     return x
Exemplo n.º 12
0
 def __call__(self, x, train=False):
     for v in self.cfg:
         if v == 'M':
             x = nn.max_pool(x, (2, 2), (2, 2))
         else:
             x = nn.Conv(v, (3, 3), padding='SAME', dtype=self.dtype)(x)
             if self.batch_norm:
                 x = nn.BatchNorm(use_running_average=not train,
                                  momentum=0.1,
                                  dtype=self.dtype)(x)
             x = nn.relu(x)
     return x
Exemplo n.º 13
0
    def __call__(self, x):
        x = nn.Conv(features=16, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

        x = nn.Conv(features=32, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

        x = nn.Conv(features=64, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)

        x = nn.Dense(features=NB_CLASSES)(x)
        x = nn.softmax(x)

        return x
Exemplo n.º 14
0
 def basic_module(self, x):
     x = nn.Conv(features=90,
                 kernel_size=(9, 9),
                 padding='VALID',
                 dtype=jp.float64)(x)
     x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = nn.ConvTranspose(features=1,
                          kernel_size=(2, 2),
                          strides=(2, 2),
                          dtype=jp.float64)(x)
     x = x.reshape(x.shape[0], -1)
     x = jp.prod(x, 1)
     return x
Exemplo n.º 15
0
    def __call__(self, x, train):
        if self.num_layers not in _block_size_options:
            raise ValueError('Please provide a valid number of layers')
        block_sizes = _block_size_options[self.num_layers]

        x = nn.Conv(self.num_filters, (7, 7), (2, 2),
                    use_bias=False,
                    dtype=self.dtype,
                    name='init_conv')(x)
        if self.use_bn:
            x = normalization.VirtualBatchNorm(
                momentum=self.batch_norm_momentum,
                epsilon=self.batch_norm_epsilon,
                dtype=self.dtype,
                name='init_bn',
                batch_size=self.batch_size,
                virtual_batch_size=self.virtual_batch_size,
                total_batch_size=self.total_batch_size,
                data_format=self.data_format)(x, use_running_average=not train)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        if self.block_type == 'post_activation':
            residual_block = ResidualBlock
        elif self.block_type == 'pre_activation':
            residual_block = PreActResidualBlock
        else:
            raise ValueError('Invalid Block Type: {}'.format(self.block_type))
        index = 0
        for i, block_size in enumerate(block_sizes):
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                index += 1
                x = residual_block(
                    self.num_filters * 2**i,
                    strides=strides,
                    dtype=self.dtype,
                    batch_norm_momentum=self.batch_norm_momentum,
                    batch_norm_epsilon=self.batch_norm_epsilon,
                    batch_size=self.batch_size,
                    virtual_batch_size=self.virtual_batch_size,
                    total_batch_size=self.total_batch_size,
                    data_format=self.data_format,
                    bn_relu_conv=self.bn_relu_conv,
                    use_bn=self.use_bn,
                    activation_function=self.activation_function)(x,
                                                                  train=train)
        x = jnp.mean(x, axis=(1, 2))
        if self.dropout_rate > 0.0:
            x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
        x = nn.Dense(self.num_outputs, dtype=self.dtype)(x)
        return x
Exemplo n.º 16
0
def features(x, num_layers, normalizer, dtype, train):
    """Implements the feature extraction portion of the network."""

    layers = _layer_size_options[num_layers]
    conv = functools.partial(nn.Conv, use_bias=False, dtype=dtype)
    maybe_normalize = model_utils.get_normalizer(normalizer, train)
    for l in layers:
        if l == 'M':
            x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        else:
            x = conv(features=l, kernel_size=(3, 3),
                     padding=((1, 1), (1, 1)))(x)
            x = maybe_normalize()(x)
            x = nn.relu(x)
    return x
Exemplo n.º 17
0
 def test_max_pool(self):
     x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
     pool = lambda x: nn.max_pool(x, (2, 2))
     expected_y = jnp.array([
         [4., 5.],
         [7., 8.],
     ]).reshape((1, 2, 2, 1))
     y = pool(x)
     np.testing.assert_allclose(y, expected_y)
     y_grad = jax.grad(lambda x: pool(x).sum())(x)
     expected_grad = jnp.array([
         [0., 0., 0.],
         [0., 1., 1.],
         [0., 1., 1.],
     ]).reshape((1, 3, 3, 1))
     np.testing.assert_allclose(y_grad, expected_grad)
Exemplo n.º 18
0
  def __call__(self, inputs, train: bool = False):
    norm = functools.partial(nn.BatchNorm, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=self.dtype)

    # replace 2x2 strides with dilated convs
    if self.use_dilation is None:
      self.use_dilation = [False, False, False]

    if len(self.use_dilation) != 3:
      raise ValueError("use_dilation should be None " "or a 3-element tuple, got {}".format(self.use_dilation))

    x = nn.Conv(64, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], use_bias=False, dtype=self.dtype, name='conv1')(inputs)
    x = norm(name='bn1')(x)
    x = nn.relu(x)
    x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=[(1, 1), (1, 1)])

    dilation = 1
    for i, block_size in enumerate(self.layers):
      features = 64 * 2**i
      downsample = False
      previous_dilation = dilation
      strides = (2, 2) if i > 0 else (1, 1)

      if i > 0 and self.use_dilation[i - 1]:
        dilation *= strides[0]
        strides = (1, 1)

      block_expansion = 4 if "Bottleneck" in self.block.__name__ else 1

      if strides != (1, 1) or x.shape[-1] != features * block_expansion:
        downsample = True

      kwargs = {
          'features': features,
          'strides': strides,
          'downsample': downsample,
          'groups': self.groups,
          'dilation': previous_dilation,
          'base_width': self.width_per_group,
          'norm': norm,
          'dtype': self.dtype,
      }

      x = Layer(self.block, block_size, dilation, kwargs, name=f'layer{i+1}')(x)

    return x
Exemplo n.º 19
0
    def __call__(self, x, train):
        del train
        encoder_keys = [
            'filter_sizes',
            'kernel_sizes',
            'kernel_paddings',
            'window_sizes',
            'window_paddings',
            'strides',
            'activations',
        ]
        if len(set(len(self.encoder[k]) for k in encoder_keys)) > 1:
            raise ValueError(
                'The elements in encoder dict do not have the same length.')

        decoder_keys = [
            'filter_sizes',
            'kernel_sizes',
            'window_sizes',
            'paddings',
            'activations',
        ]
        if len(set(len(self.decoder[k]) for k in decoder_keys)) > 1:
            raise ValueError(
                'The elements in decoder dict do not have the same length.')

        # encoder
        for i in range(len(self.encoder['filter_sizes'])):
            x = nn.Conv(self.encoder['filter_sizes'][i],
                        self.encoder['kernel_sizes'][i],
                        padding=self.encoder['kernel_paddings'][i])(x)
            x = model_utils.ACTIVATIONS[self.encoder['activations'][i]](x)
            x = nn.max_pool(x,
                            self.encoder['window_sizes'][i],
                            strides=self.encoder['strides'][i],
                            padding=self.encoder['window_paddings'][i])

        # decoder
        for i in range(len(self.decoder['filter_sizes'])):
            x = nn.ConvTranspose(self.decoder['filter_sizes'][i],
                                 self.decoder['kernel_sizes'][i],
                                 self.decoder['window_sizes'][i],
                                 padding=self.decoder['paddings'][i])(x)
            x = model_utils.ACTIVATIONS[self.decoder['activations'][i]](x)
        return x
Exemplo n.º 20
0
 def __call__(self, x, train):
     if self.num_layers not in _block_size_options:
         raise ValueError('Please provide a valid number of layers')
     block_sizes = _block_size_options[self.num_layers]
     conv = functools.partial(nn.Conv, padding=[(3, 3), (3, 3)])
     x = conv(self.num_filters,
              kernel_size=(7, 7),
              strides=(2, 2),
              use_bias=False,
              dtype=self.dtype,
              name='conv0')(x)
     x = normalization.VirtualBatchNorm(
         momentum=self.batch_norm_momentum,
         epsilon=self.batch_norm_epsilon,
         name='init_bn',
         axis_name=self.axis_name,
         axis_index_groups=self.axis_index_groups,
         dtype=self.dtype,
         batch_size=self.batch_size,
         virtual_batch_size=self.virtual_batch_size,
         total_batch_size=self.total_batch_size,
         data_format=self.data_format)(x, use_running_average=not train)
     x = nn.relu(x)  # MLPerf-required
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(block_sizes):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = ResidualBlock(self.num_filters * 2**i,
                               strides=strides,
                               axis_name=self.axis_name,
                               axis_index_groups=self.axis_index_groups,
                               dtype=self.dtype,
                               batch_norm_momentum=self.batch_norm_momentum,
                               batch_norm_epsilon=self.batch_norm_epsilon,
                               bn_output_scale=self.bn_output_scale,
                               batch_size=self.batch_size,
                               virtual_batch_size=self.virtual_batch_size,
                               total_batch_size=self.total_batch_size,
                               data_format=self.data_format)(x, train=train)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(self.num_classes,
                  kernel_init=nn.initializers.normal(),
                  dtype=self.dtype)(x)
     return x
Exemplo n.º 21
0
    def setup(self):
        self.conv1 = nn.Conv(features=64,
                             kernel_size=(5, 5),
                             strides=(1, 1),
                             padding=((0, 0), (0, 0)),
                             use_bias=True)
        self.conv2 = nn.Conv(features=64,
                             kernel_size=(5, 5),
                             strides=(1, 1),
                             padding=((0, 0), (0, 0)),
                             use_bias=True)
        self.dense1 = nn.Dense(features=384)
        self.dense2 = nn.Dense(features=192)
        self.dense3 = nn.Dense(features=10)

        self.pool = lambda x: nn.max_pool(
            x, window_shape=(2, 2), strides=(2, 2), padding=((0, 0), (0, 0)))

        self.activation = nn.leaky_relu
Exemplo n.º 22
0
    def __call__(self, x):
        branch3x3 = self.conv_block(384,
                                    kernel_size=(3, 3),
                                    strides=(2, 2),
                                    name='branch3x3')(x)

        branch3x3dbl = self.conv_block(64,
                                       kernel_size=(1, 1),
                                       name='branch3x3dbl_1')(x)
        branch3x3dbl = self.conv_block(96,
                                       kernel_size=(3, 3),
                                       padding=[(1, 1), (1, 1)],
                                       name='branch3x3dbl_2')(branch3x3dbl)
        branch3x3dbl = self.conv_block(96,
                                       kernel_size=(3, 3),
                                       strides=(2, 2),
                                       name='branch3x3dbl_3')(branch3x3dbl)

        branch_pool = nn.max_pool(x, (3, 3), strides=(2, 2))

        outputs = [branch3x3, branch3x3dbl, branch_pool]
        return jnp.concatenate(outputs, 3)
Exemplo n.º 23
0
    def __call__(
            self,
            x: jnp.ndarray,
            train: bool = True,
            debug: bool = False) -> Union[jnp.ndarray, Dict[str, jnp.ndarray]]:
        """Applies the Bit ResNet model to the inputs.

    Args:
      x: Inputs to the model.
      train: Unused.
      debug: Unused.

    Returns:
       Un-normalized logits if `num_outputs` is provided, a dictionary with
       representations otherwise.
    """
        del train
        del debug
        if self.max_output_stride not in [4, 8, 16, 32]:
            raise ValueError('Only supports output strides of [4, 8, 16, 32]')

        blocks, bottleneck = _BLOCK_SIZE_OPTIONS[self.num_layers]

        width = int(64 * self.width_factor)

        # Root block.
        x = StdConv(width, (7, 7), (2, 2), use_bias=False, name='conv_root')(x)
        x = nn.GroupNorm(num_groups=self.gn_num_groups,
                         epsilon=1e-4,
                         name='gn_root')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        representations = {'stem': x}

        # Stages.
        x = ResNetStage(blocks[0],
                        width,
                        first_stride=(1, 1),
                        bottleneck=bottleneck,
                        gn_num_groups=self.gn_num_groups,
                        name='block1')(x)
        stride = 4
        for i, block_size in enumerate(blocks[1:], 1):
            max_stride_reached = self.max_output_stride <= stride
            x = ResNetStage(block_size,
                            width * 2**i,
                            first_stride=(2, 2) if not max_stride_reached else
                            (1, 1),
                            first_dilation=(2, 2) if max_stride_reached else
                            (1, 1),
                            bottleneck=bottleneck,
                            gn_num_groups=self.gn_num_groups,
                            name=f'block{i + 1}')(x)
            if not max_stride_reached:
                stride *= 2
            representations[f'stage_{i + 1}'] = x

        # Head.
        x = jnp.mean(x, axis=(1, 2))
        x = IdentityLayer(name='pre_logits')(x)
        representations['pre_logits'] = x
        x = nn.Dense(self.num_outputs,
                     kernel_init=nn.initializers.zeros,
                     name='head')(x)
        return x, representations
Exemplo n.º 24
0
    def __call__(self, inputs, *, train):

        x = inputs
        # (Possibly partial) ResNet root.
        if self.resnet is not None:
            width = int(64 * self.resnet.width_factor)

            # Root block.
            x = models_resnet.StdConv(features=width,
                                      kernel_size=(7, 7),
                                      strides=(2, 2),
                                      use_bias=False,
                                      name='conv_root')(x)
            x = nn.GroupNorm(name='gn_root')(x)
            x = nn.relu(x)
            x = nn.max_pool(x,
                            window_shape=(3, 3),
                            strides=(2, 2),
                            padding='SAME')

            # ResNet stages.
            if self.resnet.num_layers:
                x = models_resnet.ResNetStage(
                    block_size=self.resnet.num_layers[0],
                    nout=width,
                    first_stride=(1, 1),
                    name='block1')(x)
                for i, block_size in enumerate(self.resnet.num_layers[1:], 1):
                    x = models_resnet.ResNetStage(block_size=block_size,
                                                  nout=width * 2**i,
                                                  first_stride=(2, 2),
                                                  name=f'block{i + 1}')(x)

        n, h, w, c = x.shape

        # We can merge s2d+emb into a single conv; it's the same.
        x = nn.Conv(features=self.hidden_size,
                    kernel_size=self.patches.size,
                    strides=self.patches.size,
                    padding='VALID',
                    name='embedding')(x)

        # Here, x is a grid of embeddings.

        # Transformer.
        n, h, w, c = x.shape
        x = jnp.reshape(x, [n, h * w, c])

        # If we want to add a class token, add it here.
        if self.classifier == 'token':
            cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
            cls = jnp.tile(cls, [n, 1, 1])
            x = jnp.concatenate([cls, x], axis=1)

        x = Encoder(name='Transformer', **self.transformer)(x, train=train)

        if self.classifier == 'token':
            x = x[:, 0]
        elif self.classifier == 'gap':
            x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)
        else:
            raise ValueError(f'Invalid classifier={self.classifier}')

        if self.representation_size is not None:
            x = nn.Dense(features=self.representation_size,
                         name='pre_logits')(x)
            x = nn.tanh(x)
        else:
            x = IdentityLayer(name='pre_logits')(x)

        if self.num_classes:
            x = nn.Dense(features=self.num_classes,
                         name='head',
                         kernel_init=nn.initializers.zeros)(x)
        return x
Exemplo n.º 25
0
  def __call__(
      self,
      inputs,
  ):
    """Applies ResNet model. Number of residual blocks inferred from hparams."""
    num_classes = self.num_classes
    hparams = self.hparams
    num_filters = self.num_filters
    dtype = self.dtype
    assert hparams.act_function in act_function_zoo.keys(
    ), 'Activation function type is not supported.'

    x = aqt_flax_layers.ConvAqt(
        features=num_filters,
        kernel_size=(7, 7),
        strides=(2, 2),
        padding=[(3, 3), (3, 3)],
        use_bias=False,
        dtype=dtype,
        name='init_conv',
        train=self.train,
        quant_context=self.quant_context,
        paxis_name='batch',
        hparams=hparams.conv_init,
    )(
        inputs)
    x = nn.BatchNorm(
        use_running_average=not self.train,
        momentum=0.9,
        epsilon=1e-5,
        dtype=dtype,
        name='init_bn')(
            x)
    if hparams.act_function == 'relu':
      x = nn.relu(x)
      x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    else:
      # TODO(yichi): try adding other activation functions here
      # Use avg pool so that for binary nets, the distribution is symmetric.
      x = nn.avg_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    filter_multiplier = hparams.filter_multiplier
    for i, block_hparams in enumerate(hparams.residual_blocks):
      proj = block_hparams.conv_proj
      # For projection layers (unless it is the first layer), strides = (2, 2)
      if i > 0 and proj is not None:
        filter_multiplier *= 2
        strides = (2, 2)
      else:
        strides = (1, 1)
      x = ResidualBlock(
          filters=int(num_filters * filter_multiplier),
          hparams=block_hparams,
          quant_context=self.quant_context,
          strides=strides,
          train=self.train,
          dtype=dtype)(
              x)
    if hparams.act_function == 'none':
      # The DenseAQT below is not binarized.
      # If removing the activation functions, there will be no act function
      # between the last residual block and the dense layer.
      # So add a ReLU in that case.
      # TODO(yichi): try BPReLU
      x = nn.relu(x)
    else:
      pass
    x = jnp.mean(x, axis=(1, 2))

    x = aqt_flax_layers.DenseAqt(
        features=num_classes,
        dtype=dtype,
        train=self.train,
        quant_context=self.quant_context,
        paxis_name='batch',
        hparams=hparams.dense_layer,
    )(x, padding_mask=None)

    x = jnp.asarray(x, dtype)
    output = nn.log_softmax(x)
    return output