Example #1
0
def predict_coqa_customized(strategy, input_meta_data, bert_config,
                            predict_tfrecord_path, num_steps):
    """Make predictions using a Bert-based coqa model."""
    primary_cpu_task = '/job:worker' if FLAGS.tpu else ''

    num_train_examples = input_meta_data['train_data_size']
    max_seq_length = input_meta_data['max_seq_length']
    max_answer_length = input_meta_data['max_answer_length']

    # add use_pointer_gen
    bert_config.add_from_dict({"use_pointer_gen": FLAGS.use_pointer_gen})
    # max_oov_size  let's just add something for now
    bert_config.add_from_dict({"max_oov_size": FLAGS.max_oov_size})
    bert_config.add_from_dict({"max_seq_length": max_seq_length})
    bert_config.add_from_dict({"max_answer_length": max_answer_length})

    with tf.device(primary_cpu_task):
        predict_dataset = input_pipeline.create_coqa_dataset_seq2seq(
            predict_tfrecord_path,
            input_meta_data['max_seq_length'],
            max_answer_length,
            FLAGS.predict_batch_size,
            is_training=False)
        predict_iterator = iter(
            strategy.experimental_distribute_dataset(predict_dataset))

        with strategy.scope():
            # Prediction always uses float32, even if training uses mixed precision.
            #tf.keras.mixed_precision.experimental.set_policy('float32')
            coqa_model, _ = coqa_models.coqa_model_bert_2heads(
                config=bert_config,
                max_seq_length=input_meta_data['max_seq_length'],
                max_answer_length=max_answer_length,
                float_type=tf.float32)

        checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
        logging.info('Restoring checkpoints from %s', checkpoint_path)
        checkpoint = tf.train.Checkpoint(model=coqa_model)
        checkpoint.restore(checkpoint_path).expect_partial()

        def decode_sequence(x):

            pred_ids = tf.unstack(x['decode_ids'], axis=-1)
            pred_mask = tf.unstack(x['decode_mask'], axis=-1)

            for i in range(1, bert_config.max_answer_length):
                unique_ids, logits, _, _ = coqa_model(inputs=({
                    'unique_ids':
                    x['unique_ids'],
                    'input_word_ids':
                    x['input_word_ids'],
                    'input_type_ids':
                    x['input_type_ids'],
                    'input_mask':
                    x['input_mask'],
                    'decode_ids':
                    tf.stack(pred_ids, axis=1),
                    'decode_mask':
                    tf.stack(pred_mask, axis=1),
                }),
                                                      training=False)

                next_pred = tf.argmax(logits, axis=-1, output_type=tf.int32)

                # Only update the i-th column in one step.
                pred_ids[i] = next_pred[:, i - 1]
                pred_mask[i] = tf.cast(
                    tf.not_equal(next_pred[:, i - 1], 105),
                    tf.int32)  #tf.not_equal(next_pred[:, i - 1], 105)
                #pred_mask[:,i]
            return x['unique_ids'], next_pred

        #@tf.function
        def predict_step(iterator):
            """Predicts on distributed devices."""
            def _replicated_step(inputs):
                """Replicated prediction calculation."""
                x, _ = inputs
                unique_ids, sentence_ids = decode_sequence(x)

                return dict(unique_ids=unique_ids, sentence_ids=sentence_ids)

            outputs = strategy.experimental_run_v2(_replicated_step,
                                                   args=(next(iterator), ))
            return outputs

        all_results = []
        for _ in range(num_steps):
            predictions = predict_step(predict_iterator)

            for result in get_raw_results(predictions):
                all_results.append(result)
            if len(all_results) % 100 == 0:
                logging.info('Made predictions for %d records.',
                             len(all_results))
        return all_results
Example #2
0
def predict_coqa_customized(strategy, input_meta_data, bert_config,
                             predict_tfrecord_path, num_steps):
  """Make predictions using a Bert-based coqa model."""
  primary_cpu_task = '/job:worker' if FLAGS.tpu else ''

  num_train_examples = input_meta_data['train_data_size']
  max_seq_length = input_meta_data['max_seq_length']
  max_answer_length = input_meta_data['max_answer_length']

  # add use_pointer_gen
  bert_config.add_from_dict({"use_pointer_gen": FLAGS.use_pointer_gen})
  # max_oov_size  let's just add something for now
  bert_config.add_from_dict({"max_oov_size": FLAGS.max_oov_size})
  bert_config.add_from_dict({"max_seq_length": max_seq_length})

  with tf.device(primary_cpu_task):
    predict_dataset = input_pipeline.create_coqa_dataset_seq2seq(
        predict_tfrecord_path,
        input_meta_data['max_seq_length'],
        max_answer_length,
        FLAGS.predict_batch_size,
        is_training=False)
    predict_iterator = iter(
        strategy.experimental_distribute_dataset(predict_dataset))

    with strategy.scope():
      # Prediction always uses float32, even if training uses mixed precision.
      #tf.keras.mixed_precision.experimental.set_policy('float32')
      coqa_model, _ = coqa_models.coqa_modelseq2seq(
          config=bert_config,
          max_seq_length=input_meta_data['max_seq_length'],
          max_answer_length=max_answer_length,
          max_oov_size= FLAGS.max_oov_size,
          float_type=tf.float32)

    checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
    logging.info('Restoring checkpoints from %s', checkpoint_path)
    checkpoint = tf.train.Checkpoint(model=coqa_model)
    checkpoint.restore(checkpoint_path).expect_partial()

    encoder, decoder = coqa_models.one_step_decoder_model(coqa_model,bert_config)

    def decode_sequence(x):
    # Encode the input as state vectors.
        states_value,enc_feature = encoder([x['input_word_ids'],x['input_mask'],x['input_type_ids']])

        # Generate empty target sequence of length 1.
        target_seq = tf.unstack(x['decode_ids'],axis=1)[0]

        target_seq=tf.expand_dims(target_seq,axis=1)
        # Sampling loop for a batch of sequences
        # (to simplify, here we assume a batch of size 1).
        stop_condition = False
        decoded_sentence = ''

        steps = 0
        batch_size = target_seq.shape[0]
        results = []

        while steps < FLAGS.max_answer_length:
            output_tokens, h, c = decoder( [target_seq] + states_value+ [x['input_mask']]+[enc_feature])
            results.append(output_tokens)

            target_seq=output_tokens
            # Update states
            states_value = [h, c]
            steps += 1

        decoded_sentences=tf.squeeze(tf.stack(results,axis=1))

        return x['unique_ids'] , decoded_sentences


    @tf.function
    def predict_step(iterator):
      """Predicts on distributed devices."""

      def _replicated_step(inputs):
        """Replicated prediction calculation."""
        x, _ = inputs
        unique_ids, sentence_ids  = decode_sequence(x)

        return dict(
            unique_ids=unique_ids,
            sentence_ids=sentence_ids )

      outputs = strategy.experimental_run_v2(
          _replicated_step, args=(next(iterator),))
      return outputs

    all_results = []
    for _ in range(num_steps):
      predictions = predict_step(predict_iterator)

      for result in get_raw_results(predictions):
        all_results.append(result)
      if len(all_results) % 100 == 0:
        logging.info('Made predictions for %d records.', len(all_results))
    return all_results