def test_bolt_01_csv(): # We can create a namespace from the csv. ns = MessageNamespace(bolt1.csv) # string [expected string] for t in [['init globalfeatures= features=80', 'init globalfeatures= features=80 tlvs={}'], ['init globalfeatures= features=80 tlvs={}'], ['init globalfeatures= features=80 tlvs={networks={chains=[6fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000]}}'], ['init globalfeatures= features=80 tlvs={networks={chains=[6fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000,1fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000]}}'], ['error channel_id=0000000000000000000000000000000000000000000000000000000000000000 data=00'], ['ping num_pong_bytes=0 ignored='], ['ping num_pong_bytes=3 ignored=0000'], ['pong ignored='], ['pong ignored=000000']]: m = Message.from_str(bolt1.namespace, t[0]) b = io.BytesIO() m.write(b) # Works with our manually-made namespace, and the builtin one. b.seek(0) m2 = Message.read(bolt1.namespace, b) assert m2.to_str() == t[-1] b.seek(0) m2 = Message.read(ns, b) assert m2.to_str() == t[-1]
def test_subtype_array(): ns = MessageNamespace() ns.load_csv([ 'msgtype,tx_signatures,1', 'msgdata,tx_signatures,num_witnesses,u16,', 'msgdata,tx_signatures,witness_stack,witness_stack,num_witnesses', 'subtype,witness_stack', 'subtypedata,witness_stack,num_input_witness,u16,', 'subtypedata,witness_stack,witness_element,witness_element,num_input_witness', 'subtype,witness_element', 'subtypedata,witness_element,len,u16,', 'subtypedata,witness_element,witness,byte,len' ]) for test in [[ "tx_signatures witness_stack=" "[{witness_element=[{witness=3045022100ac0fdee3e157f50be3214288cb7f11b03ce33e13b39dadccfcdb1a174fd3729a02200b69b286ac9f0fc5c51f9f04ae5a9827ac11d384cc203a0eaddff37e8d15c1ac01},{witness=02d6a3c2d0cf7904ab6af54d7c959435a452b24a63194e1c4e7c337d3ebbb3017b}]}]", bytes.fromhex( '00010001000200483045022100ac0fdee3e157f50be3214288cb7f11b03ce33e13b39dadccfcdb1a174fd3729a02200b69b286ac9f0fc5c51f9f04ae5a9827ac11d384cc203a0eaddff37e8d15c1ac01002102d6a3c2d0cf7904ab6af54d7c959435a452b24a63194e1c4e7c337d3ebbb3017b' ) ]]: m = Message.from_str(ns, test[0]) assert m.to_str() == test[0] buf = io.BytesIO() m.write(buf) assert buf.getvalue().hex() == test[1].hex() assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0]
def test_subtype(): ns = MessageNamespace() ns.load_csv([ 'msgtype,test1,1', 'msgdata,test1,test_sub,channel_update_timestamps,4', 'subtype,channel_update_timestamps', 'subtypedata,' + 'channel_update_timestamps,timestamp_node_id_1,u32,', 'subtypedata,' + 'channel_update_timestamps,timestamp_node_id_2,u32,' ]) for test in [[ "test1 test_sub=[" "{timestamp_node_id_1=1,timestamp_node_id_2=2}" ",{timestamp_node_id_1=3,timestamp_node_id_2=4}" ",{timestamp_node_id_1=5,timestamp_node_id_2=6}" ",{timestamp_node_id_1=7,timestamp_node_id_2=8}]", bytes([0, 1] + [0, 0, 0, 1, 0, 0, 0, 2] + [0, 0, 0, 3, 0, 0, 0, 4] + [0, 0, 0, 5, 0, 0, 0, 6] + [0, 0, 0, 7, 0, 0, 0, 8]) ]]: m = Message.from_str(ns, test[0]) assert m.to_str() == test[0] buf = io.BytesIO() m.write(buf) assert buf.getvalue() == test[1] assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0] # Test missing field logic. m = Message.from_str(ns, "test1", incomplete_ok=True) assert m.missing_fields()
def cmp_msg(msg: Message, expected: Message) -> Optional[str]: """Return None if every field in expected matches a field in msg. Otherwise return a complaint""" if msg.messagetype != expected.messagetype: return "Expected {}, got {}".format(expected.messagetype, msg.messagetype) obj = msg.to_py() expected_obj = expected.to_py() return cmp_obj(obj, expected_obj, expected.messagetype.name)
def action(self, runner: 'Runner') -> bool: super().action(runner) # Now we have runner, we can fill in all the message fields message = Message(self.msgtype, **self.resolve_args(runner, self.kwargs)) missing = message.missing_fields() if missing: raise SpecFileError(self, "Missing fields {}".format(missing)) binmsg = io.BytesIO() message.write(binmsg) runner.recv(self, self.find_conn(runner), binmsg.getvalue()) msg_to_stash(runner, self, message) return True
def test_fundamental(): ns = MessageNamespace() ns.load_csv([ 'msgtype,test,1', 'msgdata,test,test_byte,byte,', 'msgdata,test,test_u16,u16,', 'msgdata,test,test_u32,u32,', 'msgdata,test,test_u64,u64,', 'msgdata,test,test_chain_hash,chain_hash,', 'msgdata,test,test_channel_id,channel_id,', 'msgdata,test,test_sha256,sha256,', 'msgdata,test,test_signature,signature,', 'msgdata,test,test_point,point,', 'msgdata,test,test_short_channel_id,short_channel_id,', ]) mstr = """test test_byte=255 test_u16=65535 test_u32=4294967295 test_u64=18446744073709551615 test_chain_hash=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20 test_channel_id=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20 test_sha256=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20 test_signature=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f40 test_point=0201030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f2021 test_short_channel_id=1x2x3""" m = Message.from_str(ns, mstr) # Same (ignoring whitespace differences) assert m.to_str().split() == mstr.split()
def test_tlv_complex(): # A real example from the spec. ns = MessageNamespace([ "msgtype,reply_channel_range,264,gossip_queries", "msgdata,reply_channel_range,chain_hash,chain_hash,", "msgdata,reply_channel_range,first_blocknum,u32,", "msgdata,reply_channel_range,number_of_blocks,u32,", "msgdata,reply_channel_range,full_information,byte,", "msgdata,reply_channel_range,len,u16,", "msgdata,reply_channel_range,encoded_short_ids,byte,len", "msgdata,reply_channel_range,tlvs,reply_channel_range_tlvs,", "tlvtype,reply_channel_range_tlvs,timestamps_tlv,1", "tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,byte,", "tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoded_timestamps,byte,...", "tlvtype,reply_channel_range_tlvs,checksums_tlv,3", "tlvdata,reply_channel_range_tlvs,checksums_tlv,checksums,channel_update_checksums,...", "subtype,channel_update_timestamps", "subtypedata,channel_update_timestamps,timestamp_node_id_1,u32,", "subtypedata,channel_update_timestamps,timestamp_node_id_2,u32,", "subtype,channel_update_checksums", "subtypedata,channel_update_checksums,checksum_node_id_1,u32,", "subtypedata,channel_update_checksums,checksum_node_id_2,u32," ]) binmsg = bytes.fromhex( '010806226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f000000670000000701001100000067000001000000006d000001000003101112fa300000000022d7a4a79bece840' ) msg = Message.read(ns, io.BytesIO(binmsg)) buf = io.BytesIO() msg.write(buf) assert buf.getvalue() == binmsg
def msg_to_stash(runner: 'Runner', event: Event, msg: Message) -> None: """ExpectMsg and Msg save every field to the stash, in order""" fields = msg.to_py() stash = runner.get_stash(event, type(event).__name__, []) stash.append((msg.messagetype.name, fields)) runner.add_stash(type(event).__name__, stash)
def action(self, runner: 'Runner') -> bool: super().action(runner) conn = self.find_conn(runner) while True: binmsg = runner.get_output_message(conn, self) if binmsg is None: raise EventError(self, "Did not receive a message from runner") for e in conn.must_not_events: if e.matches(binmsg): raise EventError(self, "Got msg banned by {}: {}" .format(e, binmsg.hex())) # Might be completely unknown to namespace. try: msg = Message.read(event_namespace, io.BytesIO(binmsg)) except ValueError as ve: raise EventError(self, "Runner gave bad msg {}: {}".format(binmsg.hex(), ve)) # Ignore function may tell us to respond. response = self.ignore(msg) if response is not None: for msg in response: binm = io.BytesIO() msg.write(binm) runner.recv(self, conn, binm.getvalue()) continue err = self.message_match(runner, msg) if err: raise EventError(self, "{}: message was {}".format(err, msg.to_str())) break return True
def has_feature(featurebits: List[int], event: Event, msg: Message, runner: 'Runner') -> None: for bit in featurebits: if not has_bit(msg.fields['features'], bit): raise EventError( event, "features set bit {} unset: {}".format(bit, msg.to_str()))
def test_message_constructor(): ns = MessageNamespace([ 'msgtype,test1,1', 'msgdata,test1,tlvs,test_tlvstream,', 'tlvtype,test_tlvstream,tlv1,1', 'tlvdata,test_tlvstream,tlv1,field1,byte,4', 'tlvdata,test_tlvstream,tlv1,field2,u32,', 'tlvtype,test_tlvstream,tlv2,255', 'tlvdata,test_tlvstream,tlv2,field3,byte,...' ]) m = Message(ns.get_msgtype('test1'), tlvs='{tlv1={field1=01020304,field2=5}' ',tlv2={field3=01020304},4=010203}') buf = io.BytesIO() m.write(buf) assert buf.getvalue() == bytes([0, 1] + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] + [4, 3, 1, 2, 3] + [253, 0, 255, 4, 1, 2, 3, 4])
def message_match(self, runner: 'Runner', msg: Message) -> Optional[str]: """Does this message match what we expect?""" partmessage = Message(self.msgtype, **self.resolve_args(runner, self.kwargs)) ret = cmp_msg(msg, partmessage) if ret is None: self.if_match(self, msg) msg_to_stash(runner, self, msg) return ret
def test_tlv(): ns = MessageNamespace() ns.load_csv([ 'msgtype,test1,1', 'msgdata,test1,tlvs,test_tlvstream,', 'tlvtype,test_tlvstream,tlv1,1', 'tlvdata,test_tlvstream,tlv1,field1,byte,4', 'tlvdata,test_tlvstream,tlv1,field2,u32,', 'tlvtype,test_tlvstream,tlv2,255', 'tlvdata,test_tlvstream,tlv2,field3,byte,...' ]) for test in [ [ "test1 tlvs={tlv1={field1=01020304,field2=5}}", bytes([0, 1] + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5]) ], [ "test1 tlvs={tlv1={field1=01020304,field2=5},tlv2={field3=01020304}}", bytes([0, 1] + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] + [253, 0, 255, 4, 1, 2, 3, 4]) ], [ "test1 tlvs={tlv1={field1=01020304,field2=5},4=010203,tlv2={field3=01020304}}", bytes([0, 1] + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] + [4, 3, 1, 2, 3] + [253, 0, 255, 4, 1, 2, 3, 4]) ] ]: m = Message.from_str(ns, test[0]) assert m.to_str() == test[0] buf = io.BytesIO() m.write(buf) assert buf.getvalue() == test[1] assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0] # Ordering test (turns into canonical ordering) m = Message.from_str( ns, 'test1 tlvs={tlv1={field1=01020304,field2=5},tlv2={field3=01020304},4=010203}' ) buf = io.BytesIO() m.write(buf) assert buf.getvalue() == bytes([0, 1] + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] + [4, 3, 1, 2, 3] + [253, 0, 255, 4, 1, 2, 3, 4])
def has_one_feature(featurebits: List[int], event: Event, msg: Message, runner: 'Runner') -> None: has_any = False for bit in featurebits: if has_bit(msg.fields['features'], bit): has_any = True if not has_any: raise EventError( event, "none of {} set: {}".format(featurebits, msg.to_str()))
def calc_checksum(update: Message) -> int: # BOLT #7: The checksum of a `channel_update` is the CRC32C checksum as # specified in [RFC3720](https://tools.ietf.org/html/rfc3720#appendix-B.4) # of this `channel_update` without its `signature` and `timestamp` fields. bufio = io.BytesIO() update.write(bufio) buf = bufio.getvalue() # BOLT #7: # 1. type: 258 (`channel_update`) # 2. data: # * [`signature`:`signature`] # * [`chain_hash`:`chain_hash`] # * [`short_channel_id`:`short_channel_id`] # * [`u32`:`timestamp`] # * [`byte`:`message_flags`] # Note: 2 bytes for `type` field return crc32c.crc32(buf[2 + 64:2 + 64 + 32 + 8] + buf[2 + 64 + 32 + 8 + 4:])
def test_static_array(): ns = MessageNamespace() ns.load_csv(['msgtype,test1,1', 'msgdata,test1,test_arr,byte,4']) ns.load_csv( ['msgtype,test2,2', 'msgdata,test2,test_arr,short_channel_id,4']) for test in [["test1 test_arr=00010203", bytes([0, 1] + [0, 1, 2, 3])], [ "test2 test_arr=[0x1x2,4x5x6,7x8x9,10x11x12]", bytes([0, 2] + [0, 0, 0, 0, 0, 1, 0, 2] + [0, 0, 4, 0, 0, 5, 0, 6] + [0, 0, 7, 0, 0, 8, 0, 9] + [0, 0, 10, 0, 0, 11, 0, 12]) ]]: m = Message.from_str(ns, test[0]) assert m.to_str() == test[0] buf = io.BytesIO() m.write(buf) assert buf.getvalue() == test[1] assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0]
def node_announcement(self, side: Side, features: str, rgb_color: Tuple[int, int, int], alias: str, addresses: bytes, timestamp: int) -> Message: # Begin with a fake signature. ann = Message(namespace().get_msgtype('node_announcement'), signature=Sig(bytes(64)), features=features, timestamp=timestamp, node_id=self.node_id(side).format().hex(), rgb_color=bytes(rgb_color).hex(), alias=bytes(alias, encoding='utf-8').zfill(32), addresses=addresses) # BOLT #7: # - MUST set `signature` to the signature of the double-SHA256 of the entire # remaining packet after `signature` (using the key given by `node_id`). buf = io.BytesIO() ann.write(buf) # Note the first two 'type' bytes! h = sha256(sha256(buf.getvalue()[2 + 64:]).digest()).digest() ann.set_field('signature', Sig(self.node_privkeys[side].secret.hex(), h.hex())) return ann
def _unsigned_channel_announcment(self, features: str, short_channel_id: str) -> Message: """Produce a channel_announcement message with dummy sigs""" node_ids = self.node_ids() bitcoin_keys = self.funding_pubkeys_for_gossip() return Message(event_namespace.get_msgtype('channel_announcement'), node_signature_1=Sig(bytes(64)), node_signature_2=Sig(bytes(64)), bitcoin_signature_1=Sig(bytes(64)), bitcoin_signature_2=Sig(bytes(64)), features=features, chain_hash=self.chain_hash, short_channel_id=short_channel_id, node_id_1=node_ids[0].format(), node_id_2=node_ids[1].format(), bitcoin_key_1=bitcoin_keys[0].format(), bitcoin_key_2=bitcoin_keys[1].format())
def get_output_message(self, conn: Conn, event: ExpectMsg) -> Optional[bytes]: if self.config.getoption('verbose'): print("[GET_OUTPUT_MESSAGE {}]".format(conn)) # We make the message they were expecting. msg = Message(event.msgtype, **event.resolve_args(self, event.kwargs)) # Fake up the other fields. for m in msg.missing_fields(): ftype = msg.messagetype.find_field(m.name) msg.set_field(m.name, self.fake_field(ftype.fieldtype)) binmsg = io.BytesIO() msg.write(binmsg) return binmsg.getvalue()
def action(self, runner: 'Runner') -> bool: super().action(runner) # Check they all use the same conn! conn: Optional[Conn] = None for s in self.sequences: c = cast(ExpectMsg, s.events[0]).find_conn(runner) if conn is None: conn = c elif c != conn: raise SpecFileError(self, "sequences do not all use the same conn?") assert conn while True: binmsg = runner.get_output_message(conn, self.sequences[0].events[0]) if binmsg is None: raise EventError(self, "Did not receive a message from runner") try: msg = Message.read(namespace(), io.BytesIO(binmsg)) except ValueError as ve: raise EventError(self, "Invalid msg {}: {}".format(binmsg.hex(), ve)) ignored = Sequence.ignored_by_all(msg, self.enabled_sequences(runner)) # If they gave us responses, send those now. if ignored is not None: for msg in ignored: binm = io.BytesIO() msg.write(binm) runner.recv(self, conn, binm.getvalue()) continue seq = Sequence.match_which_sequence(runner, msg, self.enabled_sequences(runner)) if seq is not None: # We found the sequence, run it return seq.action(runner, skip_first=True) raise EventError(self, "None of the sequences {} matched {}".format(self.enabled_sequences(runner), msg.to_str()))
def action(self, runner: 'Runner') -> bool: super().action(runner) # Check they all use the same conn! conn = None for s in self.sequences: c = cast(ExpectMsg, s.events[0]).find_conn(runner) if conn is None: conn = c elif c != conn: raise SpecFileError(self, "sequences do not all use the same conn?") assert conn all_done = True sequences = self.enabled_sequences(runner) while sequences != []: # Get message binmsg = runner.get_output_message(conn, sequences[0].events[0]) if binmsg is None: raise EventError(self, "Did not receive a message from runner, still expecting {}" .format([s.events[0] for s in sequences])) try: msg = Message.read(namespace(), io.BytesIO(binmsg)) except ValueError as ve: raise EventError(self, "Invalid msg {}: {}".format(binmsg.hex(), ve)) if Sequence.ignored_by_all(msg, sequences): continue seq = Sequence.match_which_sequence(runner, msg, sequences) if seq is not None: sequences.remove(seq) all_done &= seq.action(runner, skip_first=True) continue raise EventError(self, "Message did not match any sequences {}: {}" .format([s.events[0] for s in sequences], msg.to_str())) return all_done
def test_dynamic_array(): """Test that dynamic array types enforce matching lengths""" ns = MessageNamespace([ 'msgtype,test1,1', 'msgdata,test1,count,u16,', 'msgdata,test1,arr1,byte,count', 'msgdata,test1,arr2,u32,count' ]) # This one is fine. m = Message(ns.get_msgtype('test1'), arr1='01020304', arr2='[1,2,3,4]') buf = io.BytesIO() m.write(buf) assert buf.getvalue() == bytes( [0, 1] + [0, 4] + [1, 2, 3, 4] + [0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4]) # These ones are not with pytest.raises(ValueError, match='Inconsistent length.*count'): m = Message(ns.get_msgtype('test1'), arr1='01020304', arr2='[1,2,3]') with pytest.raises(ValueError, match='Inconsistent length.*count'): m = Message(ns.get_msgtype('test1'), arr1='01020304', arr2='[1,2,3,4,5]')
def ignore_pings(msg: Message) -> Optional[List[Message]]: """Function to ignore pings (and respond with pongs appropriately)""" if msg.messagetype.name != 'ping': return None # BOLT #1: # A node receiving a `ping` message: # ... # - if `num_pong_bytes` is less than 65532: # - MUST respond by sending a `pong` message, with `byteslen` equal # to `num_pong_bytes`. # - otherwise (`num_pong_bytes` is **not** less than 65532): # - MUST ignore the `ping`. if msg.fields['num_pong_bytes'] >= 65532: return [] # A node sending a `pong` message: # - SHOULD set `ignored` to 0s. # - MUST NOT set `ignored` to sensitive data such as secrets or # portions of initialized outmsg = Message(event_namespace.get_msgtype('pong'), ignored='00' * msg.fields['num_pong_bytes']) return [outmsg]
def channel_update(self, short_channel_id: str, side: Side, disable: bool, cltv_expiry_delta: int, htlc_minimum_msat: int, fee_base_msat: int, fee_proportional_millionths: int, timestamp: int, htlc_maximum_msat: Optional[int]) -> Message: # BOLT #7: The `channel_flags` bitfield is used to indicate the # direction of the channel: it identifies the node that this update # originated from and signals various options concerning the # channel. The following table specifies the meaning of its individual # bits: # # | Bit Position | Name | Meaning | # | ------------- | ----------- | -------------------------------- | # | 0 | `direction` | Direction this update refers to. | # | 1 | `disable` | Disable the channel. | # BOLT #7: # - if the origin node is `node_id_1` in the message: # - MUST set the `direction` bit of `channel_flags` to 0. # - otherwise: # - MUST set the `direction` bit of `channel_flags` to 1. if self.funding_pubkey(side) == self.funding_pubkeys_for_gossip()[0]: channel_flags = 0 else: channel_flags = 1 if disable: channel_flags |= 2 # BOLT #7: The `message_flags` bitfield is used to indicate the # presence of optional fields in the `channel_update` message: # # | Bit Position | Name | Field | # | ------------- | ------------------------- | -------------------------------- | # | 0 | `option_channel_htlc_max` | `htlc_maximum_msat` | message_flags = 0 if htlc_maximum_msat: message_flags |= 1 # Begin with a fake signature. update = Message(namespace().get_msgtype('channel_update'), short_channel_id=short_channel_id, signature=Sig(bytes(64)), chain_hash=self.chain_hash, timestamp=timestamp, message_flags=message_flags, channel_flags=channel_flags, cltv_expiry_delta=cltv_expiry_delta, htlc_minimum_msat=htlc_minimum_msat, fee_base_msat=fee_base_msat, fee_proportional_millionths=fee_proportional_millionths) if htlc_maximum_msat: update.set_field('htlc_maximum_msat', htlc_maximum_msat) # BOLT #7: # - MUST set `signature` to the signature of the double-SHA256 of the # entire remaining packet after `signature`, using its own `node_id`. buf = io.BytesIO() update.write(buf) # Note the first two 'type' bytes! h = sha256(sha256(buf.getvalue()[2 + 64:]).digest()).digest() update.set_field('signature', Sig(self.node_privkeys[side].secret.hex(), h.hex())) return update
def no_feature(featurebits: List[int], event: Event, msg: Message) -> None: for bit in featurebits: if has_bit(msg.fields['features'], bit): raise EventError( event, "features set bit {} unexpected: {}".format(bit, msg.to_str()))