Ejemplo n.º 1
0
def max_variance_balanced(tokenizer, estimator, data_dir, model_dir, pool_name,
                          n_instances, sample_size, label_distn,
                          max_seq_length, use_tpu, batch_size,
                          overwrite_tfrecord):
    pool_df = pd.read_csv(f"{data_dir}/{pool_name}.tsv", sep="\t")
    n_instances_per_label = [
        math.ceil(int(probs * n_instances)) for probs in label_distn
    ]
    output = multilabel.predict_routine(pool_name, data_dir, model_dir,
                                        max_seq_length, tokenizer, estimator,
                                        use_tpu, batch_size, sample_size,
                                        overwrite_tfrecord)
    result_df = pd.DataFrame(output)
    variance_df = pd.DataFrame(result_df["variance"].tolist(),
                               index=result_df.index,
                               columns=multilabel.LABEL_COLUMN)
    query_idx = []
    avail_idx = set(range(len(variance_df)))
    for i in range(len(multilabel.LABEL_COLUMN)):
        chosen_idx = set(
            variance_df[variance_df.index.isin(avail_idx)].nlargest(
                n_instances_per_label[i],
                multilabel.LABEL_COLUMN[i],
                keep="all").index)
        query_idx.extend(chosen_idx)
        avail_idx -= chosen_idx
    return pool_df[~pool_df.index.isin(query_idx)], pool_df.iloc[query_idx]
Ejemplo n.º 2
0
def margin(tokenizer, estimator, data_dir, model_dir, pool_name, n_instances,
           sample_size, label_distn, max_seq_length, use_tpu, batch_size,
           overwrite_tfrecord):
    del label_distn  # Unused.
    pool_df = pd.read_csv(f"{data_dir}/{pool_name}.tsv", sep="\t")
    n_instances_per_label = int(n_instances / len(multilabel.LABEL_COLUMN))

    output = multilabel.predict_routine(pool_name, data_dir, model_dir,
                                        max_seq_length, tokenizer, estimator,
                                        use_tpu, batch_size, sample_size,
                                        overwrite_tfrecord)
    result_df = pd.DataFrame(output)
    probs_df = pd.DataFrame(result_df["probs"].tolist(),
                            index=result_df.index,
                            columns=multilabel.LABEL_COLUMN)
    margin_df = np.abs(probs_df - 0.5)
    query_idx = []
    avail_idx = set(range(len(margin_df)))
    for label in multilabel.LABEL_COLUMN:
        chosen_idx = set(margin_df[margin_df.index.isin(avail_idx)].nsmallest(
            n_instances_per_label, label).index)
        query_idx.extend(chosen_idx)
        avail_idx -= chosen_idx
    return pool_df[~pool_df.index.isin(query_idx)], pool_df.iloc[query_idx]
Ejemplo n.º 3
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `do_train`, `do_eval` or `do_predict' must be True."
        )

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.io.gfile.makedirs(FLAGS.output_dir)
    tf.logging.info("***** FLAGS *****")
    writer = tf.io.gfile.GFile(f"{FLAGS.output_dir}/flags.txt", "w+")
    for key, val in FLAGS.__flags.items():
        tf.logging.info("  %s = %s", key, str(val.value))
        writer.write("%s = %s\n" % (key, str(val.value)))
    writer.close()

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.compat.v1.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        tpu_config=tf.compat.v1.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host),
        model_dir=FLAGS.output_dir,
        tf_random_seed=100,
        save_summary_steps=FLAGS.save_summary_steps,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        log_step_count_steps=FLAGS.log_step_count_steps)

    train_task_name = FLAGS.task_name if FLAGS.task_name else "train"
    _, num_train_examples, _, num_train_steps = multilabel.prepare_tfrecords(
        train_task_name, FLAGS.data_dir, FLAGS.max_seq_length, tokenizer,
        False, FLAGS.train_batch_size, 1, FLAGS.convert_tsv_to_tfrecord)
    num_train_steps = int(num_train_steps * FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    model_fn = multilabel.model_fn_builder(
        bert_config=bert_config,
        num_labels=len(multilabel.LABEL_COLUMN),
        init_checkpoint=tf.train.latest_checkpoint(FLAGS.init_checkpoint),
        learning_rate=FLAGS.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu,
        freeze_bert=FLAGS.freeze_bert,
        finetune_module=FLAGS.finetune_module,
        num_train_examples=num_train_examples)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.estimator.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.predict_batch_size)

    if FLAGS.do_train and FLAGS.do_eval and (not FLAGS.use_tpu):
        eval_task_name = "eval"
        _, num_eval_examples, num_padded_eval_examples, num_eval_steps = multilabel.prepare_tfrecords(
            eval_task_name, FLAGS.data_dir, FLAGS.max_seq_length, tokenizer,
            FLAGS.use_tpu, FLAGS.eval_batch_size, 1,
            FLAGS.convert_tsv_to_tfrecord)

        tf.logging.info("***** Running training and evaluation *****")
        tf.logging.info("  Num training examples = %d", num_train_examples)
        tf.logging.info("  Training batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Num training steps = %d", num_train_steps)
        tf.logging.info("  Num eval examples = %d ", num_eval_examples)
        tf.logging.info("  Eval batch size = %d", FLAGS.eval_batch_size)
        tf.logging.info("  Num eval steps = %d", num_eval_steps)

        train_input_fn = multilabel.file_based_input_fn_builder(
            input_file=f"{FLAGS.data_dir}/{train_task_name}-1.tfrecord",
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True)
        eval_input_fn = multilabel.file_based_input_fn_builder(
            input_file=f"{FLAGS.data_dir}/{eval_task_name}-1.tfrecord",
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=False)
        train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn,
                                            max_steps=num_train_steps)
        eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn,
                                          steps=num_eval_steps,
                                          start_delay_secs=60,
                                          throttle_secs=120)
        tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
    else:
        if FLAGS.do_train:
            tf.logging.info("***** Running training *****")
            tf.logging.info("  Num examples = %d", num_train_examples)
            tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
            tf.logging.info("  Num steps = %d", num_train_steps)
            train_input_fn = multilabel.file_based_input_fn_builder(
                input_file=f"{FLAGS.data_dir}/{train_task_name}-1.tfrecord",
                seq_length=FLAGS.max_seq_length,
                is_training=True,
                drop_remainder=True)
            estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

        if FLAGS.do_eval:
            eval_task_name = "eval"
            ckpt_steps = list(
                range(0, num_train_steps, FLAGS.save_checkpoints_steps))
            ckpt_steps.append(num_train_steps)
            best_result = multilabel.eval_routine(
                eval_task_name, FLAGS.data_dir, FLAGS.output_dir,
                FLAGS.max_seq_length, tokenizer, estimator, FLAGS.use_tpu,
                FLAGS.predict_batch_size, ckpt_steps, 1,
                FLAGS.convert_tsv_to_tfrecord)

    if FLAGS.do_predict:
        predict_task_name = "predict"
        output = multilabel.predict_routine(
            predict_task_name, FLAGS.data_dir, FLAGS.output_dir,
            FLAGS.max_seq_length, tokenizer, estimator, FLAGS.use_tpu,
            FLAGS.predict_batch_size, FLAGS.sample_size,
            FLAGS.convert_tsv_to_tfrecord)
        file_name = f"{FLAGS.output_dir}/{predict_task_name}-{FLAGS.sample_size}-results.tsv"
        with tf.io.gfile.GFile(file_name, "w+") as writer:
            num_written_items = 0
            writer.write(
                multilabel.ID_COLUMN + "\t" + multilabel.LANG_COLUMN + "\t" +
                "\t".join(name + " prob" for name in multilabel.LABEL_COLUMN) +
                "\t" + "\t".join(name + " var"
                                 for name in multilabel.LABEL_COLUMN) + "\t" +
                "\t".join(name + " ci_lb"
                          for name in multilabel.LABEL_COLUMN) + "\t" +
                "\t".join(name + " ci_ub"
                          for name in multilabel.LABEL_COLUMN) + "\t" +
                "\t".join(name + " ci_68_lb"
                          for name in multilabel.LABEL_COLUMN) + "\t" +
                "\t".join(name + " ci_68_ub"
                          for name in multilabel.LABEL_COLUMN) + "\n")
            for (i, pred) in enumerate(output):
                logits = np.array(pred["logits"])
                vars = np.array(pred["variance"])
                std_dev = np.sqrt(vars)
                lower_95_ci = expit(logits - 2 * std_dev)
                upper_95_ci = expit(logits + 2 * std_dev)
                lower_68_ci = expit(logits - std_dev)
                upper_68_ci = expit(logits + std_dev)
                output_line = pred[multilabel.ID_COLUMN] + "\t" + pred[
                    multilabel.LANG_COLUMN] + "\t"
                output_line += "\t".join(
                    str(class_prob) for class_prob in pred["probs"]) + "\t"
                output_line += "\t".join(str(class_var)
                                         for class_var in vars) + "\t"
                output_line += "\t".join(str(lb) for lb in lower_95_ci) + "\t"
                output_line += "\t".join(str(ub) for ub in upper_95_ci) + "\t"
                output_line += "\t".join(str(lb) for lb in lower_68_ci) + "\t"
                output_line += "\t".join(str(ub) for ub in upper_68_ci) + "\n"
                writer.write(output_line)
                num_written_items += 1
            assert num_written_items == len(output)
        tf.logging.info(f"Prediction results written to {file_name}")