def build(self, input_shape): self._layers = [] if self._num_proj_layers > 0: intermediate_dim = int(input_shape[-1]) for j in range(self._num_proj_layers): if j != self._num_proj_layers - 1: # for the middle layers, use bias and relu for the output. layer = nn_blocks.DenseBN( output_dim=intermediate_dim, use_bias=True, use_normalization=True, activation='relu', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer, use_sync_bn=self._use_sync_bn, norm_momentum=self._norm_momentum, norm_epsilon=self._norm_epsilon, name='nl_%d' % j) else: # for the final layer, neither bias nor relu is used. layer = nn_blocks.DenseBN( output_dim=self._proj_output_dim, use_bias=False, use_normalization=True, activation=None, kernel_regularizer=self._kernel_regularizer, kernel_initializer=self._kernel_initializer, use_sync_bn=self._use_sync_bn, norm_momentum=self._norm_momentum, norm_epsilon=self._norm_epsilon, name='nl_%d' % j) self._layers.append(layer) super(ProjectionHead, self).build(input_shape)
def test_pass_through(self, output_dim, use_bias, use_normalization): test_layer = nn_blocks.DenseBN(output_dim=output_dim, use_bias=use_bias, use_normalization=use_normalization) x = tf.keras.Input(shape=(64, )) out_x = test_layer(x) self.assertAllEqual(out_x.shape.as_list(), [None, output_dim]) # kernel of the dense layer train_var_len = 1 if use_normalization: if use_bias: # batch norm introduce two trainable variables train_var_len += 2 else: # center is set to False if not use bias train_var_len += 1 else: if use_bias: # bias of dense layer train_var_len += 1 self.assertLen(test_layer.trainable_variables, train_var_len)