def create_data_reader(data_origin, records_per_task=None, **kwargs): """Create a data reader to read records Args: data_origin: The origin of the data, e.g. location to files, table name in the database, etc. records_per_task: The number of records to create a task kwargs: data reader params, the supported keys are "columns", "partition", "reader_type" """ reader_type = kwargs.get("reader_type", None) if reader_type is None: if is_odps_configured(): return ODPSDataReader( project=os.environ[MaxComputeConfig.PROJECT_NAME], access_id=os.environ[MaxComputeConfig.ACCESS_ID], access_key=os.environ[MaxComputeConfig.ACCESS_KEY], table=data_origin, endpoint=os.environ.get(MaxComputeConfig.ENDPOINT), tunnel_endpoint=os.environ.get( MaxComputeConfig.TUNNEL_ENDPOINT, None), records_per_task=records_per_task, **kwargs, ) elif data_origin and data_origin.endswith(".csv"): return CSVDataReader(data_dir=data_origin, **kwargs) else: return RecordIODataReader(data_dir=data_origin) elif reader_type == ReaderType.CSV_READER: return CSVDataReader(data_dir=data_origin, **kwargs) elif reader_type == ReaderType.ODPS_READER: if not is_odps_configured: raise ValueError( "MAXCOMPUTE_AK, MAXCOMPUTE_SK and MAXCOMPUTE_PROJECT ", "must be configured in envs", ) return ODPSDataReader( project=os.environ[MaxComputeConfig.PROJECT_NAME], access_id=os.environ[MaxComputeConfig.ACCESS_ID], access_key=os.environ[MaxComputeConfig.ACCESS_KEY], table=data_origin, endpoint=os.environ.get(MaxComputeConfig.ENDPOINT), records_per_task=records_per_task, **kwargs, ) elif reader_type == ReaderType.RECORDIO_READER: return RecordIODataReader(data_dir=data_origin) else: raise ValueError( "The reader type {} is not supported".format(reader_type))
def test_csv_data_reader(self): with tempfile.TemporaryDirectory() as temp_dir_name: num_records = 128 columns = [ "sepal_length", "sepal_width", "petal_length", "petal_width", "class", ] iris_file_name = create_iris_csv_file( size=num_records, columns=columns, temp_dir=temp_dir_name ) csv_data_reader = CSVDataReader(columns=columns, sep=",") task = _MockedTask( 0, num_records, iris_file_name, elasticdl_pb2.TRAINING ) def _gen(): for record in csv_data_reader.read_records(task): yield record def _dataset_fn(dataset, mode, metadata): def _parse_data(record): features = tf.strings.to_number(record[0:-1], tf.float32) label = tf.strings.to_number(record[-1], tf.float32) return features, label dataset = dataset.map(_parse_data) dataset = dataset.batch(10) return dataset dataset = tf.data.Dataset.from_generator( _gen, csv_data_reader.records_output_types ) dataset = _dataset_fn(dataset, None, None) for features, labels in dataset: self.assertEqual(features.shape.as_list(), [10, 4]) self.assertEqual(labels.shape.as_list(), [10]) break