Ejemplo n.º 1
0
def get_model(args,
              model_config,
              steps_per_epoch,
              warmup_steps,
              num_labels,
              max_seq_length,
              is_hub_module=False):
    # Get classifier and core model (used to initialize from checkpoint)
    if args.init_checkpoint is None and PRETRAINED_MODELS[
            args.model_class]['is_tfhub_model']:
        # load pretrained model from TF-hub
        hub_module_url = f"https://tfhub.dev/{PRETRAINED_MODELS[args.model_class]['hub_url']}"
        hub_module_trainable = True
    else:
        hub_module_url = None
        hub_module_trainable = False
    classifier_model, core_model = bert_models.classifier_model(
        model_config,
        num_labels,
        max_seq_length,
        hub_module_url=hub_module_url,
        hub_module_trainable=hub_module_trainable)
    # Optimizer
    optimizer = utils.optimizer.create_optimizer(
        args.learning_rate, steps_per_epoch * args.num_epochs, warmup_steps,
        args.end_lr, args.optimizer_type)
    classifier_model.optimizer = configure_optimizer(optimizer,
                                                     use_float16=False,
                                                     use_graph_rewrite=False)
    return classifier_model, core_model
Ejemplo n.º 2
0
def export_classifier(model_export_path, input_meta_data, bert_config,
                      model_dir):
    """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.
    bert_config: Bert configuration file to define core bert layers.
    model_dir: The directory where the model weights and training/evaluation
      summaries are stored.

  Raises:
    Export path is not specified, got an empty string or None.
  """
    if not model_export_path:
        raise ValueError('Export path is not specified: %s' %
                         model_export_path)
    if not model_dir:
        raise ValueError('Export path is not specified: %s' % model_dir)

    # Export uses float32 for now, even if training uses mixed precision.
    tf.keras.mixed_precision.experimental.set_policy('float32')
    classifier_model = bert_models.classifier_model(
        bert_config,
        input_meta_data.get('num_labels', 1),
        hub_module_url=FLAGS.hub_module_url,
        hub_module_trainable=False)[0]

    model_saving_utils.export_bert_model(model_export_path,
                                         model=classifier_model,
                                         checkpoint_dir=model_dir)
def predict(strategy, albert_config, input_meta_data, predict_input_fn):
    """Function outputs both the ground truth predictions as .tsv files."""
    with strategy.scope():
        classifier_model = bert_models.classifier_model(
            albert_config, input_meta_data['num_labels'])[0]
        checkpoint = tf.train.Checkpoint(model=classifier_model)
        latest_checkpoint_file = (FLAGS.predict_checkpoint_path or
                                  tf.train.latest_checkpoint(FLAGS.model_dir))
        assert latest_checkpoint_file
        logging.info(
            'Checkpoint file %s found and restoring from '
            'checkpoint', latest_checkpoint_file)
        checkpoint.restore(
            latest_checkpoint_file).assert_existing_objects_matched()
        preds, ground_truth = run_classifier_bert.get_predictions_and_labels(
            strategy, classifier_model, predict_input_fn, return_probs=True)
        output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
        with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
            logging.info('***** Predict results *****')
            for probabilities in preds:
                output_line = '\t'.join(
                    str(class_probability)
                    for class_probability in probabilities) + '\n'
                writer.write(output_line)
        ground_truth_labels_file = os.path.join(FLAGS.model_dir,
                                                'output_labels.tsv')
        with tf.io.gfile.GFile(ground_truth_labels_file, 'w') as writer:
            logging.info('***** Ground truth results *****')
            for label in ground_truth:
                output_line = '\t'.join(str(label)) + '\n'
                writer.write(output_line)
    return
Ejemplo n.º 4
0
 def _get_classifier_model():
   """Gets a classifier model."""
   classifier_model, core_model = (
       bert_models.classifier_model(
           bert_config, num_classes, self.seq_len, hub_module_url=self.uri))
   classifier_model.optimizer = optimization.create_optimizer(
       initial_lr, steps_per_epoch * epochs, warmup_steps)
   return classifier_model, core_model
Ejemplo n.º 5
0
def get_model(args, model_config, num_labels, max_seq_length, is_hub_module=False):
    classifier_model, _ = bert_models.classifier_model(
            model_config,
            num_labels,
            max_seq_length,
            hub_module_url=None,
            hub_module_trainable=False)
    return classifier_model
Ejemplo n.º 6
0
def load_tf_model(tf_config_path, num_labels=3, max_seq_length=96):
    config = bert_configs.BertConfig.from_json_file(tf_config_path)
    classifier_model, _ = bert_models.classifier_model(
        config,
        num_labels,
        max_seq_length,
        hub_module_url=None,
        hub_module_trainable=False)
    return classifier_model
Ejemplo n.º 7
0
def get_model(args, model_config, num_labels, max_seq_length):
    if args.use_tf_hub and PRETRAINED_MODELS[
            args.model_class]['is_tfhub_model']:
        hub_module_url = f"https://tfhub.dev/{PRETRAINED_MODELS[args.model_class]['hub_url']}"
        hub_module_trainable = True
    else:
        hub_module_url = None
        hub_module_trainable = False
    classifier_model, _ = bert_models.classifier_model(
        model_config,
        num_labels,
        max_seq_length,
        hub_module_url=hub_module_url,
        hub_module_trainable=hub_module_trainable)
    return classifier_model
Ejemplo n.º 8
0
 def _get_classifier_model():
     """Gets a classifier model."""
     classifier_model, core_model = (bert_models.classifier_model(
         bert_config,
         num_classes,
         max_seq_length,
         hub_module_url=FLAGS.hub_module_url,
         hub_module_trainable=FLAGS.hub_module_trainable))
     optimizer = optimization.create_optimizer(initial_lr,
                                               steps_per_epoch * epochs,
                                               warmup_steps, FLAGS.end_lr,
                                               FLAGS.optimizer_type)
     classifier_model.optimizer = performance.configure_optimizer(
         optimizer, use_float16=common_flags.use_float16())
     return classifier_model, core_model
Ejemplo n.º 9
0
 def _get_classifier_model():
     """Gets a classifier model."""
     bert_config = bert_configs.BertConfig.from_json_file(
         FLAGS.bert_config_file)
     num_classes = input_meta_data['num_labels']
     classifier_model, core_model = (bert_models.classifier_model(
         bert_config,
         num_classes,
         input_meta_data['max_seq_length'],
         hub_module_url=FLAGS.hub_module_url,
         hub_module_trainable=False))
     classifier_model.optimizer = optimization.create_optimizer(
         initial_lr, steps_per_epoch * epochs, warmup_steps)
     if FLAGS.fp16_implementation == 'graph_rewrite':
         classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
             classifier_model.optimizer)
     return classifier_model, core_model
Ejemplo n.º 10
0
  def test_classifier_model(self):
    model, core_model = bert_models.classifier_model(
        self._bert_test_config,
        num_labels=3,
        max_seq_length=5,
        final_layer_initializer=None,
        hub_module_url=None,
        hub_module_trainable=None)
    self.assertIsInstance(model, tf.keras.Model)
    self.assertIsInstance(core_model, tf.keras.Model)

    # model has one classification output with num_labels=3.
    self.assertEqual(model.output.shape.as_list(), [None, 3])

    # Expect two output from core_model: sequence and classification output.
    self.assertIsInstance(core_model.output, list)
    self.assertLen(core_model.output, 2)
Ejemplo n.º 11
0
 def _get_model():
     """Gets a siamese model."""
     if FLAGS.model_type == 'siamese':
         model, core_model = (siamese_bert.siamese_model(
             bert_config, num_classes, siamese_type=FLAGS.siamese_type))
     else:
         model, core_model = (bert_models.classifier_model(
             bert_config, num_classes, max_seq_length))
     optimizer = optimization.create_optimizer(initial_lr,
                                               steps_per_epoch * epochs,
                                               warmup_steps, FLAGS.end_lr,
                                               FLAGS.optimizer_type)
     model.optimizer = performance.configure_optimizer(
         optimizer,
         use_float16=common_flags.use_float16(),
         use_graph_rewrite=common_flags.use_graph_rewrite())
     return model, core_model
Ejemplo n.º 12
0
 def _get_classifier_model():
     """Gets a classifier model."""
     classifier_model, core_model = (bert_models.classifier_model(
         bert_config,
         num_classes,
         max_seq_length,
         hub_module_url=FLAGS.hub_module_url))
     classifier_model.optimizer = optimization.create_optimizer(
         initial_lr, steps_per_epoch * epochs, warmup_steps)
     if FLAGS.fp16_implementation == 'graph_rewrite':
         # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
         # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
         # which will ensure tf.compat.v2.keras.mixed_precision and
         # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
         # up.
         classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
             classifier_model.optimizer)
     return classifier_model, core_model
Ejemplo n.º 13
0
  def _get_classifier_model():
    """Gets a classifier model."""
    bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
    classifier_model, core_model = (
        bert_models.classifier_model(
            bert_config,
            num_classes,
            input_meta_data['max_seq_length'],
            hub_module_url=FLAGS.hub_module_url,
            hub_module_trainable=False))

    epochs = FLAGS.num_train_epochs
    steps_per_epoch = int(70000 / 32)
    warmup_steps = int(2 * 70000 * 0.1 / 32)

    classifier_model.optimizer = optimization.create_optimizer(
        2e-5, steps_per_epoch * 2, warmup_steps)
    return classifier_model, core_model
Ejemplo n.º 14
0
def export_classifier(model_export_path, input_meta_data,
                      restore_model_using_load_weights, bert_config,
                      model_dir):
    """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.
    restore_model_using_load_weights: Whether to use checkpoint.restore() API
      for custom checkpoint or to use model.load_weights() API. There are 2
      different ways to save checkpoints. One is using tf.train.Checkpoint and
      another is using Keras model.save_weights(). Custom training loop
      implementation uses tf.train.Checkpoint API and Keras ModelCheckpoint
      callback internally uses model.save_weights() API. Since these two API's
      cannot be used together, model loading logic must be take into account how
      model checkpoint was saved.
    bert_config: Bert configuration file to define core bert layers.
    model_dir: The directory where the model weights and training/evaluation
      summaries are stored.

  Raises:
    Export path is not specified, got an empty string or None.
  """
    if not model_export_path:
        raise ValueError('Export path is not specified: %s' %
                         model_export_path)
    if not model_dir:
        raise ValueError('Export path is not specified: %s' % model_dir)

    # Export uses float32 for now, even if training uses mixed precision.
    tf.keras.mixed_precision.experimental.set_policy('float32')
    classifier_model = bert_models.classifier_model(
        bert_config, input_meta_data['num_labels'],
        input_meta_data['max_seq_length'])[0]

    model_saving_utils.export_bert_model(
        model_export_path,
        model=classifier_model,
        checkpoint_dir=model_dir,
        restore_model_using_load_weights=restore_model_using_load_weights)
Ejemplo n.º 15
0
def custom_main(custom_callbacks=None, custom_metrics=None):
    """Run classification or regression.

  Args:
    custom_callbacks: list of tf.keras.Callbacks passed to training loop.
    custom_metrics: list of metrics passed to the training loop.
  """
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)

    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))
    label_type = LABEL_TYPES_MAP[input_meta_data.get('label_type', 'int')]
    include_sample_weights = input_meta_data.get('has_sample_weights', False)

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'

    bert_config = bert_configs.BertConfig.from_json_file(
        FLAGS.bert_config_file)

    if FLAGS.mode == 'export_only':
        export_classifier(FLAGS.model_export_path, input_meta_data,
                          bert_config, FLAGS.model_dir)
        return

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)
    eval_input_fn = get_dataset_fn(
        FLAGS.eval_data_path,
        input_meta_data['max_seq_length'],
        FLAGS.eval_batch_size,
        is_training=False,
        label_type=label_type,
        include_sample_weights=include_sample_weights)

    if FLAGS.mode == 'predict':
        with strategy.scope():
            classifier_model = bert_models.classifier_model(
                bert_config, input_meta_data['num_labels'])[0]
            checkpoint = tf.train.Checkpoint(model=classifier_model)
            latest_checkpoint_file = (FLAGS.predict_checkpoint_path
                                      or tf.train.latest_checkpoint(
                                          FLAGS.model_dir))
            assert latest_checkpoint_file
            logging.info(
                'Checkpoint file %s found and restoring from '
                'checkpoint', latest_checkpoint_file)
            checkpoint.restore(
                latest_checkpoint_file).assert_existing_objects_matched()
            preds, _ = get_predictions_and_labels(strategy,
                                                  classifier_model,
                                                  eval_input_fn,
                                                  return_probs=True)
        output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
        with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
            logging.info('***** Predict results *****')
            for probabilities in preds:
                output_line = '\t'.join(
                    str(class_probability)
                    for class_probability in probabilities) + '\n'
                writer.write(output_line)
        return

    if FLAGS.mode != 'train_and_eval':
        raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
    train_input_fn = get_dataset_fn(
        FLAGS.train_data_path,
        input_meta_data['max_seq_length'],
        FLAGS.train_batch_size,
        is_training=True,
        label_type=label_type,
        include_sample_weights=include_sample_weights)
    run_bert(strategy,
             input_meta_data,
             bert_config,
             train_input_fn,
             eval_input_fn,
             custom_callbacks=custom_callbacks,
             custom_metrics=custom_metrics)
Ejemplo n.º 16
0
max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data['num_labels']
batch_size = 32
eval_batch_size = 32
train_input_fn = run_classifier.get_dataset_fn(train_data_path, max_seq_length, batch_size, is_training=True)
eval_input_fn = run_classifier.get_dataset_fn(eval_data_path, max_seq_length, eval_batch_size, is_training=False)

strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy='one_device', num_gpus=2)

with strategy.scope():
  training_dataset = train_input_fn()
  evaluation_dataset = eval_input_fn()
  bert_config = bert_configs.BertConfig.from_json_file(bert_config_file)
  classifier_model, encoder = bert_models.classifier_model(
      bert_config, num_classes, max_seq_length)

  checkpoint = tf.train.Checkpoint(model=encoder)
  #checkpoint.restore(ckpt_path).assert_consumed()

  epochs = 3
  train_data_size = input_meta_data['train_data_size']
  eval_data_size = input_meta_data['eval_data_size']
  steps_per_epoch = int(train_data_size / batch_size)
  warmup_steps = int(epochs * train_data_size * 0.1 / batch_size)
  optimizer = optimization.create_optimizer(
      2e-5, num_train_steps=steps_per_epoch * epochs, num_warmup_steps=warmup_steps)

  def metric_fn():
    return tf.keras.metrics.SparseCategoricalAccuracy(
        'test_accuracy', dtype=tf.float32)
Ejemplo n.º 17
0
def custom_main(custom_callbacks=None, custom_metrics=None):
    """Run classification or regression.

  Args:
    custom_callbacks: list of tf.keras.Callbacks passed to training loop.
    custom_metrics: list of metrics passed to the training loop.
  """
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)

    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))
    label_type = LABEL_TYPES_MAP[input_meta_data.get('label_type', 'int')]
    include_sample_weights = input_meta_data.get('has_sample_weights', False)

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/type-bert-finetune/'

    bert_config = bert_configs.BertConfig.from_json_file(
        FLAGS.bert_config_file)

    if FLAGS.mode == 'export_only':
        export_classifier(FLAGS.model_export_path, input_meta_data,
                          bert_config, FLAGS.model_dir)
        return

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)
    eval_input_fn = get_dataset_fn(
        FLAGS.eval_data_path,
        input_meta_data['max_seq_length'],
        FLAGS.eval_batch_size,
        is_training=False,
        label_type=label_type,
        include_sample_weights=include_sample_weights)

    if FLAGS.mode == 'predict':
        num_labels = input_meta_data.get('num_labels', 1)
        with strategy.scope():
            classifier_model = bert_models.classifier_model(
                bert_config, num_labels)[0]
            checkpoint = tf.train.Checkpoint(model=classifier_model)
            latest_checkpoint_file = (FLAGS.predict_checkpoint_path
                                      or tf.train.latest_checkpoint(
                                          FLAGS.model_dir))
            assert latest_checkpoint_file
            logging.info(
                'Checkpoint file %s found and restoring from '
                'checkpoint', latest_checkpoint_file)
            checkpoint.restore(
                latest_checkpoint_file).assert_existing_objects_matched()
            preds, labels, indices, probs, user_def_bool, filelist, input_seqs, seq_indices, times = get_predictions_and_labels(
                strategy,
                classifier_model,
                eval_input_fn,
                is_regression=(num_labels == 1),
                return_probs=True)

            ckpt = FLAGS.predict_checkpoint_path.split('-')[-1]
            output_predict_file = os.path.join(
                FLAGS.model_dir, 'test_results_{0}.tsv'.format(ckpt))
            output_probs_file = os.path.join(
                FLAGS.model_dir, 'test_probs_results_{0}.nps'.format(ckpt))
            output_indices_file = os.path.join(
                FLAGS.model_dir, 'test_indices_results_{0}.nps'.format(ckpt))
            with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
                logging.info('***** Predict results *****')
                for pred, label, is_user_def in zip(preds, labels,
                                                    user_def_bool):
                    output_line = '\t'.join(
                        [str(label), str(pred),
                         str(is_user_def)]) + '\n'
                    writer.write(output_line)
            indices = np.concatenate(indices, axis=0)
            probs = np.concatenate(probs, axis=0)
            np.save(output_probs_file, probs)
            np.save(output_indices_file, indices)
        return

    if FLAGS.mode != 'train_and_eval':
        raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
    train_input_fn = get_dataset_fn(
        FLAGS.train_data_path,
        input_meta_data['max_seq_length'],
        FLAGS.train_batch_size,
        is_training=True,
        label_type=label_type,
        include_sample_weights=include_sample_weights)

    if FLAGS.checkpoint_path:
        with strategy.scope():
            num_labels = input_meta_data.get('num_labels', 1)
            classifier_model = bert_models.classifier_model(
                bert_config, num_labels)[0]
            checkpoint = tf.train.Checkpoint(model=classifier_model)
            latest_checkpoint_file = (FLAGS.checkpoint_path
                                      or tf.train.latest_checkpoint(
                                          FLAGS.model_dir))
            assert latest_checkpoint_file
            logging.info(
                'Checkpoint file %s found and restoring from '
                'checkpoint', latest_checkpoint_file)
            checkpoint.restore(
                latest_checkpoint_file).assert_existing_objects_matched()

    run_bert(strategy,
             input_meta_data,
             bert_config,
             train_input_fn,
             eval_input_fn,
             custom_callbacks=custom_callbacks,
             custom_metrics=custom_metrics)