Ejemplo n.º 1
0
 def parallel_record_records(self, task, num_processes, shard_size,
                             transform_fn):
     check_required_kwargs(["project", "access_id", "access_key"],
                           self._kwargs)
     start = task.start
     end = task.end
     table = self._get_odps_table_name(task.shard_name)
     table = table.split(".")[1]
     project = self._kwargs["project"]
     access_id = self._kwargs["access_id"]
     access_key = self._kwargs["access_key"]
     endpoint = self._kwargs.get("endpoint")
     partition = self._kwargs.get("partition", None)
     columns = self._kwargs.get("columns", None)
     pd = ODPSReader(
         access_id=access_id,
         access_key=access_key,
         project=project,
         endpoint=endpoint,
         table=table,
         partition=partition,
         num_processes=num_processes,
         transform_fn=transform_fn,
         columns=columns,
     )
     pd.reset((start, end - start), shard_size)
     shard_count = pd.get_shards_count()
     for i in range(shard_count):
         records = pd.get_records()
         for record in records:
             yield record
     pd.stop()
Ejemplo n.º 2
0
 def __init__(self, **kwargs):
     """
     Args:
         kwargs should contains "sep" and "columns" like
         'sep=",",column=["sepal.length", "sepal.width", "variety"]'
     """
     AbstractDataReader.__init__(self, **kwargs)
     check_required_kwargs(["sep", "columns"], kwargs)
     self.sep = kwargs.get("sep", ",")
     self.selected_columns = kwargs.get("columns", None)
Ejemplo n.º 3
0
 def _get_reader(self, table_name):
     check_required_kwargs(["project", "access_id", "access_key"],
                           self._kwargs)
     return ODPSReader(
         project=self._kwargs["project"],
         access_id=self._kwargs["access_id"],
         access_key=self._kwargs["access_key"],
         table=table_name,
         endpoint=self._kwargs.get("endpoint"),
         partition=self._kwargs.get("partition", None),
         num_processes=self._kwargs.get("num_processes", 1),
     )
Ejemplo n.º 4
0
    def default_dataset_fn(self):
        check_required_kwargs(["label_col"], self._kwargs)

        def dataset_fn(dataset, mode, metadata):
            def _parse_data(record):
                label_col_name = self._kwargs["label_col"]
                record = tf.strings.to_number(record, tf.float32)

                def _get_features_without_labels(
                    record, label_col_idx, features_shape
                ):
                    features = [
                        record[:label_col_idx],
                        record[label_col_idx + 1 :],  # noqa: E203
                    ]
                    features = tf.concat(features, -1)
                    return tf.reshape(features, features_shape)

                features_shape = (len(metadata.column_names) - 1, 1)
                labels_shape = (1,)
                if mode == Mode.PREDICTION:
                    if label_col_name in metadata.column_names:
                        label_col_idx = metadata.column_names.index(
                            label_col_name
                        )
                        return _get_features_without_labels(
                            record, label_col_idx, features_shape
                        )
                    else:
                        return tf.reshape(record, features_shape)
                else:
                    if label_col_name not in metadata.column_names:
                        raise ValueError(
                            "Missing the label column '%s' in the retrieved "
                            "ODPS table during %s mode."
                            % (label_col_name, mode)
                        )
                    label_col_idx = metadata.column_names.index(label_col_name)
                    labels = tf.reshape(record[label_col_idx], labels_shape)
                    return (
                        _get_features_without_labels(
                            record, label_col_idx, features_shape
                        ),
                        labels,
                    )

            dataset = dataset.map(_parse_data)

            if mode == Mode.TRAINING:
                dataset = dataset.shuffle(buffer_size=200)
            return dataset

        return dataset_fn
Ejemplo n.º 5
0
    def _init_reader(self, table_name, task_type):
        if (table_name in self._table_readers
                and task_type in self._table_readers[table_name]):
            return

        self._table_readers.setdefault(table_name, {})

        check_required_kwargs(["project", "access_id", "access_key"],
                              self._kwargs)
        reader = self.get_odps_reader(table_name)

        # There may be weird errors if tasks with the same table
        # and different type use the same reader.
        self._table_readers[table_name][task_type] = reader
Ejemplo n.º 6
0
 def create_shards(self):
     check_required_kwargs(["table", "records_per_task"], self._kwargs)
     reader = self._get_reader(self._kwargs["table"])
     shard_name_prefix = self._kwargs["table"] + ":shard_"
     table_size = reader.get_table_size()
     records_per_task = self._kwargs["records_per_task"]
     shards = {}
     num_shards = table_size // records_per_task
     start_ind = 0
     for shard_id in range(num_shards):
         shards[shard_name_prefix + str(shard_id)] = (
             start_ind,
             records_per_task,
         )
         start_ind += records_per_task
     num_records_left = table_size % records_per_task
     if num_records_left != 0:
         shards[shard_name_prefix + str(num_shards)] = (
             start_ind,
             num_records_left,
         )
     return shards
Ejemplo n.º 7
0
 def create_shards(self):
     check_required_kwargs(["table", "records_per_task"], self._kwargs)
     table_name = self._kwargs["table"]
     reader = self.get_odps_reader(table_name)
     table_size = reader.get_table_size()
     records_per_task = self._kwargs["records_per_task"]
     shards = []
     num_shards = table_size // records_per_task
     start_ind = 0
     for shard_id in range(num_shards):
         shards.append((
             table_name,
             start_ind,
             records_per_task,
         ))
         start_ind += records_per_task
     num_records_left = table_size % records_per_task
     if num_records_left != 0:
         shards.append((
             table_name,
             start_ind,
             num_records_left,
         ))
     return shards
Ejemplo n.º 8
0
 def __init__(self, **kwargs):
     AbstractDataReader.__init__(self, **kwargs)
     self._kwargs = kwargs
     check_required_kwargs(["data_dir"], self._kwargs)