Esempio n. 1
0
  def run_test_case(self, dataset, truth_data, batch_size=None):
    """run_test_case"""
    iterator = data.make_one_shot_iterator(dataset)
    next_element = iterator.get_next()

    def is_float(dtype):
      return dtype in [dtypes.float16, dtypes.float32, dtypes.float64]

    batch_counter = batch_size

    with self.test_session() as sess:
      for row in range(len(truth_data.data[0])):
        if batch_size is None:
          value = sess.run(next_element)
        else:
          if batch_counter == batch_size:
            value_batch = sess.run(next_element)
            print(value_batch)
            batch_counter = 0
          value = [v[batch_counter] for v in value_batch]
          batch_counter += 1
        for i, col in enumerate(dataset.columns):
          if truth_data.output_shapes[col].ndims == 0:
            if is_float(truth_data.output_types[col]):
              self.assertAlmostEqual(value[i], truth_data.data[col][row], 4)
            else:
              self.assertEqual(value[i], truth_data.data[col][row])
          elif truth_data.output_shapes[col].ndims == 1:
            if is_float(truth_data.output_types[col]):
              for j, v in enumerate(value[i]):
                self.assertAlmostEqual(v, truth_data.data[col][row][j], 4)
            else:
              self.assertListEqual(value[i].tolist(), truth_data.data[col][row])
Esempio n. 2
0
    def run_test_case(self, dataset, truth_data):
        """run_test_case"""
        iterator = data.make_one_shot_iterator(dataset)
        next_element = iterator.get_next()

        def is_float(dtype):
            return dtype in [dtypes.float16, dtypes.float32, dtypes.float64]

        with self.test_session() as sess:
            for row in range(len(truth_data.data[0])):
                value = sess.run(next_element)
                for i, col in enumerate(dataset.columns):
                    if truth_data.output_shapes[col].ndims == 0:
                        if is_float(truth_data.output_types[col]):
                            self.assertAlmostEqual(value[i],
                                                   truth_data.data[col][row],
                                                   4)
                        else:
                            self.assertEqual(value[i],
                                             truth_data.data[col][row])
                    elif truth_data.output_shapes[col].ndims == 1:
                        if is_float(truth_data.output_types[col]):
                            for j, v in enumerate(value[i]):
                                self.assertAlmostEqual(
                                    v, truth_data.data[col][row][j], 4)
                        else:
                            self.assertListEqual(value[i].tolist(),
                                                 truth_data.data[col][row])
Esempio n. 3
0
 def pipeline(*args):
     # TF2 replacement for: iterator = dataset.make_one_shot_iterator()
     iterator = compat_v1_data.make_one_shot_iterator(dataset)
     next_example, next_label = iterator.get_next()
     outputs = functional_ops._convert_to_list(args)  # pylint: disable=W0212
     outputs.append(next_example)
     outputs.append(next_label)
     for stage in stages:
         outputs = stage(
             *functional_ops._convert_to_list(outputs))  # pylint: disable=W0212
     return outputs
Esempio n. 4
0
    def testBufferDataset(self):
        dataset = tu.create_single_increasing_dataset(10, shape=[4, 4])
        dataset = dataset.take(3)
        dataset = ipu.data.ops.dataset_ops.BufferDataset(dataset, 2)
        itr = compat_v1_data.make_one_shot_iterator(dataset)

        next_data = itr.get_next()
        with self.session() as sess:
            self.assertAllEqual(sess.run(next_data)[0], np.zeros([4, 4]))
            self.assertAllEqual(sess.run(next_data)[0], np.ones([4, 4]))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(sess.run(next_data))