def run_entity_marker_cls(bert_model_dir, do_lower_case): vocab_file = os.path.join(bert_model_dir, "vocab.txt") processor = input_processors.EntityProcessor(vocab_file=vocab_file, do_lower_case=do_lower_case, max_seq_length=128) head = heads.CLSHead(n_classes=1, out_activation="sigmoid", bias_initializer="zeros", dropout_rate=0.0) inputs = processor.get_input_placeholders() bert_params = bert.params_from_pretrained_ckpt(bert_model_dir) bert_params["vocab_size"] = processor.vocab_size model_ckpt = os.path.join(args.bert_model_dir, "bert_model.ckpt") # Calls model.build() model = get_bert_classifier(inputs, bert_params, model_ckpt, head) opt = tf.keras.optimizers.Adam(learning_rate=3e-5) loss_fn = "binary_crossentropy" model.compile(optimizer=opt, loss=loss_fn) bert_inputs, tokenized_docs = processor.process(docs["entity"]) loss = model.evaluate(bert_inputs, np.array([1])) print() print("=== Entity CLS ===") print(docs["entity"]) print(tokenized_docs) print(bert_inputs) print() model.summary() print(f"Loss: {loss}")
def main(args): if args.model_id == "3c": msg = "Model 3c: Positional Emb - Mention Pool, not yet implemented." raise NotImplementedError(msg) logger = get_logger(os.path.join(args.outdir, f"model_{args.model_id}.log")) logger.info(f"RUNNING MODEL: {args.model_id}") argstr = "Command line arguments:" for (k, v) in args.__dict__.items(): argstr += f"\n {k}: {v}" logger.info(argstr) # Build the input processor. vocab_file = os.path.join(args.bert_model_dir, "vocab.txt") processor = get_processor_from_model_id(args.model_id, vocab_file=vocab_file, do_lower_case=True, max_seq_length=128) # Training and validation data if args.test is True: print("TESTING MODEL OVERFITS SMALL TRAINING DATA SET") train_docs, train_labels = get_single_examples(args.train_file) val_docs = None else: all_train_docs, all_train_labels = read_dataset(args.train_file, n=-1) train_docs, val_docs, train_labels, val_labels = train_test_split( all_train_docs, all_train_labels, train_size=0.8, random_state=RANDOM_STATE, shuffle=True) train_inputs, train_tokens = processor.process(train_docs) assert len(train_docs) == len(train_inputs[0]) assert len(train_docs) == len(train_labels) logger.info(f"Training on {len(train_labels)} examples") if val_docs is not None: val_inputs, val_tokens = processor.process(val_docs) assert len(val_docs) == len(val_inputs[0]) assert len(val_docs) == len(val_labels) logger.info(f"Validating on {len(val_labels)} examples") val_data = (val_inputs, val_labels) else: val_data = None # Test data if args.test_file is not None: test_docs, test_labels = read_dataset(args.test_file, n=-1) test_inputs, _ = processor.process(test_docs) assert len(test_docs) == len(test_inputs[0]) assert len(test_docs) == len(test_labels) val_data = (test_inputs, test_labels) logger.info(f"Evaluating on {len(test_labels)} examples") # Build the BERT model initial_biases = compute_initial_bias(train_labels) initial_biases[initial_biases == -np.inf] = 0.0 bias_init = tf.keras.initializers.Constant(initial_biases) head = get_classification_head_from_model_id( args.model_id, n_classes=train_labels.shape[1], out_activation="softmax", bias_initializer=bias_init, dropout_rate=0.1) inputs = processor.get_input_placeholders() bert_params = bert.params_from_pretrained_ckpt(args.bert_model_dir) bert_params["vocab_size"] = processor.vocab_size logger.info(f"Input Processor: {processor.name}") logger.info(f"Classification Head: {head.name}") model_ckpt = os.path.join(args.bert_model_dir, "bert_model.ckpt") model = get_bert_classifier(inputs, bert_params, model_ckpt, head, logging_fn=logger.info) # Define the learning rate schedule # Linear warm up with polynomial decay train_data_size = len(train_labels) steps_per_epoch = int(train_data_size / args.batch_size) num_train_steps = steps_per_epoch * args.epochs warmup_steps = int(args.epochs * train_data_size * 0.1 / args.batch_size) if warmup_steps == 0: warmup_steps = 1 lr_schedule = lr_schedulers.PolynomialDecayWithLinearWarmup( target_learning_rate=args.learning_rate, warmup_steps=warmup_steps, decay_steps=num_train_steps) optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, epsilon=1e-8) loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=False) model.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"]) # Visualize the model model.summary(print_fn=logger.info) # Check that we've initialized properly expected_loss = compute_expected_loss_categorical(train_labels) init_bias_str = str(dict(zip(datasets.SEMEVAL_LABELS, initial_biases))) logger.info(f"Initial biases: {init_bias_str}") logger.info(f"Expected starting loss: {expected_loss}") actual_loss = model.evaluate(train_inputs, train_labels, verbose=1) logger.info(f"Actual starting loss: {actual_loss[0]}") # Define training callbacks tb_logdir = os.path.join(args.outdir, f"tensorboard_logs/model_{args.model_id}") if not os.path.exists(tb_logdir): os.makedirs(tb_logdir) tb_callback = tf.keras.callbacks.TensorBoard(log_dir=tb_logdir, histogram_freq=1, write_graph=True, write_images=True, update_freq="batch", embeddings_freq=1) lr_callback = LogLearningRateCallback(log_dir=tb_logdir) early_stop_callback = tf.keras.callbacks.EarlyStopping( monitor="val_loss", patience=1, restore_best_weights=True) ckpt_outfile = os.path.join( args.outdir, f"model_checkpoints/model_{args.model_id}/model_ckpt") ckpt_callback = tf.keras.callbacks.ModelCheckpoint(ckpt_outfile, monitor="val_loss", save_best_only=True, save_weights_only=True) progress_callback = tf.keras.callbacks.LambdaCallback( on_epoch_end=lambda epoch, logs: logger.info(f"Completed epoch {epoch}" )) callbacks = [ tb_callback, lr_callback, early_stop_callback, ckpt_callback, progress_callback ] # Train the model logger.info(f"TRAINING FOR {args.epochs} EPOCHS") history = model.fit(train_inputs, train_labels, batch_size=args.batch_size, epochs=args.epochs, validation_data=val_data, shuffle=True, callbacks=callbacks, verbose=1) logger.info("TRAINING COMPLETE") res_str = "Final metrics on train/val data:" for (k, v) in history.history.items(): res_str += f"\n {k}: {v[-1]}" logger.info(res_str) # Evaluate the model if args.test_file is not None: logger.info("EVALUATING ON TEST DATA") predictions = model.predict(test_inputs, batch_size=args.batch_size, verbose=1) true_labels = np.argmax(test_labels, axis=1) pred_labels = np.argmax(predictions, axis=1) report = classification_report(true_labels, pred_labels, target_names=datasets.SEMEVAL_LABELS) logger.info('\n' + report)