def _get_pretrain_model(): """Gets a pretraining model.""" pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq) optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) pretrain_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), use_graph_rewrite=common_flags.use_graph_rewrite()) return pretrain_model, core_model
def _get_model(): """Gets a ner model.""" model, core_model = (ner_model(bert_config, num_classes, FLAGS.use_crf)) optimizer = optimization.create_optimizer(initial_lr, steps_per_epoch * epochs, warmup_steps, FLAGS.end_lr, FLAGS.optimizer_type) model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), use_graph_rewrite=common_flags.use_graph_rewrite()) return model, core_model
def _get_squad_model(): """Get Squad model and optimizer.""" squad_model, core_model = bert_models.squad_model( bert_config, max_seq_length, hub_module_url=FLAGS.hub_module_url, hub_module_trainable=FLAGS.hub_module_trainable) optimizer = optimization.create_optimizer(FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps) squad_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), use_graph_rewrite=common_flags.use_graph_rewrite()) return squad_model, core_model
def _get_classifier_model(): """Gets a classifier model.""" classifier_model, core_model = (bert_models.classifier_model( bert_config, num_classes, max_seq_length, hub_module_url=FLAGS.hub_module_url, hub_module_trainable=FLAGS.hub_module_trainable)) optimizer = optimization.create_optimizer(initial_lr, steps_per_epoch * epochs, warmup_steps) classifier_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), use_graph_rewrite=common_flags.use_graph_rewrite()) return classifier_model, core_model
def _get_pretrain_model(): """Gets a pretraining model.""" pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq, use_next_sentence_label=use_next_sentence_label) optimizer = optimization.create_optimizer(initial_lr, steps_per_epoch * epochs, warmup_steps, end_lr, optimizer_type) pretrain_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), use_graph_rewrite=common_flags.use_graph_rewrite(), use_experimental_api=False) return pretrain_model, core_model
def _get_model(): """Gets a siamese model.""" if FLAGS.model_type == 'siamese': model, core_model = (siamese_bert.siamese_model( bert_config, num_classes, siamese_type=FLAGS.siamese_type)) else: model, core_model = (bert_models.classifier_model( bert_config, num_classes, max_seq_length)) optimizer = optimization.create_optimizer(initial_lr, steps_per_epoch * epochs, warmup_steps, FLAGS.end_lr, FLAGS.optimizer_type) model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), use_graph_rewrite=common_flags.use_graph_rewrite()) return model, core_model