def mnist_model(features, **_):
    return hk.Sequential([
        hk.Conv2D(16, (8, 8), padding='SAME', stride=(2, 2)),
        jax.nn.relu,
        hk.MaxPool(2, 1, padding='VALID'),  # matches stax
        hk.Conv2D(32, (4, 4), padding='VALID', stride=(2, 2)),
        jax.nn.relu,
        hk.MaxPool(2, 1, padding='VALID'),  # matches stax
        hk.Flatten(),
        hk.Linear(32),
        jax.nn.relu,
        hk.Linear(10),
    ])(features)
Example #2
0
 def __call__(self, x):
   return hk.Sequential([
       hk.Conv2D(8, (3, 3)), jax.nn.relu,
       hk.MaxPool((1, 2, 2, 1), (1, 2, 2, 1), 'VALID'),
       hk.Flatten(),
       hk.Linear(3, with_bias=False)
   ])(x)
Example #3
0
    def __init__(self, model_size):
        super().__init__()
        if model_size == 'large':
            channels = {
                'prep': 64,
                'layer1': 128,
                'layer2': 256,
                'layer3': 512
            }
        elif model_size == 'med':
            channels = {'prep': 32, 'layer1': 64, 'layer2': 128, 'layer3': 256}
        elif model_size == 'small':
            channels = {'prep': 16, 'layer1': 32, 'layer2': 64, 'layer3': 128}
        elif model_size == 'tiny':
            channels = {'prep': 8, 'layer1': 16, 'layer2': 32, 'layer3': 64}
        self.prep = ConvBN(3, channels['prep'])

        # Layer 1
        self.conv1 = ConvBN(channels['prep'], channels['layer1'])
        self.pool1 = hk.MaxPool(window_shape=(1, 1, 2, 2),
                                strides=(1, 1, 2, 2),
                                padding='SAME')
        self.residual1 = Residual(channels['layer1'])

        # Layer 2
        self.conv2 = ConvBN(channels['layer1'], channels['layer2'])
        self.pool2 = hk.MaxPool(window_shape=(1, 1, 2, 2),
                                strides=(1, 1, 2, 2),
                                padding='SAME')

        # Layer 3
        self.conv3 = ConvBN(channels['layer2'], channels['layer3'])
        self.pool3 = hk.MaxPool(window_shape=(1, 1, 2, 2),
                                strides=(1, 1, 2, 2),
                                padding='SAME')
        self.residual3 = Residual(channels['layer3'])

        self.pool4 = hk.MaxPool(window_shape=(1, 1, 4, 4),
                                strides=(1, 1, 4, 4),
                                padding='SAME')

        self.fc = hk.Linear(10, with_bias=False)
        self.logit_weight = 0.125
Example #4
0
    def lenet_fn(batch, is_training):
        """Network inspired by LeNet-5."""
        x, _ = batch

        cnn = hk.Sequential([
            hk.Conv2D(output_channels=6, kernel_shape=5, padding="SAME"),
            jax.nn.relu,
            hk.MaxPool(window_shape=3, strides=2, padding="VALID"),
            hk.Conv2D(output_channels=16, kernel_shape=5, padding="SAME"),
            jax.nn.relu,
            hk.MaxPool(window_shape=3, strides=2, padding="VALID"),
            hk.Conv2D(output_channels=120, kernel_shape=5, padding="SAME"),
            jax.nn.relu,
            hk.MaxPool(window_shape=3, strides=2, padding="VALID"),
            hk.Flatten(),
            hk.Linear(84),
            jax.nn.relu,
            hk.Linear(num_classes),
        ])
        return cnn(x)
def lenet_fn(batch):
    """Network inspired by LeNet-5."""
    x, _ = batch
    x = x.astype(jnp.float32)
    cnn = hk.Sequential([
        hk.Conv2D(output_channels=32, kernel_shape=5, padding="SAME"),
        jax.nn.relu,
        hk.MaxPool(window_shape=3, strides=2, padding="VALID"),
        hk.Conv2D(output_channels=64, kernel_shape=5, padding="SAME"),
        jax.nn.relu,
        hk.MaxPool(window_shape=3, strides=2, padding="VALID"),
        hk.Conv2D(output_channels=128, kernel_shape=5, padding="SAME"),
        hk.MaxPool(window_shape=3, strides=2, padding="VALID"),
        hk.Flatten(),
        hk.Linear(1000),
        jax.nn.relu,
        hk.Linear(1000),
        jax.nn.relu,
        hk.Linear(10),
    ])
    return cnn(x)
Example #6
0
def _call_layers(
        cfg,
        inp: Tensor,
        batch_norm: bool = True,
        include_top: bool = True,
        initial_weights: Optional[hk.Params] = None,
        output_feature_maps: bool = False) -> Union[Tensor, List[Tensor]]:

    x = inp
    partial_results = []

    # Ignore max pooling if we do not append the classifier
    if not include_top:
        cfg = cfg[:-1]

    i = 0
    base_name = 'vgg16/conv2_d'
    for v in cfg:
        if v == 'M':
            partial_results.append(x)
            x = hk.MaxPool(window_shape=2, strides=2, padding="VALID")(x)
        else:
            if i == 0:
                param_name = base_name
            else:
                param_name = base_name + f'_{i}'

            i += 1

            w_init = (None
                      if initial_weights is None else hk.initializers.Constant(
                          constant=initial_weights[param_name]['w']))
            b_init = (None
                      if initial_weights is None else hk.initializers.Constant(
                          constant=initial_weights[param_name]['b']))
            x = hk.Conv2D(v,
                          kernel_shape=3,
                          stride=1,
                          padding='SAME',
                          w_init=w_init,
                          b_init=b_init)(x)

            if batch_norm:
                x = hk.BatchNorm(True, True, decay_rate=0.999)(x)

            x = jax.nn.relu(x)

    partial_results.append(x)

    if not output_feature_maps:
        return partial_results[-1]
    else:
        return partial_results
Example #7
0
    def __init__(self):
        super().__init__()
        # input is 28x28
        # padding=2 for same padding
        self.conv1 = hk.Conv2D(output_channels=32,
                               kernel_shape=5,
                               padding='SAME',
                               data_format='NCHW')
        self.pool1 = hk.MaxPool(window_shape=(1, 1, 2, 2),
                                strides=(1, 1, 2, 2),
                                padding='SAME')

        # feature map size is 14*14 by pooling
        # padding=2 for same padding
        self.conv2 = hk.Conv2D(output_channels=64,
                               kernel_shape=5,
                               padding='SAME',
                               data_format='NCHW')
        self.pool2 = hk.MaxPool(window_shape=(1, 1, 2, 2),
                                strides=(1, 1, 2, 2),
                                padding='SAME')
        # feature map size is 7*7 by pooling
        self.fc = hk.Linear(10)
Example #8
0
 def __call__(self, x: jnp.ndarray, is_train: bool):
   x = hk.Conv2D(output_channels=32, kernel_shape=(3, 3), padding='VALID')(x)
   x = jax.nn.relu(x)
   x = hk.Conv2D(output_channels=64, kernel_shape=(3, 3), padding='VALID')(x)
   x = jax.nn.relu(x)
   x = (
       hk.MaxPool(
           window_shape=(1, 2, 2, 1), strides=(1, 2, 2, 1),
           padding='VALID')(x))
   x = Dropout(rate=0.25)(x, is_train)
   x = hk.Flatten()(x)
   x = hk.Linear(128)(x)
   x = jax.nn.relu(x)
   x = Dropout(rate=0.5)(x, is_train)
   x = hk.Linear(self._num_classes)(x)
   return x
Example #9
0
def make_downsampling_layer(
    strategy: Union[str, DownsamplingStrategy],
    output_channels: int,
) -> hk.SupportsCall:
    """Returns a sequence of modules corresponding to the desired downsampling."""
    strategy = DownsamplingStrategy(strategy)

    if strategy is DownsamplingStrategy.AVG_POOL:
        return hk.AvgPool(window_shape=(3, 3, 1),
                          strides=(2, 2, 1),
                          padding='SAME')

    elif strategy is DownsamplingStrategy.CONV:
        return hk.Sequential([
            hk.Conv2D(output_channels,
                      kernel_shape=3,
                      stride=2,
                      w_init=hk.initializers.TruncatedNormal(1e-2)),
        ])

    elif strategy is DownsamplingStrategy.LAYERNORM_RELU_CONV:
        return hk.Sequential([
            hk.LayerNorm(axis=(1, 2, 3),
                         create_scale=True,
                         create_offset=True,
                         eps=1e-6),
            jax.nn.relu,
            hk.Conv2D(output_channels,
                      kernel_shape=3,
                      stride=2,
                      w_init=hk.initializers.TruncatedNormal(1e-2)),
        ])

    elif strategy is DownsamplingStrategy.CONV_MAX:
        return hk.Sequential([
            hk.Conv2D(output_channels, kernel_shape=3, stride=1),
            hk.MaxPool(window_shape=(3, 3, 1),
                       strides=(2, 2, 1),
                       padding='SAME')
        ])
    else:
        raise ValueError(
            'Unrecognized downsampling strategy. Expected one of'
            f' {[strategy.value for strategy in DownsamplingStrategy]}'
            f' but received {strategy}.')
Example #10
0
 def forward(batch, is_training):
   x, _ = batch
   batch_size = x.shape[0]
   x = hk.Embed(vocab_size=max_features, embed_dim=embedding_size)(x)
   x = hk.Conv1D(output_channels=num_filters, kernel_shape=kernel_size,
                 padding="VALID")(x)
   if use_swish:
       x = jax.nn.swish(x)
   else:
       x = jax.nn.relu(x)
   if use_maxpool:
       x = hk.MaxPool(
           window_shape=pool_size, strides=pool_size, padding='VALID',
           channel_axis=2)(x)
   x = jnp.moveaxis(x, 1, 0)[:, :] #[T, B, F]
   lstm_layer = hk.LSTM(hidden_size=cell_size)
   init_state = lstm_layer.initial_state(batch_size)
   x, state = hk.static_unroll(lstm_layer, x, init_state)
   x = x[-1]
   logits = hk.Linear(num_classes)(x)
   return logits
Example #11
0
    def __call__(self,
                 x: Tensor,
                 training: bool = False) -> Tuple[Tensor, Tensor]:
        params = self.ssd_initial_weights
        x = aj.zoo.VGG16(include_top=False,
                         pretrained=False,
                         initial_weights=self.backbone_initial_weights,
                         output_feature_maps=True)(x)

        conv4_3, x = x[-2:]
        conv4_3 = aj.nn.layers.L2Norm(init_fn=(
            hk.initializers.Constant(params['ssd/l2_norm']['gamma'])
            if params is not None else hk.initializers.Constant(20.)))(conv4_3)

        x = hk.MaxPool(window_shape=(1, 3, 3, 1), strides=1, padding='SAME')(x)

        # Replace fully connected by FCN
        conv_6 = hk.Conv2D(
            output_channels=1024,
            kernel_shape=3,
            stride=1,
            rate=6,
            w_init=(hk.initializers.Constant(params['ssd/conv2_d']['w'])
                    if params is not None else self.xavier_init_fn),
            b_init=(hk.initializers.Constant(params['ssd/conv2_d']['b'])
                    if params is not None else None),
            padding='SAME')(x)
        conv_6 = jax.nn.relu(conv_6)

        conv7 = hk.Conv2D(
            output_channels=1024,
            kernel_shape=1,
            stride=1,
            w_init=(hk.initializers.Constant(params['ssd/conv2_d_1']['w'])
                    if params is not None else self.xavier_init_fn),
            b_init=(hk.initializers.Constant(params['ssd/conv2_d_1']['b'])
                    if params is not None else None),
            padding='SAME')(conv_6)
        conv7 = jax.nn.relu(conv7)

        # Build additional features
        conv8_2 = self._additional_conv(conv7,
                                        256,
                                        512,
                                        stride=2,
                                        training=training,
                                        name='conv_8')
        conv9_2 = self._additional_conv(conv8_2,
                                        128,
                                        256,
                                        stride=2,
                                        training=training,
                                        name='conv_9')
        conv10_2 = self._additional_conv(conv9_2,
                                         128,
                                         256,
                                         stride=1,
                                         training=training,
                                         name='conv_10')
        conv11_2 = self._additional_conv(conv10_2,
                                         128,
                                         256,
                                         stride=1,
                                         training=training,
                                         name='conv_11')

        detection_features = [
            conv4_3, conv7, conv8_2, conv9_2, conv10_2, conv11_2
        ]

        detection_features = zip(itertools.cycle(self.k), detection_features)

        clf, reg = list(
            zip(*[
                self._head(o, k=k, name=f"fm_{i}")
                for i, (k, o) in enumerate(detection_features)
            ]))

        return np.concatenate(clf, axis=1), np.concatenate(reg, axis=1)
Example #12
0
    def __call__(self, inputs: types.TensorLike,
                 is_training: bool) -> jnp.ndarray:
        """Connects the module to inputs.

    Args:
      inputs: A 5-D float array of shape `[B, T, H, W, C]`.
      is_training: Whether to use training mode.

    Returns:
      A 5-D float array of shape
      `[B, new_t, new_h, new_w, sum(output_channels)]`.
    """
        # Branch 0
        branch_0 = SUnit3D(output_channels=self._output_channels[0],
                           kernel_shape=(1, 1, 1),
                           separable=False,
                           normalize_fn=self._normalize_fn,
                           self_gating_fn=self._self_gating_fn,
                           name='Branch_0_Conv2d_0a_1x1')(
                               inputs, is_training=is_training)

        # Branch 1
        branch_1 = SUnit3D(output_channels=self._output_channels[1],
                           kernel_shape=(1, 1, 1),
                           separable=False,
                           normalize_fn=self._normalize_fn,
                           self_gating_fn=None,
                           name='Branch_1_Conv2d_0a_1x1')(
                               inputs, is_training=is_training)
        branch_1 = SUnit3D(output_channels=self._output_channels[2],
                           kernel_shape=(self._temporal_kernel_size, 3, 3),
                           separable=True,
                           normalize_fn=self._normalize_fn,
                           self_gating_fn=self._self_gating_fn,
                           name='Branch_1_Conv2d_0b_3x3')(
                               branch_1, is_training=is_training)

        # Branch 2
        branch_2 = SUnit3D(output_channels=self._output_channels[3],
                           kernel_shape=(1, 1, 1),
                           separable=False,
                           normalize_fn=self._normalize_fn,
                           self_gating_fn=None,
                           name='Branch_2_Conv2d_0a_1x1')(
                               inputs, is_training=is_training)
        branch_2 = SUnit3D(output_channels=self._output_channels[4],
                           kernel_shape=(self._temporal_kernel_size, 3, 3),
                           separable=True,
                           normalize_fn=self._normalize_fn,
                           self_gating_fn=self._self_gating_fn,
                           name='Branch_2_Conv2d_0b_3x3')(
                               branch_2, is_training=is_training)

        # Branch 3
        branch_3 = hk.MaxPool(window_shape=(1, 3, 3, 3, 1),
                              strides=(1, 1, 1, 1, 1),
                              padding='SAME',
                              name='Branch_3_MaxPool_0a_3x3')(inputs)
        branch_3 = SUnit3D(output_channels=self._output_channels[5],
                           kernel_shape=(1, 1, 1),
                           separable=False,
                           normalize_fn=self._normalize_fn,
                           self_gating_fn=self._self_gating_fn,
                           name='Branch_3_Conv2d_0b_1x1')(
                               branch_3, is_training=is_training)

        return jnp.concatenate((branch_0, branch_1, branch_2, branch_3),
                               axis=4)
Example #13
0
    def __init__(
        self,
        stochastic_parameters: bool,
        dropout: bool,
        dropout_rate: float,
        linear_model: bool,
        blocks_per_group: Sequence[int],
        num_classes: int,
        bn_config: Optional[Mapping[str, float]] = None,
        resnet_v2: bool = False,
        bottleneck: bool = True,
        channels_per_group: Sequence[int] = (256, 512, 1024, 2048),
        use_projection: Sequence[bool] = (True, True, True, True),
        logits_config: Optional[Mapping[str, Any]] = None,
        name: Optional[str] = None,
        uniform_init_minval: float = -20.0,
        uniform_init_maxval: float = -18.0,
        w_init: str = "uniform",
        b_init: str = "uniform",
    ):
        """Constructs a ResNet model.

    Args:
      stochastic_parameters: TODO(nband).
      dropout: TODO(nband).
      dropout_rate: TODO(nband).
      linear_model: TODO(nband).
      blocks_per_group: A sequence of length 4 that indicates the number of
        blocks created in each group.
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of two elements, ``decay_rate`` and ``eps`` to
        be
      passed on to the :class:`~haiku.BatchNorm` layers. By default the
        ``decay_rate`` is ``0.9`` and ``eps`` is ``1e-5``.
      resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults
        to ``False``.
      bottleneck: Whether the block should bottleneck or not. Defaults to
        ``True``.
      channels_per_group: A sequence of length 4 that indicates the number of
        channels used for each block in each group.
      use_projection: A sequence of length 4 that indicates whether each
        residual block should use projection.
      logits_config: A dictionary of keyword arguments for the logits layer.
      name: Name of the module.
      uniform_init_minval: TODO(nband).
      uniform_init_maxval: TODO(nband).
      w_init: weight init.
      b_init: bias init.
    """
        super().__init__(name=name)
        self.resnet_v2 = resnet_v2
        self.linear_model = linear_model
        self.dropout = dropout
        self.dropout_rate = dropout_rate

        if self.linear_model:
            self.stochastic_parameters_feature_mapping = False
            self.stochastic_parameters_final_layer = stochastic_parameters
        else:
            self.stochastic_parameters_feature_mapping = stochastic_parameters
            self.stochastic_parameters_final_layer = stochastic_parameters

        # TODO(nband): Maybe remove hardcoding here
        self.uniform_init_minval = uniform_init_minval
        self.uniform_init_maxval = uniform_init_maxval

        bn_config = dict(bn_config or {})
        bn_config.setdefault("decay_rate", 0.9)
        bn_config.setdefault("eps", 1e-5)
        bn_config.setdefault("create_scale", True)
        bn_config.setdefault("create_offset", True)

        logits_config = dict(logits_config or {})
        # logits_config.setdefault("w_init", jnp.zeros)
        logits_config.setdefault("name", "logits")
        logits_config.setdefault("with_bias", True)  # TR: added

        # Number of blocks in each group for ResNet.
        check_length(4, blocks_per_group, "blocks_per_group")
        check_length(4, channels_per_group, "channels_per_group")

        self.initial_conv = Conv2dStochastic(
            output_channels=64,
            kernel_shape=7,
            stride=2,
            with_bias=False,
            padding="VALID",
            name="initial_conv",
            stochastic_parameters=self.stochastic_parameters_feature_mapping,
            uniform_init_minval=self.uniform_init_minval,
            uniform_init_maxval=self.uniform_init_maxval,
            w_init=w_init,
            b_init=b_init,
        )

        if not self.resnet_v2:
            self.initial_batchnorm = hk.BatchNorm(name="batchnorm",
                                                  **bn_config)

        self.block_groups = []
        strides = (1, 2, 2, 2)
        for i in range(4):
            self.block_groups.append(
                BlockGroup(
                    stochastic_parameters=self.
                    stochastic_parameters_feature_mapping,
                    dropout=self.dropout,
                    dropout_rate=self.dropout_rate,
                    uniform_init_minval=self.uniform_init_minval,
                    uniform_init_maxval=self.uniform_init_maxval,
                    channels=channels_per_group[i],
                    num_blocks=blocks_per_group[i],
                    stride=strides[i],
                    bn_config=bn_config,
                    bottleneck=bottleneck,
                    use_projection=use_projection[i],
                    name=f"block_group_{i}",
                    w_init=w_init,
                    b_init=b_init,
                ))

        self.max_pool = hk.MaxPool(window_shape=(1, 3, 3, 1),
                                   strides=(1, 2, 2, 1),
                                   padding="SAME")

        if self.resnet_v2:
            self.final_batchnorm = hk.BatchNorm(name="batchnorm", **bn_config)

        self.logits = DenseStochasticHaiku(
            output_size=num_classes,
            uniform_init_minval=self.uniform_init_minval,
            uniform_init_maxval=self.uniform_init_maxval,
            stochastic_parameters=self.stochastic_parameters_final_layer,
            w_init=w_init,
            b_init=b_init,
            **logits_config,
        )
Example #14
0
    def __call__(self,
                 inputs: types.TensorLike,
                 is_training: bool = True,
                 final_endpoint: str = 'Embeddings') -> jnp.ndarray:
        """Connects the TSM ResNetV2 module into the graph.

    Args:
      inputs: A 4-D float array of shape `[B, H, W, C]`.
      is_training: Whether to use training mode.
      final_endpoint: Up to which endpoint to run / return.

    Returns:
      Network output at location `final_endpoint`. A float array which shape
      depends on `final_endpoint`.

    Raises:
      ValueError: If `final_endpoint` is not recognized.
    """

        # Prepare inputs for TSM.
        inputs, tsm_mode, num_frames = tsmu.prepare_inputs(inputs)
        num_frames = num_frames or self._num_frames

        self._final_endpoint = final_endpoint
        if self._final_endpoint not in self.VALID_ENDPOINTS:
            raise ValueError(f'Unknown final endpoint {self._final_endpoint}')

        # Stem convolution.
        end_point = 'tsm_resnet_stem'
        net = hk.Conv2D(output_channels=64 * self._width_mult,
                        kernel_shape=7,
                        stride=2,
                        with_bias=False,
                        name=end_point,
                        padding='SAME')(inputs)
        net = hk.MaxPool(window_shape=(1, 3, 3, 1),
                         strides=(1, 2, 2, 1),
                         padding='SAME')(net)
        if self._final_endpoint == end_point:
            return net

        # Residual block.
        for unit_id, (channels, num_blocks, stride) in enumerate(
                zip(self._channels, self._num_blocks, self._strides)):
            end_point = f'tsm_resnet_unit_{unit_id}'
            net = TSMResNetUnit(
                output_channels=channels * self._width_mult,
                num_blocks=num_blocks,
                stride=stride,
                normalize_fn=self._normalize_fn,
                channel_shift_fraction=self._channel_shift_fraction,
                num_frames=num_frames,
                tsm_mode=tsm_mode,
                name=end_point)(net, is_training=is_training)
            if self._final_endpoint == end_point:
                return net

        if self._normalize_fn is not None:
            net = self._normalize_fn(net, is_training=is_training)
        net = jax.nn.relu(net)

        end_point = 'last_conv'
        if self._final_endpoint == end_point:
            return net
        net = jnp.mean(net, axis=(1, 2))
        # Prepare embedding outputs for TSM (temporal average of features).
        net = tsmu.prepare_outputs(net, tsm_mode, num_frames)
        assert self._final_endpoint == 'Embeddings'
        return net
Example #15
0
def augmented_vgg19(fp: str,
                    content_image: jnp.ndarray,
                    style_image: jnp.ndarray,
                    mean: jnp.ndarray,
                    std: jnp.ndarray,
                    content_layers: List[str] = None,
                    style_layers: List[str] = None,
                    pooling: str = "avg") -> hk.Sequential:
    """Build a VGG19 network augmented by content and style loss layers."""
    pooling = pooling.lower()
    if pooling not in ["avg", "max"]:
        raise ValueError("Pooling method not recognized. Options are: "
                         "\"avg\", \"max\".")

    params = get_model_params(fp=fp)

    # prepend a normalization layer
    layers = [Normalization(content_image, mean, std, "norm")]

    # tracks number of conv layers
    n = 0

    # desired depth layers to compute style/content losses :
    content_layers = content_layers or []
    style_layers = style_layers or []

    model = hk.Sequential(layers=layers)

    for k, p_dict in params.items():
        if "pool" in k:
            if pooling == "avg":
                # exactly as many as pools as convolutions
                layers.append(
                    hk.AvgPool(window_shape=2,
                               strides=2,
                               padding="VALID",
                               channel_axis=1,
                               name=f"avg_pool_{n}"))
            else:
                layers.append(
                    hk.MaxPool(window_shape=2,
                               strides=2,
                               padding="VALID",
                               channel_axis=1,
                               name=f"max_pool_{n}"))
        elif "conv" in k:
            n += 1
            name = f"conv_{n}"

            kernel_h, kernel_w, in_ch, out_ch = p_dict["w"].shape

            # VGG only has square conv kernels
            assert kernel_w == kernel_h, "VGG19 only has square conv kernels"
            kernel_shape = kernel_h

            layers.append(
                hk.Conv2D(output_channels=out_ch,
                          kernel_shape=kernel_shape,
                          stride=1,
                          padding="SAME",
                          data_format="NCHW",
                          w_init=hk.initializers.Constant(p_dict["w"]),
                          b_init=hk.initializers.Constant(p_dict["b"]),
                          name=name))

            if name in style_layers:
                model.layers = tuple(layers)
                style_target = model(style_image, is_training=False)
                layers.append(
                    StyleLoss(target=style_target, name=f"style_loss_{n}"))

            if name in content_layers:
                model.layers = tuple(layers)
                content_target = model(content_image, is_training=False)
                layers.append(
                    ContentLoss(target=content_target,
                                name=f"content_loss_{n}"))

            layers.append(jax.nn.relu)

    # this modifies our n from before in place
    for n in range(len(layers) - 1, -1, -1):
        if isinstance(layers[n], (StyleLoss, ContentLoss)):
            break

    # break off after last content loss layer
    layers = layers[:(n + 1)]

    model.layers = tuple(layers)

    return model