示例#1
0
    def _serialize_dataset(self, tasks, is_training, split):
        """Write out the dataset as tfrecords."""
        dataset_name = "_".join(sorted([task.name for task in tasks]))
        dataset_name += "_" + split
        dataset_prefix = os.path.join(
            self._config.preprocessed_data_dir, 'tfrecord')
        tfrecords_path = dataset_prefix + ".tfrecord"
        metadata_path = dataset_prefix + ".metadata"
        batch_size = (self._config.train_batch_size if is_training else
                      self._config.eval_batch_size)

        utils.log("Loading dataset")
        n_examples = None

        if n_examples is None:
            utils.log("Existing tfrecords not found so creating")
            examples = []
            for task in tasks:
                task_examples = task.get_examples(split)
                examples += task_examples
            utils.mkdir(tfrecords_path.rsplit("/", 1)[0])
            n_examples = self.serialize_examples(
                examples, is_training, tfrecords_path, batch_size)
            utils.write_json({"n_examples": n_examples}, metadata_path)

        input_fn = self._input_fn_builder(tfrecords_path, is_training)
        if is_training:
            steps = int(n_examples // batch_size * self._config.num_train_epochs)
        else:
            steps = n_examples // batch_size

        return input_fn, steps
示例#2
0
 def log(self, run_values=None):
     step = self._global_step if self._is_training else self._steps_run_so_far
     if step % self._log_every != 0:
         return
     msg = "{:}/{:} = {:.1f}%".format(step, self._n_steps,
                                      100.0 * step / self._n_steps)
     time_elapsed = time.time() - self._start_time
     time_per_step = time_elapsed / (
         (step - self._start_step) if self._is_training else step)
     msg += ", SPS: {:.1f}".format(1 / time_per_step)
     msg += ", ELAP: " + secs_to_str(time_elapsed)
     msg += ", ETA: " + secs_to_str((self._n_steps - step) * time_per_step)
     if run_values is not None:
         for tag, value in run_values.results.items():
             msg += " - " + str(tag) + (": {:.4f}".format(value))
     utils.log(msg)
示例#3
0
 def serialize_examples(self, examples, is_training, output_file, batch_size):
     """Convert a set of `InputExample`s to a TFRecord file."""
     n_examples = 0
     with tf.io.TFRecordWriter(output_file) as writer:
         for (ex_index, example) in enumerate(examples):
             if ex_index % 2000 == 0:
                 utils.log("Writing example {:} of {:}".format(
                     ex_index, len(examples)))
             for tf_example in self._example_to_tf_example(
                     example, is_training,
                     log=self._config.log_examples and ex_index < 1):
                 writer.write(tf_example.SerializeToString())
                 n_examples += 1
         # add padding so the dataset is a multiple of batch_size
         while n_examples % batch_size != 0:
             writer.write(self._make_tf_example(task_id=len(self._config.task_names))
                          .SerializeToString())
             n_examples += 1
     return n_examples
示例#4
0
    def get_examples(self, split):
        if split in self._examples:
            return self._examples[split]

        path = os.path.join(self.config.raw_data_dir, split + ".json")
        # path = self.config.raw_data_dir
        # print(path)
        # input()
        with tf.io.gfile.GFile(path, "r") as f:
            input_data = json.load(f)["data"]

        examples = []
        example_failures = [0]
        for entry in input_data:
            for paragraph in entry["paragraphs"]:
                self._add_examples(examples, example_failures, paragraph,
                                   split)
        self._examples[split] = examples
        utils.log("{:} examples created, {:} failures".format(
            len(examples), example_failures[0]))
        return examples
    def model_fn(features, labels, mode, params):
        """The `model_fn` for TPUEstimator."""
        utils.log("Building model...")
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        model = FinetuningModel(config, tasks, is_training, features,
                                num_train_steps)

        if pretraining_config is not None:
            # init_checkpoint = tf.train.latest_checkpoint(pretraining_config.model_dir)
            init_checkpoint = pretraining_config['checkpoint']
            utils.log("Using checkpoint", init_checkpoint)
        tvars = tf.trainable_variables()
        scaffold_fn = None
        if init_checkpoint:
            assignment_map, _ = modeling.get_assignment_map_from_checkpoint(
                tvars, init_checkpoint)
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        # Build model for training or prediction
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(
                model.loss,
                config.learning_rate,
                num_train_steps,
                weight_decay_rate=config.weight_decay_rate,
                use_tpu=config.use_tpu,
                warmup_proportion=config.warmup_proportion,
                layerwise_lr_decay_power=config.layerwise_lr_decay,
                n_transformer_layers=model.bert_config.num_hidden_layers)
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=model.loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn,
                training_hooks=[
                    training_utils.ETAHook(
                        {} if config.use_tpu else dict(loss=model.loss),
                        num_train_steps, config.iterations_per_loop,
                        config.use_tpu, 10)
                ])
        else:
            assert mode == tf.estimator.ModeKeys.PREDICT
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions=utils.flatten_dict(model.outputs),
                scaffold_fn=scaffold_fn)

        utils.log("Building complete")
        return output_spec
示例#6
0
def get_final_text(config: electra.configure_finetuning.FinetuningConfig, pred_text,
                   orig_text):
    """Project the tokenized prediction back to the original text."""

    # When we created the data, we kept track of the alignment between original
    # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
    # now `orig_text` contains the span of our original text corresponding to the
    # span that we predicted.
    #
    # However, `orig_text` may contain extra characters that we don't want in
    # our prediction.
    #
    # For example, let's say:
    #   pred_text = steve smith
    #   orig_text = Steve Smith's
    #
    # We don't want to return `orig_text` because it contains the extra "'s".
    #
    # We don't want to return `pred_text` because it's already been normalized
    # (the SQuAD eval script also does punctuation stripping/lower casing but
    # our tokenizer does additional normalization like stripping accent
    # characters).
    #
    # What we really want to return is "Steve Smith".
    #
    # Therefore, we have to apply a semi-complicated alignment heruistic between
    # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
    # can fail in certain cases in which case we just return `orig_text`.

    def _strip_spaces(text):
        ns_chars = []
        ns_to_s_map = collections.OrderedDict()
        for i, c in enumerate(text):
            if c == " ":
                continue
            ns_to_s_map[len(ns_chars)] = i
            ns_chars.append(c)
        ns_text = "".join(ns_chars)
        return ns_text, dict(ns_to_s_map)

    # We first tokenize `orig_text`, strip whitespace from the result
    # and `pred_text`, and check if they are the same length. If they are
    # NOT the same length, the heuristic has failed. If they are the same
    # length, we assume the characters are one-to-one aligned.
    tokenizer = tokenization.BasicTokenizer(do_lower_case=config.do_lower_case)

    tok_text = " ".join(tokenizer.tokenize(orig_text))

    start_position = tok_text.find(pred_text)
    if start_position == -1:
        if config.debug:
            utils.log(
                "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
        return orig_text
    end_position = start_position + len(pred_text) - 1

    (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
    (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)

    if len(orig_ns_text) != len(tok_ns_text):
        if config.debug:
            utils.log("Length not equal after stripping spaces: '%s' vs '%s'",
                      orig_ns_text, tok_ns_text)
        return orig_text

    # We then project the characters in `pred_text` back to `orig_text` using
    # the character-to-character alignment.
    tok_s_to_ns_map = {}
    for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
        tok_s_to_ns_map[tok_index] = i

    orig_start_position = None
    if start_position in tok_s_to_ns_map:
        ns_start_position = tok_s_to_ns_map[start_position]
        if ns_start_position in orig_ns_to_s_map:
            orig_start_position = orig_ns_to_s_map[ns_start_position]

    if orig_start_position is None:
        if config.debug:
            utils.log("Couldn't map start position")
        return orig_text

    orig_end_position = None
    if end_position in tok_s_to_ns_map:
        ns_end_position = tok_s_to_ns_map[end_position]
        if ns_end_position in orig_ns_to_s_map:
            orig_end_position = orig_ns_to_s_map[ns_end_position]

    if orig_end_position is None:
        if config.debug:
            utils.log("Couldn't map end position")
        return orig_text

    output_text = orig_text[orig_start_position:(orig_end_position + 1)]
    return output_text
示例#7
0
    def featurize(self,
                  example: QAExample,
                  is_training,
                  log=False,
                  for_eval=False):
        all_features = []
        query_tokens = self._tokenizer.tokenize(example.question_text)

        if len(query_tokens) > self.config.max_query_length:
            query_tokens = query_tokens[0:self.config.max_query_length]

        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, token) in enumerate(example.doc_tokens):
            orig_to_tok_index.append(len(all_doc_tokens))
            sub_tokens = self._tokenizer.tokenize(token)
            for sub_token in sub_tokens:
                tok_to_orig_index.append(i)
                all_doc_tokens.append(sub_token)

        tok_start_position = None
        tok_end_position = None
        if is_training and example.is_impossible:
            tok_start_position = -1
            tok_end_position = -1
        if is_training and not example.is_impossible:
            tok_start_position = orig_to_tok_index[example.start_position]
            if example.end_position < len(example.doc_tokens) - 1:
                tok_end_position = orig_to_tok_index[example.end_position +
                                                     1] - 1
            else:
                tok_end_position = len(all_doc_tokens) - 1
            (tok_start_position, tok_end_position) = _improve_answer_span(
                all_doc_tokens, tok_start_position, tok_end_position,
                self._tokenizer, example.orig_answer_text)

        # The -3 accounts for [CLS], [SEP] and [SEP]
        max_tokens_for_doc = self.config.max_seq_length - len(query_tokens) - 3

        # We can have documents that are longer than the maximum sequence length.
        # To deal with this we do a sliding window approach, where we take chunks
        # of the up to our max length with a stride of `doc_stride`.
        _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
            "DocSpan", ["start", "length"])
        doc_spans = []
        start_offset = 0
        while start_offset < len(all_doc_tokens):
            length = len(all_doc_tokens) - start_offset
            if length > max_tokens_for_doc:
                length = max_tokens_for_doc
            doc_spans.append(_DocSpan(start=start_offset, length=length))
            if start_offset + length == len(all_doc_tokens):
                break
            start_offset += min(length, self.config.doc_stride)

        for (doc_span_index, doc_span) in enumerate(doc_spans):
            tokens = []
            token_to_orig_map = {}
            token_is_max_context = {}
            segment_ids = []
            tokens.append("[CLS]")
            segment_ids.append(0)
            for token in query_tokens:
                tokens.append(token)
                segment_ids.append(0)
            tokens.append("[SEP]")
            segment_ids.append(0)

            for i in range(doc_span.length):
                split_token_index = doc_span.start + i
                token_to_orig_map[len(
                    tokens)] = tok_to_orig_index[split_token_index]

                is_max_context = _check_is_max_context(doc_spans,
                                                       doc_span_index,
                                                       split_token_index)
                token_is_max_context[len(tokens)] = is_max_context
                tokens.append(all_doc_tokens[split_token_index])
                segment_ids.append(1)
            tokens.append("[SEP]")
            segment_ids.append(1)

            input_ids = self._tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < self.config.max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)

            assert len(input_ids) == self.config.max_seq_length
            assert len(input_mask) == self.config.max_seq_length
            assert len(segment_ids) == self.config.max_seq_length

            start_position = None
            end_position = None
            if is_training and not example.is_impossible:
                # For training, if our document chunk does not contain an annotation
                # we throw it out, since there is nothing to predict.
                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1
                out_of_span = False
                if not (tok_start_position >= doc_start
                        and tok_end_position <= doc_end):
                    out_of_span = True
                if out_of_span:
                    start_position = 0
                    end_position = 0
                else:
                    doc_offset = len(query_tokens) + 2
                    start_position = tok_start_position - doc_start + doc_offset
                    end_position = tok_end_position - doc_start + doc_offset

            if is_training and example.is_impossible:
                start_position = 0
                end_position = 0

            if log:
                utils.log("*** Example ***")
                utils.log("doc_span_index: %s" % doc_span_index)
                utils.log(
                    "tokens: %s" %
                    " ".join([tokenization.printable_text(x) for x in tokens]))
                utils.log("token_to_orig_map: %s" % " ".join([
                    "%d:%d" % (x, y)
                    for (x, y) in six.iteritems(token_to_orig_map)
                ]))
                utils.log("token_is_max_context: %s" % " ".join([
                    "%d:%s" % (x, y)
                    for (x, y) in six.iteritems(token_is_max_context)
                ]))
                utils.log("input_ids: %s" %
                          " ".join([str(x) for x in input_ids]))
                utils.log("input_mask: %s" %
                          " ".join([str(x) for x in input_mask]))
                utils.log("segment_ids: %s" %
                          " ".join([str(x) for x in segment_ids]))
                if is_training and example.is_impossible:
                    utils.log("impossible example")
                if is_training and not example.is_impossible:
                    answer_text = " ".join(
                        tokens[start_position:(end_position + 1)])
                    utils.log("start_position: %d" % start_position)
                    utils.log("end_position: %d" % end_position)
                    utils.log("answer: %s" %
                              (tokenization.printable_text(answer_text)))

            features = {
                "task_id": self.config.task_names.index(self.name),
                self.name + "_eid": (1000 * example.eid) + doc_span_index,
                "input_ids": input_ids,
                "input_mask": input_mask,
                "segment_ids": segment_ids,
            }
            if for_eval:
                features.update({
                    self.name + "_doc_span_index":
                    doc_span_index,
                    self.name + "_tokens":
                    tokens,
                    self.name + "_token_to_orig_map":
                    token_to_orig_map,
                    self.name + "_token_is_max_context":
                    token_is_max_context,
                })
            if is_training:
                features.update({
                    self.name + "_start_positions":
                    start_position,
                    self.name + "_end_positions":
                    end_position,
                    self.name + "_is_impossible":
                    example.is_impossible
                })
            all_features.append(features)
        return all_features
示例#8
0
    def _add_examples(self, examples, example_failures, paragraph, split):
        paragraph_text = paragraph["context"]
        doc_tokens = []
        char_to_word_offset = []
        prev_is_whitespace = True
        for c in paragraph_text:
            if is_whitespace(c):
                prev_is_whitespace = True
            else:
                if prev_is_whitespace:
                    doc_tokens.append(c)
                else:
                    doc_tokens[-1] += c
                prev_is_whitespace = False
            char_to_word_offset.append(len(doc_tokens) - 1)

        for qa in paragraph["qas"]:
            qas_id = qa["id"] if "id" in qa else None
            qid = qa["qid"] if "qid" in qa else None
            question_text = qa["question"]
            start_position = None
            end_position = None
            orig_answer_text = None
            is_impossible = False
            if split == "train":
                if self.v2:
                    is_impossible = qa["is_impossible"]
                if not is_impossible:
                    if "detected_answers" in qa:  # MRQA format
                        answer = qa["detected_answers"][0]
                        answer_offset = answer["char_spans"][0][0]
                    else:  # SQuAD format
                        answer = qa["answers"][0]
                        answer_offset = answer["answer_start"]
                    orig_answer_text = answer["text"]
                    answer_length = len(orig_answer_text)
                    start_position = char_to_word_offset[answer_offset]
                    if answer_offset + answer_length - 1 >= len(
                            char_to_word_offset):
                        utils.log("End position is out of document!")
                        example_failures[0] += 1
                        continue
                    end_position = char_to_word_offset[answer_offset +
                                                       answer_length - 1]

                    # Only add answers where the text can be exactly recovered from the
                    # document. If this CAN'T happen it's likely due to weird Unicode
                    # stuff so we will just skip the example.
                    #
                    # Note that this means for training mode, every example is NOT
                    # guaranteed to be preserved.
                    actual_text = " ".join(
                        doc_tokens[start_position:(end_position + 1)])
                    cleaned_answer_text = " ".join(
                        tokenization.whitespace_tokenize(orig_answer_text))
                    actual_text = actual_text.lower()
                    cleaned_answer_text = cleaned_answer_text.lower()
                    if actual_text.find(cleaned_answer_text) == -1:
                        utils.log("Could not find answer: '{:}' in doc vs. "
                                  "'{:}' in provided answer".format(
                                      tokenization.printable_text(actual_text),
                                      tokenization.printable_text(
                                          cleaned_answer_text)))
                        example_failures[0] += 1
                        continue
                else:
                    start_position = -1
                    end_position = -1
                    orig_answer_text = ""
            # print('======= QA_TASK =========')
            # print(self.name)
            example = QAExample(task_name=self.name,
                                eid=len(examples),
                                qas_id=qas_id,
                                qid=qid,
                                question_text=question_text,
                                doc_tokens=doc_tokens,
                                orig_answer_text=orig_answer_text,
                                start_position=start_position,
                                end_position=end_position,
                                is_impossible=is_impossible)
            # print(example)
            examples.append(example)