def test_remove_give_list_with_existent_samples_should_remove_them(self): samples = [1, 2, 3] sut = Dataset(list(samples)) sut.remove([1, 2]) samples.remove(1) samples.remove(2) self.assertListEqual(samples, sut.get_samples())
def test_radd_operator_with_datasets_should_return_dataset_with_all_samples( self): s1 = [Sample(1, 1), Sample(2, 2)] s2 = [Sample(3, 3), Sample(4, 4)] d1 = Dataset(list(s1)) d2 = Dataset(list(s2)) d3 = sum([d1, d2]) self.assertListEqual(s1 + s2, d3.get_samples())
def test_radd_operator_with_datasets_should_return_new_dataset_instance( self): d1 = Dataset() d2 = Dataset() d3 = sum([d1, d2]) self.assertIsInstance(d3, Dataset) self.assertNotEqual(d3, d1) self.assertNotEqual(d3, d2)
def test_get_output_given_offset_should_return_array_without_offset_first_elems( self): samples = [ Sample(1, 1), Sample(1, 2), Sample(1, 3), Sample(1, 4), Sample(1, 5) ] sut = Dataset(list(samples)) self.assertListEqual([[2], [3], [4], [5]], sut.get_output(offset=1))
def test_get_input_given_offset_should_return_array_without_offset_first_elems( self): samples = [ Sample(1, None), Sample(2, None), Sample(3, None), Sample(4, None), Sample(5, None) ] sut = Dataset(list(samples)) self.assertListEqual([[2], [3], [4], [5]], sut.get_input(offset=1))
def test_get_input_given_offset_and_num_elems_should_return_chunked_array( self): samples = [ Sample(1, None), Sample(2, None), Sample(3, None), Sample(4, None), Sample(5, None) ] sut = Dataset(list(samples)) self.assertListEqual([[2], [3], [4]], sut.get_input(offset=1, num_elems=3))
def test_get_output_given_offset_and_num_elems_should_return_chunked_array( self): samples = [ Sample(1, 1), Sample(1, 2), Sample(1, 3), Sample(1, 4), Sample(1, 5) ] sut = Dataset(list(samples)) self.assertListEqual([[2], [3], [4]], sut.get_output(offset=1, num_elems=3))
def test_merge_given_dataset_should_extend_original_dataset_with_given_one( self): samples_sut = [Sample(1, 1), Sample(2, 2)] samples_dataset = [Sample(3, 3), Sample(4, 4)] sut = Dataset(list(samples_sut)) dataset = Dataset(list(samples_dataset)) sut.merge(dataset) samples_sut.extend(samples_dataset) self.assertListEqual(samples_sut, sut.get_samples())
def load(self, path): """Creates a Dataset object from the data saved in HDF5 file. The dataset will contain plain Sample objects with the raw data. """ super(H5pyPersistenceManager, self).load(path) with h5py.File(path, 'r') as f: inputs = f[self.INPUT_DATASET_NAME] if self.OUTPUT_DATASET_NAME in f: outputs = f[self.OUTPUT_DATASET_NAME] samples = [Sample(sample_input, outputs[idx]) for idx, sample_input in enumerate(inputs)] else: samples = [Sample(sample_input) for sample_input in inputs] return Dataset(samples)
def test_get_input_given_axis_samples_false_should_return_array_with_input_as_first_axis( self): samples = [Sample([1, 2], 1), Sample([3, 4], 3)] sut = Dataset(samples) self.assertListEqual([[1, 3], [2, 4]], sut.get_input(axis_samples=False))
def test_get_input_given_invalid_samples_should_raise_exception(self): sut = Dataset([1, 2, 3]) self.assertRaises(TypeError, sut.get_input)
def test_construct_given_none_should_initialize_samples_as_list(self): sut = Dataset() self.assertIsInstance(sut.get_samples(), list) self.assertListEqual([], sut.get_samples())
def test_remove_given_list_with_unexistent_sample_should_raise_exception( self): samples = [1, 2, 3] sut = Dataset(list(samples)) self.assertRaises(ValueError, sut.remove, [1, 4])
def test_load_given_existing_path_should_return_dataset(self): with open(self.file_path, 'w') as f: cPickle.dump(Dataset([Sample([1, 2]), Sample([3, 4])]), f) dataset = self.sut.load(self.file_path) self.assertIsInstance(dataset, Dataset) self.assertTrue(dataset.len() == 2)
def test_merge_given_non_dataset_should_raise_exception(self): sut = Dataset() self.assertRaises(TypeError, sut.merge, 'no dataset')
def test_remove_given_existent_sample_should_remove_it(self): samples = [1, 2, 3] sut = Dataset(list(samples)) sut.remove(3) samples.remove(3) self.assertListEqual(samples, sut.get_samples())
def test_add_given_single_item_should_add_elem(self): samples = [1, 2, 3] sut = Dataset(list(samples)) sut.add(4) samples.append(4) self.assertListEqual(samples, sut.get_samples())
def test_add_given_list_should_add_elements(self): samples = [1, 2, 3] sut = Dataset(list(samples)) sut.add([4, 5, 6]) samples.extend([4, 5, 6]) self.assertListEqual(samples, sut.get_samples())
def test_get_output_given_axis_samples_true_should_return_array_with_sample_as_first_axis( self): samples = [Sample(1, [1, 2]), Sample(3, [3, 4])] sut = Dataset(samples) self.assertListEqual([[1, 2], [3, 4]], sut.get_output(axis_samples=True))
def test_add_give_empty_list_should_do_nothing(self): samples = [1, 2, 3] sut = Dataset(list(samples)) sut.add([]) self.assertListEqual(samples, sut.get_samples())
def test_save_given_dataset_should_persist_it(self): dataset = Dataset([Sample([1, 2]), Sample([3, 4])]) self.sut.save(dataset, self.file_path) with open(self.file_path) as f: d = cPickle.load(f) self.assertIsInstance(d, Dataset)
def test_add_operator_with_non_dataset_should_raise_exception(self): sut = Dataset() self.assertRaises(TypeError, sut.__add__, 1)
def test_construct_given_list_should_assign_samples(self): samples = [1, 2, 3] sut = Dataset(samples) self.assertIsInstance(sut.get_samples(), list) self.assertListEqual(samples, sut.get_samples())
def test_get_input_given_axis_samples_false_and_inconsisten_samples_should_raise_exception( self): samples = [Sample(1, 1), Sample([3, 4], 3)] sut = Dataset(samples) self.assertRaises(ValueError, sut.get_input, False)
def test_remove_given_unexistent_sample_should_raise_exception(self): sut = Dataset([1, 2, 3]) self.assertRaises(ValueError, sut.remove, 4)
def test_save_given_dataset_with_input_output_should_persist_both(self): dataset = Dataset([Sample([1, 2], 1), Sample([3, 4], 3)]) self.sut.save(dataset, self.file_path) with h5py.File(self.file_path, 'r') as f: self.assertIn(H5pyPersistenceManager.INPUT_DATASET_NAME, f) self.assertIn(H5pyPersistenceManager.OUTPUT_DATASET_NAME, f)
def test_get_output_without_output_samples_should_raise_exception(self): sut = Dataset([Sample(1), Sample(2)]) self.assertRaises(TypeError, sut.get_output)