Exemple #1
0
 def build(self, input_shape):
     self._layers = []
     if self._num_proj_layers > 0:
         intermediate_dim = int(input_shape[-1])
         for j in range(self._num_proj_layers):
             if j != self._num_proj_layers - 1:
                 # for the middle layers, use bias and relu for the output.
                 layer = nn_blocks.DenseBN(
                     output_dim=intermediate_dim,
                     use_bias=True,
                     use_normalization=True,
                     activation='relu',
                     kernel_initializer=self._kernel_initializer,
                     kernel_regularizer=self._kernel_regularizer,
                     bias_regularizer=self._bias_regularizer,
                     use_sync_bn=self._use_sync_bn,
                     norm_momentum=self._norm_momentum,
                     norm_epsilon=self._norm_epsilon,
                     name='nl_%d' % j)
             else:
                 # for the final layer, neither bias nor relu is used.
                 layer = nn_blocks.DenseBN(
                     output_dim=self._proj_output_dim,
                     use_bias=False,
                     use_normalization=True,
                     activation=None,
                     kernel_regularizer=self._kernel_regularizer,
                     kernel_initializer=self._kernel_initializer,
                     use_sync_bn=self._use_sync_bn,
                     norm_momentum=self._norm_momentum,
                     norm_epsilon=self._norm_epsilon,
                     name='nl_%d' % j)
             self._layers.append(layer)
     super(ProjectionHead, self).build(input_shape)
Exemple #2
0
    def test_pass_through(self, output_dim, use_bias, use_normalization):
        test_layer = nn_blocks.DenseBN(output_dim=output_dim,
                                       use_bias=use_bias,
                                       use_normalization=use_normalization)

        x = tf.keras.Input(shape=(64, ))
        out_x = test_layer(x)

        self.assertAllEqual(out_x.shape.as_list(), [None, output_dim])

        # kernel of the dense layer
        train_var_len = 1
        if use_normalization:
            if use_bias:
                # batch norm introduce two trainable variables
                train_var_len += 2
            else:
                # center is set to False if not use bias
                train_var_len += 1
        else:
            if use_bias:
                # bias of dense layer
                train_var_len += 1
        self.assertLen(test_layer.trainable_variables, train_var_len)