Exemplo n.º 1
0
    def decode(self, obj: bytes) -> Message:
        """Decode bytes into a 'ml_trade' message."""
        json_body = json.loads(obj.decode("utf-8"))
        body = {}

        msg_type = MLTradeMessage.Performative(json_body["performative"])
        body["performative"] = msg_type
        if msg_type == MLTradeMessage.Performative.CFT:
            query_bytes = base64.b64decode(json_body["query"])
            query = pickle.loads(query_bytes)  # nosec
            body["query"] = query
        elif msg_type == MLTradeMessage.Performative.TERMS:
            terms_bytes = base64.b64decode(json_body["terms"])
            terms = pickle.loads(terms_bytes)  # nosec
            body["terms"] = terms
        elif msg_type == MLTradeMessage.Performative.ACCEPT:
            terms_bytes = base64.b64decode(json_body["terms"])
            terms = pickle.loads(terms_bytes)  # nosec
            body["terms"] = terms
            body["tx_digest"] = json_body["tx_digest"]
        elif msg_type == MLTradeMessage.Performative.DATA:
            # encoding terms
            terms_bytes = base64.b64decode(json_body["terms"])
            terms = pickle.loads(terms_bytes)  # nosec
            body["terms"] = terms
            # encoding data
            data_bytes = base64.b64decode(json_body["data"])
            data = pickle.loads(data_bytes)  # nosec
            body["data"] = data
        else:  # pragma: no cover
            raise ValueError("Type not recognized.")

        return MLTradeMessage(performative=msg_type, body=body)
Exemplo n.º 2
0
def test_ml_messge_consistency():
    """Test the consistency of the message."""
    dm = DataModel("ml_datamodel", [Attribute("dataset_id", str, True)])
    query = Query([Constraint("dataset_id", ConstraintType("==", "fmnist"))],
                  model=dm)
    msg = MLTradeMessage(performative=MLTradeMessage.Performative.CFT,
                         query=query)
    with mock.patch.object(MLTradeMessage.Performative,
                           "__eq__",
                           return_value=False):
        assert not msg._is_consistent()
Exemplo n.º 3
0
    def _handle_search(self, agents: List[str]) -> None:
        """
        Handle the search response.

        :param agents: the agents returned by the search
        :return: None
        """
        if len(agents) == 0:
            self.context.logger.info(
                "[{}]: found no agents, continue searching.".format(
                    self.context.agent_name))
            return

        self.context.logger.info(
            "[{}]: found agents={}, stopping search.".format(
                self.context.agent_name, list(map(lambda x: x[-5:], agents))))
        strategy = cast(Strategy, self.context.strategy)
        strategy.is_searching = False
        query = strategy.get_service_query()
        for opponent_address in agents:
            self.context.logger.info("[{}]: sending CFT to agent={}".format(
                self.context.agent_name, opponent_address[-5:]))
            cft_msg = MLTradeMessage(
                performative=MLTradeMessage.Performative.CFT, query=query)
            self.context.outbox.put_message(
                to=opponent_address,
                sender=self.context.agent_address,
                protocol_id=MLTradeMessage.protocol_id,
                message=MLTradeSerializer().encode(cft_msg),
            )
Exemplo n.º 4
0
    def _handle_accept(self, ml_trade_msg: MLTradeMessage) -> None:
        """
        Handle accept.

        :param ml_trade_msg: the ml trade message
        :return: None
        """
        terms = ml_trade_msg.terms
        self.context.logger.info(
            "Got an Accept from {}: {}".format(
                ml_trade_msg.counterparty[-5:], terms.values
            )
        )
        strategy = cast(Strategy, self.context.strategy)
        if not strategy.is_valid_terms(terms):
            return
        batch_size = terms.values["batch_size"]
        data = strategy.sample_data(batch_size)
        self.context.logger.info(
            "[{}]: sending to address={} a Data message: shape={}".format(
                self.context.agent_name, ml_trade_msg.counterparty[-5:], data[0].shape
            )
        )
        data_msg = MLTradeMessage(
            performative=MLTradeMessage.Performative.DATA, terms=terms, data=data
        )
        self.context.outbox.put_message(
            to=ml_trade_msg.counterparty,
            sender=self.context.agent_address,
            protocol_id=MLTradeMessage.protocol_id,
            message=MLTradeSerializer().encode(data_msg),
        )
Exemplo n.º 5
0
    def _handle_cft(self, ml_trade_msg: MLTradeMessage) -> None:
        """
        Handle call for terms.

        :param ml_trade_msg: the ml trade message
        :return: None
        """
        query = ml_trade_msg.query
        self.context.logger.info(
            "Got a Call for Terms from {}: query={}".format(
                ml_trade_msg.counterparty[-5:], query
            )
        )
        strategy = cast(Strategy, self.context.strategy)
        if not strategy.is_matching_supply(query):
            return
        terms = strategy.generate_terms()
        self.context.logger.info(
            "[{}]: sending to the address={} a Terms message: {}".format(
                self.context.agent_name, ml_trade_msg.counterparty[-5:], terms.values
            )
        )
        terms_msg = MLTradeMessage(
            performative=MLTradeMessage.Performative.TERMS, terms=terms
        )
        self.context.outbox.put_message(
            to=ml_trade_msg.counterparty,
            sender=self.context.agent_address,
            protocol_id=MLTradeMessage.protocol_id,
            message=MLTradeSerializer().encode(terms_msg),
        )
Exemplo n.º 6
0
def test_ml_message_creation():
    """Test the creation of a ml message."""
    dm = DataModel("ml_datamodel", [Attribute("dataset_id", str, True)])
    query = Query([Constraint("dataset_id", ConstraintType("==", "fmnist"))],
                  model=dm)
    msg = MLTradeMessage(performative=MLTradeMessage.Performative.CFT,
                         query=query)
    msg_bytes = MLTradeSerializer().encode(msg)
    recovered_msg = MLTradeSerializer().decode(msg_bytes)
    assert recovered_msg == msg

    terms = Description({
        "batch_size": 5,
        "price": 10,
        "seller_tx_fee": 5,
        "buyer_tx_fee": 2,
        "currency_id": "FET",
        "ledger_id": "fetch",
        "address": "agent1",
    })

    msg = MLTradeMessage(performative=MLTradeMessage.Performative.TERMS,
                         terms=terms)
    msg_bytes = MLTradeSerializer().encode(msg)
    recovered_msg = MLTradeSerializer().decode(msg_bytes)
    assert recovered_msg == msg

    tx_digest = "This is the transaction digest."
    msg = MLTradeMessage(
        performative=MLTradeMessage.Performative.ACCEPT,
        terms=terms,
        tx_digest=tx_digest,
    )
    msg_bytes = MLTradeSerializer().encode(msg)
    recovered_msg = MLTradeSerializer().decode(msg_bytes)
    assert recovered_msg == msg

    data = np.zeros((5, 2)), np.zeros((5, 2))
    msg = MLTradeMessage(performative=MLTradeMessage.Performative.DATA,
                         terms=terms,
                         data=data)
    msg_bytes = MLTradeSerializer().encode(msg)
    with pytest.raises(ValueError):
        recovered_msg = MLTradeSerializer().decode(msg_bytes)
        assert recovered_msg == msg
    assert np.array_equal(recovered_msg.data, msg.data)
Exemplo n.º 7
0
    def handle(self, message: Message) -> None:
        """
        Implement the reaction to a message.

        :param message: the message
        :return: None
        """
        tx_msg_response = cast(TransactionMessage, message)
        if (tx_msg_response.performative ==
                TransactionMessage.Performative.SUCCESSFUL_SETTLEMENT):
            self.context.logger.info(
                "[{}]: transaction was successful.".format(
                    self.context.agent_name))
            info = tx_msg_response.info
            terms = cast(Description, info.get("terms"))
            ml_accept = MLTradeMessage(
                performative=MLTradeMessage.Performative.ACCEPT,
                tx_digest=tx_msg_response.tx_digest,
                terms=terms,
            )
            self.context.outbox.put_message(
                to=tx_msg_response.tx_counterparty_addr,
                sender=self.context.agent_address,
                protocol_id=MLTradeMessage.protocol_id,
                message=MLTradeSerializer().encode(ml_accept),
            )
            self.context.logger.info(
                "[{}]: Sending accept to counterparty={} with transaction digest={} and terms={}."
                .format(
                    self.context.agent_name,
                    tx_msg_response.tx_counterparty_addr[-5:],
                    tx_msg_response.tx_digest,
                    terms.values,
                ))
        else:
            self.context.logger.info(
                "[{}]: transaction was not successful.".format(
                    self.context.agent_name))
Exemplo n.º 8
0
    def _handle_terms(self, ml_trade_msg: MLTradeMessage) -> None:
        """
        Handle the terms of the request.

        :param ml_trade_msg: the ml trade message
        :return: None
        """
        terms = ml_trade_msg.terms
        self.context.logger.info(
            "Received terms message from {}: terms={}".format(
                ml_trade_msg.counterparty[-5:], terms.values))

        strategy = cast(Strategy, self.context.strategy)
        acceptable = strategy.is_acceptable_terms(terms)
        affordable = strategy.is_affordable_terms(terms)
        if not acceptable and affordable:
            self.context.logger.info(
                "[{}]: rejecting, terms are not acceptable and/or affordable".
                format(self.context.agent_name))
            return

        if strategy.is_ledger_tx:
            # propose the transaction to the decision maker for settlement on the ledger
            tx_msg = TransactionMessage(
                performative=TransactionMessage.Performative.
                PROPOSE_FOR_SETTLEMENT,
                skill_callback_ids=[PublicId("fetchai", "ml_train", "0.1.0")],
                tx_id=strategy.get_next_transition_id(),
                tx_sender_addr=self.context.agent_addresses[
                    terms.values["ledger_id"]],
                tx_counterparty_addr=terms.values["address"],
                tx_amount_by_currency_id={
                    terms.values["currency_id"]: -terms.values["price"]
                },
                tx_sender_fee=terms.values["buyer_tx_fee"],
                tx_counterparty_fee=terms.values["seller_tx_fee"],
                tx_quantities_by_good_id={},
                ledger_id=terms.values["ledger_id"],
                info={
                    "terms": terms,
                    "counterparty_addr": ml_trade_msg.counterparty
                },
            )  # this is used to send the terms later - because the seller is stateless and must know what terms have been accepted
            self.context.decision_maker_message_queue.put_nowait(tx_msg)
            self.context.logger.info(
                "[{}]: proposing the transaction to the decision maker. Waiting for confirmation ..."
                .format(self.context.agent_name))
        else:
            # accept directly with a dummy transaction digest, no settlement
            ml_accept = MLTradeMessage(
                performative=MLTradeMessage.Performative.ACCEPT,
                tx_digest=DUMMY_DIGEST,
                terms=terms,
            )
            self.context.outbox.put_message(
                to=ml_trade_msg.counterparty,
                sender=self.context.agent_address,
                protocol_id=MLTradeMessage.protocol_id,
                message=MLTradeSerializer().encode(ml_accept),
            )
            self.context.logger.info(
                "[{}]: sending dummy transaction digest ...".format(
                    self.context.agent_name))
Exemplo n.º 9
0
def test_ml_wrong_message_creation():
    """Test the creation of a ml message."""
    with pytest.raises(AssertionError):
        MLTradeMessage(performative=MLTradeMessage.Performative.CFT, query="")