Exemple #1
0
    def __call__(self, x):
        needs_projection = (x.shape[-1] != self.features * 4 or self.strides !=
                            (1, 1))

        residual = x
        if needs_projection:
            residual = StdConv(features=self.features * 4,
                               kernel_size=(1, 1),
                               strides=self.strides,
                               use_bias=False,
                               name='conv_proj')(residual)
            residual = nn.GroupNorm(name='gn_proj')(residual)

        y = StdConv(features=self.features,
                    kernel_size=(1, 1),
                    use_bias=False,
                    name='conv1')(x)
        y = nn.GroupNorm(name='gn1')(y)
        y = nn.relu(y)
        y = StdConv(features=self.features,
                    kernel_size=(3, 3),
                    strides=self.strides,
                    use_bias=False,
                    name='conv2')(y)
        y = nn.GroupNorm(name='gn2')(y)
        y = nn.relu(y)
        y = StdConv(features=self.features * 4,
                    kernel_size=(1, 1),
                    use_bias=False,
                    name='conv3')(y)

        y = nn.GroupNorm(name='gn3', scale_init=nn.initializers.zeros)(y)
        y = nn.relu(residual + y)
        return y
 def __call__(self, x):
     x = nn.Conv(features=28, kernel_size=(3, 3), strides=(2, 2))(x)
     x = nn.GroupNorm(28)(x)
     x = nn.gelu(x)
     x = nn.Conv(features=64, kernel_size=(3, 3), strides=(2, 2))(x)
     x = nn.GroupNorm(32)(x)
     x = nn.gelu(x)
     x = nn.Conv(features=64, kernel_size=(3, 3), strides=(2, 2))(x)
     x = nn.GroupNorm(32)(x)
     x = nn.gelu(x)
     x = x.reshape((x.shape[0], -1))
     mean_x = nn.Dense(self.latent_dim, name='fc2_mean')(x)
     logvar_x = nn.Dense(self.latent_dim, name='fc2_logvar')(x)
     return mean_x, logvar_x
 def __call__(self, z):
     shape_before_flattening, flatten_out_size = self.flatten_enc_shape()
     #print(shape_before_flattening, flatten_out_size)
     x = nn.Dense(flatten_out_size, name='fc1')(z)
     x = nn.gelu(x)
     x = x.reshape((x.shape[0], *shape_before_flattening[1:]))
     x = nn.ConvTranspose(features=32, kernel_size=(3, 3),
                          strides=(2, 2))(x)
     x = nn.GroupNorm(32)(x)
     x = nn.gelu(x)
     x = nn.ConvTranspose(features=28, kernel_size=(3, 3),
                          strides=(2, 2))(x)
     x = nn.GroupNorm(28)(x)
     x = nn.gelu(x)
     x = nn.ConvTranspose(features=1, kernel_size=(3, 3), strides=(2, 2))(x)
     return x
Exemple #4
0
def activation(x, train, apply_relu=True, name=''):
    x = nn.GroupNorm(name=name,
                     epsilon=1e-5,
                     num_groups=min(x.shape[-1] // 4, 32))(x)
    if apply_relu:
        x = jax.nn.relu(x)
    return x
Exemple #5
0
    def __call__(self, x, temb=None, train=True):
        B, H, W, C = x.shape
        out_ch = self.out_ch if self.out_ch else C
        h = self.act(nn.GroupNorm(num_groups=min(x.shape[-1] // 4, 32))(x))

        if self.up:
            if self.fir:
                h = up_or_down_sampling.upsample_2d(h,
                                                    self.fir_kernel,
                                                    factor=2)
                x = up_or_down_sampling.upsample_2d(x,
                                                    self.fir_kernel,
                                                    factor=2)
            else:
                h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
                x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
        elif self.down:
            if self.fir:
                h = up_or_down_sampling.downsample_2d(h,
                                                      self.fir_kernel,
                                                      factor=2)
                x = up_or_down_sampling.downsample_2d(x,
                                                      self.fir_kernel,
                                                      factor=2)
            else:
                h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
                x = up_or_down_sampling.naive_downsample_2d(x, factor=2)

        h = conv3x3(h, out_ch)
        # Add bias to each feature map conditioned on the time embedding
        if temb is not None:
            h += nn.Dense(out_ch,
                          kernel_init=default_init())(self.act(temb))[:, None,
                                                                      None, :]

        h = self.act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h))
        h = nn.Dropout(self.dropout)(h, deterministic=not train)
        h = conv3x3(h, out_ch, init_scale=self.init_scale)
        if C != out_ch or self.up or self.down:
            x = conv1x1(x, out_ch)

        if not self.skip_rescale:
            return x + h
        else:
            return (x + h) / np.sqrt(2.)
Exemple #6
0
  def __call__(self, x):
    # Build Encoder
    for h_dim in self.hidden_dims:
      x = nn.Conv(features=h_dim, kernel_size=(3, 3), strides=(2,2), padding="valid")(x)
      x = nn.GroupNorm()(x)
      x = nn.gelu(x)

    x = x.reshape((x.shape[0], -1))
    mean_x = nn.Dense(self.latent_dim, name='fc2_mean')(x)
    logvar_x = nn.Dense(self.latent_dim, name='fc2_logvar')(x)
    return mean_x, logvar_x
Exemple #7
0
 def setup(self):
     activation = nn.softplus if self.activation == 'softplus' else nn.relu
     if (self.group_norm):
         self.double_conv = Sequential([
             nn.Conv(self.mid_channels, kernel_size=(3, 3), use_bias=False),
             nn.GroupNorm(self.num_groups),
             activation,
             nn.Conv(self.out_channels, kernel_size=(3, 3), use_bias=False),
             nn.GroupNorm(self.num_groups),
             activation,
         ])
     else:
         self.double_conv = Sequential([
             nn.Conv(self.mid_channels, kernel_size=(3, 3), use_bias=False),
             nn.BatchNorm(use_running_average=self.test),
             activation,
             nn.Conv(self.out_channels, kernel_size=(3, 3), use_bias=False),
             nn.BatchNorm(use_running_average=self.test),
             activation,
         ])
Exemple #8
0
    def test_group_norm_raises(self):
        rng = random.PRNGKey(0)
        key1, key2 = random.split(rng)
        e = 1e-5
        x = random.normal(key1, (2, 5, 4, 4, 32))
        model_cls = nn.GroupNorm(num_groups=3,
                                 use_bias=False,
                                 use_scale=False,
                                 epsilon=e)

        with self.assertRaises(ValueError):
            model_cls.init_with_output(key2, x)
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        features = self.nout
        nout = self.nout * 4 if self.bottleneck else self.nout
        needs_projection = x.shape[-1] != nout or self.strides != (1, 1)
        residual = x
        if needs_projection:
            residual = StdConv(nout, (1, 1),
                               self.strides,
                               use_bias=False,
                               name='conv_proj')(residual)
            residual = nn.GroupNorm(num_groups=self.gn_num_groups,
                                    epsilon=1e-4,
                                    name='gn_proj')(residual)

        if self.bottleneck:
            x = StdConv(features, (1, 1), use_bias=False, name='conv1')(x)
            x = nn.GroupNorm(num_groups=self.gn_num_groups,
                             epsilon=1e-4,
                             name='gn1')(x)
            x = nn.relu(x)

        x = StdConv(features, (3, 3),
                    self.strides,
                    kernel_dilation=self.dilation,
                    use_bias=False,
                    name='conv2')(x)
        x = nn.GroupNorm(num_groups=self.gn_num_groups,
                         epsilon=1e-4,
                         name='gn2')(x)
        x = nn.relu(x)

        last_kernel = (1, 1) if self.bottleneck else (3, 3)
        x = StdConv(nout, last_kernel, use_bias=False, name='conv3')(x)
        x = nn.GroupNorm(num_groups=self.gn_num_groups,
                         epsilon=1e-4,
                         name='gn3',
                         scale_init=nn.initializers.zeros)(x)
        x = nn.relu(residual + x)

        return x
Exemple #10
0
    def __call__(self, x, temb=None, train=True):
        B, H, W, C = x.shape
        out_ch = self.out_ch if self.out_ch else C
        h = self.act(nn.GroupNorm(num_groups=min(x.shape[-1] // 4, 32))(x))
        h = conv3x3(h, out_ch)
        # Add bias to each feature map conditioned on the time embedding
        if temb is not None:
            h += nn.Dense(out_ch,
                          kernel_init=default_init())(self.act(temb))[:, None,
                                                                      None, :]

        h = self.act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h))
        h = nn.Dropout(self.dropout)(h, deterministic=not train)
        h = conv3x3(h, out_ch, init_scale=self.init_scale)
        if C != out_ch:
            if self.conv_shortcut:
                x = conv3x3(x, out_ch)
            else:
                x = NIN(out_ch)(x)

        if not self.skip_rescale:
            return x + h
        else:
            return (x + h) / np.sqrt(2.)
Exemple #11
0
  def __call__(self, z):
    shape_before_flattening, flatten_out_size = self.flatten_enc_shape()

    x = nn.Dense(flatten_out_size, name='fc1')(z)
    x = x.reshape((x.shape[0], *shape_before_flattening[1:]))
    
    hidden_dims = self.hidden_dims[::-1]
    # Build Decoder
    for h_dim in range(len(hidden_dims)-1):
      x = nn.ConvTranspose(features=hidden_dims[h_dim], kernel_size=(3, 3), strides=(2,2))(x)
      x = nn.GroupNorm()(x)
      x = nn.gelu(x)
    
    x = nn.ConvTranspose(features=3, kernel_size=(3, 3), strides=(2,2))(x)
    x = nn.sigmoid(x)
    return x
Exemple #12
0
  def test_group_norm(self):
    rng = random.PRNGKey(0)
    key1, key2 = random.split(rng)
    e = 1e-5
    x = random.normal(key1, (2, 5, 4, 4, 32))
    model_cls = nn.GroupNorm(num_groups=2, use_bias=False, use_scale=False, epsilon=e)

    y, _ = model_cls.init_with_output(key2, x)
    self.assertEqual(x.shape, y.shape)
    self.assertIsInstance(y, type(x))

    x_gr = x.reshape([2, 5, 4, 4, 2, 16])
    y_test = ((x_gr - x_gr.mean(axis=[1, 2, 3, 5], keepdims=True)) *
              jax.lax.rsqrt(x_gr.var(axis=[1, 2, 3, 5], keepdims=True) + e))
    y_test = y_test.reshape([2, 5, 4, 4, 32])

    np.testing.assert_allclose(y_test, y, atol=1e-4)
Exemple #13
0
    def __call__(self, x):
        B, H, W, C = x.shape
        h = nn.GroupNorm(num_groups=min(x.shape[-1] // 4, 32))(x)
        q = NIN(C)(h)
        k = NIN(C)(h)
        v = NIN(C)(h)

        w = jnp.einsum('bhwc,bHWc->bhwHW', q, k) * (int(C)**(-0.5))
        w = jnp.reshape(w, (B, H, W, H * W))
        w = jax.nn.softmax(w, axis=-1)
        w = jnp.reshape(w, (B, H, W, H, W))
        h = jnp.einsum('bhwHW,bHWc->bhwc', w, v)
        h = NIN(C, init_scale=self.init_scale)(h)
        if not self.skip_rescale:
            return x + h
        else:
            return (x + h) / np.sqrt(2.)
    def __call__(self, inputs, *, train):

        x = inputs
        # (Possibly partial) ResNet root.
        if self.resnet is not None:
            width = int(64 * self.resnet.width_factor)

            # Root block.
            x = models_resnet.StdConv(features=width,
                                      kernel_size=(7, 7),
                                      strides=(2, 2),
                                      use_bias=False,
                                      name='conv_root')(x)
            x = nn.GroupNorm(name='gn_root')(x)
            x = nn.relu(x)
            x = nn.max_pool(x,
                            window_shape=(3, 3),
                            strides=(2, 2),
                            padding='SAME')

            # ResNet stages.
            if self.resnet.num_layers:
                x = models_resnet.ResNetStage(
                    block_size=self.resnet.num_layers[0],
                    nout=width,
                    first_stride=(1, 1),
                    name='block1')(x)
                for i, block_size in enumerate(self.resnet.num_layers[1:], 1):
                    x = models_resnet.ResNetStage(block_size=block_size,
                                                  nout=width * 2**i,
                                                  first_stride=(2, 2),
                                                  name=f'block{i + 1}')(x)

        n, h, w, c = x.shape

        # We can merge s2d+emb into a single conv; it's the same.
        x = nn.Conv(features=self.hidden_size,
                    kernel_size=self.patches.size,
                    strides=self.patches.size,
                    padding='VALID',
                    name='embedding')(x)

        # Here, x is a grid of embeddings.

        # Transformer.
        n, h, w, c = x.shape
        x = jnp.reshape(x, [n, h * w, c])

        # If we want to add a class token, add it here.
        if self.classifier == 'token':
            cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
            cls = jnp.tile(cls, [n, 1, 1])
            x = jnp.concatenate([cls, x], axis=1)

        x = Encoder(name='Transformer', **self.transformer)(x, train=train)

        if self.classifier == 'token':
            x = x[:, 0]
        elif self.classifier == 'gap':
            x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)
        else:
            raise ValueError(f'Invalid classifier={self.classifier}')

        if self.representation_size is not None:
            x = nn.Dense(features=self.representation_size,
                         name='pre_logits')(x)
            x = nn.tanh(x)
        else:
            x = IdentityLayer(name='pre_logits')(x)

        if self.num_classes:
            x = nn.Dense(features=self.num_classes,
                         name='head',
                         kernel_init=nn.initializers.zeros)(x)
        return x
Exemple #15
0
    def __call__(self, x, time_cond, train=True):
        # config parsing
        config = self.config
        act = get_act(config)
        sigmas = utils.get_sigmas(config)

        nf = config.model.nf
        ch_mult = config.model.ch_mult
        num_res_blocks = config.model.num_res_blocks
        attn_resolutions = config.model.attn_resolutions
        dropout = config.model.dropout
        resamp_with_conv = config.model.resamp_with_conv
        num_resolutions = len(ch_mult)

        conditional = config.model.conditional  # noise-conditional
        fir = config.model.fir
        fir_kernel = config.model.fir_kernel
        skip_rescale = config.model.skip_rescale
        resblock_type = config.model.resblock_type.lower()
        progressive = config.model.progressive.lower()
        progressive_input = config.model.progressive_input.lower()
        embedding_type = config.model.embedding_type.lower()
        init_scale = config.model.init_scale
        assert progressive in ['none', 'output_skip', 'residual']
        assert progressive_input in ['none', 'input_skip', 'residual']
        assert embedding_type in ['fourier', 'positional']
        combine_method = config.model.progressive_combine.lower()
        combiner = functools.partial(Combine, method=combine_method)

        # timestep/noise_level embedding; only for continuous training
        if embedding_type == 'fourier':
            # Gaussian Fourier features embeddings.
            assert config.training.continuous, "Fourier features are only used for continuous training."
            used_sigmas = time_cond
            temb = layerspp.GaussianFourierProjection(
                embedding_size=nf,
                scale=config.model.fourier_scale)(jnp.log(used_sigmas))

        elif embedding_type == 'positional':
            # Sinusoidal positional embeddings.
            timesteps = time_cond
            used_sigmas = sigmas[time_cond.astype(jnp.int32)]
            temb = layers.get_timestep_embedding(timesteps, nf)
        else:
            raise ValueError(f'embedding type {embedding_type} unknown.')

        if conditional:
            temb = nn.Dense(nf * 4, kernel_init=default_initializer())(temb)
            temb = nn.Dense(nf * 4,
                            kernel_init=default_initializer())(act(temb))
        else:
            temb = None

        AttnBlock = functools.partial(layerspp.AttnBlockpp,
                                      init_scale=init_scale,
                                      skip_rescale=skip_rescale)

        Upsample = functools.partial(layerspp.Upsample,
                                     with_conv=resamp_with_conv,
                                     fir=fir,
                                     fir_kernel=fir_kernel)

        if progressive == 'output_skip':
            pyramid_upsample = functools.partial(layerspp.Upsample,
                                                 fir=fir,
                                                 fir_kernel=fir_kernel,
                                                 with_conv=False)
        elif progressive == 'residual':
            pyramid_upsample = functools.partial(layerspp.Upsample,
                                                 fir=fir,
                                                 fir_kernel=fir_kernel,
                                                 with_conv=True)

        Downsample = functools.partial(layerspp.Downsample,
                                       with_conv=resamp_with_conv,
                                       fir=fir,
                                       fir_kernel=fir_kernel)

        if progressive_input == 'input_skip':
            pyramid_downsample = functools.partial(layerspp.Downsample,
                                                   fir=fir,
                                                   fir_kernel=fir_kernel,
                                                   with_conv=False)
        elif progressive_input == 'residual':
            pyramid_downsample = functools.partial(layerspp.Downsample,
                                                   fir=fir,
                                                   fir_kernel=fir_kernel,
                                                   with_conv=True)

        if resblock_type == 'ddpm':
            ResnetBlock = functools.partial(ResnetBlockDDPM,
                                            act=act,
                                            dropout=dropout,
                                            init_scale=init_scale,
                                            skip_rescale=skip_rescale)

        elif resblock_type == 'biggan':
            ResnetBlock = functools.partial(ResnetBlockBigGAN,
                                            act=act,
                                            dropout=dropout,
                                            fir=fir,
                                            fir_kernel=fir_kernel,
                                            init_scale=init_scale,
                                            skip_rescale=skip_rescale)

        else:
            raise ValueError(f'resblock type {resblock_type} unrecognized.')

        if not config.data.centered:
            # If input data is in [0, 1]
            x = 2 * x - 1.

        # Downsampling block

        input_pyramid = None
        if progressive_input != 'none':
            input_pyramid = x

        hs = [conv3x3(x, nf)]
        for i_level in range(num_resolutions):
            # Residual blocks for this resolution
            for i_block in range(num_res_blocks):
                h = ResnetBlock(out_ch=nf * ch_mult[i_level])(hs[-1], temb,
                                                              train)
                if h.shape[1] in attn_resolutions:
                    h = AttnBlock()(h)
                hs.append(h)

            if i_level != num_resolutions - 1:
                if resblock_type == 'ddpm':
                    h = Downsample()(hs[-1])
                else:
                    h = ResnetBlock(down=True)(hs[-1], temb, train)

                if progressive_input == 'input_skip':
                    input_pyramid = pyramid_downsample()(input_pyramid)
                    h = combiner()(input_pyramid, h)

                elif progressive_input == 'residual':
                    input_pyramid = pyramid_downsample(
                        out_ch=h.shape[-1])(input_pyramid)
                    if skip_rescale:
                        input_pyramid = (input_pyramid + h) / np.sqrt(2.)
                    else:
                        input_pyramid = input_pyramid + h
                    h = input_pyramid

                hs.append(h)

        h = hs[-1]
        h = ResnetBlock()(h, temb, train)
        h = AttnBlock()(h)
        h = ResnetBlock()(h, temb, train)

        pyramid = None

        # Upsampling block
        for i_level in reversed(range(num_resolutions)):
            for i_block in range(num_res_blocks + 1):
                h = ResnetBlock(out_ch=nf * ch_mult[i_level])(jnp.concatenate(
                    [h, hs.pop()], axis=-1), temb, train)

            if h.shape[1] in attn_resolutions:
                h = AttnBlock()(h)

            if progressive != 'none':
                if i_level == num_resolutions - 1:
                    if progressive == 'output_skip':
                        pyramid = conv3x3(act(
                            nn.GroupNorm(num_groups=min(h.shape[-1] //
                                                        4, 32))(h)),
                                          x.shape[-1],
                                          bias=True,
                                          init_scale=init_scale)
                    elif progressive == 'residual':
                        pyramid = conv3x3(act(
                            nn.GroupNorm(num_groups=min(h.shape[-1] //
                                                        4, 32))(h)),
                                          h.shape[-1],
                                          bias=True)
                    else:
                        raise ValueError(f'{progressive} is not a valid name.')
                else:
                    if progressive == 'output_skip':
                        pyramid = pyramid_upsample()(pyramid)
                        pyramid = pyramid + conv3x3(act(
                            nn.GroupNorm(num_groups=min(h.shape[-1] //
                                                        4, 32))(h)),
                                                    x.shape[-1],
                                                    bias=True,
                                                    init_scale=init_scale)
                    elif progressive == 'residual':
                        pyramid = pyramid_upsample(out_ch=h.shape[-1])(pyramid)
                        if skip_rescale:
                            pyramid = (pyramid + h) / np.sqrt(2.)
                        else:
                            pyramid = pyramid + h
                        h = pyramid
                    else:
                        raise ValueError(f'{progressive} is not a valid name')

            if i_level != 0:
                if resblock_type == 'ddpm':
                    h = Upsample()(h)
                else:
                    h = ResnetBlock(up=True)(h, temb, train)

        assert not hs

        if progressive == 'output_skip':
            h = pyramid
        else:
            h = act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h))
            h = conv3x3(h, x.shape[-1], init_scale=init_scale)

        if config.model.scale_by_sigma:
            used_sigmas = used_sigmas.reshape(
                (x.shape[0], *([1] * len(x.shape[1:]))))
            h = h / used_sigmas

        return h
    def __call__(
            self,
            x: jnp.ndarray,
            train: bool = True,
            debug: bool = False) -> Union[jnp.ndarray, Dict[str, jnp.ndarray]]:
        """Applies the Bit ResNet model to the inputs.

    Args:
      x: Inputs to the model.
      train: Unused.
      debug: Unused.

    Returns:
       Un-normalized logits if `num_outputs` is provided, a dictionary with
       representations otherwise.
    """
        del train
        del debug
        if self.max_output_stride not in [4, 8, 16, 32]:
            raise ValueError('Only supports output strides of [4, 8, 16, 32]')

        blocks, bottleneck = _BLOCK_SIZE_OPTIONS[self.num_layers]

        width = int(64 * self.width_factor)

        # Root block.
        x = StdConv(width, (7, 7), (2, 2), use_bias=False, name='conv_root')(x)
        x = nn.GroupNorm(num_groups=self.gn_num_groups,
                         epsilon=1e-4,
                         name='gn_root')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        representations = {'stem': x}

        # Stages.
        x = ResNetStage(blocks[0],
                        width,
                        first_stride=(1, 1),
                        bottleneck=bottleneck,
                        gn_num_groups=self.gn_num_groups,
                        name='block1')(x)
        stride = 4
        for i, block_size in enumerate(blocks[1:], 1):
            max_stride_reached = self.max_output_stride <= stride
            x = ResNetStage(block_size,
                            width * 2**i,
                            first_stride=(2, 2) if not max_stride_reached else
                            (1, 1),
                            first_dilation=(2, 2) if max_stride_reached else
                            (1, 1),
                            bottleneck=bottleneck,
                            gn_num_groups=self.gn_num_groups,
                            name=f'block{i + 1}')(x)
            if not max_stride_reached:
                stride *= 2
            representations[f'stage_{i + 1}'] = x

        # Head.
        x = jnp.mean(x, axis=(1, 2))
        x = IdentityLayer(name='pre_logits')(x)
        representations['pre_logits'] = x
        x = nn.Dense(self.num_outputs,
                     kernel_init=nn.initializers.zeros,
                     name='head')(x)
        return x, representations
Exemple #17
0
 def setup(self):
     self.straight1 = nn.Conv(12, (3, 3), strides=(1, 1), use_bias=True)
     self.straight2 = nn.Conv(32, (3, 3), strides=(1, 1), use_bias=True)
     self.straight3 = nn.Conv(3, (3, 3), strides=(1, 1), use_bias=True)
     self.groupnorm1 = nn.GroupNorm(1)
     self.groupnorm2 = nn.GroupNorm(8)
Exemple #18
0
    def exec_op(self, op, input_values, deterministic, training, **_):
        """Executes an op according to the normal concrete semantics."""
        input_kwargs: Dict[str, Any] = op.input_kwargs
        op_kwargs: Dict[str, Any] = op.op_kwargs
        op_type = op.type
        if "name" not in op_kwargs:
            raise ValueError("Op kwargs must contain a name.")
        op_name = op_kwargs["name"]

        if op_type == OpType.NONE:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            assert len(op_kwargs) == 1
            output_values = [lax.stop_gradient(input_value)]

        elif op_type == OpType.IDENTITY:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            assert len(op_kwargs) == 1
            output_values = [input_value]

        # nn.linear

        elif op_type == OpType.DENSE:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            output_values = [nn.Dense(**op_kwargs)(input_value)]

        elif op_type == OpType.DENSE_GENERAL:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            assert 2 <= len(op_kwargs) <= 7
            output_values = [nn.DenseGeneral(**op_kwargs)(input_value)]

        elif op_type == OpType.CONV:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs

            ks = op_kwargs["kernel_size"]
            if isinstance(ks, int):
                op_kwargs["kernel_size"] = (ks, ) * (input_value.ndim - 2)

            output_values = [nn.Conv(**op_kwargs)(input_value)]

        # others

        elif op_type == OpType.MUL:
            assert len(input_values) == 2
            assert not input_kwargs
            assert len(op_kwargs) == 1  # name
            output_values = [input_values[0] * input_values[1]]

        elif op_type in [OpType.ADD, OpType.STOCH_DEPTH]:
            assert len(op_kwargs) == 1  # name

            input_value = input_values[0]
            if "layer_drop_rate" in input_kwargs:
                assert len(input_kwargs) == 1
                survival_rate = 1 - input_kwargs["layer_drop_rate"]
                if survival_rate == 1.0 or deterministic:
                    pass
                else:
                    # Reuse dropout's rng stream.
                    rng = self.make_rng("dropout")
                    mask_shape = [input_value.shape[0]
                                  ] + [1] * (input_value.ndim - 1)
                    mask = random.bernoulli(rng,
                                            p=survival_rate,
                                            shape=mask_shape)
                    mask = jnp.tile(mask, [1] + list(input_value.shape[1:]))
                    input_value = lax.select(mask, input_value / survival_rate,
                                             jnp.zeros_like(input_value))
            else:
                assert not input_kwargs
                assert op_type == OpType.ADD

            if op_type == OpType.ADD:
                assert len(input_values) == 2
                output_values = [input_value + input_values[1]]
            else:
                assert len(input_values) == 1
                output_values = [input_value]

        elif op_type == OpType.SCALAR_MUL:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            assert len(op_kwargs) == 1  # name
            if "const" in input_kwargs:
                c = input_kwargs["const"]
            else:
                c = 1 / jnp.sqrt(input_values[0].shape[-1])
            output_values = [input_values[0] * c]

        elif op_type == OpType.SCALAR_ADD:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            assert len(op_kwargs) == 1  # name
            assert "const" in input_kwargs
            c = input_kwargs["const"]
            output_values = [input_values[0] + c]

        elif op_type == OpType.DOT_GENERAL:
            assert len(input_values) == 2
            assert 0 < len(input_kwargs) <= 3
            assert len(op_kwargs) == 1  # name
            output_values = [
                lax.dot_general(input_values[0], input_values[1],
                                **input_kwargs)
            ]

        elif op_type == OpType.EINSUM:
            assert len(input_values) == 2
            assert len(input_kwargs) == 1
            assert "sum" in input_kwargs
            output_values = [
                jnp.einsum(input_kwargs["sum"], input_values[0],
                           input_values[1])
            ]

        # nn.attention

        elif op_type == OpType.SELF_ATTENTION:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            output_values = [
                nn.SelfAttention(**op_kwargs,
                                 deterministic=deterministic)(input_value)
            ]

        # nn.activation

        elif op_type in [
                OpType.RELU, OpType.GELU, OpType.SWISH, OpType.SIGMOID
        ]:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            fn = {
                OpType.RELU: nn.relu,
                OpType.GELU: nn.gelu,
                OpType.SWISH: nn.swish,
                OpType.SIGMOID: nn.sigmoid
            }[op_type]
            output_values = [fn(input_value)]

        elif op_type == OpType.SOFTMAX:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            output_values = [nn.softmax(input_value, **input_kwargs)]

        # nn.normalization

        elif op_type == OpType.BATCH_NORM:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            add_kwargs = {}
            if "use_running_average" not in input_kwargs:
                add_kwargs = {"use_running_average": not training}
            else:
                add_kwargs = {}
            output_values = [
                nn.BatchNorm(**op_kwargs)(input_value, **input_kwargs,
                                          **add_kwargs)
            ]

        elif op_type == OpType.LAYER_NORM:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            output_values = [nn.LayerNorm(**op_kwargs)(input_value)]

        elif op_type == OpType.GROUP_NORM:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            output_values = [nn.GroupNorm(**op_kwargs)(input_value)]

        # reshape operators

        elif op_type == OpType.RESHAPE:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert 0 < len(input_kwargs) < 3
            new_shape = input_kwargs.pop("new_shape")
            if new_shape[0] == "B":
                new_shape = (input_value.shape[0], ) + new_shape[1:]
            output_values = [
                jnp.reshape(input_value, new_shape, **input_kwargs)
            ]

        elif op_type == OpType.FLATTEN:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            new_shape = (input_value.shape[0], -1)
            output_values = [jnp.reshape(input_value, new_shape)]

        elif op_type == OpType.TRANSPOSE:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) == 1
            assert len(op_kwargs) == 1  # name
            output_values = [jnp.transpose(input_value, **input_kwargs)]

        # nn.stochastic

        elif op_type == OpType.DROPOUT:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            output_values = [
                nn.Dropout(**op_kwargs)(input_value,
                                        deterministic=deterministic,
                                        **input_kwargs)
            ]

        # nn.pooling

        elif op_type == OpType.AVG_POOL or op_type == OpType.MAX_POOL:
            op_fn = nn.avg_pool if op_type == OpType.AVG_POOL else nn.max_pool
            assert len(input_values) == 1
            input_value = input_values[0]
            assert input_kwargs

            ws = input_kwargs["window_shape"]
            if isinstance(ws, int):
                ws = [ws] * (input_value.ndim - 2)
            new_ws = []
            for window_dim_shape, dim_shape in zip(ws, input_value.shape[1:]):
                if window_dim_shape == 0:
                    new_ws.append(dim_shape)
                else:
                    new_ws.append(window_dim_shape)
            input_kwargs["window_shape"] = tuple(new_ws)

            if "strides" in input_kwargs:
                s = input_kwargs["strides"]
                if isinstance(s, int):
                    input_kwargs["strides"] = (s, ) * (input_value.ndim - 2)

            output_values = [op_fn(input_value, **input_kwargs)]

        elif op_type == OpType.MEAN:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert input_kwargs
            output_values = [jnp.mean(input_value, **input_kwargs)]

        # new param

        elif op_type == OpType.PARAM:
            assert not input_values
            assert 0 < len(input_kwargs) <= 2
            init_fn = input_kwargs.pop("init_fn")

            init_fn_with_kwargs = functools.partial(init_fn, **input_kwargs)
            output_values = [self.param(op_name, init_fn_with_kwargs)]

        else:
            raise ValueError(f"op_type {op_type} not supported...")

        return output_values