示例#1
0
    def test_recordio_data_reader(self):
        num_records = 128
        with tempfile.TemporaryDirectory() as temp_dir_name:
            shard_name = create_recordio_file(num_records,
                                              DatasetName.TEST_MODULE,
                                              1,
                                              temp_dir=temp_dir_name)

            # Test shards creation
            expected_shards = {shard_name: (0, num_records)}
            reader = RecordIODataReader(data_dir=temp_dir_name)
            self.assertEqual(expected_shards, reader.create_shards())

            # Test records reading
            records = list(
                reader.read_records(
                    _MockedTask(0, num_records, shard_name,
                                elasticdl_pb2.TRAINING)))
            self.assertEqual(len(records), num_records)
            for record in records:
                parsed_record = tf.io.parse_single_example(
                    record,
                    {
                        "x": tf.io.FixedLenFeature([1], tf.float32),
                        "y": tf.io.FixedLenFeature([1], tf.float32),
                    },
                )
                for k, v in parsed_record.items():
                    self.assertEqual(len(v.numpy()), 1)
def create_data_reader(data_origin, records_per_task=None, **kwargs):
    """Create a data reader to read records
    Args:
        data_origin: The origin of the data, e.g. location to files,
            table name in the database, etc.
        records_per_task: The number of records to create a task
        kwargs: data reader params, the supported keys are
            "columns", "partition", "reader_type"
    """
    reader_type = kwargs.get("reader_type", None)
    if reader_type is None:
        if is_odps_configured():
            return ODPSDataReader(
                project=os.environ[MaxComputeConfig.PROJECT_NAME],
                access_id=os.environ[MaxComputeConfig.ACCESS_ID],
                access_key=os.environ[MaxComputeConfig.ACCESS_KEY],
                table=data_origin,
                endpoint=os.environ.get(MaxComputeConfig.ENDPOINT),
                tunnel_endpoint=os.environ.get(
                    MaxComputeConfig.TUNNEL_ENDPOINT, None),
                records_per_task=records_per_task,
                **kwargs,
            )
        elif data_origin and data_origin.endswith(".csv"):
            return TextDataReader(
                filename=data_origin,
                records_per_task=records_per_task,
                **kwargs,
            )
        else:
            return RecordIODataReader(data_dir=data_origin)
    elif reader_type == ReaderType.CSV_READER:
        return TextDataReader(filename=data_origin,
                              records_per_task=records_per_task,
                              **kwargs)
    elif reader_type == ReaderType.ODPS_READER:
        if not is_odps_configured:
            raise ValueError(
                "MAXCOMPUTE_AK, MAXCOMPUTE_SK and MAXCOMPUTE_PROJECT ",
                "must be configured in envs",
            )
        return ODPSDataReader(
            project=os.environ[MaxComputeConfig.PROJECT_NAME],
            access_id=os.environ[MaxComputeConfig.ACCESS_ID],
            access_key=os.environ[MaxComputeConfig.ACCESS_KEY],
            table=data_origin,
            endpoint=os.environ.get(MaxComputeConfig.ENDPOINT),
            records_per_task=records_per_task,
            **kwargs,
        )
    elif reader_type == ReaderType.RECORDIO_READER:
        return RecordIODataReader(data_dir=data_origin)
    else:
        raise ValueError(
            "The reader type {} is not supported".format(reader_type))
示例#3
0
 def __init__(self, **kwargs):
     RecordIODataReader.__init__(self, **kwargs)