Ejemplo n.º 1
0
 def __call__(self, x: jnp.ndarray):
     x = hk.Conv2D(output_channels=6, kernel_shape=5, padding='VALID')(x)
     if self._interval:
         x = self._act_fn(x, self._interval)
     else:
         x = self._act_fn(x)
     x = hk.AvgPool(window_shape=2, strides=1, padding='VALID')(x)
     x = hk.Conv2D(output_channels=16, kernel_shape=5, padding='VALID')(x)
     if self._interval:
         x = self._act_fn(x, self._interval)
     else:
         x = self._act_fn(x)
     x = hk.AvgPool(window_shape=2, strides=1, padding='VALID')(x)
     x = hk.Flatten()(x)
     x = hk.Linear(120)(x)
     if self._interval:
         x = self._act_fn(x, self._interval)
     else:
         x = self._act_fn(x)
     x = hk.Linear(84)(x)
     if self._interval:
         x = self._act_fn(x, self._interval)
     else:
         x = self._act_fn(x)
     x = hk.Linear(self._num_classes)(x)
     return x
Ejemplo n.º 2
0
 def forward(batch, is_training):
   num_filters = width
   x, _ = batch
   x = _resnet_layer(
       x, num_filters=num_filters, activation=jax.nn.relu, use_bias=use_bias,
       normalization_layer=normalization_layer
   )
   for stack in range(3):
     for res_block in range(num_res_blocks):
       strides = 1
       if stack > 0 and res_block == 0:  # first layer but not first stack
         strides = 2  # downsample
       y = _resnet_layer(
           x, num_filters=num_filters, strides=strides, activation=jax.nn.relu,
           use_bias=use_bias, is_training=is_training,
           normalization_layer=normalization_layer)
       y = _resnet_layer(
           y, num_filters=num_filters, use_bias=use_bias,
           is_training=is_training, normalization_layer=normalization_layer)
       if stack > 0 and res_block == 0:  # first layer but not first stack
         # linear projection residual shortcut connection to match changed dims
         x = _resnet_layer(
             x, num_filters=num_filters, kernel_size=1, strides=strides,
             use_bias=use_bias, is_training=is_training,
             normalization_layer=normalization_layer)
       x = jax.nn.relu(x + y)
     num_filters *= 2
   x = hk.AvgPool((8, 8, 1), 8, 'VALID')(x)
   x = hk.Flatten()(x)
   logits = hk.Linear(num_classes, w_init=he_normal)(x)
   return logits
Ejemplo n.º 3
0
    def __call__(self, x: jnp.ndarray, is_train: bool):
        x = hk.Conv2D(output_channels=64, kernel_shape=3, padding='VALID')(x)
        if self._interval:
            x = self._act_fn(x, self._interval)
        else:
            x = self._act_fn(x)
        x = hk.Conv2D(output_channels=64, kernel_shape=3, padding='VALID')(x)
        if self._interval:
            x = self._act_fn(x, self._interval)
        else:
            x = self._act_fn(x)
        x = hk.AvgPool(window_shape=2, strides=1, padding='VALID')(x)
        x = Dropout(0.2, self._seed)(x, is_train)
        x = hk.Conv2D(output_channels=96, kernel_shape=3, padding='VALID')(x)
        if self._interval:
            x = self._act_fn(x, self._interval)
        else:
            x = self._act_fn(x)
        x = hk.Conv2D(output_channels=96, kernel_shape=3, padding='VALID')(x)
        if self._interval:
            x = self._act_fn(x, self._interval)
        else:
            x = self._act_fn(x)
        x = hk.AvgPool(window_shape=2, strides=1, padding='VALID')(x)
        x = Dropout(0.3, self._seed)(x, is_train)
        x = hk.Conv2D(output_channels=128, kernel_shape=3, padding='VALID')(x)
        if self._interval:
            x = self._act_fn(x, self._interval)
        else:
            x = self._act_fn(x)
        x = hk.Conv2D(output_channels=128, kernel_shape=3, padding='VALID')(x)
        if self._interval:
            x = self._act_fn(x, self._interval)
        else:
            x = self._act_fn(x)
        x = hk.AvgPool(window_shape=2, strides=1, padding='VALID')(x)
        x = Dropout(0.4, self._seed)(x, is_train)
        x = hk.Flatten()(x)
        x = hk.Linear(128)(x)
        if self._interval:
            x = self._act_fn(x, self._interval)
        else:
            x = self._act_fn(x)
        x = Dropout(0.5, self._seed)(x, is_train)
        x = hk.Linear(self._num_classes)(x)

        return x
Ejemplo n.º 4
0
def make_downsampling_layer(
    strategy: Union[str, DownsamplingStrategy],
    output_channels: int,
) -> hk.SupportsCall:
    """Returns a sequence of modules corresponding to the desired downsampling."""
    strategy = DownsamplingStrategy(strategy)

    if strategy is DownsamplingStrategy.AVG_POOL:
        return hk.AvgPool(window_shape=(3, 3, 1),
                          strides=(2, 2, 1),
                          padding='SAME')

    elif strategy is DownsamplingStrategy.CONV:
        return hk.Sequential([
            hk.Conv2D(output_channels,
                      kernel_shape=3,
                      stride=2,
                      w_init=hk.initializers.TruncatedNormal(1e-2)),
        ])

    elif strategy is DownsamplingStrategy.LAYERNORM_RELU_CONV:
        return hk.Sequential([
            hk.LayerNorm(axis=(1, 2, 3),
                         create_scale=True,
                         create_offset=True,
                         eps=1e-6),
            jax.nn.relu,
            hk.Conv2D(output_channels,
                      kernel_shape=3,
                      stride=2,
                      w_init=hk.initializers.TruncatedNormal(1e-2)),
        ])

    elif strategy is DownsamplingStrategy.CONV_MAX:
        return hk.Sequential([
            hk.Conv2D(output_channels, kernel_shape=3, stride=1),
            hk.MaxPool(window_shape=(3, 3, 1),
                       strides=(2, 2, 1),
                       padding='SAME')
        ])
    else:
        raise ValueError(
            'Unrecognized downsampling strategy. Expected one of'
            f' {[strategy.value for strategy in DownsamplingStrategy]}'
            f' but received {strategy}.')
Ejemplo n.º 5
0
    def __init__(self,
                 blocks_per_group,
                 bn_config,
                 bottleneck,
                 channels_per_group,
                 use_projection,
                 strides,
                 n_output_channels=1,
                 use_bn=True,
                 pad_crop=False,
                 name=None):
        """Constructs a Residual UNet model based on a traditional ResNet.
    Args:
      blocks_per_group: A sequence of length 4 that indicates the number of
        blocks created in each group.
      bn_config: A dictionary of two elements, ``decay_rate`` and ``eps`` to be
        passed on to the :class:`~haiku.BatchNorm` layers. By default the
        ``decay_rate`` is ``0.9`` and ``eps`` is ``1e-5``.
      resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to
        ``False``.
      bottleneck: Whether the block should bottleneck or not. Defaults to
        ``True``.
      channels_per_group: A sequence of length 4 that indicates the number
        of channels used for each block in each group.
      use_projection: A sequence of length 4 that indicates whether each
        residual block should use projection.
      n_output_channels: The number of output channels, for example to change in
        the case of a complex denoising. Defaults to 1.
      use_bn: Whether the network should use batch normalisation. Defaults to
        ``True``.
      pad_crop: Whether to use cropping/padding to make sure the images can be
        downsampled and upsampled correctly. Defaults to ``False``.
      name: Name of the module.
    """
        super().__init__(name=name)
        self.resnet_v2 = False
        self.use_bn = use_bn
        self.pad_crop = pad_crop
        self.n_output_channels = n_output_channels
        bn_config = dict(bn_config or {})
        bn_config.setdefault("decay_rate", 0.9)
        bn_config.setdefault("eps", 1e-5)
        bn_config.setdefault("create_scale", True)
        bn_config.setdefault("create_offset", True)
        self.strides = strides
        bl = len(self.strides)

        # Number of blocks in each group for ResNet.
        check_length(bl, blocks_per_group, "blocks_per_group")
        check_length(bl, channels_per_group, "channels_per_group")
        self.upsampling = upsample
        self.pooling = hk.AvgPool(window_shape=2, strides=2, padding='SAME')

        self.initial_conv = hk.Conv2D(output_channels=32,
                                      kernel_shape=7,
                                      stride=1,
                                      with_bias=not self.use_bn,
                                      padding="SAME",
                                      name="initial_conv")

        if not self.resnet_v2 and self.use_bn:
            self.initial_batchnorm = hk.BatchNorm(name="initial_batchnorm",
                                                  **bn_config)

        self.block_groups = []
        self.up_block_groups = []
        for i in range(bl):
            self.block_groups.append(
                BlockGroup(channels=channels_per_group[i],
                           num_blocks=blocks_per_group[i],
                           stride=strides[i],
                           bn_config=bn_config,
                           bottleneck=bottleneck,
                           use_projection=use_projection[i],
                           transpose=False,
                           use_bn=self.use_bn,
                           name="block_group_%d" % (i)))

        for i in range(bl):
            self.up_block_groups.append(
                BlockGroup(channels=channels_per_group[i],
                           num_blocks=blocks_per_group[i],
                           stride=strides[i],
                           bn_config=bn_config,
                           bottleneck=bottleneck,
                           use_projection=use_projection[i],
                           transpose=True,
                           use_bn=self.use_bn,
                           name="up_block_group_%d" % (i)))

        if self.resnet_v2 and self.use_bn:
            self.final_batchnorm = hk.BatchNorm(name="final_batchnorm",
                                                **bn_config)

        self.final_conv = hk.Conv2D(
            output_channels=self.n_output_channels,
            kernel_shape=5,
            stride=1,
            padding='SAME',
            name='final_conv',
        )
Ejemplo n.º 6
0
    def __init__(self,
                 channels: int,
                 stride: Union[int, Sequence[int]],
                 use_projection: bool,
                 bn_config: Mapping[str, float],
                 bottleneck: bool,
                 use_bn: bool = True,
                 transpose: bool = False,
                 name: Optional[str] = None):
        super().__init__(name=name)
        self.use_projection = use_projection
        self.use_bn = use_bn
        if self.use_bn:
            bn_config = dict(bn_config)
            bn_config.setdefault("create_scale", True)
            bn_config.setdefault("create_offset", True)
            bn_config.setdefault("decay_rate", 0.999)

        if transpose:
            self.pooling_or_upsampling = upsample
        else:
            self.pooling_or_upsampling = hk.AvgPool(window_shape=2,
                                                    strides=2,
                                                    padding='SAME')
        self.stride = stride

        if self.use_projection:
            # this is just used for the skip connection
            self.proj_conv = hk.Conv2D(
                output_channels=channels,
                kernel_shape=1,
                # depending on whether it's transpose or not this stride must be
                # replaced by upsampling or avg pooling
                stride=1,
                with_bias=not self.use_bn,
                padding="SAME",
                name="shortcut_conv")
            if self.use_bn:
                self.proj_batchnorm = hk.BatchNorm(name="shortcut_batchnorm",
                                                   **bn_config)

        channel_div = 4 if bottleneck else 1
        conv_0 = hk.Conv2D(output_channels=channels // channel_div,
                           kernel_shape=1 if bottleneck else 3,
                           stride=1,
                           with_bias=not self.use_bn,
                           padding="SAME",
                           name="conv_0")
        if self.use_bn:
            bn_0 = hk.BatchNorm(name="batchnorm_0", **bn_config)

        conv_1 = hk.Conv2D(output_channels=channels // channel_div,
                           kernel_shape=3,
                           stride=1,
                           with_bias=not self.use_bn,
                           padding="SAME",
                           name="conv_1")
        if self.use_bn:
            bn_1 = hk.BatchNorm(name="batchnorm_1", **bn_config)
            layers = ((conv_0, bn_0), (conv_1, bn_1))
        else:
            layers = ((conv_0, None), (conv_1, None))

        if bottleneck:
            conv_2 = hk.Conv2D(output_channels=channels,
                               kernel_shape=1,
                               stride=1,
                               with_bias=not self.use_bn,
                               padding="SAME",
                               name="conv_2")
            if self.use_bn:
                bn_2 = hk.BatchNorm(name="batchnorm_2",
                                    scale_init=jnp.zeros,
                                    **bn_config)
                layers = layers + ((conv_2, bn_2), )
            else:
                layers = layers + ((conv_2, None), )

        self.layers = layers
Ejemplo n.º 7
0
def augmented_vgg19(fp: str,
                    content_image: jnp.ndarray,
                    style_image: jnp.ndarray,
                    mean: jnp.ndarray,
                    std: jnp.ndarray,
                    content_layers: List[str] = None,
                    style_layers: List[str] = None,
                    pooling: str = "avg") -> hk.Sequential:
    """Build a VGG19 network augmented by content and style loss layers."""
    pooling = pooling.lower()
    if pooling not in ["avg", "max"]:
        raise ValueError("Pooling method not recognized. Options are: "
                         "\"avg\", \"max\".")

    params = get_model_params(fp=fp)

    # prepend a normalization layer
    layers = [Normalization(content_image, mean, std, "norm")]

    # tracks number of conv layers
    n = 0

    # desired depth layers to compute style/content losses :
    content_layers = content_layers or []
    style_layers = style_layers or []

    model = hk.Sequential(layers=layers)

    for k, p_dict in params.items():
        if "pool" in k:
            if pooling == "avg":
                # exactly as many as pools as convolutions
                layers.append(
                    hk.AvgPool(window_shape=2,
                               strides=2,
                               padding="VALID",
                               channel_axis=1,
                               name=f"avg_pool_{n}"))
            else:
                layers.append(
                    hk.MaxPool(window_shape=2,
                               strides=2,
                               padding="VALID",
                               channel_axis=1,
                               name=f"max_pool_{n}"))
        elif "conv" in k:
            n += 1
            name = f"conv_{n}"

            kernel_h, kernel_w, in_ch, out_ch = p_dict["w"].shape

            # VGG only has square conv kernels
            assert kernel_w == kernel_h, "VGG19 only has square conv kernels"
            kernel_shape = kernel_h

            layers.append(
                hk.Conv2D(output_channels=out_ch,
                          kernel_shape=kernel_shape,
                          stride=1,
                          padding="SAME",
                          data_format="NCHW",
                          w_init=hk.initializers.Constant(p_dict["w"]),
                          b_init=hk.initializers.Constant(p_dict["b"]),
                          name=name))

            if name in style_layers:
                model.layers = tuple(layers)
                style_target = model(style_image, is_training=False)
                layers.append(
                    StyleLoss(target=style_target, name=f"style_loss_{n}"))

            if name in content_layers:
                model.layers = tuple(layers)
                content_target = model(content_image, is_training=False)
                layers.append(
                    ContentLoss(target=content_target,
                                name=f"content_loss_{n}"))

            layers.append(jax.nn.relu)

    # this modifies our n from before in place
    for n in range(len(layers) - 1, -1, -1):
        if isinstance(layers[n], (StyleLoss, ContentLoss)):
            break

    # break off after last content loss layer
    layers = layers[:(n + 1)]

    model.layers = tuple(layers)

    return model