コード例 #1
0
ファイル: test_transformers.py プロジェクト: markusnagel/fuel
 def test_one_hot_batches_invalid_input(self):
     wrapper = OneHotEncoding(
         DataStream(IndexableDataset(self.data),
                    iteration_scheme=SequentialScheme(4, 2)),
         num_classes=2,
         which_sources=('targets',))
     assert_raises(ValueError, list, wrapper.get_epoch_iterator())
コード例 #2
0
ファイル: test_transformers.py プロジェクト: basveeling/fuel
 def test_one_hot_batches_invalid_input(self):
     wrapper = OneHotEncoding(DataStream(IndexableDataset(self.data),
                                         iteration_scheme=SequentialScheme(
                                             4, 2)),
                              num_classes=2,
                              which_sources=('targets', ))
     assert_raises(ValueError, list, wrapper.get_epoch_iterator())
コード例 #3
0
ファイル: test_transformers.py プロジェクト: basveeling/fuel
 def test_one_hot_examples(self):
     wrapper = OneHotEncoding(DataStream(
         IndexableDataset(self.data),
         iteration_scheme=SequentialExampleScheme(4)),
                              num_classes=4,
                              which_sources=('targets', ))
     assert_equal(list(wrapper.get_epoch_iterator()),
                  [(numpy.ones((2, 2)), numpy.array([[1, 0, 0, 0]])),
                   (numpy.ones((2, 2)), numpy.array([[0, 1, 0, 0]])),
                   (numpy.ones((2, 2)), numpy.array([[0, 0, 1, 0]])),
                   (numpy.ones((2, 2)), numpy.array([[0, 0, 0, 1]]))])
コード例 #4
0
ファイル: test_transformers.py プロジェクト: markusnagel/fuel
 def test_one_hot_batches(self):
     wrapper = OneHotEncoding(
         DataStream(IndexableDataset(self.data),
                    iteration_scheme=SequentialScheme(4, 2)),
         num_classes=4,
         which_sources=('targets',))
     assert_equal(
         list(wrapper.get_epoch_iterator()),
         [(numpy.ones((2, 2, 2)),
           numpy.array([[1, 0, 0, 0], [0, 1, 0, 0]])),
          (numpy.ones((2, 2, 2)),
           numpy.array([[0, 0, 1, 0], [0, 0, 0, 1]]))])