Exemplo n.º 1
0
 def _makeDataset(self,
                  inputter,
                  data_file,
                  data_config=None,
                  dataset_size=1,
                  shapes=None):
     if data_config is not None:
         inputter.initialize(data_config)
     self.assertEqual(inputter.get_dataset_size(data_file), dataset_size)
     dataset = inputter.make_dataset(data_file)
     eager_features = inputter.make_features(iter(dataset).next(),
                                             training=True)
     eager_features = tf.nest.map_structure(lambda t: tf.expand_dims(t, 0),
                                            eager_features)
     dataset = dataset.map(lambda *arg: inputter.make_features(
         item_or_tuple(arg), training=True))
     dataset = dataset.apply(dataset_util.batch_dataset(1))
     features = iter(dataset).next()
     if shapes is not None:
         self._checkFeatures(features, shapes)
         self._checkFeatures(eager_features, shapes)
     keep = inputter.keep_for_training(features)
     self.assertIs(keep.dtype, tf.bool)
     inputs = inputter(features, training=True)
     if not isinstance(inputter, inputters.ExampleInputter):
         self._testServing(inputter)
     return self.evaluate((features, inputs))
Exemplo n.º 2
0
 def _makeDataset(self,
                  inputter,
                  data_file,
                  data_config=None,
                  dataset_size=1,
                  shapes=None):
     if data_config is not None:
         inputter.initialize(data_config)
     dataset = inputter.make_dataset(data_file)
     dataset = dataset.map(lambda *arg: inputter.make_features(
         item_or_tuple(arg), training=True))
     dataset = dataset.apply(dataset_util.batch_dataset(1))
     features = iter(dataset).next()
     if shapes is not None:
         self._checkFeatures(features, shapes)
     inputs = inputter(features, training=True)
     return self.evaluate((features, inputs))
Exemplo n.º 3
0
 def _serving_fun(features):
     features = inputter.make_features(features=features.copy())
     inputs = inputter(features)
     return inputs