예제 #1
0
    def _make_portal_worker(self):
        portal_worker_options = dp_pb.DataPortalWorkerOptions(
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter="TF_RECORD",
                                                  compressed_type=''),
            writer_options=dj_pb.WriterOptions(output_writer="TF_RECORD"),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=128, max_flying_item=300000),
            merge_buffer_size=4096,
            merger_read_ahead_size=1000000)

        self._portal_worker = DataPortalWorker(portal_worker_options,
                                               "localhost:5005", 0,
                                               "test_portal_worker_0",
                                               "portal_worker_0",
                                               "localhost:2379", True)
예제 #2
0
    def _make_portal_worker(self):
        portal_worker_options = dp_pb.DataPortalWorkerOptions(
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter="TF_RECORD",
                                                  read_ahead_size=1 << 20,
                                                  read_batch_size=128),
            writer_options=dj_pb.WriterOptions(output_writer="TF_RECORD"),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=128, max_flying_item=300000),
            merger_read_ahead_size=1000000,
            merger_read_batch_size=128)

        self._portal_worker = DataPortalWorker(portal_worker_options,
                                               "localhost:5005", 0,
                                               "test_portal_worker_0",
                                               "portal_worker_0",
                                               "localhost:2379", "test_user",
                                               "test_password", True)
예제 #3
0
    def _make_portal_worker(self, raw_data_iter, validation_ratio):
        portal_worker_options = dp_pb.DataPortalWorkerOptions(
            raw_data_options=dj_pb.RawDataOptions(
                raw_data_iter=raw_data_iter,
                read_ahead_size=1 << 20,
                read_batch_size=128,
                optional_fields=['label'],
                validation_ratio=validation_ratio,
            ),
            writer_options=dj_pb.WriterOptions(output_writer="TF_RECORD"),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=128, max_flying_item=300000),
            merger_read_ahead_size=1000000,
            merger_read_batch_size=128)

        os.environ['ETCD_BASE_DIR'] = "portal_worker_0"
        self._portal_worker = DataPortalWorker(portal_worker_options,
                                               "localhost:5005", 0, "etcd",
                                               True)
    set_logger()
    if args.input_data_file_iter == 'TF_RECORD' or \
            args.output_builder == 'TF_RECORD':
        import tensorflow
        tensorflow.compat.v1.enable_eager_execution()

    optional_fields = list(
        field for field in map(str.strip, args.optional_fields.split(','))
        if field != '')

    portal_worker_options = dp_pb.DataPortalWorkerOptions(
        raw_data_options=dj_pb.RawDataOptions(
            raw_data_iter=args.input_data_file_iter,
            compressed_type=args.compressed_type,
            read_ahead_size=args.read_ahead_size,
            read_batch_size=args.read_batch_size,
            optional_fields=optional_fields),
        writer_options=dj_pb.WriterOptions(
            output_writer=args.output_builder,
            compressed_type=args.builder_compressed_type),
        batch_processor_options=dj_pb.BatchProcessorOptions(
            batch_size=args.batch_size, max_flying_item=-1),
        merger_read_ahead_size=args.merger_read_ahead_size,
        merger_read_batch_size=args.merger_read_batch_size,
        memory_limit_ratio=args.memory_limit_ratio / 100)
    data_portal_worker = DataPortalWorker(portal_worker_options,
                                          args.master_addr, args.rank_id,
                                          args.kvstore_type,
                                          (args.kvstore_type == 'mock'))
    data_portal_worker.start()
예제 #5
0
                        help='the builder for ouput file')
    parser.add_argument("--batch_size", type=int, default=1024,
                        help="the batch size for raw data reader")
    parser.add_argument("--max_flying_item", type=int, default=1048576,
                        help='the maximum items processed at the same time')
    args = parser.parse_args()

    portal_worker_options = dp_pb.DataPortalWorkerOptions(
        raw_data_options=dj_pb.RawDataOptions(
            raw_data_iter=args.input_data_file_iter,
            compressed_type=args.compressed_type
        ),
        writer_options=dj_pb.WriterOptions(
            output_writer=args.output_builder,
            compressed_type=args.builder_compressed_type
        ),
        batch_processor_options=dj_pb.BatchProcessorOptions(
            batch_size=args.batch_size,
            max_flying_item=args.max_flying_item
        ),
        merge_buffer_size=args.merge_buffer_size,
        write_buffer_size=args.write_buffer_size,
        merger_read_ahead_size=args.merger_read_ahead_size
    )

    data_portal_worker = DataPortalWorker(
            portal_worker_options, args.master_addr,
            args.rank_id, args.etcd_name, args.etcd_base_dir,
            args.etcd_addrs, args.use_mock_etcd
        )
    data_portal_worker.start()