Example #1
0
 def test_tokenizer_val(self):
   testdata = os.path.join(
       FLAGS.test_srcdir, ".//"
       "small_gold_annotation.jsonl")
   train_examples = preproc.read_tydi_examples(
       testdata,
       is_training=True,
       max_passages=45,
       max_position=45,
       fail_on_invalid=False,
       open_fn=tf_io.gopen)
   dev_examples = preproc.read_tydi_examples(
       testdata,
       is_training=False,
       max_passages=45,
       max_position=45,
       fail_on_invalid=False,
       open_fn=tf_io.gopen)
   vocab_file = self._get_vocab_file()
   tokenizer = tokenization.TyDiTokenizer(vocab_file=vocab_file)
   for tr_ex, dev_ex in zip(train_examples, dev_examples):
     train_feats = preproc.convert_single_example(
         tr_ex,
         tokenizer,
         is_training=True,
         max_question_length=64,
         max_seq_length=512,
         doc_stride=128,
         include_unknowns=1.0,
         errors=[],
         debug_info={})
     dev_feats = preproc.convert_single_example(
         dev_ex,
         tokenizer,
         is_training=False,
         max_question_length=64,
         max_seq_length=512,
         doc_stride=128,
         include_unknowns=1.0,
         errors=[],
         debug_info={})
     for train_f, dev_f in zip(train_feats, dev_feats):
       if train_f.answer_text:
         st_ = train_f.start_position
         ed_ = train_f.end_position
         st_offset = dev_f.wp_start_offset[st_]
         end_offset = dev_f.wp_end_offset[ed_]
         self.assertGreaterEqual(end_offset, st_offset)
Example #2
0
 def test_tokenizer(self):
   testdata = os.path.join(
       FLAGS.test_srcdir, ".//"
       "small_gold_annotation.jsonl")
   test_examples = preproc.read_tydi_examples(
       testdata,
       is_training=True,
       max_passages=45,
       max_position=45,
       fail_on_invalid=False,
       open_fn=tf_io.gopen)
   vocab_file = self._get_vocab_file()
   tokenizer = tokenization.TyDiTokenizer(vocab_file=vocab_file)
   for tydi_example in test_examples:
     features = preproc.convert_single_example(
         tydi_example,
         tokenizer,
         is_training=True,
         max_question_length=64,
         max_seq_length=512,
         doc_stride=128,
         include_unknowns=1.0,
         errors=[],
         debug_info={})
     self.assertEqual(len(set([f.language_id for f in features])), 1)
     for feature in features:
       if feature.end_position <= 0:
         self.assertEqual(feature.start_position, 0)
Example #3
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    bert_config = bert_modeling.BertConfig.from_json_file(
        FLAGS.bert_config_file)
    validate_flags_or_throw(bert_config)
    tf.gfile.MakeDirs(FLAGS.output_dir)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf_contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf_contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf_contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf_contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            per_host_input_for_training=is_per_host))

    num_train_steps = None
    num_warmup_steps = None
    if FLAGS.do_train:
        with tf.gfile.Open(FLAGS.record_count_file, "r") as f:
            num_train_features = int(f.read().strip())
        num_train_steps = int(num_train_features / FLAGS.train_batch_size *
                              FLAGS.num_train_epochs)
        tf.logging.info("record_count_file: %s", FLAGS.record_count_file)
        tf.logging.info("num_records (features): %d", num_train_features)
        tf.logging.info("num_train_epochs: %d", FLAGS.num_train_epochs)
        tf.logging.info("train_batch_size: %d", FLAGS.train_batch_size)
        tf.logging.info("num_train_steps: %d", num_train_steps)

        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    model_fn = tydi_modeling.model_fn_builder(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu)

    # If TPU is not available, this falls back to normal Estimator on CPU or GPU.
    estimator = tf_contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        predict_batch_size=FLAGS.predict_batch_size)

    if FLAGS.do_train:
        tf.logging.info("***** Running training on precomputed features *****")
        tf.logging.info("  Num split examples = %d", num_train_features)
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Num steps = %d", num_train_steps)
        train_filenames = tf.gfile.Glob(FLAGS.train_records_file)
        train_input_fn = tf_io.input_fn_builder(
            input_file=train_filenames,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True)

        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
    if FLAGS.do_predict:
        if not FLAGS.output_prediction_file:
            raise ValueError(
                "--output_prediction_file must be defined in predict mode.")
        if not FLAGS.precomputed_predict_file:
            # `evan_tydi_examples` is a lazy generator.
            eval_tydi_examples = preproc.read_tydi_examples(
                input_file=FLAGS.predict_file,
                is_training=False,
                max_passages=FLAGS.max_passages,
                max_position=FLAGS.max_position,
                fail_on_invalid=FLAGS.fail_on_invalid,
                open_fn=tf_io.gopen)
            eval_writer = tf_io.FeatureWriter(filename=os.path.join(
                FLAGS.output_dir, "eval.tf_record"),
                                              is_training=False)

            eval_features = []

            def append_feature(feature):
                eval_features.append(feature)
                eval_writer.process_feature(feature)

            tf.logging.info("**** Converting examples.")
            num_spans_to_ids, num_examples = preproc.convert_examples_to_features(
                tydi_examples=eval_tydi_examples,
                vocab_file=FLAGS.vocab_file,
                is_training=False,
                max_question_length=FLAGS.max_question_length,
                max_seq_length=FLAGS.max_seq_length,
                doc_stride=FLAGS.doc_stride,
                include_unknowns=FLAGS.include_unknowns,
                output_fn=append_feature)
            eval_writer.close()
            eval_filename = eval_writer.filename
            tf.logging.info("**** Converting examples finished.")

            for spans, ids in num_spans_to_ids.items():
                tf.logging.info("  Num split into %d = %d", spans, len(ids))
            tf.logging.info("***** Running predictions *****")
            tf.logging.info("  Num orig examples = %d", num_examples)
            eval_filenames = [eval_filename]
        else:
            eval_filenames = tf.gfile.Glob(FLAGS.precomputed_predict_file)

        predict_input_fn = tf_io.input_fn_builder(
            input_file=eval_filenames,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=False)

        # If running eval on the TPU, you will need to specify the number of steps.
        all_results = []
        for result in estimator.predict(predict_input_fn,
                                        yield_single_examples=True):
            if len(all_results) % 1000 == 0:
                tf.logging.info("Processing example: %d", len(all_results))
            unique_id = int(result["unique_ids"])
            start_logits = [float(x) for x in result["start_logits"].flat]
            end_logits = [float(x) for x in result["end_logits"].flat]
            answer_type_logits = [
                float(x) for x in result["answer_type_logits"].flat
            ]
            all_results.append(
                tydi_modeling.RawResult(unique_id=unique_id,
                                        start_logits=start_logits,
                                        end_logits=end_logits,
                                        answer_type_logits=answer_type_logits))

        candidates_dict = read_candidates(FLAGS.predict_file)

        tf.logging.info("Loaded candidates examples: %d", len(candidates_dict))
        eval_features = []
        tf.logging.info("Number of eval file shards: %d", len(eval_filenames))
        for eval_filename in eval_filenames:
            eval_features.extend([
                tf.train.Example.FromString(r)
                for r in tf.python_io.tf_record_iterator(eval_filename)
            ])
        tf.logging.info("Loaded eval features: %d", len(eval_features))
        tf.logging.info("Loaded results: %d", len(all_results))

        tydi_pred_dict = postproc.compute_pred_dict(
            candidates_dict,
            eval_features, [r._asdict() for r in all_results],
            candidate_beam=FLAGS.candidate_beam)
        predictions_json = {"predictions": list(tydi_pred_dict.values())}
        with tf.gfile.Open(FLAGS.output_prediction_file, "w") as f:
            json.dump(predictions_json, f, indent=4)
Example #4
0
def main(_):
    logging.set_verbosity(logging.INFO)
    bert_config = bert_modeling.BertConfig.from_json_file(
        FLAGS.bert_config_file)
    validate_flags_or_throw(bert_config)
    tf.gfile.MakeDirs(FLAGS.output_dir)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf_contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf_contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf_contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf_contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            per_host_input_for_training=is_per_host))

    num_train_steps = None
    num_warmup_steps = None
    if FLAGS.do_train:
        with tf.gfile.Open(FLAGS.record_count_file, "r") as f:
            num_train_features = int(f.read().strip())
        num_train_steps = int(num_train_features / FLAGS.train_batch_size *
                              FLAGS.num_train_epochs)
        logging.info("record_count_file: %s", FLAGS.record_count_file)
        logging.info("num_records (features): %d", num_train_features)
        logging.info("num_train_epochs: %d", FLAGS.num_train_epochs)
        logging.info("train_batch_size: %d", FLAGS.train_batch_size)
        logging.info("num_train_steps: %d", num_train_steps)

        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    model_fn = tydi_modeling.model_fn_builder(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu)

    # If TPU is not available, this falls back to normal Estimator on CPU or GPU.
    estimator = tf_contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        predict_batch_size=FLAGS.predict_batch_size)

    if FLAGS.do_train:
        logging.info("Running training on precomputed features")
        logging.info("  Num split examples = %d", num_train_features)
        logging.info("  Batch size = %d", FLAGS.train_batch_size)
        logging.info("  Num steps = %d", num_train_steps)
        train_filenames = tf.gfile.Glob(FLAGS.train_records_file)
        train_input_fn = tf_io.input_fn_builder(
            input_file=train_filenames,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True)
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

    if FLAGS.do_predict:
        if not FLAGS.precomputed_predict_file:
            predict_examples_iter = preproc.read_tydi_examples(
                input_file=FLAGS.predict_file,
                is_training=False,
                max_passages=FLAGS.max_passages,
                max_position=FLAGS.max_position,
                fail_on_invalid=FLAGS.fail_on_invalid,
                open_fn=tf_io.gopen)
            shards_iter = write_tf_feature_files(predict_examples_iter)
        else:
            # Uses zeros for example and feature counts since they're unknown, and
            # we only use them for logging anyway.
            shards_iter = enumerate(
                ((f, 0, 0)
                 for f in tf.gfile.Glob(FLAGS.precomputed_predict_file)), 1)

        # Accumulates all of the prediction results to be written to the output.
        full_tydi_pred_dict = {}
        total_num_examples = 0
        for shard_num, (shard_filename, shard_num_examples,
                        shard_num_features) in shards_iter:
            total_num_examples += shard_num_examples
            logging.info(
                "Shard %d: Running prediction for %s; %d examples, %d features.",
                shard_num, shard_filename, shard_num_examples,
                shard_num_features)

            # Runs the model on the shard and store the individual results.
            # If running predict on TPU, you will need to specify the number of steps.
            predict_input_fn = tf_io.input_fn_builder(
                input_file=[shard_filename],
                seq_length=FLAGS.max_seq_length,
                is_training=False,
                drop_remainder=False)
            all_results = []
            for result in estimator.predict(predict_input_fn,
                                            yield_single_examples=True):
                if len(all_results) % 10000 == 0:
                    logging.info("Shard %d: Predicting for feature %d/%s",
                                 shard_num, len(all_results),
                                 shard_num_features)
                unique_id = int(result["unique_ids"])
                start_logits = [float(x) for x in result["start_logits"].flat]
                end_logits = [float(x) for x in result["end_logits"].flat]
                answer_type_logits = [
                    float(x) for x in result["answer_type_logits"].flat
                ]
                all_results.append(
                    tydi_modeling.RawResult(
                        unique_id=unique_id,
                        start_logits=start_logits,
                        end_logits=end_logits,
                        answer_type_logits=answer_type_logits))

            # Reads the prediction candidates from the (entire) prediction input file.
            candidates_dict = read_candidates(FLAGS.predict_file)
            predict_features = [
                tf.train.Example.FromString(r)
                for r in tf.python_io.tf_record_iterator(shard_filename)
            ]
            logging.info("Shard %d: Post-processing predictions.", shard_num)
            logging.info(
                "  Num candidate examples loaded (includes all shards): %d",
                len(candidates_dict))
            logging.info("  Num candidate features loaded: %d",
                         len(predict_features))
            logging.info("  Num prediction result features: %d",
                         len(all_results))
            logging.info("  Num shard features: %d", shard_num_features)

            tydi_pred_dict = postproc.compute_pred_dict(
                candidates_dict,
                predict_features, [r._asdict() for r in all_results],
                candidate_beam=FLAGS.candidate_beam)

            logging.info("Shard %d: Post-processed predictions.", shard_num)
            logging.info("  Num shard examples: %d", shard_num_examples)
            logging.info("  Num post-processed results: %d",
                         len(tydi_pred_dict))
            if shard_num_examples != len(tydi_pred_dict):
                logging.warning("  Num missing predictions: %d",
                                shard_num_examples - len(tydi_pred_dict))
            for key, value in tydi_pred_dict.items():
                if key in full_tydi_pred_dict:
                    logging.warning(
                        "ERROR: '%s' already in full_tydi_pred_dict!", key)
                full_tydi_pred_dict[key] = value

        logging.info("Prediction finished for all shards.")
        logging.info("  Total input examples: %d", total_num_examples)
        logging.info("  Total output predictions: %d",
                     len(full_tydi_pred_dict))

        with tf.gfile.Open(FLAGS.output_prediction_file, "w") as output_file:
            for prediction in full_tydi_pred_dict.values():
                output_file.write((json.dumps(prediction) + "\n").encode())
Example #5
0
  def test_offset_wp_mapping(self):
    """Test the mapping from wordpiece to plaintext offsets."""
    testdata = os.path.join(
        FLAGS.test_srcdir, ".//"
        "small_gold_annotation.jsonl")
    vocab_file = self._get_vocab_file()
    examples = preproc.read_tydi_examples(
        testdata,
        is_training=False,
        max_passages=45,
        max_position=45,
        fail_on_invalid=False,
        open_fn=tf_io.gopen)
    vocab_file = self._get_vocab_file()
    tokenizer = tokenization.TyDiTokenizer(vocab_file=vocab_file)
    for tydi_example in examples:
      wordpieces, start_offsets, end_offsets, offset_to_wp = (
          tokenizer.tokenize_with_offsets(tydi_example.contexts))

      # Check invariants.
      for i in start_offsets:
        if i > 0:
          self.assertLess(
              i, len(tydi_example.context_to_plaintext_offset),
              "Expected start offset {} to be in `context_to_plaintext_offset` "
              "byte_len(contexts)={} Context@{}='{}' Have={}".format(
                  i, data.byte_len(tydi_example.contexts), i,
                  data.byte_slice(
                      tydi_example.contexts, i, i + 100,
                      errors="ignore").encode("utf8"),
                  tydi_example.context_to_plaintext_offset))
      for i in end_offsets:
        if i > 0:
          self.assertLess(
              i, len(tydi_example.context_to_plaintext_offset),
              "Expected end offset {} to be in `context_to_plaintext_offset` "
              "byte_len(contexts)={} Have={}".format(
                  i, data.byte_len(tydi_example.contexts),
                  tydi_example.context_to_plaintext_offset))

      wp_start_offsets, wp_end_offsets = (
          preproc.create_mapping(start_offsets, end_offsets,
                                 tydi_example.context_to_plaintext_offset))
      wp_count = 0
      for wp_s, wp_e in zip(wp_start_offsets, wp_end_offsets):
        if wp_s >= 0 or wp_e >= 0 and wp_count < 20:
          wp_txt = wordpieces[wp_count]
          if isinstance(wp_txt, str):
            if "##" not in wp_txt and wp_txt != "[UNK]":
              self.assertEqual(tydi_example.plaintext[wp_s:wp_e + 1], wp_txt)
        wp_count += 1

      for offset in offset_to_wp:
        self.assertLess(offset, data.byte_len(tydi_example.contexts))
        self.assertGreaterEqual(offset, 0)
        matching_wp = offset_to_wp[offset]
        if matching_wp == -1:
          continue
        if wp_end_offsets[matching_wp] == -1:
          continue
        if wp_start_offsets[matching_wp] == -1:
          continue
        self.assertGreaterEqual(wp_end_offsets[matching_wp],
                                wp_start_offsets[matching_wp])
Example #6
0
def main(_):

    logging.set_verbosity(logging.INFO)
    bert_config = bert_modeling.BertConfig.from_json_file(
        FLAGS.bert_config_file)
    validate_flags_or_throw(bert_config)
    tf.gfile.MakeDirs(FLAGS.output_dir)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf_contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf_contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf_contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf_contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            per_host_input_for_training=is_per_host))

    num_train_steps = None
    num_warmup_steps = None
    '''if FLAGS.do_train:
    with tf.gfile.Open(FLAGS.record_count_file, "r") as f:
      num_train_features = int(f.read().strip())
    num_train_steps = int(num_train_features / FLAGS.train_batch_size *
                          FLAGS.num_train_epochs)
    logging.info("record_count_file: %s", FLAGS.record_count_file)
    logging.info("num_records (features): %d", num_train_features)
    logging.info("num_train_epochs: %d", FLAGS.num_train_epochs)
    logging.info("train_batch_size: %d", FLAGS.train_batch_size)
    logging.info("num_train_steps: %d", num_train_steps)

    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

  model_fn = tydi_modeling.model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_tpu)

  # If TPU is not available, this falls back to normal Estimator on CPU or GPU.
  estimator = tf_contrib.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      predict_batch_size=FLAGS.predict_batch_size)'''

    if FLAGS.do_train:

        with tf.gfile.Open(FLAGS.record_count_file, "r") as f:
            num_train_features = int(f.read().strip())
        num_train_steps = int(num_train_features / FLAGS.train_batch_size *
                              FLAGS.num_train_epochs)
        logging.info("record_count_file: %s", FLAGS.record_count_file)
        logging.info("num_records (features): %d", num_train_features)
        logging.info("num_train_epochs: %d", FLAGS.num_train_epochs)
        logging.info("train_batch_size: %d", FLAGS.train_batch_size)
        logging.info("num_train_steps: %d", num_train_steps)

        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

        logging.info("Running training on precomputed features")
        logging.info("  Num split examples = %d", num_train_features)
        logging.info("  Batch size = %d", FLAGS.train_batch_size)
        logging.info("  Num steps = %d", num_train_steps)
        train_filenames = tf.gfile.Glob(FLAGS.train_records_file)

        model_fn = tydi_modeling.model_fn_builder(
            bert_config=bert_config,
            init_checkpoint=FLAGS.init_checkpoint,
            learning_rate=FLAGS.learning_rate,
            num_train_steps=num_train_steps,
            num_warmup_steps=num_warmup_steps,
            use_tpu=FLAGS.use_tpu,
            use_one_hot_embeddings=FLAGS.use_tpu)

        def input_fn_builder(input_file, seq_length, is_training,
                             drop_remainder):
            """Creates an `input_fn` closure to be passed to TPUEstimator."""

            # This needs to be kept in sync with `FeatureWriter`.
            name_to_features = {
                "language_id": tf.FixedLenFeature([], tf.int64),
                "unique_ids": tf.FixedLenFeature([], tf.int64),
                "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
                "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
                "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
            }

            if is_training:
                name_to_features["start_positions"] = tf.FixedLenFeature(
                    [], tf.int64)
                name_to_features["end_positions"] = tf.FixedLenFeature(
                    [], tf.int64)
                name_to_features["answer_types"] = tf.FixedLenFeature([],
                                                                      tf.int64)

            def _decode_record(record, name_to_features):
                example = tf.parse_single_example(record, name_to_features)

                # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
                # So cast all int64 to int32.
                for name in list(example.keys()):
                    t = example[name]
                    if t.dtype == tf.int64:
                        t = tf.to_int32(t)
                        example[name] = t

                return example

            # For training, we want a lot of parallel reading and shuffling.
            # For eval, we want no shuffling and parallel reading doesn't matter.
            d = tf.data.TFRecordDataset(input_file)
            d = d.shuffle(buffer_size=100)
            d = d.map(lambda record: _decode_record(record, name_to_features))
            return d

        def split_train_dev(dataset, train_port, num_dev):
            train_size = int(DATASET_SIZE * train_port)
            sub_dev_size = int(DATASET_SIZE * (1 - train_port) / num_dev)

            full_dataset = dataset.shuffle(100)
            train_dataset = full_dataset.take(train_size)
            test_dataset = full_dataset.skip(train_size)

            shard_devs = []
            for i in range(num_dev):
                shard_devs.append(
                    test_dataset.shard(num_shards=num_dev, index=i))
            return train_dataset, shard_devs

        def count_dataset(dataset):
            cnt = 0
            for i in dataset.repeat(1).make_one_shot_iterator():
                # if cnt % 2000==0:
                # print(cnt)
                cnt += 1

            return cnt

        def split_langs(dataset):
            def dataset_fn(ds, i):
                return ds.filter(lambda x: tf.equal(x['language_id'], i))

            data_set_lst = []
            for i in range(11):
                dataset_filter_lang = dataset.apply(lambda x: dataset_fn(x, i))
                data_set_lst.append(dataset_filter_lang)
            return data_set_lst

        DATASET_SIZE = num_train_features

        NUM_LANGS = 11

        tf_dataset = input_fn_builder(FLAGS.train_records_file, 512, True,
                                      False)

        train_set, dev_shards = split_train_dev(tf_dataset, FLAGS.train_size,
                                                FLAGS.num_dev_sets)
        total_num_train_samples = FLAGS.train_size * DATASET_SIZE
        total_num_dev_samples = DATASET_SIZE - total_num_train_samples

        logging.info("Nums of examples in train dataset = %d",
                     total_num_train_samples)
        logging.info("Total numbers of examples in dev set = %d",
                     total_num_dev_samples)

        train_set_langs = split_langs(train_set)
        dev_set_langs = []

        for div_set in dev_shards:
            dev_set_langs.append(split_langs(div_set))

        #train_set_langs is a 1d lst, and div_set_langs is a 2d lst, notice that you could find
        #corresponding languages ids in data file

        def sample_lang_id(lang_freq):
            #print(lang_freq)
            #print(list(range(NUM_LANGS)))
            return choices(list(range(NUM_LANGS)), lang_freq)

        # count number of languages in each language

        lang_sample_dist = []
        for lang in train_set_langs:
            lang_cnt = count_dataset(lang)
            print(lang_cnt)
            lang_sample_dist.append(lang_cnt / total_num_train_samples)

        train_samplers = list(
            map(lambda x: x.repeat().batch(1).make_one_shot_iterator(),
                train_set_langs))

        dev_samplers = []
        for dev_set in dev_set_langs:
            dev_samplers.append(
                list(
                    map(
                        lambda x: iter(x.repeat().batch(1).
                                       make_one_shot_iterator()), dev_set)))

        global_step = tf.train.get_or_create_global_step()
        learning_rate = tf.constant(value=FLAGS.learning_rate,
                                    shape=[],
                                    dtype=tf.float32)
        learning_rate = tf.train.polynomial_decay(learning_rate,
                                                  global_step,
                                                  num_train_steps,
                                                  end_learning_rate=0.0,
                                                  power=1.0,
                                                  cycle=False)

        #uncomment this to enable warm-up steps
        '''if num_warmup_steps:
        global_steps_int = tf.cast(global_step, tf.int32)
        warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)

        global_steps_float = tf.cast(global_steps_int, tf.float32)
        warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

        warmup_percent_done = global_steps_float / warmup_steps_float
        warmup_learning_rate = FLAGS.learning_rate * warmup_percent_done

        is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
        learning_rate = (
                (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)'''

        optimizer = opt.AdamWeightDecayOptimizer(
            learning_rate=learning_rate,
            weight_decay_rate=0.01,
            beta_1=0.9,
            beta_2=0.999,
            epsilon=1e-6,
            exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
        if FLAGS.use_tpu:
            optimizer = contrib_tpu.CrossShardOptimizer(optimizer)

        #do MultiDDS training
        #initialize sample distribution phis
        phi = tf.get_variable(
            "phi", [11],
            initializer=tf.truncated_normal_initializer(stddev=0.02))

        opt_scorer = tf.train.AdamOptimizer(learning_rate=0.001,
                                            beta1=0.9,
                                            beta2=0.999,
                                            epsilon=1e-08,
                                            use_locking=False,
                                            name='Adam')

        while global_step < num_train_steps:
            if not tf.equal(global_step, 0):
                lang_sample_dist = list(tf.nn.softmax(phi).numpy())

            #load training data with phi
            logging.info('We are sampling from train data')
            data_lst = []
            while len(data_lst) < FLAGS.M:
                #choose a langue to sample
                cur_lang = sample_lang_id(lang_sample_dist)
                data_lst.append(train_samplers[cur_lang[0]].get_next())

            logging.info('Train mBert for multiple steps')
            for data in data_lst:
                with tf.GradientTape() as tape:
                    tvars, loss = model_fn(data, _,
                                           tf.estimator.ModeKeys.TRAIN, _,
                                           global_step)
                    #print(loss)

                grads = tape.gradient(loss, tvars)
                (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
                optimizer.apply_gradients(zip(grads, tvars),
                                          global_step=global_step)

            logging.info('Estimate the effect of each language')
            rewards = []
            for i in range(NUM_LANGS):
                gradient_dev = 0
                gradient_train = 0
                #Some languages might not have samples
                try:
                    train_test = train_samplers[i].get_next()
                    with tf.GradientTape() as tape:
                        tvars, loss = model_fn(data, _,
                                               tf.estimator.ModeKeys.TRAIN, _,
                                               global_step)

                    grads = tape.gradient(loss, tvars)
                    gradient_train = grads
                    #Not sure whether to add this line or not
                    #TODO: modify me to allow functions
                    (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
                    optimizer.apply_gradients(zip(grads, tvars),
                                              global_step=global_step)

                    logging.info("Testing effect on other languages")

                    for k in range(len(dev_samplers)):
                        for j in range(NUM_LANGS):
                            try:
                                dev_data = dev_samplers[k][j].get_next()
                                with tf.GradientTape() as tape:
                                    tvars, loss = model_fn(
                                        data, _, tf.estimator.ModeKeys.TRAIN,
                                        _, global_step)
                                grads = tape.gradient(loss, tvars)
                                gradient_dev += grads

                            except:
                                print(j, 'language not exist in dataset', k)
                except:
                    print("No data in this train language!!!")

                #append scores of each language to reward list
                print(gradient_train, gradient_dev)
                normalize_a = tf.nn.l2_normalize(gradient_dev, 0)
                normalize_b = tf.nn.l2_normalize(gradient_train, 0)
                cos_similarity = tf.reduce_sum(
                    tf.multiply(normalize_a, normalize_b))
                rewards.append(cos_similarity)

            logging.info("Optimize phi!")
            grad_phi = 0
            for i in range(NUM_LANGS):
                log_i = tf.log(tf.nn.softmax(phi))[i]
                with tf.GradientTape() as tape:
                    grads = tape.gradient(log_i, phi)
                grad_phi += grads * rewards[i]
            opt_scorer.apply_gradient(zip(grad_phi, phi),
                                      global_step=global_step)

            new_global_step = global_step + 1
            global_step.assign(new_global_step)

        #estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

    if FLAGS.do_predict:
        if not FLAGS.precomputed_predict_file:
            predict_examples_iter = preproc.read_tydi_examples(
                input_file=FLAGS.predict_file,
                is_training=False,
                max_passages=FLAGS.max_passages,
                max_position=FLAGS.max_position,
                fail_on_invalid=FLAGS.fail_on_invalid,
                open_fn=tf_io.gopen)
            shards_iter = write_tf_feature_files(predict_examples_iter)
        else:
            # Uses zeros for example and feature counts since they're unknown, and
            # we only use them for logging anyway.
            shards_iter = enumerate(
                ((f, 0, 0)
                 for f in tf.gfile.Glob(FLAGS.precomputed_predict_file)), 1)

        # Accumulates all of the prediction results to be written to the output.
        full_tydi_pred_dict = {}
        total_num_examples = 0
        for shard_num, (shard_filename, shard_num_examples,
                        shard_num_features) in shards_iter:
            total_num_examples += shard_num_examples
            logging.info(
                "Shard %d: Running prediction for %s; %d examples, %d features.",
                shard_num, shard_filename, shard_num_examples,
                shard_num_features)

            # Runs the model on the shard and store the individual results.
            # If running predict on TPU, you will need to specify the number of steps.
            predict_input_fn = tf_io.input_fn_builder(
                input_file=[shard_filename],
                seq_length=FLAGS.max_seq_length,
                is_training=False,
                drop_remainder=False)
            all_results = []
            for result in estimator.predict(predict_input_fn,
                                            yield_single_examples=True):
                if len(all_results) % 10000 == 0:
                    logging.info("Shard %d: Predicting for feature %d/%s",
                                 shard_num, len(all_results),
                                 shard_num_features)
                unique_id = int(result["unique_ids"])
                start_logits = [float(x) for x in result["start_logits"].flat]
                end_logits = [float(x) for x in result["end_logits"].flat]
                answer_type_logits = [
                    float(x) for x in result["answer_type_logits"].flat
                ]
                all_results.append(
                    tydi_modeling.RawResult(
                        unique_id=unique_id,
                        start_logits=start_logits,
                        end_logits=end_logits,
                        answer_type_logits=answer_type_logits))

            # Reads the prediction candidates from the (entire) prediction input file.
            candidates_dict = read_candidates(FLAGS.predict_file)
            predict_features = [
                tf.train.Example.FromString(r)
                for r in tf.python_io.tf_record_iterator(shard_filename)
            ]
            logging.info("Shard %d: Post-processing predictions.", shard_num)
            logging.info(
                "  Num candidate examples loaded (includes all shards): %d",
                len(candidates_dict))
            logging.info("  Num candidate features loaded: %d",
                         len(predict_features))
            logging.info("  Num prediction result features: %d",
                         len(all_results))
            logging.info("  Num shard features: %d", shard_num_features)

            tydi_pred_dict = postproc.compute_pred_dict(
                candidates_dict,
                predict_features, [r._asdict() for r in all_results],
                candidate_beam=FLAGS.candidate_beam)

            logging.info("Shard %d: Post-processed predictions.", shard_num)
            logging.info("  Num shard examples: %d", shard_num_examples)
            logging.info("  Num post-processed results: %d",
                         len(tydi_pred_dict))
            if shard_num_examples != len(tydi_pred_dict):
                logging.warning("  Num missing predictions: %d",
                                shard_num_examples - len(tydi_pred_dict))
            for key, value in tydi_pred_dict.items():
                if key in full_tydi_pred_dict:
                    logging.warning(
                        "ERROR: '%s' already in full_tydi_pred_dict!", key)
                full_tydi_pred_dict[key] = value

        logging.info("Prediction finished for all shards.")
        logging.info("  Total input examples: %d", total_num_examples)
        logging.info("  Total output predictions: %d",
                     len(full_tydi_pred_dict))

        with tf.gfile.Open(FLAGS.output_prediction_file, "w") as output_file:
            for prediction in full_tydi_pred_dict.values():
                output_file.write((json.dumps(prediction) + "\n").encode())