示例#1
0
 def test_size(self):
     dset = dataset.tuple_dataset(self.context,
                                  self.clean_examples, [str, str],
                                  field_separator='|',
                                  entity_separator=',',
                                  normalize_outputs=False)
     self.assertEqual(len(self.as_list(dset)), len(self.clean_examples))
示例#2
0
文件: nell995.py 项目: yyht/language
 def dset_fn():
     dset = dataset.tuple_dataset(context, full_filename, data_specs)
     if shuffle:
         dset = dset.shuffle(1000, seed=SEED)
     if n_take > 0:
         dset = dset.take(n_take)
     dset = dset.batch(FLAGS.minibatch_size)
     dset = dset.repeat(epochs)
     return dset.map(lambda _, q, s, ans: ({'seeds': s}, ans))
示例#3
0
 def test_idempotent_file_usage(self):
     data_tempfile = tempfile.mktemp('tuple_data')
     with open(data_tempfile, 'w') as fh:
         fh.write('\n'.join(self.clean_examples))
     dset = dataset.tuple_dataset(self.context,
                                  data_tempfile, [str, str],
                                  field_separator='|',
                                  entity_separator=',',
                                  normalize_outputs=False)
     dset = dset.repeat(2)
     self.assertEqual(len(self.as_list(dset)), 2 * len(self.clean_examples))
示例#4
0
 def test_lookup(self):
     dset = dataset.tuple_dataset(self.context,
                                  self.clean_examples, [str, 'uc_t'],
                                  field_separator='|',
                                  entity_separator=',',
                                  normalize_outputs=False)
     instances = self.as_list(dset)
     self.assertEqual(len(instances), len(self.clean_examples))
     exp_values = {
         b'a': NP_A,
         b'b': NP_B,
         b'c': NP_C + NP_D,
     }
     self.check_instances(instances, exp_values)
示例#5
0
 def test_empty_recovery(self):
     dset = dataset.tuple_dataset(self.context,
                                  self.empty_examples, [str, 'uc_t'],
                                  field_separator='|',
                                  entity_separator=',',
                                  normalize_outputs=False)
     instances = self.as_list(dset)
     self.assertEqual(len(instances), self.empty_examples_good_count)
     exp_values = {
         b'a': NP_A,
         b'b': NP_UNK,
         b'c': NP_NONE,
     }
     self.check_instances(instances, exp_values)
示例#6
0
 def test_lookup(self):
     dset = dataset.tuple_dataset(self.context,
                                  self.clean_examples, [str, 'uc_t'],
                                  field_separator='|',
                                  entity_separator=',',
                                  normalize_outputs=False)
     instances = self.as_list(dset)
     self.assertEqual(len(instances), len(self.clean_examples))
     exp_values = {
         b'a': NP_A,
         b'b': NP_B,
         b'c': NP_C + NP_D,
     }
     for (s, a) in instances:
         self.assertEqual(type(s), self.tf_string_type)
         np.testing.assert_array_equal(a, exp_values[s])
示例#7
0
 def test_empty_recovery(self):
     dset = dataset.tuple_dataset(self.context,
                                  self.empty_examples, [str, 'uc_t'],
                                  field_separator='|',
                                  entity_separator=',',
                                  normalize_outputs=False)
     instances = self.as_list(dset)
     self.assertEqual(len(instances), self.empty_examples_good_count)
     exp_values = {
         b'a': NP_A,
         b'b': NP_UNK,
         b'c': NP_NONE,
     }
     for (s, a) in instances:
         self.assertEqual(type(s), self.tf_string_type)
         np.testing.assert_array_equal(a, exp_values[s])
示例#8
0
def simple_tf_dataset(context,
                      tuple_input,
                      x_type,
                      y_type,
                      normalize_outputs=False,
                      batch_size=1,
                      shuffle_buffer_size=1000,
                      feature_key=None,
                      field_separator="\t"):
    """A dataset with just two columns, x and y.

  Args:
    context: a NeuralQueryContext
    tuple_input: passed to util.tuple_dataset
    x_type: type of entities x
    y_type: type of entities y1,...,yk
    normalize_outputs: make the encoding of {y1,...,yk} sum to 1
    batch_size: size of minibatches
    shuffle_buffer_size: if zero, do not shuffle the dataset. Otherwise, this is
      passed in as argument to shuffle
    feature_key: if not None, wrap the x part of the minibatch in a dictionary
      with the given key
    field_separator: passed in to dataset.tuple_dataset

  Returns:
    a tf.data.Dataset formed by wrapping the generator
  """
    dset = dataset.tuple_dataset(context,
                                 tuple_input, [x_type, y_type],
                                 normalize_outputs=normalize_outputs,
                                 field_separator=field_separator)
    if shuffle_buffer_size > 0:
        dset = dset.shuffle(shuffle_buffer_size)
    dset = dset.batch(batch_size)
    if feature_key is None:
        return dset
    else:
        wrap_x_in_dict = lambda x, y: ({feature_key: x}, y)
        return dset.map(wrap_x_in_dict)
示例#9
0
    def dset_fn():
        """Construct a tf dataset.

    Returns:
      a tf dataset
    """

        dset = dataset.tuple_dataset(context, full_filename, data_specs)
        if shuffle:
            dset = dset.shuffle(1000, seed=SEED)
        if n_take > 0:
            dset = dset.take(n_take)
        dset = dset.batch(FLAGS.minibatch_size)
        dset = dset.repeat(epochs)

        def feature_dict_mapper(q, s):
            feature_dict = {
                question_name: tf.strings.regex_replace(q, r'\[([^]]+)\]', ''),
                'seeds': s,
            }
            return feature_dict

        question_name = get_text_module_input_name()
        return dset.map(lambda _, q, s, ans: (feature_dict_mapper(q, s), ans))