def apply(self, x, *, stride, filters, train): norm_layer = nn.BatchNorm.partial(use_running_average=not train, momentum=0.9, epsilon=1e-5) conv3x3 = nn.Conv.partial(kernel_size=(3, 3), padding="SAME", bias=False) conv1x1 = nn.Conv.partial(kernel_size=(1, 1), padding="SAME", bias=False) x = norm_layer(x) x = nn.relu(x) identity = x needs_projection = x.shape[-1] != filters or stride != (1, 1) if needs_projection: identity = conv1x1(x, features=filters, strides=stride) x = conv3x3(x, features=filters, strides=stride) x = norm_layer(x) x = nn.relu(x) x = conv3x3(x, features=filters, strides=(1, 1)) x += identity return x
def apply(self, x, action_dim, max_action): x = nn.Dense(x, features=256) x = nn.relu(x) x = nn.Dense(x, features=256) x = nn.relu(x) x = nn.Dense(x, features=action_dim) return max_action * nn.tanh(x)
def apply(self, x, filters, strides=(1, 1), train=True, dtype=jnp.float32): needs_projection = x.shape[-1] != filters * 4 or strides != (1, 1) batch_norm = nn.BatchNorm.partial(use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=dtype) conv = nn.Conv.partial(bias=False, dtype=dtype) residual = x if needs_projection: residual = conv(residual, filters * 4, (1, 1), strides, name='proj_conv') residual = batch_norm(residual, name='proj_bn') y = conv(x, filters, (1, 1), name='conv1') y = batch_norm(y, name='bn1') y = nn.relu(y) y = conv(y, filters, (3, 3), strides, name='conv2') y = batch_norm(y, name='bn2') y = nn.relu(y) y = conv(y, filters * 4, (1, 1), name='conv3') y = batch_norm(y, name='bn3', scale_init=nn.initializers.zeros) y = nn.relu(residual + y) return y
def apply(self, x, rep_size, m_layers, m_features, m_kernel_sizes, conv_rep_size, padding_mask=None): H_0 = nn.relu(nn.Dense(x, conv_rep_size)) G_0 = nn.relu(nn.Dense(x, conv_rep_size)) H, G = jnp.expand_dims(H_0, axis=2), jnp.expand_dims(G_0, axis=2) for layer in range(1, m_layers+1): if layer < m_layers: H_features, G_features = m_features[layer-1] else: H_features, G_features = conv_rep_size, conv_rep_size H_kernel_size, G_kernel_size = m_kernel_sizes[layer-1] H = nn.Conv(H, features=H_features, kernel_size=(H_kernel_size, 1)) G = nn.Conv(G, features=G_features, kernel_size=(G_kernel_size, 1)) if layer < m_layers: H = nn.relu(H) G = nn.relu(G) else: H = nn.tanh(H) G = nn.sigmoid(G) H, G = jnp.squeeze(H, axis=2), jnp.squeeze(G, axis=2) F = H * G + G_0 rep = linear_max_pool(F, padding_mask=padding_mask, rep_size=rep_size) return rep
def apply(self, x, nout, strides=(1, 1), bottleneck=True): features = nout nout = nout * 4 if bottleneck else nout needs_projection = x.shape[-1] != nout or strides != (1, 1) residual = x if needs_projection: residual = StdConv(residual, nout, (1, 1), strides, bias=False, name="conv_proj") residual = nn.GroupNorm(residual, epsilon=1e-4, name="gn_proj") if bottleneck: x = StdConv(x, features, (1, 1), bias=False, name="conv1") x = nn.GroupNorm(x, epsilon=1e-4, name="gn1") x = nn.relu(x) x = StdConv(x, features, (3, 3), strides, bias=False, name="conv2") x = nn.GroupNorm(x, epsilon=1e-4, name="gn2") x = nn.relu(x) last_kernel = (1, 1) if bottleneck else (3, 3) x = StdConv(x, nout, last_kernel, bias=False, name="conv3") x = nn.GroupNorm(x, epsilon=1e-4, name="gn3", scale_init=nn.initializers.zeros) x = nn.relu(residual + x) return x
def classifier(x, num_outputs, dropout_rate, deterministic): """Implements the classification portion of the network.""" x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = nn.Dense(x, 512) x = nn.relu(x) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = nn.Dense(x, 512) x = nn.relu(x) x = nn.Dense(x, num_outputs) return x
def apply(self, x, depth, features, kernel_size, use_one_hot): if use_one_hot: x = one_hot(x) x = MaskedConv1d(x, features, (kernel_size, ), is_first_layer=True) x = nn.relu(x) for _ in range(depth - 2): x = MaskedConv1d(x, features, (kernel_size, )) x = nn.relu(x) x = MaskedConv1d(x, 4, (kernel_size, )) x = real_to_complex(x) return x
def apply(self, x): x = nn.Conv(x, features=32, kernel_size=(3, 3)) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(x, features=64, kernel_size=(3, 3)) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=256) x = nn.relu(x) x = nn.Dense(x, features=10) x = nn.log_softmax(x) return x
def classifier_head_dual(encoded1, encoded2, num_classes, mlp_dim, pooling_mode='MEAN', interaction=None): """Classifier head for dual encoding or pairwise problem. We put this here just so that all models consistently call the same function. Args: encoded1: tensor inputs are shape of [bs, len, dim]. encoded2: tensor inputs are shape of [bs, len, dim]. num_classes: int, number of classes mlp_dim: int, dim of intermediate MLP. pooling_mode: str, string dictating pooling op {MEAN} interaction: str, string dictating interaction between e1, e2 Returns: tensor of shape [bs, num_classes] """ if pooling_mode == 'MEAN': encoded1 = jnp.mean(encoded1, axis=1) encoded2 = jnp.mean(encoded2, axis=1) elif pooling_mode == 'SUM': encoded1 = jnp.sum(encoded1, axis=1) encoded2 = jnp.sum(encoded2, axis=1) elif pooling_mode == 'FLATTEN': encoded1 = encoded1.reshape((encoded1.shape[0], -1)) encoded2 = encoded2.reshape((encoded2.shape[0], -1)) elif pooling_mode == 'CLS': encoded1 = encoded1[:, 0] encoded2 = encoded2[:, 0] else: raise NotImplementedError('Pooling not supported yet.') if interaction == 'NLI': # NLI interaction style encoded = jnp.concatenate( [encoded1, encoded2, encoded1 * encoded2, encoded1 - encoded2], 1) else: encoded = jnp.concatenate([encoded1, encoded2], 1) encoded = nn.Dense(encoded, mlp_dim, name='mlp') encoded = nn.relu(encoded) encoded = nn.Dense(encoded, int(mlp_dim // 2), name='mlp2') encoded = nn.relu(encoded) encoded = nn.Dense(encoded, num_classes, name='logits') return encoded
def apply(self, x, use_squeeze_excite=False): x = nn.Conv(x, features=8, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) x = nn.Conv(x, features=16, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) if use_squeeze_excite: x = SqueezeExciteLayer(x) x = nn.Conv(x, features=32, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) if use_squeeze_excite: x = SqueezeExciteLayer(x) x = nn.Conv(x, features=1, kernel_size=(3, 3), padding="VALID") scores = nn.max_pool(x, window_shape=(8, 8), strides=(8, 8))[Ellipsis, 0] return scores
def apply( self, x, action_dim, max_action, key=None, MPO=False, sample=False, log_sig_min=-20, log_sig_max=2, ): x = nn.Dense(x, features=200) x = nn.LayerNorm(x) x = nn.tanh(x) x = nn.Dense(x, features=200) x = nn.elu(x) x = nn.Dense(x, features=2 * action_dim) mu, log_sig = jnp.split(x, 2, axis=-1) log_sig = nn.softplus(log_sig) log_sig = jnp.clip(log_sig, log_sig_min, log_sig_max) if MPO: return mu, log_sig if not sample: return max_action * nn.tanh(mu), log_sig else: pi = mu + random.normal(key, mu.shape) * jnp.exp(log_sig) log_pi = gaussian_likelihood(pi, mu, log_sig) pi = nn.tanh(pi) log_pi -= jnp.sum(jnp.log(nn.relu(1 - pi ** 2) + 1e-6), axis=1) return max_action * pi, log_pi
def apply(self, inputs, mlp_dim, dtype=jnp.float32, out_dim=None, dropout_rate=0.1, deterministic=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)): """Applies Transformer MlpBlock module.""" actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim x = nn.Dense(inputs, mlp_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) x = nn.relu(x) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) output = nn.Dense(x, actual_out_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic) return output
def apply(self, x, num_classes=1000, width_factor=1, num_layers=50): block_sizes = _block_sizes[num_layers] width = 64 * width_factor root_block = RootBlock.partial(width=width) x = root_block(x, name='root_block') # Blocks for i, block_size in enumerate(block_sizes): x = ResidualBlock(x, block_size, width * 2**i, first_stride=(1, 1) if i == 0 else (2, 2), name=f"block{i + 1}") # Pre-head x = GroupNorm(x, name='norm-pre-head') x = nn.relu(x) x = jnp.mean(x, axis=(1, 2)) # Head x = nn.Dense(x, num_classes, name="conv_head", kernel_init=nn.initializers.zeros) return x.astype(jnp.float32)
def apply(self, inputs, mlp_dim, dtype=jnp.float32, out_dim=None, dropout_rate=0.1, deterministic=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), num_partitions=2): """Applies Transformer MlpBlock module.""" actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim inputs_shape = inputs.shape inputs = inputs.reshape((-1, inputs_shape[-1])) x = nn.Dense(inputs, mlp_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) x = nn.relu(x) if num_partitions > 1: x = with_sharding_constraint(x, P(1, num_partitions)) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) output = nn.Dense(x, actual_out_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic) output = output.reshape(inputs_shape[:-1] + (actual_out_dim, )) return output
def classifier_head(encoded, num_classes, mlp_dim, pooling_mode='MEAN'): """Classifier head. We put this here just so that all models consistently call the same function. Args: encoded: tensor inputs are shape of [bs, len, dim]. num_classes: int, number of classes mlp_dim: int, dim of intermediate MLP. pooling_mode: str, string dictating pooling op {MEAN} Returns: tensor of shape [bs, num_classes] """ if pooling_mode == 'MEAN': encoded = jnp.mean(encoded, axis=1) elif pooling_mode == 'SUM': encoded = jnp.sum(encoded, axis=1) elif pooling_mode == 'FLATTEN': encoded = encoded.reshape((encoded.shape[0], -1)) elif pooling_mode == 'CLS': encoded = encoded[:, 0] else: raise NotImplementedError('Pooling not supported yet.') encoded = nn.Dense(encoded, mlp_dim, name='mlp') encoded = nn.relu(encoded) encoded = nn.Dense(encoded, num_classes, name='logits') return encoded
def apply_activation(intermediate_output, intermediate_activation): """Applies selected activation function to intermediate output.""" if intermediate_activation is None: return intermediate_output if intermediate_activation == 'gelu': intermediate_output = nn.gelu(intermediate_output) elif intermediate_activation == 'relu': intermediate_output = nn.relu(intermediate_output) elif intermediate_activation == 'sigmoid': intermediate_output = nn.sigmoid(intermediate_output) elif intermediate_activation == 'softmax': intermediate_output = nn.softmax(intermediate_output) elif intermediate_activation == 'celu': intermediate_output = nn.celu(intermediate_output) elif intermediate_activation == 'elu': intermediate_output = nn.elu(intermediate_output) elif intermediate_activation == 'log_sigmoid': intermediate_output = nn.log_sigmoid(intermediate_output) elif intermediate_activation == 'log_softmax': intermediate_output = nn.log_softmax(intermediate_output) elif intermediate_activation == 'soft_sign': intermediate_output = nn.soft_sign(intermediate_output) elif intermediate_activation == 'softplus': intermediate_output = nn.softplus(intermediate_output) elif intermediate_activation == 'swish': intermediate_output = nn.swish(intermediate_output) elif intermediate_activation == 'tanh': intermediate_output = jnp.tanh(intermediate_output) else: raise NotImplementedError( '%s activation function is not yet supported.' % intermediate_activation) return intermediate_output
def apply(self, x): x = nn.Conv(x, features=32, kernel_size=(3, 3), name="conv") x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten. x = nn.Dense(x, 128, name="fc") return x
def apply(self, x, num_classes, *, stage_sizes, block_cls, num_filters=64, dtype=jnp.float32, act=nn.relu, train=True): conv = nn.Conv.partial(bias=False, dtype=dtype) norm = nn.BatchNorm.partial( use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=dtype) x = conv(x, num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name='conv_init') x = norm(x, name='bn_init') x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(stage_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = block_cls(x, num_filters * 2 ** i, strides=strides, conv=conv, norm=norm, act=act) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, dtype=dtype) x = jnp.asarray(x, dtype) x = nn.log_softmax(x) return x
def apply(self, x, hidden_layers, hidden_dim, n_classes): x = jnp.reshape(x, (x.shape[0], -1)) for layer in range(hidden_layers): x = nn.Dense(x, hidden_dim, name=f'fc{layer}') x = nn.relu(x) x = nn.Dense(x, n_classes, name=f'fc{hidden_layers}') preds = nn.log_softmax(x) return preds
def apply(self, x, reduction=16): num_channels = x.shape[-1] y = x.mean(axis=(1, 2)) y = nn.Dense(y, features=num_channels // reduction, bias=False) y = nn.relu(y) y = nn.Dense(y, features=num_channels, bias=False) y = nn.sigmoid(y) return x * y[:, None, None, :]
def apply(self, x, num_layers, num_outputs, growth_rate, reduction, normalizer='batch_norm', dtype='float32', train=True): def dense_layers(y, block, num_blocks, growth_rate): for _ in range(num_blocks): y = block(y, growth_rate) return y def update_num_features(num_features, num_blocks, growth_rate, reduction): num_features += num_blocks * growth_rate if reduction is not None: num_features = int(math.floor(num_features * reduction)) return num_features # Initial convolutional layer num_features = 2 * growth_rate conv = nn.Conv.partial(bias=False, dtype=dtype) y = conv(x, features=num_features, kernel_size=(3, 3), padding=((1, 1), (1, 1)), name='conv1') # Internal dense and transtion blocks num_blocks = _block_size_options[num_layers] block = BottleneckBlock.partial(train=train, dtype=dtype, normalizer=normalizer) for i in range(3): y = dense_layers(y, block, num_blocks[i], growth_rate) num_features = update_num_features(num_features, num_blocks[i], growth_rate, reduction) y = TransitionBlock(y, num_features, train=train, dtype=dtype, normalizer=normalizer) # Final dense block y = dense_layers(y, block, num_blocks[3], growth_rate) # Final pooling maybe_normalize = model_utils.get_normalizer(normalizer, train) y = maybe_normalize(y) y = nn.relu(y) y = nn.avg_pool(y, window_shape=(4, 4)) # Classification layer y = jnp.reshape(y, (y.shape[0], -1)) y = nn.Dense(y, num_outputs) return y
def apply(self, state, action, Q1=False): state_action = jnp.concatenate([state, action], axis=1) q1 = nn.Dense(state_action, features=256) q1 = nn.relu(q1) q1 = nn.Dense(q1, features=256) q1 = nn.relu(q1) q1 = nn.Dense(q1, features=1) if Q1: return q1 q2 = nn.Dense(state_action, features=256) q2 = nn.relu(q2) q2 = nn.Dense(q2, features=256) q2 = nn.relu(q2) q2 = nn.Dense(q2, features=1) return q1, q2
def apply( self, x, num_classes, num_filters=64, num_layers=50, train=True, axis_name=None, axis_index_groups=None, dtype=jnp.float32, batch_norm_momentum=0.9, batch_norm_epsilon=1e-5, bn_output_scale=0.0, virtual_batch_size=None, data_format=None): if num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[num_layers] conv = nn.Conv.partial(padding=[(3, 3), (3, 3)]) x = conv(x, num_filters, kernel_size=(7, 7), strides=(2, 2), bias=False, dtype=dtype, name='conv0') x = normalization.VirtualBatchNorm( x, use_running_average=not train, momentum=batch_norm_momentum, epsilon=batch_norm_epsilon, name='init_bn', axis_name=axis_name, axis_index_groups=axis_index_groups, dtype=dtype, virtual_batch_size=virtual_batch_size, data_format=data_format) x = nn.relu(x) # MLPerf-required x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = ResidualBlock( x, num_filters * 2 ** i, strides=strides, train=train, axis_name=axis_name, axis_index_groups=axis_index_groups, dtype=dtype, batch_norm_momentum=batch_norm_momentum, batch_norm_epsilon=batch_norm_epsilon, bn_output_scale=bn_output_scale, virtual_batch_size=virtual_batch_size, data_format=data_format) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(), dtype=dtype) return x
def apply(self, x, filters, strides=(1, 1), train=True, batch_stats=None, dtype=jnp.float32, batch_norm_momentum=0.9, batch_norm_epsilon=1e-5, virtual_batch_size=None, data_format=None): needs_projection = x.shape[-1] != filters * 4 or strides != (1, 1) batch_norm = normalization.VirtualBatchNorm.partial( batch_stats=batch_stats, use_running_average=not train, momentum=batch_norm_momentum, epsilon=batch_norm_epsilon, dtype=dtype, virtual_batch_size=virtual_batch_size, data_format=data_format) conv = nn.Conv.partial(bias=False, dtype=dtype) residual = x if needs_projection: residual = conv(residual, filters * 4, (1, 1), strides, name='proj_conv') residual = batch_norm(residual, name='proj_bn') y = conv(x, filters, (1, 1), name='conv1') y = batch_norm(y, name='bn1') y = nn.relu(y) y = conv(y, filters, (3, 3), strides, name='conv2') y = batch_norm(y, name='bn2') y = nn.relu(y) y = conv(y, filters * 4, (1, 1), name='conv3') y = batch_norm(y, name='bn3', scale_init=nn.initializers.zeros) y = nn.relu(residual + y) return y
def apply(self, x, num_outputs): """Define the convolutional network architecture. Architecture originates from "Human-level control through deep reinforcement learning.", Nature 518, no. 7540 (2015): 529-533. Note that this is different than the one from "Playing atari with deep reinforcement learning." arxiv.org/abs/1312.5602 (2013) """ dtype = jnp.float32 x = x.astype(dtype) / 255. x = nn.Conv(x, features=32, kernel_size=(8, 8), strides=(4, 4), name='conv1', dtype=dtype) x = nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(4, 4), strides=(2, 2), name='conv2', dtype=dtype) x = nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(3, 3), strides=(1, 1), name='conv3', dtype=dtype) x = nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=512, name='hidden', dtype=dtype) x = nn.relu(x) # Network used to both estimate policy (logits) and expected state value. # See github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py logits = nn.Dense(x, features=num_outputs, name='logits', dtype=dtype) policy_log_probabilities = nn.log_softmax(logits) value = nn.Dense(x, features=1, name='value', dtype=dtype) return policy_log_probabilities, value
def apply(self, x, num_classes, num_filters=64, num_layers=50, train=True, axis_name=None, axis_index_groups=None, dtype=jnp.float32, conv0_space_to_depth=False): if num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[num_layers] if conv0_space_to_depth: conv = SpaceToDepthConv.partial(block_size=(2, 2), padding=[(2, 1), (2, 1)]) else: conv = nn.Conv.partial(padding=[(3, 3), (3, 3)]) x = conv(x, num_filters, kernel_size=(7, 7), strides=(2, 2), bias=False, dtype=dtype, name='conv0') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn', axis_name=axis_name, axis_index_groups=axis_index_groups, dtype=dtype) x = nn.relu(x) # MLPerf-required x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = ResidualBlock(x, num_filters * 2**i, strides=strides, train=train, axis_name=axis_name, axis_index_groups=axis_index_groups, dtype=dtype) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(), dtype=dtype) x = nn.log_softmax(x) return x
def apply(self, x, n_layers, n_features, n_kernel_sizes): x = jnp.expand_dims(x, axis=2) for layer in range(n_layers): features = n_features[layer] kernel_size = (n_kernel_sizes[layer], 1) x = nn.Conv(x, features=features, kernel_size=kernel_size) x = nn.relu(x) x = jnp.squeeze(x, axis=2) return x
def apply(self, actions, num_layers, hidden_dims): timesteps = actions.shape[1] # flatten time into batch actions = jnp.reshape(actions, (-1, ) + actions.shape[2:]) # embed actions x = nn.Dense(actions, hidden_dims) for _ in range(num_layers): x = nn.Dense(x, hidden_dims) x = nn.LayerNorm(x) x = nn.relu(x) x = nn.Dense(x, 1) x = jnp.reshape(x, (-1, timesteps, 1)) return x
def apply(self, x, num_features, train=True, dtype=jnp.float32, normalizer='batch_norm'): conv = nn.Conv.partial(bias=False, dtype=dtype) maybe_normalize = model_utils.get_normalizer(normalizer, train) y = maybe_normalize(x) y = nn.relu(y) y = conv(y, features=num_features, kernel_size=(1, 1)) y = nn.avg_pool(y, window_shape=(2, 2)) return y
def apply(self, x, nout, strides=(1, 1)): x_shortcut = x needs_projection = x.shape[-1] != nout * 4 or strides != (1, 1) group_norm = GroupNorm conv = StdConv.partial(bias=False) x = group_norm(x, name="gn1") x = nn.relu(x) if needs_projection: x_shortcut = conv(x, nout * 4, (1, 1), strides, name="conv_proj") x = conv(x, nout, (1, 1), name="conv1") x = group_norm(x, name="gn2") x = nn.relu(x) x = fixed_padding(x, 3) x = conv(x, nout, (3, 3), strides, name="conv2", padding='VALID') x = group_norm(x, name="gn3") x = nn.relu(x) x = conv(x, nout * 4, (1, 1), name="conv3") return x + x_shortcut