コード例 #1
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    processors = {
        "sst-2": run_classifier.SST2Processor,
        "mnli": run_classifier.MnliProcessor
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint1)
    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint2)

    if not tf.train.checkpoint_exists(FLAGS.init_checkpoint1):
        raise TFCheckpointNotFoundError("checkpoint1 does not exist!")

    if not tf.train.checkpoint_exists(FLAGS.init_checkpoint2) and \
       not FLAGS.use_random:
        raise TFCheckpointNotFoundError("checkpoint2 does not exist!")

    bert_config1 = modeling.BertConfig.from_json_file(FLAGS.bert_config_file1)
    bert_config2 = modeling.BertConfig.from_json_file(FLAGS.bert_config_file2)

    if FLAGS.max_seq_length > bert_config1.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config1.max_position_embeddings))

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    all_results = []

    predict_examples = processor.get_test_examples(FLAGS.diff_input_file)
    num_actual_predict_examples = len(predict_examples)

    # For single sentence tasks (like SST2) eg.text_b is None
    original_data = [(eg.text_a, eg.text_b) for eg in predict_examples]
    if FLAGS.use_tpu:
        # TPU requires a fixed batch size for all batches, therefore the number
        # of examples must be a multiple of the batch size, or else examples
        # will get dropped. So we pad with fake examples which are ignored
        # later on.
        while len(predict_examples) % FLAGS.predict_batch_size != 0:
            predict_examples.append(run_classifier.PaddingInputExample())

    predict_file = os.path.join(FLAGS.init_checkpoint1,
                                FLAGS.exp_name + ".predict.tf_record")

    run_classifier.file_based_convert_examples_to_features(
        predict_examples, label_list, FLAGS.max_seq_length, tokenizer,
        predict_file)

    for bert_config_type, output_dir in [
        (bert_config1, FLAGS.init_checkpoint1),
        (bert_config2, FLAGS.init_checkpoint2)
    ]:
        tpu_cluster_resolver = None
        if FLAGS.use_tpu and FLAGS.tpu_name:
            tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
                FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

        is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
        run_config = tf.contrib.tpu.RunConfig(
            cluster=tpu_cluster_resolver,
            master=FLAGS.master,
            model_dir=output_dir,
            save_checkpoints_steps=FLAGS.save_checkpoints_steps,
            tpu_config=tf.contrib.tpu.TPUConfig(
                iterations_per_loop=FLAGS.iterations_per_loop,
                num_shards=FLAGS.num_tpu_cores,
                per_host_input_for_training=is_per_host))

        model_fn = run_classifier.model_fn_builder(
            bert_config=bert_config_type,
            num_labels=len(label_list),
            # This init checkpoint is eventually overriden by the estimator
            init_checkpoint=FLAGS.output_dir,
            learning_rate=FLAGS.learning_rate,
            num_train_steps=None,
            num_warmup_steps=None,
            use_tpu=FLAGS.use_tpu,
            use_one_hot_embeddings=FLAGS.use_tpu)

        # If TPU is not available, this will fall back to normal Estimator on CPU
        # or GPU.
        estimator = tf.contrib.tpu.TPUEstimator(
            use_tpu=FLAGS.use_tpu,
            model_fn=model_fn,
            config=run_config,
            train_batch_size=FLAGS.train_batch_size,
            eval_batch_size=FLAGS.eval_batch_size,
            predict_batch_size=FLAGS.predict_batch_size)

        tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(predict_examples), num_actual_predict_examples,
                        len(predict_examples) - num_actual_predict_examples)
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_drop_remainder = True if FLAGS.use_tpu else False
        predict_input_fn = run_classifier.file_based_input_fn_builder(
            input_file=predict_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=predict_drop_remainder)

        result = [x for x in estimator.predict(input_fn=predict_input_fn)]
        all_results.append(result)

    all_results[0] = all_results[0][:num_actual_predict_examples]
    all_results[1] = all_results[1][:num_actual_predict_examples]

    assert len(all_results[0]) == len(all_results[1])

    # Assuming model1's predictions are gold labels, calculate model2's accuracy
    score = 0
    for prob1, prob2 in zip(all_results[0], all_results[1]):
        if np.argmax(prob1["probabilities"]) == np.argmax(
                prob2["probabilities"]):
            score += 1

    tf.logging.info("Agreement score = %.6f",
                    float(score) / num_actual_predict_examples)

    # Calculate the average value of |v1 - v2|, the distance on the simplex
    # Unlike KL divergence, this is a bounded metric
    # However, these results are not comparable across tasks
    # with different number classes
    distances = []
    for prob1, prob2 in zip(all_results[0], all_results[1]):
        distances.append(
            np.linalg.norm(prob1["probabilities"] - prob2["probabilities"]))

    tf.logging.info("Average length |p1 - p2| = %.8f", np.mean(distances))
    tf.logging.info("Max length |p1 - p2| = %.8f", np.max(distances))
    tf.logging.info("Min length |p1 - p2| = %.8f", np.min(distances))
    tf.logging.info("Std length |p1 - p2| = %.8f", np.std(distances))

    if FLAGS.diff_type == "kld1":
        all_kld = []

        for prob1, prob2 in zip(all_results[0], all_results[1]):
            all_kld.append(
                stats.entropy(prob1["probabilities"], prob2["probabilities"]))

        tf.logging.info("Average kl-divergence (p1, p2) = %.8f",
                        np.mean(all_kld))
        tf.logging.info("Max kl-divergence (p1, p2) = %.8f", np.max(all_kld))
        tf.logging.info("Min kl-divergence (p1, p2) = %.8f", np.min(all_kld))
        tf.logging.info("Std kl-divergence (p1, p2) = %.8f", np.std(all_kld))

    elif FLAGS.diff_type == "kld2":
        all_kld = []

        for prob1, prob2 in zip(all_results[0], all_results[1]):
            all_kld.append(
                stats.entropy(prob2["probabilities"], prob1["probabilities"]))

        tf.logging.info("Average kl-divergence (p2, p1) = %.8f",
                        np.mean(all_kld))
        tf.logging.info("Max kl-divergence (p2, p1) = %.8f", np.max(all_kld))
        tf.logging.info("Min kl-divergence (p2, p1) = %.8f", np.min(all_kld))
        tf.logging.info("Std kl-divergence (p2, p1) = %.8f", np.std(all_kld))

    if FLAGS.diff_output_file:
        output = ""

        # Removing padded examples
        all_results[0] = all_results[0][:len(original_data)]
        all_results[1] = all_results[1][:len(original_data)]

        with tf.gfile.GFile(FLAGS.diff_output_file, "w") as f:
            for i, (eg, prob1, prob2) in enumerate(
                    zip(original_data, all_results[0], all_results[1])):

                if i % 1000 == 0:
                    tf.logging.info("Writing instance %d", i + 1)

                p1_items = [p1.item() for p1 in prob1["probabilities"]]
                p2_items = [p2.item() for p2 in prob2["probabilities"]]

                prob1_str = "%.6f\t%.6f\t%.6f" % (p1_items[0], p1_items[1],
                                                  p1_items[2])
                prob2_str = "%.6f\t%.6f\t%.6f" % (p2_items[0], p2_items[1],
                                                  p2_items[2])

                output = "%s\t%s\t%s\t%s\n" % (eg[0], eg[1], prob1_str,
                                               prob2_str)
                f.write(output)

    return
コード例 #2
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  processors = {
      "sst-2": rc.SST2Processor,
      "mnli": rc.MnliProcessor,
  }

  tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                FLAGS.init_checkpoint)
  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  task_name = FLAGS.task_name.lower()
  processor = processors[task_name]()
  label_list = processor.get_labels()
  predict_examples = processor.get_test_examples(FLAGS.predict_input_file)

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  predict_file = os.path.join(FLAGS.output_dir,
                              "mixup_%s.tf_record" % FLAGS.exp_name)

  rc.file_based_convert_examples_to_features(predict_examples, label_list,
                                             FLAGS.max_seq_length, tokenizer,
                                             predict_file)

  predict_input_fn = rc.file_based_input_fn_builder(
      input_file=predict_file,
      seq_length=FLAGS.max_seq_length,
      is_training=True,
      drop_remainder=False)

  predict_dataset = predict_input_fn({"batch_size": FLAGS.predict_batch_size})

  predict_iterator1 = predict_dataset.make_one_shot_iterator()
  predict_iterator2 = predict_dataset.make_one_shot_iterator()

  predict_dict1 = predict_iterator1.get_next()
  predict_dict2 = predict_iterator2.get_next()

  # Extract only the BERT non-contextual word embeddings, see their outputs
  embed1_out, embed_var = em_util.run_bert_embeddings(
      predict_dict1["input_ids"], bert_config)
  embed2_out, _ = em_util.run_bert_embeddings(predict_dict2["input_ids"],
                                              bert_config)

  if FLAGS.interpolate_scheme == "beta":
    # Interpolate two embeddings using samples from a beta(alpha, alpha) distro
    beta_distro = tf.distributions.Beta(FLAGS.alpha, FLAGS.alpha)
    interpolate = beta_distro.sample()
  elif FLAGS.interpolate_scheme == "fixed":
    # Interpolate two embeddings using a fixed interpolation constant
    interpolate = tf.constant(FLAGS.alpha)

  new_embed = interpolate * embed1_out + (1 - interpolate) * embed2_out

  # Get nearest neighbour in embedding space for interpolated embeddings
  nearest_neighbour, _ = em_util.get_nearest_neighbour(
      source=new_embed, reference=embed_var)
  nearest_neighbour = tf.cast(nearest_neighbour, tf.int32)

  # Check whether nearest neighbour is a new word
  new_vectors = tf.logical_and(
      tf.not_equal(nearest_neighbour, predict_dict1["input_ids"]),
      tf.not_equal(nearest_neighbour, predict_dict2["input_ids"]))

  # Combine the two input masks
  token_mask = tf.logical_or(
      tf.cast(predict_dict1["input_mask"], tf.bool),
      tf.cast(predict_dict2["input_mask"], tf.bool))

  # Mask out new vectors with original tokens mask
  new_vectors_masked = tf.logical_and(new_vectors, token_mask)

  tvars = tf.trainable_variables()

  assignment_map, _ = modeling.get_assignment_map_from_checkpoint(
      tvars, FLAGS.init_checkpoint)

  tf.train.init_from_checkpoint(FLAGS.init_checkpoint, assignment_map)

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  total_score = 0
  total_tokens = 0

  total_steps = len(predict_examples) // FLAGS.predict_batch_size + 1

  # Count the total words where new embeddings are produced via interpolation
  all_predict_input1 = []
  all_predict_input2 = []
  all_nearest_neighbours = []

  for i in range(total_steps):
    tf.logging.info("%d/%d, total_score = %d / %d", i, total_steps, total_score,
                    total_tokens)
    pd1, pd2, nn, tm, nvm = sess.run([
        predict_dict1, predict_dict2, nearest_neighbour, token_mask,
        new_vectors_masked
    ])

    # populate global lists of inputs and mix-ups
    all_nearest_neighbours.extend(nn.tolist())
    all_predict_input1.extend(pd1["input_ids"].tolist())
    all_predict_input2.extend(pd2["input_ids"].tolist())
    total_score += nvm.sum()
    total_tokens += tm.sum()

  tf.logging.info("Total score = %d", total_score)

  with tf.gfile.GFile(FLAGS.predict_output_file, "w") as f:
    for pd1, pd2, nn in zip(all_predict_input1, all_predict_input2,
                            all_nearest_neighbours):
      pd1_sent = " ".join(tokenizer.convert_ids_to_tokens(pd1))
      pd2_sent = " ".join(tokenizer.convert_ids_to_tokens(pd2))
      nn_sent = " ".join(tokenizer.convert_ids_to_tokens(nn))
      full_line = pd1_sent + "\t" + pd2_sent + "\t" + nn_sent + "\n"
      f.write(full_line)