def check_dict_dataset(self, x, y): dd = datasets.DictDataset(x=x, y=y) self.assertEqual(len(dd), len(x)) for i in range(len(x)): example = dd[i] self.assertIn('x', example) self.assertIn('y', example) numpy.testing.assert_array_equal(cuda.to_cpu(example['x']), cuda.to_cpu(x[i])) numpy.testing.assert_array_equal(cuda.to_cpu(example['y']), cuda.to_cpu(y[i]))
def test_dict_dtaset_overrun(self): dd = datasets.DictDataset(x=self.x, y=self.y) with self.assertRaises(IndexError): dd[3]
def test_dict_dataset_len_mismatch(self): with self.assertRaises(ValueError): datasets.DictDataset(x=self.x, z=self.z)