コード例 #1
0
    def __init__(self, media_stream_descriptor):
        try:
            # media descriptor holding input data format
            self.media_stream_descriptor = media_stream_descriptor

            # Get how data will be transferred
            if (self.media_stream_descriptor.WhichOneof(
                    "data_transfer_properties") is None):
                self.content_transfer_type = TransferType.BYTES
            elif self.media_stream_descriptor.HasField(
                    "shared_memory_buffer_transfer_properties"):
                self.content_transfer_type = TransferType.REFERENCE
            elif self.media_stream_descriptor.HasField(
                    "shared_memory_segments_transfer_properties"):
                self.content_transfer_type = TransferType.HANDLE

            # Setup if shared mem used
            if self.content_transfer_type == TransferType.REFERENCE:
                # Create shared memory accessor specific to the client
                self.shared_memory_manager = SharedMemoryManager(
                    name=self.media_stream_descriptor.
                    shared_memory_buffer_transfer_properties.handle_name,
                    size=self.media_stream_descriptor.
                    shared_memory_buffer_transfer_properties.length_bytes,
                )
            else:
                self.shared_memory_manager = None

        except:
            log_exception(get_logger("State"))
            raise
コード例 #2
0
    def _generate_media_stream_message(self, gva_sample):
        message = json.loads(list(gva_sample.video_frame.messages())[0])

        msg = extension_pb2.MediaStreamMessage()
        msg.ack_sequence_number = message["sequence_number"]
        msg.media_sample.timestamp = message["timestamp"]

        for region in gva_sample.video_frame.regions():
            inference = msg.media_sample.inferences.add()

            attributes = []
            obj_label = None
            obj_confidence = 0
            obj_left = 0
            obj_width = 0
            obj_top = 0
            obj_height = 0

            for tensor in region.tensors():
                name = tensor.name()

                if (name == 'detection'):
                    obj_confidence = region.confidence()
                    obj_label = region.label()

                    obj_left, obj_top, obj_width, obj_height = region.normalized_rect(
                    )

                    inference.type = inferencing_pb2.Inference.InferenceType.ENTITY  # pylint: disable=no-member
                    if region.object_id():  #Tracking
                        obj_id = region.object_id()
                        attributes.append(["object_id", str(obj_id), 0])
                elif tensor["label"]:  #Classification
                    attr_name = name
                    attr_label = tensor["label"]
                    attr_confidence = region.confidence()
                    attributes.append([attr_name, attr_label, attr_confidence])

            if obj_label is not None:
                try:
                    entity = inferencing_pb2.Entity(
                        tag=inferencing_pb2.Tag(value=obj_label,
                                                confidence=obj_confidence),
                        box=inferencing_pb2.Rectangle(l=obj_left,
                                                      t=obj_top,
                                                      w=obj_width,
                                                      h=obj_height))

                    for attr in attributes:
                        attribute = inferencing_pb2.Attribute(
                            name=attr[0], value=attr[1], confidence=attr[2])
                        entity.attributes.append(attribute)

                except:
                    log_exception(self._logger)
                    raise

                inference.entity.CopyFrom(entity)

        return msg
コード例 #3
0
 def __init__(self, descriptor, shared_memory_manager, queue):
     try:
         self._request_seq_num = 1
         self._descriptor = descriptor
         self._shared_memory_manager = shared_memory_manager
         self._queue = queue
     except:
         log_exception()
         raise
コード例 #4
0
    def _generate_gva_sample(self, client_state, request):

        new_sample = None

        try:
            # Get reference to raw bytes
            if client_state.content_transfer_type == TransferType.BYTES:
                raw_bytes = memoryview(
                    request.media_sample.content_bytes.bytes)
            elif client_state.content_transfer_type == TransferType.REFERENCE:
                # Data sent over shared memory buffer
                address_offset = request.media_sample.content_reference.address_offset
                length_bytes = request.media_sample.content_reference.length_bytes

                # Get memory reference to (in readonly mode) data sent over shared memory
                raw_bytes = client_state.shared_memory_manager.read_bytes(
                    address_offset, length_bytes)

            # Get encoding details of the media sent by client
            encoding = client_state.media_stream_descriptor.media_descriptor.video_frame_sample_format.encoding

            # Handle RAW content (Just place holder for the user to handle each variation...)
            if encoding == client_state.media_stream_descriptor.media_descriptor.video_frame_sample_format.Encoding.RAW:
                pixel_format = client_state.media_stream_descriptor.media_descriptor\
                    .video_frame_sample_format.pixel_format
                caps_format = None

                if pixel_format == media_pb2.VideoFrameSampleFormat.PixelFormat.RGBA:
                    caps_format = 'RGBA'
                elif pixel_format == media_pb2.VideoFrameSampleFormat.PixelFormat.RGB24:
                    caps_format = 'RGB'
                elif pixel_format == media_pb2.VideoFrameSampleFormat.PixelFormat.BGR24:
                    caps_format = 'BGR'
                if caps_format is not None:
                    caps = ''.join(("video/x-raw,format=",
                                    caps_format,
                                    ",width=",
                                    str(client_state.media_stream_descriptor.media_descriptor\
                                        .video_frame_sample_format.dimensions.width),
                                    ",height=",
                                    str(client_state.media_stream_descriptor.media_descriptor\
                                        .video_frame_sample_format.dimensions.height)))
                    new_sample = GvaFrameData(
                        bytes(raw_bytes),
                        caps,
                        message={
                            'sequence_number': request.sequence_number,
                            'timestamp': request.media_sample.timestamp
                        })
            else:
                self._logger.info('Sample format is not RAW')
        except:
            log_exception(self._logger)
            raise
        return new_sample
コード例 #5
0
        def get_memory_slot(self, sequence_number, content_bytes):
            try:
                memory_slot = self._shared_memory_manager.get_empty_slot(
                    sequence_number, len(content_bytes))
                if memory_slot is None:
                    return None

                self._shared_memory_manager.write_bytes(
                    memory_slot[0], content_bytes)

            except Exception:
                log_exception()
                raise
            return memory_slot
コード例 #6
0
def main():
    try:
        args = parse_args()
        _log_options(args)
        frame_source = None
        frame_queue = queue.Queue()
        result_queue = queue.Queue()
        msp = MediaStreamProcessor(args.grpc_server_address,
                                   args.use_shared_memory)

        _, extension = os.path.splitext(args.sample_file)
        if extension in ['.png', '.jpg']:
            frame_source = ImageSource(args.sample_file, args.loop_count)
        elif extension in ['.mp4']:
            frame_source = VideoSource(args.sample_file)
        else:
            print("{}: unsupported file type".format(args.sample_file))
            sys.exit(1)

        width, height = frame_source.dimensions()
        print("{} {}".format(width, height))
        msp.start(width, height, frame_queue, result_queue)

        with open(args.output_file, "w") as output:
            image = frame_source.get_frame()
            while image:
                frame_queue.put(image)
                while not result_queue.empty():
                    result = result_queue.get()
                    print_result(result, output)
                image = frame_source.get_frame()
            frame_queue.put(None)

            result = result_queue.get()
            while result:
                print_result(result, output)
                result = result_queue.get()

        frame_source.close()

    except Exception:
        log_exception()
        sys.exit(-1)
コード例 #7
0
    def get_media_stream_descriptor(self, width, height, extension_config):
        try:
            smbtp = None
            if self._shared_memory_manager:
                smbtp = extension_pb2.SharedMemoryBufferTransferProperties(
                    handle_name=self._shared_memory_manager.shm_file_name,
                    length_bytes=self._shared_memory_manager.shm_file_size,
                )
            media_stream_descriptor = extension_pb2.MediaStreamDescriptor(
                graph_identifier=extension_pb2.GraphIdentifier(
                    media_services_arm_id="",
                    graph_instance_name="SampleGraph1",
                    graph_node_name="SampleGraph1",
                ),
                extension_configuration=extension_config,
                media_descriptor=media_pb2.MediaDescriptor(
                    timescale=90000,
                    # pylint: disable=no-member
                    # E1101: Class 'VideoFrameSampleFormat' has no 'Encoding' member (no-member)
                    # E1101: Class 'VideoFrameSampleFormat' has no 'PixelFormat' member (no-member)
                    video_frame_sample_format=media_pb2.VideoFrameSampleFormat(
                        encoding=media_pb2.VideoFrameSampleFormat.Encoding.
                        Value("RAW"),
                        pixel_format=media_pb2.VideoFrameSampleFormat.
                        PixelFormat.Value("BGR24"),
                        dimensions=media_pb2.Dimensions(
                            width=width,
                            height=height,
                        ),
                    ),
                ),
                shared_memory_buffer_transfer_properties=smbtp,
            )
        except Exception:
            log_exception()
            raise

        return media_stream_descriptor
コード例 #8
0
    def __init__(self, grpc_server_address, use_shared_memory,
                 frame_queue_size, frame_size):
        try:
            # Full address including port number i.e. "localhost:44001"
            self._grpc_server_address = grpc_server_address
            self._shared_memory_manager = None
            if use_shared_memory:
                shared_memory_size = (frame_queue_size *
                                      frame_size if frame_queue_size else 100 *
                                      frame_size)
                self._shared_memory_manager = SharedMemoryManager(
                    os.O_RDWR | os.O_SYNC | os.O_CREAT,
                    name=None,
                    size=shared_memory_size,
                )
            self._grpc_channel = grpc.insecure_channel(
                self._grpc_server_address)
            self._grpc_stub = extension_pb2_grpc.MediaGraphExtensionStub(
                self._grpc_channel)

        except Exception:
            log_exception()
            raise
コード例 #9
0
    def ProcessMediaStream(self, requestIterator, context):
        requests_received = 0
        responses_sent = 0
        # First message from the client is (must be) MediaStreamDescriptor
        request = next(requestIterator)
        requests_received += 1
        # Extract message IDs
        request_seq_num = request.sequence_number
        request_ack_seq_num = request.ack_sequence_number
        # State object per client
        client_state = State(request.media_stream_descriptor)
        self._logger.info("[Received] SeqNum: {0:07d} | "
                          "AckNum: {1}\nMediaStreamDescriptor:\n{2}".format(
                              request_seq_num,
                              request_ack_seq_num,
                              client_state.media_stream_descriptor,
                          ))
        # First message response ...
        media_stream_message = extension_pb2.MediaStreamMessage(
            sequence_number=1,
            ack_sequence_number=request_seq_num,
            media_stream_descriptor=extension_pb2.MediaStreamDescriptor(
                media_descriptor=media_pb2.MediaDescriptor(
                    timescale=client_state.media_stream_descriptor.
                    media_descriptor.timescale)),
        )
        responses_sent += 1
        yield media_stream_message

        final_pipeline_parameters = {}
        if self._version.startswith("debug"):
            timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
            location = os.path.join(tempfile.gettempdir(), "vaserving",
                                    self._version, timestamp)
            os.makedirs(os.path.abspath(location))
            final_pipeline_parameters = {
                "location": os.path.join(location, "frame_%07d.jpeg")
            }

        try:
            if self._pipeline_parameter_arg:
                pipeline_parameters = {}
                if os.path.isfile(self._pipeline_parameter_arg):
                    with open(self._pipeline_parameter_arg) as json_file:
                        pipeline_parameters = json.load(json_file)
                else:
                    pipeline_parameters = json.loads(
                        self._pipeline_parameter_arg)
                final_pipeline_parameters.update(pipeline_parameters)
        except ValueError:
            self._logger.error("Issue loading json parameters: {}".format(
                self._pipeline_parameter_arg))
            raise

        self._logger.info("Pipeline Name : {}".format(self._pipeline))
        self._logger.info("Pipeline Version : {}".format(self._version))
        self._logger.info(
            "Pipeline Parameters : {}".format(final_pipeline_parameters))
        detect_input = Queue(maxsize=self._input_queue_size)
        detect_output = Queue()
        # Start object detection pipeline
        # It will wait until it receives frames via the detect_input queue
        detect_pipeline = VAServing.pipeline(self._pipeline, self._version)
        detect_pipeline.start(
            source={
                "type": "application",
                "class": "GStreamerAppSource",
                "input": detect_input,
                "mode": "push",
            },
            destination={
                "type": "application",
                "class": "GStreamerAppDestination",
                "output": detect_output,
                "mode": "frames",
            },
            parameters=final_pipeline_parameters,
        )

        # Process rest of the MediaStream message sequence
        for request in requestIterator:
            try:
                if requests_received - responses_sent >= self._input_queue_size:
                    queued_samples = self._get_queued_samples(detect_output,
                                                              block=True)
                else:
                    queued_samples = []
                # Read request id, sent by client
                request_seq_num = request.sequence_number
                self._logger.info(
                    "[Received] SeqNum: {0:07d}".format(request_seq_num))
                requests_received += 1
                gva_sample = self._generate_gva_sample(client_state, request)
                detect_input.put(gva_sample)
                queued_samples.extend(self._get_queued_samples(detect_output))
                if context.is_active():
                    for gva_sample in queued_samples:
                        if gva_sample:
                            media_stream_message = self._generate_media_stream_message(
                                gva_sample)
                            responses_sent += 1
                            self._logger.info(
                                "[Sent] AckSeqNum: {0:07d}".format(
                                    media_stream_message.ack_sequence_number))
                            yield media_stream_message
                else:
                    break
                if detect_pipeline.status().state.stopped():
                    break
            except:
                log_exception(self._logger)
                raise

        if detect_pipeline.status().state.stopped():
            try:
                raise Exception(detect_pipeline.status().state)
            except:
                log_exception(self._logger)
                raise

        # After the server has finished processing all the request iterator objects
        # Push a None object into the input queue.
        # When the None object comes out of the output queue, we know we've finished
        # processing all requests
        gva_sample = None
        if not detect_pipeline.status().state.stopped():
            detect_input.put(None)
            gva_sample = detect_output.get()
        while gva_sample:
            media_stream_message = self._generate_media_stream_message(
                gva_sample)
            responses_sent += 1
            self._logger.info("[Sent] AckSeqNum: {0:07d}".format(
                media_stream_message.ack_sequence_number))
            if context.is_active():
                yield media_stream_message
            gva_sample = detect_output.get()

        self._logger.info(
            "Done processing messages: Received: {}, Sent: {}".format(
                requests_received, responses_sent))
        self._logger.info("MediaStreamDescriptor:\n{0}".format(
            client_state.media_stream_descriptor))
コード例 #10
0
def main():
    try:
        args = parse_args()
        _log_options(args)
        frame_delay = 1 / args.frame_rate if args.frame_rate > 0 else 0
        frame_source = None
        frame_queue = queue.Queue(args.frame_queue_size)
        result_queue = queue.Queue()
        frames_sent = 0
        frames_received = 0
        prev_fps_delta = 0
        start_time = None
        frame_source = VideoSource(args.sample_file, args.loop_count)
        width, height = frame_source.dimensions()
        image = frame_source.get_frame()

        if not image:
            logging.error(
                "Error getting frame from video source: {}".format(args.sample_file)
            )
            sys.exit(1)

        msp = MediaStreamProcessor(
            args.grpc_server_address,
            args.use_shared_memory,
            args.frame_queue_size,
            len(image),
        )

        extension_config = json.dumps(create_extension_config(args))

        msp.start(width, height, frame_queue, result_queue, extension_config)

        with open(args.output_file, "w") as output:
            start_time = time.time()
            result = True
            while image and result:
                frame_queue.put(image)
                while not result_queue.empty():
                    result = result_queue.get()
                    if isinstance(result, Exception):
                        logging.error(result)
                        frame_source.close()
                        sys.exit(1)
                    frames_received += 1
                    prev_fps_delta = _log_fps(
                        start_time, frames_received, prev_fps_delta, args.fps_interval
                    )
                    _log_result(result, output)
                image = frame_source.get_frame()
                time.sleep(frame_delay)
                frames_sent += 1

            if result:
                frame_queue.put(None)
                result = result_queue.get()
            while result:
                if isinstance(result, Exception):
                    logging.error(result)
                    frame_source.close()
                    sys.exit(1)
                frames_received += 1
                prev_fps_delta = _log_fps(
                    start_time, frames_received, prev_fps_delta, args.fps_interval
                )
                _log_result(result, output)
                result = result_queue.get()

        frame_source.close()
        delta = time.time() - start_time
        logging.info(
            "Start Time: {} End Time: {} Frames: Tx {} Rx {} FPS: {}".format(
                start_time,
                start_time + delta,
                frames_sent,
                frames_received,
                (frames_received / delta) if delta > 0 else None,
            )
        )

    except Exception:
        log_exception()
        return -1

    if frames_sent != frames_received:
        logging.error("Sent {} requests, received {} responses".format(
            frames_sent, frames_received
        ))
        return 1

    return 0