def apply(self, x, num_classes, *, stage_sizes, block_cls, num_filters=64, dtype=jnp.float32, act=nn.relu, train=True): conv = nn.Conv.partial(bias=False, dtype=dtype) norm = nn.BatchNorm.partial( use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=dtype) x = conv(x, num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name='conv_init') x = norm(x, name='bn_init') x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(stage_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = block_cls(x, num_filters * 2 ** i, strides=strides, conv=conv, norm=norm, act=act) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, dtype=dtype) return x
def apply(self, x, num_classes, num_filters=64, num_layers=50, train=True, dtype=jnp.float32): if num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[num_layers] x = nn.Conv(x, num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], bias=False, dtype=dtype, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn') x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = ResidualBlock(x, num_filters * 2**i, strides=strides, train=train, dtype=dtype) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes) x = nn.log_softmax(x) return x
def apply( self, x, num_outputs, num_filters=64, block_sizes=[3, 4, 6, 3], # pylint: disable=dangerous-default-value train=True): x = nn.Conv(x, num_filters, (7, 7), (2, 2), bias=False, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, epsilon=1e-5, name='init_bn') x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = BottleneckBlock(x, num_filters * 2**i, strides=strides, train=train) x = jnp.mean(x, axis=(1, 2)) x_clf = nn.Dense(x, num_outputs, name='clf') # We return both the outputs from the dense layer *and* the features # that go into it. return x_clf, x
def apply(self, x, num_filters=64, block_sizes=(3, 4, 6, 3), train=True, block=BottleneckBlock, small_inputs=False): if small_inputs: x = nn.Conv(x, num_filters, kernel_size=(3, 3), strides=(1, 1), bias=False, name="init_conv") else: x = nn.Conv(x, num_filters, kernel_size=(7, 7), strides=(2, 2), bias=False, name="init_conv") x = nn.BatchNorm(x, use_running_average=not train, epsilon=1e-5, name="init_bn") if not small_inputs: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = block(x, num_filters * 2**i, strides=strides, train=train) return x
def apply(self, x, n_bins, stage_sizes, block_cls, num_filters=64, dtype=jnp.float32, act=nn.leaky_relu, train=True): b = x.shape[0] conv = nn.Conv.partial(bias=False, dtype=dtype) norm = nn.BatchNorm.partial( use_running_average=not train, dtype=dtype ) x = conv(x, num_filters, kernel_size=(7,), strides=(2,), padding=[(3, 3)]) x = norm(x) x = nn.leaky_relu(x) x = nn.max_pool(x, window_shape=(3,), strides=(2,), padding='SAME') for i, block_size in enumerate(stage_sizes): for j in range(block_size): strides = (2,) if i > 0 and j == 0 else (1,) x = block_cls(x, num_filters * 2 ** i, strides=strides, conv=conv, norm=norm, act=act) x = x.reshape(b, -1) x = nn.Dense(x, n_bins, dtype=dtype) x = nn.softmax(x) return x
def apply(self, x, features, n_stages, act=nn.relu): x = act(x) path = x for _ in range(n_stages): path = nn.max_pool(path, window_shape=(5, 5), strides=(1, 1), padding='SAME') path = ncsn_conv3x3(path, features, stride=1, bias=False) x = path + x return x
def apply( self, x, num_classes, num_filters=64, num_layers=50, train=True, axis_name=None, axis_index_groups=None, dtype=jnp.float32, batch_norm_momentum=0.9, batch_norm_epsilon=1e-5, bn_output_scale=0.0, virtual_batch_size=None, data_format=None): if num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[num_layers] conv = nn.Conv.partial(padding=[(3, 3), (3, 3)]) x = conv(x, num_filters, kernel_size=(7, 7), strides=(2, 2), bias=False, dtype=dtype, name='conv0') x = normalization.VirtualBatchNorm( x, use_running_average=not train, momentum=batch_norm_momentum, epsilon=batch_norm_epsilon, name='init_bn', axis_name=axis_name, axis_index_groups=axis_index_groups, dtype=dtype, virtual_batch_size=virtual_batch_size, data_format=data_format) x = nn.relu(x) # MLPerf-required x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = ResidualBlock( x, num_filters * 2 ** i, strides=strides, train=train, axis_name=axis_name, axis_index_groups=axis_index_groups, dtype=dtype, batch_norm_momentum=batch_norm_momentum, batch_norm_epsilon=batch_norm_epsilon, bn_output_scale=bn_output_scale, virtual_batch_size=virtual_batch_size, data_format=data_format) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(), dtype=dtype) return x
def apply(self, x, width): x = fixed_padding(x, 7) x = StdConv(x, width, (7, 7), (2, 2), padding="VALID", bias=False, name="conv_root") x = fixed_padding(x, 3) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="VALID") return x
def apply(self, x, num_classes, num_filters=64, num_layers=50, train=True, axis_name=None, axis_index_groups=None, dtype=jnp.float32, conv0_space_to_depth=False): if num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[num_layers] if conv0_space_to_depth: conv = SpaceToDepthConv.partial(block_size=(2, 2), padding=[(2, 1), (2, 1)]) else: conv = nn.Conv.partial(padding=[(3, 3), (3, 3)]) x = conv(x, num_filters, kernel_size=(7, 7), strides=(2, 2), bias=False, dtype=dtype, name='conv0') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn', axis_name=axis_name, axis_index_groups=axis_index_groups, dtype=dtype) x = nn.relu(x) # MLPerf-required x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = ResidualBlock(x, num_filters * 2**i, strides=strides, train=train, axis_name=axis_name, axis_index_groups=axis_index_groups, dtype=dtype) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(), dtype=dtype) x = nn.log_softmax(x) return x
def apply(self, x, num_classes, parameters, num_filters=64, train=True, axis_name=None, num_layers='34'): block_sizes = [3, 4, 6] data_format = 'channels_last' if ('conv0_space_to_depth' in parameters and parameters['conv0_space_to_depth']): # conv0 uses space-to-depth transform for TPU performance. x = func_conv0_space_to_depth(inputs=x, data_format=data_format, dtype=parameters['dtype']) else: x = conv2d_fixed_padding( inputs=x, filters=num_filters, kernel_size=7, strides=2, data_format=data_format, name='init_conv') replica_groups = _make_replica_groups(parameters) x = nn.BatchNorm( x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn', axis_name=axis_name, dtype=parameters['dtype'], axis_index_groups=replica_groups) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = 2 if i == 1 and j == 0 else 1 use_projection = False if i == 0 or j > 0 else True x = ResidualBlock( x, num_filters * 2**i, parameters, strides=strides, train=train, axis_name=axis_name, use_projection=use_projection, data_format=data_format) if num_layers == '34': x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(), dtype=jnp.float32) # TODO(deveci): dtype=dtype x = nn.log_softmax(x) return x
def apply(self, x, use_squeeze_excite=False): x = nn.Conv(x, features=8, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) x = nn.Conv(x, features=16, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) if use_squeeze_excite: x = SqueezeExciteLayer(x) x = nn.Conv(x, features=32, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) if use_squeeze_excite: x = SqueezeExciteLayer(x) x = nn.Conv(x, features=1, kernel_size=(3, 3), padding="VALID") scores = nn.max_pool(x, window_shape=(8, 8), strides=(8, 8))[Ellipsis, 0] return scores
def apply( self, x, num_outputs, num_filters=64, num_layers=50, train=True, batch_stats=None, dtype=jnp.float32, batch_norm_momentum=0.9, batch_norm_epsilon=1e-5, virtual_batch_size=None, data_format=None): if num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[num_layers] x = nn.Conv(x, num_filters, (7, 7), (2, 2), bias=False, dtype=dtype, name='init_conv') x = normalization.VirtualBatchNorm( x, batch_stats=batch_stats, use_running_average=not train, momentum=batch_norm_momentum, epsilon=batch_norm_epsilon, dtype=dtype, name='init_bn', virtual_batch_size=virtual_batch_size, data_format=data_format) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = ResidualBlock( x, num_filters * 2 ** i, strides=strides, train=train, batch_stats=batch_stats, dtype=dtype, batch_norm_momentum=batch_norm_momentum, batch_norm_epsilon=batch_norm_epsilon, virtual_batch_size=virtual_batch_size, data_format=data_format) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_outputs, dtype=dtype) return x
def test_max_pool(self): x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) pool = lambda x: nn.max_pool(x, (2, 2)) expected_y = jnp.array([ [4., 5.], [7., 8.], ]).reshape((1, 2, 2, 1)) y = pool(x) onp.testing.assert_allclose(y, expected_y) y_grad = jax.grad(lambda x: pool(x).sum())(x) expected_grad = jnp.array([ [0., 0., 0.], [0., 1., 1.], [0., 1., 1.], ]).reshape((1, 3, 3, 1)) onp.testing.assert_allclose(y_grad, expected_grad)
def features(x, num_layers, normalizer, dtype, train): """Implements the feature extraction portion of the network.""" layers = _layer_size_options[num_layers] conv = nn.Conv.partial(bias=False, dtype=dtype) maybe_normalize = model_utils.get_normalizer(normalizer, train) for l in layers: if l == 'M': x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) else: x = conv(x, features=l, kernel_size=(3, 3), padding=((1, 1), (1, 1))) x = maybe_normalize(x) x = nn.relu(x) return x
def apply(self, x, num_outputs, train=True): x = nn.Conv(x, self.NUM_FILTERS, (7, 7), (2, 2), bias=False, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn') x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(self.BLOCK_SIZES): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = BottleneckBlock(x, self.NUM_FILTERS * 2 ** i, strides=strides, groups=self.GROUPS, base_width=self.WIDTH_PER_GROUP, train=train) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_outputs, name='clf') return x
def test_max_pool_explicit_pads(self): x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) pool = lambda x: nn.max_pool(x, (2, 2), padding=((1, 1), (1, 1))) expected_y = jnp.array([ [0., 1., 2., 2.], [3., 4., 5., 5.], [6., 7., 8., 8.], [6., 7., 8., 8.], ]).reshape((1, 4, 4, 1)) y = pool(x) onp.testing.assert_allclose(y, expected_y) y_grad = jax.grad(lambda x: pool(x).sum())(x) expected_grad = jnp.array([ [1., 1., 2.], [1., 1., 2.], [2., 2., 4.], ]).reshape((1, 3, 3, 1)) onp.testing.assert_allclose(y_grad, expected_grad)
def apply(self, x, output_shape, encoder, decoder, train=True): if not len(encoder['filter_sizes']) == len(encoder['kernel_sizes']) == len( encoder['kernel_paddings']) == len(encoder['window_sizes']) == len( encoder['window_paddings']) == len(encoder['strides']) == len( encoder['activations']): raise ValueError( 'The elements in encoder dict do not have the same length.') if not len(decoder['filter_sizes']) == len(decoder['kernel_sizes']) == len( decoder['window_sizes']) == len(decoder['paddings']) == len( decoder['activations']): raise ValueError( 'The elements in decoder dict do not have the same length.') # encoder for i in range(len(encoder['filter_sizes'])): x = nn.Conv( x, encoder['filter_sizes'][i], encoder['kernel_sizes'][i], padding=encoder['kernel_paddings'][i]) x = model_utils.ACTIVATIONS[encoder['activations'][i]](x) x = nn.max_pool( x, encoder['window_sizes'][i], strides=encoder['strides'][i], padding=encoder['window_paddings'][i]) # decoder for i in range(len(decoder['filter_sizes'])): x = nn.ConvTranspose( x, decoder['filter_sizes'][i], decoder['kernel_sizes'][i], decoder['window_sizes'][i], padding=decoder['paddings'][i]) x = model_utils.ACTIVATIONS[decoder['activations'][i]](x) return x
def apply(self, x, num_outputs, num_filters, kernel_sizes, kernel_paddings, window_sizes, window_paddings, strides, num_dense_units, activation_fn, normalizer='none', kernel_init=initializers.lecun_normal(), bias_init=initializers.zeros, train=True): maybe_normalize = model_utils.get_normalizer(normalizer, train) for num_filters, kernel_size, kernel_padding, window_size, window_padding, stride in zip( num_filters, kernel_sizes, kernel_paddings, window_sizes, window_paddings, strides): x = nn.Conv( x, num_filters, (kernel_size, kernel_size), (1, 1), padding=kernel_padding, kernel_init=kernel_init, bias_init=bias_init) x = model_utils.ACTIVATIONS[activation_fn](x) x = maybe_normalize(x) x = nn.max_pool( x, window_shape=(window_size, window_size), strides=(stride, stride), padding=window_padding) x = jnp.reshape(x, (x.shape[0], -1)) for num_units in num_dense_units: x = nn.Dense( x, num_units, kernel_init=kernel_init, bias_init=bias_init) x = model_utils.ACTIVATIONS[activation_fn](x) x = maybe_normalize(x) x = nn.Dense(x, num_outputs, kernel_init=kernel_init, bias_init=bias_init) return x
def apply(self, x, num_classes=1000, train=False, width_factor=1, num_layers=50): del train blocks, bottleneck = get_block_desc(num_layers) width = int(64 * width_factor) # Root block x = StdConv(x, width, (7, 7), (2, 2), bias=False, name="conv_root") x = nn.GroupNorm(x, name="gn_root") x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") # Stages x = ResNetStage(x, blocks[0], width, first_stride=(1, 1), bottleneck=bottleneck, name="block1") for i, block_size in enumerate(blocks[1:], 1): x = ResNetStage(x, block_size, width * 2**i, first_stride=(2, 2), bottleneck=bottleneck, name=f"block{i + 1}") # Head x = jnp.mean(x, axis=(1, 2)) x = IdentityLayer(x, name="pre_logits") x = nn.Dense(x, num_classes, kernel_init=nn.initializers.zeros, name="head") return x
def apply(self, x, *, train, num_classes, block_class=BottleneckResNetImageNetBlock, stage_sizes, width_factor=1, normalization='bn', activation_f=None, std_penalty_mult=0, use_residual=1, bias_scale=0.0, weight_norm='none', compensate_padding=True, softplus_scale=None, no_head=False, zero_inits=True): """Construct ResNet V1 with `num_classes` outputs.""" self._stage_sizes = stage_sizes if std_penalty_mult > 0: raise NotImplementedError( 'std_penalty_mult not supported for ResNetImageNet') width = 64 * width_factor # Root block. activation_f = get_activation_f(activation_f, train, softplus_scale, bias_scale) norm = get_norm(activation_f, normalization, train) conv = get_conv(activation_f, bias_scale, weight_norm, compensate_padding, normalization) x = conv(x, width, kernel_size=(7, 7), strides=(2, 2), name='init_conv') x = norm(x, name='init_bn') if compensate_padding: # NOTE: this leads to lower performance. x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding='SAME') else: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') # Stages. for i, stage_size in enumerate(stage_sizes): x = ResNetStage( x, stage_size, filters=width * 2**i, block_class=block_class, first_block_strides=(1, 1) if i == 0 else (2, 2), train=train, name=f'stage{i + 1}', conv=conv, norm=norm, activation_f=activation_f, use_residual=use_residual, zero_inits=zero_inits, ) if not no_head: # Head. x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.zeros if zero_inits else nn.initializers.lecun_normal(), name='head') return x, 0, {}
def apply(self, x, depth, num_outputs, dropout_rate=0.0, normalization='bn', activation_f=None, std_penalty_mult=0, use_residual=1, train=True, bias_scale=0.0, weight_norm='none', filters=16, no_head=False, report_metrics=False, benchmark='cifar10', compensate_padding=True, softplus_scale=None): bn_index = iter(range(1000)) conv_index = iter(range(1000)) summaries = {} summary_ind = [0] def add_summary(name, val): """Summarize statistics of tensor.""" if report_metrics: assert val.ndim == 4, ( 'Assuming 4D inputs with channels last, got %s' % str(val.shape)) assert val.shape[1] == val.shape[ 2], 'Assuming 4D inputs with channels last' summaries['%s_%d_mean_abs' % (name, summary_ind[0] // 2)] = jnp.mean( jnp.abs(jnp.mean(val, axis=(0, 1, 2)))) summaries['%s_%d_mean_std' % (name, summary_ind[0] // 2)] = jnp.mean( jnp.std(val, axis=(0, 1, 2))) summary_ind[0] += 1 penalty = 0 activation_f = get_activation_f(activation_f, train, softplus_scale, bias_scale) norm = get_norm(activation_f, normalization, train) conv = get_conv(activation_f, bias_scale, weight_norm, compensate_padding, normalization) def resnet_layer( inputs, penalty, filters, kernel_size=3, strides=1, activation=None, ): """2D Convolution-Batch Normalization-Activation stack builder.""" x = inputs x = conv(x, filters, (kernel_size, kernel_size), strides=(strides, strides), padding='SAME', name='conv%d' % next(conv_index)) x = norm(x, name='norm%d' % next(bn_index)) add_summary('postnorm', x) if std_penalty_mult > 0: penalty += std_penalty(x) if activation: x = activation_f(x, features=x.shape[-1]) add_summary('postact', x) return x, penalty # Main network code. num_res_blocks = (depth - 2) // 6 if (depth - 2) % 6 != 0: raise ValueError('depth must be 6n+2 (e.g. 20, 32, 44).') inputs = x add_summary('input', x) add_summary('inputb', x) if benchmark in ['cifar10', 'cifar100']: x, penalty = resnet_layer(inputs, penalty, filters=filters, activation=True) head_kernel_init = nn.initializers.lecun_normal() elif benchmark in ['imagenet']: head_kernel_init = nn.initializers.zeros x, penalty = resnet_layer(inputs, penalty, filters=filters, activation=False, kernel_size=7, strides=2) # TODO(basv): evaluate max pool v/s avg_pool in an experiment? # if compensate_padding: # x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding="VALID") # else: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') else: raise ValueError('Model def not prepared for benchmark %s' % benchmark) for stack in range(3): for res_block in range(num_res_blocks): strides = 1 if stack > 0 and res_block == 0: # First layer but not first stack. strides = 2 # Downsample. y, penalty = resnet_layer( x, penalty, filters=filters, strides=strides, activation=True, ) y, penalty = resnet_layer( y, penalty, filters=filters, activation=False, ) if stack > 0 and res_block == 0: # First layer but not first stack. # Linear projection residual shortcut to match changed dims. x, penalty = resnet_layer( x, penalty, filters=filters, kernel_size=1, strides=strides, activation=False, ) if use_residual == 1: # Apply an up projection in case of channel mismatch x = x + y elif use_residual == 2: x = (x + y) / jnp.sqrt( 1**2 + 1**2) # Sum of independent normals. else: x = y add_summary('postres', x) x = activation_f(x, features=x.shape[-1]) add_summary('postresact', x) filters *= 2 # V1 does not use BN after last shortcut connection-ReLU. if not no_head: x = jnp.mean(x, axis=(1, 2)) add_summary('postpool', x) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs, kernel_init=head_kernel_init) return x, penalty, summaries
def apply(self, x, num_classes=1000, train=False, resnet=None, patches=None, hidden_size=None, transformer=None, representation_size=None, classifier='gap'): # (Possibly partial) ResNet root. if resnet is not None: width = int(64 * resnet.width_factor) # Root block. x = models_resnet.StdConv(x, width, (7, 7), (2, 2), bias=False, name='conv_root') x = nn.GroupNorm(x, name='gn_root') x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') # ResNet stages. x = models_resnet.ResNetStage(x, resnet.num_layers[0], width, first_stride=(1, 1), name='block1') for i, block_size in enumerate(resnet.num_layers[1:], 1): x = models_resnet.ResNetStage(x, block_size, width * 2**i, first_stride=(2, 2), name=f'block{i + 1}') n, h, w, c = x.shape # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv(x, hidden_size, patches.size, strides=patches.size, padding='VALID', name='embedding') # Here, x is a grid of embeddings. # (Possibly partial) Transformer. if transformer is not None: n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if classifier == 'token': cls = self.param('cls', (1, 1, c), nn.initializers.zeros) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = Encoder(x, train=train, name='Transformer', **transformer) if classifier == 'token': x = x[:, 0] elif classifier == 'gap': x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) if representation_size is not None: x = nn.Dense(x, representation_size, name='pre_logits') x = nn.tanh(x) else: x = IdentityLayer(x, name='pre_logits') x = nn.Dense(x, num_classes, name='head', kernel_init=nn.initializers.zeros) return x
def apply(self, x, num_classes=1, train=False, hidden_size=None, transformer=None, resnet_emb=None, representation_size=None): """Apply model on inputs. Args: x: the processed input patches and position annotations. num_classes: the number of output classes. 1 for single model. train: train or eval. hidden_size: the hidden dimension for patch embedding tokens. transformer: the model config for Transformer backbone. resnet_emb: the config for patch embedding w/ small resnet. representation_size: size of the last FC before prediction. Returns: Model prediction output. """ assert transformer is not None # Either 3: (batch size, seq len, channel) or # 4: (batch size, crops, seq len, channel) assert len(x.shape) in [3, 4] multi_crops_input = False if len(x.shape) == 4: multi_crops_input = True batch_size, num_crops, l, channel = x.shape x = jnp.reshape(x, [batch_size * num_crops, l, channel]) # We concat (x, spatial_positions, scale_posiitons, input_masks) # when preprocessing. inputs_spatial_positions = x[:, :, -3] inputs_spatial_positions = inputs_spatial_positions.astype(jnp.int32) inputs_scale_positions = x[:, :, -2] inputs_scale_positions = inputs_scale_positions.astype(jnp.int32) inputs_masks = x[:, :, -1] inputs_masks = inputs_masks.astype(jnp.bool_) x = x[:, :, :-3] n, l, channel = x.shape if hidden_size: if resnet_emb: # channel = patch_size * patch_size * 3 patch_size = int(np.sqrt(channel // 3)) x = jnp.reshape(x, [-1, patch_size, patch_size, 3]) x = resnet.StdConv( x, RESNET_TOKEN_DIM, (7, 7), (2, 2), bias=False, name="conv_root") x = nn.GroupNorm(x, name="gn_root") x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") if resnet_emb.num_layers > 0: blocks, bottleneck = resnet.get_block_desc(resnet_emb.num_layers) if blocks: x = resnet.ResNetStage( x, blocks[0], RESNET_TOKEN_DIM, first_stride=(1, 1), bottleneck=bottleneck, name="block1") for i, block_size in enumerate(blocks[1:], 1): x = resnet.ResNetStage( x, block_size, RESNET_TOKEN_DIM * 2**i, first_stride=(2, 2), bottleneck=bottleneck, name=f"block{i + 1}") x = jnp.reshape(x, [n, l, -1]) x = nn.Dense(x, hidden_size, name="embedding") # Here, x is a list of embeddings. x = utils.Encoder( x, inputs_spatial_positions, inputs_scale_positions, inputs_masks, train=train, name="Transformer", **transformer) x = x[:, 0] if representation_size: x = nn.Dense(x, representation_size, name="pre_logits") x = nn.tanh(x) else: x = resnet.IdentityLayer(x, name="pre_logits") x = nn.Dense(x, num_classes, name="head", kernel_init=nn.initializers.zeros) if multi_crops_input: _, channel = x.shape x = jnp.reshape(x, [batch_size, num_crops, channel]) return x