def __call__(self, x): needs_projection = (x.shape[-1] != self.features * 4 or self.strides != (1, 1)) residual = x if needs_projection: residual = StdConv(features=self.features * 4, kernel_size=(1, 1), strides=self.strides, use_bias=False, name='conv_proj')(residual) residual = nn.GroupNorm(name='gn_proj')(residual) y = StdConv(features=self.features, kernel_size=(1, 1), use_bias=False, name='conv1')(x) y = nn.GroupNorm(name='gn1')(y) y = nn.relu(y) y = StdConv(features=self.features, kernel_size=(3, 3), strides=self.strides, use_bias=False, name='conv2')(y) y = nn.GroupNorm(name='gn2')(y) y = nn.relu(y) y = StdConv(features=self.features * 4, kernel_size=(1, 1), use_bias=False, name='conv3')(y) y = nn.GroupNorm(name='gn3', scale_init=nn.initializers.zeros)(y) y = nn.relu(residual + y) return y
def __call__(self, x): x = nn.Conv(features=28, kernel_size=(3, 3), strides=(2, 2))(x) x = nn.GroupNorm(28)(x) x = nn.gelu(x) x = nn.Conv(features=64, kernel_size=(3, 3), strides=(2, 2))(x) x = nn.GroupNorm(32)(x) x = nn.gelu(x) x = nn.Conv(features=64, kernel_size=(3, 3), strides=(2, 2))(x) x = nn.GroupNorm(32)(x) x = nn.gelu(x) x = x.reshape((x.shape[0], -1)) mean_x = nn.Dense(self.latent_dim, name='fc2_mean')(x) logvar_x = nn.Dense(self.latent_dim, name='fc2_logvar')(x) return mean_x, logvar_x
def __call__(self, z): shape_before_flattening, flatten_out_size = self.flatten_enc_shape() #print(shape_before_flattening, flatten_out_size) x = nn.Dense(flatten_out_size, name='fc1')(z) x = nn.gelu(x) x = x.reshape((x.shape[0], *shape_before_flattening[1:])) x = nn.ConvTranspose(features=32, kernel_size=(3, 3), strides=(2, 2))(x) x = nn.GroupNorm(32)(x) x = nn.gelu(x) x = nn.ConvTranspose(features=28, kernel_size=(3, 3), strides=(2, 2))(x) x = nn.GroupNorm(28)(x) x = nn.gelu(x) x = nn.ConvTranspose(features=1, kernel_size=(3, 3), strides=(2, 2))(x) return x
def activation(x, train, apply_relu=True, name=''): x = nn.GroupNorm(name=name, epsilon=1e-5, num_groups=min(x.shape[-1] // 4, 32))(x) if apply_relu: x = jax.nn.relu(x) return x
def __call__(self, x, temb=None, train=True): B, H, W, C = x.shape out_ch = self.out_ch if self.out_ch else C h = self.act(nn.GroupNorm(num_groups=min(x.shape[-1] // 4, 32))(x)) if self.up: if self.fir: h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2) x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) else: h = up_or_down_sampling.naive_upsample_2d(h, factor=2) x = up_or_down_sampling.naive_upsample_2d(x, factor=2) elif self.down: if self.fir: h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2) x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) else: h = up_or_down_sampling.naive_downsample_2d(h, factor=2) x = up_or_down_sampling.naive_downsample_2d(x, factor=2) h = conv3x3(h, out_ch) # Add bias to each feature map conditioned on the time embedding if temb is not None: h += nn.Dense(out_ch, kernel_init=default_init())(self.act(temb))[:, None, None, :] h = self.act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h)) h = nn.Dropout(self.dropout)(h, deterministic=not train) h = conv3x3(h, out_ch, init_scale=self.init_scale) if C != out_ch or self.up or self.down: x = conv1x1(x, out_ch) if not self.skip_rescale: return x + h else: return (x + h) / np.sqrt(2.)
def __call__(self, x): # Build Encoder for h_dim in self.hidden_dims: x = nn.Conv(features=h_dim, kernel_size=(3, 3), strides=(2,2), padding="valid")(x) x = nn.GroupNorm()(x) x = nn.gelu(x) x = x.reshape((x.shape[0], -1)) mean_x = nn.Dense(self.latent_dim, name='fc2_mean')(x) logvar_x = nn.Dense(self.latent_dim, name='fc2_logvar')(x) return mean_x, logvar_x
def setup(self): activation = nn.softplus if self.activation == 'softplus' else nn.relu if (self.group_norm): self.double_conv = Sequential([ nn.Conv(self.mid_channels, kernel_size=(3, 3), use_bias=False), nn.GroupNorm(self.num_groups), activation, nn.Conv(self.out_channels, kernel_size=(3, 3), use_bias=False), nn.GroupNorm(self.num_groups), activation, ]) else: self.double_conv = Sequential([ nn.Conv(self.mid_channels, kernel_size=(3, 3), use_bias=False), nn.BatchNorm(use_running_average=self.test), activation, nn.Conv(self.out_channels, kernel_size=(3, 3), use_bias=False), nn.BatchNorm(use_running_average=self.test), activation, ])
def test_group_norm_raises(self): rng = random.PRNGKey(0) key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (2, 5, 4, 4, 32)) model_cls = nn.GroupNorm(num_groups=3, use_bias=False, use_scale=False, epsilon=e) with self.assertRaises(ValueError): model_cls.init_with_output(key2, x)
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: features = self.nout nout = self.nout * 4 if self.bottleneck else self.nout needs_projection = x.shape[-1] != nout or self.strides != (1, 1) residual = x if needs_projection: residual = StdConv(nout, (1, 1), self.strides, use_bias=False, name='conv_proj')(residual) residual = nn.GroupNorm(num_groups=self.gn_num_groups, epsilon=1e-4, name='gn_proj')(residual) if self.bottleneck: x = StdConv(features, (1, 1), use_bias=False, name='conv1')(x) x = nn.GroupNorm(num_groups=self.gn_num_groups, epsilon=1e-4, name='gn1')(x) x = nn.relu(x) x = StdConv(features, (3, 3), self.strides, kernel_dilation=self.dilation, use_bias=False, name='conv2')(x) x = nn.GroupNorm(num_groups=self.gn_num_groups, epsilon=1e-4, name='gn2')(x) x = nn.relu(x) last_kernel = (1, 1) if self.bottleneck else (3, 3) x = StdConv(nout, last_kernel, use_bias=False, name='conv3')(x) x = nn.GroupNorm(num_groups=self.gn_num_groups, epsilon=1e-4, name='gn3', scale_init=nn.initializers.zeros)(x) x = nn.relu(residual + x) return x
def __call__(self, x, temb=None, train=True): B, H, W, C = x.shape out_ch = self.out_ch if self.out_ch else C h = self.act(nn.GroupNorm(num_groups=min(x.shape[-1] // 4, 32))(x)) h = conv3x3(h, out_ch) # Add bias to each feature map conditioned on the time embedding if temb is not None: h += nn.Dense(out_ch, kernel_init=default_init())(self.act(temb))[:, None, None, :] h = self.act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h)) h = nn.Dropout(self.dropout)(h, deterministic=not train) h = conv3x3(h, out_ch, init_scale=self.init_scale) if C != out_ch: if self.conv_shortcut: x = conv3x3(x, out_ch) else: x = NIN(out_ch)(x) if not self.skip_rescale: return x + h else: return (x + h) / np.sqrt(2.)
def __call__(self, z): shape_before_flattening, flatten_out_size = self.flatten_enc_shape() x = nn.Dense(flatten_out_size, name='fc1')(z) x = x.reshape((x.shape[0], *shape_before_flattening[1:])) hidden_dims = self.hidden_dims[::-1] # Build Decoder for h_dim in range(len(hidden_dims)-1): x = nn.ConvTranspose(features=hidden_dims[h_dim], kernel_size=(3, 3), strides=(2,2))(x) x = nn.GroupNorm()(x) x = nn.gelu(x) x = nn.ConvTranspose(features=3, kernel_size=(3, 3), strides=(2,2))(x) x = nn.sigmoid(x) return x
def test_group_norm(self): rng = random.PRNGKey(0) key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (2, 5, 4, 4, 32)) model_cls = nn.GroupNorm(num_groups=2, use_bias=False, use_scale=False, epsilon=e) y, _ = model_cls.init_with_output(key2, x) self.assertEqual(x.shape, y.shape) self.assertIsInstance(y, type(x)) x_gr = x.reshape([2, 5, 4, 4, 2, 16]) y_test = ((x_gr - x_gr.mean(axis=[1, 2, 3, 5], keepdims=True)) * jax.lax.rsqrt(x_gr.var(axis=[1, 2, 3, 5], keepdims=True) + e)) y_test = y_test.reshape([2, 5, 4, 4, 32]) np.testing.assert_allclose(y_test, y, atol=1e-4)
def __call__(self, x): B, H, W, C = x.shape h = nn.GroupNorm(num_groups=min(x.shape[-1] // 4, 32))(x) q = NIN(C)(h) k = NIN(C)(h) v = NIN(C)(h) w = jnp.einsum('bhwc,bHWc->bhwHW', q, k) * (int(C)**(-0.5)) w = jnp.reshape(w, (B, H, W, H * W)) w = jax.nn.softmax(w, axis=-1) w = jnp.reshape(w, (B, H, W, H, W)) h = jnp.einsum('bhwHW,bHWc->bhwc', w, v) h = NIN(C, init_scale=self.init_scale)(h) if not self.skip_rescale: return x + h else: return (x + h) / np.sqrt(2.)
def __call__(self, inputs, *, train): x = inputs # (Possibly partial) ResNet root. if self.resnet is not None: width = int(64 * self.resnet.width_factor) # Root block. x = models_resnet.StdConv(features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name='conv_root')(x) x = nn.GroupNorm(name='gn_root')(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding='SAME') # ResNet stages. if self.resnet.num_layers: x = models_resnet.ResNetStage( block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name='block1')(x) for i, block_size in enumerate(self.resnet.num_layers[1:], 1): x = models_resnet.ResNetStage(block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f'block{i + 1}')(x) n, h, w, c = x.shape # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv(features=self.hidden_size, kernel_size=self.patches.size, strides=self.patches.size, padding='VALID', name='embedding')(x) # Here, x is a grid of embeddings. # Transformer. n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if self.classifier == 'token': cls = self.param('cls', nn.initializers.zeros, (1, 1, c)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = Encoder(name='Transformer', **self.transformer)(x, train=train) if self.classifier == 'token': x = x[:, 0] elif self.classifier == 'gap': x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) else: raise ValueError(f'Invalid classifier={self.classifier}') if self.representation_size is not None: x = nn.Dense(features=self.representation_size, name='pre_logits')(x) x = nn.tanh(x) else: x = IdentityLayer(name='pre_logits')(x) if self.num_classes: x = nn.Dense(features=self.num_classes, name='head', kernel_init=nn.initializers.zeros)(x) return x
def __call__(self, x, time_cond, train=True): # config parsing config = self.config act = get_act(config) sigmas = utils.get_sigmas(config) nf = config.model.nf ch_mult = config.model.ch_mult num_res_blocks = config.model.num_res_blocks attn_resolutions = config.model.attn_resolutions dropout = config.model.dropout resamp_with_conv = config.model.resamp_with_conv num_resolutions = len(ch_mult) conditional = config.model.conditional # noise-conditional fir = config.model.fir fir_kernel = config.model.fir_kernel skip_rescale = config.model.skip_rescale resblock_type = config.model.resblock_type.lower() progressive = config.model.progressive.lower() progressive_input = config.model.progressive_input.lower() embedding_type = config.model.embedding_type.lower() init_scale = config.model.init_scale assert progressive in ['none', 'output_skip', 'residual'] assert progressive_input in ['none', 'input_skip', 'residual'] assert embedding_type in ['fourier', 'positional'] combine_method = config.model.progressive_combine.lower() combiner = functools.partial(Combine, method=combine_method) # timestep/noise_level embedding; only for continuous training if embedding_type == 'fourier': # Gaussian Fourier features embeddings. assert config.training.continuous, "Fourier features are only used for continuous training." used_sigmas = time_cond temb = layerspp.GaussianFourierProjection( embedding_size=nf, scale=config.model.fourier_scale)(jnp.log(used_sigmas)) elif embedding_type == 'positional': # Sinusoidal positional embeddings. timesteps = time_cond used_sigmas = sigmas[time_cond.astype(jnp.int32)] temb = layers.get_timestep_embedding(timesteps, nf) else: raise ValueError(f'embedding type {embedding_type} unknown.') if conditional: temb = nn.Dense(nf * 4, kernel_init=default_initializer())(temb) temb = nn.Dense(nf * 4, kernel_init=default_initializer())(act(temb)) else: temb = None AttnBlock = functools.partial(layerspp.AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale) Upsample = functools.partial(layerspp.Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) if progressive == 'output_skip': pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir, fir_kernel=fir_kernel, with_conv=False) elif progressive == 'residual': pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir, fir_kernel=fir_kernel, with_conv=True) Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) if progressive_input == 'input_skip': pyramid_downsample = functools.partial(layerspp.Downsample, fir=fir, fir_kernel=fir_kernel, with_conv=False) elif progressive_input == 'residual': pyramid_downsample = functools.partial(layerspp.Downsample, fir=fir, fir_kernel=fir_kernel, with_conv=True) if resblock_type == 'ddpm': ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, dropout=dropout, init_scale=init_scale, skip_rescale=skip_rescale) elif resblock_type == 'biggan': ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act, dropout=dropout, fir=fir, fir_kernel=fir_kernel, init_scale=init_scale, skip_rescale=skip_rescale) else: raise ValueError(f'resblock type {resblock_type} unrecognized.') if not config.data.centered: # If input data is in [0, 1] x = 2 * x - 1. # Downsampling block input_pyramid = None if progressive_input != 'none': input_pyramid = x hs = [conv3x3(x, nf)] for i_level in range(num_resolutions): # Residual blocks for this resolution for i_block in range(num_res_blocks): h = ResnetBlock(out_ch=nf * ch_mult[i_level])(hs[-1], temb, train) if h.shape[1] in attn_resolutions: h = AttnBlock()(h) hs.append(h) if i_level != num_resolutions - 1: if resblock_type == 'ddpm': h = Downsample()(hs[-1]) else: h = ResnetBlock(down=True)(hs[-1], temb, train) if progressive_input == 'input_skip': input_pyramid = pyramid_downsample()(input_pyramid) h = combiner()(input_pyramid, h) elif progressive_input == 'residual': input_pyramid = pyramid_downsample( out_ch=h.shape[-1])(input_pyramid) if skip_rescale: input_pyramid = (input_pyramid + h) / np.sqrt(2.) else: input_pyramid = input_pyramid + h h = input_pyramid hs.append(h) h = hs[-1] h = ResnetBlock()(h, temb, train) h = AttnBlock()(h) h = ResnetBlock()(h, temb, train) pyramid = None # Upsampling block for i_level in reversed(range(num_resolutions)): for i_block in range(num_res_blocks + 1): h = ResnetBlock(out_ch=nf * ch_mult[i_level])(jnp.concatenate( [h, hs.pop()], axis=-1), temb, train) if h.shape[1] in attn_resolutions: h = AttnBlock()(h) if progressive != 'none': if i_level == num_resolutions - 1: if progressive == 'output_skip': pyramid = conv3x3(act( nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h)), x.shape[-1], bias=True, init_scale=init_scale) elif progressive == 'residual': pyramid = conv3x3(act( nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h)), h.shape[-1], bias=True) else: raise ValueError(f'{progressive} is not a valid name.') else: if progressive == 'output_skip': pyramid = pyramid_upsample()(pyramid) pyramid = pyramid + conv3x3(act( nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h)), x.shape[-1], bias=True, init_scale=init_scale) elif progressive == 'residual': pyramid = pyramid_upsample(out_ch=h.shape[-1])(pyramid) if skip_rescale: pyramid = (pyramid + h) / np.sqrt(2.) else: pyramid = pyramid + h h = pyramid else: raise ValueError(f'{progressive} is not a valid name') if i_level != 0: if resblock_type == 'ddpm': h = Upsample()(h) else: h = ResnetBlock(up=True)(h, temb, train) assert not hs if progressive == 'output_skip': h = pyramid else: h = act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h)) h = conv3x3(h, x.shape[-1], init_scale=init_scale) if config.model.scale_by_sigma: used_sigmas = used_sigmas.reshape( (x.shape[0], *([1] * len(x.shape[1:])))) h = h / used_sigmas return h
def __call__( self, x: jnp.ndarray, train: bool = True, debug: bool = False) -> Union[jnp.ndarray, Dict[str, jnp.ndarray]]: """Applies the Bit ResNet model to the inputs. Args: x: Inputs to the model. train: Unused. debug: Unused. Returns: Un-normalized logits if `num_outputs` is provided, a dictionary with representations otherwise. """ del train del debug if self.max_output_stride not in [4, 8, 16, 32]: raise ValueError('Only supports output strides of [4, 8, 16, 32]') blocks, bottleneck = _BLOCK_SIZE_OPTIONS[self.num_layers] width = int(64 * self.width_factor) # Root block. x = StdConv(width, (7, 7), (2, 2), use_bias=False, name='conv_root')(x) x = nn.GroupNorm(num_groups=self.gn_num_groups, epsilon=1e-4, name='gn_root')(x) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') representations = {'stem': x} # Stages. x = ResNetStage(blocks[0], width, first_stride=(1, 1), bottleneck=bottleneck, gn_num_groups=self.gn_num_groups, name='block1')(x) stride = 4 for i, block_size in enumerate(blocks[1:], 1): max_stride_reached = self.max_output_stride <= stride x = ResNetStage(block_size, width * 2**i, first_stride=(2, 2) if not max_stride_reached else (1, 1), first_dilation=(2, 2) if max_stride_reached else (1, 1), bottleneck=bottleneck, gn_num_groups=self.gn_num_groups, name=f'block{i + 1}')(x) if not max_stride_reached: stride *= 2 representations[f'stage_{i + 1}'] = x # Head. x = jnp.mean(x, axis=(1, 2)) x = IdentityLayer(name='pre_logits')(x) representations['pre_logits'] = x x = nn.Dense(self.num_outputs, kernel_init=nn.initializers.zeros, name='head')(x) return x, representations
def setup(self): self.straight1 = nn.Conv(12, (3, 3), strides=(1, 1), use_bias=True) self.straight2 = nn.Conv(32, (3, 3), strides=(1, 1), use_bias=True) self.straight3 = nn.Conv(3, (3, 3), strides=(1, 1), use_bias=True) self.groupnorm1 = nn.GroupNorm(1) self.groupnorm2 = nn.GroupNorm(8)
def exec_op(self, op, input_values, deterministic, training, **_): """Executes an op according to the normal concrete semantics.""" input_kwargs: Dict[str, Any] = op.input_kwargs op_kwargs: Dict[str, Any] = op.op_kwargs op_type = op.type if "name" not in op_kwargs: raise ValueError("Op kwargs must contain a name.") op_name = op_kwargs["name"] if op_type == OpType.NONE: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs assert len(op_kwargs) == 1 output_values = [lax.stop_gradient(input_value)] elif op_type == OpType.IDENTITY: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs assert len(op_kwargs) == 1 output_values = [input_value] # nn.linear elif op_type == OpType.DENSE: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs output_values = [nn.Dense(**op_kwargs)(input_value)] elif op_type == OpType.DENSE_GENERAL: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs assert 2 <= len(op_kwargs) <= 7 output_values = [nn.DenseGeneral(**op_kwargs)(input_value)] elif op_type == OpType.CONV: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs ks = op_kwargs["kernel_size"] if isinstance(ks, int): op_kwargs["kernel_size"] = (ks, ) * (input_value.ndim - 2) output_values = [nn.Conv(**op_kwargs)(input_value)] # others elif op_type == OpType.MUL: assert len(input_values) == 2 assert not input_kwargs assert len(op_kwargs) == 1 # name output_values = [input_values[0] * input_values[1]] elif op_type in [OpType.ADD, OpType.STOCH_DEPTH]: assert len(op_kwargs) == 1 # name input_value = input_values[0] if "layer_drop_rate" in input_kwargs: assert len(input_kwargs) == 1 survival_rate = 1 - input_kwargs["layer_drop_rate"] if survival_rate == 1.0 or deterministic: pass else: # Reuse dropout's rng stream. rng = self.make_rng("dropout") mask_shape = [input_value.shape[0] ] + [1] * (input_value.ndim - 1) mask = random.bernoulli(rng, p=survival_rate, shape=mask_shape) mask = jnp.tile(mask, [1] + list(input_value.shape[1:])) input_value = lax.select(mask, input_value / survival_rate, jnp.zeros_like(input_value)) else: assert not input_kwargs assert op_type == OpType.ADD if op_type == OpType.ADD: assert len(input_values) == 2 output_values = [input_value + input_values[1]] else: assert len(input_values) == 1 output_values = [input_value] elif op_type == OpType.SCALAR_MUL: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 assert len(op_kwargs) == 1 # name if "const" in input_kwargs: c = input_kwargs["const"] else: c = 1 / jnp.sqrt(input_values[0].shape[-1]) output_values = [input_values[0] * c] elif op_type == OpType.SCALAR_ADD: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 assert len(op_kwargs) == 1 # name assert "const" in input_kwargs c = input_kwargs["const"] output_values = [input_values[0] + c] elif op_type == OpType.DOT_GENERAL: assert len(input_values) == 2 assert 0 < len(input_kwargs) <= 3 assert len(op_kwargs) == 1 # name output_values = [ lax.dot_general(input_values[0], input_values[1], **input_kwargs) ] elif op_type == OpType.EINSUM: assert len(input_values) == 2 assert len(input_kwargs) == 1 assert "sum" in input_kwargs output_values = [ jnp.einsum(input_kwargs["sum"], input_values[0], input_values[1]) ] # nn.attention elif op_type == OpType.SELF_ATTENTION: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs output_values = [ nn.SelfAttention(**op_kwargs, deterministic=deterministic)(input_value) ] # nn.activation elif op_type in [ OpType.RELU, OpType.GELU, OpType.SWISH, OpType.SIGMOID ]: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs fn = { OpType.RELU: nn.relu, OpType.GELU: nn.gelu, OpType.SWISH: nn.swish, OpType.SIGMOID: nn.sigmoid }[op_type] output_values = [fn(input_value)] elif op_type == OpType.SOFTMAX: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 output_values = [nn.softmax(input_value, **input_kwargs)] # nn.normalization elif op_type == OpType.BATCH_NORM: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 add_kwargs = {} if "use_running_average" not in input_kwargs: add_kwargs = {"use_running_average": not training} else: add_kwargs = {} output_values = [ nn.BatchNorm(**op_kwargs)(input_value, **input_kwargs, **add_kwargs) ] elif op_type == OpType.LAYER_NORM: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs output_values = [nn.LayerNorm(**op_kwargs)(input_value)] elif op_type == OpType.GROUP_NORM: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs output_values = [nn.GroupNorm(**op_kwargs)(input_value)] # reshape operators elif op_type == OpType.RESHAPE: assert len(input_values) == 1 input_value = input_values[0] assert 0 < len(input_kwargs) < 3 new_shape = input_kwargs.pop("new_shape") if new_shape[0] == "B": new_shape = (input_value.shape[0], ) + new_shape[1:] output_values = [ jnp.reshape(input_value, new_shape, **input_kwargs) ] elif op_type == OpType.FLATTEN: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs new_shape = (input_value.shape[0], -1) output_values = [jnp.reshape(input_value, new_shape)] elif op_type == OpType.TRANSPOSE: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) == 1 assert len(op_kwargs) == 1 # name output_values = [jnp.transpose(input_value, **input_kwargs)] # nn.stochastic elif op_type == OpType.DROPOUT: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 output_values = [ nn.Dropout(**op_kwargs)(input_value, deterministic=deterministic, **input_kwargs) ] # nn.pooling elif op_type == OpType.AVG_POOL or op_type == OpType.MAX_POOL: op_fn = nn.avg_pool if op_type == OpType.AVG_POOL else nn.max_pool assert len(input_values) == 1 input_value = input_values[0] assert input_kwargs ws = input_kwargs["window_shape"] if isinstance(ws, int): ws = [ws] * (input_value.ndim - 2) new_ws = [] for window_dim_shape, dim_shape in zip(ws, input_value.shape[1:]): if window_dim_shape == 0: new_ws.append(dim_shape) else: new_ws.append(window_dim_shape) input_kwargs["window_shape"] = tuple(new_ws) if "strides" in input_kwargs: s = input_kwargs["strides"] if isinstance(s, int): input_kwargs["strides"] = (s, ) * (input_value.ndim - 2) output_values = [op_fn(input_value, **input_kwargs)] elif op_type == OpType.MEAN: assert len(input_values) == 1 input_value = input_values[0] assert input_kwargs output_values = [jnp.mean(input_value, **input_kwargs)] # new param elif op_type == OpType.PARAM: assert not input_values assert 0 < len(input_kwargs) <= 2 init_fn = input_kwargs.pop("init_fn") init_fn_with_kwargs = functools.partial(init_fn, **input_kwargs) output_values = [self.param(op_name, init_fn_with_kwargs)] else: raise ValueError(f"op_type {op_type} not supported...") return output_values