コード例 #1
0
 def input_fn(params):
     """The actual input function."""
     batch_size = params["batch_size"]
     max_seq_length = flags.max_seq_length
     name_to_features = {
         "input_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
         "input_mask": tf.io.FixedLenFeature([max_seq_length], tf.int64),
         "segment_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
         "instance_id": tf.io.FixedLenFeature([1], tf.int64),
     }
     return format_dataset(name_to_features, batch_size, is_training, flags,
                           input_files, num_cpu_threads)
コード例 #2
0
 def input_fn(params):
     """The actual input function."""
     batch_size = params["batch_size"]
     FixedLenFeature = tf.io.FixedLenFeature
     all_features = {
         "input_ids": FixedLenFeature([total_sequence_length], tf.int64),
         "input_mask": FixedLenFeature([total_sequence_length], tf.int64),
         "segment_ids": FixedLenFeature([total_sequence_length], tf.int64),
         "use_context": FixedLenFeature([1], tf.int64),
     }
     return format_dataset(all_features, batch_size, is_training, flags,
                           input_files, num_cpu_threads)
コード例 #3
0
ファイル: dict_model_fn.py プロジェクト: clover3/Chair
    def input_fn(params):
        """The actual input function."""
        batch_size = params["batch_size"]
        max_seq_length = flags.max_seq_length
        max_predictions_per_seq = flags.max_predictions_per_seq
        max_def_length = flags.max_def_length
        max_loc_length = flags.max_loc_length
        max_word_length = flags.max_word_length
        FixedLenFeature = tf.io.FixedLenFeature
        all_features = {
            "input_ids":    FixedLenFeature([max_seq_length], tf.int64),
            "input_mask":   FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids":  FixedLenFeature([max_seq_length], tf.int64),
            "d_input_ids":  FixedLenFeature([max_def_length], tf.int64),
            "d_input_mask": FixedLenFeature([max_def_length], tf.int64),
            "d_segment_ids": FixedLenFeature([max_def_length], tf.int64),
            "d_location_ids": FixedLenFeature([max_loc_length], tf.int64),
            "next_sentence_labels": FixedLenFeature([1], tf.int64),
            "masked_lm_positions": FixedLenFeature([max_predictions_per_seq], tf.int64),
            "masked_lm_ids": FixedLenFeature([max_predictions_per_seq], tf.int64),
            "lookup_idx": FixedLenFeature([1], tf.int64),
            "selected_word": FixedLenFeature([max_word_length], tf.int64),
        }

        active_feature = ["input_ids", "input_mask", "segment_ids",
                          "d_input_ids", "d_input_mask", "d_location_ids",
                          "next_sentence_labels"]

        if flags.fixed_mask:
            active_feature.append("masked_lm_positions")
            active_feature.append("masked_lm_ids")
        if flags.train_op == "lookup":
            active_feature.append("masked_lm_positions")
            active_feature.append("masked_lm_ids")
            active_feature.append("lookup_idx")
        elif flags.train_op == "entry_prediction":
            active_feature.append("masked_lm_positions")
            active_feature.append("masked_lm_ids")
            active_feature.append("lookup_idx")

        if flags.use_d_segment_ids:
            active_feature.append("d_segment_ids")

        if max_word_length > 0:
            active_feature.append("selected_word")

        selected_features = {k:all_features[k] for k in active_feature}

        # For training, we want a lot of parallel reading and shuffling.
        # For eval, we want no shuffling and parallel reading doesn't matter.
        return format_dataset(selected_features, batch_size, is_training, flags, input_files, num_cpu_threads)
コード例 #4
0
    def input_fn(params):
        """The actual input function."""
        batch_size = params["batch_size"]
        all_features = {}
        all_features.update(get_lm_basic_features(flags))
        all_features.update(get_lm_mask_features(flags))

        active_feature = [
            "input_ids", "input_mask", "segment_ids", "next_sentence_labels",
            "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"
        ]
        selected_features = {k: all_features[k] for k in active_feature}
        FixedLenFeature = tf.io.FixedLenFeature
        max_predictions_per_seq = flags.max_predictions_per_seq
        selected_features["loss_base"] = FixedLenFeature(
            [max_predictions_per_seq], tf.float32)
        selected_features["loss_target"] = FixedLenFeature(
            [max_predictions_per_seq], tf.float32)
        return format_dataset(selected_features, batch_size, is_training,
                              flags, input_files, num_cpu_threads)
コード例 #5
0
    def input_fn(params):
        """The actual input function."""
        batch_size = params["batch_size"]
        max_sequence_length = flags.max_seq_length
        FixedLenFeature = tf.io.FixedLenFeature
        all_features = {
            "input_ids": FixedLenFeature([max_sequence_length], tf.int64),
            "input_mask": FixedLenFeature([max_sequence_length], tf.int64),
            "segment_ids": FixedLenFeature([max_sequence_length], tf.int64),
            "nli_input_ids": FixedLenFeature([max_sequence_length], tf.int64),
            "nli_input_mask": FixedLenFeature([max_sequence_length], tf.int64),
            "nli_segment_ids": FixedLenFeature([max_sequence_length],
                                               tf.int64),
            "label_ids": FixedLenFeature([1], tf.int64),
        }
        if use_next_sentence_labels:
            all_features["next_sentence_labels"] = FixedLenFeature([1],
                                                                   tf.int64)

        return format_dataset(all_features, batch_size, is_training, flags,
                              input_files, num_cpu_threads)
コード例 #6
0
    def input_fn(params):
        """The actual input function."""
        batch_size = params["batch_size"]

        name_to_features = dict({
            "input_ids1":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask1":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids1":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "input_ids2":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask2":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids2":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
        })
        name_to_features["strict_good"] = tf.io.FixedLenFeature([1], tf.int64)
        name_to_features["strict_bad"] = tf.io.FixedLenFeature([1], tf.int64)
        return format_dataset(name_to_features, batch_size, is_training, flags,
                              input_files, num_cpu_threads)