Esempio n. 1
0
 def __call__(self, x):
     x = nn.Conv(features=32, kernel_size=(3, 3))(x)
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = nn.Conv(features=64, kernel_size=(3, 3))(x)
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = x.reshape((x.shape[0], -1))  # flatten
     x = nn.Dense(features=256)(x)
     x = nn.relu(x)
     x = nn.Dense(features=10)(x)
     return x
Esempio n. 2
0
 def __call__(self, x):
     x = nn.Conv(features=32, kernel_size=(3, 3))(x)
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = nn.Conv(features=64, kernel_size=(3, 3))(x)
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = x.reshape((x.shape[0], -1))  # Flatten
     x = nn.Dense(features=256)(x)
     x = nn.relu(x)
     x = nn.Dense(features=10)(x)  # There are 10 classes in MNIST
     x = nn.log_softmax(x)
     return x
Esempio n. 3
0
    def __call__(self, x):
        x = nn.Conv(features=16, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=16, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)

        # return intermediate output
        x = TestDense()(x)

        x = nn.Dense(features=N_CLASSES)(x)
        return x
Esempio n. 4
0
 def __call__(self, x, with_classifier=True):
   x = nn.Conv(features=32, kernel_size=(3, 3))(x)
   x = nn.relu(x)
   x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
   x = nn.Conv(features=64, kernel_size=(3, 3))(x)
   x = nn.relu(x)
   x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
   x = x.reshape((x.shape[0], -1))  # flatten
   x = nn.Dense(features=256)(x)
   x = nn.relu(x)
   if not with_classifier:
     return x
   x = nn.Dense(features=10)(x)
   x = nn.log_softmax(x)
   return x
Esempio n. 5
0
    def __call__(self, x):
        B, H, W, C = x.shape
        out_ch = self.out_ch if self.out_ch else C
        if not self.fir:
            if self.with_conv:
                x = conv3x3(x, out_ch, stride=2)
            else:
                x = nn.avg_pool(x,
                                window_shape=(2, 2),
                                strides=(2, 2),
                                padding='SAME')
        else:
            if not self.with_conv:
                x = up_or_down_sampling.downsample_2d(x,
                                                      self.fir_kernel,
                                                      factor=2)
            else:
                x = up_or_down_sampling.Conv2d(out_ch,
                                               kernel=3,
                                               down=True,
                                               resample_kernel=self.fir_kernel,
                                               use_bias=True,
                                               kernel_init=default_init())(x)

        assert x.shape == (B, H // 2, W // 2, out_ch)
        return x
Esempio n. 6
0
 def __call__(self, x, with_classifier=True):
   x = nn.Conv(features=32, kernel_size=(3, 3))(x)
   x = nn.relu(x)
   x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
   x = nn.Conv(features=64, kernel_size=(3, 3))(x)
   x = nn.relu(x)
   x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
   # TODO: replace np.prod(x.shape[1:]) with -1 once we fix shape_polymorphism
   x = x.reshape((x.shape[0], np.prod(x.shape[1:])))  # flatten
   x = nn.Dense(features=256)(x)
   x = nn.relu(x)
   if not with_classifier:
     return x
   x = nn.Dense(features=10)(x)
   x = nn.log_softmax(x)
   return x
Esempio n. 7
0
    def __call__(self, x):
        # Helper macro.
        R_ = lambda hidden_: ResidualUnit(hidden_features=hidden_,
                                          norm=self.norm,
                                          training=self.training,
                                          activation=nn.gelu)
        # First filter to make features.
        h = nn.Conv(features=self.hidden * self.alpha,
                    use_bias=False,
                    kernel_size=(3, 3),
                    kernel_init=INITS[self.kernel_init])(x)
        h = NORMS[self.norm](use_running_average=not self.training)(h)
        h = nn.gelu(h)
        # 2 stages of continuous segments:
        h = ResidualStitch(hidden_features=self.hidden * self.alpha,
                           output_features=self.hidden * self.alpha,
                           strides=(1, 1),
                           norm=self.norm,
                           training=self.training,
                           activation=nn.gelu)(h)
        h = StatefulContinuousBlock(R=R_(self.hidden * self.alpha),
                                    scheme=self.scheme,
                                    n_step=self.n_step,
                                    n_basis=self.n_basis,
                                    basis=self.basis,
                                    training=self.training)(h)

        # Pool and linearly classify:
        h = NORMS[self.norm](use_running_average=not self.training)(h)
        h = nn.gelu(h)
        h = nn.avg_pool(h, window_shape=(8, 8), strides=(8, 8))
        h = h.reshape((h.shape[0], -1))
        h = nn.Dense(features=self.n_classes)(h)
        return nn.log_softmax(h)  # no softmax
Esempio n. 8
0
    def __call__(self, x):
        branch1x1 = self.conv_block(64, kernel_size=(1, 1),
                                    name='branch1x1')(x)

        branch5x5 = self.conv_block(48, kernel_size=(1, 1),
                                    name='branch5x5_1')(x)
        branch5x5 = self.conv_block(64,
                                    kernel_size=(5, 5),
                                    padding=[(2, 2), (2, 2)],
                                    name='branch5x5_2')(branch5x5)

        branch3x3dbl = self.conv_block(64,
                                       kernel_size=(1, 1),
                                       name='branch3x3dbl_1')(x)
        branch3x3dbl = self.conv_block(96,
                                       kernel_size=(3, 3),
                                       padding=[(1, 1), (1, 1)],
                                       name='branch3x3dbl_2')(branch3x3dbl)
        branch3x3dbl = self.conv_block(96,
                                       kernel_size=(3, 3),
                                       padding=[(1, 1), (1, 1)],
                                       name='branch3x3dbl_3')(branch3x3dbl)

        branch_pool = nn.avg_pool(x, (3, 3),
                                  strides=(1, 1),
                                  padding=[(1, 1), (1, 1)])
        branch_pool = self.conv_block(self.pool_features,
                                      kernel_size=(1, 1),
                                      name='branch_pool')(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
        return jnp.concatenate(outputs, 3)
Esempio n. 9
0
    def __call__(self, inputs):
        """Passes the input through a bottleneck transformer block.
        Arguments:
            inputs:     [batch_size, height, width, dim]
        Returns:
            output:     [batch_size, height, width, dim * config.projection_factor]
        """
        residual = inputs
        cfg = self.config

        y = self.conv(self.filters, kernel_size=(1, 1))(inputs)
        y = self.norm()(y)
        y = cfg.activation_fn(y)
        y = BoTMHSA(config=cfg)(y)
        if self.strides == (2, 2):
            y = nn.avg_pool(y,
                            window_shape=(2, 2),
                            strides=self.strides,
                            padding='SAME')
        y = self.norm()(y)
        y = cfg.activation_fn(y)
        y = self.conv(self.filters * cfg.projection_factor,
                      kernel_size=(1, 1))(y)
        y = self.norm(scale_init=initializers.zeros)(y)

        if self.strides == (2, 2) or residual.shape != y.shape:
            residual = self.conv(self.filters * cfg.projection_factor,
                                 kernel_size=(1, 1),
                                 strides=self.strides)(residual)
            residual = self.norm()(residual)
            residual = cfg.activation_fn(residual)

        y = cfg.activation_fn(residual + y)
        return y
Esempio n. 10
0
    def __call__(self, inputs):
        """Applies spherical pooling.

    Args:
      inputs: An array of dimensions (batch_size, resolution, resolution,
      n_spins_in, n_channels_in).
    Returns:
      An array of dimensions (batch_size, resolution // stride, resolution //
      stride, n_spins_in, n_channels_in).
    """
        # We use variables to cache the in/out weights.
        resolution_in = inputs.shape[1]
        resolution_out = resolution_in // self.stride
        weights_in = sphere_utils.sphere_quadrature_weights(resolution_in)
        weights_out = sphere_utils.sphere_quadrature_weights(resolution_out)

        weighted = inputs * jnp.expand_dims(weights_in, (0, 2, 3, 4))
        pooled = nn.avg_pool(weighted,
                             window_shape=(self.stride, self.stride, 1),
                             strides=(self.stride, self.stride, 1))
        # This was average pooled. We multiply by stride**2 to obtain the sum
        # pooled, then divide by output weights to get the weighted average.
        pooled = (pooled * self.stride**2 /
                  jnp.expand_dims(weights_out, (0, 2, 3, 4)))

        return pooled
Esempio n. 11
0
    def __call__(self, x, sigmas, train=True):
        # per image standardization
        N = np.prod(x.shape[1:])
        x = (x - jnp.mean(x, axis=(1, 2, 3), keepdims=True)) / jnp.maximum(
            jnp.std(x, axis=(1, 2, 3), keepdims=True), 1. / np.sqrt(N))
        temb = GaussianFourierProjection(embedding_size=128,
                                         scale=16)(jnp.log(sigmas))
        temb = nn.Dense(128 * 4)(temb)
        temb = nn.Dense(128 * 4)(nn.swish(temb))

        x = nn.Conv(16, (3, 3),
                    padding='SAME',
                    name='init_conv',
                    kernel_init=conv_kernel_init_fn,
                    use_bias=False)(x)
        x = WideResnetGroup(self.blocks_per_group,
                            16 * self.channel_multiplier,
                            activate_before_residual=True)(x, temb, train)
        x = WideResnetGroup(self.blocks_per_group,
                            32 * self.channel_multiplier, (2, 2))(x, temb,
                                                                  train)
        x = WideResnetGroup(self.blocks_per_group,
                            64 * self.channel_multiplier, (2, 2))(x, temb,
                                                                  train)
        x = activation(x, train=train, name='pre-pool-bn')
        x = nn.avg_pool(x, x.shape[1:3])
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(self.num_outputs, kernel_init=dense_layer_init_fn)(x)
        return x
Esempio n. 12
0
    def __call__(self, x):
        for feat in self.features:
            x = nn.Conv(features=feat,
                        kernel_size=(self.kernel_size, self.kernel_size))(x)
            x = nn.relu(x)
            x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

        x = x.reshape((x.shape[0], -1))  # flatten
        return x
Esempio n. 13
0
    def __call__(self, x):
        x = nn.avg_pool(x, (5, 5), strides=(3, 3))
        x = self.conv_block(128, kernel_size=(1, 1), name='conv0')(x)
        x = self.conv_block(768, kernel_size=(5, 5), name='conv1')(x)

        x = x.transpose((0, 3, 1, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(self.num_classes, name='fc')(x)

        return x
Esempio n. 14
0
  def __call__(self, x, train=True):
    down_sample_layers = [ConvBlock(self.chans, self.drop_prob)]

    ch = self.chans
    for _ in range(self.num_pool_layers - 1):
      down_sample_layers.append(ConvBlock(ch * 2, self.drop_prob))
      ch *= 2
    conv = ConvBlock(ch * 2, self.drop_prob)

    up_conv = []
    up_transpose_conv = []
    for _ in range(self.num_pool_layers - 1):
      up_transpose_conv.append(TransposeConvBlock(ch))
      up_conv.append(ConvBlock(ch, self.drop_prob))
      ch //= 2

    up_transpose_conv.append(TransposeConvBlock(ch))
    up_conv.append(ConvBlock(ch, self.drop_prob))

    final_conv = nn.Conv(self.out_chans, kernel_size=(1, 1), strides=(1, 1))

    stack = []
    output = jnp.expand_dims(x, axis=-1)

    # apply down-sampling layers
    for layer in down_sample_layers:
      output = layer(output, train)
      stack.append(output)
      output = nn.avg_pool(output, window_shape=(2, 2), strides=(2, 2))

    output = conv(output, train)

    # apply up-sampling layers
    for transpose_conv, conv in zip(up_transpose_conv, up_conv):
      downsample_layer = stack.pop()
      output = transpose_conv(output)

      # reflect pad on the right/botton if needed to handle odd input dimensions
      padding_right = 0
      padding_bottom = 0
      if output.shape[-2] != downsample_layer.shape[-2]:
        padding_right = 1  # padding right
      if output.shape[-3] != downsample_layer.shape[-3]:
        padding_bottom = 1  # padding bottom

      if padding_right or padding_bottom:
        padding = ((0, 0), (0, padding_bottom), (0, padding_right), (0, 0))
        output = jnp.pad(output, padding, mode='reflect')

      output = jnp.concatenate((output, downsample_layer), axis=-1)
      output = conv(output, train)

    output = final_conv(output)

    return output.squeeze(-1)
Esempio n. 15
0
 def __call__(self, x):
     B, H, W, C = x.shape
     if self.with_conv:
         x = ddpm_conv3x3(x, C, stride=2)
     else:
         x = nn.avg_pool(x,
                         window_shape=(2, 2),
                         strides=(2, 2),
                         padding='SAME')
     assert x.shape == (B, H // 2, W // 2, C)
     return x
Esempio n. 16
0
  def __call__(self, inputs, train: bool = False):
    in_shape = np.shape(inputs)[1:-1]
    x = nn.avg_pool(inputs, in_shape)
    x = nn.Conv(self.channels, (1, 1), padding='SAME', use_bias=False, name="conv1")(x)
    x = nn.BatchNorm(use_running_average=not train, name="bn1")(x)
    x = nn.relu(x)

    out_shape = (1, in_shape[0], in_shape[1], self.channels)
    x = jax.image.resize(x, shape=out_shape, method='bilinear')

    return x
Esempio n. 17
0
  def __call__(self, x, train):
    conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
    maybe_normalize = model_utils.get_normalizer(self.normalizer, train)

    y = maybe_normalize()(x)
    y = nn.relu(y)
    y = conv(features=self.num_features, kernel_size=(1, 1))(y)
    y = nn.avg_pool(
        y,
        window_shape=(2, 2),
        strides=(2, 2) if self.use_kernel_size_as_stride_in_pooling else (1, 1))
    return y
Esempio n. 18
0
    def __call__(self, x, train=False):
        conv_block = partial(BasicConv, train=train, dtype=self.dtype)
        inception_a = partial(InceptionA, conv_block=conv_block)
        inception_b = partial(InceptionB, conv_block=conv_block)
        inception_c = partial(InceptionC, conv_block=conv_block)
        inception_d = partial(InceptionD, conv_block=conv_block)
        inception_e = partial(InceptionE, conv_block=conv_block)
        inception_aux = partial(InceptionAux, conv_block=conv_block)

        if self.transform_input:
            x = np.transpose(x, (0, 3, 1, 2))
            x_ch0 = jnp.expand_dims(x[:, 0],
                                    1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = jnp.expand_dims(x[:, 1],
                                    1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = jnp.expand_dims(x[:, 2],
                                    1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = jnp.concatenate((x_ch0, x_ch1, x_ch2), 1)
            x = np.transpose(x, (0, 2, 3, 1))

        x = conv_block(32,
                       kernel_size=(3, 3),
                       strides=(2, 2),
                       name='Conv2d_1a_3x3')(x)
        x = conv_block(32, kernel_size=(3, 3), name='Conv2d_2a_3x3')(x)
        x = conv_block(64,
                       kernel_size=(3, 3),
                       padding=[(1, 1), (1, 1)],
                       name='Conv2d_2b_3x3')(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2))
        x = conv_block(80, kernel_size=(1, 1), name='Conv2d_3b_1x1')(x)
        x = conv_block(192, kernel_size=(3, 3), name='Conv2d_4a_3x3')(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2))

        x = inception_a(pool_features=32, name='Mixed_5b')(x)
        x = inception_a(pool_features=64, name='Mixed_5c')(x)
        x = inception_a(pool_features=64, name='Mixed_5d')(x)
        x = inception_b(name='Mixed_6a')(x)
        x = inception_c(channels_7x7=128, name='Mixed_6b')(x)
        x = inception_c(channels_7x7=160, name='Mixed_6c')(x)
        x = inception_c(channels_7x7=160, name='Mixed_6d')(x)
        x = inception_c(channels_7x7=192, name='Mixed_6e')(x)

        aux = inception_aux(self.num_classes, name='AuxLogits')(x) \
              if train and self.aux_logits else None

        x = inception_d(name='Mixed_7a')(x)
        x = inception_e(name='Mixed_7b')(x)
        x = inception_e(name='Mixed_7c')(x)
        x = nn.avg_pool(x, (8, 8))
        x = nn.Dropout(0.5)(x, deterministic=not train)

        return x, aux
Esempio n. 19
0
 def __call__(self, x, y):
     x = self.act(x)
     path = x
     for _ in range(self.n_stages):
         path = self.normalizer()(path, y)
         path = nn.avg_pool(path,
                            window_shape=(5, 5),
                            strides=(1, 1),
                            padding='SAME')
         path = ncsn_conv3x3(path, self.features, stride=1, bias=False)
         x = path + x
     return x
Esempio n. 20
0
 def test_avg_pool_no_batch(self):
     x = jnp.full((3, 3, 1), 2.)
     pool = lambda x: nn.avg_pool(x, (2, 2))
     y = pool(x)
     np.testing.assert_allclose(y, np.full((2, 2, 1), 2.))
     y_grad = jax.grad(lambda x: pool(x).sum())(x)
     expected_grad = jnp.array([
         [0.25, 0.5, 0.25],
         [0.5, 1., 0.5],
         [0.25, 0.5, 0.25],
     ]).reshape((3, 3, 1))
     np.testing.assert_allclose(y_grad, expected_grad)
Esempio n. 21
0
 def __call__(self, x, train):
     x = nn.Conv(16, (3, 3),
                 padding='SAME',
                 name='init_conv',
                 kernel_init=self.conv_kernel_init,
                 use_bias=False)(x)
     x = WideResnetGroup(self.blocks_per_group,
                         16 * self.channel_multiplier,
                         self.group_strides[0],
                         conv_kernel_init=self.conv_kernel_init,
                         normalizer=self.normalizer,
                         dropout_rate=self.dropout_rate,
                         activation_function=self.activation_function,
                         batch_size=self.batch_size,
                         virtual_batch_size=self.virtual_batch_size,
                         total_batch_size=self.total_batch_size)(
                             x, train=train)
     x = WideResnetGroup(self.blocks_per_group,
                         32 * self.channel_multiplier,
                         self.group_strides[1],
                         conv_kernel_init=self.conv_kernel_init,
                         normalizer=self.normalizer,
                         dropout_rate=self.dropout_rate,
                         activation_function=self.activation_function,
                         batch_size=self.batch_size,
                         virtual_batch_size=self.virtual_batch_size,
                         total_batch_size=self.total_batch_size)(
                             x, train=train)
     x = WideResnetGroup(self.blocks_per_group,
                         64 * self.channel_multiplier,
                         self.group_strides[2],
                         conv_kernel_init=self.conv_kernel_init,
                         dropout_rate=self.dropout_rate,
                         normalizer=self.normalizer,
                         activation_function=self.activation_function,
                         batch_size=self.batch_size,
                         virtual_batch_size=self.virtual_batch_size,
                         total_batch_size=self.total_batch_size)(
                             x, train=train)
     maybe_normalize = model_utils.get_normalizer(
         self.normalizer,
         train,
         batch_size=self.batch_size,
         virtual_batch_size=self.virtual_batch_size,
         total_batch_size=self.total_batch_size)
     x = maybe_normalize()(x)
     x = model_utils.ACTIVATIONS[self.activation_function](x)
     x = nn.avg_pool(x, (8, 8))
     x = x.reshape((x.shape[0], -1))
     x = nn.Dense(self.num_outputs, kernel_init=self.dense_kernel_init)(x)
     return x
Esempio n. 22
0
 def __call__(self, x):
     Conv1x1_ = partial(Conv1x1, precision=self.conv_precision)
     Conv3x3_ = partial(Conv3x3 if self.use_3x3 else Conv1x1,
                        precision=self.conv_precision)
     x_ = Conv1x1_(self.middle_width)(nn.gelu(x))
     x_ = Conv3x3_(self.middle_width)(nn.gelu(x_))
     x_ = Conv3x3_(self.middle_width)(nn.gelu(x_))
     x_ = Conv1x1_(self.out_width,
                   kernel_init=lecun_normal(self.last_scale))(nn.gelu(x_))
     out = x + x_ if self.residual else x_
     if self.down_rate > 1:
         window_shape = 2 * (self.down_rate, )
         out = nn.avg_pool(out, window_shape, window_shape)
     return out
Esempio n. 23
0
    def __call__(self, x):
        branch1x1 = self.conv_block(192, kernel_size=(1, 1),
                                    name='branch1x1')(x)

        c7 = self.channels_7x7
        branch7x7 = self.conv_block(c7, kernel_size=(1, 1),
                                    name='branch7x7_1')(x)
        branch7x7 = self.conv_block(c7,
                                    kernel_size=(1, 7),
                                    padding=[(0, 0), (3, 3)],
                                    name='branch7x7_2')(branch7x7)
        branch7x7 = self.conv_block(192,
                                    kernel_size=(7, 1),
                                    padding=[(3, 3), (0, 0)],
                                    name='branch7x7_3')(branch7x7)

        branch7x7dbl = self.conv_block(c7,
                                       kernel_size=(1, 1),
                                       name='branch7x7dbl_1')(x)
        branch7x7dbl = self.conv_block(c7,
                                       kernel_size=(7, 1),
                                       padding=[(3, 3), (0, 0)],
                                       name='branch7x7dbl_2')(branch7x7dbl)
        branch7x7dbl = self.conv_block(c7,
                                       kernel_size=(1, 7),
                                       padding=[(0, 0), (3, 3)],
                                       name='branch7x7dbl_3')(branch7x7dbl)
        branch7x7dbl = self.conv_block(c7,
                                       kernel_size=(7, 1),
                                       padding=[(3, 3), (0, 0)],
                                       name='branch7x7dbl_4')(branch7x7dbl)
        branch7x7dbl = self.conv_block(192,
                                       kernel_size=(1, 7),
                                       padding=[(0, 0), (3, 3)],
                                       name='branch7x7dbl_5')(branch7x7dbl)

        branch_pool = nn.avg_pool(x, (3, 3),
                                  strides=(1, 1),
                                  padding=[(1, 1), (1, 1)])
        branch_pool = self.conv_block(192,
                                      kernel_size=(1, 1),
                                      name='branch_pool')(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return jnp.concatenate(outputs, 3)
Esempio n. 24
0
def _output_add(block_x, orig_x):
    """Add two tensors, padding them with zeros or pooling them if necessary.

  Args:
    block_x: Output of a resnet block.
    orig_x: Residual branch to add to the output of the resnet block.

  Returns:
    The sum of blocks_x and orig_x. If necessary, orig_x will be average pooled
      or zero padded so that its shape matches orig_x.
  """
    stride = orig_x.shape[-2] // block_x.shape[-2]
    strides = (stride, stride)
    if block_x.shape[-1] != orig_x.shape[-1]:
        orig_x = nn.avg_pool(orig_x, strides, strides)
        channels_to_add = block_x.shape[-1] - orig_x.shape[-1]
        orig_x = jnp.pad(orig_x, [(0, 0), (0, 0), (0, 0),
                                  (0, channels_to_add)])
    return block_x + orig_x
Esempio n. 25
0
    def __call__(self, x):
        branch1x1 = self.conv_block(320, kernel_size=(1, 1),
                                    name='branch1x1')(x)

        branch3x3 = self.conv_block(384,
                                    kernel_size=(1, 1),
                                    name='branch3x3_1')(x)
        branch3x3_2a = self.conv_block(384,
                                       kernel_size=(1, 3),
                                       padding=[(0, 0), (1, 1)],
                                       name='branch3x3_2a')(branch3x3)
        branch3x3_2b = self.conv_block(384,
                                       kernel_size=(3, 1),
                                       padding=[(1, 1), (0, 0)],
                                       name='branch3x3_2b')(branch3x3)
        branch3x3 = jnp.concatenate([branch3x3_2a, branch3x3_2b], 3)

        branch3x3dbl = self.conv_block(448,
                                       kernel_size=(1, 1),
                                       name='branch3x3dbl_1')(x)
        branch3x3dbl = self.conv_block(384,
                                       kernel_size=(3, 3),
                                       padding=[(1, 1), (1, 1)],
                                       name='branch3x3dbl_2')(branch3x3dbl)
        branch3x3dbl_3a = self.conv_block(384,
                                          kernel_size=(1, 3),
                                          padding=[(0, 0), (1, 1)],
                                          name='branch3x3dbl_3a')(branch3x3dbl)
        branch3x3dbl_3b = self.conv_block(384,
                                          kernel_size=(3, 1),
                                          padding=[(1, 1), (0, 0)],
                                          name='branch3x3dbl_3b')(branch3x3dbl)
        branch3x3dbl = jnp.concatenate([branch3x3dbl_3a, branch3x3dbl_3b], 3)

        branch_pool = nn.avg_pool(x, (3, 3),
                                  strides=(1, 1),
                                  padding=[(1, 1), (1, 1)])
        branch_pool = self.conv_block(192,
                                      kernel_size=(1, 1),
                                      name='branch_pool')(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return jnp.concatenate(outputs, 3)
Esempio n. 26
0
    def __call__(self, inputs, train):
        """Applies the network to inputs.

    Args:
      inputs: (batch_size, resolution, resolution, n_spins, n_channels) array.
      train: whether to run in training or inference mode.
    Returns:
      A (batch_size, num_classes) float32 array with per-class scores (logits).
    Raises:
      ValueError: If resolutions cannot be enforced with 2x2 pooling.
    """
        num_layers = len(self.resolutions)
        # Merge spin and channel dimensions.
        features = inputs.reshape((*inputs.shape[:3], -1))
        for layer_id in range(num_layers - 1):
            resolution_in = self.resolutions[layer_id]
            resolution_out = self.resolutions[layer_id + 1]
            n_channels = self.widths[layer_id + 1]

            if resolution_out == resolution_in // 2:
                features = nn.avg_pool(features,
                                       window_shape=(2, 2),
                                       strides=(2, 2),
                                       padding='SAME')
            elif resolution_out != resolution_in:
                raise ValueError(
                    'Consecutive resolutions must be equal or halved.')

            features = nn.Conv(features=n_channels,
                               kernel_size=(3, 3),
                               strides=(1, 1))(features)

            features = nn.BatchNorm(use_running_average=not train,
                                    axis_name=self.axis_name)(features)
            features = nn.relu(features)

        features = jnp.mean(features, axis=(1, 2))
        features = nn.Dense(self.num_classes)(features)

        return features
Esempio n. 27
0
  def __call__(self, x, *, emb, deterministic):
    B, _, _, C = x.shape  # pylint: disable=invalid-name
    assert emb.shape[0] == B and len(emb.shape) == 2
    out_ch = C if self.out_ch is None else self.out_ch

    h = nonlinearity(Normalize(name='norm1')(x))
    if self.resample is not None:
      updown = lambda z: {
          'up': nearest_neighbor_upsample(z),
          'down': nn.avg_pool(z, (2, 2), (2, 2))
      }[self.resample]
      h = updown(h)
      x = updown(x)
    h = nn.Conv(
        features=out_ch, kernel_size=(3, 3), strides=(1, 1), name='conv1')(h)

    # add in timestep/class embedding
    emb_out = nn.Dense(features=2 * out_ch, name='temb_proj')(
        nonlinearity(emb))[:, None, None, :]
    scale, shift = jnp.split(emb_out, 2, axis=-1)
    h = Normalize(name='norm2')(h) * (1 + scale) + shift
    # rest
    h = nonlinearity(h)
    h = nn.Dropout(rate=self.dropout)(h, deterministic=deterministic)
    h = nn.Conv(
        features=out_ch,
        kernel_size=(3, 3),
        strides=(1, 1),
        kernel_init=nn.initializers.zeros,
        name='conv2')(h)

    if C != out_ch:
      x = nn.Dense(features=out_ch, name='nin_shortcut')(x)

    assert x.shape == h.shape
    logging.info(
        '%s: x=%r emb=%r resample=%r',
        self.name, x.shape, emb.shape, self.resample)
    return x + h
Esempio n. 28
0
  def __call__(self, x, train):
    def dense_layers(y, block, num_blocks, growth_rate):
      for _ in range(num_blocks):
        y = block(growth_rate)(y, train=train)
      return y

    def update_num_features(num_features, num_blocks, growth_rate, reduction):
      num_features += num_blocks * growth_rate
      if reduction is not None:
        num_features = int(math.floor(num_features * reduction))
      return num_features

    # Initial convolutional layer
    num_features = 2 * self.growth_rate
    conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
    y = conv(
        features=num_features,
        kernel_size=(3, 3),
        padding=((1, 1), (1, 1)),
        name='conv1')(x)

    # Internal dense and transtion blocks
    num_blocks = _block_size_options[self.num_layers]
    block = functools.partial(
        BottleneckBlock,
        dtype=self.dtype,
        normalizer=self.normalizer)
    for i in range(3):
      y = dense_layers(y, block, num_blocks[i], self.growth_rate)
      num_features = update_num_features(num_features, num_blocks[i],
                                         self.growth_rate, self.reduction)
      y = TransitionBlock(
          num_features,
          dtype=self.dtype,
          normalizer=self.normalizer,
          use_kernel_size_as_stride_in_pooling=self
          .use_kernel_size_as_stride_in_pooling)(
              y, train=train)

    # Final dense block
    y = dense_layers(y, block, num_blocks[3], self.growth_rate)

    # Final pooling
    maybe_normalize = model_utils.get_normalizer(self.normalizer, train)
    y = maybe_normalize()(y)
    y = nn.relu(y)
    y = nn.avg_pool(
        y,
        window_shape=(4, 4),
        strides=(4, 4) if self.use_kernel_size_as_stride_in_pooling else (1, 1))

    # Classification layer
    y = jnp.reshape(y, (y.shape[0], -1))
    if self.normalize_classifier_input:
      maybe_normalize = model_utils.get_normalizer(
          self.normalize_classifier_input, train)
      y = maybe_normalize()(y)
    y = y * self.classification_scale_factor

    y = nn.Dense(self.num_outputs)(y)
    return y
Esempio n. 29
0
    def __call__(
        self,
        inputs,
        context_vectors=None,
    ):
        """Applies the res block to input images.

    Args:
      inputs: a rank-4 array of input images of shape (B, H, W, C).
      context_vectors: optional auxiliary inputs, typically used for
        conditioning. If set, they should be of rank 2, and their first (batch)
        dimension should match that of `inputs`. Their number of features is
        arbitrary. They will be reshaped from (B, D) to (B, 1, 1, D) and a 1x1
        convolution will be applied to them.

    Returns:
      a the rank-4 output of the block.
    """
        if self.downsampling_rate < 1:
            raise ValueError('downsampling_rate should be >= 1, but got '
                             f'{self.downsampling_rate}.')

        def build_layers(inputs):
            """Build layers of the ResBlock given a batch of inputs."""
            resolution = inputs.shape[1]
            if resolution > 2:
                kernel_shapes = ((1, 1), (3, 3), (3, 3), (1, 1))
            else:
                kernel_shapes = ((1, 1), (1, 1), (1, 1), (1, 1))

            conv_layers = []
            aux_conv_layers = []
            for layer_idx, kernel_shape in enumerate(kernel_shapes):
                is_last = layer_idx == _NUM_CONV_LAYER_PER_BLOCK - 1
                num_channels = self.output_channels if is_last else self.internal_channels
                weights_scale = self.last_weights_scale if is_last else 1.
                conv_layers.append(
                    get_vdvae_convolution(num_channels,
                                          kernel_shape,
                                          weights_scale,
                                          name=f'c{layer_idx}',
                                          precision=self.precision))
                aux_conv_layers.append(
                    get_vdvae_convolution(num_channels, (1, 1),
                                          0.,
                                          name=f'aux_c{layer_idx}',
                                          precision=self.precision))

            return conv_layers, aux_conv_layers

        chex.assert_rank(inputs, 4)
        if inputs.shape[1] != inputs.shape[2]:
            raise ValueError(
                'VDVAE only works with square images, but got '
                f'rectangular images of shape {inputs.shape[1:3]}.')
        if context_vectors is not None:
            chex.assert_rank(context_vectors, 2)
            inputs_batch_dim = inputs.shape[0]
            aux_batch_dim = context_vectors.shape[0]
            if inputs_batch_dim != aux_batch_dim:
                raise ValueError(
                    'Context vectors batch dimension is incompatible '
                    'with inputs batch dimension. Got '
                    f'{aux_batch_dim} vs {inputs_batch_dim}.')
            context_vectors = context_vectors[:, None, None, :]

        conv_layers, aux_conv_layers = build_layers(inputs)

        outputs = inputs
        for conv, auxiliary_conv in zip(conv_layers, aux_conv_layers):
            outputs = conv(jax.nn.gelu(outputs))
            if context_vectors is not None:
                outputs += auxiliary_conv(context_vectors)

        if self.use_residual_connection:
            in_channels = inputs.shape[-1]
            out_channels = outputs.shape[-1]
            if in_channels != out_channels:
                raise AssertionError(
                    'Cannot apply residual connection because the '
                    'number of output channels differs from the '
                    'number of input channels: '
                    f'{out_channels} vs {in_channels}.')
            outputs += inputs
        if self.downsampling_rate > 1:
            shape = (self.downsampling_rate, self.downsampling_rate)
            outputs = nn.avg_pool(outputs,
                                  window_shape=shape,
                                  strides=shape,
                                  padding='VALID')
        return outputs
Esempio n. 30
0
  def __call__(
      self,
      inputs,
  ):
    """Applies ResNet model. Number of residual blocks inferred from hparams."""
    num_classes = self.num_classes
    hparams = self.hparams
    num_filters = self.num_filters
    dtype = self.dtype
    assert hparams.act_function in act_function_zoo.keys(
    ), 'Activation function type is not supported.'

    x = aqt_flax_layers.ConvAqt(
        features=num_filters,
        kernel_size=(7, 7),
        strides=(2, 2),
        padding=[(3, 3), (3, 3)],
        use_bias=False,
        dtype=dtype,
        name='init_conv',
        train=self.train,
        quant_context=self.quant_context,
        paxis_name='batch',
        hparams=hparams.conv_init,
    )(
        inputs)
    x = nn.BatchNorm(
        use_running_average=not self.train,
        momentum=0.9,
        epsilon=1e-5,
        dtype=dtype,
        name='init_bn')(
            x)
    if hparams.act_function == 'relu':
      x = nn.relu(x)
      x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    else:
      # TODO(yichi): try adding other activation functions here
      # Use avg pool so that for binary nets, the distribution is symmetric.
      x = nn.avg_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    filter_multiplier = hparams.filter_multiplier
    for i, block_hparams in enumerate(hparams.residual_blocks):
      proj = block_hparams.conv_proj
      # For projection layers (unless it is the first layer), strides = (2, 2)
      if i > 0 and proj is not None:
        filter_multiplier *= 2
        strides = (2, 2)
      else:
        strides = (1, 1)
      x = ResidualBlock(
          filters=int(num_filters * filter_multiplier),
          hparams=block_hparams,
          quant_context=self.quant_context,
          strides=strides,
          train=self.train,
          dtype=dtype)(
              x)
    if hparams.act_function == 'none':
      # The DenseAQT below is not binarized.
      # If removing the activation functions, there will be no act function
      # between the last residual block and the dense layer.
      # So add a ReLU in that case.
      # TODO(yichi): try BPReLU
      x = nn.relu(x)
    else:
      pass
    x = jnp.mean(x, axis=(1, 2))

    x = aqt_flax_layers.DenseAqt(
        features=num_classes,
        dtype=dtype,
        train=self.train,
        quant_context=self.quant_context,
        paxis_name='batch',
        hparams=hparams.dense_layer,
    )(x, padding_mask=None)

    x = jnp.asarray(x, dtype)
    output = nn.log_softmax(x)
    return output