def test_oneof_nested_oneof_messages_are_serialized_with_defaults(): """ Nested messages with oneofs should also be handled """ message = Test(wrapped_nested_message_value=NestedMessage( id=0, wrapped_message_value=Message(value=0))) assert (betterproto.which_one_of( message, "value_type") == betterproto.which_one_of( Test().from_json(message.to_json()), "value_type") == ( "wrapped_nested_message_value", NestedMessage(id=0, wrapped_message_value=Message(value=0)), ))
def test_which_one_of_returns_second_field_when_set(): message = Test() message.from_json(get_test_case_json_data("oneof_enum")) assert message.move == Move(x=2, y=3) assert message.signal == 0 assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))
def test_which_one_of_returns_enum_with_non_default_value(): """ returns first field when it is enum and set with non default value """ message = Test() message.from_json( get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")) assert message.move is None assert message.signal == Signal.PASS assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN)
def parse_tags(feature: TileFeature, layer: TileLayer, show_names: bool) -> dict: return { '*ID*': feature.id, 'GeoSize': f"{len(feature.geometry):,}", 'GeoType': TileGeomType(feature.type).name, **{ layer.keys[feature.tags[i]]: which_one_of( layer.values[feature.tags[i + 1]], "val")[1] for i in range(0, len(feature.tags), 2) if show_names or not layer.keys[feature.tags[i]].startswith("name:") } }
def test_which_one_of_returns_enum_with_non_default_value(): """ returns first field when it is enum and set with non default value """ message = Test() message.from_json( get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")) assert message.move == Move( x=0, y=0) # Proto3 will default this as there is no null assert message.signal == Signal.RESIGN assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN)
def test_oneof_default_value_set_causes_writes_wire(): @dataclass class Foo(betterproto.Message): bar: int = betterproto.int32_field(1, group="group1") baz: str = betterproto.string_field(2, group="group1") def _round_trip_serialization(foo: Foo) -> Foo: return Foo().parse(bytes(foo)) foo1 = Foo(bar=0) foo2 = Foo(baz="") foo3 = Foo() assert bytes(foo1) == b"\x08\x00" assert (betterproto.which_one_of( foo1, "group1") == betterproto.which_one_of( _round_trip_serialization(foo1), "group1") == ("bar", 0)) assert bytes(foo2) == b"\x12\x00" # Baz is just an empty string assert (betterproto.which_one_of( foo2, "group1") == betterproto.which_one_of( _round_trip_serialization(foo2), "group1") == ("baz", "")) assert bytes(foo3) == b"" assert (betterproto.which_one_of( foo3, "group1") == betterproto.which_one_of( _round_trip_serialization(foo3), "group1") == ("", None))
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: """ True if proto_field_obj is a OneOf, otherwise False. .. warning:: Becuase the message from protoc is defined in proto2, and betterproto works with proto3, and interpreting the FieldDescriptorProto.oneof_index field requires distinguishing between default and unset values (which proto3 doesn't support), we have to hack the generated FieldDescriptorProto class for this to work. The hack consists of setting group="oneof_index" in the field metadata, essentially making oneof_index the sole member of a one_of group, which allows us to tell whether it was set, via the which_one_of interface. """ return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
def parse_tags(feature: TileFeature, layer: TileLayer, show_names: bool, summary: bool) -> dict: if summary: show_names = True geo_size = len(feature.geometry) res = {'*ID*': feature.id, 'GeoSize': f'{geo_size :,}' if not summary else geo_size, 'GeoType': TileGeomType(feature.type).name} tags = { layer.keys[feature.tags[i]]: which_one_of(layer.values[feature.tags[i + 1]], 'val')[1] for i in range(0, len(feature.tags), 2) if show_names or not layer.keys[feature.tags[i]].startswith('name:')} if summary: res['tags'] = tags else: res.update(tags) return res
def preprocess_message(self, update: pb.FeedMessage): """ Check if the message is something concerning. If it is, print out a warning to the console to let competitors know """ msg_type, _ = betterproto.which_one_of(update, "msg") if msg_type == "request_failed_msg": warnings.warn( "REQUEST_DENIED: " + update.request_failed_msg.message, XChangeWarning, ) if msg_type == "liquidation_msg": warnings.warn( "LIQUIDATION: " + update.liquidation_msg.message, XChangeWarning, ) if msg_type == "generic_msg": event_type = update.generic_msg.event_type if event_type != pb.GenericMessageType.MESSAGE: warnings.warn( f"{pb.GenericMessageType(event_type).name}: {update.generic_msg.message}", XChangeWarning, ) if msg_type == "market_snapshot_msg": ts_time = datetime.fromisoformat( update.market_snapshot_msg.timestamp) diff = datetime.now().timestamp() - ts_time.timestamp() if self.__time_differential is None: self.__time_differential = diff elif diff - self.__time_differential > 2: warnings.warn( f"DESYNC: bot is receiving messages faster than it is processing them ({int(diff - self.__time_differential)}s behind)", XChangeWarning, ) self.__time_differential = min(self.__time_differential, diff)
def test_oneof_support(): @dataclass class Sub(betterproto.Message): val: int = betterproto.int32_field(1) @dataclass class Foo(betterproto.Message): bar: int = betterproto.int32_field(1, group="group1") baz: str = betterproto.string_field(2, group="group1") sub: Sub = betterproto.message_field(3, group="group2") abc: str = betterproto.string_field(4, group="group2") foo = Foo() assert betterproto.which_one_of(foo, "group1")[0] == "" foo.bar = 1 foo.baz = "test" # Other oneof fields should now be unset assert foo.bar == 0 assert betterproto.which_one_of(foo, "group1")[0] == "baz" foo.sub.val = 1 assert betterproto.serialized_on_wire(foo.sub) foo.abc = "test" # Group 1 shouldn't be touched, group 2 should have reset assert foo.sub.val == 0 assert betterproto.serialized_on_wire(foo.sub) is False assert betterproto.which_one_of(foo, "group2")[0] == "abc" # Zero value should always serialize for one-of foo = Foo(bar=0) assert betterproto.which_one_of(foo, "group1")[0] == "bar" assert bytes(foo) == b"\x08\x00" # Round trip should also work foo2 = Foo().parse(bytes(foo)) assert betterproto.which_one_of(foo2, "group1")[0] == "bar" assert foo.bar == 0 assert betterproto.which_one_of(foo2, "group2")[0] == ""
def test_which_name(): message = Test() message.from_json(get_test_case_json_data("oneof", "oneof-name.json")) assert betterproto.which_one_of(message, "foo") == ("name", "foobar")
def test_oneof_no_default_values_passed(): message = Test() assert (betterproto.which_one_of( message, "value_type") == betterproto.which_one_of( Test().from_json(message.to_json()), "value_type") == ("", None))
async def handle_exchange_update(self, update: pb.FeedMessage): kind, _ = betterproto.which_one_of(update, "msg") #Possible exchange updates: 'market_snapshot_msg','fill_msg' #'liquidation_msg','generic_msg', 'trade_msg', 'pnl_msg', etc. """ Calculate PnL based upon market to market contracts and tracked cash """ if kind == "pnl_msg": my_m2m = self.cash for asset in ( [i + j for i in ["6R", "6H"] for j in ["H", "M", "U", "Z"]] + ["RORUSD"]): my_m2m += self.mid[asset] * self.pos[asset] if self.mid[ asset] is not None else 0 for asset in (["RH" + j for j in ["H", "M", "U", "Z"]]): my_m2m += (self.mid[asset] * self.pos[asset] * self.mid["RORUSD"] if (self.mid[asset] is not None and self.mid["RORUSD"] is not None) else 0) print("M2M", update.pnl_msg.realized_pnl, update.pnl_msg.m2m_pnl, my_m2m) #Update position upon fill messages of your trades elif kind == "fill_msg": if update.fill_msg.order_side == pb.FillMessageSide.BUY: self.cash -= update.fill_msg.filled_qty * float( update.fill_msg.price) self.pos[update.fill_msg.asset] += update.fill_msg.filled_qty else: self.cash += update.fill_msg.filled_qty * float( update.fill_msg.price) self.pos[update.fill_msg.asset] -= update.fill_msg.filled_qty global checked if checked > 300: await self.place_bids(FUTURES) await self.place_asks(FUTURES) checked = 0 checked += 1 await self.spot_market() #Identify mid price through order book updates elif kind == "market_snapshot_msg": for asset in (FUTURES + ["RORUSD"]): book = update.market_snapshot_msg.books[asset] mid: "Optional[float]" if len(book.asks) > 0: if len(book.bids) > 0: mid = (float(book.asks[0].px) + float(book.bids[0].px)) / 2 else: mid = float(book.asks[0].px) elif len(book.bids) > 0: mid = float(book.bids[0].px) else: mid = None self.mid[asset] = mid elif kind == "order_cancelled_msg": print('order cancelled') elif kind == "request_failed_msg": print('request failed') #Competition event messages elif kind == "generic_msg": data = update.generic_msg.message.split(',') if (0 < IsInt(data[0])): self.today = int(data[0]) self.interestRates['ROR'] = daily_rate(float(data[1])) self.interestRates['HAP'] = daily_rate(float(data[2])) self.interestRates['USD'] = daily_rate(float(data[3])) print(update.generic_msg.message) # print(self.interestRates['ROR'], self.interestRates['HAP'], self.interestRates['USD']) await self.evaluate_fairs() await self.place_bids(FUTURES) await self.place_asks(FUTURES) await self.spot_market() elif ("New Federal Funds Target" in data): d = data.split(" ") currency = d[0] target = int(d[-1]) self.federalRate[currency] = (target, self.today) print("!!!!! New Federal Funds Target for " + currency + ":" + target) else: pass print(update.generic_msg.message)
async def handle_exchange_update(self, update: pb.FeedMessage): kind, _ = betterproto.which_one_of(update, "msg") #Possible exchange updates: 'market_snapshot_msg','fill_msg' #'liquidation_msg','generic_msg', 'trade_msg', 'pnl_msg', etc. """ Calculate PnL based upon market to market contracts and tracked cash """ if kind == "pnl_msg": my_m2m = self.cash for asset in ( [i + j for i in ["6R", "6H"] for j in ["H", "M", "U", "Z"]] + ["RORUSD"]): my_m2m += self.mid[asset] * self.pos[asset] if self.mid[ asset] is not None else 0 for asset in (["RH" + j for j in ["H", "M", "U", "Z"]]): my_m2m += (self.mid[asset] * self.pos[asset] * self.mid["RORUSD"] if (self.mid[asset] is not None and self.mid["RORUSD"] is not None) else 0) print("M2M", update.pnl_msg.realized_pnl, update.pnl_msg.m2m_pnl, my_m2m) #Update position upon fill messages of your trades elif kind == "fill_msg": if update.fill_msg.order_side == pb.FillMessageSide.BUY: self.cash -= update.fill_msg.filled_qty * float( update.fill_msg.price) self.pos[update.fill_msg.asset] += update.fill_msg.filled_qty if update.fill_msg.asset != 'RORUSD': reqs = await self.place_bids(update.fill_msg.asset) resps = await asyncio.gather(*reqs) for i, resp in enumerate(resps): self.bidorderid[ update.fill_msg.asset][i] = resp.order_id else: self.cash += update.fill_msg.filled_qty * float( update.fill_msg.price) self.pos[update.fill_msg.asset] -= update.fill_msg.filled_qty if update.fill_msg.asset != 'RORUSD': reqs = await self.place_asks(update.fill_msg.asset) resps = await asyncio.gather(*reqs) for i, resp in enumerate(resps): self.askorderid[ update.fill_msg.asset][i] = resp.order_id await self.spot_market() #Identify mid price through order book updates elif kind == "market_snapshot_msg": for asset in (FUTURES + ["RORUSD"]): book = update.market_snapshot_msg.books[asset] mid: "Optional[float]" if len(book.asks) > 0: if len(book.bids) > 0: mid = (float(book.asks[0].px) + float(book.bids[0].px)) / 2 else: mid = float(book.asks[0].px) elif len(book.bids) > 0: mid = float(book.bids[0].px) else: mid = None self.mid[asset] = mid #Competition event messages elif kind == "generic_msg": data = update.generic_msg.message.split(',') if (0 < IsInt(data[0])): TODAY = data[0] self.interestRates['ROR'] = daily_rate(float(data[1])) self.interestRates['HAP'] = daily_rate(float(data[2])) self.interestRates['USD'] = daily_rate(float(data[3])) print(update.generic_msg.message) # print(self.interestRates['ROR'], self.interestRates['HAP'], self.interestRates['USD']) await self.evaluate_fairs() bid_reqs = [] ask_reqs = [] for asset in FUTURES: bid_reqs += await self.place_bids(asset) ask_reqs += await self.place_asks(asset) bid_resps = await asyncio.gather(*bid_reqs) ask_resps = await asyncio.gather(*ask_reqs) for idx, resp in enumerate(bid_resps): asset = FUTURES[math.floor(idx / 2)] self.bidorderid[asset][idx % 2] = resp.order_id for idx, resp in enumerate(ask_resps): asset = FUTURES[math.floor(idx / 2)] self.askorderid[asset][idx % 2] = resp.order_id await self.spot_market() elif ("New Federal Funds Target" in data[0]): d = data.split(" ") currency = d[0] target = d[0] print("New Federal Funds Target for " + currency + ":" + target) else: pass print(update.generic_msg.message)
def get_one_of(base: T, sealed=True) -> T: _, value = which_one_of(base, ("sealed_" if sealed else "") + "value") return value
async def stream_messages(self): async for message in self.seabird.stream_events(): name, val = betterproto.which_one_of(message, "inner") if name == "command": asyncio.create_task(self.handle_command(val))
def test_which_count(): message = Test() message.from_json(get_test_case_json_data("oneof")) assert betterproto.which_one_of(message, "foo") == ("count", 100)
def assert_round_trip_serialization_works(message: Test) -> None: assert betterproto.which_one_of(message, "value_type") == betterproto.which_one_of( Test().from_json(message.to_json()), "value_type")
def test_which_name(): message = Test() message.from_json( get_test_case_json_data("oneof", "oneof_name.json")[0].json) assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")