Exemplo n.º 1
0
def create_model(features,
                 mode,
                 bert_config,
                 disabled_features=None,
                 disable_position_embeddings=False):
    """Creates a TABLE BERT model."""
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    token_type_features = [
        "segment_ids", "column_ids", "row_ids", "prev_label_ids",
        "column_ranks", "inv_column_ranks", "numeric_relations"
    ]
    token_type_ids = []
    for key in token_type_features:
        if disabled_features is not None and key in disabled_features:
            token_type_ids.append(tf.zeros_like(features[key]))
        else:
            token_type_ids.append(features[key])

    return modeling.BertModel(
        config=bert_config,
        is_training=is_training,
        input_ids=features["input_ids"],
        input_mask=features["input_mask"],
        token_type_ids=token_type_ids,
        use_position_embeddings=not disable_position_embeddings)
Exemplo n.º 2
0
def create_model(
    features,
    mode,
    bert_config,
    restrict_attention_mode=_AttentionMode.FULL,
    restrict_attention_bucket_size=0,
    restrict_attention_header_size=None,
    restrict_attention_row_heads_ratio=0.5,
    restrict_attention_sort_after_projection=True,
    token_weights=None,
    disabled_features=None,
    disable_position_embeddings=False,
    reset_position_index_per_cell=False,
    proj_value_length=None,
    attention_bias_disabled=0,
    attention_bias_use_relative_scalar_only=True,
):
    """Creates a TABLE BERT model."""
    is_training = (mode == tf_estimator.ModeKeys.TRAIN)
    token_type_features = get_token_type_features()
    token_type_ids, disabled_ids = get_token_type_ids(
        features=features,
        token_type_features=token_type_features,
        disabled_features=disabled_features)
    custom_attention = run_custom_attention(
        is_training=is_training,
        num_attention_heads=bert_config.num_attention_heads,
        config=CustomAttentionConfig(
            restrict_attention_mode=restrict_attention_mode,
            restrict_attention_row_heads_ratio=
            restrict_attention_row_heads_ratio,
            restrict_attention_bucket_size=restrict_attention_bucket_size,
            restrict_attention_header_size=restrict_attention_header_size,
            restrict_attention_sort_after_projection=
            restrict_attention_sort_after_projection,
            attention_bias_disabled=attention_bias_disabled,
            attention_bias_use_relative_scalar_only=
            attention_bias_use_relative_scalar_only,
        ),
        features=features,
    )
    return modeling.BertModel(
        config=bert_config,
        is_training=is_training,
        input_ids=features["input_ids"],
        input_mask=features["input_mask"],
        attention_mask=custom_attention.attention_mask,
        custom_attention_layer=custom_attention.custom_attention_layer,
        token_weights=token_weights,
        token_type_ids=token_type_ids,
        disabled_ids=disabled_ids,
        use_position_embeddings=not disable_position_embeddings,
        reset_position_index_per_cell=reset_position_index_per_cell,
        proj_value_length=proj_value_length,
    )
Exemplo n.º 3
0
        def create_model(self):
            input_ids = BertModelTest.ids_tensor(
                [self.batch_size, self.seq_length], self.vocab_size)

            input_mask = None
            if self.use_input_mask:
                input_mask = BertModelTest.ids_tensor(
                    [self.batch_size, self.seq_length], vocab_size=2)

            token_type_ids = None
            if self.use_token_type_ids:
                token_type_ids = BertModelTest.ids_tensor(
                    [self.batch_size, self.seq_length], self.type_vocab_size)

            config = modeling.BertConfig(
                vocab_size=self.vocab_size,
                hidden_size=self.hidden_size,
                num_hidden_layers=self.num_hidden_layers,
                num_attention_heads=self.num_attention_heads,
                intermediate_size=self.intermediate_size,
                hidden_act=self.hidden_act,
                hidden_dropout_prob=self.hidden_dropout_prob,
                attention_probs_dropout_prob=self.attention_probs_dropout_prob,
                max_position_embeddings=self.max_position_embeddings,
                type_vocab_size=self.type_vocab_size,
                initializer_range=self.initializer_range,
                softmax_temperature=self.softmax_temperature)

            model = modeling.BertModel(
                config=config,
                is_training=self.is_training,
                input_ids=input_ids,
                input_mask=input_mask,
                token_type_ids=token_type_ids,
                scope=self.scope,
                proj_value_length=self.proj_value_length)

            outputs = {
                "embedding_output": model.get_embedding_output(),
                "sequence_output": model.get_sequence_output(),
                "pooled_output": model.get_pooled_output(),
                "all_encoder_layers": model.get_all_encoder_layers(),
            }
            return outputs