def test_dump_qid(self): tmpfile = "/tmp/tmp_dump.txt" try: # loads from file Xs, y, q = load_svmlight_file(qid_datafile, query_id=True) # dumps to file dump_svmlight_file(Xs, y, tmpfile, query_id=list(q), zero_based=False) # loads them as CSR MATRIX with scikit-learn X2, y2, q2 = sk_load_svmlight_file(tmpfile, query_id=True) X3 = np.ndarray(shape=X2.shape, dtype=X2.dtype) X2.toarray(out=X3) # check assertions assert_array_almost_equal(Xs, X3) assert_array_almost_equal(y, y2) assert_array_equal(q, q2) finally: if os.path.exists(tmpfile): os.remove(tmpfile)
def test_load_svmlight_file_descriptor(self): with open(datafile, 'rb') as reader: X, y = load_svmlight_file(reader) # test X's shape assert_array_equal(X.shape, (3, 20)) # test y assert_array_equal(y, [1, 2, 3])
def test_load_svmlight_file(self): X, y = load_svmlight_file(datafile) # test X's shape assert_array_equal(X.shape, (3, 20)) # test X's non-zero values # tests X's zero values # test can change X's values # test y assert_array_equal(y, [1, 2, 3])
def test_load_svmlight_file_empty_qid(self): X, y, q = load_svmlight_file(datafile, query_id=True) # test X's shape assert_array_equal(X.shape, (3, 20)) # test X's non-zero values # tests X's zero values # test can change X's values # test y assert_array_equal(y, [1, 2, 3]) # test q assert_equal(q.shape[0], 0)
def test_load_svmlight_qid_file(self): X, y, q = load_svmlight_file(qid_datafile, query_id=True) # test X's shape assert_array_equal(X.shape, (4, 33)) #print X # test X's non-zero values # tests X's zero values # test can change X's values # test y assert_array_equal(y, [1, 2, 0, 3]) # test q # print q assert_array_equal(q, [1, 37, 37, 12])
def test_dump(self): tmpfile = "tmp_dump.txt" try: # loads from file Xs, y = load_svmlight_file(datafile) # dumps to file dump_svmlight_file(Xs, y, tmpfile, zero_based=False) # loads them as CSR MATRIX X2, y2 = sk_load_svmlight_file(tmpfile) X3 = np.ndarray(shape=X2.shape, dtype=X2.dtype) X2.toarray(out=X3) # check assertions assert_array_almost_equal(Xs, X3) assert_array_almost_equal(y, y2) finally: if os.path.exists(tmpfile): os.remove(tmpfile)
def test_load_invalid_file(self): try: load_svmlight_file(invalidfile) assert False except RuntimeError: pass
def test_invalid_filename(self): load_svmlight_file("trou pic nic douille")
def test_not_a_filename(self): load_svmlight_file(1)