def _make_portal_worker(self):
        portal_worker_options = dp_pb.DataPortalWorkerOptions(
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter="TF_RECORD",
                                                  compressed_type=''),
            writer_options=dj_pb.WriterOptions(output_writer="TF_RECORD"),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=128, max_flying_item=300000),
            merge_buffer_size=4096,
            merger_read_ahead_size=1000000)

        self._portal_worker = DataPortalWorker(portal_worker_options,
                                               "localhost:5005", 0,
                                               "test_portal_worker_0",
                                               "portal_worker_0",
                                               "localhost:2379", True)
Example #2
0
    def _make_portal_worker(self):
        portal_worker_options = dp_pb.DataPortalWorkerOptions(
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter="TF_RECORD",
                                                  read_ahead_size=1 << 20,
                                                  read_batch_size=128),
            writer_options=dj_pb.WriterOptions(output_writer="TF_RECORD"),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=128, max_flying_item=300000),
            merger_read_ahead_size=1000000,
            merger_read_batch_size=128)

        self._portal_worker = DataPortalWorker(portal_worker_options,
                                               "localhost:5005", 0,
                                               "test_portal_worker_0",
                                               "portal_worker_0",
                                               "localhost:2379", "test_user",
                                               "test_password", True)
Example #3
0
    def _make_portal_worker(self, raw_data_iter, validation_ratio):
        portal_worker_options = dp_pb.DataPortalWorkerOptions(
            raw_data_options=dj_pb.RawDataOptions(
                raw_data_iter=raw_data_iter,
                read_ahead_size=1 << 20,
                read_batch_size=128,
                optional_fields=['label'],
                validation_ratio=validation_ratio,
            ),
            writer_options=dj_pb.WriterOptions(output_writer="TF_RECORD"),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=128, max_flying_item=300000),
            merger_read_ahead_size=1000000,
            merger_read_batch_size=128)

        os.environ['ETCD_BASE_DIR'] = "portal_worker_0"
        self._portal_worker = DataPortalWorker(portal_worker_options,
                                               "localhost:5005", 0, "etcd",
                                               True)
    set_logger()
    if args.input_data_file_iter == 'TF_RECORD' or \
            args.output_builder == 'TF_RECORD':
        import tensorflow
        tensorflow.compat.v1.enable_eager_execution()

    optional_fields = list(
        field for field in map(str.strip, args.optional_fields.split(','))
        if field != '')

    portal_worker_options = dp_pb.DataPortalWorkerOptions(
        raw_data_options=dj_pb.RawDataOptions(
            raw_data_iter=args.input_data_file_iter,
            compressed_type=args.compressed_type,
            read_ahead_size=args.read_ahead_size,
            read_batch_size=args.read_batch_size,
            optional_fields=optional_fields),
        writer_options=dj_pb.WriterOptions(
            output_writer=args.output_builder,
            compressed_type=args.builder_compressed_type),
        batch_processor_options=dj_pb.BatchProcessorOptions(
            batch_size=args.batch_size, max_flying_item=-1),
        merger_read_ahead_size=args.merger_read_ahead_size,
        merger_read_batch_size=args.merger_read_batch_size,
        memory_limit_ratio=args.memory_limit_ratio / 100)
    data_portal_worker = DataPortalWorker(portal_worker_options,
                                          args.master_addr, args.rank_id,
                                          args.kvstore_type,
                                          (args.kvstore_type == 'mock'))
    data_portal_worker.start()
Example #5
0
class TestDataPortalWorker(unittest.TestCase):
    def _get_input_fpath(self, partition_id):
        return "{}/raw_data_partition_{}".format(self._input_dir, partition_id)

    def _generate_one_partition(self, partition_id, example_id, num_examples):
        fpath = self._get_input_fpath(partition_id)
        with tf.io.TFRecordWriter(fpath) as writer:
            for i in range(num_examples):
                example_id += random.randint(1, 5)
                # real_id = example_id.encode("utf-8")
                event_time = 150000000 + random.randint(10000000, 20000000)
                feat = {}
                feat['example_id'] = tf.train.Feature(
                    bytes_list=tf.train.BytesList(
                        value=[str(example_id).encode('utf-8')]))
                feat['raw_id'] = tf.train.Feature(
                    bytes_list=tf.train.BytesList(
                        value=[str(example_id).encode('utf-8')]))
                feat['event_time'] = tf.train.Feature(
                    int64_list=tf.train.Int64List(value=[event_time]))
                example = tf.train.Example(features=tf.train.Features(
                    feature=feat))
                writer.write(example.SerializeToString())
        return example_id

    def _generate_input_data(self):
        self._partition_item_num = 1 << 16
        self._clean_up()
        gfile.MakeDirs(self._input_dir)
        success_flag_fpath = "{}/_SUCCESS".format(self._input_dir)
        example_id = 1000001
        for partition_id in range(self._input_partition_num):
            example_id = self._generate_one_partition(partition_id, example_id,
                                                      self._partition_item_num)

        with gfile.GFile(success_flag_fpath, 'w') as fh:
            fh.write('')

    def _make_portal_worker(self):
        portal_worker_options = dp_pb.DataPortalWorkerOptions(
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter="TF_RECORD",
                                                  read_ahead_size=1 << 20,
                                                  read_batch_size=128),
            writer_options=dj_pb.WriterOptions(output_writer="TF_RECORD"),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=128, max_flying_item=300000),
            merger_read_ahead_size=1000000,
            merger_read_batch_size=128)

        self._portal_worker = DataPortalWorker(portal_worker_options,
                                               "localhost:5005", 0,
                                               "test_portal_worker_0",
                                               "portal_worker_0",
                                               "localhost:2379", "test_user",
                                               "test_password", True)

    def _clean_up(self):
        if gfile.Exists(self._input_dir):
            gfile.DeleteRecursively(self._input_dir)
        if gfile.Exists(self._partition_output_dir):
            gfile.DeleteRecursively(self._partition_output_dir)
        if gfile.Exists(self._merge_output_dir):
            gfile.DeleteRecursively(self._merge_output_dir)

    def _prepare_test(self):
        self._input_dir = './portal_worker_input'
        self._partition_output_dir = './portal_worker_partition_output'
        self._merge_output_dir = './portal_worker_merge_output'
        self._input_partition_num = 4
        self._output_partition_num = 2
        self._generate_input_data()
        self._make_portal_worker()

    def _check_partitioner(self, map_task):
        output_partitions = gfile.ListDirectory(map_task.output_base_dir)
        output_partitions = [
            x for x in output_partitions if "SUCCESS" not in x
        ]
        self.assertEqual(len(output_partitions), map_task.output_partition_num)
        partition_dirs = ["{}/{}".format(map_task.output_base_dir, x) \
            for x in output_partitions]

        total_cnt = 0
        for partition in output_partitions:
            dpath = "{}/{}".format(map_task.output_base_dir, partition)
            partition_id = partition.split("_")[-1]
            partition_id = int(partition_id)
            segments = gfile.ListDirectory(dpath)
            for segment in segments:
                fpath = "{}/{}".format(dpath, segment)
                event_time = 0
                for record in tf.python_io.tf_record_iterator(fpath):
                    tf_item = TfExampleItem(record)
                    self.assertTrue(
                        tf_item.event_time >= event_time,
                        "{}, {}".format(tf_item.event_time, event_time))
                    event_time = tf_item.event_time  ## assert order
                    self.assertEqual(partition_id, CityHash32(tf_item.raw_id) \
                        % map_task.output_partition_num)
                    total_cnt += 1
        self.assertEqual(total_cnt,
                         self._partition_item_num * self._input_partition_num)

    def _check_merge(self, reduce_task):
        dpath = os.path.join(self._merge_output_dir, \
            common.partition_repr(reduce_task.partition_id))
        fpaths = gfile.ListDirectory(dpath)
        fpaths = sorted(fpaths, key=lambda fpath: fpath, reverse=False)
        event_time = 0
        total_cnt = 0
        for fpath in fpaths:
            fpath = os.path.join(dpath, fpath)
            logging.info("check merge path:{}".format(fpath))
            for record in tf.python_io.tf_record_iterator(fpath):
                tf_item = TfExampleItem(record)
                self.assertTrue(tf_item.event_time >= event_time)
                event_time = tf_item.event_time
                total_cnt += 1
        return total_cnt

    def test_portal_worker(self):
        self._prepare_test()
        map_task = dp_pb.MapTask()
        map_task.output_base_dir = self._partition_output_dir
        map_task.output_partition_num = self._output_partition_num
        map_task.partition_id = 0
        map_task.task_name = 'map_part_{}'.format(map_task.partition_id)
        map_task.part_field = 'example_id'
        map_task.data_portal_type = dp_pb.DataPortalType.Streaming
        for partition_id in range(self._input_partition_num):
            map_task.fpaths.append(self._get_input_fpath(partition_id))

        # partitioner
        task = dp_pb.NewTaskResponse()
        task.map_task.CopyFrom(map_task)
        self._portal_worker._run_map_task(task.map_task)

        self._check_partitioner(task.map_task)

        # merge
        total_cnt = 0
        for partition_id in range(self._output_partition_num):
            reduce_task = dp_pb.ReduceTask()
            reduce_task.map_base_dir = self._partition_output_dir
            reduce_task.reduce_base_dir = self._merge_output_dir
            reduce_task.partition_id = partition_id
            reduce_task.task_name = 'reduce_part_{}'.format(partition_id)
            self._portal_worker._run_reduce_task(reduce_task)
            total_cnt += self._check_merge(reduce_task)

        self.assertEqual(total_cnt,
                         self._partition_item_num * self._input_partition_num)
        self._clean_up()
Example #6
0
    portal_worker_options = dp_pb.DataPortalWorkerOptions(
        raw_data_options=dj_pb.RawDataOptions(
            raw_data_iter=args.input_data_file_iter,
            compressed_type=args.compressed_type,
            read_ahead_size=args.read_ahead_size,
            read_batch_size=args.read_batch_size,
            optional_fields=optional_fields,
            input_data_stat_sample_ratio=args.input_data_stat_sample_ratio
        ),
        writer_options=dj_pb.WriterOptions(
            output_writer=args.output_builder,
            compressed_type=args.builder_compressed_type
        ),
        batch_processor_options=dj_pb.BatchProcessorOptions(
            batch_size=args.batch_size,
            max_flying_item=-1
        ),
        merger_read_ahead_size=args.merger_read_ahead_size,
        merger_read_batch_size=args.merger_read_batch_size,
        memory_limit_ratio=args.memory_limit_ratio/100
    )
    db_database, db_addr, db_username, db_password, db_base_dir = \
        get_kvstore_config(args.kvstore_type)
    data_portal_worker = DataPortalWorker(
            portal_worker_options, args.master_addr,
            args.rank_id, db_database, db_base_dir,
            db_addr, db_username, db_password,
            (args.kvstore_type == 'mock')
        )
    data_portal_worker.start()
    parser.add_argument("--batch_size", type=int, default=1024,
                        help="the batch size for raw data reader")
    parser.add_argument("--max_flying_item", type=int, default=1048576,
                        help='the maximum items processed at the same time')
    args = parser.parse_args()

    portal_worker_options = dp_pb.DataPortalWorkerOptions(
        raw_data_options=dj_pb.RawDataOptions(
            raw_data_iter=args.input_data_file_iter,
            compressed_type=args.compressed_type
        ),
        writer_options=dj_pb.WriterOptions(
            output_writer=args.output_builder,
            compressed_type=args.builder_compressed_type
        ),
        batch_processor_options=dj_pb.BatchProcessorOptions(
            batch_size=args.batch_size,
            max_flying_item=args.max_flying_item
        ),
        merge_buffer_size=args.merge_buffer_size,
        write_buffer_size=args.write_buffer_size,
        merger_read_ahead_size=args.merger_read_ahead_size
    )

    data_portal_worker = DataPortalWorker(
            portal_worker_options, args.master_addr,
            args.rank_id, args.etcd_name, args.etcd_base_dir,
            args.etcd_addrs, args.use_mock_etcd
        )
    data_portal_worker.start()