def _setup_data_feeder(self, X, y): """Create data feeder, to sample inputs from dataset. If X and y are iterators, use StreamingDataFeeder. """ if hasattr(X, 'next'): assert hasattr(y, 'next') self._data_feeder = data_feeder.StreamingDataFeeder( X, y, self.n_classes, self.batch_size) else: self._data_feeder = data_feeder.DataFeeder(X, y, self.n_classes, self.batch_size)
def test_streaming_data_feeder(self): def X_iter(): yield np.array([1, 2]) yield np.array([3, 4]) def y_iter(): yield np.array([1]) yield np.array([2]) df = data_feeder.StreamingDataFeeder(X_iter(), y_iter(), n_classes=0, batch_size=2) feed_dict_fn = df.get_feed_dict_fn( MockPlaceholder(name='input'), MockPlaceholder(name='output')) feed_dict = feed_dict_fn() self.assertAllClose(feed_dict['input'], [[1, 2], [3, 4]]) self.assertAllClose(feed_dict['output'], [1, 2])