def test_convert_empty(self): loader = StubLoader() ds1 = BaseDataset(loader, self.SCHEMA) ds2 = ds1.convert(lambda data: [None for d in data]) # ds2 should be empty for d in ds2: self.fail()
def test_static(self): loader = StubLoader() ds = BaseDataset(loader, self.SCHEMA) self.assertTrue(ds.is_static()) self.assertEqual(3, len(ds)) self.assertEqual({'value': 1}, dict(ds[0].num_values)) self.assertEqual({'value': 2}, dict(ds[1].num_values)) self.assertEqual({'value': 3}, dict(ds[2].num_values)) self.assertEqual({'v': 1}, dict(ds.get(0))) self.assertEqual({'v': 2}, dict(ds.get(1))) self.assertEqual({'v': 3}, dict(ds.get(2))) ds2 = ds[(1, 2)] self.assertEqual(2, len(ds2)) self.assertEqual({'value': 2}, dict(ds2[0].num_values)) self.assertEqual({'value': 3}, dict(ds2[1].num_values)) self.assertEqual({'v': 2}, dict(ds2.get(0))) self.assertEqual({'v': 3}, dict(ds2.get(1))) expected_idx = 0 for (idx, row) in ds: self.assertEqual(expected_idx, idx) self.assertEqual({'value': idx + 1}, dict(row.num_values)) expected_idx += 1
def test_nonstatic_ops(self): loader = StubLoader() ds = BaseDataset(loader, self.SCHEMA, False) self.assertRaises(RuntimeError, ds.shuffle, 0) self.assertRaises(RuntimeError, ds.convert, lambda x: x) self.assertRaises(RuntimeError, ds.get, 0) self.assertRaises(RuntimeError, len, ds) self.assertRaises(RuntimeError, lambda: ds[0])
def test_shuffle(self): loader = StubLoader() ds = BaseDataset(loader, self.SCHEMA).shuffle(0) rows = [] for (_, row) in ds: rows.append(row) row_values = [x.num_values[0][1] for x in rows] self.assertEqual([1, 2, 3], sorted(row_values))
def test_convert(self): loader = StubLoader() ds1 = BaseDataset(loader, self.SCHEMA) def f(d): new_d = {} for (k, v) in d.items(): new_d[k] = d[k] + 1 return new_d ds2 = ds1.convert(lambda data: [f(d) for d in data]) self.assertEqual(1, ds1[0].num_values[0][1]) self.assertEqual(2, ds2[0].num_values[0][1])
def test_infinite(self): loader = StubInfiniteLoader() ds = BaseDataset(loader, self.SCHEMA) self.assertFalse(ds.is_static()) expected_idx = 0 for (idx, row) in ds: self.assertEqual(expected_idx, idx) self.assertEqual({'value': idx + 1}, dict(row.num_values)) expected_idx += 1 if 10 < expected_idx: break
def test_nonstatic(self): loader = StubLoader() ds = BaseDataset(loader, self.SCHEMA, False) self.assertFalse(ds.is_static()) expected_idx = 0 for (idx, row) in ds: self.assertEqual(expected_idx, idx) self.assertEqual({'value': idx + 1}, dict(row.num_values)) self.assertEqual(row.num_values, ds[idx].num_values) self.assertEqual({'v': idx + 1}, ds.get(idx)) expected_idx += 1
def test_index_access(self): loader = StubLoader() ds = BaseDataset(loader, self.SCHEMA) self.assertTrue(isinstance(ds[0], jubatus.common.Datum)) self.assertTrue(isinstance(ds[0:1], BaseDataset))
def test_invalid_convert(self): loader = StubLoader() ds1 = BaseDataset(loader, self.SCHEMA) self.assertRaises(RuntimeError, ds1.convert, lambda x: None)
def test_get_schema(self): loader = StubLoader() ds = BaseDataset(loader, self.SCHEMA) self.assertEqual(self.SCHEMA, ds.get_schema())
def test_predict(self): loader = StubLoader() ds = BaseDataset(loader, self.SCHEMA) self.assertRaises(NotImplementedError, ds._predict, {})