Пример #1
0
    def apply(self, x, *, stride, filters, train):
        norm_layer = nn.BatchNorm.partial(use_running_average=not train,
                                          momentum=0.9,
                                          epsilon=1e-5)
        conv3x3 = nn.Conv.partial(kernel_size=(3, 3),
                                  padding="SAME",
                                  bias=False)
        conv1x1 = nn.Conv.partial(kernel_size=(1, 1),
                                  padding="SAME",
                                  bias=False)

        x = norm_layer(x)
        x = nn.relu(x)
        identity = x
        needs_projection = x.shape[-1] != filters or stride != (1, 1)
        if needs_projection:
            identity = conv1x1(x, features=filters, strides=stride)

        x = conv3x3(x, features=filters, strides=stride)
        x = norm_layer(x)
        x = nn.relu(x)
        x = conv3x3(x, features=filters, strides=(1, 1))

        x += identity
        return x
Пример #2
0
    def apply(self,
              x,
              filters,
              strides=(1, 1),
              dropout_rate=0.0,
              epsilon=1e-5,
              momentum=0.9,
              norm_layer='batch_norm',
              train=True,
              dtype=jnp.float32):

        # TODO(samirabnar): Make 4 a parameter.
        needs_projection = x.shape[-1] != filters * 4 or strides != (1, 1)
        norm_layer_name = ''
        if norm_layer == 'batch_norm':
            norm_layer = nn.BatchNorm.partial(use_running_average=not train,
                                              momentum=momentum,
                                              epsilon=epsilon,
                                              dtype=dtype)
            norm_layer_name = 'bn'
        elif norm_layer == 'group_norm':
            norm_layer = nn.GroupNorm.partial(num_groups=16, dtype=dtype)
            norm_layer_name = 'gn'

        conv = nn.Conv.partial(bias=False, dtype=dtype)

        residual = x
        if needs_projection:
            residual = conv(residual,
                            filters * 4, (1, 1),
                            strides,
                            name='proj_conv')
            residual = norm_layer(residual, name=f'proj_{norm_layer_name}')

        y = conv(x, filters, (1, 1), name='conv1')
        y = norm_layer(y, name=f'{norm_layer_name}1')
        y = nn.relu(y)

        y = conv(y, filters, (3, 3), strides, name='conv2')
        y = norm_layer(y, name=f'{norm_layer_name}2')
        y = nn.relu(y)

        if dropout_rate > 0.0:
            y = nn.dropout(y, dropout_rate, deterministic=not train)
        y = conv(y, filters * 4, (1, 1), name='conv3')
        y = norm_layer(y,
                       name=f'{norm_layer_name}3',
                       scale_init=nn.initializers.zeros)
        y = nn.relu(residual + y)

        return y
Пример #3
0
 def apply(self, x, inner_channels=8):
   x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3),
                               bias=False, padding='SAME')
   x = nn.relu(x)
   #x = nn.BatchNorm(x)
   x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3),
                               bias=False, padding='SAME')
   x = nn.relu(x)
   #x = nn.BatchNorm(x)
   x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3),
                               bias=False, padding='SAME')
   x = nn.relu(x)
   #x = nn.BatchNorm(x)
   return x
Пример #4
0
 def apply(self, x, use_squeeze_excite = False):
   x = nn.Conv(x, features=8, kernel_size=(3, 3), padding="VALID")
   x = nn.relu(x)
   x = nn.Conv(x, features=16, kernel_size=(3, 3), padding="VALID")
   x = nn.relu(x)
   if use_squeeze_excite:
     x = SqueezeExciteLayer(x)
   x = nn.Conv(x, features=32, kernel_size=(3, 3), padding="VALID")
   x = nn.relu(x)
   if use_squeeze_excite:
     x = SqueezeExciteLayer(x)
   x = nn.Conv(x, features=1, kernel_size=(3, 3), padding="VALID")
   scores = nn.max_pool(x, window_shape=(8, 8), strides=(8, 8))[Ellipsis, 0]
   return scores
Пример #5
0
 def apply(self, x):
     x = nn.Conv(x, features=32, kernel_size=(3, 3), name="conv")
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = x.reshape((x.shape[0], -1))  # flatten.
     x = nn.Dense(x, 128, name="fc")
     return x
Пример #6
0
 def apply(self, x, reduction=16):
     num_channels = x.shape[-1]
     y = x.mean(axis=(1, 2))
     y = nn.Dense(y, features=num_channels // reduction, bias=False)
     y = nn.relu(y)
     y = nn.Dense(y, features=num_channels, bias=False)
     y = nn.sigmoid(y)
     return x * y[:, None, None, :]
Пример #7
0
 def apply(self, x, inner_channels=8):
     x = NonLinearCycle(x, 4, inner_channels)
     x = nn.Conv(x,
                 features=1,
                 kernel_size=(3, 3),
                 bias=False,
                 padding='SAME')
     x = nn.relu(x)
     return x
Пример #8
0
    def apply(self,
              x,
              strides=(1, 2, 2, 2),
              filters=(32, 32, 32, 32),
              train=True):
        """This is an adaptation of a ResNetv2 used in ATS (see links above).

    Note that the size of each block is fixed to 1 and the first block is only
    a convolution.

    Args:
      x: Input tensor of shape (b, h, w,  c).
      strides: Strides of the blocks.
      filters: Number of filters of each block.
      train: Whether the module is being trained.

    Returns:
      The global averaged and normalized vector representation of each image.
    """
        norm_layer = nn.BatchNorm.partial(use_running_average=not train,
                                          momentum=0.9,
                                          epsilon=1e-5)
        conv3x3 = nn.Conv.partial(kernel_size=(3, 3),
                                  padding="SAME",
                                  bias=False)

        # Make strides a pair of integer instead of an int
        strides = [(s, s) if isinstance(s, int) else s for s in strides]

        x = conv3x3(x, features=filters[0], strides=strides[0])

        for s, f in zip(strides[1:], filters[1:]):
            x = BasicBlockv2(x, stride=s, filters=f, train=train)

        x = norm_layer(x)
        x = nn.relu(x)

        # Global average pooling and l2 normalize.
        x = x.mean(axis=(1, 2))

        return x
Пример #9
0
    def apply(self, x, config, num_classes, train=True):
        """Creates a model definition."""

        if config.get("append_position_to_input", False):
            b, h, w, _ = x.shape
            coords = utils.create_grid([h, w], value_range=(0., 1.))
            x = jnp.concatenate(
                [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1)

        if config.model.lower() == "cnn":
            h = models.SimpleCNNImageClassifier(x)
            h = nn.relu(h)
            stats = None
        elif config.model.lower() == "resnet":
            smallinputs = config.get("resnet.small_inputs", False)
            blocks = config.get("resnet.blocks", [3, 4, 6, 3])
            h = models.ResNet(x,
                              train=train,
                              block_sizes=blocks,
                              small_inputs=smallinputs)
            h = jnp.mean(h, axis=[1, 2])  # global average pool
            stats = None
        elif config.model.lower() == "resnet18":
            h = models.ResNet18(x, train=train)
            h = jnp.mean(h, axis=[1, 2])  # global average pool
            stats = None
        elif config.model.lower() == "resnet50":
            h = models.ResNet50(x, train=train)
            h = jnp.mean(h, axis=[1, 2])  # global average pool
            stats = None
        elif config.model.lower() == "ats-traffic":
            h = models.ATSFeatureNetwork(x, train=train)
            stats = None
        elif config.model.lower() == "patchnet":
            feature_network = {
                "resnet18":
                models.ResNet18,
                "resnet18-fourth":
                models.ResNet.partial(num_filters=16,
                                      block_sizes=(2, 2, 2, 2),
                                      block=models.BasicBlock),
                "resnet50":
                models.ResNet50,
                "ats-traffic":
                models.ATSFeatureNetwork,
            }[config.feature_network.lower()]

            selection_method = sample_patches.SelectionMethod(
                config.selection_method)
            selection_method_kwargs = {}
            if selection_method is sample_patches.SelectionMethod.SINKHORN_TOPK:
                selection_method_kwargs = config.sinkhorn_topk_kwargs
            if selection_method is sample_patches.SelectionMethod.PERTURBED_TOPK:
                selection_method_kwargs = config.perturbed_topk_kwargs

            h, stats = sample_patches.PatchNet(
                x,
                patch_size=config.patch_size,
                k=config.k,
                downscale=config.downscale,
                scorer_has_se=config.get("scorer_has_se", False),
                selection_method=config.selection_method,
                selection_method_kwargs=selection_method_kwargs,
                selection_method_inference=config.get(
                    "selection_method_inference", None),
                normalization_str=config.normalization_str,
                aggregation_method=config.aggregation_method,
                aggregation_method_kwargs=config.get(
                    "aggregation_method_kwargs", {}),
                append_position_to_input=config.get("append_position_to_input",
                                                    False),
                feature_network=feature_network,
                use_iterative_extraction=config.use_iterative_extraction,
                hard_topk_probability=config.get("hard_topk_probability", 0.),
                random_patch_probability=config.get("random_patch_probability",
                                                    0.),
                train=train)
            stats["x"] = x
        else:
            raise RuntimeError("Unknown classification model type: %s" %
                               config.model.lower())
        out = nn.Dense(h, num_classes, name="final")
        return out, stats
Пример #10
0
 def apply(self, x):
     x = nn.Dense(x, features=256)
     x = nn.relu(x)
     x = nn.Dense(x, features=256)
     x = nn.relu(x)
     return nn.Dense(x, features=2)
Пример #11
0
    def apply(self, x, communication=Communication.NONE, train=True):
        """Forward pass."""
        batch_size = x.shape[0]

        if communication is Communication.SQUEEZE_EXCITE_X:
            x = sample_patches.SqueezeExciteLayer(x)
        # end if squeeze excite x

        d1 = nn.relu(
            nn.Conv(x,
                    128,
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    bias=True,
                    name="down1"))
        d2 = nn.relu(
            nn.Conv(d1,
                    128,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    bias=True,
                    name="down2"))
        d3 = nn.relu(
            nn.Conv(d2,
                    128,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    bias=True,
                    name="down3"))

        if communication is Communication.SQUEEZE_EXCITE_D:
            d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c")
            d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c")
            d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c")

            nd1 = d1_flatten.shape[1]
            nd2 = d2_flatten.shape[1]

            d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten],
                                         axis=1)

            num_channels = d_together.shape[-1]
            y = d_together.mean(axis=1)
            y = nn.Dense(y, features=num_channels // 4, bias=False)
            y = nn.relu(y)
            y = nn.Dense(y, features=num_channels, bias=False)
            y = nn.sigmoid(y)

            d_together = d_together * y[:, None, :]

            # split and reshape
            d1 = d_together[:, :nd1].reshape(d1.shape)
            d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape)
            d3 = d_together[:, nd1 + nd2:].reshape(d3.shape)

        elif communication is Communication.TRANSFORMER:
            d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c")
            d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c")
            d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c")

            nd1 = d1_flatten.shape[1]
            nd2 = d2_flatten.shape[1]

            d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten],
                                         axis=1)

            positional_encodings = self.param(
                "scale_ratio_position_encodings",
                shape=(1, ) + d_together.shape[1:],
                initializer=jax.nn.initializers.normal(1. /
                                                       d_together.shape[-1]))
            d_together = transformer.Transformer(d_together +
                                                 positional_encodings,
                                                 num_layers=2,
                                                 num_heads=8,
                                                 is_training=train)

            # split and reshape
            d1 = d_together[:, :nd1].reshape(d1.shape)
            d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape)
            d3 = d_together[:, nd1 + nd2:].reshape(d3.shape)

        t1 = nn.Conv(d1,
                     6,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy1")
        t2 = nn.Conv(d2,
                     6,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy2")
        t3 = nn.Conv(d3,
                     9,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy3")

        raw_scores = (jnp.split(t1, 6, axis=-1) + jnp.split(t2, 6, axis=-1) +
                      jnp.split(t3, 9, axis=-1))

        # The following is for normalization.
        t = jnp.concatenate((jnp.reshape(
            t1, [batch_size, -1]), jnp.reshape(
                t2, [batch_size, -1]), jnp.reshape(t3, [batch_size, -1])),
                            axis=1)
        t_min = jnp.reshape(jnp.min(t, axis=-1), [batch_size, 1, 1, 1])
        t_max = jnp.reshape(jnp.max(t, axis=-1), [batch_size, 1, 1, 1])
        normalized_scores = zeroone(raw_scores, t_min, t_max)

        stats = {
            "scores": normalized_scores,
            "raw_scores": t,
        }
        # removes the split dimension. scores are now b x h' x w' shaped
        normalized_scores = [s.squeeze(-1) for s in normalized_scores]

        return normalized_scores, stats
Пример #12
0
  def apply(
      self,
      inputs,
      blocks_per_group,
      channel_multiplier,
      num_outputs,
      kernel_size=(3, 3),
      strides=None,
      maxpool=False,
      dropout_rate=0.0,
      dtype=jnp.float32,
      norm_layer='group_norm',
      train=True,
      return_activations=False,
      input_layer_key='input',
      has_discriminator=False,
      discriminator=False,
  ):

    norm_layer_name = ''
    if norm_layer == 'batch_norm':
      norm_layer = nn.BatchNorm.partial(use_running_average=not train)
      norm_layer_name = 'bn'
    elif norm_layer == 'group_norm':
      norm_layer = nn.GroupNorm.partial(num_groups=16)
      norm_layer_name = 'gn'

    layer_activations = collections.OrderedDict()
    input_is_set = False
    current_rep_key = 'input'
    if input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'init_conv'
    if input_is_set:
      x = nn.Conv(
          x,
          16,
          kernel_size=kernel_size,
          strides=strides,
          padding='SAME',
          name='init_conv')
      if maxpool:
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l1'
    if input_is_set:
      x = WideResnetGroup(
          x,
          blocks_per_group,
          16 * channel_multiplier,
          dropout_rate=dropout_rate,
          norm_layer=norm_layer,
          train=train,
          name='l1')
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l2'
    if input_is_set:
      x = WideResnetGroup(
          x,
          blocks_per_group,
          32 * channel_multiplier, (2, 2),
          dropout_rate=dropout_rate,
          norm_layer=norm_layer,
          train=train,
          name='l2')
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l3'
    if input_is_set:
      x = WideResnetGroup(
          x,
          blocks_per_group,
          64 * channel_multiplier, (2, 2),
          dropout_rate=dropout_rate,
          norm_layer=norm_layer,
          train=train,
          name='l3')
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l4'
    if input_is_set:
      x = norm_layer(x, name=f'{norm_layer_name}')
      x = jax.nn.relu(x)
      x = nn.avg_pool(x, (8, 8))
      x = x.reshape((x.shape[0], -1))
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    # DANN module
    if has_discriminator:
      z = dann_utils.flip_grad_identity(x)
      z = nn.Dense(z, 2, name='disc_l1', bias=True)
      z = nn.relu(z)
      z = nn.Dense(z, 2, name='disc_l2', bias=True)

    current_rep_key = 'head'
    if input_is_set:
      x = nn.Dense(x, num_outputs, dtype=dtype, name='head')
    else:
      x = inputs
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

      logging.warn('Input was never used')

    outputs = x
    if return_activations:
      outputs = (x, layer_activations, rep_key)
      if discriminator and has_discriminator:
        outputs = outputs + (z,)
    else:
      del layer_activations
      if discriminator and has_discriminator:
        outputs = (x, z)
    if discriminator and (not has_discriminator):
      raise ValueError(
          'Incosistent values passed for discriminator and has_discriminator')
    return outputs
Пример #13
0
            def apply(self, x):
                x = nn.Dense(x, hidden_reps_dim, bias=True, name='l1')
                x = nn.relu(x)
                x = nn.Dense(x, hidden_reps_dim, bias=True, name='l2')

                return x
Пример #14
0
    def apply(self,
              inputs,
              num_outputs,
              num_filters=64,
              num_layers=50,
              dropout_rate=0.0,
              input_dropout_rate=0.0,
              train=True,
              dtype=jnp.float32,
              head_bias_init=jnp.zeros,
              return_activations=False,
              input_layer_key='input',
              has_discriminator=False,
              discriminator=False):
        """Apply a ResNet network on the input.

    Args:
      inputs: jnp array; Inputs.
      num_outputs: int; Number of output units.
      num_filters: int; Determines base number of filters. Number of filters in
        block i is  num_filters * 2 ** i.
      num_layers: int; Number of layers (should be one of the predefined ones.)
      dropout_rate: float; Rate of dropping out the output of different hidden
        layers.
      input_dropout_rate: float; Rate of dropping out the input units.
      train: bool; Is train?
      dtype: jnp type; Type of the outputs.
      head_bias_init: fn(rng_key, shape)--> jnp array; Initializer for head bias
        parameters.
      return_activations: bool; If True hidden activation are also returned.
      input_layer_key: str; Determines where to plugin the input (this is to
        enable providing inputs to slices of the model). If `input_layer_key` is
        `layer_i` we assume the inputs are the activations of `layer_i` and pass
        them to `layer_{i+1}`.
      has_discriminator: bool; Whether the model should have discriminator
        layer.
      discriminator: bool; Whether we should return discriminator logits.

    Returns:
      Unnormalized Logits with shape `[bs, num_outputs]`,
      if return_activations:
        Logits, dict of hidden activations and the key to the representation(s)
        which will be used in as ``The Representation'', e.g., for computing
        losses.
    """
        if num_layers not in ResNet._block_size_options:
            raise ValueError('Please provide a valid number of layers')

        block_sizes = ResNet._block_size_options[num_layers]

        layer_activations = collections.OrderedDict()
        input_is_set = False
        current_rep_key = 'input'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True

        if input_is_set:
            # Input dropout
            x = nn.dropout(x, input_dropout_rate, deterministic=not train)
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        current_rep_key = 'init_conv'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key
        elif input_is_set:
            # First block
            x = nn.Conv(x,
                        num_filters, (7, 7), (2, 2),
                        padding=[(3, 3), (3, 3)],
                        bias=False,
                        dtype=dtype,
                        name='init_conv')
            x = nn.BatchNorm(x,
                             use_running_average=not train,
                             momentum=0.9,
                             epsilon=1e-5,
                             dtype=dtype,
                             name='init_bn')
            x = nn.relu(x)
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        # Residual blocks
        for i, block_size in enumerate(block_sizes):

            # Stage i (each stage contains blocks of the same size).
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                current_rep_key = f'block_{i + 1}+{j}'
                if input_layer_key == current_rep_key:
                    x = inputs
                    input_is_set = True
                    layer_activations[current_rep_key] = x
                    rep_key = current_rep_key
                elif input_is_set:
                    x = ResidualBlock(x,
                                      num_filters * 2**i,
                                      strides=strides,
                                      dropout_rate=dropout_rate,
                                      train=train,
                                      dtype=dtype,
                                      name=f'block_{i + 1}_{j}')
                    layer_activations[current_rep_key] = x
                    rep_key = current_rep_key

        current_rep_key = 'avg_pool'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key
        elif input_is_set:
            # Global Average Pool
            x = jnp.mean(x, axis=(1, 2))
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        # DANN module
        if has_discriminator:
            z = dann_utils.flip_grad_identity(x)
            z = nn.Dense(z, 2, name='disc_l1', bias=True)
            z = nn.relu(z)
            z = nn.Dense(z, 2, name='disc_l2', bias=True)

        current_rep_key = 'head'
        if input_layer_key == current_rep_key:
            x = inputs
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

            logging.warn('Input was never used')
        elif input_is_set:
            x = nn.Dense(x,
                         num_outputs,
                         dtype=dtype,
                         bias_init=head_bias_init,
                         name='head')

        # Make sure that the output is float32, even if our previous computations
        # are in float16, or other types.
        x = jnp.asarray(x, jnp.float32)

        outputs = x
        if return_activations:
            outputs = (x, layer_activations, rep_key)
            if discriminator and has_discriminator:
                outputs = outputs + (z, )
        else:
            del layer_activations
            if discriminator and has_discriminator:
                outputs = (x, z)
        if discriminator and (not has_discriminator):
            raise ValueError(
                'Incosistent values passed for discriminator and has_discriminator'
            )
        return outputs