예제 #1
0
    def test_read_parquet_str(self):
        t = XFrame({'id': [1, 2, 3], 'val': ['a', 'b', 'c']})
        path = '{}/tmp/frame-parquet'.format(hdfs_prefix)
        t.save(path, format='parquet')

        res = XFrame('{}/tmp/frame-parquet.parquet'.format(hdfs_prefix))
        # results may not come back in the same order
        res = res.sort('id')
        self.assertEqualLen(3, res)
        self.assertListEqual(['id', 'val'], res.column_names())
        self.assertListEqual([int, str], res.column_types())
        self.assertDictEqual({'id': 1, 'val': 'a'}, res[0])
        self.assertDictEqual({'id': 2, 'val': 'b'}, res[1])
        self.assertDictEqual({'id': 3, 'val': 'c'}, res[2])
        fileio.delete(path)
예제 #2
0
    def predict_all(self, user):
        """
        Predict ratings for all items.

        Parameters
        ----------
        user : int
            The user to make predictions for.

        Returns
        -------
        out : XFrame
            Each row of the frame consists of a user id, an item id, and a predicted rating.
        """

        # build rdd to pass to predictAll
        user_item = XFrame()
        user_item[self.item_col] = self.items
        user_item[self.user_col] = user
        user_item.swap_columns(self.item_col, self.user_col)
        rdd = user_item.to_rdd()
        res = self.model.predictAll(rdd)
        res = res.map(lambda rating: (rating.user, rating.product, rating.rating))
        col_names = [self.user_col, self.item_col, self.rating_col]
        user_type = self.users.dtype()
        item_type = self.items.dtype()
        col_types = [user_type, item_type, float]
        return XFrame.from_rdd(res, column_names=col_names, column_types=col_types)
예제 #3
0
 def test_save(self):
     t = XFrame({'id': [30, 20, 10], 'val': ['a', 'b', 'c']})
     path = '{}/tmp/frame'.format(hdfs_prefix)
     t.save(path, format='binary')
     with fileio.open_file(os.path.join(path, '_metadata')) as f:
         metadata = pickle.load(f)
     self.assertListEqual([['id', 'val'], [int, str]], metadata)
     # TODO find some way to check the data
     fileio.delete(path)
예제 #4
0
 def test_construct_auto_str_psv(self):
     path = '{}/user/xpatterns/files/test-frame.psv'.format(hdfs_prefix)
     res = XFrame(path)
     self.assertEqualLen(3, res)
     self.assertListEqual(['id', 'val'], res.column_names())
     self.assertListEqual([int, str], res.column_types())
     self.assertDictEqual({'id': 1, 'val': 'a'}, res[0])
     self.assertDictEqual({'id': 2, 'val': 'b'}, res[1])
     self.assertDictEqual({'id': 3, 'val': 'c'}, res[2])
예제 #5
0
 def test_construct_str_xframe(self):
     # construct and XFrame given a saved xframe
     path = '{}/user/xpatterns/files/test-frame'.format(hdfs_prefix)
     res = XFrame(path, format='xframe')
     res = res.sort('id')
     self.assertEqualLen(3, res)
     self.assertListEqual(['id', 'val'], res.column_names())
     self.assertListEqual([int, str], res.column_types())
     self.assertDictEqual({'id': 1, 'val': 'a'}, res[0])
     self.assertDictEqual({'id': 2, 'val': 'b'}, res[1])
     self.assertDictEqual({'id': 3, 'val': 'c'}, res[2])
예제 #6
0
 def test_construct_str_csv(self):
     # construct and XFrame given a text file
     # interpret as csv
     path = '{}/user/xpatterns/files/test-frame.txt'.format(hdfs_prefix)
     res = XFrame(path, format='csv')
     self.assertEqualLen(3, res)
     self.assertListEqual(['id', 'val'], res.column_names())
     self.assertListEqual([int, str], res.column_types())
     self.assertDictEqual({'id': 1, 'val': 'a'}, res[0])
     self.assertDictEqual({'id': 2, 'val': 'b'}, res[1])
     self.assertDictEqual({'id': 3, 'val': 'c'}, res[2])
예제 #7
0
 def test_construct_auto_str_xframe(self):
     # construct an XFrame given a file with unrecognized file extension
     path = '{}/user/xpatterns/files/test-frame'.format(hdfs_prefix)
     res = XFrame(path)
     res = res.sort('id')
     self.assertEqualLen(3, res)
     self.assertListEqual(['id', 'val'], res.column_names())
     self.assertListEqual([int, str], res.column_types())
     self.assertDictEqual({'id': 1, 'val': 'a'}, res[0])
     self.assertDictEqual({'id': 2, 'val': 'b'}, res[1])
     self.assertDictEqual({'id': 3, 'val': 'c'}, res[2])
예제 #8
0
    def test_save(self):
        t = XFrame({'id': [30, 20, 10], 'val': ['a', 'b', 'c']})
        path = '{}/tmp/frame-csv'.format(hdfs_prefix)
        t.save(path, format='csv')

        with fileio.open_file(path + '.csv') as f:
            heading = f.readline().rstrip()
            self.assertEqual('id,val', heading)
            self.assertEqual('30,a', f.readline().rstrip())
            self.assertEqual('20,b', f.readline().rstrip())
            self.assertEqual('10,c', f.readline().rstrip())
        fileio.delete(path + '.csv')
예제 #9
0
 def test_construct_auto_dataframe(self):
     path = '{}/user/xpatterns/files/test-frame-auto.csv'.format(hdfs_prefix)
     res = XFrame(path)
     self.assertEqualLen(3, res)
     self.assertListEqual(['val_int', 'val_int_signed', 'val_float', 'val_float_signed',
                           'val_str', 'val_list', 'val_dict'], res.column_names())
     self.assertListEqual([int, int, float, float, str, list, dict], res.column_types())
     self.assertDictEqual({'val_int': 1, 'val_int_signed': -1, 'val_float': 1.0, 'val_float_signed': -1.0,
                           'val_str': 'a', 'val_list': ['a'], 'val_dict': {1: 'a'}}, res[0])
     self.assertDictEqual({'val_int': 2, 'val_int_signed': -2, 'val_float': 2.0, 'val_float_signed': -2.0,
                           'val_str': 'b', 'val_list': ['b'], 'val_dict': {2: 'b'}}, res[1])
     self.assertDictEqual({'val_int': 3, 'val_int_signed': -3, 'val_float': 3.0, 'val_float_signed': -3.0,
                           'val_str': 'c', 'val_list': ['c'], 'val_dict': {3: 'c'}}, res[2])
예제 #10
0
from xframes import XFrame

xf = XFrame({'id': [1, 2, 3], 'val': ['a', 'b', 'c']})
print xf
예제 #11
0
 def test_save(self):
     t = XFrame({'id': [30, 20, 10], 'val': ['a', 'b', 'c']})
     path = '{}/tmp/frame-parquet'.format(hdfs_prefix)
     t.save(path, format='parquet')
     # TODO verify
     fileio.delete(path + '.parquet')