예제 #1
0
    def test_dump_qid(self):
        tmpfile = "/tmp/tmp_dump.txt"
        try:
            # loads from file
            Xs, y, q = load_svmlight_file(qid_datafile, query_id=True)

            # dumps to file
            dump_svmlight_file(Xs,
                               y,
                               tmpfile,
                               query_id=list(q),
                               zero_based=False)

            # loads them as CSR MATRIX with scikit-learn
            X2, y2, q2 = sk_load_svmlight_file(tmpfile, query_id=True)

            X3 = np.ndarray(shape=X2.shape, dtype=X2.dtype)
            X2.toarray(out=X3)

            # check assertions
            assert_array_almost_equal(Xs, X3)
            assert_array_almost_equal(y, y2)
            assert_array_equal(q, q2)
        finally:
            if os.path.exists(tmpfile):
                os.remove(tmpfile)
예제 #2
0
    def test_load_svmlight_file_descriptor(self):
        with open(datafile, 'rb') as reader:
            X, y = load_svmlight_file(reader)

        # test X's shape
        assert_array_equal(X.shape, (3, 20))

        # test y
        assert_array_equal(y, [1, 2, 3])
예제 #3
0
    def test_load_svmlight_file(self):
        X, y = load_svmlight_file(datafile)

        # test X's shape
        assert_array_equal(X.shape, (3, 20))

        # test X's non-zero values
        # tests X's zero values
        # test can change X's values

        # test y
        assert_array_equal(y, [1, 2, 3])
예제 #4
0
    def test_load_svmlight_file_empty_qid(self):
        X, y, q = load_svmlight_file(datafile, query_id=True)

        # test X's shape
        assert_array_equal(X.shape, (3, 20))

        # test X's non-zero values
        # tests X's zero values
        # test can change X's values

        # test y
        assert_array_equal(y, [1, 2, 3])

        # test q
        assert_equal(q.shape[0], 0)
예제 #5
0
    def test_load_svmlight_qid_file(self):
        X, y, q = load_svmlight_file(qid_datafile, query_id=True)

        # test X's shape
        assert_array_equal(X.shape, (4, 33))
        #print X

        # test X's non-zero values
        # tests X's zero values
        # test can change X's values

        # test y
        assert_array_equal(y, [1, 2, 0, 3])

        # test q
        # print q
        assert_array_equal(q, [1, 37, 37, 12])
예제 #6
0
    def test_dump(self):
        tmpfile = "tmp_dump.txt"
        try:
            # loads from file
            Xs, y = load_svmlight_file(datafile)

            # dumps to file
            dump_svmlight_file(Xs, y, tmpfile, zero_based=False)

            # loads them as CSR MATRIX
            X2, y2 = sk_load_svmlight_file(tmpfile)

            X3 = np.ndarray(shape=X2.shape, dtype=X2.dtype)
            X2.toarray(out=X3)

            # check assertions
            assert_array_almost_equal(Xs, X3)
            assert_array_almost_equal(y, y2)
        finally:
            if os.path.exists(tmpfile):
                os.remove(tmpfile)
예제 #7
0
 def test_load_invalid_file(self):
     try:
         load_svmlight_file(invalidfile)
         assert False
     except RuntimeError:
         pass
예제 #8
0
 def test_invalid_filename(self):
     load_svmlight_file("trou pic nic douille")
예제 #9
0
 def test_not_a_filename(self):
     load_svmlight_file(1)