Exemplo n.º 1
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  input_files = []
  for input_pattern in FLAGS.input_file.split(","):
    input_files.extend(tf.gfile.Glob(input_pattern))

  tf.logging.info("*** Reading from input files ***")
  for input_file in input_files:
    tf.logging.info("  %s", input_file)

  rng = random.Random(FLAGS.random_seed)
  instances = create_training_instances(
      input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
      FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
      rng)

  output_files = FLAGS.output_file.split(",")
  tf.logging.info("*** Writing to output files ***")
  for output_file in output_files:
    tf.logging.info("  %s", output_file)

  write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
                                  FLAGS.max_predictions_per_seq, output_files)
Exemplo n.º 2
0
    def __init__(self):

        self.output_dir = cur + '/' + 'model_data'
        self.ckpt = cur + '/' + 'model_data/model.ckpt-10859'
        #tf.logging.set_verbosity(tf.logging.INFO)
        self.RawResult = collections.namedtuple(
            "RawResult", ["unique_id", "start_logits", "end_logits"])

        self.bert_config = modeling.BertConfig.from_json_file(
            cur + '/' + 'model_data/bert_config.json')
        #validate_flags_or_throw(bert_config)
        #tf.gfile.MakeDirs()
        self.tokenizer = tokenization.FullTokenizer(vocab_file=cur + '/' +
                                                    'model_data/vocab.txt',
                                                    do_lower_case=True)

        tpu_cluster_resolver = None
        is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
        run_config = tf.contrib.tpu.RunConfig(
            cluster=tpu_cluster_resolver,
            master=None,
            model_dir=self.ckpt,
            save_checkpoints_steps=1000,
            tpu_config=tf.contrib.tpu.TPUConfig(
                iterations_per_loop=1000,
                num_shards=8,
                per_host_input_for_training=is_per_host))

        model_fn = model_fn_builder(bert_config=self.bert_config,
                                    init_checkpoint=self.ckpt,
                                    learning_rate=3e-5,
                                    num_train_steps=2.0,
                                    num_warmup_steps=None,
                                    use_tpu=False,
                                    use_one_hot_embeddings=False)

        # If TPU is not available, this will fall back to normal Estimator on CPU
        # or GPU.
        self.estimator = tf.contrib.tpu.TPUEstimator(use_tpu=False,
                                                     model_fn=model_fn,
                                                     config=run_config,
                                                     train_batch_size=32,
                                                     predict_batch_size=8)
Exemplo n.º 3
0
    def test_full_tokenizer(self):
        vocab_tokens = [
            "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un",
            "runn", "##ing", ","
        ]
        with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

            vocab_file = vocab_writer.name

        tokenizer = tokenization.FullTokenizer(vocab_file)
        os.unlink(vocab_file)

        tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
        self.assertAllEqual(tokens,
                            ["un", "##want", "##ed", ",", "runn", "##ing"])

        self.assertAllEqual(tokenizer.convert_tokens_to_ids(tokens),
                            [7, 4, 5, 10, 8, 9])
Exemplo n.º 4
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    layer_indexes = [int(x) for x in FLAGS.layers.split(",")]

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        master=FLAGS.master,
        tpu_config=tf.contrib.tpu.TPUConfig(
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))

    examples = read_examples(FLAGS.input_file)

    features = convert_examples_to_features(examples=examples,
                                            seq_length=FLAGS.max_seq_length,
                                            tokenizer=tokenizer)

    unique_id_to_feature = {}
    for feature in features:
        unique_id_to_feature[feature.unique_id] = feature

    model_fn = model_fn_builder(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        layer_indexes=layer_indexes,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_one_hot_embeddings)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        predict_batch_size=FLAGS.batch_size)

    input_fn = input_fn_builder(features=features,
                                seq_length=FLAGS.max_seq_length)

    with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file,
                                                 "w")) as writer:
        for result in estimator.predict(input_fn, yield_single_examples=True):
            unique_id = int(result["unique_id"])
            feature = unique_id_to_feature[unique_id]
            output_json = collections.OrderedDict()
            output_json["linex_index"] = unique_id
            all_features = []
            for (i, token) in enumerate(feature.tokens):
                all_layers = []
                for (j, layer_index) in enumerate(layer_indexes):
                    layer_output = result["layer_output_%d" % j]
                    layers = collections.OrderedDict()
                    layers["index"] = layer_index
                    layers["values"] = [
                        round(float(x), 6)
                        for x in layer_output[i:(i + 1)].flat
                    ]
                    all_layers.append(layers)
                features = collections.OrderedDict()
                features["token"] = token
                features["layers"] = all_layers
                all_features.append(features)
            output_json["features"] = all_features
            writer.write(json.dumps(output_json) + "\n")
Exemplo n.º 5
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  processors = {
      "cola": ColaProcessor,
      "mnli": MnliProcessor,
      "mrpc": MrpcProcessor,
      "xnli": XnliProcessor,
  }

  if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
    raise ValueError(
        "At least one of `do_train`, `do_eval` or `do_predict' must be True.")

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  if FLAGS.max_seq_length > bert_config.max_position_embeddings:
    raise ValueError(
        "Cannot use sequence length %d because the BERT model "
        "was only trained up to sequence length %d" %
        (FLAGS.max_seq_length, bert_config.max_position_embeddings))

  tf.gfile.MakeDirs(FLAGS.output_dir)

  task_name = FLAGS.task_name.lower()

  if task_name not in processors:
    raise ValueError("Task not found: %s" % (task_name))

  processor = processors[task_name]()

  label_list = processor.get_labels()

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  tpu_cluster_resolver = None
  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
  run_config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=FLAGS.output_dir,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      tpu_config=tf.contrib.tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_tpu_cores,
          per_host_input_for_training=is_per_host))

  train_examples = None
  num_train_steps = None
  num_warmup_steps = None
  if FLAGS.do_train:
    train_examples = processor.get_train_examples(FLAGS.data_dir)
    num_train_steps = int(
        len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

  model_fn = model_fn_builder(
      bert_config=bert_config,
      num_labels=len(label_list),
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_tpu)

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  estimator = tf.contrib.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size,
      predict_batch_size=FLAGS.predict_batch_size)

  if FLAGS.do_train:
    train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
    file_based_convert_examples_to_features(
        train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
    tf.logging.info("***** Running training *****")
    tf.logging.info("  Num examples = %d", len(train_examples))
    tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    tf.logging.info("  Num steps = %d", num_train_steps)
    train_input_fn = file_based_input_fn_builder(
        input_file=train_file,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True)
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

  if FLAGS.do_eval:
    eval_examples = processor.get_dev_examples(FLAGS.data_dir)
    eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
    file_based_convert_examples_to_features(
        eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)

    tf.logging.info("***** Running evaluation *****")
    tf.logging.info("  Num examples = %d", len(eval_examples))
    tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

    # This tells the estimator to run through the entire set.
    eval_steps = None
    # However, if running eval on the TPU, you will need to specify the
    # number of steps.
    if FLAGS.use_tpu:
      # Eval will be slightly WRONG on the TPU because it will truncate
      # the last batch.
      eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size)

    eval_drop_remainder = True if FLAGS.use_tpu else False
    eval_input_fn = file_based_input_fn_builder(
        input_file=eval_file,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=eval_drop_remainder)

    result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)

    output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
    with tf.gfile.GFile(output_eval_file, "w") as writer:
      tf.logging.info("***** Eval results *****")
      for key in sorted(result.keys()):
        tf.logging.info("  %s = %s", key, str(result[key]))
        writer.write("%s = %s\n" % (key, str(result[key])))

  if FLAGS.do_predict:
    predict_examples = processor.get_test_examples(FLAGS.data_dir)
    predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
    file_based_convert_examples_to_features(predict_examples, label_list,
                                            FLAGS.max_seq_length, tokenizer,
                                            predict_file)

    tf.logging.info("***** Running prediction*****")
    tf.logging.info("  Num examples = %d", len(predict_examples))
    tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    if FLAGS.use_tpu:
      # Warning: According to tpu_estimator.py Prediction on TPU is an
      # experimental feature and hence not supported here
      raise ValueError("Prediction in TPU not supported")

    predict_drop_remainder = True if FLAGS.use_tpu else False
    predict_input_fn = file_based_input_fn_builder(
        input_file=predict_file,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=predict_drop_remainder)

    result = estimator.predict(input_fn=predict_input_fn)

    output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
    with tf.gfile.GFile(output_predict_file, "w") as writer:
      tf.logging.info("***** Predict results *****")
      for prediction in result:
        output_line = "\t".join(
            str(class_probability) for class_probability in prediction) + "\n"
        writer.write(output_line)
Exemplo n.º 6
0
def learning_inversion():
    assert FLAGS.low_layer_idx == FLAGS.high_layer_idx == -1

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    num_words = bert_config.vocab_size

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    cls_id = tokenizer.vocab['[CLS]']
    sep_id = tokenizer.vocab['[SEP]']
    mask_id = tokenizer.vocab['[MASK]']

    train_x, train_y, test_x, test_y = load_inversion_data()
    filters = [cls_id, sep_id, mask_id, 0]
    train_y = filter_labels(train_y[0], filters)
    test_y = filter_labels(test_y[0], filters)

    label_freq = count_label_freq(train_y, num_words)
    log('Imbalace ratio: {}'.format(np.max(label_freq) / np.min(label_freq)))

    label_margin = tf.constant(np.reciprocal(label_freq**0.25),
                               dtype=tf.float32)
    C = FLAGS.C

    log('Build attack model for {} words...'.format(num_words))

    encoder_dim = train_x.shape[1]
    inputs = tf.placeholder(tf.float32, (None, encoder_dim), name="inputs")
    labels = tf.placeholder(tf.float32, (None, num_words), name="labels")
    training = tf.placeholder(tf.bool, name='training')

    if FLAGS.model == 'multiset':
        init_word_emb = None
        emb_dim = 512
        model = MultiSetInversionModel(emb_dim,
                                       num_words,
                                       FLAGS.seq_len,
                                       init_word_emb,
                                       C=C,
                                       label_margin=label_margin)
    elif FLAGS.model == 'multilabel':
        model = MultiLabelInversionModel(num_words,
                                         C=C,
                                         label_margin=label_margin)
    else:
        raise ValueError(FLAGS.model)

    preds, loss = model.forward(inputs, labels, training)
    true_pos, false_pos, false_neg = tp_fp_fn_metrics(labels, preds)
    eval_fetch = [loss, true_pos, false_pos, false_neg]

    t_vars = tf.trainable_variables()
    wd = FLAGS.wd
    post_ops = [
        tf.assign(v, v * (1 - wd)) for v in t_vars if 'kernel' in v.name
    ]

    optimizer = tf.train.AdamOptimizer(FLAGS.lr)
    grads_and_vars = optimizer.compute_gradients(
        loss + tf.losses.get_regularization_loss(), t_vars)
    train_ops = optimizer.apply_gradients(
        grads_and_vars, global_step=tf.train.get_or_create_global_step())

    with tf.control_dependencies([train_ops]):
        train_ops = tf.group(*post_ops)

    log('Train attack model with {} data...'.format(len(train_x)))
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        for epoch in range(30):
            train_iterations = 0
            train_loss = 0

            for batch_idx in iterate_minibatches_indices(
                    len(train_y), FLAGS.batch_size, True):
                one_hot_labels = np.zeros((len(batch_idx), num_words),
                                          dtype=np.float32)
                for i, idx in enumerate(batch_idx):
                    one_hot_labels[i][train_y[idx]] = 1
                feed = {
                    inputs: train_x[batch_idx],
                    labels: one_hot_labels,
                    training: True
                }
                err, _ = sess.run([loss, train_ops], feed_dict=feed)
                train_loss += err
                train_iterations += 1

            test_iterations = 0
            test_loss = 0
            test_tp, test_fp, test_fn = 0, 0, 0

            for batch_idx in iterate_minibatches_indices(len(test_y),
                                                         batch_size=512,
                                                         shuffle=False):
                one_hot_labels = np.zeros((len(batch_idx), num_words),
                                          dtype=np.float32)
                for i, idx in enumerate(batch_idx):
                    one_hot_labels[i][test_y[idx]] = 1
                feed = {
                    inputs: test_x[batch_idx],
                    labels: one_hot_labels,
                    training: False
                }

                fetch = sess.run(eval_fetch, feed_dict=feed)
                err, tp, fp, fn = fetch

                test_iterations += 1
                test_loss += err
                test_tp += tp
                test_fp += fp
                test_fn += fn

            precision = test_tp / (test_tp + test_fp) * 100
            recall = test_tp / (test_tp + test_fn) * 100
            f1 = 2 * precision * recall / (precision + recall)

            log("Epoch: {}, train loss: {:.4f}, test loss: {:.4f}, "
                "pre: {:.2f}%, rec: {:.2f}%, f1: {:.2f}%".format(
                    epoch, train_loss / train_iterations,
                    test_loss / test_iterations, precision, recall, f1))
Exemplo n.º 7
0
def optimization_inversion():
    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)
    cls_id = tokenizer.vocab['[CLS]']
    sep_id = tokenizer.vocab['[SEP]']
    mask_id = tokenizer.vocab['[MASK]']

    _, _, x, y = load_inversion_data()
    filters = [cls_id, sep_id, mask_id]
    y = filter_labels(y[0], filters)

    batch_size = FLAGS.batch_size
    seq_len = FLAGS.seq_len
    max_iters = FLAGS.max_iters

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_ids = tf.ones((batch_size, seq_len + 2), tf.int32)
    input_mask = tf.ones_like(input_ids, tf.int32)
    input_type_ids = tf.zeros_like(input_ids, tf.int32)

    model = modeling.BertModel(config=bert_config,
                               is_training=False,
                               input_ids=input_ids,
                               input_mask=input_mask,
                               token_type_ids=input_type_ids,
                               use_one_hot_embeddings=False)

    bert_vars = tf.trainable_variables()

    (assignment_map,
     _) = modeling.get_assignment_map_from_checkpoint(bert_vars,
                                                      FLAGS.init_checkpoint)
    tf.train.init_from_checkpoint(FLAGS.init_checkpoint, assignment_map)
    word_emb = model.embedding_table

    batch_cls_ids = tf.ones((batch_size, 1), tf.int32) * cls_id
    batch_sep_ids = tf.ones((batch_size, 1), tf.int32) * sep_id
    cls_emb = tf.nn.embedding_lookup(word_emb, batch_cls_ids)
    sep_emb = tf.nn.embedding_lookup(word_emb, batch_sep_ids)

    prob_mask = np.zeros((bert_config.vocab_size, ), np.float32)
    prob_mask[filters] = -1e9
    prob_mask = tf.constant(prob_mask, dtype=np.float32)

    logit_inputs = tf.get_variable(
        name='inputs',
        shape=(batch_size, seq_len, bert_config.vocab_size),
        initializer=tf.random_uniform_initializer(-0.1, 0.1))
    t_vars = [logit_inputs]
    t_var_names = {logit_inputs.name}

    logit_inputs += prob_mask
    prob_inputs = tf.nn.softmax(logit_inputs / FLAGS.temp, axis=-1)
    emb_inputs = tf.matmul(prob_inputs, word_emb)

    emb_inputs = tf.concat([cls_emb, emb_inputs, sep_emb], axis=1)
    if FLAGS.low_layer_idx == 0:
        encoded = mean_pool(emb_inputs, input_mask)
    else:
        encoded = encode(emb_inputs, input_ids, input_mask, input_type_ids,
                         bert_config)
    targets = tf.placeholder(tf.float32,
                             shape=(batch_size, encoded.shape.as_list()[-1]))
    loss = get_similarity_metric(encoded, targets, FLAGS.metric, rtn_loss=True)
    loss = tf.reduce_sum(loss)

    if FLAGS.alpha > 0.:
        # encourage the words to be different
        diff = tf.expand_dims(prob_inputs, 2) - tf.expand_dims(prob_inputs, 1)
        reg = tf.reduce_sum(-tf.exp(tf.reduce_sum(diff**2, axis=-1)), [1, 2])
        loss += FLAGS.alpha * tf.reduce_sum(reg)

    optimizer = tf.train.AdamOptimizer(FLAGS.lr)

    start_vars = set(v.name for v in tf.global_variables()
                     if v.name not in t_var_names)
    grads_and_vars = optimizer.compute_gradients(loss, t_vars)
    train_ops = optimizer.apply_gradients(
        grads_and_vars, global_step=tf.train.get_or_create_global_step())

    end_vars = tf.global_variables()
    new_vars = [v for v in end_vars if v.name not in start_vars]

    preds = tf.argmax(prob_inputs, axis=-1)
    batch_init_ops = tf.variables_initializer(new_vars)

    total_it = len(x) // batch_size

    with tf.Session() as sess:
        sess.run([tf.global_variables_initializer(), tf.tables_initializer()])

        def invert_one_batch(batch_targets):
            sess.run(batch_init_ops)
            feed_dict = {targets: batch_targets}
            prev = 1e6
            for i in range(max_iters):
                curr, _ = sess.run([loss, train_ops], feed_dict)
                # stop if no progress
                if (i + 1) % (max_iters // 10) == 0 and curr > prev:
                    break
                prev = curr
            return sess.run([preds, loss], feed_dict)

        start_time = time.time()
        it = 0.0
        all_tp, all_fp, all_fn, all_err = 0.0, 0.0, 0.0, 0.0

        for batch_idx in iterate_minibatches_indices(len(x), batch_size, False,
                                                     False):
            y_pred, err = invert_one_batch(x[batch_idx])
            tp, fp, fn = tp_fp_fn_metrics_np(y_pred, y[batch_idx])

            # for yp, yt in zip(y_pred, y[batch_idx]):
            #   print(','.join(set(tokenizer.convert_ids_to_tokens(yp))))
            #   print(','.join(set(tokenizer.convert_ids_to_tokens(yt))))

            it += 1.0
            all_err += err
            all_tp += tp
            all_fp += fp
            all_fn += fn

            all_pre = all_tp / (all_tp + all_fp + 1e-7)
            all_rec = all_tp / (all_tp + all_fn + 1e-7)
            all_f1 = 2 * all_pre * all_rec / (all_pre + all_rec + 1e-7)

            if it % FLAGS.print_every == 0:
                it_time = (time.time() - start_time) / it
                log("Iter {:.2f}%, err={}, pre={:.2f}%, rec={:.2f}%, f1={:.2f}%,"
                    " {:.2f} sec/it".format(it / total_it * 100, all_err / it,
                                            all_pre * 100, all_rec * 100,
                                            all_f1 * 100, it_time))

        all_pre = all_tp / (all_tp + all_fp + 1e-7)
        all_rec = all_tp / (all_tp + all_fn + 1e-7)
        all_f1 = 2 * all_pre * all_rec / (all_pre + all_rec + 1e-7)
        log("Final err={}, pre={:.2f}%, rec={:.2f}%, f1={:.2f}%".format(
            all_err / it, all_pre * 100, all_rec * 100, all_f1 * 100))
Exemplo n.º 8
0
def load_inversion_data():
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    train_sents, _, test_sents, _, _, _ = load_bookcorpus_author(
        train_size=FLAGS.train_size,
        test_size=FLAGS.test_size,
        unlabeled_size=0,
        split_by_book=True,
        split_word=False,
        top_attr=800)
    # train_sents, test_sents = load_all_diagnosis(split_word=False)

    if FLAGS.cross_domain:
        train_sents = load_cross_domain_data(800000, split_word=False)

    def sents_to_examples(sents):
        examples = read_examples(sents, tokenization.convert_to_unicode)
        return convert_examples_to_features(examples=examples,
                                            seq_length=FLAGS.max_seq_length,
                                            tokenizer=tokenizer)

    input_ids, input_mask, input_type_ids, outputs = model_fn_builder(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        use_one_hot_embeddings=False)

    sess = tf.Session()
    sess.run([tf.global_variables_initializer(), tf.tables_initializer()])

    learn_mapping = FLAGS.high_layer_idx != FLAGS.low_layer_idx

    def encode_example(features):
        n_data = len(features[0])
        embs_low, embs_high = [], []
        pbar = tqdm.tqdm(total=n_data)
        for b_idx in iterate_minibatches_indices(n_data, 128):
            emb = sess.run(outputs,
                           feed_dict={
                               input_ids: features[0][b_idx],
                               input_mask: features[1][b_idx],
                               input_type_ids: features[2][b_idx]
                           })
            if learn_mapping:
                embs_low.append(emb[0])
                embs_high.append(emb[1])
                n_batch = len(emb[0])
            else:
                embs_low.append(emb)
                n_batch = len(emb)
            pbar.update(n_batch)
        pbar.close()

        if learn_mapping:
            return np.vstack(embs_low), np.vstack(embs_high)
        else:
            return np.vstack(embs_low)

    train_features = sents_to_examples(train_sents)
    train_x = encode_example(train_features)

    test_features = sents_to_examples(test_sents)
    test_x = encode_example(test_features)
    tf.keras.backend.clear_session()

    if learn_mapping:
        log('Training high to low mapping...')
        if FLAGS.mapper == 'linear':
            mapping = linear_mapping(train_x[1], train_x[0])
        elif FLAGS.mapper == 'mlp':
            mapping = mlp_mapping(train_x[1],
                                  train_x[0],
                                  epochs=30,
                                  activation=tf.tanh)
        elif FLAGS.mapper == 'gan':
            mapping = gan_mapping(train_x[1],
                                  train_x[0],
                                  disc_iters=5,
                                  batch_size=64,
                                  gamma=1.0,
                                  epoch=100,
                                  activation=tf.tanh)
        else:
            raise ValueError(FLAGS.mapper)
        test_x = mapping(test_x[1])

    return train_x, train_features, test_x, test_features