Пример #1
0
def get_config(version='base', batch_size=1):
    """
    get_config definition
    """
    if version == 'base':
        return BertConfig(
            batch_size=batch_size,
            seq_length=128,
            vocab_size=21128,
            hidden_size=768,
            num_hidden_layers=12,
            num_attention_heads=12,
            intermediate_size=3072,
            hidden_act="gelu",
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1,
            max_position_embeddings=512,
            type_vocab_size=2,
            initializer_range=0.02,
            use_relative_positions=True,
            input_mask_from_dataset=True,
            token_type_ids_from_dataset=True,
            dtype=mstype.float32,
            compute_type=mstype.float32)
    if version == 'large':
        return BertConfig(
            batch_size=batch_size,
            seq_length=128,
            vocab_size=21128,
            hidden_size=1024,
            num_hidden_layers=24,
            num_attention_heads=16,
            intermediate_size=4096,
            hidden_act="gelu",
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1,
            max_position_embeddings=512,
            type_vocab_size=2,
            initializer_range=0.02,
            use_relative_positions=True,
            input_mask_from_dataset=True,
            token_type_ids_from_dataset=True,
            dtype=mstype.float32,
            compute_type=mstype.float32)
    return BertConfig(batch_size=batch_size)
Пример #2
0
    'decay_steps': 1000,
    'power': 10.0,
    'save_checkpoint_steps': 2000,
    'keep_checkpoint_max': 10,
    'checkpoint_prefix': "checkpoint_bert",
    # please add your own dataset path
    'DATA_DIR': "/your/path/examples.tfrecord",
    # please add your own dataset schema path
    'SCHEMA_DIR': "/your/path/datasetSchema.json"
})
bert_net_cfg = BertConfig(
    batch_size=16,
    seq_length=128,
    vocab_size=21136,
    hidden_size=1024,
    num_hidden_layers=24,
    num_attention_heads=16,
    intermediate_size=4096,
    hidden_act="gelu",
    hidden_dropout_prob=0.0,
    attention_probs_dropout_prob=0.0,
    max_position_embeddings=512,
    type_vocab_size=2,
    initializer_range=0.02,
    use_relative_positions=True,
    input_mask_from_dataset=True,
    token_type_ids_from_dataset=True,
    dtype=mstype.float32,
    compute_type=mstype.float16,
)
Пример #3
0
    return dataset


def load_test_data():
    dataset = get_dataset()
    return dataset.next()


input_ids, input_mask, token_type_id, \
next_sentence_labels, masked_lm_positions, \
masked_lm_ids, masked_lm_weights = load_test_data()

test_sets = [
    ('BertNetworkWithLoss_1', {
        'block':
        BertNetworkWithLoss(BertConfig(batch_size=1),
                            False,
                            use_one_hot_embeddings=True),
        'desc_inputs': [
            input_ids, input_mask, token_type_id, next_sentence_labels,
            masked_lm_positions, masked_lm_ids, masked_lm_weights
        ],
        'desc_bprop': [[1]]
    }),
    ('BertNetworkWithLoss_2', {
        'block':
        BertNetworkWithLoss(BertConfig(batch_size=1), False, True),
        'desc_inputs': [
            input_ids, input_mask, token_type_id, next_sentence_labels,
            masked_lm_positions, masked_lm_ids, masked_lm_weights
        ],
Пример #4
0
def test_bert_model():
    # test for config.hidden_size % config.num_attention_heads != 0
    config_error = BertConfig(32, hidden_size=512, num_attention_heads=10)
    with pytest.raises(ValueError):
        BertModel(config_error, True)