Example #1
0
    def _build(self, inputs, is_training):
        """Connects the module to inputs.

        Args:
          inputs: Inputs to the SepConv component.
          is_training: whether to use training mode for snt.BatchNorm (boolean).

        Returns:
          Outputs from the module.
        """

        intermediate = snt.Conv3D(output_channels=self._output_channels,
                                  kernel_shape=self._sp_kernel_shape,
                                  stride=self._sp_stride_shape,
                                  padding=self.padding,
                                  use_bias=self._use_bias)(inputs)
        net = snt.Conv3D(output_channels=self._output_channels,
                         kernel_shape=self._temp_kernel_shape,
                         stride=self._temp_stride_shape,
                         padding=self.padding,
                         use_bias=self._use_bias)(intermediate)
        if self._use_batch_norm:
            bn = snt.BatchNorm()
            net = bn(net, is_training=is_training, test_local_stats=False)
        if self._activation_fn is not None:
            net = self._activation_fn(net)
        return net
Example #2
0
    def _build(self, inputs, is_training):
        t = self.kernel_shape[0]
        h = self.kernel_shape[1]
        w = self.kernel_shape[2]
        t_stride = self.stride[0]
        h_stride = self.stride[1]
        w_stride = self.stride[2]
        net = snt.Conv3D(output_channels=self.output_channels,
                         kernel_shape=(1, h, w),
                         stride=(1, h_stride, w_stride),
                         padding=snt.SAME,
                         initializers=self.initializer,
                         use_bias=self.use_bias,
                         regularizers=regularizers,
                         name='conv_3d')(inputs)
        net = tf.nn.relu(net)
        net = snt.Conv3D(output_channels=self.output_channels,
                         kernel_shape=(t, 1, 1),
                         stride=(t_stride, 1, 1),
                         padding=snt.SAME,
                         initializers=ones_initializer,
                         use_bias=self.use_bias,
                         regularizers=regularizers,
                         name='conv_3d_temporal')(net)

        if self.use_batch_norm:
            bn = snt.BatchNormV2(scale=True)
            net = bn(net, is_training=is_training, test_local_stats=False)
        if self.activation_fn is not None:
            net = self.activation_fn(net)
        return net
Example #3
0
 def _build(self, inputs, block_function, softmax=False):
     represent = snt.Conv3D(output_channels=self._output_channels // 2,
                            kernel_shape=self._kernel_shape,
                            stride=self._stride,
                            use_bias=self._use_bias)(inputs)
     shape_0, shape_1, shape_2, shape_3, shape_4 = represent.shape
     represent = tf.reshape(represent, [
         represent.shape[0].value, represent.shape[1].value *
         represent.shape[2].value * represent.shape[3].value,
         represent.shape[-1].value
     ])
     if block_function in self._block_function:
         relation = self._block_function[block_function](inputs)
         factor = relation.shape[-1].value
     if softmax:
         relation = relation / tf.cast(factor, tf.float32)
     else:
         relation = tf.nn.softmax(relation)
     response = tf.matmul(relation, represent)
     response = tf.reshape(response,
                           [shape_0, shape_1, shape_2, shape_3, -1])
     # print(,self._output_channels)
     # assert inputs.shape[-1].value == self._output_channels
     fg = snt.Conv3D(output_channels=self._output_channels,
                     kernel_shape=[1, 1, 1])(response)
     if inputs.shape[-1].value == self._output_channels:
         net = inputs + fg
     else:
         net = snt.Conv3D(output_channels=self._output_channels,
                          kernel_shape=[1, 1, 1])(inputs)
         net = net + fg
     return net
 def __init__(self, out_size, post_gain, name=None):
     super(DecoderResBlock3D, self).__init__(name=name)
     assert out_size % 4 == 0
     self.out_size = out_size
     hidden_size = out_size // 4
     self.id_path = snt.Conv3D(self.out_size, 1, name='id_path')
     self.conv_block = snt.Sequential([snt.LayerNorm(-1, True, True, name='layer_norm'),
                                       tf.nn.relu, snt.Conv3D(hidden_size, 1, padding="SAME", name='conv_1'),
                                       tf.nn.relu, snt.Conv3D(hidden_size, 3, padding="SAME", name='conv_2'),
                                       tf.nn.relu, snt.Conv3D(hidden_size, 3, padding="SAME", name='conv_3'),
                                       tf.nn.relu, snt.Conv3D(out_size, 3, padding="SAME", name='conv_4')])
     self.post_gain = post_gain
Example #5
0
 def Dot_product(self, inputs):
     inputs_x = snt.Conv3D(output_channels=self._output_channels // 2,
                           kernel_shape=self._kernel_shape,
                           stride=self._stride,
                           use_bias=self._use_bias)(inputs)
     inputs_x = tf.reshape(
         inputs_x, [inputs_x.shape[-0].value, -1, inputs_x.shape[-1].value])
     inputs_y = snt.Conv3D(output_channels=self._output_channels // 2,
                           kernel_shape=self._kernel_shape,
                           stride=self._stride,
                           use_bias=self._use_bias)(inputs)
     inputs_y = tf.reshape(
         inputs_y, [inputs_y.shape[0].value, -1, inputs_y.shape[-1].value])
     return tf.matmul(inputs_x, tf.transpose(inputs_y, [0, 2, 1]))
Example #6
0
 def _build(self, inputs, is_training):
     
     shortcut = inputs
     
     net = Unit_2D1D(self._in_channels, 
                     self._out_channels, 
                     kernels=[3,3,3],
                     name='1')(inputs, is_training)
     
     net = tf.layers.batch_normalization(net, training=is_training, name='spatbn_1')
     net = tf.nn.relu(net)
     
     net = Unit_2D1D(self._out_channels, 
                     self._out_channels, 
                     kernels=[3,3,3],
                     name='2')(net, is_training)
     
     net = tf.layers.batch_normalization(net, training=is_training, name='spatbn_2')
     
     if self._in_channels != self._out_channels:
         shortcut = snt.Conv3D(output_channels=self._out_channels,
                               kernel_shape=[1,1,1],
                               stride=[2,2,2],
                               padding=snt.SAME,
                               use_bias=False,
                               name='shortcut_projection')(shortcut)
         shortcut = tf.layers.batch_normalization(shortcut, training=is_training, name='shortcut_projection_spatbn')
     
     return tf.nn.relu(net+shortcut)
Example #7
0
    def _build(self, inputs, is_training):
        net = Unit3D(output_channels=self._output_channels[0],
                     kernel_shape=[1, 1, 1],
                     padding=self.padding,
                     use_bias=self._use_bias,
                     name=self.name + "_1")(inputs, is_training=is_training)

        net = SepConv(output_channels=self._output_channels[1],
                      kernel_shape=self.kernel_shape,
                      padding=snt.SAME,
                      name=self.name + "_2")(net, is_training=is_training)

        net = snt.Conv3D(output_channels=self._output_channels[2],
                         kernel_shape=[1, 1, 1],
                         padding=self.padding,
                         use_bias=self._use_bias,
                         name=self.name + "_3")(net)

        net = snt.BatchNorm()(net,
                              is_training=is_training,
                              test_local_stats=False)
        net = layers.add([net, inputs])
        net = self._activation_fn(net)

        return net
Example #8
0
 def Concatenation(self, inputs):
     # TODO
     inputs_x = snt.Conv3D(output_channels=self._output_channels // 2,
                           kernel_shape=self._kernel_shape,
                           stride=self._stride,
                           use_bias=self._use_bias)(inputs)
     inputs_x = tf.reshape(
         inputs_x, [inputs_x.shape[-0].value, -1, inputs_x.shape[-1].value])
     inputs_y = snt.Conv3D(output_channels=self._output_channels // 2,
                           kernel_shape=self._kernel_shape,
                           stride=self._stride,
                           use_bias=self._use_bias)(inputs)
     inputs_y = tf.reshape(
         inputs_y, [inputs_y.shape[0].value, -1, inputs_y.shape[-1].value])
     net = tf.concat([inputs_x, inputs_y], axis=-1)
     pass
Example #9
0
 def _build(self,
            inputs,
            is_training,
            is_spatial_first=True,
            is_separate_pooling=False):
     if is_spatial_first:
         net = snt.Conv3D(output_channels=self._output_channels,
                          kernel_shape=(1, self._x_shape, self._y_shape),
                          stride=(1, self._x_stride, self._y_stride),
                          padding=snt.SAME,
                          use_bias=self._use_bias)(inputs)
         if is_separate_pooling:
             net = tf.nn.max_pool3d(net,
                                    ksize=(1, 1, 3, 3, 1),
                                    strides=(1, 1, 2, 2, 1),
                                    padding='SAME')
         net = snt.Conv3D(output_channels=self._output_channels,
                          kernel_shape=(self._temporal_shape, 1, 1),
                          stride=(self._temporal_stride, 1, 1),
                          padding=snt.SAME,
                          use_bias=self._use_bias)(inputs)
     else:
         net = snt.Conv3D(output_channels=self._output_channels,
                          kernel_shape=(self._temporal_shape, 1, 1),
                          stride=(self._temporal_stride, 1, 1),
                          padding=snt.SAME,
                          use_bias=self._use_bias)(inputs)
         if is_separate_pooling:
             net = tf.nn.max_pool3d(net,
                                    ksize=(1, 1, 3, 3, 1),
                                    strides=(1, 1, 2, 2, 1),
                                    padding='SAME')
         net = snt.Conv3D(output_channels=self._output_channels,
                          kernel_shape=(1, self._x_shape, self._y_shape),
                          stride=(1, self._x_stride, self._y_stride),
                          padding=snt.SAME,
                          use_bias=self._use_bias)(inputs)
     if self._use_batch_norm:
         bn = snt.BatchNorm()
         net = bn(net, is_training=is_training, test_local_stats=False)
     if self._activation_fn is not None:
         net = self._activation_fn(net)
     return net
 def _build(self, inputs, is_training):
   net = snt.Conv3D(output_channels=self._output_channels,
                    kernel_shape=self._kernel_shape,
                    stride=self._stride,
                    padding=snt.SAME,
                    use_bias=self._use_bias)(inputs)
   if self._use_batch_norm:
     bn = snt.BatchNorm()
     net = bn(net, is_training=is_training, test_local_stats=False)
   if self._activation_fn is not None:
     net = self._activation_fn(net)
   return net
 def _build(self, inputs, is_training):
     net = snt.Conv3D(output_channels=self._output_channels,
                      kernel_shape=self._kernel_shape,
                      stride=self._stride,
                      padding=snt.SAME,
                      use_bias=self._use_bias)(inputs)
     if self._use_batch_norm:
         bn = snt.BatchNorm()
         net = bn(net, is_training=is_training, test_local_stats=False)
         # net=tf.contrib.layers.group_norm(net,groups=32,channels_axis=-1,reduction_axes=(-4,-3,-2))
     if self._activation_fn is not None:
         net = self._activation_fn(net)
     return net
Example #12
0
 def _build(self, inputs, is_training):
     
     i = 3 * self._in_filters * self._out_filters * self._kernels[1] * self._kernels[2]
     i /= self._in_filters * self._kernels[1] * self._kernels[2] + 3 * self._out_filters
     middle_filters = int(i)
     net = snt.Conv3D(output_channels=middle_filters,
                  kernel_shape=[1,self._kernels[1],self._kernels[2]],
                  stride=[1,self._strides[1],self._strides[2]],
                  padding=snt.SAME,
                  use_bias=False,
                  name='conv_middle')(inputs)
     
     net = tf.layers.batch_normalization(net, training=is_training, name='spatbn_middle')
     net = tf.nn.relu(net)
     
     net = snt.Conv3D(output_channels=self._out_filters,
                  kernel_shape=[self._kernels[0],1,1],
                  stride=[self._strides[0],1,1],
                  padding=snt.SAME,
                  use_bias=False,
                  name='conv')(net)
     return net
    def __init__(self, hidden_size, num_embeddings, num_groups=4, name=None):
        super(Encoder3D, self).__init__(name=name)
        self.shrink_factor = 2**(num_groups-1)
        num_blk_per_group = 1
        num_layers = num_groups * num_blk_per_group
        post_gain = 1. / num_layers ** 2

        def _single_group(group_idx):
            blk_hidden_size = 2 ** group_idx * hidden_size
            res_blocks = [EncoderResBlock3D(blk_hidden_size, post_gain, name=f'blk_{res_blk}')
                          for res_blk in range(num_blk_per_group)]
            if group_idx < num_groups - 1:
                res_blocks.append(lambda x: tf.nn.max_pool3d(x, 2, strides=2, padding='SAME'))
            return snt.Sequential(res_blocks, name=f'group_{group_idx}')

        groups = [snt.Conv3D(hidden_size, 7, padding="SAME", name='input_group')]
        for groud_idx in range(num_groups):
            groups.append(_single_group(groud_idx))
        groups.append(
            snt.Sequential([tf.nn.relu, snt.Conv3D(num_embeddings, 1, padding="SAME", name='logits_conv')],
                           name='output_group'))

        self.blocks = snt.Sequential(groups, name='groups')
def conv_3d(inputs,
            output_channels,
            kernel_shape,
            strides,
            name,
            activation=tf.nn.relu,
            use_bias=True):
    """Wraps sonnet 3D conv module with a nonlinear activation."""
    conv_out = snt.Conv3D(output_channels=output_channels,
                          kernel_shape=kernel_shape,
                          stride=strides,
                          use_bias=use_bias,
                          name=name)(inputs)
    return activation(conv_out)
Example #15
0
 def _build(self, inputs, is_training):
     net = snt.Conv3D(output_channels=self.output_channels,
                      kernel_shape=self.kernel_shape,
                      stride=self.stride,
                      padding=snt.SAME,
                      name='conv_2d',
                      use_bias=self.use_bias,
                      regularizers=regularizers)(inputs)
     if self.use_batch_norm:
         bn = snt.BatchNorm(scale=self.use_scale)
         net = bn(net, is_training=is_training, test_local_stats=False)
     if self.activation_fn is not None:
         net = self.activation_fn(net)
     return net
    def __init__(self, hidden_size, num_channels=1, num_groups=4, name=None):
        super(Decoder3D, self).__init__(name=name)
        self.shrink_factor = 2**(num_groups-1)
        num_blk_per_group = 1
        num_layers = num_groups * num_blk_per_group
        post_gain = 1. / num_layers ** 2

        def _single_group(group_idx):
            blk_hidden_size = 2 ** (num_groups - group_idx - 1) * hidden_size
            res_blocks = [DecoderResBlock3D(blk_hidden_size, post_gain, name=f'blk_{res_blk}')
                          for res_blk in range(num_blk_per_group)]
            if group_idx < num_groups - 1:
                res_blocks.append(upsample)
            return snt.Sequential(res_blocks, name=f'group_{group_idx}')

        groups = [snt.Conv3D(hidden_size // 2, 1, padding="SAME", name='input_group')]
        for groud_idx in range(num_groups):
            groups.append(_single_group(groud_idx))
        groups.append(
            snt.Sequential(
                [tf.nn.relu, snt.Conv3D(num_channels * 2, 1, padding="SAME", name='likelihood_params_conv')],
                name='output_group'))

        self.blocks = snt.Sequential(groups, name='groups')
Example #17
0
 def func(name, data_format, custom_getter=None):
   conv = snt.Conv3D(
       name=name,
       output_channels=self.OUT_CHANNELS,
       kernel_shape=self.KERNEL_SHAPE,
       use_bias=use_bias,
       initializers=create_initializers(use_bias),
       data_format=data_format,
       custom_getter=custom_getter)
   if data_format == "NDHWC":
     batch_norm = snt.BatchNorm(scale=True, update_ops_collection=None)
   else:  # data_format = "NCDHW"
     batch_norm = snt.BatchNorm(scale=True, update_ops_collection=None,
                                axis=(0, 2, 3, 4))
   return snt.Sequential([conv,
                          functools.partial(batch_norm, is_training=True)])
Example #18
0
 def __call__(self, inputs, is_training):
     """Connects the module to inputs.
     Args:
       inputs: Inputs to the Unit3D component.
       is_training: whether to use training mode for snt.BatchNorm (boolean).
     Returns:
       Outputs from the module.
     """
     net = snt.Conv3D(output_channels=self._output_channels,
                      kernel_shape=self._kernel_shape,
                      stride=self._stride,
                      padding="SAME",
                      with_bias=self._use_bias)(inputs)
     if self._use_batch_norm:
         bn = snt.BatchNorm(create_scale=True, create_offset=True)
         net = bn(net, is_training=is_training, test_local_stats=False
                  )  # test_local_stats: if not, moving averages are used
     if self._activation_fn is not None:
         net = self._activation_fn(net)
     return net
Example #19
0
    def _build(self, inputs, is_training):
        """Connects the module to inputs.

    Args:
      inputs: Inputs to the Unit3D component.
      is_training: whether to use training mode for snt.BatchNorm (boolean).

    Returns:
      Outputs from the module.
    """
        net = snt.Conv3D(output_channels=self._output_channels,
                         kernel_shape=self._kernel_shape,
                         stride=self._stride,
                         padding=snt.SAME,
                         use_bias=self._use_bias)(inputs)
        if self._use_batch_norm:
            bn = snt.BatchNorm()
            #can use moving average or local batch stat
            net = bn(net, is_training=is_training, test_local_stats=False)
        if self._activation_fn is not None:
            net = self._activation_fn(net)
        return net
Example #20
0
 def _build(self, inputs, is_training):
     if self.depthwise:
         net = snt.SeparableConv2D(output_channels=self.output_channels,
                                   channel_multiplier=8,
                                   kernel_shape=self.kernel_shape,
                                   stride=self.stride,
                                   padding=snt.SAME,
                                   use_bias=self.use_bias)(inputs)
     else:
         net = snt.Conv3D(output_channels=self.output_channels,
                          kernel_shape=self.kernel_shape,
                          stride=self.stride,
                          padding=snt.SAME,
                          initializers=self.initializer,
                          use_bias=self.use_bias,
                          regularizers=regularizers,
                          name='conv_2d')(inputs)
     if self.use_batch_norm:
         bn = snt.BatchNorm(scale=self.use_scale)
         net = bn(net, is_training=is_training, test_local_stats=False)
     if self.activation_fn is not None:
         net = self.activation_fn(net)
     return net
def conv_1x1x1(inputs, channels, name):
    return snt.Conv3D(output_channels=channels,
                      kernel_shape=(1, 1, 1),
                      stride=1,
                      padding=snt.SAME,
                      name=name)(inputs)
Example #22
0
 def _build(self,
            inputs,
            is_training):
     
     net = inputs
     
     ### Decomposition
     net = snt.Conv3D(output_channels=45,
                  kernel_shape=[1,7,7],
                  stride=[1,2,2],
                  padding=snt.SAME,
                  use_bias=False,
                  name='conv1_middle')(net)
     
     net = tf.layers.batch_normalization(net, training=is_training, name='conv1_middle_spatbn_relu')
     
     net = tf.nn.relu(net)
     
     net = snt.Conv3D(output_channels=64,
                  kernel_shape=[3,1,1],
                  stride=[1,1,1],
                  padding=snt.SAME,
                  use_bias=False,
                  name='conv1')(net)
     
     net = tf.layers.batch_normalization(net, training=is_training, name='conv1_spatbn_relu')
     
     net = tf.nn.relu(net)
     
     # conv_2x
     net = R21D_Block(64, 64, name='comp_0')(net, is_training)
     net = R21D_Block(64, 64, name='comp_1')(net, is_training)
     net = R21D_Block(64, 64, name='comp_2')(net, is_training)
     
     # conv_3x
     net = R21D_Block(64, 128, name='comp_3')(net, is_training)
     net = R21D_Block(128, 128, name='comp_4')(net, is_training)
     net = R21D_Block(128, 128, name='comp_5')(net, is_training)
     net = R21D_Block(128, 128, name='comp_6')(net, is_training)
     
     # conv_4x
     net = R21D_Block(128, 256, name='comp_7')(net, is_training)
     net = R21D_Block(256, 256, name='comp_8')(net, is_training)
     net = R21D_Block(256, 256, name='comp_9')(net, is_training)
     net = R21D_Block(256, 256, name='comp_10')(net, is_training)
     net = R21D_Block(256, 256, name='comp_11')(net, is_training)
     net = R21D_Block(256, 256, name='comp_12')(net, is_training)
     
     #conv_5x
     net = R21D_Block(256, 512, name='comp_13')(net, is_training)
     net = R21D_Block(512, 512, name='comp_14')(net, is_training)
     net = R21D_Block(512, 512, name='comp_15')(net, is_training)
     
     #Final layers
     #print(net.shape)
     net = tf.nn.pool(net,
                      window_shape=[
                              self._final_temporal_kernel,
                              self._final_spatial_kernel,
                              self._final_spatial_kernel
                      ],
                      pooling_type="AVG",
                      strides=[1,1,1],
                      padding='VALID')
     logits = tf.squeeze(net, [2, 3], name='SpatialSqueeze')
     averaged_logits = tf.reduce_mean(logits, axis=1)
     
     return averaged_logits
     
     
    def _build(self, inputs, is_training=True):
        """Internal method to build the sonnet module.

    Args:
      inputs: tensor of batch input OCT or dense segmentation maps.
              OCT shape: [batch, 41, 450, 450, 1]
              Segmentation map shape: [batch, 41, 450, 450, 17]
      is_training: flag for model usage when training

    Returns:
      Output tensor of module. A tensor with size equal to
      number of classes.
    """
        net = inputs

        # First level.
        net = block(net,
                    'l1',
                    self._filter_chs // 4,
                    block_kernels=[(1, 3, 3), (1, 3, 3)])
        net = max_pool3d(net,
                         pool_size=(1, 2, 2),
                         strides=(1, 2, 2),
                         name='l1_out')
        print('Shape after L1: %s' % net.shape.as_list())

        # Second level
        net = block(net, 'l2', channels_per_layer=self._filter_chs // 2)
        net = max_pool3d(net,
                         pool_size=(1, 2, 2),
                         strides=(1, 2, 2),
                         name='l2_out')
        print('Shape after L2: %s' % net.shape.as_list())

        # Third level
        net = conv_1x1x1(net, self._bottleneck_chs * 4, 'l3_1x1x1')
        net = block(net, 'l3', channels_per_layer=self._filter_chs // 2)
        net = max_pool3d(net,
                         pool_size=(2, 2, 2),
                         strides=(2, 2, 2),
                         name='l3_out')
        print('Shape after L3 level: %s' % net.shape.as_list())

        # Fourth level
        net = conv_1x1x1(net, self._bottleneck_chs * 4, 'l4_1x1x1')
        for i in range(2):
            net = block(net,
                        'l4_b%d' % (i + 1),
                        channels_per_layer=self._filter_chs)
        net = max_pool3d(net,
                         pool_size=(2, 2, 2),
                         strides=(2, 2, 2),
                         name='l4_out')
        print('Shape after L4 level: %s' % net.shape.as_list())

        # Fifth level
        net = conv_1x1x1(net, self._bottleneck_chs * 4, 'l5_1x1x1')
        for i in range(2):
            net = block(net, 'l5_b%d' % i, channels_per_layer=self._filter_chs)
        net = max_pool3d(net,
                         pool_size=(2, 2, 2),
                         strides=(2, 2, 2),
                         name='l5_out')
        print('Shape after L5 level: %s' % net.shape.as_list())

        # Sixth level
        net = conv_1x1x1(net, self._bottleneck_chs * 8, 'l6_1x1x1')
        for i in range(2):
            net = block(net, 'l6_b%d' % i, channels_per_layer=self._filter_chs)
        print('Shape after L6 level: %s' % net.shape.as_list())

        # Output
        net = snt.Conv3D(output_channels=self._bottleneck_chs * 4,
                         kernel_shape=(1, 1, 1),
                         stride=1,
                         padding=snt.SAME,
                         name='final_1x1x1')(net)
        print('Output shape: %s' % net.shape.as_list())
        return net
  def _build(self, inputs, is_training):
    """Adds the network into the graph.

    Args:
      inputs: The network input. A tensor of dtype float32, of shape:
        [batch_size, input_channels, height_in, width_in]
      is_training: True if running in training mode, False otherwise.
    Returns:
      outputs: The network output. A tensor of dtype float32, of shape
        [batch_size, output_channels, height_out, width_out]
    """
    outputs_i = inputs

    # DCGAN layers (5 for 128x128)
    # 128->64->32->16->8->4, spatial dims

    outputs_i = snt.Conv3D(name="first_layer",
                           output_channels=self._output_channels_list[0],
                           kernel_shape=4,
                           stride=2,
                           rate=1,
                           padding=snt.SAME,
                           use_bias=False,
                           data_format=self._data_format,
                           initializers=self._initializers,
                           regularizers=self._regularizers)(outputs_i)

    if self._use_input_batchnorm:
      outputs_i = snt.BatchNorm(
        axis=self._batchnorm_axis)(outputs_i, is_training, test_local_stats=False)
    outputs_i = tf.nn.leaky_relu(outputs_i, alpha=0.2)

    # Set up internal layers
    for output_channels in self._output_channels_list[1:]:
      outputs_i = snt.Conv3D(output_channels=output_channels,
                             kernel_shape=4,
                             stride=2,
                             rate=1,
                             padding=snt.SAME,
                             use_bias=False,
                             data_format=self._data_format,
                             initializers=self._initializers,
                             regularizers=self._regularizers)(outputs_i)
      outputs_i = snt.BatchNorm(
        axis=self._batchnorm_axis)(outputs_i, is_training, test_local_stats=False)
      outputs_i = tf.nn.leaky_relu(outputs_i, alpha=0.2)

    # Set up output layer, no downsampling
    outputs_i = snt.Conv3D(name="last_layer",
                           output_channels=self._latent_size,
                           kernel_shape=4,
                           stride=1,
                           rate=1,
                           padding=snt.SAME,
                           use_bias=False,
                           data_format=self._data_format,
                           initializers=self._initializers,
                           regularizers=self._regularizers)(outputs_i)
    outputs_i = snt.BatchNorm(
      axis=self._batchnorm_axis)(outputs_i, is_training, test_local_stats=False)

    if self._activation is not None:
      outputs = self._activation(outputs_i)
    else:
      outputs = outputs_i

    return outputs