def test_recordio_data_reader(self): num_records = 128 with tempfile.TemporaryDirectory() as temp_dir_name: shard_name = create_recordio_file(num_records, DatasetName.TEST_MODULE, 1, temp_dir=temp_dir_name) # Test shards creation expected_shards = {shard_name: (0, num_records)} reader = RecordIODataReader(data_dir=temp_dir_name) self.assertEqual(expected_shards, reader.create_shards()) # Test records reading records = list( reader.read_records( _MockedTask(0, num_records, shard_name, elasticdl_pb2.TRAINING))) self.assertEqual(len(records), num_records) for record in records: parsed_record = tf.io.parse_single_example( record, { "x": tf.io.FixedLenFeature([1], tf.float32), "y": tf.io.FixedLenFeature([1], tf.float32), }, ) for k, v in parsed_record.items(): self.assertEqual(len(v.numpy()), 1)
def create_data_reader(data_origin, records_per_task=None, **kwargs): """Create a data reader to read records Args: data_origin: The origin of the data, e.g. location to files, table name in the database, etc. records_per_task: The number of records to create a task kwargs: data reader params, the supported keys are "columns", "partition", "reader_type" """ reader_type = kwargs.get("reader_type", None) if reader_type is None: if is_odps_configured(): return ODPSDataReader( project=os.environ[MaxComputeConfig.PROJECT_NAME], access_id=os.environ[MaxComputeConfig.ACCESS_ID], access_key=os.environ[MaxComputeConfig.ACCESS_KEY], table=data_origin, endpoint=os.environ.get(MaxComputeConfig.ENDPOINT), tunnel_endpoint=os.environ.get( MaxComputeConfig.TUNNEL_ENDPOINT, None), records_per_task=records_per_task, **kwargs, ) elif data_origin and data_origin.endswith(".csv"): return TextDataReader( filename=data_origin, records_per_task=records_per_task, **kwargs, ) else: return RecordIODataReader(data_dir=data_origin) elif reader_type == ReaderType.CSV_READER: return TextDataReader(filename=data_origin, records_per_task=records_per_task, **kwargs) elif reader_type == ReaderType.ODPS_READER: if not is_odps_configured: raise ValueError( "MAXCOMPUTE_AK, MAXCOMPUTE_SK and MAXCOMPUTE_PROJECT ", "must be configured in envs", ) return ODPSDataReader( project=os.environ[MaxComputeConfig.PROJECT_NAME], access_id=os.environ[MaxComputeConfig.ACCESS_ID], access_key=os.environ[MaxComputeConfig.ACCESS_KEY], table=data_origin, endpoint=os.environ.get(MaxComputeConfig.ENDPOINT), records_per_task=records_per_task, **kwargs, ) elif reader_type == ReaderType.RECORDIO_READER: return RecordIODataReader(data_dir=data_origin) else: raise ValueError( "The reader type {} is not supported".format(reader_type))
def __init__(self, **kwargs): RecordIODataReader.__init__(self, **kwargs)