Exemplo n.º 1
0
 def _get_pretrain_model():
   """Gets a pretraining model."""
   pretrain_model, core_model = bert_models.pretrain_model(
       bert_config, max_seq_length, max_predictions_per_seq)
   optimizer = optimization.create_optimizer(
       initial_lr, steps_per_epoch * epochs, warmup_steps)
   pretrain_model.optimizer = performance.configure_optimizer(
       optimizer,
       use_float16=common_flags.use_float16(),
       use_graph_rewrite=common_flags.use_graph_rewrite())
   return pretrain_model, core_model
Exemplo n.º 2
0
 def _get_model():
     """Gets a ner model."""
     model, core_model = (ner_model(bert_config, num_classes,
                                    FLAGS.use_crf))
     optimizer = optimization.create_optimizer(initial_lr,
                                               steps_per_epoch * epochs,
                                               warmup_steps, FLAGS.end_lr,
                                               FLAGS.optimizer_type)
     model.optimizer = performance.configure_optimizer(
         optimizer,
         use_float16=common_flags.use_float16(),
         use_graph_rewrite=common_flags.use_graph_rewrite())
     return model, core_model
Exemplo n.º 3
0
    def _get_squad_model():
        """Get Squad model and optimizer."""
        squad_model, core_model = bert_models.squad_model(
            bert_config,
            max_seq_length,
            hub_module_url=FLAGS.hub_module_url,
            hub_module_trainable=FLAGS.hub_module_trainable)
        optimizer = optimization.create_optimizer(FLAGS.learning_rate,
                                                  steps_per_epoch * epochs,
                                                  warmup_steps)

        squad_model.optimizer = performance.configure_optimizer(
            optimizer,
            use_float16=common_flags.use_float16(),
            use_graph_rewrite=common_flags.use_graph_rewrite())
        return squad_model, core_model
Exemplo n.º 4
0
 def _get_classifier_model():
     """Gets a classifier model."""
     classifier_model, core_model = (bert_models.classifier_model(
         bert_config,
         num_classes,
         max_seq_length,
         hub_module_url=FLAGS.hub_module_url,
         hub_module_trainable=FLAGS.hub_module_trainable))
     optimizer = optimization.create_optimizer(initial_lr,
                                               steps_per_epoch * epochs,
                                               warmup_steps)
     classifier_model.optimizer = performance.configure_optimizer(
         optimizer,
         use_float16=common_flags.use_float16(),
         use_graph_rewrite=common_flags.use_graph_rewrite())
     return classifier_model, core_model
Exemplo n.º 5
0
 def _get_pretrain_model():
     """Gets a pretraining model."""
     pretrain_model, core_model = bert_models.pretrain_model(
         bert_config,
         max_seq_length,
         max_predictions_per_seq,
         use_next_sentence_label=use_next_sentence_label)
     optimizer = optimization.create_optimizer(initial_lr,
                                               steps_per_epoch * epochs,
                                               warmup_steps, end_lr,
                                               optimizer_type)
     pretrain_model.optimizer = performance.configure_optimizer(
         optimizer,
         use_float16=common_flags.use_float16(),
         use_graph_rewrite=common_flags.use_graph_rewrite(),
         use_experimental_api=False)
     return pretrain_model, core_model
Exemplo n.º 6
0
 def _get_model():
     """Gets a siamese model."""
     if FLAGS.model_type == 'siamese':
         model, core_model = (siamese_bert.siamese_model(
             bert_config, num_classes, siamese_type=FLAGS.siamese_type))
     else:
         model, core_model = (bert_models.classifier_model(
             bert_config, num_classes, max_seq_length))
     optimizer = optimization.create_optimizer(initial_lr,
                                               steps_per_epoch * epochs,
                                               warmup_steps, FLAGS.end_lr,
                                               FLAGS.optimizer_type)
     model.optimizer = performance.configure_optimizer(
         optimizer,
         use_float16=common_flags.use_float16(),
         use_graph_rewrite=common_flags.use_graph_rewrite())
     return model, core_model