def __call__(self, x: jnp.ndarray): x = hk.Conv2D(output_channels=6, kernel_shape=5, padding='VALID')(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = hk.AvgPool(window_shape=2, strides=1, padding='VALID')(x) x = hk.Conv2D(output_channels=16, kernel_shape=5, padding='VALID')(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = hk.AvgPool(window_shape=2, strides=1, padding='VALID')(x) x = hk.Flatten()(x) x = hk.Linear(120)(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = hk.Linear(84)(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = hk.Linear(self._num_classes)(x) return x
def forward(batch, is_training): num_filters = width x, _ = batch x = _resnet_layer( x, num_filters=num_filters, activation=jax.nn.relu, use_bias=use_bias, normalization_layer=normalization_layer ) for stack in range(3): for res_block in range(num_res_blocks): strides = 1 if stack > 0 and res_block == 0: # first layer but not first stack strides = 2 # downsample y = _resnet_layer( x, num_filters=num_filters, strides=strides, activation=jax.nn.relu, use_bias=use_bias, is_training=is_training, normalization_layer=normalization_layer) y = _resnet_layer( y, num_filters=num_filters, use_bias=use_bias, is_training=is_training, normalization_layer=normalization_layer) if stack > 0 and res_block == 0: # first layer but not first stack # linear projection residual shortcut connection to match changed dims x = _resnet_layer( x, num_filters=num_filters, kernel_size=1, strides=strides, use_bias=use_bias, is_training=is_training, normalization_layer=normalization_layer) x = jax.nn.relu(x + y) num_filters *= 2 x = hk.AvgPool((8, 8, 1), 8, 'VALID')(x) x = hk.Flatten()(x) logits = hk.Linear(num_classes, w_init=he_normal)(x) return logits
def __call__(self, x: jnp.ndarray, is_train: bool): x = hk.Conv2D(output_channels=64, kernel_shape=3, padding='VALID')(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = hk.Conv2D(output_channels=64, kernel_shape=3, padding='VALID')(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = hk.AvgPool(window_shape=2, strides=1, padding='VALID')(x) x = Dropout(0.2, self._seed)(x, is_train) x = hk.Conv2D(output_channels=96, kernel_shape=3, padding='VALID')(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = hk.Conv2D(output_channels=96, kernel_shape=3, padding='VALID')(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = hk.AvgPool(window_shape=2, strides=1, padding='VALID')(x) x = Dropout(0.3, self._seed)(x, is_train) x = hk.Conv2D(output_channels=128, kernel_shape=3, padding='VALID')(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = hk.Conv2D(output_channels=128, kernel_shape=3, padding='VALID')(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = hk.AvgPool(window_shape=2, strides=1, padding='VALID')(x) x = Dropout(0.4, self._seed)(x, is_train) x = hk.Flatten()(x) x = hk.Linear(128)(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = Dropout(0.5, self._seed)(x, is_train) x = hk.Linear(self._num_classes)(x) return x
def make_downsampling_layer( strategy: Union[str, DownsamplingStrategy], output_channels: int, ) -> hk.SupportsCall: """Returns a sequence of modules corresponding to the desired downsampling.""" strategy = DownsamplingStrategy(strategy) if strategy is DownsamplingStrategy.AVG_POOL: return hk.AvgPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding='SAME') elif strategy is DownsamplingStrategy.CONV: return hk.Sequential([ hk.Conv2D(output_channels, kernel_shape=3, stride=2, w_init=hk.initializers.TruncatedNormal(1e-2)), ]) elif strategy is DownsamplingStrategy.LAYERNORM_RELU_CONV: return hk.Sequential([ hk.LayerNorm(axis=(1, 2, 3), create_scale=True, create_offset=True, eps=1e-6), jax.nn.relu, hk.Conv2D(output_channels, kernel_shape=3, stride=2, w_init=hk.initializers.TruncatedNormal(1e-2)), ]) elif strategy is DownsamplingStrategy.CONV_MAX: return hk.Sequential([ hk.Conv2D(output_channels, kernel_shape=3, stride=1), hk.MaxPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding='SAME') ]) else: raise ValueError( 'Unrecognized downsampling strategy. Expected one of' f' {[strategy.value for strategy in DownsamplingStrategy]}' f' but received {strategy}.')
def __init__(self, blocks_per_group, bn_config, bottleneck, channels_per_group, use_projection, strides, n_output_channels=1, use_bn=True, pad_crop=False, name=None): """Constructs a Residual UNet model based on a traditional ResNet. Args: blocks_per_group: A sequence of length 4 that indicates the number of blocks created in each group. bn_config: A dictionary of two elements, ``decay_rate`` and ``eps`` to be passed on to the :class:`~haiku.BatchNorm` layers. By default the ``decay_rate`` is ``0.9`` and ``eps`` is ``1e-5``. resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to ``False``. bottleneck: Whether the block should bottleneck or not. Defaults to ``True``. channels_per_group: A sequence of length 4 that indicates the number of channels used for each block in each group. use_projection: A sequence of length 4 that indicates whether each residual block should use projection. n_output_channels: The number of output channels, for example to change in the case of a complex denoising. Defaults to 1. use_bn: Whether the network should use batch normalisation. Defaults to ``True``. pad_crop: Whether to use cropping/padding to make sure the images can be downsampled and upsampled correctly. Defaults to ``False``. name: Name of the module. """ super().__init__(name=name) self.resnet_v2 = False self.use_bn = use_bn self.pad_crop = pad_crop self.n_output_channels = n_output_channels bn_config = dict(bn_config or {}) bn_config.setdefault("decay_rate", 0.9) bn_config.setdefault("eps", 1e-5) bn_config.setdefault("create_scale", True) bn_config.setdefault("create_offset", True) self.strides = strides bl = len(self.strides) # Number of blocks in each group for ResNet. check_length(bl, blocks_per_group, "blocks_per_group") check_length(bl, channels_per_group, "channels_per_group") self.upsampling = upsample self.pooling = hk.AvgPool(window_shape=2, strides=2, padding='SAME') self.initial_conv = hk.Conv2D(output_channels=32, kernel_shape=7, stride=1, with_bias=not self.use_bn, padding="SAME", name="initial_conv") if not self.resnet_v2 and self.use_bn: self.initial_batchnorm = hk.BatchNorm(name="initial_batchnorm", **bn_config) self.block_groups = [] self.up_block_groups = [] for i in range(bl): self.block_groups.append( BlockGroup(channels=channels_per_group[i], num_blocks=blocks_per_group[i], stride=strides[i], bn_config=bn_config, bottleneck=bottleneck, use_projection=use_projection[i], transpose=False, use_bn=self.use_bn, name="block_group_%d" % (i))) for i in range(bl): self.up_block_groups.append( BlockGroup(channels=channels_per_group[i], num_blocks=blocks_per_group[i], stride=strides[i], bn_config=bn_config, bottleneck=bottleneck, use_projection=use_projection[i], transpose=True, use_bn=self.use_bn, name="up_block_group_%d" % (i))) if self.resnet_v2 and self.use_bn: self.final_batchnorm = hk.BatchNorm(name="final_batchnorm", **bn_config) self.final_conv = hk.Conv2D( output_channels=self.n_output_channels, kernel_shape=5, stride=1, padding='SAME', name='final_conv', )
def __init__(self, channels: int, stride: Union[int, Sequence[int]], use_projection: bool, bn_config: Mapping[str, float], bottleneck: bool, use_bn: bool = True, transpose: bool = False, name: Optional[str] = None): super().__init__(name=name) self.use_projection = use_projection self.use_bn = use_bn if self.use_bn: bn_config = dict(bn_config) bn_config.setdefault("create_scale", True) bn_config.setdefault("create_offset", True) bn_config.setdefault("decay_rate", 0.999) if transpose: self.pooling_or_upsampling = upsample else: self.pooling_or_upsampling = hk.AvgPool(window_shape=2, strides=2, padding='SAME') self.stride = stride if self.use_projection: # this is just used for the skip connection self.proj_conv = hk.Conv2D( output_channels=channels, kernel_shape=1, # depending on whether it's transpose or not this stride must be # replaced by upsampling or avg pooling stride=1, with_bias=not self.use_bn, padding="SAME", name="shortcut_conv") if self.use_bn: self.proj_batchnorm = hk.BatchNorm(name="shortcut_batchnorm", **bn_config) channel_div = 4 if bottleneck else 1 conv_0 = hk.Conv2D(output_channels=channels // channel_div, kernel_shape=1 if bottleneck else 3, stride=1, with_bias=not self.use_bn, padding="SAME", name="conv_0") if self.use_bn: bn_0 = hk.BatchNorm(name="batchnorm_0", **bn_config) conv_1 = hk.Conv2D(output_channels=channels // channel_div, kernel_shape=3, stride=1, with_bias=not self.use_bn, padding="SAME", name="conv_1") if self.use_bn: bn_1 = hk.BatchNorm(name="batchnorm_1", **bn_config) layers = ((conv_0, bn_0), (conv_1, bn_1)) else: layers = ((conv_0, None), (conv_1, None)) if bottleneck: conv_2 = hk.Conv2D(output_channels=channels, kernel_shape=1, stride=1, with_bias=not self.use_bn, padding="SAME", name="conv_2") if self.use_bn: bn_2 = hk.BatchNorm(name="batchnorm_2", scale_init=jnp.zeros, **bn_config) layers = layers + ((conv_2, bn_2), ) else: layers = layers + ((conv_2, None), ) self.layers = layers
def augmented_vgg19(fp: str, content_image: jnp.ndarray, style_image: jnp.ndarray, mean: jnp.ndarray, std: jnp.ndarray, content_layers: List[str] = None, style_layers: List[str] = None, pooling: str = "avg") -> hk.Sequential: """Build a VGG19 network augmented by content and style loss layers.""" pooling = pooling.lower() if pooling not in ["avg", "max"]: raise ValueError("Pooling method not recognized. Options are: " "\"avg\", \"max\".") params = get_model_params(fp=fp) # prepend a normalization layer layers = [Normalization(content_image, mean, std, "norm")] # tracks number of conv layers n = 0 # desired depth layers to compute style/content losses : content_layers = content_layers or [] style_layers = style_layers or [] model = hk.Sequential(layers=layers) for k, p_dict in params.items(): if "pool" in k: if pooling == "avg": # exactly as many as pools as convolutions layers.append( hk.AvgPool(window_shape=2, strides=2, padding="VALID", channel_axis=1, name=f"avg_pool_{n}")) else: layers.append( hk.MaxPool(window_shape=2, strides=2, padding="VALID", channel_axis=1, name=f"max_pool_{n}")) elif "conv" in k: n += 1 name = f"conv_{n}" kernel_h, kernel_w, in_ch, out_ch = p_dict["w"].shape # VGG only has square conv kernels assert kernel_w == kernel_h, "VGG19 only has square conv kernels" kernel_shape = kernel_h layers.append( hk.Conv2D(output_channels=out_ch, kernel_shape=kernel_shape, stride=1, padding="SAME", data_format="NCHW", w_init=hk.initializers.Constant(p_dict["w"]), b_init=hk.initializers.Constant(p_dict["b"]), name=name)) if name in style_layers: model.layers = tuple(layers) style_target = model(style_image, is_training=False) layers.append( StyleLoss(target=style_target, name=f"style_loss_{n}")) if name in content_layers: model.layers = tuple(layers) content_target = model(content_image, is_training=False) layers.append( ContentLoss(target=content_target, name=f"content_loss_{n}")) layers.append(jax.nn.relu) # this modifies our n from before in place for n in range(len(layers) - 1, -1, -1): if isinstance(layers[n], (StyleLoss, ContentLoss)): break # break off after last content loss layer layers = layers[:(n + 1)] model.layers = tuple(layers) return model