コード例 #1
0
 def test_radd_operator_with_datasets_should_return_dataset_with_all_samples(
         self):
     s1 = [Sample(1, 1), Sample(2, 2)]
     s2 = [Sample(3, 3), Sample(4, 4)]
     d1 = Dataset(list(s1))
     d2 = Dataset(list(s2))
     d3 = sum([d1, d2])
     self.assertListEqual(s1 + s2, d3.get_samples())
コード例 #2
0
 def test_merge_given_dataset_should_extend_original_dataset_with_given_one(
         self):
     samples_sut = [Sample(1, 1), Sample(2, 2)]
     samples_dataset = [Sample(3, 3), Sample(4, 4)]
     sut = Dataset(list(samples_sut))
     dataset = Dataset(list(samples_dataset))
     sut.merge(dataset)
     samples_sut.extend(samples_dataset)
     self.assertListEqual(samples_sut, sut.get_samples())
コード例 #3
0
 def test_get_output_given_offset_should_return_array_without_offset_first_elems(
         self):
     samples = [
         Sample(1, 1),
         Sample(1, 2),
         Sample(1, 3),
         Sample(1, 4),
         Sample(1, 5)
     ]
     sut = Dataset(list(samples))
     self.assertListEqual([[2], [3], [4], [5]], sut.get_output(offset=1))
コード例 #4
0
 def test_get_input_given_offset_should_return_array_without_offset_first_elems(
         self):
     samples = [
         Sample(1, None),
         Sample(2, None),
         Sample(3, None),
         Sample(4, None),
         Sample(5, None)
     ]
     sut = Dataset(list(samples))
     self.assertListEqual([[2], [3], [4], [5]], sut.get_input(offset=1))
コード例 #5
0
 def test_get_output_given_offset_and_num_elems_should_return_chunked_array(
         self):
     samples = [
         Sample(1, 1),
         Sample(1, 2),
         Sample(1, 3),
         Sample(1, 4),
         Sample(1, 5)
     ]
     sut = Dataset(list(samples))
     self.assertListEqual([[2], [3], [4]],
                          sut.get_output(offset=1, num_elems=3))
コード例 #6
0
 def test_get_input_given_offset_and_num_elems_should_return_chunked_array(
         self):
     samples = [
         Sample(1, None),
         Sample(2, None),
         Sample(3, None),
         Sample(4, None),
         Sample(5, None)
     ]
     sut = Dataset(list(samples))
     self.assertListEqual([[2], [3], [4]],
                          sut.get_input(offset=1, num_elems=3))
コード例 #7
0
ファイル: persistence.py プロジェクト: ram-iyer/DeepFramework
    def load(self, path):
        """Creates a Dataset object from the data saved in HDF5 file.

        The dataset will contain plain Sample objects with the raw data.
        """

        super(H5pyPersistenceManager, self).load(path)
        with h5py.File(path, 'r') as f:
            inputs = f[self.INPUT_DATASET_NAME]
            if self.OUTPUT_DATASET_NAME in f:
                outputs = f[self.OUTPUT_DATASET_NAME]
                samples = [Sample(sample_input, outputs[idx]) for idx, sample_input in enumerate(inputs)]
            else:
                samples = [Sample(sample_input) for sample_input in inputs]

            return Dataset(samples)
コード例 #8
0
 def test_save_given_dataset_with_input_output_should_persist_both(self):
     dataset = Dataset([Sample([1, 2], 1), Sample([3, 4], 3)])
     self.sut.save(dataset, self.file_path)
     with h5py.File(self.file_path, 'r') as f:
         self.assertIn(H5pyPersistenceManager.INPUT_DATASET_NAME, f)
         self.assertIn(H5pyPersistenceManager.OUTPUT_DATASET_NAME, f)
コード例 #9
0
 def test_get_input_given_axis_samples_false_should_return_array_with_input_as_first_axis(
         self):
     samples = [Sample([1, 2], 1), Sample([3, 4], 3)]
     sut = Dataset(samples)
     self.assertListEqual([[1, 3], [2, 4]],
                          sut.get_input(axis_samples=False))
コード例 #10
0
 def test_get_input_given_axis_samples_false_and_inconsisten_samples_should_raise_exception(
         self):
     samples = [Sample(1, 1), Sample([3, 4], 3)]
     sut = Dataset(samples)
     self.assertRaises(ValueError, sut.get_input, False)
コード例 #11
0
 def test_get_output_without_output_should_raise_exception(self):
     sample = Sample(1)
     self.assertRaises(AttributeError, sample.get_output)
コード例 #12
0
 def test_load_given_existing_path_should_return_dataset(self):
     with open(self.file_path, 'w') as f:
         cPickle.dump(Dataset([Sample([1, 2]), Sample([3, 4])]), f)
     dataset = self.sut.load(self.file_path)
     self.assertIsInstance(dataset, Dataset)
     self.assertTrue(dataset.len() == 2)
コード例 #13
0
 def test_get_data_with_none_should_return_none(self):
     self.assertIsNone(Sample._get_data(None))
コード例 #14
0
 def test_get_output_without_output_samples_should_raise_exception(self):
     sut = Dataset([Sample(1), Sample(2)])
     self.assertRaises(TypeError, sut.get_output)
コード例 #15
0
 def test_get_output_given_axis_samples_true_should_return_array_with_sample_as_first_axis(
         self):
     samples = [Sample(1, [1, 2]), Sample(3, [3, 4])]
     sut = Dataset(samples)
     self.assertListEqual([[1, 2], [3, 4]],
                          sut.get_output(axis_samples=True))
コード例 #16
0
 def test_get_data_with_primitive_should_return_primitive(self):
     elem = 1
     self.assertEqual([elem], Sample._get_data(elem))
コード例 #17
0
 def test_get_data_with_list_primitives_should_return_list_primitives(self):
     elems = [1, 2, 3]
     self.assertItemsEqual(elems, Sample._get_data(elems))
コード例 #18
0
 def test_save_given_dataset_should_persist_it(self):
     dataset = Dataset([Sample([1, 2]), Sample([3, 4])])
     self.sut.save(dataset, self.file_path)
     with open(self.file_path) as f:
         d = cPickle.load(f)
         self.assertIsInstance(d, Dataset)
コード例 #19
0
 def test_get_data_with_mixed_list_should_return_list_data(self):
     elems = [1, DummyValue([1, 2, 3])]
     self.assertItemsEqual([1, [1, 2, 3]], Sample._get_data(elems))
コード例 #20
0
 def test_get_data_with_list_value_objects_should_return_list_data(self):
     data = [1, 2, 3]
     elems = [DummyValue(data), DummyValue(data)]
     self.assertIsInstance(Sample._get_data(elems), list)
     self.assertItemsEqual([data, data], Sample._get_data(elems))
コード例 #21
0
 def test_get_data_with_value_object_should_return_its_data(self):
     data = [1, 2, 3]
     elem = DummyValue(data)
     self.assertItemsEqual([data], Sample._get_data(elem))