def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) return x
def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # Flatten x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) # There are 10 classes in MNIST x = nn.log_softmax(x) return x
def __call__(self, x): x = nn.Conv(features=16, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=16, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=256)(x) # return intermediate output x = TestDense()(x) x = nn.Dense(features=N_CLASSES)(x) return x
def __call__(self, x, with_classifier=True): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=256)(x) x = nn.relu(x) if not with_classifier: return x x = nn.Dense(features=10)(x) x = nn.log_softmax(x) return x
def __call__(self, x): B, H, W, C = x.shape out_ch = self.out_ch if self.out_ch else C if not self.fir: if self.with_conv: x = conv3x3(x, out_ch, stride=2) else: x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME') else: if not self.with_conv: x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) else: x = up_or_down_sampling.Conv2d(out_ch, kernel=3, down=True, resample_kernel=self.fir_kernel, use_bias=True, kernel_init=default_init())(x) assert x.shape == (B, H // 2, W // 2, out_ch) return x
def __call__(self, x, with_classifier=True): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) # TODO: replace np.prod(x.shape[1:]) with -1 once we fix shape_polymorphism x = x.reshape((x.shape[0], np.prod(x.shape[1:]))) # flatten x = nn.Dense(features=256)(x) x = nn.relu(x) if not with_classifier: return x x = nn.Dense(features=10)(x) x = nn.log_softmax(x) return x
def __call__(self, x): # Helper macro. R_ = lambda hidden_: ResidualUnit(hidden_features=hidden_, norm=self.norm, training=self.training, activation=nn.gelu) # First filter to make features. h = nn.Conv(features=self.hidden * self.alpha, use_bias=False, kernel_size=(3, 3), kernel_init=INITS[self.kernel_init])(x) h = NORMS[self.norm](use_running_average=not self.training)(h) h = nn.gelu(h) # 2 stages of continuous segments: h = ResidualStitch(hidden_features=self.hidden * self.alpha, output_features=self.hidden * self.alpha, strides=(1, 1), norm=self.norm, training=self.training, activation=nn.gelu)(h) h = StatefulContinuousBlock(R=R_(self.hidden * self.alpha), scheme=self.scheme, n_step=self.n_step, n_basis=self.n_basis, basis=self.basis, training=self.training)(h) # Pool and linearly classify: h = NORMS[self.norm](use_running_average=not self.training)(h) h = nn.gelu(h) h = nn.avg_pool(h, window_shape=(8, 8), strides=(8, 8)) h = h.reshape((h.shape[0], -1)) h = nn.Dense(features=self.n_classes)(h) return nn.log_softmax(h) # no softmax
def __call__(self, x): branch1x1 = self.conv_block(64, kernel_size=(1, 1), name='branch1x1')(x) branch5x5 = self.conv_block(48, kernel_size=(1, 1), name='branch5x5_1')(x) branch5x5 = self.conv_block(64, kernel_size=(5, 5), padding=[(2, 2), (2, 2)], name='branch5x5_2')(branch5x5) branch3x3dbl = self.conv_block(64, kernel_size=(1, 1), name='branch3x3dbl_1')(x) branch3x3dbl = self.conv_block(96, kernel_size=(3, 3), padding=[(1, 1), (1, 1)], name='branch3x3dbl_2')(branch3x3dbl) branch3x3dbl = self.conv_block(96, kernel_size=(3, 3), padding=[(1, 1), (1, 1)], name='branch3x3dbl_3')(branch3x3dbl) branch_pool = nn.avg_pool(x, (3, 3), strides=(1, 1), padding=[(1, 1), (1, 1)]) branch_pool = self.conv_block(self.pool_features, kernel_size=(1, 1), name='branch_pool')(branch_pool) outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] return jnp.concatenate(outputs, 3)
def __call__(self, inputs): """Passes the input through a bottleneck transformer block. Arguments: inputs: [batch_size, height, width, dim] Returns: output: [batch_size, height, width, dim * config.projection_factor] """ residual = inputs cfg = self.config y = self.conv(self.filters, kernel_size=(1, 1))(inputs) y = self.norm()(y) y = cfg.activation_fn(y) y = BoTMHSA(config=cfg)(y) if self.strides == (2, 2): y = nn.avg_pool(y, window_shape=(2, 2), strides=self.strides, padding='SAME') y = self.norm()(y) y = cfg.activation_fn(y) y = self.conv(self.filters * cfg.projection_factor, kernel_size=(1, 1))(y) y = self.norm(scale_init=initializers.zeros)(y) if self.strides == (2, 2) or residual.shape != y.shape: residual = self.conv(self.filters * cfg.projection_factor, kernel_size=(1, 1), strides=self.strides)(residual) residual = self.norm()(residual) residual = cfg.activation_fn(residual) y = cfg.activation_fn(residual + y) return y
def __call__(self, inputs): """Applies spherical pooling. Args: inputs: An array of dimensions (batch_size, resolution, resolution, n_spins_in, n_channels_in). Returns: An array of dimensions (batch_size, resolution // stride, resolution // stride, n_spins_in, n_channels_in). """ # We use variables to cache the in/out weights. resolution_in = inputs.shape[1] resolution_out = resolution_in // self.stride weights_in = sphere_utils.sphere_quadrature_weights(resolution_in) weights_out = sphere_utils.sphere_quadrature_weights(resolution_out) weighted = inputs * jnp.expand_dims(weights_in, (0, 2, 3, 4)) pooled = nn.avg_pool(weighted, window_shape=(self.stride, self.stride, 1), strides=(self.stride, self.stride, 1)) # This was average pooled. We multiply by stride**2 to obtain the sum # pooled, then divide by output weights to get the weighted average. pooled = (pooled * self.stride**2 / jnp.expand_dims(weights_out, (0, 2, 3, 4))) return pooled
def __call__(self, x, sigmas, train=True): # per image standardization N = np.prod(x.shape[1:]) x = (x - jnp.mean(x, axis=(1, 2, 3), keepdims=True)) / jnp.maximum( jnp.std(x, axis=(1, 2, 3), keepdims=True), 1. / np.sqrt(N)) temb = GaussianFourierProjection(embedding_size=128, scale=16)(jnp.log(sigmas)) temb = nn.Dense(128 * 4)(temb) temb = nn.Dense(128 * 4)(nn.swish(temb)) x = nn.Conv(16, (3, 3), padding='SAME', name='init_conv', kernel_init=conv_kernel_init_fn, use_bias=False)(x) x = WideResnetGroup(self.blocks_per_group, 16 * self.channel_multiplier, activate_before_residual=True)(x, temb, train) x = WideResnetGroup(self.blocks_per_group, 32 * self.channel_multiplier, (2, 2))(x, temb, train) x = WideResnetGroup(self.blocks_per_group, 64 * self.channel_multiplier, (2, 2))(x, temb, train) x = activation(x, train=train, name='pre-pool-bn') x = nn.avg_pool(x, x.shape[1:3]) x = x.reshape((x.shape[0], -1)) x = nn.Dense(self.num_outputs, kernel_init=dense_layer_init_fn)(x) return x
def __call__(self, x): for feat in self.features: x = nn.Conv(features=feat, kernel_size=(self.kernel_size, self.kernel_size))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten return x
def __call__(self, x): x = nn.avg_pool(x, (5, 5), strides=(3, 3)) x = self.conv_block(128, kernel_size=(1, 1), name='conv0')(x) x = self.conv_block(768, kernel_size=(5, 5), name='conv1')(x) x = x.transpose((0, 3, 1, 2)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(self.num_classes, name='fc')(x) return x
def __call__(self, x, train=True): down_sample_layers = [ConvBlock(self.chans, self.drop_prob)] ch = self.chans for _ in range(self.num_pool_layers - 1): down_sample_layers.append(ConvBlock(ch * 2, self.drop_prob)) ch *= 2 conv = ConvBlock(ch * 2, self.drop_prob) up_conv = [] up_transpose_conv = [] for _ in range(self.num_pool_layers - 1): up_transpose_conv.append(TransposeConvBlock(ch)) up_conv.append(ConvBlock(ch, self.drop_prob)) ch //= 2 up_transpose_conv.append(TransposeConvBlock(ch)) up_conv.append(ConvBlock(ch, self.drop_prob)) final_conv = nn.Conv(self.out_chans, kernel_size=(1, 1), strides=(1, 1)) stack = [] output = jnp.expand_dims(x, axis=-1) # apply down-sampling layers for layer in down_sample_layers: output = layer(output, train) stack.append(output) output = nn.avg_pool(output, window_shape=(2, 2), strides=(2, 2)) output = conv(output, train) # apply up-sampling layers for transpose_conv, conv in zip(up_transpose_conv, up_conv): downsample_layer = stack.pop() output = transpose_conv(output) # reflect pad on the right/botton if needed to handle odd input dimensions padding_right = 0 padding_bottom = 0 if output.shape[-2] != downsample_layer.shape[-2]: padding_right = 1 # padding right if output.shape[-3] != downsample_layer.shape[-3]: padding_bottom = 1 # padding bottom if padding_right or padding_bottom: padding = ((0, 0), (0, padding_bottom), (0, padding_right), (0, 0)) output = jnp.pad(output, padding, mode='reflect') output = jnp.concatenate((output, downsample_layer), axis=-1) output = conv(output, train) output = final_conv(output) return output.squeeze(-1)
def __call__(self, x): B, H, W, C = x.shape if self.with_conv: x = ddpm_conv3x3(x, C, stride=2) else: x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME') assert x.shape == (B, H // 2, W // 2, C) return x
def __call__(self, inputs, train: bool = False): in_shape = np.shape(inputs)[1:-1] x = nn.avg_pool(inputs, in_shape) x = nn.Conv(self.channels, (1, 1), padding='SAME', use_bias=False, name="conv1")(x) x = nn.BatchNorm(use_running_average=not train, name="bn1")(x) x = nn.relu(x) out_shape = (1, in_shape[0], in_shape[1], self.channels) x = jax.image.resize(x, shape=out_shape, method='bilinear') return x
def __call__(self, x, train): conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) maybe_normalize = model_utils.get_normalizer(self.normalizer, train) y = maybe_normalize()(x) y = nn.relu(y) y = conv(features=self.num_features, kernel_size=(1, 1))(y) y = nn.avg_pool( y, window_shape=(2, 2), strides=(2, 2) if self.use_kernel_size_as_stride_in_pooling else (1, 1)) return y
def __call__(self, x, train=False): conv_block = partial(BasicConv, train=train, dtype=self.dtype) inception_a = partial(InceptionA, conv_block=conv_block) inception_b = partial(InceptionB, conv_block=conv_block) inception_c = partial(InceptionC, conv_block=conv_block) inception_d = partial(InceptionD, conv_block=conv_block) inception_e = partial(InceptionE, conv_block=conv_block) inception_aux = partial(InceptionAux, conv_block=conv_block) if self.transform_input: x = np.transpose(x, (0, 3, 1, 2)) x_ch0 = jnp.expand_dims(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 x_ch1 = jnp.expand_dims(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 x_ch2 = jnp.expand_dims(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 x = jnp.concatenate((x_ch0, x_ch1, x_ch2), 1) x = np.transpose(x, (0, 2, 3, 1)) x = conv_block(32, kernel_size=(3, 3), strides=(2, 2), name='Conv2d_1a_3x3')(x) x = conv_block(32, kernel_size=(3, 3), name='Conv2d_2a_3x3')(x) x = conv_block(64, kernel_size=(3, 3), padding=[(1, 1), (1, 1)], name='Conv2d_2b_3x3')(x) x = nn.max_pool(x, (3, 3), strides=(2, 2)) x = conv_block(80, kernel_size=(1, 1), name='Conv2d_3b_1x1')(x) x = conv_block(192, kernel_size=(3, 3), name='Conv2d_4a_3x3')(x) x = nn.max_pool(x, (3, 3), strides=(2, 2)) x = inception_a(pool_features=32, name='Mixed_5b')(x) x = inception_a(pool_features=64, name='Mixed_5c')(x) x = inception_a(pool_features=64, name='Mixed_5d')(x) x = inception_b(name='Mixed_6a')(x) x = inception_c(channels_7x7=128, name='Mixed_6b')(x) x = inception_c(channels_7x7=160, name='Mixed_6c')(x) x = inception_c(channels_7x7=160, name='Mixed_6d')(x) x = inception_c(channels_7x7=192, name='Mixed_6e')(x) aux = inception_aux(self.num_classes, name='AuxLogits')(x) \ if train and self.aux_logits else None x = inception_d(name='Mixed_7a')(x) x = inception_e(name='Mixed_7b')(x) x = inception_e(name='Mixed_7c')(x) x = nn.avg_pool(x, (8, 8)) x = nn.Dropout(0.5)(x, deterministic=not train) return x, aux
def __call__(self, x, y): x = self.act(x) path = x for _ in range(self.n_stages): path = self.normalizer()(path, y) path = nn.avg_pool(path, window_shape=(5, 5), strides=(1, 1), padding='SAME') path = ncsn_conv3x3(path, self.features, stride=1, bias=False) x = path + x return x
def test_avg_pool_no_batch(self): x = jnp.full((3, 3, 1), 2.) pool = lambda x: nn.avg_pool(x, (2, 2)) y = pool(x) np.testing.assert_allclose(y, np.full((2, 2, 1), 2.)) y_grad = jax.grad(lambda x: pool(x).sum())(x) expected_grad = jnp.array([ [0.25, 0.5, 0.25], [0.5, 1., 0.5], [0.25, 0.5, 0.25], ]).reshape((3, 3, 1)) np.testing.assert_allclose(y_grad, expected_grad)
def __call__(self, x, train): x = nn.Conv(16, (3, 3), padding='SAME', name='init_conv', kernel_init=self.conv_kernel_init, use_bias=False)(x) x = WideResnetGroup(self.blocks_per_group, 16 * self.channel_multiplier, self.group_strides[0], conv_kernel_init=self.conv_kernel_init, normalizer=self.normalizer, dropout_rate=self.dropout_rate, activation_function=self.activation_function, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size)( x, train=train) x = WideResnetGroup(self.blocks_per_group, 32 * self.channel_multiplier, self.group_strides[1], conv_kernel_init=self.conv_kernel_init, normalizer=self.normalizer, dropout_rate=self.dropout_rate, activation_function=self.activation_function, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size)( x, train=train) x = WideResnetGroup(self.blocks_per_group, 64 * self.channel_multiplier, self.group_strides[2], conv_kernel_init=self.conv_kernel_init, dropout_rate=self.dropout_rate, normalizer=self.normalizer, activation_function=self.activation_function, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size)( x, train=train) maybe_normalize = model_utils.get_normalizer( self.normalizer, train, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size) x = maybe_normalize()(x) x = model_utils.ACTIVATIONS[self.activation_function](x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(self.num_outputs, kernel_init=self.dense_kernel_init)(x) return x
def __call__(self, x): Conv1x1_ = partial(Conv1x1, precision=self.conv_precision) Conv3x3_ = partial(Conv3x3 if self.use_3x3 else Conv1x1, precision=self.conv_precision) x_ = Conv1x1_(self.middle_width)(nn.gelu(x)) x_ = Conv3x3_(self.middle_width)(nn.gelu(x_)) x_ = Conv3x3_(self.middle_width)(nn.gelu(x_)) x_ = Conv1x1_(self.out_width, kernel_init=lecun_normal(self.last_scale))(nn.gelu(x_)) out = x + x_ if self.residual else x_ if self.down_rate > 1: window_shape = 2 * (self.down_rate, ) out = nn.avg_pool(out, window_shape, window_shape) return out
def __call__(self, x): branch1x1 = self.conv_block(192, kernel_size=(1, 1), name='branch1x1')(x) c7 = self.channels_7x7 branch7x7 = self.conv_block(c7, kernel_size=(1, 1), name='branch7x7_1')(x) branch7x7 = self.conv_block(c7, kernel_size=(1, 7), padding=[(0, 0), (3, 3)], name='branch7x7_2')(branch7x7) branch7x7 = self.conv_block(192, kernel_size=(7, 1), padding=[(3, 3), (0, 0)], name='branch7x7_3')(branch7x7) branch7x7dbl = self.conv_block(c7, kernel_size=(1, 1), name='branch7x7dbl_1')(x) branch7x7dbl = self.conv_block(c7, kernel_size=(7, 1), padding=[(3, 3), (0, 0)], name='branch7x7dbl_2')(branch7x7dbl) branch7x7dbl = self.conv_block(c7, kernel_size=(1, 7), padding=[(0, 0), (3, 3)], name='branch7x7dbl_3')(branch7x7dbl) branch7x7dbl = self.conv_block(c7, kernel_size=(7, 1), padding=[(3, 3), (0, 0)], name='branch7x7dbl_4')(branch7x7dbl) branch7x7dbl = self.conv_block(192, kernel_size=(1, 7), padding=[(0, 0), (3, 3)], name='branch7x7dbl_5')(branch7x7dbl) branch_pool = nn.avg_pool(x, (3, 3), strides=(1, 1), padding=[(1, 1), (1, 1)]) branch_pool = self.conv_block(192, kernel_size=(1, 1), name='branch_pool')(branch_pool) outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] return jnp.concatenate(outputs, 3)
def _output_add(block_x, orig_x): """Add two tensors, padding them with zeros or pooling them if necessary. Args: block_x: Output of a resnet block. orig_x: Residual branch to add to the output of the resnet block. Returns: The sum of blocks_x and orig_x. If necessary, orig_x will be average pooled or zero padded so that its shape matches orig_x. """ stride = orig_x.shape[-2] // block_x.shape[-2] strides = (stride, stride) if block_x.shape[-1] != orig_x.shape[-1]: orig_x = nn.avg_pool(orig_x, strides, strides) channels_to_add = block_x.shape[-1] - orig_x.shape[-1] orig_x = jnp.pad(orig_x, [(0, 0), (0, 0), (0, 0), (0, channels_to_add)]) return block_x + orig_x
def __call__(self, x): branch1x1 = self.conv_block(320, kernel_size=(1, 1), name='branch1x1')(x) branch3x3 = self.conv_block(384, kernel_size=(1, 1), name='branch3x3_1')(x) branch3x3_2a = self.conv_block(384, kernel_size=(1, 3), padding=[(0, 0), (1, 1)], name='branch3x3_2a')(branch3x3) branch3x3_2b = self.conv_block(384, kernel_size=(3, 1), padding=[(1, 1), (0, 0)], name='branch3x3_2b')(branch3x3) branch3x3 = jnp.concatenate([branch3x3_2a, branch3x3_2b], 3) branch3x3dbl = self.conv_block(448, kernel_size=(1, 1), name='branch3x3dbl_1')(x) branch3x3dbl = self.conv_block(384, kernel_size=(3, 3), padding=[(1, 1), (1, 1)], name='branch3x3dbl_2')(branch3x3dbl) branch3x3dbl_3a = self.conv_block(384, kernel_size=(1, 3), padding=[(0, 0), (1, 1)], name='branch3x3dbl_3a')(branch3x3dbl) branch3x3dbl_3b = self.conv_block(384, kernel_size=(3, 1), padding=[(1, 1), (0, 0)], name='branch3x3dbl_3b')(branch3x3dbl) branch3x3dbl = jnp.concatenate([branch3x3dbl_3a, branch3x3dbl_3b], 3) branch_pool = nn.avg_pool(x, (3, 3), strides=(1, 1), padding=[(1, 1), (1, 1)]) branch_pool = self.conv_block(192, kernel_size=(1, 1), name='branch_pool')(branch_pool) outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] return jnp.concatenate(outputs, 3)
def __call__(self, inputs, train): """Applies the network to inputs. Args: inputs: (batch_size, resolution, resolution, n_spins, n_channels) array. train: whether to run in training or inference mode. Returns: A (batch_size, num_classes) float32 array with per-class scores (logits). Raises: ValueError: If resolutions cannot be enforced with 2x2 pooling. """ num_layers = len(self.resolutions) # Merge spin and channel dimensions. features = inputs.reshape((*inputs.shape[:3], -1)) for layer_id in range(num_layers - 1): resolution_in = self.resolutions[layer_id] resolution_out = self.resolutions[layer_id + 1] n_channels = self.widths[layer_id + 1] if resolution_out == resolution_in // 2: features = nn.avg_pool(features, window_shape=(2, 2), strides=(2, 2), padding='SAME') elif resolution_out != resolution_in: raise ValueError( 'Consecutive resolutions must be equal or halved.') features = nn.Conv(features=n_channels, kernel_size=(3, 3), strides=(1, 1))(features) features = nn.BatchNorm(use_running_average=not train, axis_name=self.axis_name)(features) features = nn.relu(features) features = jnp.mean(features, axis=(1, 2)) features = nn.Dense(self.num_classes)(features) return features
def __call__(self, x, *, emb, deterministic): B, _, _, C = x.shape # pylint: disable=invalid-name assert emb.shape[0] == B and len(emb.shape) == 2 out_ch = C if self.out_ch is None else self.out_ch h = nonlinearity(Normalize(name='norm1')(x)) if self.resample is not None: updown = lambda z: { 'up': nearest_neighbor_upsample(z), 'down': nn.avg_pool(z, (2, 2), (2, 2)) }[self.resample] h = updown(h) x = updown(x) h = nn.Conv( features=out_ch, kernel_size=(3, 3), strides=(1, 1), name='conv1')(h) # add in timestep/class embedding emb_out = nn.Dense(features=2 * out_ch, name='temb_proj')( nonlinearity(emb))[:, None, None, :] scale, shift = jnp.split(emb_out, 2, axis=-1) h = Normalize(name='norm2')(h) * (1 + scale) + shift # rest h = nonlinearity(h) h = nn.Dropout(rate=self.dropout)(h, deterministic=deterministic) h = nn.Conv( features=out_ch, kernel_size=(3, 3), strides=(1, 1), kernel_init=nn.initializers.zeros, name='conv2')(h) if C != out_ch: x = nn.Dense(features=out_ch, name='nin_shortcut')(x) assert x.shape == h.shape logging.info( '%s: x=%r emb=%r resample=%r', self.name, x.shape, emb.shape, self.resample) return x + h
def __call__(self, x, train): def dense_layers(y, block, num_blocks, growth_rate): for _ in range(num_blocks): y = block(growth_rate)(y, train=train) return y def update_num_features(num_features, num_blocks, growth_rate, reduction): num_features += num_blocks * growth_rate if reduction is not None: num_features = int(math.floor(num_features * reduction)) return num_features # Initial convolutional layer num_features = 2 * self.growth_rate conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) y = conv( features=num_features, kernel_size=(3, 3), padding=((1, 1), (1, 1)), name='conv1')(x) # Internal dense and transtion blocks num_blocks = _block_size_options[self.num_layers] block = functools.partial( BottleneckBlock, dtype=self.dtype, normalizer=self.normalizer) for i in range(3): y = dense_layers(y, block, num_blocks[i], self.growth_rate) num_features = update_num_features(num_features, num_blocks[i], self.growth_rate, self.reduction) y = TransitionBlock( num_features, dtype=self.dtype, normalizer=self.normalizer, use_kernel_size_as_stride_in_pooling=self .use_kernel_size_as_stride_in_pooling)( y, train=train) # Final dense block y = dense_layers(y, block, num_blocks[3], self.growth_rate) # Final pooling maybe_normalize = model_utils.get_normalizer(self.normalizer, train) y = maybe_normalize()(y) y = nn.relu(y) y = nn.avg_pool( y, window_shape=(4, 4), strides=(4, 4) if self.use_kernel_size_as_stride_in_pooling else (1, 1)) # Classification layer y = jnp.reshape(y, (y.shape[0], -1)) if self.normalize_classifier_input: maybe_normalize = model_utils.get_normalizer( self.normalize_classifier_input, train) y = maybe_normalize()(y) y = y * self.classification_scale_factor y = nn.Dense(self.num_outputs)(y) return y
def __call__( self, inputs, context_vectors=None, ): """Applies the res block to input images. Args: inputs: a rank-4 array of input images of shape (B, H, W, C). context_vectors: optional auxiliary inputs, typically used for conditioning. If set, they should be of rank 2, and their first (batch) dimension should match that of `inputs`. Their number of features is arbitrary. They will be reshaped from (B, D) to (B, 1, 1, D) and a 1x1 convolution will be applied to them. Returns: a the rank-4 output of the block. """ if self.downsampling_rate < 1: raise ValueError('downsampling_rate should be >= 1, but got ' f'{self.downsampling_rate}.') def build_layers(inputs): """Build layers of the ResBlock given a batch of inputs.""" resolution = inputs.shape[1] if resolution > 2: kernel_shapes = ((1, 1), (3, 3), (3, 3), (1, 1)) else: kernel_shapes = ((1, 1), (1, 1), (1, 1), (1, 1)) conv_layers = [] aux_conv_layers = [] for layer_idx, kernel_shape in enumerate(kernel_shapes): is_last = layer_idx == _NUM_CONV_LAYER_PER_BLOCK - 1 num_channels = self.output_channels if is_last else self.internal_channels weights_scale = self.last_weights_scale if is_last else 1. conv_layers.append( get_vdvae_convolution(num_channels, kernel_shape, weights_scale, name=f'c{layer_idx}', precision=self.precision)) aux_conv_layers.append( get_vdvae_convolution(num_channels, (1, 1), 0., name=f'aux_c{layer_idx}', precision=self.precision)) return conv_layers, aux_conv_layers chex.assert_rank(inputs, 4) if inputs.shape[1] != inputs.shape[2]: raise ValueError( 'VDVAE only works with square images, but got ' f'rectangular images of shape {inputs.shape[1:3]}.') if context_vectors is not None: chex.assert_rank(context_vectors, 2) inputs_batch_dim = inputs.shape[0] aux_batch_dim = context_vectors.shape[0] if inputs_batch_dim != aux_batch_dim: raise ValueError( 'Context vectors batch dimension is incompatible ' 'with inputs batch dimension. Got ' f'{aux_batch_dim} vs {inputs_batch_dim}.') context_vectors = context_vectors[:, None, None, :] conv_layers, aux_conv_layers = build_layers(inputs) outputs = inputs for conv, auxiliary_conv in zip(conv_layers, aux_conv_layers): outputs = conv(jax.nn.gelu(outputs)) if context_vectors is not None: outputs += auxiliary_conv(context_vectors) if self.use_residual_connection: in_channels = inputs.shape[-1] out_channels = outputs.shape[-1] if in_channels != out_channels: raise AssertionError( 'Cannot apply residual connection because the ' 'number of output channels differs from the ' 'number of input channels: ' f'{out_channels} vs {in_channels}.') outputs += inputs if self.downsampling_rate > 1: shape = (self.downsampling_rate, self.downsampling_rate) outputs = nn.avg_pool(outputs, window_shape=shape, strides=shape, padding='VALID') return outputs
def __call__( self, inputs, ): """Applies ResNet model. Number of residual blocks inferred from hparams.""" num_classes = self.num_classes hparams = self.hparams num_filters = self.num_filters dtype = self.dtype assert hparams.act_function in act_function_zoo.keys( ), 'Activation function type is not supported.' x = aqt_flax_layers.ConvAqt( features=num_filters, kernel_size=(7, 7), strides=(2, 2), padding=[(3, 3), (3, 3)], use_bias=False, dtype=dtype, name='init_conv', train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.conv_init, )( inputs) x = nn.BatchNorm( use_running_average=not self.train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn')( x) if hparams.act_function == 'relu': x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') else: # TODO(yichi): try adding other activation functions here # Use avg pool so that for binary nets, the distribution is symmetric. x = nn.avg_pool(x, (3, 3), strides=(2, 2), padding='SAME') filter_multiplier = hparams.filter_multiplier for i, block_hparams in enumerate(hparams.residual_blocks): proj = block_hparams.conv_proj # For projection layers (unless it is the first layer), strides = (2, 2) if i > 0 and proj is not None: filter_multiplier *= 2 strides = (2, 2) else: strides = (1, 1) x = ResidualBlock( filters=int(num_filters * filter_multiplier), hparams=block_hparams, quant_context=self.quant_context, strides=strides, train=self.train, dtype=dtype)( x) if hparams.act_function == 'none': # The DenseAQT below is not binarized. # If removing the activation functions, there will be no act function # between the last residual block and the dense layer. # So add a ReLU in that case. # TODO(yichi): try BPReLU x = nn.relu(x) else: pass x = jnp.mean(x, axis=(1, 2)) x = aqt_flax_layers.DenseAqt( features=num_classes, dtype=dtype, train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.dense_layer, )(x, padding_mask=None) x = jnp.asarray(x, dtype) output = nn.log_softmax(x) return output