# ## Build Keras Model

# In[9]:

if USE_AMP:
    tf.keras.mixed_precision.experimental.set_policy('infer_float32_vars')

in_id = layers.Input(shape=(MAX_SEQ_LEN, ), name="input_ids")
in_mask = layers.Input(shape=(MAX_SEQ_LEN, ), name="input_masks")
in_segment = layers.Input(shape=(MAX_SEQ_LEN, ), name="segment_ids")

in_bert = [in_id, in_mask, in_segment]

l_bert = bert_utils.BERT(fine_tune_layers=TUNE_LAYERS,
                         bert_path=BERT_PATH,
                         return_sequence=False,
                         output_size=H_SIZE,
                         debug=False)(in_bert)

out_pred = layers.Dense(num_classes, activation="softmax")(l_bert)

model = tf.keras.models.Model(inputs=in_bert, outputs=out_pred)

# In[10]:

opt = tf.keras.optimizers.Adam(lr=LEARNING_RATE)

if USE_AMP:
    opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
        opt, "dynamic")
示例#2
0
test_dataset = tf.keras.utils.get_file("news_test.csv.zip", TEST_SET_URL,
                                       cache_subdir='datasets', extract=True)
"""

MAX_SEQ_LEN = 512

if args.bertbase:
    print("[INFO ] Caching BERTBASE")
    BERT_PATH = "https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1"
    in_id = layers.Input(shape=(MAX_SEQ_LEN, ), name="input_ids")
    in_mask = layers.Input(shape=(MAX_SEQ_LEN, ), name="input_masks")
    in_segment = layers.Input(shape=(MAX_SEQ_LEN, ), name="segment_ids")
    in_bert = [in_id, in_mask, in_segment]
    l_bert = bert_utils.BERT(fine_tune_layers=-1,
                             bert_path=BERT_PATH,
                             return_sequence=False,
                             output_size=768,
                             debug=False)(in_bert)

if args.bertlarge:
    print("[INFO ] Caching BERTLARGE")
    BERT_PATH = "https://tfhub.dev/google/bert_uncased_L-24_H-1024_A-16/1"
    in_id = layers.Input(shape=(MAX_SEQ_LEN, ), name="input_ids")
    in_mask = layers.Input(shape=(MAX_SEQ_LEN, ), name="input_masks")
    in_segment = layers.Input(shape=(MAX_SEQ_LEN, ), name="segment_ids")
    in_bert = [in_id, in_mask, in_segment]
    l_bert = bert_utils.BERT(fine_tune_layers=-1,
                             bert_path=BERT_PATH,
                             return_sequence=False,
                             output_size=1024,
                             debug=False)(in_bert)