def test_inconsistent_samples(self): with open(self.classfile, 'w') as fh: fh.write("% This is a file\n%Clusters 3 cat,dog,bear\n% foobar\n") fh.write("0 0\n1 1\n2 1\n3 3") with self.assertRaises(IndexError): td = TrainingData() td.load_raw(self.dirpath)
def test_spit(self): exp_trainx = numpy.array([[1.0, 2.0], [0.0, 2.0]]) exp_trainy = numpy.array([0, 1]) exp_testx = numpy.array([[5.0, 0]]) exp_testy = numpy.array([1]) td = TrainingData() td.load_raw(self.dirpath) trainx, trainy, testx, testy = td.train_test_split(test_frac=0.33, shuffle=False) self.assertTrue(numpy.array_equal(exp_trainx, trainx)) self.assertTrue(numpy.array_equal(exp_trainy, trainy)) self.assertTrue(numpy.array_equal(exp_testx, testx)) self.assertTrue(numpy.array_equal(exp_testy, testy))
def test_export_csv(self): export_file = os.path.join(self.dirpath, 'export.csv') td = TrainingData() td.load_raw(self.dirpath) td.export(export_file) self.assertTrue(filecmp.cmp(self.csvfile, export_file))
def test_load_csv(self): td = TrainingData() td.load_csv(self.csvfile) self.assertTrue(self.expected_df.equals(td.dataframe))
def test_feature_names(self): td = TrainingData() td.load_raw(self.dirpath) self.assertEqual(['fluffy', 'scary'], td.feature_names)
def test_max_value(self): td = TrainingData() td.load_raw(self.dirpath) self.assertEqual(5, td.feature_max_value)
def test_inconsistent_terms(self): with open(self.termsfile, 'w') as fh: fh.write("fluffy\nscary\ntoothy\n") with self.assertRaises(IndexError): td = TrainingData() td.load_raw(self.dirpath)
def test_missing_file(self): os.remove(self.mtxfile) with self.assertRaises(IOError): td = TrainingData() td.load_raw(self.dirpath)
def test_good_case(self): td = TrainingData() td.load_raw(self.dirpath) self.assertTrue(self.expected_df.equals(td.dataframe))
def test_split_negative(self): td = TrainingData() td.load_raw(self.dirpath) with self.assertRaises(ValueError): td.train_test_split(test_frac=-0.5, shuffle=False)
def test_clean(self): td = TrainingData() td.load_raw(self.dirpath) td.clean() self.assertIsNone(td.dataframe)