示例#1
0
    def __call__(self, inputs, is_training):
        dwc_layer = depthwise_conv.DepthwiseConv2D(1,
                                                   3,
                                                   stride=self._stride,
                                                   padding=((1, 1), (1, 1)),
                                                   with_bias=self._with_bias,
                                                   name="depthwise_conv")
        pwc_layer = conv.Conv2D(self._channels, (1, 1),
                                stride=1,
                                padding="VALID",
                                with_bias=self._with_bias,
                                name="pointwise_conv")

        net = inputs
        net = dwc_layer(net)
        if self._use_bn:
            net = batch_norm.BatchNorm(create_scale=True,
                                       create_offset=True)(net, is_training)
        net = jax.nn.relu(net)
        net = pwc_layer(net)
        if self._use_bn:
            net = batch_norm.BatchNorm(create_scale=True,
                                       create_offset=True)(net, is_training)
        net = jax.nn.relu(net)
        return net
 def f():
     data = np.ones([1, 10, 10, 3])
     data[0, :, :, 1] += 1
     data[0, :, :, 2] += 2
     data = jnp.array(data)
     net = depthwise_conv.DepthwiseConv2D(
         channel_multiplier=3,
         kernel_shape=3,
         stride=1,
         padding="VALID",
         with_bias=with_bias,
         data_format="channels_last",
         **create_constant_initializers(1.0, 0.0, with_bias))
     return net(data)