示例#1
0
    def setUp(self):
        self.project = os.environ[MaxComputeConfig.PROJECT_NAME]
        access_id = os.environ[MaxComputeConfig.ACCESS_ID]
        access_key = os.environ[MaxComputeConfig.ACCESS_KEY]
        endpoint = os.environ.get(MaxComputeConfig.ENDPOINT)
        tunnel_endpoint = os.environ.get(MaxComputeConfig.TUNNEL_ENDPOINT,
                                         None)
        self.test_table = "test_odps_data_reader_%d_%d" % (
            int(time.time()),
            random.randint(1, 101),
        )
        self.odps_client = ODPS(access_id, access_key, self.project, endpoint)
        create_iris_odps_table(self.odps_client, self.project, self.test_table)
        self.records_per_task = 50

        self.reader = ODPSDataReader(
            project=self.project,
            access_id=access_id,
            access_key=access_key,
            endpoint=endpoint,
            table=self.test_table,
            tunnel_endpoint=tunnel_endpoint,
            num_processes=1,
            records_per_task=self.records_per_task,
        )
def create_data_reader(data_origin, records_per_task=None, **kwargs):
    """Create a data reader to read records
    Args:
        data_origin: The origin of the data, e.g. location to files,
            table name in the database, etc.
        records_per_task: The number of records to create a task
        kwargs: data reader params, the supported keys are
            "columns", "partition", "reader_type"
    """
    reader_type = kwargs.get("reader_type", None)
    if reader_type is None:
        if is_odps_configured():
            return ODPSDataReader(
                project=os.environ[MaxComputeConfig.PROJECT_NAME],
                access_id=os.environ[MaxComputeConfig.ACCESS_ID],
                access_key=os.environ[MaxComputeConfig.ACCESS_KEY],
                table=data_origin,
                endpoint=os.environ.get(MaxComputeConfig.ENDPOINT),
                tunnel_endpoint=os.environ.get(
                    MaxComputeConfig.TUNNEL_ENDPOINT, None),
                records_per_task=records_per_task,
                **kwargs,
            )
        elif data_origin and data_origin.endswith(".csv"):
            return TextDataReader(
                filename=data_origin,
                records_per_task=records_per_task,
                **kwargs,
            )
        else:
            return RecordIODataReader(data_dir=data_origin)
    elif reader_type == ReaderType.CSV_READER:
        return TextDataReader(filename=data_origin,
                              records_per_task=records_per_task,
                              **kwargs)
    elif reader_type == ReaderType.ODPS_READER:
        if not is_odps_configured:
            raise ValueError(
                "MAXCOMPUTE_AK, MAXCOMPUTE_SK and MAXCOMPUTE_PROJECT ",
                "must be configured in envs",
            )
        return ODPSDataReader(
            project=os.environ[MaxComputeConfig.PROJECT_NAME],
            access_id=os.environ[MaxComputeConfig.ACCESS_ID],
            access_key=os.environ[MaxComputeConfig.ACCESS_KEY],
            table=data_origin,
            endpoint=os.environ.get(MaxComputeConfig.ENDPOINT),
            records_per_task=records_per_task,
            **kwargs,
        )
    elif reader_type == ReaderType.RECORDIO_READER:
        return RecordIODataReader(data_dir=data_origin)
    else:
        raise ValueError(
            "The reader type {} is not supported".format(reader_type))
示例#3
0
class ODPSDataReaderTest(unittest.TestCase):
    def setUp(self):
        self.project = os.environ[MaxComputeConfig.PROJECT_NAME]
        access_id = os.environ[MaxComputeConfig.ACCESS_ID]
        access_key = os.environ[MaxComputeConfig.ACCESS_KEY]
        endpoint = os.environ.get(MaxComputeConfig.ENDPOINT)
        tunnel_endpoint = os.environ.get(MaxComputeConfig.TUNNEL_ENDPOINT,
                                         None)
        self.test_table = "test_odps_data_reader_%d_%d" % (
            int(time.time()),
            random.randint(1, 101),
        )
        self.odps_client = ODPS(access_id, access_key, self.project, endpoint)
        create_iris_odps_table(self.odps_client, self.project, self.test_table)
        self.records_per_task = 50

        self.reader = ODPSDataReader(
            project=self.project,
            access_id=access_id,
            access_key=access_key,
            endpoint=endpoint,
            table=self.test_table,
            tunnel_endpoint=tunnel_endpoint,
            num_processes=1,
            records_per_task=self.records_per_task,
        )

    def test_odps_data_reader_shards_creation(self):
        expected_shards = {
            self.test_table + ":shard_0": (0, self.records_per_task),
            self.test_table + ":shard_1": (50, self.records_per_task),
            self.test_table + ":shard_2": (100, 10),
        }
        self.assertEqual(expected_shards, self.reader.create_shards())

    def test_odps_data_reader_records_reading(self):
        records = list(
            self.reader.read_records(
                _MockedTask(0, 2, self.test_table + ":shard_0",
                            elasticdl_pb2.TRAINING)))
        records = np.array(records, dtype="float").tolist()
        self.assertEqual([[6.4, 2.8, 5.6, 2.2, 2], [5.0, 2.3, 3.3, 1.0, 1]],
                         records)
        self.assertEqual(self.reader.metadata.column_names,
                         IRIS_TABLE_COLUMN_NAMES)
        self.assertListEqual(
            list(self.reader.metadata.column_dtypes.values()),
            [
                odps.types.double,
                odps.types.double,
                odps.types.double,
                odps.types.double,
                odps.types.bigint,
            ],
        )
        self.assertEqual(
            self.reader.metadata.get_tf_dtype_from_maxcompute_column(
                self.reader.metadata.column_names[0]),
            tf.float64,
        )
        self.assertEqual(
            self.reader.metadata.get_tf_dtype_from_maxcompute_column(
                self.reader.metadata.column_names[-1]),
            tf.int64,
        )

    def test_create_data_reader(self):
        reader = create_data_reader(data_origin=self.test_table,
                                    records_per_task=10,
                                    **{
                                        "columns":
                                        ["sepal_length", "sepal_width"],
                                        "label_col": "class",
                                    })
        self.assertEqual(reader._kwargs["columns"],
                         ["sepal_length", "sepal_width"])
        self.assertEqual(reader._kwargs["label_col"], "class")
        self.assertEqual(reader._kwargs["records_per_task"], 10)
        reader = create_data_reader(data_origin=self.test_table,
                                    records_per_task=10)
        self.assertEqual(reader._kwargs["records_per_task"], 10)
        self.assertTrue("columns" not in reader._kwargs)

    def test_odps_data_reader_integration_with_local_keras(self):
        num_records = 2
        model_spec = load_module(
            os.path.join(
                os.path.dirname(os.path.realpath(__file__)),
                "../../../model_zoo",
                "odps_iris_dnn_model/odps_iris_dnn_model.py",
            )).__dict__
        model = model_spec["custom_model"]()
        optimizer = model_spec["optimizer"]()
        loss = model_spec["loss"]
        reader = create_data_reader(data_origin=self.test_table,
                                    records_per_task=10,
                                    **{
                                        "columns": IRIS_TABLE_COLUMN_NAMES,
                                        "label_col": "class"
                                    })
        dataset_fn = reader.default_dataset_fn()

        def _gen():
            for data in self.reader.read_records(
                    _MockedTask(
                        0,
                        num_records,
                        self.test_table + ":shard_0",
                        elasticdl_pb2.TRAINING,
                    )):
                if data is not None:
                    yield data

        dataset = tf.data.Dataset.from_generator(_gen, tf.string)
        dataset = dataset_fn(dataset, None,
                             Metadata(column_names=IRIS_TABLE_COLUMN_NAMES))
        dataset = dataset.batch(1)

        loss_history = []
        grads = None
        for features, labels in dataset:
            with tf.GradientTape() as tape:
                logits = model(features, training=True)
                loss_value = loss(labels, logits)
            loss_history.append(loss_value.numpy())
            grads = tape.gradient(loss_value, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

        self.assertEqual(len(loss_history), num_records)
        self.assertEqual(len(grads), num_records)
        self.assertEqual(len(model.trainable_variables), num_records)

    def tearDown(self):
        self.odps_client.delete_table(self.test_table,
                                      self.project,
                                      if_exists=True)