def apply(self, x, *, stride, filters, train): norm_layer = nn.BatchNorm.partial(use_running_average=not train, momentum=0.9, epsilon=1e-5) conv3x3 = nn.Conv.partial(kernel_size=(3, 3), padding="SAME", bias=False) conv1x1 = nn.Conv.partial(kernel_size=(1, 1), padding="SAME", bias=False) x = norm_layer(x) x = nn.relu(x) identity = x needs_projection = x.shape[-1] != filters or stride != (1, 1) if needs_projection: identity = conv1x1(x, features=filters, strides=stride) x = conv3x3(x, features=filters, strides=stride) x = norm_layer(x) x = nn.relu(x) x = conv3x3(x, features=filters, strides=(1, 1)) x += identity return x
def apply(self, x, filters, strides=(1, 1), dropout_rate=0.0, epsilon=1e-5, momentum=0.9, norm_layer='batch_norm', train=True, dtype=jnp.float32): # TODO(samirabnar): Make 4 a parameter. needs_projection = x.shape[-1] != filters * 4 or strides != (1, 1) norm_layer_name = '' if norm_layer == 'batch_norm': norm_layer = nn.BatchNorm.partial(use_running_average=not train, momentum=momentum, epsilon=epsilon, dtype=dtype) norm_layer_name = 'bn' elif norm_layer == 'group_norm': norm_layer = nn.GroupNorm.partial(num_groups=16, dtype=dtype) norm_layer_name = 'gn' conv = nn.Conv.partial(bias=False, dtype=dtype) residual = x if needs_projection: residual = conv(residual, filters * 4, (1, 1), strides, name='proj_conv') residual = norm_layer(residual, name=f'proj_{norm_layer_name}') y = conv(x, filters, (1, 1), name='conv1') y = norm_layer(y, name=f'{norm_layer_name}1') y = nn.relu(y) y = conv(y, filters, (3, 3), strides, name='conv2') y = norm_layer(y, name=f'{norm_layer_name}2') y = nn.relu(y) if dropout_rate > 0.0: y = nn.dropout(y, dropout_rate, deterministic=not train) y = conv(y, filters * 4, (1, 1), name='conv3') y = norm_layer(y, name=f'{norm_layer_name}3', scale_init=nn.initializers.zeros) y = nn.relu(residual + y) return y
def apply(self, x, inner_channels=8): x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, padding='SAME') x = nn.relu(x) #x = nn.BatchNorm(x) x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, padding='SAME') x = nn.relu(x) #x = nn.BatchNorm(x) x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, padding='SAME') x = nn.relu(x) #x = nn.BatchNorm(x) return x
def apply(self, x, use_squeeze_excite = False): x = nn.Conv(x, features=8, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) x = nn.Conv(x, features=16, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) if use_squeeze_excite: x = SqueezeExciteLayer(x) x = nn.Conv(x, features=32, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) if use_squeeze_excite: x = SqueezeExciteLayer(x) x = nn.Conv(x, features=1, kernel_size=(3, 3), padding="VALID") scores = nn.max_pool(x, window_shape=(8, 8), strides=(8, 8))[Ellipsis, 0] return scores
def apply(self, x): x = nn.Conv(x, features=32, kernel_size=(3, 3), name="conv") x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten. x = nn.Dense(x, 128, name="fc") return x
def apply(self, x, reduction=16): num_channels = x.shape[-1] y = x.mean(axis=(1, 2)) y = nn.Dense(y, features=num_channels // reduction, bias=False) y = nn.relu(y) y = nn.Dense(y, features=num_channels, bias=False) y = nn.sigmoid(y) return x * y[:, None, None, :]
def apply(self, x, inner_channels=8): x = NonLinearCycle(x, 4, inner_channels) x = nn.Conv(x, features=1, kernel_size=(3, 3), bias=False, padding='SAME') x = nn.relu(x) return x
def apply(self, x, strides=(1, 2, 2, 2), filters=(32, 32, 32, 32), train=True): """This is an adaptation of a ResNetv2 used in ATS (see links above). Note that the size of each block is fixed to 1 and the first block is only a convolution. Args: x: Input tensor of shape (b, h, w, c). strides: Strides of the blocks. filters: Number of filters of each block. train: Whether the module is being trained. Returns: The global averaged and normalized vector representation of each image. """ norm_layer = nn.BatchNorm.partial(use_running_average=not train, momentum=0.9, epsilon=1e-5) conv3x3 = nn.Conv.partial(kernel_size=(3, 3), padding="SAME", bias=False) # Make strides a pair of integer instead of an int strides = [(s, s) if isinstance(s, int) else s for s in strides] x = conv3x3(x, features=filters[0], strides=strides[0]) for s, f in zip(strides[1:], filters[1:]): x = BasicBlockv2(x, stride=s, filters=f, train=train) x = norm_layer(x) x = nn.relu(x) # Global average pooling and l2 normalize. x = x.mean(axis=(1, 2)) return x
def apply(self, x, config, num_classes, train=True): """Creates a model definition.""" if config.get("append_position_to_input", False): b, h, w, _ = x.shape coords = utils.create_grid([h, w], value_range=(0., 1.)) x = jnp.concatenate( [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1) if config.model.lower() == "cnn": h = models.SimpleCNNImageClassifier(x) h = nn.relu(h) stats = None elif config.model.lower() == "resnet": smallinputs = config.get("resnet.small_inputs", False) blocks = config.get("resnet.blocks", [3, 4, 6, 3]) h = models.ResNet(x, train=train, block_sizes=blocks, small_inputs=smallinputs) h = jnp.mean(h, axis=[1, 2]) # global average pool stats = None elif config.model.lower() == "resnet18": h = models.ResNet18(x, train=train) h = jnp.mean(h, axis=[1, 2]) # global average pool stats = None elif config.model.lower() == "resnet50": h = models.ResNet50(x, train=train) h = jnp.mean(h, axis=[1, 2]) # global average pool stats = None elif config.model.lower() == "ats-traffic": h = models.ATSFeatureNetwork(x, train=train) stats = None elif config.model.lower() == "patchnet": feature_network = { "resnet18": models.ResNet18, "resnet18-fourth": models.ResNet.partial(num_filters=16, block_sizes=(2, 2, 2, 2), block=models.BasicBlock), "resnet50": models.ResNet50, "ats-traffic": models.ATSFeatureNetwork, }[config.feature_network.lower()] selection_method = sample_patches.SelectionMethod( config.selection_method) selection_method_kwargs = {} if selection_method is sample_patches.SelectionMethod.SINKHORN_TOPK: selection_method_kwargs = config.sinkhorn_topk_kwargs if selection_method is sample_patches.SelectionMethod.PERTURBED_TOPK: selection_method_kwargs = config.perturbed_topk_kwargs h, stats = sample_patches.PatchNet( x, patch_size=config.patch_size, k=config.k, downscale=config.downscale, scorer_has_se=config.get("scorer_has_se", False), selection_method=config.selection_method, selection_method_kwargs=selection_method_kwargs, selection_method_inference=config.get( "selection_method_inference", None), normalization_str=config.normalization_str, aggregation_method=config.aggregation_method, aggregation_method_kwargs=config.get( "aggregation_method_kwargs", {}), append_position_to_input=config.get("append_position_to_input", False), feature_network=feature_network, use_iterative_extraction=config.use_iterative_extraction, hard_topk_probability=config.get("hard_topk_probability", 0.), random_patch_probability=config.get("random_patch_probability", 0.), train=train) stats["x"] = x else: raise RuntimeError("Unknown classification model type: %s" % config.model.lower()) out = nn.Dense(h, num_classes, name="final") return out, stats
def apply(self, x): x = nn.Dense(x, features=256) x = nn.relu(x) x = nn.Dense(x, features=256) x = nn.relu(x) return nn.Dense(x, features=2)
def apply(self, x, communication=Communication.NONE, train=True): """Forward pass.""" batch_size = x.shape[0] if communication is Communication.SQUEEZE_EXCITE_X: x = sample_patches.SqueezeExciteLayer(x) # end if squeeze excite x d1 = nn.relu( nn.Conv(x, 128, kernel_size=(3, 3), strides=(1, 1), bias=True, name="down1")) d2 = nn.relu( nn.Conv(d1, 128, kernel_size=(3, 3), strides=(2, 2), bias=True, name="down2")) d3 = nn.relu( nn.Conv(d2, 128, kernel_size=(3, 3), strides=(2, 2), bias=True, name="down3")) if communication is Communication.SQUEEZE_EXCITE_D: d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c") d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c") d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c") nd1 = d1_flatten.shape[1] nd2 = d2_flatten.shape[1] d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten], axis=1) num_channels = d_together.shape[-1] y = d_together.mean(axis=1) y = nn.Dense(y, features=num_channels // 4, bias=False) y = nn.relu(y) y = nn.Dense(y, features=num_channels, bias=False) y = nn.sigmoid(y) d_together = d_together * y[:, None, :] # split and reshape d1 = d_together[:, :nd1].reshape(d1.shape) d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape) d3 = d_together[:, nd1 + nd2:].reshape(d3.shape) elif communication is Communication.TRANSFORMER: d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c") d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c") d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c") nd1 = d1_flatten.shape[1] nd2 = d2_flatten.shape[1] d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten], axis=1) positional_encodings = self.param( "scale_ratio_position_encodings", shape=(1, ) + d_together.shape[1:], initializer=jax.nn.initializers.normal(1. / d_together.shape[-1])) d_together = transformer.Transformer(d_together + positional_encodings, num_layers=2, num_heads=8, is_training=train) # split and reshape d1 = d_together[:, :nd1].reshape(d1.shape) d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape) d3 = d_together[:, nd1 + nd2:].reshape(d3.shape) t1 = nn.Conv(d1, 6, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy1") t2 = nn.Conv(d2, 6, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy2") t3 = nn.Conv(d3, 9, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy3") raw_scores = (jnp.split(t1, 6, axis=-1) + jnp.split(t2, 6, axis=-1) + jnp.split(t3, 9, axis=-1)) # The following is for normalization. t = jnp.concatenate((jnp.reshape( t1, [batch_size, -1]), jnp.reshape( t2, [batch_size, -1]), jnp.reshape(t3, [batch_size, -1])), axis=1) t_min = jnp.reshape(jnp.min(t, axis=-1), [batch_size, 1, 1, 1]) t_max = jnp.reshape(jnp.max(t, axis=-1), [batch_size, 1, 1, 1]) normalized_scores = zeroone(raw_scores, t_min, t_max) stats = { "scores": normalized_scores, "raw_scores": t, } # removes the split dimension. scores are now b x h' x w' shaped normalized_scores = [s.squeeze(-1) for s in normalized_scores] return normalized_scores, stats
def apply( self, inputs, blocks_per_group, channel_multiplier, num_outputs, kernel_size=(3, 3), strides=None, maxpool=False, dropout_rate=0.0, dtype=jnp.float32, norm_layer='group_norm', train=True, return_activations=False, input_layer_key='input', has_discriminator=False, discriminator=False, ): norm_layer_name = '' if norm_layer == 'batch_norm': norm_layer = nn.BatchNorm.partial(use_running_average=not train) norm_layer_name = 'bn' elif norm_layer == 'group_norm': norm_layer = nn.GroupNorm.partial(num_groups=16) norm_layer_name = 'gn' layer_activations = collections.OrderedDict() input_is_set = False current_rep_key = 'input' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'init_conv' if input_is_set: x = nn.Conv( x, 16, kernel_size=kernel_size, strides=strides, padding='SAME', name='init_conv') if maxpool: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l1' if input_is_set: x = WideResnetGroup( x, blocks_per_group, 16 * channel_multiplier, dropout_rate=dropout_rate, norm_layer=norm_layer, train=train, name='l1') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l2' if input_is_set: x = WideResnetGroup( x, blocks_per_group, 32 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, norm_layer=norm_layer, train=train, name='l2') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l3' if input_is_set: x = WideResnetGroup( x, blocks_per_group, 64 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, norm_layer=norm_layer, train=train, name='l3') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l4' if input_is_set: x = norm_layer(x, name=f'{norm_layer_name}') x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key # DANN module if has_discriminator: z = dann_utils.flip_grad_identity(x) z = nn.Dense(z, 2, name='disc_l1', bias=True) z = nn.relu(z) z = nn.Dense(z, 2, name='disc_l2', bias=True) current_rep_key = 'head' if input_is_set: x = nn.Dense(x, num_outputs, dtype=dtype, name='head') else: x = inputs layer_activations[current_rep_key] = x rep_key = current_rep_key logging.warn('Input was never used') outputs = x if return_activations: outputs = (x, layer_activations, rep_key) if discriminator and has_discriminator: outputs = outputs + (z,) else: del layer_activations if discriminator and has_discriminator: outputs = (x, z) if discriminator and (not has_discriminator): raise ValueError( 'Incosistent values passed for discriminator and has_discriminator') return outputs
def apply(self, x): x = nn.Dense(x, hidden_reps_dim, bias=True, name='l1') x = nn.relu(x) x = nn.Dense(x, hidden_reps_dim, bias=True, name='l2') return x
def apply(self, inputs, num_outputs, num_filters=64, num_layers=50, dropout_rate=0.0, input_dropout_rate=0.0, train=True, dtype=jnp.float32, head_bias_init=jnp.zeros, return_activations=False, input_layer_key='input', has_discriminator=False, discriminator=False): """Apply a ResNet network on the input. Args: inputs: jnp array; Inputs. num_outputs: int; Number of output units. num_filters: int; Determines base number of filters. Number of filters in block i is num_filters * 2 ** i. num_layers: int; Number of layers (should be one of the predefined ones.) dropout_rate: float; Rate of dropping out the output of different hidden layers. input_dropout_rate: float; Rate of dropping out the input units. train: bool; Is train? dtype: jnp type; Type of the outputs. head_bias_init: fn(rng_key, shape)--> jnp array; Initializer for head bias parameters. return_activations: bool; If True hidden activation are also returned. input_layer_key: str; Determines where to plugin the input (this is to enable providing inputs to slices of the model). If `input_layer_key` is `layer_i` we assume the inputs are the activations of `layer_i` and pass them to `layer_{i+1}`. has_discriminator: bool; Whether the model should have discriminator layer. discriminator: bool; Whether we should return discriminator logits. Returns: Unnormalized Logits with shape `[bs, num_outputs]`, if return_activations: Logits, dict of hidden activations and the key to the representation(s) which will be used in as ``The Representation'', e.g., for computing losses. """ if num_layers not in ResNet._block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = ResNet._block_size_options[num_layers] layer_activations = collections.OrderedDict() input_is_set = False current_rep_key = 'input' if input_layer_key == current_rep_key: x = inputs input_is_set = True if input_is_set: # Input dropout x = nn.dropout(x, input_dropout_rate, deterministic=not train) layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'init_conv' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: # First block x = nn.Conv(x, num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], bias=False, dtype=dtype, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn') x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') layer_activations[current_rep_key] = x rep_key = current_rep_key # Residual blocks for i, block_size in enumerate(block_sizes): # Stage i (each stage contains blocks of the same size). for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) current_rep_key = f'block_{i + 1}+{j}' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: x = ResidualBlock(x, num_filters * 2**i, strides=strides, dropout_rate=dropout_rate, train=train, dtype=dtype, name=f'block_{i + 1}_{j}') layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'avg_pool' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: # Global Average Pool x = jnp.mean(x, axis=(1, 2)) layer_activations[current_rep_key] = x rep_key = current_rep_key # DANN module if has_discriminator: z = dann_utils.flip_grad_identity(x) z = nn.Dense(z, 2, name='disc_l1', bias=True) z = nn.relu(z) z = nn.Dense(z, 2, name='disc_l2', bias=True) current_rep_key = 'head' if input_layer_key == current_rep_key: x = inputs layer_activations[current_rep_key] = x rep_key = current_rep_key logging.warn('Input was never used') elif input_is_set: x = nn.Dense(x, num_outputs, dtype=dtype, bias_init=head_bias_init, name='head') # Make sure that the output is float32, even if our previous computations # are in float16, or other types. x = jnp.asarray(x, jnp.float32) outputs = x if return_activations: outputs = (x, layer_activations, rep_key) if discriminator and has_discriminator: outputs = outputs + (z, ) else: del layer_activations if discriminator and has_discriminator: outputs = (x, z) if discriminator and (not has_discriminator): raise ValueError( 'Incosistent values passed for discriminator and has_discriminator' ) return outputs