def test_oneof_nested_oneof_messages_are_serialized_with_defaults():
    """
    Nested messages with oneofs should also be handled
    """
    message = Test(wrapped_nested_message_value=NestedMessage(
        id=0, wrapped_message_value=Message(value=0)))
    assert (betterproto.which_one_of(
        message, "value_type") == betterproto.which_one_of(
            Test().from_json(message.to_json()), "value_type") == (
                "wrapped_nested_message_value",
                NestedMessage(id=0, wrapped_message_value=Message(value=0)),
            ))
def test_which_one_of_returns_second_field_when_set():
    message = Test()
    message.from_json(get_test_case_json_data("oneof_enum"))
    assert message.move == Move(x=2, y=3)
    assert message.signal == 0
    assert betterproto.which_one_of(message,
                                    "action") == ("move", Move(x=2, y=3))
def test_which_one_of_returns_enum_with_non_default_value():
    """
    returns first field when it is enum and set with non default value
    """
    message = Test()
    message.from_json(
        get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json"))
    assert message.move is None
    assert message.signal == Signal.PASS
    assert betterproto.which_one_of(message,
                                    "action") == ("signal", Signal.RESIGN)
Esempio n. 4
0
def parse_tags(feature: TileFeature, layer: TileLayer,
               show_names: bool) -> dict:
    return {
        '*ID*': feature.id,
        'GeoSize': f"{len(feature.geometry):,}",
        'GeoType': TileGeomType(feature.type).name,
        **{
            layer.keys[feature.tags[i]]: which_one_of(
                layer.values[feature.tags[i + 1]], "val")[1]
            for i in range(0, len(feature.tags), 2) if show_names or not layer.keys[feature.tags[i]].startswith("name:")
        }
    }
def test_which_one_of_returns_enum_with_non_default_value():
    """
    returns first field when it is enum and set with non default value
    """
    message = Test()
    message.from_json(
        get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json"))
    assert message.move == Move(
        x=0, y=0)  # Proto3 will default this as there is no null
    assert message.signal == Signal.RESIGN
    assert betterproto.which_one_of(message,
                                    "action") == ("signal", Signal.RESIGN)
Esempio n. 6
0
def test_oneof_default_value_set_causes_writes_wire():
    @dataclass
    class Foo(betterproto.Message):
        bar: int = betterproto.int32_field(1, group="group1")
        baz: str = betterproto.string_field(2, group="group1")

    def _round_trip_serialization(foo: Foo) -> Foo:
        return Foo().parse(bytes(foo))

    foo1 = Foo(bar=0)
    foo2 = Foo(baz="")
    foo3 = Foo()

    assert bytes(foo1) == b"\x08\x00"
    assert (betterproto.which_one_of(
        foo1, "group1") == betterproto.which_one_of(
            _round_trip_serialization(foo1), "group1") == ("bar", 0))

    assert bytes(foo2) == b"\x12\x00"  # Baz is just an empty string
    assert (betterproto.which_one_of(
        foo2, "group1") == betterproto.which_one_of(
            _round_trip_serialization(foo2), "group1") == ("baz", ""))

    assert bytes(foo3) == b""
    assert (betterproto.which_one_of(
        foo3, "group1") == betterproto.which_one_of(
            _round_trip_serialization(foo3), "group1") == ("", None))
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
    """
    True if proto_field_obj is a OneOf, otherwise False.

    .. warning::
        Becuase the message from protoc is defined in proto2, and betterproto works with
        proto3, and interpreting the FieldDescriptorProto.oneof_index field requires
        distinguishing between default and unset values (which proto3 doesn't support),
        we have to hack the generated FieldDescriptorProto class for this to work.
        The hack consists of setting group="oneof_index" in the field metadata,
        essentially making oneof_index the sole member of a one_of group, which allows
        us to tell whether it was set, via the which_one_of interface.
    """

    return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
Esempio n. 8
0
def parse_tags(feature: TileFeature, layer: TileLayer, show_names: bool,
               summary: bool) -> dict:
    if summary:
        show_names = True
    geo_size = len(feature.geometry)
    res = {'*ID*': feature.id,
           'GeoSize': f'{geo_size :,}' if not summary else geo_size,
           'GeoType': TileGeomType(feature.type).name}
    tags = {
        layer.keys[feature.tags[i]]:
            which_one_of(layer.values[feature.tags[i + 1]], 'val')[1] for i in
        range(0, len(feature.tags), 2) if
        show_names or not layer.keys[feature.tags[i]].startswith('name:')}
    if summary:
        res['tags'] = tags
    else:
        res.update(tags)
    return res
Esempio n. 9
0
    def preprocess_message(self, update: pb.FeedMessage):
        """
        Check if the message is something concerning. If it is, print out a warning to the console
        to let competitors know
        """
        msg_type, _ = betterproto.which_one_of(update, "msg")
        if msg_type == "request_failed_msg":
            warnings.warn(
                "REQUEST_DENIED: " + update.request_failed_msg.message,
                XChangeWarning,
            )
        if msg_type == "liquidation_msg":
            warnings.warn(
                "LIQUIDATION: " + update.liquidation_msg.message,
                XChangeWarning,
            )
        if msg_type == "generic_msg":
            event_type = update.generic_msg.event_type
            if event_type != pb.GenericMessageType.MESSAGE:
                warnings.warn(
                    f"{pb.GenericMessageType(event_type).name}: {update.generic_msg.message}",
                    XChangeWarning,
                )
        if msg_type == "market_snapshot_msg":
            ts_time = datetime.fromisoformat(
                update.market_snapshot_msg.timestamp)
            diff = datetime.now().timestamp() - ts_time.timestamp()

            if self.__time_differential is None:
                self.__time_differential = diff
            elif diff - self.__time_differential > 2:
                warnings.warn(
                    f"DESYNC: bot is receiving messages faster than it is processing them ({int(diff - self.__time_differential)}s behind)",
                    XChangeWarning,
                )

            self.__time_differential = min(self.__time_differential, diff)
Esempio n. 10
0
def test_oneof_support():
    @dataclass
    class Sub(betterproto.Message):
        val: int = betterproto.int32_field(1)

    @dataclass
    class Foo(betterproto.Message):
        bar: int = betterproto.int32_field(1, group="group1")
        baz: str = betterproto.string_field(2, group="group1")
        sub: Sub = betterproto.message_field(3, group="group2")
        abc: str = betterproto.string_field(4, group="group2")

    foo = Foo()

    assert betterproto.which_one_of(foo, "group1")[0] == ""

    foo.bar = 1
    foo.baz = "test"

    # Other oneof fields should now be unset
    assert foo.bar == 0
    assert betterproto.which_one_of(foo, "group1")[0] == "baz"

    foo.sub.val = 1
    assert betterproto.serialized_on_wire(foo.sub)

    foo.abc = "test"

    # Group 1 shouldn't be touched, group 2 should have reset
    assert foo.sub.val == 0
    assert betterproto.serialized_on_wire(foo.sub) is False
    assert betterproto.which_one_of(foo, "group2")[0] == "abc"

    # Zero value should always serialize for one-of
    foo = Foo(bar=0)
    assert betterproto.which_one_of(foo, "group1")[0] == "bar"
    assert bytes(foo) == b"\x08\x00"

    # Round trip should also work
    foo2 = Foo().parse(bytes(foo))
    assert betterproto.which_one_of(foo2, "group1")[0] == "bar"
    assert foo.bar == 0
    assert betterproto.which_one_of(foo2, "group2")[0] == ""
Esempio n. 11
0
def test_which_name():
    message = Test()
    message.from_json(get_test_case_json_data("oneof", "oneof-name.json"))
    assert betterproto.which_one_of(message, "foo") == ("name", "foobar")
def test_oneof_no_default_values_passed():
    message = Test()
    assert (betterproto.which_one_of(
        message, "value_type") == betterproto.which_one_of(
            Test().from_json(message.to_json()), "value_type") == ("", None))
Esempio n. 13
0
    async def handle_exchange_update(self, update: pb.FeedMessage):
        kind, _ = betterproto.which_one_of(update, "msg")

        #Possible exchange updates: 'market_snapshot_msg','fill_msg'
        #'liquidation_msg','generic_msg', 'trade_msg', 'pnl_msg', etc.
        """
        Calculate PnL based upon market to market contracts and tracked cash 
        """
        if kind == "pnl_msg":
            my_m2m = self.cash
            for asset in (
                [i + j for i in ["6R", "6H"]
                 for j in ["H", "M", "U", "Z"]] + ["RORUSD"]):
                my_m2m += self.mid[asset] * self.pos[asset] if self.mid[
                    asset] is not None else 0
            for asset in (["RH" + j for j in ["H", "M", "U", "Z"]]):
                my_m2m += (self.mid[asset] * self.pos[asset] *
                           self.mid["RORUSD"] if
                           (self.mid[asset] is not None
                            and self.mid["RORUSD"] is not None) else 0)
            print("M2M", update.pnl_msg.realized_pnl, update.pnl_msg.m2m_pnl,
                  my_m2m)
        #Update position upon fill messages of your trades
        elif kind == "fill_msg":
            if update.fill_msg.order_side == pb.FillMessageSide.BUY:
                self.cash -= update.fill_msg.filled_qty * float(
                    update.fill_msg.price)
                self.pos[update.fill_msg.asset] += update.fill_msg.filled_qty
            else:
                self.cash += update.fill_msg.filled_qty * float(
                    update.fill_msg.price)
                self.pos[update.fill_msg.asset] -= update.fill_msg.filled_qty
                global checked
                if checked > 300:
                    await self.place_bids(FUTURES)
                    await self.place_asks(FUTURES)
                    checked = 0
                checked += 1
                await self.spot_market()
        #Identify mid price through order book updates
        elif kind == "market_snapshot_msg":
            for asset in (FUTURES + ["RORUSD"]):
                book = update.market_snapshot_msg.books[asset]

                mid: "Optional[float]"
                if len(book.asks) > 0:
                    if len(book.bids) > 0:
                        mid = (float(book.asks[0].px) +
                               float(book.bids[0].px)) / 2
                    else:
                        mid = float(book.asks[0].px)
                elif len(book.bids) > 0:
                    mid = float(book.bids[0].px)
                else:
                    mid = None

                self.mid[asset] = mid
        elif kind == "order_cancelled_msg":
            print('order cancelled')
        elif kind == "request_failed_msg":
            print('request failed')
        #Competition event messages
        elif kind == "generic_msg":
            data = update.generic_msg.message.split(',')
            if (0 < IsInt(data[0])):
                self.today = int(data[0])
                self.interestRates['ROR'] = daily_rate(float(data[1]))
                self.interestRates['HAP'] = daily_rate(float(data[2]))
                self.interestRates['USD'] = daily_rate(float(data[3]))
                print(update.generic_msg.message)
                # print(self.interestRates['ROR'], self.interestRates['HAP'], self.interestRates['USD'])
                await self.evaluate_fairs()
                await self.place_bids(FUTURES)
                await self.place_asks(FUTURES)
                await self.spot_market()
            elif ("New Federal Funds Target" in data):
                d = data.split(" ")
                currency = d[0]
                target = int(d[-1])
                self.federalRate[currency] = (target, self.today)
                print("!!!!! New Federal Funds Target for " + currency + ":" +
                      target)
            else:
                pass
                print(update.generic_msg.message)
Esempio n. 14
0
    async def handle_exchange_update(self, update: pb.FeedMessage):
        kind, _ = betterproto.which_one_of(update, "msg")

        #Possible exchange updates: 'market_snapshot_msg','fill_msg'
        #'liquidation_msg','generic_msg', 'trade_msg', 'pnl_msg', etc.
        """
        Calculate PnL based upon market to market contracts and tracked cash 
        """
        if kind == "pnl_msg":
            my_m2m = self.cash
            for asset in (
                [i + j for i in ["6R", "6H"]
                 for j in ["H", "M", "U", "Z"]] + ["RORUSD"]):
                my_m2m += self.mid[asset] * self.pos[asset] if self.mid[
                    asset] is not None else 0
            for asset in (["RH" + j for j in ["H", "M", "U", "Z"]]):
                my_m2m += (self.mid[asset] * self.pos[asset] *
                           self.mid["RORUSD"] if
                           (self.mid[asset] is not None
                            and self.mid["RORUSD"] is not None) else 0)
            print("M2M", update.pnl_msg.realized_pnl, update.pnl_msg.m2m_pnl,
                  my_m2m)
        #Update position upon fill messages of your trades
        elif kind == "fill_msg":
            if update.fill_msg.order_side == pb.FillMessageSide.BUY:
                self.cash -= update.fill_msg.filled_qty * float(
                    update.fill_msg.price)
                self.pos[update.fill_msg.asset] += update.fill_msg.filled_qty
                if update.fill_msg.asset != 'RORUSD':
                    reqs = await self.place_bids(update.fill_msg.asset)
                    resps = await asyncio.gather(*reqs)
                    for i, resp in enumerate(resps):
                        self.bidorderid[
                            update.fill_msg.asset][i] = resp.order_id
            else:
                self.cash += update.fill_msg.filled_qty * float(
                    update.fill_msg.price)
                self.pos[update.fill_msg.asset] -= update.fill_msg.filled_qty
                if update.fill_msg.asset != 'RORUSD':
                    reqs = await self.place_asks(update.fill_msg.asset)
                    resps = await asyncio.gather(*reqs)
                    for i, resp in enumerate(resps):
                        self.askorderid[
                            update.fill_msg.asset][i] = resp.order_id
            await self.spot_market()
        #Identify mid price through order book updates
        elif kind == "market_snapshot_msg":
            for asset in (FUTURES + ["RORUSD"]):
                book = update.market_snapshot_msg.books[asset]

                mid: "Optional[float]"
                if len(book.asks) > 0:
                    if len(book.bids) > 0:
                        mid = (float(book.asks[0].px) +
                               float(book.bids[0].px)) / 2
                    else:
                        mid = float(book.asks[0].px)
                elif len(book.bids) > 0:
                    mid = float(book.bids[0].px)
                else:
                    mid = None

                self.mid[asset] = mid
        #Competition event messages
        elif kind == "generic_msg":
            data = update.generic_msg.message.split(',')
            if (0 < IsInt(data[0])):
                TODAY = data[0]
                self.interestRates['ROR'] = daily_rate(float(data[1]))
                self.interestRates['HAP'] = daily_rate(float(data[2]))
                self.interestRates['USD'] = daily_rate(float(data[3]))
                print(update.generic_msg.message)
                # print(self.interestRates['ROR'], self.interestRates['HAP'], self.interestRates['USD'])
                await self.evaluate_fairs()
                bid_reqs = []
                ask_reqs = []
                for asset in FUTURES:
                    bid_reqs += await self.place_bids(asset)
                    ask_reqs += await self.place_asks(asset)
                bid_resps = await asyncio.gather(*bid_reqs)
                ask_resps = await asyncio.gather(*ask_reqs)
                for idx, resp in enumerate(bid_resps):
                    asset = FUTURES[math.floor(idx / 2)]
                    self.bidorderid[asset][idx % 2] = resp.order_id
                for idx, resp in enumerate(ask_resps):
                    asset = FUTURES[math.floor(idx / 2)]
                    self.askorderid[asset][idx % 2] = resp.order_id
                await self.spot_market()
            elif ("New Federal Funds Target" in data[0]):
                d = data.split(" ")
                currency = d[0]
                target = d[0]
                print("New Federal Funds Target for " + currency + ":" +
                      target)
            else:
                pass
                print(update.generic_msg.message)
Esempio n. 15
0
def get_one_of(base: T, sealed=True) -> T:
    _, value = which_one_of(base, ("sealed_" if sealed else "") + "value")
    return value
Esempio n. 16
0
 async def stream_messages(self):
     async for message in self.seabird.stream_events():
         name, val = betterproto.which_one_of(message, "inner")
         if name == "command":
             asyncio.create_task(self.handle_command(val))
Esempio n. 17
0
def test_which_count():
    message = Test()
    message.from_json(get_test_case_json_data("oneof"))
    assert betterproto.which_one_of(message, "foo") == ("count", 100)
def assert_round_trip_serialization_works(message: Test) -> None:
    assert betterproto.which_one_of(message,
                                    "value_type") == betterproto.which_one_of(
                                        Test().from_json(message.to_json()),
                                        "value_type")
Esempio n. 19
0
def test_which_name():
    message = Test()
    message.from_json(
        get_test_case_json_data("oneof", "oneof_name.json")[0].json)
    assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")