Exemple #1
0
class ODPSIOTest(unittest.TestCase):
    def setUp(self):
        self._project = os.environ[ODPSConfig.PROJECT_NAME]
        self._access_id = os.environ[ODPSConfig.ACCESS_ID]
        self._access_key = os.environ[ODPSConfig.ACCESS_KEY]
        self._endpoint = os.environ.get(ODPSConfig.ENDPOINT)
        self._test_read_table = "test_odps_reader_%d_%d" % (
            int(time.time()),
            random.randint(1, 101),
        )
        self._test_write_table = "test_odps_writer_%d_%d" % (
            int(time.time()),
            random.randint(1, 101),
        )
        self._odps_client = ODPS(self._access_id, self._access_key,
                                 self._project, self._endpoint)
        create_iris_odps_table(self._odps_client, self._project,
                               self._test_read_table)

    def test_read_to_iterator(self):
        reader = ODPSReader(
            self._project,
            self._access_id,
            self._access_key,
            self._endpoint,
            self._test_read_table,
            None,
            4,
            None,
        )
        records_iter = reader.to_iterator(1, 0, 50, 2, False, None)
        records = list(records_iter)
        self.assertEqual(len(records), 6,
                         "Unexpected number of batches: %d" % len(records))
        flattened_records = [record for batch in records for record in batch]
        self.assertEqual(
            len(flattened_records),
            220,
            "Unexpected number of total records: %d" % len(flattened_records),
        )

    def test_write_odps_to_recordio_shards_from_iterator(self):
        reader = ODPSReader(
            self._project,
            self._access_id,
            self._access_key,
            self._endpoint,
            self._test_read_table,
            None,
            4,
            None,
        )
        records_iter = reader.to_iterator(1, 0, 50, 2, False, None)
        with tempfile.TemporaryDirectory() as output_dir:
            write_recordio_shards_from_iterator(
                records_iter,
                ["f" + str(i) for i in range(5)],
                output_dir,
                records_per_shard=50,
            )
            self.assertEqual(len(os.listdir(output_dir)), 5)

    def test_write_from_iterator(self):
        columns = ["num", "num2"]
        column_types = ["bigint", "double"]

        # If the table doesn't exist yet
        writer = ODPSWriter(
            self._project,
            self._access_id,
            self._access_key,
            self._endpoint,
            self._test_write_table,
            columns,
            column_types,
        )
        writer.from_iterator(iter([[1, 0.5], [2, 0.6]]), 2)
        table = self._odps_client.get_table(self._test_write_table,
                                            self._project)
        self.assertEqual(table.schema.names, columns)
        self.assertEqual(table.schema.types, column_types)
        self.assertEqual(table.to_df().count(), 1)

        # If the table already exists
        writer = ODPSWriter(
            self._project,
            self._access_id,
            self._access_key,
            self._endpoint,
            self._test_write_table,
        )
        writer.from_iterator(iter([[1, 0.5], [2, 0.6]]), 2)
        table = self._odps_client.get_table(self._test_write_table,
                                            self._project)
        self.assertEqual(table.schema.names, columns)
        self.assertEqual(table.schema.types, column_types)
        self.assertEqual(table.to_df().count(), 2)

    def tearDown(self):
        self._odps_client.delete_table(self._test_write_table,
                                       self._project,
                                       if_exists=True)
        self._odps_client.delete_table(self._test_read_table,
                                       self._project,
                                       if_exists=True)
class ODPSDataReaderTest(unittest.TestCase):
    def setUp(self):
        self.project = os.environ[MaxComputeConfig.PROJECT_NAME]
        access_id = os.environ[MaxComputeConfig.ACCESS_ID]
        access_key = os.environ[MaxComputeConfig.ACCESS_KEY]
        endpoint = os.environ.get(MaxComputeConfig.ENDPOINT)
        tunnel_endpoint = os.environ.get(MaxComputeConfig.TUNNEL_ENDPOINT,
                                         None)
        self.test_table = "test_odps_data_reader_%d_%d" % (
            int(time.time()),
            random.randint(1, 101),
        )
        self.odps_client = ODPS(access_id, access_key, self.project, endpoint)
        create_iris_odps_table(self.odps_client, self.project, self.test_table)
        self.records_per_task = 50

        self.reader = ODPSDataReader(
            project=self.project,
            access_id=access_id,
            access_key=access_key,
            endpoint=endpoint,
            table=self.test_table,
            tunnel_endpoint=tunnel_endpoint,
            num_processes=1,
            records_per_task=self.records_per_task,
        )

    def test_odps_data_reader_shards_creation(self):
        expected_shards = {
            self.test_table + ":shard_0": (0, self.records_per_task),
            self.test_table + ":shard_1": (50, self.records_per_task),
            self.test_table + ":shard_2": (100, 10),
        }
        self.assertEqual(expected_shards, self.reader.create_shards())

    def test_odps_data_reader_records_reading(self):
        records = list(
            self.reader.read_records(
                _MockedTask(0, 2, self.test_table + ":shard_0",
                            elasticdl_pb2.TRAINING)))
        records = np.array(records, dtype="float").tolist()
        self.assertEqual([[6.4, 2.8, 5.6, 2.2, 2], [5.0, 2.3, 3.3, 1.0, 1]],
                         records)
        self.assertEqual(self.reader.metadata.column_names,
                         IRIS_TABLE_COLUMN_NAMES)
        self.assertListEqual(
            list(self.reader.metadata.column_dtypes.values()),
            [
                odps.types.double,
                odps.types.double,
                odps.types.double,
                odps.types.double,
                odps.types.bigint,
            ],
        )
        self.assertEqual(
            self.reader.metadata.get_tf_dtype_from_maxcompute_column(
                self.reader.metadata.column_names[0]),
            tf.float64,
        )
        self.assertEqual(
            self.reader.metadata.get_tf_dtype_from_maxcompute_column(
                self.reader.metadata.column_names[-1]),
            tf.int64,
        )

    def test_create_data_reader(self):
        reader = create_data_reader(data_origin=self.test_table,
                                    records_per_task=10,
                                    **{
                                        "columns":
                                        ["sepal_length", "sepal_width"],
                                        "label_col": "class",
                                    })
        self.assertEqual(reader._kwargs["columns"],
                         ["sepal_length", "sepal_width"])
        self.assertEqual(reader._kwargs["label_col"], "class")
        self.assertEqual(reader._kwargs["records_per_task"], 10)
        reader = create_data_reader(data_origin=self.test_table,
                                    records_per_task=10)
        self.assertEqual(reader._kwargs["records_per_task"], 10)
        self.assertTrue("columns" not in reader._kwargs)

    def test_odps_data_reader_integration_with_local_keras(self):
        num_records = 2
        model_spec = load_module(
            os.path.join(
                os.path.dirname(os.path.realpath(__file__)),
                "../../../model_zoo",
                "odps_iris_dnn_model/odps_iris_dnn_model.py",
            )).__dict__
        model = model_spec["custom_model"]()
        optimizer = model_spec["optimizer"]()
        loss = model_spec["loss"]
        reader = create_data_reader(data_origin=self.test_table,
                                    records_per_task=10,
                                    **{
                                        "columns": IRIS_TABLE_COLUMN_NAMES,
                                        "label_col": "class"
                                    })
        dataset_fn = reader.default_dataset_fn()

        def _gen():
            for data in self.reader.read_records(
                    _MockedTask(
                        0,
                        num_records,
                        self.test_table + ":shard_0",
                        elasticdl_pb2.TRAINING,
                    )):
                if data is not None:
                    yield data

        dataset = tf.data.Dataset.from_generator(_gen, tf.string)
        dataset = dataset_fn(dataset, None,
                             Metadata(column_names=IRIS_TABLE_COLUMN_NAMES))
        dataset = dataset.batch(1)

        loss_history = []
        grads = None
        for features, labels in dataset:
            with tf.GradientTape() as tape:
                logits = model(features, training=True)
                loss_value = loss(labels, logits)
            loss_history.append(loss_value.numpy())
            grads = tape.gradient(loss_value, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

        self.assertEqual(len(loss_history), num_records)
        self.assertEqual(len(grads), num_records)
        self.assertEqual(len(model.trainable_variables), num_records)

    def tearDown(self):
        self.odps_client.delete_table(self.test_table,
                                      self.project,
                                      if_exists=True)
Exemple #3
0
class ODPSDataReaderTest(unittest.TestCase):
    def setUp(self):
        self.project = os.environ[ODPSConfig.PROJECT_NAME]
        access_id = os.environ[ODPSConfig.ACCESS_ID]
        access_key = os.environ[ODPSConfig.ACCESS_KEY]
        endpoint = os.environ.get(ODPSConfig.ENDPOINT)
        self.test_table = "test_odps_data_reader_%d_%d" % (
            int(time.time()),
            random.randint(1, 101),
        )
        self.odps_client = ODPS(access_id, access_key, self.project, endpoint)
        create_iris_odps_table(self.odps_client, self.project, self.test_table)
        self.records_per_task = 50

        self.reader = ODPSDataReader(
            project=self.project,
            access_id=access_id,
            access_key=access_key,
            endpoint=endpoint,
            table=self.test_table,
            num_processes=1,
            records_per_task=self.records_per_task,
        )

    def test_odps_data_reader_shards_creation(self):
        expected_shards = {
            "shard_0": (0, self.records_per_task),
            "shard_1": (50, self.records_per_task),
            "shard_2": (100, 10),
        }
        self.assertEqual(expected_shards, self.reader.create_shards())

    def test_odps_data_reader_records_reading(self):
        records = list(self.reader.read_records(_MockedTask(0, 2, "shard_0")))
        self.assertEqual([[6.4, 2.8, 5.6, 2.2, 2], [5.0, 2.3, 3.3, 1.0, 1]],
                         records)

    def test_odps_data_reader_integration_with_local_keras(self):
        num_records = 2
        model_spec = load_module(
            os.path.join(
                os.path.dirname(os.path.realpath(__file__)),
                "odps_test_module.py",
            )).__dict__
        model = model_spec["custom_model"]()
        optimizer = model_spec["optimizer"]()
        loss = model_spec["loss"]
        dataset_fn = model_spec["dataset_fn"]

        def _gen():
            for data in self.reader.read_records(
                    _MockedTask(0, num_records, "shard_0")):
                if data is not None:
                    yield data

        dataset = tf.data.Dataset.from_generator(_gen, (tf.float32))
        dataset = dataset_fn(dataset, None)

        loss_history = []
        grads = None
        for features, labels in dataset:
            with tf.GradientTape() as tape:
                logits = model(features, training=True)
                loss_value = loss(logits, labels)
            loss_history.append(loss_value.numpy())
            grads = tape.gradient(loss_value, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

        self.assertEqual(len(loss_history), num_records)
        self.assertEqual(len(grads), num_records)
        self.assertEqual(len(model.trainable_variables), num_records)

    def tearDown(self):
        self.odps_client.delete_table(self.test_table,
                                      self.project,
                                      if_exists=True)
class ODPSIOTest(unittest.TestCase):
    def setUp(self):
        self._project = os.environ[ODPSConfig.PROJECT_NAME]
        self._access_id = os.environ[ODPSConfig.ACCESS_ID]
        self._access_key = os.environ[ODPSConfig.ACCESS_KEY]
        self._endpoint = os.environ[ODPSConfig.ENDPOINT]
        self._test_read_table = "chicago_taxi_train_data"
        self._test_write_table = "test_odps_writer_%d_%d" % (
            int(time.time()),
            random.randint(1, 101),
        )
        self._odps_client = ODPS(self._access_id, self._access_key,
                                 self._project, self._endpoint)

    def test_read_to_iterator(self):
        reader = ODPSReader(
            self._project,
            self._access_id,
            self._access_key,
            self._endpoint,
            self._test_read_table,
            None,
            4,
            None,
        )
        records_iter = reader.to_iterator(1, 0, 200, 2, False, None)
        for batch in records_iter:
            self.assertEqual(len(batch), 200,
                             "incompatible size: %d" % len(batch))

    def test_write_odps_to_recordio_shards_from_iterator(self):
        reader = ODPSReader(
            self._project,
            self._access_id,
            self._access_key,
            self._endpoint,
            self._test_read_table,
            None,
            4,
            None,
        )
        records_iter = reader.to_iterator(1, 0, 200, 2, False, None)
        with tempfile.TemporaryDirectory() as output_dir:
            write_recordio_shards_from_iterator(
                records_iter,
                ["f" + str(i) for i in range(18)],
                output_dir,
                records_per_shard=200,
            )
            self.assertEqual(len(os.listdir(output_dir)), 100)

    def test_write_from_iterator(self):
        columns = ["num", "num2"]
        column_types = ["bigint", "double"]

        # If the table doesn't exist yet
        writer = ODPSWriter(
            self._project,
            self._access_id,
            self._access_key,
            self._endpoint,
            self._test_write_table,
            columns,
            column_types,
        )
        writer.from_iterator(iter([[1, 0.5], [2, 0.6]]), 2)
        table = self._odps_client.get_table(self._test_write_table,
                                            self._project)
        self.assertEqual(table.schema.names, columns)
        self.assertEqual(table.schema.types, column_types)
        self.assertEqual(table.to_df().count(), 1)

        # If the table already exists
        writer = ODPSWriter(
            self._project,
            self._access_id,
            self._access_key,
            self._endpoint,
            self._test_write_table,
        )
        writer.from_iterator(iter([[1, 0.5], [2, 0.6]]), 2)
        table = self._odps_client.get_table(self._test_write_table,
                                            self._project)
        self.assertEqual(table.schema.names, columns)
        self.assertEqual(table.schema.types, column_types)
        self.assertEqual(table.to_df().count(), 2)

    def tearDown(self):
        self._odps_client.delete_table(self._test_write_table,
                                       self._project,
                                       if_exists=True)
Exemple #5
0
class myOdps:

    # 初始化一个odps连接对象
    def __init__(self, access_id, secret_access_key, project):
        self.odps = ODPS(access_id=access_id,
                         secret_access_key=secret_access_key,
                         project=project,
                         end_point="http://service.odps.aliyun.com/api")

    # 获取所有表名
    def get_all_tabel(self):
        # return 返回所有的表名
        table_name = []
        for table in self.odps.list_tables():
            table_name.append(table.name)
        return table_name

    # 创建一张表
    def creat_table(self, table_name, columns=None, if_not_exists=True):
        # table_name: 表名
        # columns :  ('num bigint, num2 double', 'pt string') 字段和分组的元组
        # if_not_exists:True   不存在才创建
        # lifecycle:28   生命周期
        # return 返回表对象
        try:
            return self.odps.create_table(table_name,
                                          columns,
                                          if_not_exists=if_not_exists)
        except:
            return self.odps.get_table(table_name)

    # 通过表名直接获取一张表
    def get_a_table(self, table_name):
        # table_name: 表名
        # return 返回表对象
        return self.odps.get_table(table_name)

    # 删除一张表
    def drop_a_table(self, table_name):
        # table_name: 表名
        # return 返回表删除结果
        return self.odps.delete_table(table_name)

    # 获取一张表的所有分区
    def get_partitions(self, table):
        # table:表对象
        # return: 表的所有分区
        partitions = []
        for partition in table.partitions:
            partitions.append(partition.name)
        return partitions

    # ============= 数据上传 ============

    # 上传csv到odps并创建表,csv必须要有表头

    def uploadCSV(self, csvFilename, tableName, sep=",", pt=None):
        """
        :param csvFilename: 传入本地csv的路径,必须要有表头
        :param tableName:  上传到odps时的表名
        :param sep:   csv的分隔符
        :param pt:   是否创建分区
        """
        print("start upload ...\n")
        df = pd.read_csv(csvFilename, sep=sep)
        shape0 = df.shape[0]
        columns = [
            Column(name=f"{x}", type='string', comment='the column')
            for x in df.columns
        ]

        if pt:
            partitions = [
                Partition(name='pt', type='string', comment='the partition')
            ]
            schema = Schema(columns=columns, partitions=partitions)
            table = self.creat_table(tableName, schema)
            table.create_partition(f"pt={pt}", if_not_exists=True)
            table_columns = [i.name for i in table.schema.columns]
            with table.open_writer(partition=f"pt={pt}") as writer:
                for index in df.index:
                    print(f"{index+1}/{shape0} in {tableName}  ...")
                    item_dict = dict(df.loc[index])
                    item = []
                    for field in table_columns[:-1]:
                        item.append(item_dict.get(field, ''))
                    item.append(pt)
                    writer.write(item)
        else:
            schema = Schema(columns=columns)
            table = self.creat_table(tableName, schema)
            table_columns = [i.name for i in table.schema.columns]
            with table.open_writer(partition=None) as writer:
                for index in df.index:
                    print(f"{index+1}/{shape0} in {tableName}  ...")
                    item_dict = dict(df.loc[index])
                    item = []
                    for field in table_columns[:-1]:
                        item.append(item_dict.get(field, ''))
                    writer.write(item)
        print("\n\n upload finish ...")

    # 上传的过程中并进行下载,下载完再上传完整的数据,数据行的坐标为1的字段为下载地址
    def downloaAndUp(self,
                     csvFilename,
                     tableName,
                     sep=",",
                     urlIndex=1,
                     pt=None):
        """
        :param csvFilename: 传入本地csv的路径,必须要有表头
        :param tableName:  上传到odps时的表名
        :param sep:   csv的分隔符
        :param urlIndex: url字段的坐标位置
        """
        print("start upload ...\n")
        f = open(csvFilename, encoding='utf-8')
        first_line = f.readlines(1)[0].strip('\n').split(sep)
        columns = [
            Column(name=f"{x}", type='string', comment='the column')
            for x in first_line
        ]

        if pt:
            partitions = [
                Partition(name='pt', type='string', comment='the partition')
            ]
            schema = Schema(columns=columns, partitions=partitions)
            table = self.creat_table(tableName, schema)
            table.create_partition(f"pt={pt}", if_not_exists=True)
            with table.open_writer(partition=f"pt={pt}") as writer:
                for index, line in enumerate(f):
                    print(f"{index} in {tableName}  ...")
                    item = line.strip('\n').split(sep)
                    item.append(pt)
                    resp = download(item[urlIndex])
                    data = resp.text
                    if sys.getsizeof(data) <= 8 * 1024 * 1000:
                        item[urlIndex] = data
                    else:
                        print(f"failed in {item[0]}")
                    writer.write(item)
        else:
            schema = Schema(columns=columns)
            table = self.creat_table(tableName, schema)
            with table.open_writer(partition=None) as writer:
                for index, line in enumerate(f):
                    print(f"{index}  in {tableName}  ...")
                    item = line.strip('\n').split(sep)
                    resp = download(item[urlIndex])
                    data = resp.text
                    if sys.getsizeof(data) <= 8 * 1024 * 1000:
                        item[urlIndex] = data
                    else:
                        print(f"failed in {item[0]}")
                    writer.write(item)
        print("\n\n upload finish ...")
        f.close()

    # ===========执行sql=========
    # sql查询
    def select_sql(self, sql):
        # return: 查询结果的迭代对象
        with self.odps.execute_sql(sql).open_reader() as reader:
            return reader
Exemple #6
0
class ODPSIOTest(unittest.TestCase):
    def setUp(self):
        self._project = os.environ[ODPSConfig.PROJECT_NAME]
        self._access_id = os.environ[ODPSConfig.ACCESS_ID]
        self._access_key = os.environ[ODPSConfig.ACCESS_KEY]
        self._endpoint = os.environ.get(ODPSConfig.ENDPOINT)
        self._test_read_table = "test_odps_reader_%d_%d" % (
            int(time.time()),
            random.randint(1, 101),
        )
        self._test_write_table = "test_odps_writer_%d_%d" % (
            int(time.time()),
            random.randint(1, 101),
        )
        self._odps_client = ODPS(self._access_id, self._access_key,
                                 self._project, self._endpoint)
        self.create_iris_odps_table()

    def test_read_to_iterator(self):
        reader = ODPSReader(
            self._project,
            self._access_id,
            self._access_key,
            self._endpoint,
            self._test_read_table,
            None,
            4,
            None,
        )
        records_iter = reader.to_iterator(1, 0, 50, 2, False, None)
        records = list(records_iter)
        self.assertEqual(len(records), 6,
                         "Unexpected number of batches: %d" % len(records))
        flattened_records = [record for batch in records for record in batch]
        self.assertEqual(
            len(flattened_records),
            220,
            "Unexpected number of total records: %d" % len(flattened_records),
        )

    def test_write_odps_to_recordio_shards_from_iterator(self):
        reader = ODPSReader(
            self._project,
            self._access_id,
            self._access_key,
            self._endpoint,
            self._test_read_table,
            None,
            4,
            None,
        )
        records_iter = reader.to_iterator(1, 0, 50, 2, False, None)
        with tempfile.TemporaryDirectory() as output_dir:
            write_recordio_shards_from_iterator(
                records_iter,
                ["f" + str(i) for i in range(5)],
                output_dir,
                records_per_shard=50,
            )
            self.assertEqual(len(os.listdir(output_dir)), 5)

    def test_write_from_iterator(self):
        columns = ["num", "num2"]
        column_types = ["bigint", "double"]

        # If the table doesn't exist yet
        writer = ODPSWriter(
            self._project,
            self._access_id,
            self._access_key,
            self._endpoint,
            self._test_write_table,
            columns,
            column_types,
        )
        writer.from_iterator(iter([[1, 0.5], [2, 0.6]]), 2)
        table = self._odps_client.get_table(self._test_write_table,
                                            self._project)
        self.assertEqual(table.schema.names, columns)
        self.assertEqual(table.schema.types, column_types)
        self.assertEqual(table.to_df().count(), 1)

        # If the table already exists
        writer = ODPSWriter(
            self._project,
            self._access_id,
            self._access_key,
            self._endpoint,
            self._test_write_table,
        )
        writer.from_iterator(iter([[1, 0.5], [2, 0.6]]), 2)
        table = self._odps_client.get_table(self._test_write_table,
                                            self._project)
        self.assertEqual(table.schema.names, columns)
        self.assertEqual(table.schema.types, column_types)
        self.assertEqual(table.to_df().count(), 2)

    def create_iris_odps_table(self):
        sql_tmpl = """
        DROP TABLE IF EXISTS {PROJECT_NAME}.{TABLE_NAME};
        CREATE TABLE {PROJECT_NAME}.{TABLE_NAME} (
               sepal_length DOUBLE,
               sepal_width  DOUBLE,
               petal_length DOUBLE,
               petal_width  DOUBLE,
               class BIGINT);

        INSERT INTO {PROJECT_NAME}.{TABLE_NAME} VALUES
        (6.4,2.8,5.6,2.2,2),
        (5.0,2.3,3.3,1.0,1),
        (4.9,2.5,4.5,1.7,2),
        (4.9,3.1,1.5,0.1,0),
        (5.7,3.8,1.7,0.3,0),
        (4.4,3.2,1.3,0.2,0),
        (5.4,3.4,1.5,0.4,0),
        (6.9,3.1,5.1,2.3,2),
        (6.7,3.1,4.4,1.4,1),
        (5.1,3.7,1.5,0.4,0),
        (5.2,2.7,3.9,1.4,1),
        (6.9,3.1,4.9,1.5,1),
        (5.8,4.0,1.2,0.2,0),
        (5.4,3.9,1.7,0.4,0),
        (7.7,3.8,6.7,2.2,2),
        (6.3,3.3,4.7,1.6,1),
        (6.8,3.2,5.9,2.3,2),
        (7.6,3.0,6.6,2.1,2),
        (6.4,3.2,5.3,2.3,2),
        (5.7,4.4,1.5,0.4,0),
        (6.7,3.3,5.7,2.1,2),
        (6.4,2.8,5.6,2.1,2),
        (5.4,3.9,1.3,0.4,0),
        (6.1,2.6,5.6,1.4,2),
        (7.2,3.0,5.8,1.6,2),
        (5.2,3.5,1.5,0.2,0),
        (5.8,2.6,4.0,1.2,1),
        (5.9,3.0,5.1,1.8,2),
        (5.4,3.0,4.5,1.5,1),
        (6.7,3.0,5.0,1.7,1),
        (6.3,2.3,4.4,1.3,1),
        (5.1,2.5,3.0,1.1,1),
        (6.4,3.2,4.5,1.5,1),
        (6.8,3.0,5.5,2.1,2),
        (6.2,2.8,4.8,1.8,2),
        (6.9,3.2,5.7,2.3,2),
        (6.5,3.2,5.1,2.0,2),
        (5.8,2.8,5.1,2.4,2),
        (5.1,3.8,1.5,0.3,0),
        (4.8,3.0,1.4,0.3,0),
        (7.9,3.8,6.4,2.0,2),
        (5.8,2.7,5.1,1.9,2),
        (6.7,3.0,5.2,2.3,2),
        (5.1,3.8,1.9,0.4,0),
        (4.7,3.2,1.6,0.2,0),
        (6.0,2.2,5.0,1.5,2),
        (4.8,3.4,1.6,0.2,0),
        (7.7,2.6,6.9,2.3,2),
        (4.6,3.6,1.0,0.2,0),
        (7.2,3.2,6.0,1.8,2),
        (5.0,3.3,1.4,0.2,0),
        (6.6,3.0,4.4,1.4,1),
        (6.1,2.8,4.0,1.3,1),
        (5.0,3.2,1.2,0.2,0),
        (7.0,3.2,4.7,1.4,1),
        (6.0,3.0,4.8,1.8,2),
        (7.4,2.8,6.1,1.9,2),
        (5.8,2.7,5.1,1.9,2),
        (6.2,3.4,5.4,2.3,2),
        (5.0,2.0,3.5,1.0,1),
        (5.6,2.5,3.9,1.1,1),
        (6.7,3.1,5.6,2.4,2),
        (6.3,2.5,5.0,1.9,2),
        (6.4,3.1,5.5,1.8,2),
        (6.2,2.2,4.5,1.5,1),
        (7.3,2.9,6.3,1.8,2),
        (4.4,3.0,1.3,0.2,0),
        (7.2,3.6,6.1,2.5,2),
        (6.5,3.0,5.5,1.8,2),
        (5.0,3.4,1.5,0.2,0),
        (4.7,3.2,1.3,0.2,0),
        (6.6,2.9,4.6,1.3,1),
        (5.5,3.5,1.3,0.2,0),
        (7.7,3.0,6.1,2.3,2),
        (6.1,3.0,4.9,1.8,2),
        (4.9,3.1,1.5,0.1,0),
        (5.5,2.4,3.8,1.1,1),
        (5.7,2.9,4.2,1.3,1),
        (6.0,2.9,4.5,1.5,1),
        (6.4,2.7,5.3,1.9,2),
        (5.4,3.7,1.5,0.2,0),
        (6.1,2.9,4.7,1.4,1),
        (6.5,2.8,4.6,1.5,1),
        (5.6,2.7,4.2,1.3,1),
        (6.3,3.4,5.6,2.4,2),
        (4.9,3.1,1.5,0.1,0),
        (6.8,2.8,4.8,1.4,1),
        (5.7,2.8,4.5,1.3,1),
        (6.0,2.7,5.1,1.6,1),
        (5.0,3.5,1.3,0.3,0),
        (6.5,3.0,5.2,2.0,2),
        (6.1,2.8,4.7,1.2,1),
        (5.1,3.5,1.4,0.3,0),
        (4.6,3.1,1.5,0.2,0),
        (6.5,3.0,5.8,2.2,2),
        (4.6,3.4,1.4,0.3,0),
        (4.6,3.2,1.4,0.2,0),
        (7.7,2.8,6.7,2.0,2),
        (5.9,3.2,4.8,1.8,1),
        (5.1,3.8,1.6,0.2,0),
        (4.9,3.0,1.4,0.2,0),
        (4.9,2.4,3.3,1.0,1),
        (4.5,2.3,1.3,0.3,0),
        (5.8,2.7,4.1,1.0,1),
        (5.0,3.4,1.6,0.4,0),
        (5.2,3.4,1.4,0.2,0),
        (5.3,3.7,1.5,0.2,0),
        (5.0,3.6,1.4,0.2,0),
        (5.6,2.9,3.6,1.3,1),
        (4.8,3.1,1.6,0.2,0);
        """
        self._odps_client.execute_sql(
            sql_tmpl.format(PROJECT_NAME=self._project,
                            TABLE_NAME=self._test_read_table),
            hints={"odps.sql.submit.mode": "script"},
        )

    def tearDown(self):
        self._odps_client.delete_table(self._test_write_table,
                                       self._project,
                                       if_exists=True)
        self._odps_client.delete_table(self._test_read_table,
                                       self._project,
                                       if_exists=True)
Exemple #7
0
class ODPSIOTest(unittest.TestCase):
    def setUp(self):
        self._project = os.environ[MaxComputeConfig.PROJECT_NAME]
        self._access_id = os.environ[MaxComputeConfig.ACCESS_ID]
        self._access_key = os.environ[MaxComputeConfig.ACCESS_KEY]
        self._endpoint = os.environ.get(MaxComputeConfig.ENDPOINT)
        self._test_read_table = "test_odps_reader_%d_%d" % (
            int(time.time()),
            random.randint(1, 101),
        )
        self._test_write_table = "test_odps_writer_%d_%d" % (
            int(time.time()),
            random.randint(1, 101),
        )
        self._odps_client = ODPS(self._access_id, self._access_key,
                                 self._project, self._endpoint)
        create_iris_odps_table(self._odps_client, self._project,
                               self._test_read_table)

    def test_parallel_read(self):
        def transform(record):
            return float(record[0]) + 1

        start = 0
        end = 100
        shard_size = (end - start) // 4

        pd = ODPSReader(
            access_id=self._access_id,
            access_key=self._access_key,
            project=self._project,
            endpoint=self._endpoint,
            table=self._test_read_table,
            num_processes=2,
            transform_fn=transform,
        )

        results = []
        pd.reset((start, end - start), shard_size)
        shard_count = pd.get_shards_count()
        for i in range(shard_count):
            records = pd.get_records()
            for record in records:
                results.append(record)
        pd.stop()

        self.assertEqual(len(results), 100)

    def test_write_from_iterator(self):
        columns = ["num", "num2"]
        column_types = ["bigint", "double"]

        # If the table doesn't exist yet
        writer = ODPSWriter(
            self._project,
            self._access_id,
            self._access_key,
            self._endpoint,
            self._test_write_table,
            columns,
            column_types,
        )
        writer.from_iterator(iter([[1, 0.5], [2, 0.6]]), 2)
        table = self._odps_client.get_table(self._test_write_table,
                                            self._project)
        self.assertEqual(table.schema.names, columns)
        self.assertEqual(table.schema.types, column_types)
        self.assertEqual(table.to_df().count(), 1)

        # If the table already exists
        writer = ODPSWriter(
            self._project,
            self._access_id,
            self._access_key,
            self._endpoint,
            self._test_write_table,
        )
        writer.from_iterator(iter([[1, 0.5], [2, 0.6]]), 2)
        table = self._odps_client.get_table(self._test_write_table,
                                            self._project)
        self.assertEqual(table.schema.names, columns)
        self.assertEqual(table.schema.types, column_types)
        self.assertEqual(table.to_df().count(), 2)

    def tearDown(self):
        self._odps_client.delete_table(self._test_write_table,
                                       self._project,
                                       if_exists=True)
        self._odps_client.delete_table(self._test_read_table,
                                       self._project,
                                       if_exists=True)