Пример #1
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.gfile.Glob(input_pattern))

    tf.logging.info("*** Reading from input files ***")
    for input_file in input_files:
        tf.logging.info("  %s", input_file)

    rng = random.Random(FLAGS.random_seed)
    instances = create_training_instances(input_files, tokenizer,
                                          FLAGS.max_seq_length,
                                          FLAGS.dupe_factor,
                                          FLAGS.short_seq_prob,
                                          FLAGS.masked_lm_prob,
                                          FLAGS.max_predictions_per_seq, rng)

    output_files = FLAGS.output_file.split(",")
    tf.logging.info("*** Writing to output files ***")
    for output_file in output_files:
        tf.logging.info("  %s", output_file)

    write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
                                    FLAGS.max_predictions_per_seq,
                                    output_files)
def create_tokenizer_from_hub_module(bert_hub_module_handle):
  """Get the vocab file and casing info from the Hub module."""
  with tf.Graph().as_default():
    bert_module = hub.Module(bert_hub_module_handle)
    tokenization_info = bert_module(signature="tokenization_info", as_dict=True)
    with tf.Session() as sess:
      vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"],
                                            tokenization_info["do_lower_case"]])
  return tokenization.FullTokenizer(
      vocab_file=vocab_file, do_lower_case=do_lower_case)
Пример #3
0
def main(argv: Sequence[str]) -> None:
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Download FEVER data
    raw_data = {}
    for split_name, url in URL_DICT.items():
        logging.info('Downloading %s split', split_name)
        with request.urlopen(url) as open_url:
            sample_list = []
            lines = open_url.readlines()
            for line in lines:
                sample_list.append(json.loads(line))
            raw_data[split_name] = sample_list

    # Process FEVER data
    tokenizer = bert_tokenization.FullTokenizer(FLAGS.vocab_path,
                                                do_lower_case=True)

    spacy_model = None
    if spacy_model is None:
        spacy_model = spacy.load('en_core_web_md')

    processed_data = {}

    for split_name, split_data in raw_data.items():
        logging.info('Processing %s split', split_name)
        processed_split_data = process_data(split_data, spacy_model, tokenizer)
        processed_data[split_name] = processed_split_data

    # Create TFRecords
    tf.io.gfile.makedirs(FLAGS.save_dir)
    for split_name, split_data in processed_data.items():
        file_path = os.path.join(FLAGS.save_dir, split_name)
        logging.info('Writing %s split to %s', split_name, file_path)
        writer = tf.io.TFRecordWriter(file_path)
        for sample in split_data:
            features = tf.train.Features(
                feature={
                    key: tf.train.Feature(int64_list=tf.train.Int64List(
                        value=value))
                    for key, value in sample.items()
                })

            record_bytes = tf.train.Example(
                features=features).SerializeToString()
            writer.write(record_bytes)
Пример #4
0
def main(argv: Sequence[str]) -> None:
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tokenizer = bert_tokenization.FullTokenizer(FLAGS.vocab_path,
                                                do_lower_case=True)

    spacy_model = None
    if spacy_model is None:
        spacy_model = spacy.load('en_core_web_md')

    raw_data = {}
    for split_file_name in FLAGS.split_file_names:
        path = os.path.join(FLAGS.data_dir, split_file_name + '.json')
        with tf.io.gfile.GFile(path, 'rb') as data_file:
            raw_data[split_file_name] = json.load(data_file)

    processed_data = {}
    relation_vocab = {}
    for split_name, split_data in raw_data.items():
        logging.info('Processing %s split', split_name)
        processed_split_data, relation_vocab = process_data(
            split_data, relation_vocab, spacy_model, tokenizer)
        processed_data[split_name] = processed_split_data

    # Create TFRecords
    tf.io.gfile.makedirs(FLAGS.save_dir)
    for split_name, split_data in processed_data.items():
        file_path = os.path.join(FLAGS.save_dir, split_name)
        logging.info('Writing %s split to %s', split_name, file_path)
        writer = tf.io.TFRecordWriter(file_path)
        for sample in split_data:
            features = tf.train.Features(
                feature={
                    key: tf.train.Feature(int64_list=tf.train.Int64List(
                        value=value))
                    for key, value in sample.items()
                })

            record_bytes = tf.train.Example(
                features=features).SerializeToString()
            writer.write(record_bytes)

    # save label vocab
    vocab_path = os.path.join(FLAGS.save_dir, 'relation_vocab.json')
    with tf.io.gfile.GFile(vocab_path, 'w+') as vocab_file:
        json.dump(relation_vocab, vocab_file)
Пример #5
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  processors = {
      "cola": ColaProcessor,
      "mnli": MnliProcessor,
      "mrpc": MrpcProcessor,
      "xnli": XnliProcessor,
  }

  tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                FLAGS.init_checkpoint)

  if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
    raise ValueError(
        "At least one of `do_train`, `do_eval` or `do_predict' must be True.")

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  if FLAGS.max_seq_length > bert_config.max_position_embeddings:
    raise ValueError(
        "Cannot use sequence length %d because the BERT model "
        "was only trained up to sequence length %d" %
        (FLAGS.max_seq_length, bert_config.max_position_embeddings))

  tf.gfile.MakeDirs(FLAGS.output_dir)

  task_name = FLAGS.task_name.lower()

  if task_name not in processors:
    raise ValueError("Task not found: %s" % (task_name))

  processor = processors[task_name]()

  label_list = processor.get_labels()

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  tpu_cluster_resolver = None
  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2
  run_config = contrib_tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=FLAGS.output_dir,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      tpu_config=contrib_tpu.TPUConfig(
          tpu_job_name=FLAGS.tpu_job_name,
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_tpu_cores,
          per_host_input_for_training=is_per_host))

  train_examples = None
  num_train_steps = None
  num_warmup_steps = None
  if FLAGS.do_train:
    train_examples = processor.get_train_examples(FLAGS.data_dir)
    num_train_steps = int(
        len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

  model_fn = model_fn_builder(
      bert_config=bert_config,
      num_labels=len(label_list),
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_tpu)

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  estimator = contrib_tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size,
      predict_batch_size=FLAGS.predict_batch_size)

  if FLAGS.do_train:
    train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
    file_based_convert_examples_to_features(
        train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
    tf.logging.info("***** Running training *****")
    tf.logging.info("  Num examples = %d", len(train_examples))
    tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    tf.logging.info("  Num steps = %d", num_train_steps)
    train_input_fn = file_based_input_fn_builder(
        input_file=train_file,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True)
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

  if FLAGS.do_eval:
    eval_examples = processor.get_dev_examples(FLAGS.data_dir)
    num_actual_eval_examples = len(eval_examples)
    if FLAGS.use_tpu:
      # TPU requires a fixed batch size for all batches, therefore the number
      # of examples must be a multiple of the batch size, or else examples
      # will get dropped. So we pad with fake examples which are ignored
      # later on. These do NOT count towards the metric (all tf.metrics
      # support a per-instance weight, and these get a weight of 0.0).
      while len(eval_examples) % FLAGS.eval_batch_size != 0:
        eval_examples.append(PaddingInputExample())

    eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
    file_based_convert_examples_to_features(
        eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)

    tf.logging.info("***** Running evaluation *****")
    tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                    len(eval_examples), num_actual_eval_examples,
                    len(eval_examples) - num_actual_eval_examples)
    tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

    # This tells the estimator to run through the entire set.
    eval_steps = None
    # However, if running eval on the TPU, you will need to specify the
    # number of steps.
    if FLAGS.use_tpu:
      assert len(eval_examples) % FLAGS.eval_batch_size == 0
      eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)

    eval_drop_remainder = True if FLAGS.use_tpu else False
    eval_input_fn = file_based_input_fn_builder(
        input_file=eval_file,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=eval_drop_remainder)

    result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)

    output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
    with tf.gfile.GFile(output_eval_file, "w") as writer:
      tf.logging.info("***** Eval results *****")
      for key in sorted(result.keys()):
        tf.logging.info("  %s = %s", key, str(result[key]))
        writer.write("%s = %s\n" % (key, str(result[key])))

  if FLAGS.do_predict:
    predict_examples = processor.get_test_examples(FLAGS.data_dir)
    num_actual_predict_examples = len(predict_examples)
    if FLAGS.use_tpu:
      # TPU requires a fixed batch size for all batches, therefore the number
      # of examples must be a multiple of the batch size, or else examples
      # will get dropped. So we pad with fake examples which are ignored
      # later on.
      while len(predict_examples) % FLAGS.predict_batch_size != 0:
        predict_examples.append(PaddingInputExample())

    predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
    file_based_convert_examples_to_features(predict_examples, label_list,
                                            FLAGS.max_seq_length, tokenizer,
                                            predict_file)

    tf.logging.info("***** Running prediction*****")
    tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                    len(predict_examples), num_actual_predict_examples,
                    len(predict_examples) - num_actual_predict_examples)
    tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    predict_drop_remainder = True if FLAGS.use_tpu else False
    predict_input_fn = file_based_input_fn_builder(
        input_file=predict_file,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=predict_drop_remainder)

    result = estimator.predict(input_fn=predict_input_fn)

    output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
    with tf.gfile.GFile(output_predict_file, "w") as writer:
      num_written_lines = 0
      tf.logging.info("***** Predict results *****")
      for (i, prediction) in enumerate(result):
        probabilities = prediction["probabilities"]
        if i >= num_actual_predict_examples:
          break
        output_line = "\t".join(
            str(class_probability)
            for class_probability in probabilities) + "\n"
        writer.write(output_line)
        num_written_lines += 1
    assert num_written_lines == num_actual_predict_examples
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  layer_indexes = [int(x) for x in FLAGS.layers.split(",")]

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2
  run_config = contrib_tpu.RunConfig(
      master=FLAGS.master,
      tpu_config=contrib_tpu.TPUConfig(
          num_shards=FLAGS.num_tpu_cores,
          per_host_input_for_training=is_per_host))

  examples = read_examples(FLAGS.input_file)

  features = convert_examples_to_features(
      examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer)

  unique_id_to_feature = {}
  for feature in features:
    unique_id_to_feature[feature.unique_id] = feature

  model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      layer_indexes=layer_indexes,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_one_hot_embeddings)

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  estimator = contrib_tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      predict_batch_size=FLAGS.batch_size)

  input_fn = input_fn_builder(
      features=features, seq_length=FLAGS.max_seq_length)

  with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file,
                                               "w")) as writer:
    for result in estimator.predict(input_fn, yield_single_examples=True):
      unique_id = int(result["unique_id"])
      feature = unique_id_to_feature[unique_id]
      output_json = collections.OrderedDict()
      output_json["linex_index"] = unique_id
      all_features = []
      for (i, token) in enumerate(feature.tokens):
        all_layers = []
        for (j, layer_index) in enumerate(layer_indexes):
          layer_output = result["layer_output_%d" % j]
          layers = collections.OrderedDict()
          layers["index"] = layer_index
          layers["values"] = [
              round(float(x), 6) for x in layer_output[i:(i + 1)].flat
          ]
          all_layers.append(layers)
        features = collections.OrderedDict()
        features["token"] = token
        features["layers"] = all_layers
        all_features.append(features)
      output_json["features"] = all_features
      writer.write(json.dumps(output_json) + "\n")
Пример #7
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  validate_flags_or_throw(bert_config)

  tf.gfile.MakeDirs(FLAGS.output_dir)

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  tpu_cluster_resolver = None
  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2
  run_config = contrib_tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=FLAGS.output_dir,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      tpu_config=contrib_tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_tpu_cores,
          per_host_input_for_training=is_per_host))

  train_examples = None
  num_train_steps = None
  num_warmup_steps = None
  if FLAGS.do_train:
    train_examples = read_squad_examples(
        input_file=FLAGS.train_file, is_training=True)
    num_train_steps = int(
        len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    # Pre-shuffle the input to avoid having to make a very large shuffle
    # buffer in in the `input_fn`.
    rng = random.Random(12345)
    rng.shuffle(train_examples)

  model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_tpu)

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  estimator = contrib_tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      predict_batch_size=FLAGS.predict_batch_size)

  if FLAGS.do_train:
    # We write to a temporary file to avoid storing very large constant tensors
    # in memory.
    train_writer = FeatureWriter(
        filename=os.path.join(FLAGS.output_dir, "train.tf_record"),
        is_training=True)
    convert_examples_to_features(
        examples=train_examples,
        tokenizer=tokenizer,
        max_seq_length=FLAGS.max_seq_length,
        doc_stride=FLAGS.doc_stride,
        max_query_length=FLAGS.max_query_length,
        is_training=True,
        output_fn=train_writer.process_feature)
    train_writer.close()

    tf.logging.info("***** Running training *****")
    tf.logging.info("  Num orig examples = %d", len(train_examples))
    tf.logging.info("  Num split examples = %d", train_writer.num_features)
    tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    tf.logging.info("  Num steps = %d", num_train_steps)
    del train_examples

    train_input_fn = input_fn_builder(
        input_file=train_writer.filename,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True)
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

  if FLAGS.do_predict:
    eval_examples = read_squad_examples(
        input_file=FLAGS.predict_file, is_training=False)

    eval_writer = FeatureWriter(
        filename=os.path.join(FLAGS.output_dir, "eval.tf_record"),
        is_training=False)
    eval_features = []

    def append_feature(feature):
      eval_features.append(feature)
      eval_writer.process_feature(feature)

    convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=FLAGS.max_seq_length,
        doc_stride=FLAGS.doc_stride,
        max_query_length=FLAGS.max_query_length,
        is_training=False,
        output_fn=append_feature)
    eval_writer.close()

    tf.logging.info("***** Running predictions *****")
    tf.logging.info("  Num orig examples = %d", len(eval_examples))
    tf.logging.info("  Num split examples = %d", len(eval_features))
    tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    all_results = []

    predict_input_fn = input_fn_builder(
        input_file=eval_writer.filename,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=False)

    # If running eval on the TPU, you will need to specify the number of
    # steps.
    all_results = []
    for result in estimator.predict(
        predict_input_fn, yield_single_examples=True):
      if len(all_results) % 1000 == 0:
        tf.logging.info("Processing example: %d" % (len(all_results)))
      unique_id = int(result["unique_ids"])
      start_logits = [float(x) for x in result["start_logits"].flat]
      end_logits = [float(x) for x in result["end_logits"].flat]
      all_results.append(
          RawResult(
              unique_id=unique_id,
              start_logits=start_logits,
              end_logits=end_logits))

    output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json")
    output_nbest_file = os.path.join(FLAGS.output_dir, "nbest_predictions.json")
    output_null_log_odds_file = os.path.join(FLAGS.output_dir, "null_odds.json")

    write_predictions(eval_examples, eval_features, all_results,
                      FLAGS.n_best_size, FLAGS.max_answer_length,
                      FLAGS.do_lower_case, output_prediction_file,
                      output_nbest_file, output_null_log_odds_file)
Пример #8
0
    def __init__(self,
                 data_sources,
                 env_config,
                 debug_writer=None,
                 worker_idx=0):
        """Construct the StreetView environment.

    Args:
      data_sources: A list of strings containing one or more of 'train', 'dev'
        and 'test'.
      env_config: The configuration object containing settings for the current
        run.
      debug_writer: Class for writing execution traces to dir.
      worker_idx: Restrict regions that this worker explores to those
        where worker_idx % region_idx == 0. Default 0 means include all regions.
        region_idx is defined as the order in a sorted list of region names.
    """
        super(StreetViewEnv, self).__init__()

        # Logging execution traces if enabled
        self._run_writer = debug_writer

        self._env_config = env_config

        # Action space
        self._panoramic_actions = env_config.panoramic_action_space
        self._panoramic_action_bins = env_config.panoramic_action_bins

        # Attributes prefixed by _all* are dictionaries indexed by region name
        self._all_regions = self._select_regions(env_config, worker_idx)
        self._all_graphs = self._init_graphs(env_config, self._all_regions)

        # Load entry sequences for each region, and activate the current region
        self._all_entry_sequences = self._init_entry_sequences(
            data_sources, env_config, self._all_regions)
        logging.info('Loaded %d regions with data sources %s',
                     len(self._all_entry_sequences), data_sources)
        for region_name, entry_sequence in self._all_entry_sequences.items():
            logging.info('Region %s has %d entries', region_name,
                         len(entry_sequence))

        self._all_feature_loaders = self._init_feature_loaders(
            env_config, self._all_regions)

        # Using BERT word piece tokenizer.
        self._tokenizer = tokenization.FullTokenizer(
            vocab_file=env_config.vocab_file, do_lower_case=True)
        # Constants from env_config.
        self._max_actions_per_episode = env_config.max_agent_actions
        self._instruction_tok_len = env_config.instruction_tok_len

        # Constant parameter
        self._base_yaw_angle = env_config.base_yaw_angle

        # Class members.
        # Variables that identify current instruction (assigned in reset)
        self._graph = None
        self._current_region = ''
        self._current_sequence_idx = -1
        self._current_entry_idx = 0

        # State variables
        self._frame_count = 0
        self._goal_pano_id = -1
        self._graph_state = GraphState(-1, 0., 0., 0.)
        self._distance_to_goal = 0.

        # Information about current instruction
        self._golden_actions = None
        self._golden_path = None
Пример #9
0
def main(argv: Sequence[str]) -> None:
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tokenizer = bert_tokenization.FullTokenizer(FLAGS.vocab_path,
                                                do_lower_case=True)

    spacy_model = None
    if spacy_model is None:
        spacy_model = spacy.load('en_core_web_md')

    def get_feature(sample, feature_name, idx=0):
        feature = sample.features.feature[feature_name]
        return getattr(feature, feature.WhichOneof('kind')).value[idx]

    path = os.path.join(FLAGS.data_dir, 'webred_21.tfrecord')
    samples = []
    dataset = tf.data.TFRecordDataset(path)
    for raw_sample in dataset:
        sample = {}
        example = tf.train.Example()
        example.ParseFromString(raw_sample.numpy())
        sample['annotated_text'] = get_feature(example,
                                               'sentence').decode('utf-8')
        sample['relation'] = get_feature(example,
                                         'relation_name').decode('utf-8')
        sample['num_pos_raters'] = get_feature(example, 'num_pos_raters')
        samples.append(sample)

    np.random.seed(0)
    shuffled_indices = np.random.permutation(len(samples))
    shuffled_samples = [samples[idx] for idx in shuffled_indices]
    raw_data = {}
    eval_split_size = int(0.1 * len(samples))
    raw_data['test'] = shuffled_samples[:eval_split_size]
    raw_data['dev'] = shuffled_samples[eval_split_size:2 * eval_split_size]
    raw_data['train'] = shuffled_samples[2 * eval_split_size:]

    processed_data = {}
    relation_vocab = {}

    for split_name, split_data in raw_data.items():
        logging.info('Processing %s split', split_name)
        processed_split_data, relation_vocab = process_data(
            split_data, relation_vocab, spacy_model, tokenizer)
        processed_data[split_name] = processed_split_data

    # Create TFRecords
    tf.io.gfile.makedirs(FLAGS.save_dir)
    for split_name, split_data in processed_data.items():
        file_path = os.path.join(FLAGS.save_dir, split_name)
        logging.info('Writing %s split to %s', split_name, file_path)
        writer = tf.io.TFRecordWriter(file_path)
        for sample in split_data:
            features = tf.train.Features(
                feature={
                    key: tf.train.Feature(int64_list=tf.train.Int64List(
                        value=value))
                    for key, value in sample.items()
                })

            record_bytes = tf.train.Example(
                features=features).SerializeToString()
            writer.write(record_bytes)

    # save label vocab
    vocab_path = os.path.join(FLAGS.save_dir, 'relation_vocab.json')
    with tf.io.gfile.GFile(vocab_path, 'w+') as vocab_file:
        json.dump(relation_vocab, vocab_file)