示例#1
0
    def test_save_load(self):
        g = SGraph().add_vertices(self.vertices,
                                  'vid').add_edges(self.edges, 'src_id',
                                                   'dst_id')
        with util.TempDirectory() as f:
            g.save(f)
            g2 = load_graph(f, 'binary')
            self.assertEqual(g2.summary(), {'num_vertices': 4, 'num_edges': 3})
            self.assertItemsEqual(
                g2.get_fields(),
                {'__id', '__src_id', '__dst_id', 'color', 'vec', 'weight'})

        with util.TempDirectory() as f:
            g.save(f, format='csv')
            vertices = SFrame.read_csv(f + "/vertices.csv")
            edges = SFrame.read_csv(f + "/edges.csv")
            g2 = SGraph().add_edges(edges, '__src_id',
                                    '__dst_id').add_vertices(vertices, '__id')
            self.assertEqual(g2.summary(), {'num_vertices': 4, 'num_edges': 3})
            self.assertItemsEqual(
                g2.get_fields(),
                {'__id', '__src_id', '__dst_id', 'color', 'vec', 'weight'})

        with tempfile.NamedTemporaryFile(suffix='.json') as f:
            g.save(f.name)
            with open(f.name, 'r') as f2:
                data = f2.read()
                g2 = json.loads(data)
            self.assertTrue("vertices" in g2)
            self.assertTrue("edges" in g2)
示例#2
0
 def __test_model_save_load_helper__(self, model):
     with util.TempDirectory() as f:
         model.save(f)
         m2 = get_unity().load_model(f)
         self.assertItemsEqual(model.list_fields(), m2.list_fields())
         for key in model.list_fields():
             if type(model.get(key)) is SGraph:
                 self.assertItemsEqual(
                     model.get(key).summary(),
                     m2.get(key).summary())
                 self.assertItemsEqual(
                     model.get(key).get_fields(),
                     m2.get(key).get_fields())
             elif type(model.get(key)) is SFrame:
                 sf1 = model.get(key)
                 sf2 = m2.get(key)
                 self.assertEqual(len(sf1), len(sf2))
                 self.assertItemsEqual(sf1.column_names(),
                                       sf2.column_names())
                 df1 = sf1.to_dataframe()
                 print df1
                 df2 = sf2.to_dataframe()
                 print df2
                 df1 = df1.set_index(df1.columns[0])
                 df2 = df2.set_index(df2.columns[0])
                 assert_frame_equal(df1, df2)
             else:
                 if (type(model.get(key)) is pd.DataFrame):
                     assert_frame_equal(model.get(key), m2.get(key))
                 else:
                     self.assertEqual(model.get(key), m2.get(key))
示例#3
0
    def test_basic_save_load(self):
        # save and load the pagerank model
        with util.TempDirectory() as tmp_pr_model_file:
            self.pr_model.save(tmp_pr_model_file)
            pr_model2 = gl.load_model(tmp_pr_model_file)
            self.__assert_model_equals__(self.pr_model, pr_model2)

        # save and load the connected_component model
        with util.TempDirectory() as tmp_cc_model_file:
            self.cc_model.save(tmp_cc_model_file)
            cc_model2 = gl.load_model(tmp_cc_model_file)
            self.__assert_model_equals__(self.cc_model, cc_model2)

        # handle different types of urls.
        # TODO: test hdfs and s3 urls.
        for url in [
                './tmp_model-%d' % temp_number,
                '/tmp/tmp_model-%d' % temp_number,
                'remote:///tmp/tmp_model2-%d' % temp_number
        ]:

            self.pr_model.save(url)
            self.__assert_model_equals__(self.pr_model, gl.load_model(url))
示例#4
0
    def test_exception(self):
        # load model from empty file
        with util.TempDirectory() as tmp_empty_file:
            with self.assertRaises(IOError):
                gl.load_model(tmp_empty_file)

        # load model from non-existing file
        if (os.path.exists('./tmp_model-%d' % temp_number)):
            shutil.rmtree('./tmp_model-%d' % temp_number)
        with self.assertRaises(IOError):
            gl.load_model('./tmp_model-%d' % temp_number)

        # save model to invalid url
        for url in ['http://test', '/root/tmp/testmodel']:
            with self.assertRaises(IOError):
                self.pr_model.save(url)
示例#5
0
文件: test_util.py 项目: yuwin/SFrame
 def _validate_gl_object_type(self, obj, expected):
     with util.TempDirectory() as temp_dir:
         obj.save(temp_dir)
         t = get_graphlab_object_type(temp_dir)
         self.assertEquals(t, expected)