def create_data_reader(data_origin, records_per_task=None, **kwargs):
    """Create a data reader to read records
    Args:
        data_origin: The origin of the data, e.g. location to files,
            table name in the database, etc.
        records_per_task: The number of records to create a task
        kwargs: data reader params, the supported keys are
            "columns", "partition", "reader_type"
    """
    reader_type = kwargs.get("reader_type", None)
    if reader_type is None:
        if is_odps_configured():
            return ODPSDataReader(
                project=os.environ[MaxComputeConfig.PROJECT_NAME],
                access_id=os.environ[MaxComputeConfig.ACCESS_ID],
                access_key=os.environ[MaxComputeConfig.ACCESS_KEY],
                table=data_origin,
                endpoint=os.environ.get(MaxComputeConfig.ENDPOINT),
                tunnel_endpoint=os.environ.get(
                    MaxComputeConfig.TUNNEL_ENDPOINT, None),
                records_per_task=records_per_task,
                **kwargs,
            )
        elif data_origin and data_origin.endswith(".csv"):
            return CSVDataReader(data_dir=data_origin, **kwargs)
        else:
            return RecordIODataReader(data_dir=data_origin)
    elif reader_type == ReaderType.CSV_READER:
        return CSVDataReader(data_dir=data_origin, **kwargs)
    elif reader_type == ReaderType.ODPS_READER:
        if not is_odps_configured:
            raise ValueError(
                "MAXCOMPUTE_AK, MAXCOMPUTE_SK and MAXCOMPUTE_PROJECT ",
                "must be configured in envs",
            )
        return ODPSDataReader(
            project=os.environ[MaxComputeConfig.PROJECT_NAME],
            access_id=os.environ[MaxComputeConfig.ACCESS_ID],
            access_key=os.environ[MaxComputeConfig.ACCESS_KEY],
            table=data_origin,
            endpoint=os.environ.get(MaxComputeConfig.ENDPOINT),
            records_per_task=records_per_task,
            **kwargs,
        )
    elif reader_type == ReaderType.RECORDIO_READER:
        return RecordIODataReader(data_dir=data_origin)
    else:
        raise ValueError(
            "The reader type {} is not supported".format(reader_type))
示例#2
0
    def test_csv_data_reader(self):
        with tempfile.TemporaryDirectory() as temp_dir_name:
            num_records = 128
            columns = [
                "sepal_length",
                "sepal_width",
                "petal_length",
                "petal_width",
                "class",
            ]
            iris_file_name = create_iris_csv_file(
                size=num_records, columns=columns, temp_dir=temp_dir_name
            )
            csv_data_reader = CSVDataReader(columns=columns, sep=",")
            task = _MockedTask(
                0, num_records, iris_file_name, elasticdl_pb2.TRAINING
            )

            def _gen():
                for record in csv_data_reader.read_records(task):
                    yield record

            def _dataset_fn(dataset, mode, metadata):
                def _parse_data(record):
                    features = tf.strings.to_number(record[0:-1], tf.float32)
                    label = tf.strings.to_number(record[-1], tf.float32)
                    return features, label

                dataset = dataset.map(_parse_data)
                dataset = dataset.batch(10)
                return dataset

            dataset = tf.data.Dataset.from_generator(
                _gen, csv_data_reader.records_output_types
            )
            dataset = _dataset_fn(dataset, None, None)
            for features, labels in dataset:
                self.assertEqual(features.shape.as_list(), [10, 4])
                self.assertEqual(labels.shape.as_list(), [10])
                break