def mnist_model(features, **_): return hk.Sequential([ hk.Conv2D(16, (8, 8), padding='SAME', stride=(2, 2)), jax.nn.relu, hk.MaxPool(2, 1, padding='VALID'), # matches stax hk.Conv2D(32, (4, 4), padding='VALID', stride=(2, 2)), jax.nn.relu, hk.MaxPool(2, 1, padding='VALID'), # matches stax hk.Flatten(), hk.Linear(32), jax.nn.relu, hk.Linear(10), ])(features)
def __call__(self, x): return hk.Sequential([ hk.Conv2D(8, (3, 3)), jax.nn.relu, hk.MaxPool((1, 2, 2, 1), (1, 2, 2, 1), 'VALID'), hk.Flatten(), hk.Linear(3, with_bias=False) ])(x)
def __init__(self, model_size): super().__init__() if model_size == 'large': channels = { 'prep': 64, 'layer1': 128, 'layer2': 256, 'layer3': 512 } elif model_size == 'med': channels = {'prep': 32, 'layer1': 64, 'layer2': 128, 'layer3': 256} elif model_size == 'small': channels = {'prep': 16, 'layer1': 32, 'layer2': 64, 'layer3': 128} elif model_size == 'tiny': channels = {'prep': 8, 'layer1': 16, 'layer2': 32, 'layer3': 64} self.prep = ConvBN(3, channels['prep']) # Layer 1 self.conv1 = ConvBN(channels['prep'], channels['layer1']) self.pool1 = hk.MaxPool(window_shape=(1, 1, 2, 2), strides=(1, 1, 2, 2), padding='SAME') self.residual1 = Residual(channels['layer1']) # Layer 2 self.conv2 = ConvBN(channels['layer1'], channels['layer2']) self.pool2 = hk.MaxPool(window_shape=(1, 1, 2, 2), strides=(1, 1, 2, 2), padding='SAME') # Layer 3 self.conv3 = ConvBN(channels['layer2'], channels['layer3']) self.pool3 = hk.MaxPool(window_shape=(1, 1, 2, 2), strides=(1, 1, 2, 2), padding='SAME') self.residual3 = Residual(channels['layer3']) self.pool4 = hk.MaxPool(window_shape=(1, 1, 4, 4), strides=(1, 1, 4, 4), padding='SAME') self.fc = hk.Linear(10, with_bias=False) self.logit_weight = 0.125
def lenet_fn(batch, is_training): """Network inspired by LeNet-5.""" x, _ = batch cnn = hk.Sequential([ hk.Conv2D(output_channels=6, kernel_shape=5, padding="SAME"), jax.nn.relu, hk.MaxPool(window_shape=3, strides=2, padding="VALID"), hk.Conv2D(output_channels=16, kernel_shape=5, padding="SAME"), jax.nn.relu, hk.MaxPool(window_shape=3, strides=2, padding="VALID"), hk.Conv2D(output_channels=120, kernel_shape=5, padding="SAME"), jax.nn.relu, hk.MaxPool(window_shape=3, strides=2, padding="VALID"), hk.Flatten(), hk.Linear(84), jax.nn.relu, hk.Linear(num_classes), ]) return cnn(x)
def lenet_fn(batch): """Network inspired by LeNet-5.""" x, _ = batch x = x.astype(jnp.float32) cnn = hk.Sequential([ hk.Conv2D(output_channels=32, kernel_shape=5, padding="SAME"), jax.nn.relu, hk.MaxPool(window_shape=3, strides=2, padding="VALID"), hk.Conv2D(output_channels=64, kernel_shape=5, padding="SAME"), jax.nn.relu, hk.MaxPool(window_shape=3, strides=2, padding="VALID"), hk.Conv2D(output_channels=128, kernel_shape=5, padding="SAME"), hk.MaxPool(window_shape=3, strides=2, padding="VALID"), hk.Flatten(), hk.Linear(1000), jax.nn.relu, hk.Linear(1000), jax.nn.relu, hk.Linear(10), ]) return cnn(x)
def _call_layers( cfg, inp: Tensor, batch_norm: bool = True, include_top: bool = True, initial_weights: Optional[hk.Params] = None, output_feature_maps: bool = False) -> Union[Tensor, List[Tensor]]: x = inp partial_results = [] # Ignore max pooling if we do not append the classifier if not include_top: cfg = cfg[:-1] i = 0 base_name = 'vgg16/conv2_d' for v in cfg: if v == 'M': partial_results.append(x) x = hk.MaxPool(window_shape=2, strides=2, padding="VALID")(x) else: if i == 0: param_name = base_name else: param_name = base_name + f'_{i}' i += 1 w_init = (None if initial_weights is None else hk.initializers.Constant( constant=initial_weights[param_name]['w'])) b_init = (None if initial_weights is None else hk.initializers.Constant( constant=initial_weights[param_name]['b'])) x = hk.Conv2D(v, kernel_shape=3, stride=1, padding='SAME', w_init=w_init, b_init=b_init)(x) if batch_norm: x = hk.BatchNorm(True, True, decay_rate=0.999)(x) x = jax.nn.relu(x) partial_results.append(x) if not output_feature_maps: return partial_results[-1] else: return partial_results
def __init__(self): super().__init__() # input is 28x28 # padding=2 for same padding self.conv1 = hk.Conv2D(output_channels=32, kernel_shape=5, padding='SAME', data_format='NCHW') self.pool1 = hk.MaxPool(window_shape=(1, 1, 2, 2), strides=(1, 1, 2, 2), padding='SAME') # feature map size is 14*14 by pooling # padding=2 for same padding self.conv2 = hk.Conv2D(output_channels=64, kernel_shape=5, padding='SAME', data_format='NCHW') self.pool2 = hk.MaxPool(window_shape=(1, 1, 2, 2), strides=(1, 1, 2, 2), padding='SAME') # feature map size is 7*7 by pooling self.fc = hk.Linear(10)
def __call__(self, x: jnp.ndarray, is_train: bool): x = hk.Conv2D(output_channels=32, kernel_shape=(3, 3), padding='VALID')(x) x = jax.nn.relu(x) x = hk.Conv2D(output_channels=64, kernel_shape=(3, 3), padding='VALID')(x) x = jax.nn.relu(x) x = ( hk.MaxPool( window_shape=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='VALID')(x)) x = Dropout(rate=0.25)(x, is_train) x = hk.Flatten()(x) x = hk.Linear(128)(x) x = jax.nn.relu(x) x = Dropout(rate=0.5)(x, is_train) x = hk.Linear(self._num_classes)(x) return x
def make_downsampling_layer( strategy: Union[str, DownsamplingStrategy], output_channels: int, ) -> hk.SupportsCall: """Returns a sequence of modules corresponding to the desired downsampling.""" strategy = DownsamplingStrategy(strategy) if strategy is DownsamplingStrategy.AVG_POOL: return hk.AvgPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding='SAME') elif strategy is DownsamplingStrategy.CONV: return hk.Sequential([ hk.Conv2D(output_channels, kernel_shape=3, stride=2, w_init=hk.initializers.TruncatedNormal(1e-2)), ]) elif strategy is DownsamplingStrategy.LAYERNORM_RELU_CONV: return hk.Sequential([ hk.LayerNorm(axis=(1, 2, 3), create_scale=True, create_offset=True, eps=1e-6), jax.nn.relu, hk.Conv2D(output_channels, kernel_shape=3, stride=2, w_init=hk.initializers.TruncatedNormal(1e-2)), ]) elif strategy is DownsamplingStrategy.CONV_MAX: return hk.Sequential([ hk.Conv2D(output_channels, kernel_shape=3, stride=1), hk.MaxPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding='SAME') ]) else: raise ValueError( 'Unrecognized downsampling strategy. Expected one of' f' {[strategy.value for strategy in DownsamplingStrategy]}' f' but received {strategy}.')
def forward(batch, is_training): x, _ = batch batch_size = x.shape[0] x = hk.Embed(vocab_size=max_features, embed_dim=embedding_size)(x) x = hk.Conv1D(output_channels=num_filters, kernel_shape=kernel_size, padding="VALID")(x) if use_swish: x = jax.nn.swish(x) else: x = jax.nn.relu(x) if use_maxpool: x = hk.MaxPool( window_shape=pool_size, strides=pool_size, padding='VALID', channel_axis=2)(x) x = jnp.moveaxis(x, 1, 0)[:, :] #[T, B, F] lstm_layer = hk.LSTM(hidden_size=cell_size) init_state = lstm_layer.initial_state(batch_size) x, state = hk.static_unroll(lstm_layer, x, init_state) x = x[-1] logits = hk.Linear(num_classes)(x) return logits
def __call__(self, x: Tensor, training: bool = False) -> Tuple[Tensor, Tensor]: params = self.ssd_initial_weights x = aj.zoo.VGG16(include_top=False, pretrained=False, initial_weights=self.backbone_initial_weights, output_feature_maps=True)(x) conv4_3, x = x[-2:] conv4_3 = aj.nn.layers.L2Norm(init_fn=( hk.initializers.Constant(params['ssd/l2_norm']['gamma']) if params is not None else hk.initializers.Constant(20.)))(conv4_3) x = hk.MaxPool(window_shape=(1, 3, 3, 1), strides=1, padding='SAME')(x) # Replace fully connected by FCN conv_6 = hk.Conv2D( output_channels=1024, kernel_shape=3, stride=1, rate=6, w_init=(hk.initializers.Constant(params['ssd/conv2_d']['w']) if params is not None else self.xavier_init_fn), b_init=(hk.initializers.Constant(params['ssd/conv2_d']['b']) if params is not None else None), padding='SAME')(x) conv_6 = jax.nn.relu(conv_6) conv7 = hk.Conv2D( output_channels=1024, kernel_shape=1, stride=1, w_init=(hk.initializers.Constant(params['ssd/conv2_d_1']['w']) if params is not None else self.xavier_init_fn), b_init=(hk.initializers.Constant(params['ssd/conv2_d_1']['b']) if params is not None else None), padding='SAME')(conv_6) conv7 = jax.nn.relu(conv7) # Build additional features conv8_2 = self._additional_conv(conv7, 256, 512, stride=2, training=training, name='conv_8') conv9_2 = self._additional_conv(conv8_2, 128, 256, stride=2, training=training, name='conv_9') conv10_2 = self._additional_conv(conv9_2, 128, 256, stride=1, training=training, name='conv_10') conv11_2 = self._additional_conv(conv10_2, 128, 256, stride=1, training=training, name='conv_11') detection_features = [ conv4_3, conv7, conv8_2, conv9_2, conv10_2, conv11_2 ] detection_features = zip(itertools.cycle(self.k), detection_features) clf, reg = list( zip(*[ self._head(o, k=k, name=f"fm_{i}") for i, (k, o) in enumerate(detection_features) ])) return np.concatenate(clf, axis=1), np.concatenate(reg, axis=1)
def __call__(self, inputs: types.TensorLike, is_training: bool) -> jnp.ndarray: """Connects the module to inputs. Args: inputs: A 5-D float array of shape `[B, T, H, W, C]`. is_training: Whether to use training mode. Returns: A 5-D float array of shape `[B, new_t, new_h, new_w, sum(output_channels)]`. """ # Branch 0 branch_0 = SUnit3D(output_channels=self._output_channels[0], kernel_shape=(1, 1, 1), separable=False, normalize_fn=self._normalize_fn, self_gating_fn=self._self_gating_fn, name='Branch_0_Conv2d_0a_1x1')( inputs, is_training=is_training) # Branch 1 branch_1 = SUnit3D(output_channels=self._output_channels[1], kernel_shape=(1, 1, 1), separable=False, normalize_fn=self._normalize_fn, self_gating_fn=None, name='Branch_1_Conv2d_0a_1x1')( inputs, is_training=is_training) branch_1 = SUnit3D(output_channels=self._output_channels[2], kernel_shape=(self._temporal_kernel_size, 3, 3), separable=True, normalize_fn=self._normalize_fn, self_gating_fn=self._self_gating_fn, name='Branch_1_Conv2d_0b_3x3')( branch_1, is_training=is_training) # Branch 2 branch_2 = SUnit3D(output_channels=self._output_channels[3], kernel_shape=(1, 1, 1), separable=False, normalize_fn=self._normalize_fn, self_gating_fn=None, name='Branch_2_Conv2d_0a_1x1')( inputs, is_training=is_training) branch_2 = SUnit3D(output_channels=self._output_channels[4], kernel_shape=(self._temporal_kernel_size, 3, 3), separable=True, normalize_fn=self._normalize_fn, self_gating_fn=self._self_gating_fn, name='Branch_2_Conv2d_0b_3x3')( branch_2, is_training=is_training) # Branch 3 branch_3 = hk.MaxPool(window_shape=(1, 3, 3, 3, 1), strides=(1, 1, 1, 1, 1), padding='SAME', name='Branch_3_MaxPool_0a_3x3')(inputs) branch_3 = SUnit3D(output_channels=self._output_channels[5], kernel_shape=(1, 1, 1), separable=False, normalize_fn=self._normalize_fn, self_gating_fn=self._self_gating_fn, name='Branch_3_Conv2d_0b_1x1')( branch_3, is_training=is_training) return jnp.concatenate((branch_0, branch_1, branch_2, branch_3), axis=4)
def __init__( self, stochastic_parameters: bool, dropout: bool, dropout_rate: float, linear_model: bool, blocks_per_group: Sequence[int], num_classes: int, bn_config: Optional[Mapping[str, float]] = None, resnet_v2: bool = False, bottleneck: bool = True, channels_per_group: Sequence[int] = (256, 512, 1024, 2048), use_projection: Sequence[bool] = (True, True, True, True), logits_config: Optional[Mapping[str, Any]] = None, name: Optional[str] = None, uniform_init_minval: float = -20.0, uniform_init_maxval: float = -18.0, w_init: str = "uniform", b_init: str = "uniform", ): """Constructs a ResNet model. Args: stochastic_parameters: TODO(nband). dropout: TODO(nband). dropout_rate: TODO(nband). linear_model: TODO(nband). blocks_per_group: A sequence of length 4 that indicates the number of blocks created in each group. num_classes: The number of classes to classify the inputs into. bn_config: A dictionary of two elements, ``decay_rate`` and ``eps`` to be passed on to the :class:`~haiku.BatchNorm` layers. By default the ``decay_rate`` is ``0.9`` and ``eps`` is ``1e-5``. resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to ``False``. bottleneck: Whether the block should bottleneck or not. Defaults to ``True``. channels_per_group: A sequence of length 4 that indicates the number of channels used for each block in each group. use_projection: A sequence of length 4 that indicates whether each residual block should use projection. logits_config: A dictionary of keyword arguments for the logits layer. name: Name of the module. uniform_init_minval: TODO(nband). uniform_init_maxval: TODO(nband). w_init: weight init. b_init: bias init. """ super().__init__(name=name) self.resnet_v2 = resnet_v2 self.linear_model = linear_model self.dropout = dropout self.dropout_rate = dropout_rate if self.linear_model: self.stochastic_parameters_feature_mapping = False self.stochastic_parameters_final_layer = stochastic_parameters else: self.stochastic_parameters_feature_mapping = stochastic_parameters self.stochastic_parameters_final_layer = stochastic_parameters # TODO(nband): Maybe remove hardcoding here self.uniform_init_minval = uniform_init_minval self.uniform_init_maxval = uniform_init_maxval bn_config = dict(bn_config or {}) bn_config.setdefault("decay_rate", 0.9) bn_config.setdefault("eps", 1e-5) bn_config.setdefault("create_scale", True) bn_config.setdefault("create_offset", True) logits_config = dict(logits_config or {}) # logits_config.setdefault("w_init", jnp.zeros) logits_config.setdefault("name", "logits") logits_config.setdefault("with_bias", True) # TR: added # Number of blocks in each group for ResNet. check_length(4, blocks_per_group, "blocks_per_group") check_length(4, channels_per_group, "channels_per_group") self.initial_conv = Conv2dStochastic( output_channels=64, kernel_shape=7, stride=2, with_bias=False, padding="VALID", name="initial_conv", stochastic_parameters=self.stochastic_parameters_feature_mapping, uniform_init_minval=self.uniform_init_minval, uniform_init_maxval=self.uniform_init_maxval, w_init=w_init, b_init=b_init, ) if not self.resnet_v2: self.initial_batchnorm = hk.BatchNorm(name="batchnorm", **bn_config) self.block_groups = [] strides = (1, 2, 2, 2) for i in range(4): self.block_groups.append( BlockGroup( stochastic_parameters=self. stochastic_parameters_feature_mapping, dropout=self.dropout, dropout_rate=self.dropout_rate, uniform_init_minval=self.uniform_init_minval, uniform_init_maxval=self.uniform_init_maxval, channels=channels_per_group[i], num_blocks=blocks_per_group[i], stride=strides[i], bn_config=bn_config, bottleneck=bottleneck, use_projection=use_projection[i], name=f"block_group_{i}", w_init=w_init, b_init=b_init, )) self.max_pool = hk.MaxPool(window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding="SAME") if self.resnet_v2: self.final_batchnorm = hk.BatchNorm(name="batchnorm", **bn_config) self.logits = DenseStochasticHaiku( output_size=num_classes, uniform_init_minval=self.uniform_init_minval, uniform_init_maxval=self.uniform_init_maxval, stochastic_parameters=self.stochastic_parameters_final_layer, w_init=w_init, b_init=b_init, **logits_config, )
def __call__(self, inputs: types.TensorLike, is_training: bool = True, final_endpoint: str = 'Embeddings') -> jnp.ndarray: """Connects the TSM ResNetV2 module into the graph. Args: inputs: A 4-D float array of shape `[B, H, W, C]`. is_training: Whether to use training mode. final_endpoint: Up to which endpoint to run / return. Returns: Network output at location `final_endpoint`. A float array which shape depends on `final_endpoint`. Raises: ValueError: If `final_endpoint` is not recognized. """ # Prepare inputs for TSM. inputs, tsm_mode, num_frames = tsmu.prepare_inputs(inputs) num_frames = num_frames or self._num_frames self._final_endpoint = final_endpoint if self._final_endpoint not in self.VALID_ENDPOINTS: raise ValueError(f'Unknown final endpoint {self._final_endpoint}') # Stem convolution. end_point = 'tsm_resnet_stem' net = hk.Conv2D(output_channels=64 * self._width_mult, kernel_shape=7, stride=2, with_bias=False, name=end_point, padding='SAME')(inputs) net = hk.MaxPool(window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding='SAME')(net) if self._final_endpoint == end_point: return net # Residual block. for unit_id, (channels, num_blocks, stride) in enumerate( zip(self._channels, self._num_blocks, self._strides)): end_point = f'tsm_resnet_unit_{unit_id}' net = TSMResNetUnit( output_channels=channels * self._width_mult, num_blocks=num_blocks, stride=stride, normalize_fn=self._normalize_fn, channel_shift_fraction=self._channel_shift_fraction, num_frames=num_frames, tsm_mode=tsm_mode, name=end_point)(net, is_training=is_training) if self._final_endpoint == end_point: return net if self._normalize_fn is not None: net = self._normalize_fn(net, is_training=is_training) net = jax.nn.relu(net) end_point = 'last_conv' if self._final_endpoint == end_point: return net net = jnp.mean(net, axis=(1, 2)) # Prepare embedding outputs for TSM (temporal average of features). net = tsmu.prepare_outputs(net, tsm_mode, num_frames) assert self._final_endpoint == 'Embeddings' return net
def augmented_vgg19(fp: str, content_image: jnp.ndarray, style_image: jnp.ndarray, mean: jnp.ndarray, std: jnp.ndarray, content_layers: List[str] = None, style_layers: List[str] = None, pooling: str = "avg") -> hk.Sequential: """Build a VGG19 network augmented by content and style loss layers.""" pooling = pooling.lower() if pooling not in ["avg", "max"]: raise ValueError("Pooling method not recognized. Options are: " "\"avg\", \"max\".") params = get_model_params(fp=fp) # prepend a normalization layer layers = [Normalization(content_image, mean, std, "norm")] # tracks number of conv layers n = 0 # desired depth layers to compute style/content losses : content_layers = content_layers or [] style_layers = style_layers or [] model = hk.Sequential(layers=layers) for k, p_dict in params.items(): if "pool" in k: if pooling == "avg": # exactly as many as pools as convolutions layers.append( hk.AvgPool(window_shape=2, strides=2, padding="VALID", channel_axis=1, name=f"avg_pool_{n}")) else: layers.append( hk.MaxPool(window_shape=2, strides=2, padding="VALID", channel_axis=1, name=f"max_pool_{n}")) elif "conv" in k: n += 1 name = f"conv_{n}" kernel_h, kernel_w, in_ch, out_ch = p_dict["w"].shape # VGG only has square conv kernels assert kernel_w == kernel_h, "VGG19 only has square conv kernels" kernel_shape = kernel_h layers.append( hk.Conv2D(output_channels=out_ch, kernel_shape=kernel_shape, stride=1, padding="SAME", data_format="NCHW", w_init=hk.initializers.Constant(p_dict["w"]), b_init=hk.initializers.Constant(p_dict["b"]), name=name)) if name in style_layers: model.layers = tuple(layers) style_target = model(style_image, is_training=False) layers.append( StyleLoss(target=style_target, name=f"style_loss_{n}")) if name in content_layers: model.layers = tuple(layers) content_target = model(content_image, is_training=False) layers.append( ContentLoss(target=content_target, name=f"content_loss_{n}")) layers.append(jax.nn.relu) # this modifies our n from before in place for n in range(len(layers) - 1, -1, -1): if isinstance(layers[n], (StyleLoss, ContentLoss)): break # break off after last content loss layer layers = layers[:(n + 1)] model.layers = tuple(layers) return model