def __call__(self, inputs, is_training): dwc_layer = depthwise_conv.DepthwiseConv2D(1, 3, stride=self._stride, padding=((1, 1), (1, 1)), with_bias=self._with_bias, name="depthwise_conv") pwc_layer = conv.Conv2D(self._channels, (1, 1), stride=1, padding="VALID", with_bias=self._with_bias, name="pointwise_conv") net = inputs net = dwc_layer(net) if self._use_bn: net = batch_norm.BatchNorm(create_scale=True, create_offset=True)(net, is_training) net = jax.nn.relu(net) net = pwc_layer(net) if self._use_bn: net = batch_norm.BatchNorm(create_scale=True, create_offset=True)(net, is_training) net = jax.nn.relu(net) return net
def f(): data = np.ones([1, 10, 10, 3]) data[0, :, :, 1] += 1 data[0, :, :, 2] += 2 data = jnp.array(data) net = depthwise_conv.DepthwiseConv2D( channel_multiplier=3, kernel_shape=3, stride=1, padding="VALID", with_bias=with_bias, data_format="channels_last", **create_constant_initializers(1.0, 0.0, with_bias)) return net(data)