def parse_sample_writer(self, values) -> Optional[callbacks_writer_pb2.WriterMessage]: """Convert raw writer message into protobuf message.""" if self.sample_fields is None: self.sample_fields = values return None if self.processing_adaptation is None: self.processing_adaptation = bool(re.match(r"^Adaptation terminated", str(values[0]))) message = callbacks_writer_pb2.WriterMessage(topic=TopicEnum.Value("SAMPLE")) if self.processing_adaptation: # detect if we are on last adaptation message is_last_adaptation_message = re.match( r"sample_writer:Diagonal elements of inverse mass matrix:", self.previous_message ) if is_last_adaptation_message: self.processing_adaptation = False for value in values: if isinstance(values[0], str): message.feature[""].string_list.value.append(value) else: message.feature[""].double_list.value.append(value) else: if isinstance(values[0], str): # after sampling, we get messages such as "Elapsed Time: ..." for value in values: message.feature[""].string_list.value.append(value) return message # typical case: draws for key, value in zip(self.sample_fields, values): message.feature[key].double_list.value.append(value) return message
def parse_init_writer(self, values) -> callbacks_writer_pb2.WriterMessage: """Convert raw writer message into protobuf message.""" message = callbacks_writer_pb2.WriterMessage( topic=TopicEnum.Value('INITIALIZATION')) for value in values: message.feature[''].double_list.value.append(value) return message
def parse_logger(self, values) -> callbacks_writer_pb2.WriterMessage: """Convert raw writer message into protobuf message.""" message = callbacks_writer_pb2.WriterMessage( topic=TopicEnum.Value('LOGGER')) for value in values: message.feature[''].string_list.value.append(value) return message
def extract_protobuf_messages(fit_bytes): varint_decoder = google.protobuf.internal.decoder._DecodeVarint32 next_pos, pos = 0, 0 while pos < len(fit_bytes): msg = callbacks_writer_pb2.WriterMessage() next_pos, pos = varint_decoder(fit_bytes, pos) msg.ParseFromString(fit_bytes[pos : pos + next_pos]) yield msg pos += next_pos
def test_callbacks_writer_parser_message_writer(): """Test that callback writer messages are parsed correctly.""" message = """logger:Gradient evaluation took 4.7e-05 seconds""" message_pb = callbacks_writer_pb2.WriterMessage( topic=TopicEnum.Value("LOGGER")) message_pb.feature[""].string_list.value.append( message.split(":", 1).pop()) parser = httpstan.callbacks_writer_parser.WriterParser() observed = parser.parse(message) assert observed == message_pb
def parse_diagnostic_writer(self, values): """Convert raw writer message into protobuf message.""" if self.diagnostic_fields is None: self.diagnostic_fields = values return message = callbacks_writer_pb2.WriterMessage(topic=TopicEnum.Value("DIAGNOSTIC")) if isinstance(values[0], str): # after sampling, we get messages such as "Elapsed Time: ..." for value in values: message.feature[""].string_list.value.append(value) return message for key, value in zip(self.diagnostic_fields, values): message.feature[key].double_list.value.append(value) return message
def test_callbacks_writer_parser_sample_writer_adapt(): """Test that callback writer messages are parsed correctly.""" messages = ('''sample_writer:["lp__", "accept_stat__"]''', '''sample_writer:Adaptation terminated''', '''sample_writer:Step size = 0.809818''') parser = httpstan.callbacks_writer_parser.WriterParser() observed = [parser.parse(message) for message in messages] expected = [None] for message in messages[1:]: message_pb = callbacks_writer_pb2.WriterMessage( topic=TopicEnum.Value('SAMPLE')) message_pb.feature[''].string_list.value.append( message.split(':', 1).pop()) expected.append(message_pb) assert observed == expected
def test_callbacks_writer_parser_sample_writer(): """Test that callback writer messages are parsed correctly.""" messages = [ """sample_writer:["lp__","accept_stat__","y"]""", # noqa """sample_writer:[-3.16745e-06,0.999965,0.00251692]""", ] sample_fields = ["lp__", "accept_stat__", "y"] # noqa values = json.loads("""[-3.16745e-06,0.999965,0.00251692]""") parser = httpstan.callbacks_writer_parser.WriterParser() observed = [parser.parse(message) for message in messages] message_pb = callbacks_writer_pb2.WriterMessage( topic=TopicEnum.Value("SAMPLE")) for key, value in zip(sample_fields, values): message_pb.feature[key].double_list.value.append(value) expected = [None, message_pb] assert observed == expected