예제 #1
0
 def _setup_data_feeder(self, X, y):
     """Create data feeder, to sample inputs from dataset.
     If X and y are iterators, use StreamingDataFeeder.
     """
     if hasattr(X, 'next'):
         assert hasattr(y, 'next')
         self._data_feeder = data_feeder.StreamingDataFeeder(
             X, y, self.n_classes, self.batch_size)
     else:
         self._data_feeder = data_feeder.DataFeeder(X, y, self.n_classes,
                                                    self.batch_size)
예제 #2
0
    def test_streaming_data_feeder(self):
        def X_iter():
            yield np.array([1, 2])
            yield np.array([3, 4])

        def y_iter():
            yield np.array([1])
            yield np.array([2])
        df = data_feeder.StreamingDataFeeder(X_iter(), y_iter(), n_classes=0,
                                             batch_size=2)
        feed_dict_fn = df.get_feed_dict_fn(
            MockPlaceholder(name='input'),
            MockPlaceholder(name='output'))
        feed_dict = feed_dict_fn()
        self.assertAllClose(feed_dict['input'], [[1, 2], [3, 4]])
        self.assertAllClose(feed_dict['output'], [1, 2])