def test_residual_block_ordering(self):
        inputs = tf.constant([[1.0, -1.0], [0.5, -1.5]])

        inner_layer = tf.keras.layers.ReLU()
        normalization_layer = tf.keras.layers.Lambda(lambda x: 2 * x)

        residual_block_default_order = readtwice_layers.ResidualBlock(
            inner_layer=inner_layer,
            normalization_layer=normalization_layer,
            use_pre_activation_order=False)
        default_order_result = residual_block_default_order(inputs)

        residual_block_pre_act_order = readtwice_layers.ResidualBlock(
            inner_layer=inner_layer,
            normalization_layer=normalization_layer,
            use_pre_activation_order=True)
        pre_act_order_result = residual_block_pre_act_order(inputs)

        self.evaluate(tf.compat.v1.global_variables_initializer())

        self.assertAllClose([[4.0, -2.0], [2.0, -3.0]], default_order_result)

        self.assertAllClose([[3.0, -1.0], [1.5, -1.5]], pre_act_order_result)
    def test_residual_block_training_vs_inference_dropout(
            self, use_pre_activation_order):
        tf.compat.v1.random.set_random_seed(1234)
        np.random.seed(1234)

        batch_size = 3
        input_size = 10
        inputs = tf.constant(np.random.normal(size=[batch_size, input_size]))

        residual_block = readtwice_layers.ResidualBlock(
            dropout_probability=0.5,
            use_pre_activation_order=use_pre_activation_order)

        inference_output1 = residual_block(inputs, training=False)
        inference_output2 = residual_block(inputs, training=False)
        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.assertAllClose(inference_output1, inference_output2)

        # Dropout makes this non-deterministic.
        training_output1 = residual_block(inputs, training=True)
        training_output2 = residual_block(inputs, training=True)
        self.assertNotAllClose(training_output1, training_output2)
    def test_residual_block_training_vs_inference_normalization_layer(
            self, use_pre_activation_order):
        np.random.seed(1234)

        batch_size = 3
        input_size = 10
        inputs = tf.constant(np.random.normal(size=[batch_size, input_size]))

        residual_block = readtwice_layers.ResidualBlock(
            normalization_layer=tf.keras.layers.BatchNormalization(),
            dropout_probability=0.0,
            use_pre_activation_order=use_pre_activation_order)

        inference_output1 = residual_block(inputs, training=False)
        inference_output2 = residual_block(inputs, training=False)
        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.assertAllClose(inference_output1, inference_output2)

        training_output1 = residual_block(inputs, training=True)
        training_output2 = residual_block(inputs, training=True)
        self.assertAllClose(training_output1, training_output2)

        # Batch normalization gives different results for training vs. inference.
        self.assertNotAllClose(inference_output1, training_output1)
示例#4
0
    def __init__(self,
                 config,
                 use_one_hot_embeddings=False,
                 name="read_it_twice_bert",
                 **kwargs):
        """Constructor for ReadItTwiceBertModel.

    Args:
      config: `model_config.ReadItTwiceBertConfig` instance.
      use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
        embeddings or tf.nn.embedding_lookup() for the word embeddings.
      name: (Optional) name of the layer.
      **kwargs: Forwarded to super.

    Raises:
      ValueError: The config is invalid.
    """
        super(ReadItTwiceBertModel, self).__init__(name=name, **kwargs)

        self.use_one_hot_embeddings = use_one_hot_embeddings

        if config.cross_attention_top_k is not None:
            assert config.second_read_type == "cross_attend_once"

        if config.embedding_size is None:
            config = dataclasses.replace(config,
                                         embedding_size=config.hidden_size)

        self.config = config

        self.token_embedding = readtwice_layers.EmbeddingLookup(
            vocab_size=config.vocab_size,
            embedding_size=config.embedding_size,
            projection_size=config.hidden_size,
            initializer_range=config.initializer_range,
            use_one_hot_lookup=use_one_hot_embeddings,
            name="token_emb_lookup")

        self.token_embedding_norm = tf.keras.layers.LayerNormalization(
            axis=-1, epsilon=1e-12, name="emb_layer_norm")
        self.token_embedding_dropout = tf.keras.layers.Dropout(
            rate=config.hidden_dropout_prob)

        self.position_embedding = readtwice_layers.EmbeddingLookup(
            vocab_size=config.max_seq_length,
            embedding_size=config.hidden_size,
            initializer_range=config.initializer_range,
            use_one_hot_lookup=use_one_hot_embeddings,
            name="position_emb_lookup_long")
        # Call layers to force variable initialization.
        self.position_embedding(tf.ones([1, 1], tf.int32))

        if config.cross_attention_pos_emb_mode is not None:
            # We would end up adding block position embeddings multiple times.
            assert config.summary_postprocessing_type not in [
                "pos", "transformer"
            ]

        if config.second_read_type == "from_scratch":
            share_kv_projections_first_read = config.share_kv_projections
        else:
            # Summaries are not going to be used by the first read model anyway.
            share_kv_projections_first_read = True

        self.transformer_with_side_inputs = readtwice_layers.TransformerWithSideInputLayers(
            hidden_size=config.hidden_size,
            num_hidden_layers=config.num_hidden_layers,
            num_attention_heads=config.num_attention_heads,
            intermediate_size=config.intermediate_size,
            hidden_act=tensor_utils.get_activation(config.hidden_act),
            hidden_dropout_prob=config.hidden_dropout_prob,
            attention_probs_dropout_prob=config.attention_probs_dropout_prob,
            initializer_range=config.initializer_range,
            share_kv_projections=share_kv_projections_first_read,
            name="transformer_layers")
        # grad_checkpointing_period=config.grad_checkpointing_period)

        self.summary_extraction = SummaryExtraction(
            config=config, use_one_hot_embeddings=use_one_hot_embeddings)

        if config.second_read_type == "new_layers":
            if config.second_read_num_new_layers is None:
                raise ValueError("Must specify `second_read_num_new_layers`"
                                 "when `second_read_type` is new_layers")

            self.second_read_transformer = readtwice_layers.TransformerWithSideInputLayers(
                hidden_size=config.hidden_size,
                num_hidden_layers=config.second_read_num_new_layers,
                num_attention_heads=config.num_attention_heads,
                intermediate_size=config.intermediate_size,
                hidden_act=tensor_utils.get_activation(config.hidden_act),
                hidden_dropout_prob=config.hidden_dropout_prob,
                attention_probs_dropout_prob=config.
                attention_probs_dropout_prob,
                initializer_range=config.initializer_range,
                share_kv_projections=config.share_kv_projections,
                name="transformer_layers")
        elif config.second_read_type == "cross_attend_once":
            if config.second_read_num_new_layers is None:
                raise ValueError(
                    "Must specify `second_read_num_new_layers`"
                    "when `second_read_type` is cross_attend_once")
            if config.second_read_num_cross_attention_heads is None:
                raise ValueError(
                    "Must specify `second_read_num_cross_attention_heads`"
                    "when `second_read_type` is cross_attend_once")
            if config.second_read_enable_default_side_input is None:
                raise ValueError(
                    "Must specify `second_read_enable_default_side_input`"
                    "when `second_read_type` is cross_attend_once")

            self.cross_attention_layer = readtwice_layers.ResidualBlock(
                inner_layer=readtwice_layers.SideAttention(
                    hidden_size=config.hidden_size,
                    num_heads=config.second_read_num_cross_attention_heads,
                    att_dropout_prob=0,
                    initializer=tf.keras.initializers.TruncatedNormal(
                        stddev=config.initializer_range),
                    top_k_attention=config.cross_attention_top_k,
                    pos_embed_mode=config.cross_attention_pos_emb_mode,
                    pos_embed_size=config.max_num_blocks_per_document,
                    use_one_hot_embeddings=use_one_hot_embeddings,
                    enable_default_side_input=config.
                    second_read_enable_default_side_input),
                dropout_probability=config.hidden_dropout_prob,
                use_pre_activation_order=False,
                name="cross_attention_layer")

            self.second_read_transformer = readtwice_layers.TransformerWithSideInputLayers(
                hidden_size=config.hidden_size,
                num_hidden_layers=config.second_read_num_new_layers,
                num_attention_heads=config.num_attention_heads,
                intermediate_size=config.intermediate_size,
                hidden_act=tensor_utils.get_activation(config.hidden_act),
                hidden_dropout_prob=config.hidden_dropout_prob,
                attention_probs_dropout_prob=config.
                attention_probs_dropout_prob,
                initializer_range=config.initializer_range,
                share_kv_projections=True,
                name="transformer_layers")
        elif config.second_read_type == "new_layers_cross_attention":
            if config.second_read_num_new_layers is None:
                raise ValueError(
                    "Must specify `second_read_num_new_layers`"
                    "when `second_read_type` is cross_attend_once")
            if config.second_read_num_cross_attention_heads is None:
                raise ValueError(
                    "Must specify `second_read_num_cross_attention_heads`"
                    "when `second_read_type` is cross_attend_once")
            if config.second_read_enable_default_side_input is None:
                raise ValueError(
                    "Must specify `second_read_enable_default_side_input`"
                    "when `second_read_type` is cross_attend_once")

            self.second_read_transformer = readtwice_layers.TransformerWithSideInputLayers(
                hidden_size=config.hidden_size,
                num_hidden_layers=config.second_read_num_new_layers,
                num_attention_heads=config.num_attention_heads,
                intermediate_size=config.intermediate_size,
                hidden_act=tensor_utils.get_activation(config.hidden_act),
                hidden_dropout_prob=config.hidden_dropout_prob,
                attention_probs_dropout_prob=config.
                attention_probs_dropout_prob,
                initializer_range=config.initializer_range,
                share_kv_projections=True,
                num_cross_attention_heads=(
                    config.second_read_num_cross_attention_heads),
                enable_default_side_input=(
                    config.second_read_enable_default_side_input),
                name="transformer_layers")
        else:
            if config.second_read_type != "from_scratch":
                raise ValueError("Unknown `second_read_type`: '{}'".format(
                    config.second_read_type))