def haiku_model(x, dense_kernel_size=64, max_conv_size=256, num_classes=10): layers = [] for i, c in enumerate([32, 64, 128, 256]): c = min(c, max_conv_size) layers.append( hk.Conv2D(output_channels=c, kernel_shape=3, stride=2, name="conv%d_%d" % (i, c))) layers.append(jax.nn.gelu) layers += [ global_spatial_mean_pooling, hk.Linear(dense_kernel_size, name="dense_%d" % dense_kernel_size), jax.nn.gelu, hk.Linear(num_classes, name='logits') ] return hk.Sequential(layers)(x)
def _resnet_layer(inputs, num_filters, normalization_layer, kernel_size=3, strides=1, activation=lambda x: x, use_bias=True, is_training=True): x = inputs x = hk.Conv2D(num_filters, kernel_size, stride=strides, padding="same", w_init=he_normal, with_bias=use_bias)(x) x = normalization_layer()(x, is_training=is_training) x = activation(x) return x
def __call__(self, image, debug=False): """ if debug, then print activation shapes """ # TODO: output should have length self.n_classes # conv_layers = self.depth * [hk.Conv2D(self.n_channels, # kernel_shape=3, # w_init=self.initializer, # b_init=self.initializer, # stride=2), # jax.nn.relu] # convnet = hk.Sequential(conv_layers + [hk.Flatten()]) with_bias = False strides = [1,2,1,2,1,2] names = ['misc'] + ['conv']*5 conv_layers = [ [ hk.Conv2D(self.n_channels, kernel_shape=3, w_init=self.initializer, b_init=self.initializer, with_bias=with_bias, stride=stride, name=name), jax.nn.relu, debug_layer(debug), ] for stride, name in zip(strides, names) ] conv_layers = [l for layer in conv_layers for l in layer] convnet = hk.Sequential(conv_layers + [ hk.Flatten(), hk.Linear(self.n_classes, w_init=self.initializer, b_init=self.initializer, name='misc'), debug_layer(debug), ]) return convnet(image)
def __init__( self, num_layers: int = 1, num_channels: int = 64, use_batchnorm: bool = True, bn_config: Optional[Mapping[str, float]] = None, name: Optional[str] = None, ): """Constructs a Conv2DDownsample model. Args: num_layers: The number of conv->max_pool layers. num_channels: The number of conv output channels. use_batchnorm: Whether to use batchnorm. bn_config: A dictionary of two elements, ``decay_rate`` and ``eps`` to be passed on to the :class:`~haiku.BatchNorm` layers. By default the ``decay_rate`` is ``0.9`` and ``eps`` is ``1e-5``. name: Name of the module. """ super().__init__(name=name) self._num_layers = num_layers self._use_batchnorm = use_batchnorm bn_config = dict(bn_config or {}) bn_config.setdefault('decay_rate', 0.9) bn_config.setdefault('eps', 1e-5) bn_config.setdefault('create_scale', True) bn_config.setdefault('create_offset', True) self.layers = [] for _ in range(self._num_layers): conv = hk.Conv2D(output_channels=num_channels, kernel_shape=7, stride=2, with_bias=False, padding='SAME', name='conv') if use_batchnorm: batchnorm = hk.BatchNorm(name='batchnorm', **bn_config) else: batchnorm = None self.layers.append(dict(conv=conv, batchnorm=batchnorm))
def __init__(self, num_blocks=[5, 5, 5], num_classes=10): super().__init__() self.conv1 = hk.Conv2D( output_channels=16, # output_channels=64, kernel_shape=3, stride=1, with_bias=False, padding='SAME', data_format='NCHW') self.bn1 = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9, data_format='NC...') self.layer1 = MultiBlock(16, 16, [1] + [1] * (num_blocks[0] - 1)) self.layer2 = MultiBlock(16, 32, [2] + [1] * (num_blocks[1] - 1)) self.layer3 = MultiBlock(32, 64, [2] + [1] * (num_blocks[2] - 1)) self.linear = hk.Linear(num_classes)
def __init__(self, num_classes: int = 10, depth: int = 28, width: int = 10, activation: str = 'relu', norm_args: Optional[Dict[str, Any]] = None, name: Optional[str] = None): super(WideResNet, self).__init__(name=name) if (depth - 4) % 6 != 0: raise ValueError('depth should be 6n+4.') self._activation = getattr(jax.nn, activation) if norm_args is None: norm_args = { 'create_offset': True, 'create_scale': True, 'decay_rate': .99, } self._conv = hk.Conv2D(output_channels=16, kernel_shape=(3, 3), stride=1, with_bias=False, name='init_conv') # pytype: disable=not-callable self._bn = hk.BatchNorm(name='batchnorm', **norm_args) self._linear = hk.Linear(num_classes, w_init=jnp.zeros, name='logits') blocks_per_layer = (depth - 4) // 6 filter_sizes = [width * n for n in [16, 32, 64]] self._blocks = [] for layer_num, filter_size in enumerate(filter_sizes): blocks_of_layer = [] for i in range(blocks_per_layer): stride = 2 if (layer_num != 0 and i == 0) else 1 projection_shortcut = (i == 0) blocks_of_layer.append( _WideResNetBlock(num_filters=filter_size, stride=stride, projection_shortcut=projection_shortcut, activation=self._activation, norm_args=norm_args, name='resnet_lay_{}_block_{}'.format( layer_num, i))) self._blocks.append(blocks_of_layer)
def __call__(self, x: jnp.ndarray, is_train: bool): x = hk.Conv2D(output_channels=64, kernel_shape=3, padding='VALID')(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = hk.Conv2D(output_channels=64, kernel_shape=3, padding='VALID')(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = hk.AvgPool(window_shape=2, strides=1, padding='VALID')(x) x = Dropout(0.2, self._seed)(x, is_train) x = hk.Conv2D(output_channels=96, kernel_shape=3, padding='VALID')(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = hk.Conv2D(output_channels=96, kernel_shape=3, padding='VALID')(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = hk.AvgPool(window_shape=2, strides=1, padding='VALID')(x) x = Dropout(0.3, self._seed)(x, is_train) x = hk.Conv2D(output_channels=128, kernel_shape=3, padding='VALID')(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = hk.Conv2D(output_channels=128, kernel_shape=3, padding='VALID')(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = hk.AvgPool(window_shape=2, strides=1, padding='VALID')(x) x = Dropout(0.4, self._seed)(x, is_train) x = hk.Flatten()(x) x = hk.Linear(128)(x) if self._interval: x = self._act_fn(x, self._interval) else: x = self._act_fn(x) x = Dropout(0.5, self._seed)(x, is_train) x = hk.Linear(self._num_classes)(x) return x
def __init__(self): super().__init__() # Block 1 self.conv1_1 = hk.Conv2D(32, 3, w_init=Vscaling(2.0)) self.bn1_1 = hk.BatchNorm(True, True, 0.99) self.conv1_2 = hk.Conv2D(32, 3, w_init=Vscaling(2.0)) self.bn1_2 = hk.BatchNorm(True, True, 0.99) # Block 2 self.conv2_1 = hk.Conv2D(64, 3, w_init=Vscaling(2.0)) self.bn2_1 = hk.BatchNorm(True, True, 0.99) self.conv2_2 = hk.Conv2D(64, 3, w_init=Vscaling(2.0)) self.bn2_2 = hk.BatchNorm(True, True, 0.99) # Block 2 self.conv3_1 = hk.Conv2D(128, 3, w_init=Vscaling(2.0)) self.bn3_1 = hk.BatchNorm(True, True, 0.99) self.conv3_2 = hk.Conv2D(128, 3, w_init=Vscaling(2.0)) self.bn3_2 = hk.BatchNorm(True, True, 0.99) # Linear part self.lin1 = hk.Linear(128, w_init=Vscaling(2.0)) self.bn4 = hk.BatchNorm(True, True, 0.99) self.lin2 = hk.Linear(10, w_init=Vscaling(1.0, "fan_avg"))
def __init__(self): super().__init__() bn_config = { 'create_scale': True, 'create_offset': True, 'decay_rate': 0.999 } ## Definition of the modules. self.conv_block = hk.Sequential([ hk.Conv2D(32, (3, 3), stride=3, rate=1), jax.nn.relu, hk.Conv2D(32, (3, 3), stride=3, rate=1), jax.nn.relu, hk.Conv2D(64, (3, 3), stride=3, rate=1), jax.nn.relu, ]) self.conv_res_block = hk.Sequential([ hk.Conv2D(32, (1, 1), stride=1, rate=1), jax.nn.relu, hk.Conv2D(32, (1, 1), stride=1, rate=1), jax.nn.relu, hk.Conv2D(64, (1, 1), stride=1, rate=1), jax.nn.relu, ]) self.reshape_mod = hk.Flatten() self.lin_res_block = [ (hk.Linear(128), hk.BatchNorm(name='lin_batchnorm_0', **bn_config)), (hk.Linear(256), hk.BatchNorm(name='lin_batchnorm_1', **bn_config)) ] self.final_linear = hk.Linear(10)
def __init__( self, blocks_per_group: Sequence[int], num_classes: Optional[int] = None, bn_config: Optional[Mapping[str, float]] = None, resnet_v2: bool = False, bottleneck: bool = True, channels_per_group: Sequence[int] = (256, 512, 1024, 2048), use_projection: Sequence[bool] = (True, True, True, True), width_multiplier: int = 1, name: Optional[str] = None, ): """Constructs a ResNet model. Args: blocks_per_group: A sequence of length 4 that indicates the number of blocks created in each group. num_classes: The number of classes to classify the inputs into. bn_config: A dictionary of three elements, `decay_rate`, `eps`, and `cross_replica_axis`, to be passed on to the `BatchNorm` layers. By default the `decay_rate` is `0.9` and `eps` is `1e-5`, and the axis is `None`. resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to False. bottleneck: Whether the block should bottleneck or not. Defaults to True. channels_per_group: A sequence of length 4 that indicates the number of channels used for each block in each group. use_projection: A sequence of length 4 that indicates whether each residual block should use projection. width_multiplier: An integer multiplying the number of channels per group. name: Name of the module. """ super().__init__(name=name) self.resnet_v2 = resnet_v2 bn_config = dict(bn_config or {}) bn_config.setdefault('decay_rate', 0.9) bn_config.setdefault('eps', 1e-5) bn_config.setdefault('create_scale', True) bn_config.setdefault('create_offset', True) # Number of blocks in each group for ResNet. check_length(4, blocks_per_group, 'blocks_per_group') check_length(4, channels_per_group, 'channels_per_group') self.initial_conv = hk.Conv2D(output_channels=64 * width_multiplier, kernel_shape=7, stride=2, with_bias=False, padding='SAME', name='initial_conv') if not self.resnet_v2: self.initial_batchnorm = hk.BatchNorm(name='initial_batchnorm', **bn_config) self.block_groups = [] strides = (1, 2, 2, 2) for i in range(4): self.block_groups.append( hk.nets.ResNet.BlockGroup(channels=width_multiplier * channels_per_group[i], num_blocks=blocks_per_group[i], stride=strides[i], bn_config=bn_config, resnet_v2=resnet_v2, bottleneck=bottleneck, use_projection=use_projection[i], name='block_group_%d' % (i))) if self.resnet_v2: self.final_batchnorm = hk.BatchNorm(name='final_batchnorm', **bn_config) self.logits = hk.Linear(num_classes, w_init=jnp.zeros, name='logits')
def conv2d_model(inp): return hk.Conv2D(output_channels=1, kernel_shape=(2, 2), padding='VALID', stride=1, with_bias=True)(inp)
def __init__(self, depth=50, num_classes: Optional[int] = 1000, width_mult: int = 1, normalize_fn: Optional[types.NormalizeFn] = None, name: Optional[Text] = None, remat: bool = False): """Creates ResNetV2 Haiku module. Args: depth: depth of the desired ResNet (18, 34, 50, 101, 152 or 202). num_classes: (int) Number of outputs in final layer. If None will not add a classification head and will return the output embedding. width_mult: multiplier for channel width. normalize_fn: normalization function, see helpers/ name: Name of the module. remat: Whether to rematerialize intermediate activations (saves memory). """ super(ResNetV2, self).__init__(name=name) self._normalize_fn = normalize_fn self._num_classes = num_classes self._width_mult = width_mult self._strides = [1, 2, 2, 2] num_blocks = { 18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3], } if depth not in num_blocks: raise ValueError( f'`depth` should be in {list(num_blocks.keys())} ({depth} given).' ) self._num_blocks = num_blocks[depth] if depth >= 50: self._block_module = BottleneckBlock self._channels = [256, 512, 1024, 2048] else: self._block_module = BasicBlock self._channels = [64, 128, 256, 512] self._initial_conv = hk.Conv2D(output_channels=64 * self._width_mult, kernel_shape=7, stride=2, with_bias=False, padding='SAME', name='initial_conv') if remat: self._initial_conv = hk.remat(self._initial_conv) self._block_groups = [] for i in range(4): self._block_groups.append( ResNetUnit(channels=self._channels[i] * self._width_mult, num_blocks=self._num_blocks[i], block_module=self._block_module, stride=self._strides[i], normalize_fn=self._normalize_fn, name='block_group_%d' % i, remat=remat)) if num_classes is not None: self._logits_layer = hk.Linear(output_size=num_classes, w_init=jnp.zeros, name='logits')
def __init__(self, **kwargs): super(ConcatConv2D, self).__init__() self._layer = hk.Conv2D(**kwargs)
def __init__(self, blocks_per_group, bn_config, bottleneck, channels_per_group, use_projection, strides, n_output_channels=1, use_bn=True, pad_crop=False, name=None): """Constructs a Residual UNet model based on a traditional ResNet. Args: blocks_per_group: A sequence of length 4 that indicates the number of blocks created in each group. bn_config: A dictionary of two elements, ``decay_rate`` and ``eps`` to be passed on to the :class:`~haiku.BatchNorm` layers. By default the ``decay_rate`` is ``0.9`` and ``eps`` is ``1e-5``. resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to ``False``. bottleneck: Whether the block should bottleneck or not. Defaults to ``True``. channels_per_group: A sequence of length 4 that indicates the number of channels used for each block in each group. use_projection: A sequence of length 4 that indicates whether each residual block should use projection. n_output_channels: The number of output channels, for example to change in the case of a complex denoising. Defaults to 1. use_bn: Whether the network should use batch normalisation. Defaults to ``True``. pad_crop: Whether to use cropping/padding to make sure the images can be downsampled and upsampled correctly. Defaults to ``False``. name: Name of the module. """ super().__init__(name=name) self.resnet_v2 = False self.use_bn = use_bn self.pad_crop = pad_crop self.n_output_channels = n_output_channels bn_config = dict(bn_config or {}) bn_config.setdefault("decay_rate", 0.9) bn_config.setdefault("eps", 1e-5) bn_config.setdefault("create_scale", True) bn_config.setdefault("create_offset", True) self.strides = strides bl = len(self.strides) # Number of blocks in each group for ResNet. check_length(bl, blocks_per_group, "blocks_per_group") check_length(bl, channels_per_group, "channels_per_group") self.upsampling = upsample self.pooling = hk.AvgPool(window_shape=2, strides=2, padding='SAME') self.initial_conv = hk.Conv2D(output_channels=32, kernel_shape=7, stride=1, with_bias=not self.use_bn, padding="SAME", name="initial_conv") if not self.resnet_v2 and self.use_bn: self.initial_batchnorm = hk.BatchNorm(name="initial_batchnorm", **bn_config) self.block_groups = [] self.up_block_groups = [] for i in range(bl): self.block_groups.append( BlockGroup(channels=channels_per_group[i], num_blocks=blocks_per_group[i], stride=strides[i], bn_config=bn_config, bottleneck=bottleneck, use_projection=use_projection[i], transpose=False, use_bn=self.use_bn, name="block_group_%d" % (i))) for i in range(bl): self.up_block_groups.append( BlockGroup(channels=channels_per_group[i], num_blocks=blocks_per_group[i], stride=strides[i], bn_config=bn_config, bottleneck=bottleneck, use_projection=use_projection[i], transpose=True, use_bn=self.use_bn, name="up_block_group_%d" % (i))) if self.resnet_v2 and self.use_bn: self.final_batchnorm = hk.BatchNorm(name="final_batchnorm", **bn_config) self.final_conv = hk.Conv2D( output_channels=self.n_output_channels, kernel_shape=5, stride=1, padding='SAME', name='final_conv', )
def __init__(self, channels: int, stride: Union[int, Sequence[int]], use_projection: bool, bn_config: Mapping[str, float], bottleneck: bool, transpose: bool = False, name: Optional[str] = None): super().__init__(name=name) self.use_projection = use_projection bn_config = dict(bn_config) bn_config.setdefault("create_scale", True) bn_config.setdefault("create_offset", True) bn_config.setdefault("decay_rate", 0.999) if transpose: maybe_transposed_conv = hk.Conv2DTranspose else: maybe_transposed_conv = hk.Conv2D if self.use_projection: self.proj_conv = maybe_transposed_conv(output_channels=channels, kernel_shape=1, stride=stride, with_bias=False, padding="SAME", name="shortcut_conv") self.proj_batchnorm = hk.BatchNorm(name="shortcut_batchnorm", **bn_config) channel_div = 4 if bottleneck else 1 conv_0 = hk.Conv2D(output_channels=channels // channel_div, kernel_shape=1 if bottleneck else 3, stride=1, with_bias=False, padding="SAME", name="conv_0") bn_0 = hk.BatchNorm(name="batchnorm_0", **bn_config) conv_1 = maybe_transposed_conv(output_channels=channels // channel_div, kernel_shape=3, stride=stride, with_bias=False, padding="SAME", name="conv_1") bn_1 = hk.BatchNorm(name="batchnorm_1", **bn_config) layers = ((conv_0, bn_0), (conv_1, bn_1)) if bottleneck: conv_2 = hk.Conv2D(output_channels=channels, kernel_shape=1, stride=1, with_bias=False, padding="SAME", name="conv_2") bn_2 = hk.BatchNorm(name="batchnorm_2", scale_init=jnp.zeros, **bn_config) layers = layers + ((conv_2, bn_2), ) self.layers = layers
def __call__(self, x: Tensor, training: bool = False) -> Tuple[Tensor, Tensor]: params = self.ssd_initial_weights x = aj.zoo.VGG16(include_top=False, pretrained=False, initial_weights=self.backbone_initial_weights, output_feature_maps=True)(x) conv4_3, x = x[-2:] conv4_3 = aj.nn.layers.L2Norm(init_fn=( hk.initializers.Constant(params['ssd/l2_norm']['gamma']) if params is not None else hk.initializers.Constant(20.)))(conv4_3) x = hk.MaxPool(window_shape=(1, 3, 3, 1), strides=1, padding='SAME')(x) # Replace fully connected by FCN conv_6 = hk.Conv2D( output_channels=1024, kernel_shape=3, stride=1, rate=6, w_init=(hk.initializers.Constant(params['ssd/conv2_d']['w']) if params is not None else self.xavier_init_fn), b_init=(hk.initializers.Constant(params['ssd/conv2_d']['b']) if params is not None else None), padding='SAME')(x) conv_6 = jax.nn.relu(conv_6) conv7 = hk.Conv2D( output_channels=1024, kernel_shape=1, stride=1, w_init=(hk.initializers.Constant(params['ssd/conv2_d_1']['w']) if params is not None else self.xavier_init_fn), b_init=(hk.initializers.Constant(params['ssd/conv2_d_1']['b']) if params is not None else None), padding='SAME')(conv_6) conv7 = jax.nn.relu(conv7) # Build additional features conv8_2 = self._additional_conv(conv7, 256, 512, stride=2, training=training, name='conv_8') conv9_2 = self._additional_conv(conv8_2, 128, 256, stride=2, training=training, name='conv_9') conv10_2 = self._additional_conv(conv9_2, 128, 256, stride=1, training=training, name='conv_10') conv11_2 = self._additional_conv(conv10_2, 128, 256, stride=1, training=training, name='conv_11') detection_features = [ conv4_3, conv7, conv8_2, conv9_2, conv10_2, conv11_2 ] detection_features = zip(itertools.cycle(self.k), detection_features) clf, reg = list( zip(*[ self._head(o, k=k, name=f"fm_{i}") for i, (k, o) in enumerate(detection_features) ])) return np.concatenate(clf, axis=1), np.concatenate(reg, axis=1)
def __call__(self, inputs: types.TensorLike, is_training: bool = True) -> jnp.ndarray: """Connects the ResNetBlock module into the graph. Args: inputs: A 4-D float array of shape `[B, H, W, C]`. is_training: Whether to use training mode. Returns: A 4-D float array of shape `[B * num_frames, new_h, new_w, output_channels]`. """ # ResNet V2 uses pre-activation, where the batch norm and relu are before # convolutions, rather than after as in ResNet V1. preact = inputs if self._normalize_fn is not None: preact = self._normalize_fn(preact, is_training=is_training) preact = jax.nn.relu(preact) if self._use_projection: shortcut = hk.Conv2D(output_channels=self._output_channels, kernel_shape=1, stride=self._stride, with_bias=False, padding='SAME', name='shortcut_conv')(preact) else: shortcut = inputs # Eventually applies Temporal Shift Module. if self._channel_shift_fraction != 0: preact = tsmu.apply_temporal_shift( preact, tsm_mode=self._tsm_mode, num_frames=self._num_frames, channel_shift_fraction=self._channel_shift_fraction) # First convolution. residual = hk.Conv2D(self._bottleneck_channels, kernel_shape=1, stride=1, with_bias=False, padding='SAME', name='conv_0')(preact) # Second convolution. if self._normalize_fn is not None: residual = self._normalize_fn(residual, is_training=is_training) residual = jax.nn.relu(residual) residual = hk.Conv2D(output_channels=self._bottleneck_channels, kernel_shape=3, stride=self._stride, with_bias=False, padding='SAME', name='conv_1')(residual) # Third convolution. if self._normalize_fn is not None: residual = self._normalize_fn(residual, is_training=is_training) residual = jax.nn.relu(residual) residual = hk.Conv2D(output_channels=self._output_channels, kernel_shape=1, stride=1, with_bias=False, padding='SAME', name='conv_2')(residual) # NOTE: we do not use block multiplier. output = shortcut + residual return output
def __call__(self, inputs: types.TensorLike, is_training: bool = True, final_endpoint: str = 'Embeddings') -> jnp.ndarray: """Connects the TSM ResNetV2 module into the graph. Args: inputs: A 4-D float array of shape `[B, H, W, C]`. is_training: Whether to use training mode. final_endpoint: Up to which endpoint to run / return. Returns: Network output at location `final_endpoint`. A float array which shape depends on `final_endpoint`. Raises: ValueError: If `final_endpoint` is not recognized. """ # Prepare inputs for TSM. inputs, tsm_mode, num_frames = tsmu.prepare_inputs(inputs) num_frames = num_frames or self._num_frames self._final_endpoint = final_endpoint if self._final_endpoint not in self.VALID_ENDPOINTS: raise ValueError(f'Unknown final endpoint {self._final_endpoint}') # Stem convolution. end_point = 'tsm_resnet_stem' net = hk.Conv2D(output_channels=64 * self._width_mult, kernel_shape=7, stride=2, with_bias=False, name=end_point, padding='SAME')(inputs) net = hk.MaxPool(window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding='SAME')(net) if self._final_endpoint == end_point: return net # Residual block. for unit_id, (channels, num_blocks, stride) in enumerate( zip(self._channels, self._num_blocks, self._strides)): end_point = f'tsm_resnet_unit_{unit_id}' net = TSMResNetUnit( output_channels=channels * self._width_mult, num_blocks=num_blocks, stride=stride, normalize_fn=self._normalize_fn, channel_shift_fraction=self._channel_shift_fraction, num_frames=num_frames, tsm_mode=tsm_mode, name=end_point)(net, is_training=is_training) if self._final_endpoint == end_point: return net if self._normalize_fn is not None: net = self._normalize_fn(net, is_training=is_training) net = jax.nn.relu(net) end_point = 'last_conv' if self._final_endpoint == end_point: return net net = jnp.mean(net, axis=(1, 2)) # Prepare embedding outputs for TSM (temporal average of features). net = tsmu.prepare_outputs(net, tsm_mode, num_frames) assert self._final_endpoint == 'Embeddings' return net
def conv(c): return hk.Conv2D(output_channels=c, kernel_shape=3, stride=2)
def augmented_vgg19(fp: str, content_image: jnp.ndarray, style_image: jnp.ndarray, mean: jnp.ndarray, std: jnp.ndarray, content_layers: List[str] = None, style_layers: List[str] = None, pooling: str = "avg") -> hk.Sequential: """Build a VGG19 network augmented by content and style loss layers.""" pooling = pooling.lower() if pooling not in ["avg", "max"]: raise ValueError("Pooling method not recognized. Options are: " "\"avg\", \"max\".") params = get_model_params(fp=fp) # prepend a normalization layer layers = [Normalization(content_image, mean, std, "norm")] # tracks number of conv layers n = 0 # desired depth layers to compute style/content losses : content_layers = content_layers or [] style_layers = style_layers or [] model = hk.Sequential(layers=layers) for k, p_dict in params.items(): if "pool" in k: if pooling == "avg": # exactly as many as pools as convolutions layers.append( hk.AvgPool(window_shape=2, strides=2, padding="VALID", channel_axis=1, name=f"avg_pool_{n}")) else: layers.append( hk.MaxPool(window_shape=2, strides=2, padding="VALID", channel_axis=1, name=f"max_pool_{n}")) elif "conv" in k: n += 1 name = f"conv_{n}" kernel_h, kernel_w, in_ch, out_ch = p_dict["w"].shape # VGG only has square conv kernels assert kernel_w == kernel_h, "VGG19 only has square conv kernels" kernel_shape = kernel_h layers.append( hk.Conv2D(output_channels=out_ch, kernel_shape=kernel_shape, stride=1, padding="SAME", data_format="NCHW", w_init=hk.initializers.Constant(p_dict["w"]), b_init=hk.initializers.Constant(p_dict["b"]), name=name)) if name in style_layers: model.layers = tuple(layers) style_target = model(style_image, is_training=False) layers.append( StyleLoss(target=style_target, name=f"style_loss_{n}")) if name in content_layers: model.layers = tuple(layers) content_target = model(content_image, is_training=False) layers.append( ContentLoss(target=content_target, name=f"content_loss_{n}")) layers.append(jax.nn.relu) # this modifies our n from before in place for n in range(len(layers) - 1, -1, -1): if isinstance(layers[n], (StyleLoss, ContentLoss)): break # break off after last content loss layer layers = layers[:(n + 1)] model.layers = tuple(layers) return model
shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="ConvNDTranspose", create=lambda: hk.ConvNDTranspose(1, 3, 3), shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="Conv1D", create=lambda: hk.Conv1D(3, 3), shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="Conv1DTranspose", create=lambda: hk.Conv1DTranspose(3, 3), shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="Conv2D", create=lambda: hk.Conv2D(3, 3), shape=(BATCH_SIZE, 2, 2, 2)), ModuleDescriptor( name="Conv2DTranspose", create=lambda: hk.Conv2DTranspose(3, 3), shape=(BATCH_SIZE, 2, 2, 2)), ModuleDescriptor( name="Conv3D", create=lambda: hk.Conv3D(3, 3), shape=(BATCH_SIZE, 2, 2, 2, 2)), ModuleDescriptor( name="Conv3DTranspose", create=lambda: hk.Conv3DTranspose(3, 3), shape=(BATCH_SIZE, 2, 2, 2, 2)), ModuleDescriptor( name="DepthwiseConv2D",
def __init__(self, channels: int, stride: Union[int, Sequence[int]], use_projection: bool, bn_config: Mapping[str, float], bottleneck: bool, use_bn: bool = True, transpose: bool = False, name: Optional[str] = None): super().__init__(name=name) self.use_projection = use_projection self.use_bn = use_bn if self.use_bn: bn_config = dict(bn_config) bn_config.setdefault("create_scale", True) bn_config.setdefault("create_offset", True) bn_config.setdefault("decay_rate", 0.999) if transpose: self.pooling_or_upsampling = upsample else: self.pooling_or_upsampling = hk.AvgPool(window_shape=2, strides=2, padding='SAME') self.stride = stride if self.use_projection: # this is just used for the skip connection self.proj_conv = hk.Conv2D( output_channels=channels, kernel_shape=1, # depending on whether it's transpose or not this stride must be # replaced by upsampling or avg pooling stride=1, with_bias=not self.use_bn, padding="SAME", name="shortcut_conv") if self.use_bn: self.proj_batchnorm = hk.BatchNorm(name="shortcut_batchnorm", **bn_config) channel_div = 4 if bottleneck else 1 conv_0 = hk.Conv2D(output_channels=channels // channel_div, kernel_shape=1 if bottleneck else 3, stride=1, with_bias=not self.use_bn, padding="SAME", name="conv_0") if self.use_bn: bn_0 = hk.BatchNorm(name="batchnorm_0", **bn_config) conv_1 = hk.Conv2D(output_channels=channels // channel_div, kernel_shape=3, stride=1, with_bias=not self.use_bn, padding="SAME", name="conv_1") if self.use_bn: bn_1 = hk.BatchNorm(name="batchnorm_1", **bn_config) layers = ((conv_0, bn_0), (conv_1, bn_1)) else: layers = ((conv_0, None), (conv_1, None)) if bottleneck: conv_2 = hk.Conv2D(output_channels=channels, kernel_shape=1, stride=1, with_bias=not self.use_bn, padding="SAME", name="conv_2") if self.use_bn: bn_2 = hk.BatchNorm(name="batchnorm_2", scale_init=jnp.zeros, **bn_config) layers = layers + ((conv_2, bn_2), ) else: layers = layers + ((conv_2, None), ) self.layers = layers