class AuditBatchHandler(BatchRequestHandler):

    def __init__(self, database_manager: DatabaseManager):
        super().__init__(database_manager, AUDIT_LEDGER_ID)
        # TODO: move it to BatchRequestHandler
        self.tracker = LedgerUncommittedTracker(None, self.ledger.uncommitted_root_hash, self.ledger.size)

    def post_batch_applied(self, three_pc_batch: ThreePcBatch, prev_handler_result=None):
        txn = self._add_to_ledger(three_pc_batch)
        self.tracker.apply_batch(None, self.ledger.uncommitted_root_hash, self.ledger.uncommitted_size)
        logger.debug("applied audit txn {}; uncommitted root hash is {}; uncommitted size is {}".
                     format(str(txn), self.ledger.uncommitted_root_hash, self.ledger.uncommitted_size))

    def post_batch_rejected(self, ledger_id, prev_handler_result=None):
        _, _, txn_count = self.tracker.reject_batch()
        self.ledger.discardTxns(txn_count)
        logger.debug("rejected {} audit txns; uncommitted root hash is {}; uncommitted size is {}".
                     format(txn_count, self.ledger.uncommitted_root_hash, self.ledger.uncommitted_size))

    def commit_batch(self, three_pc_batch, prev_handler_result=None):
        _, _, txns_count = self.tracker.commit_batch()
        _, committedTxns = self.ledger.commitTxns(txns_count)
        logger.debug("committed {} audit txns; uncommitted root hash is {}; uncommitted size is {}".
                     format(txns_count, self.ledger.uncommitted_root_hash, self.ledger.uncommitted_size))
        return committedTxns

    def on_catchup_finished(self):
        self.tracker.set_last_committed(state_root=None,
                                        txn_root=self.ledger.uncommitted_root_hash,
                                        ledger_size=self.ledger.size)

    @staticmethod
    def transform_txn_for_ledger(txn):
        '''
        Makes sure that we have integer as keys after possible deserialization from json
        :param txn: txn to be transformed
        :return: transformed txn
        '''
        txn_data = get_payload_data(txn)
        txn_data[AUDIT_TXN_LEDGERS_SIZE] = {int(k): v for k, v in txn_data[AUDIT_TXN_LEDGERS_SIZE].items()}
        txn_data[AUDIT_TXN_LEDGER_ROOT] = {int(k): v for k, v in txn_data[AUDIT_TXN_LEDGER_ROOT].items()}
        txn_data[AUDIT_TXN_STATE_ROOT] = {int(k): v for k, v in txn_data[AUDIT_TXN_STATE_ROOT].items()}
        return txn

    def _add_to_ledger(self, three_pc_batch: ThreePcBatch):
        # if PRE-PREPARE doesn't have audit txn (probably old code) - do nothing
        # TODO: remove this check after all nodes support audit ledger
        if not three_pc_batch.has_audit_txn:
            logger.info("Has 3PC batch without audit ledger: {}".format(str(three_pc_batch)))
            return

        # 1. prepare AUDIT txn
        txn_data = self._create_audit_txn_data(three_pc_batch, self.ledger.get_last_txn())
        txn = init_empty_txn(txn_type=PlenumTransactions.AUDIT.value)
        txn = set_payload_data(txn, txn_data)

        # 2. Append txn metadata
        self.ledger.append_txns_metadata([txn], three_pc_batch.pp_time)

        # 3. Add to the Ledger
        self.ledger.appendTxns([txn])
        return txn

    def _create_audit_txn_data(self, three_pc_batch, last_audit_txn):
        # 1. general format and (view_no, pp_seq_no)
        txn = {
            TXN_VERSION: "1",
            AUDIT_TXN_VIEW_NO: three_pc_batch.view_no,
            AUDIT_TXN_PP_SEQ_NO: three_pc_batch.pp_seq_no,
            AUDIT_TXN_LEDGERS_SIZE: {},
            AUDIT_TXN_LEDGER_ROOT: {},
            AUDIT_TXN_STATE_ROOT: {},
            AUDIT_TXN_PRIMARIES: None
        }

        for lid, ledger in self.database_manager.ledgers.items():
            if lid == AUDIT_LEDGER_ID:
                continue
            # 2. ledger size
            txn[AUDIT_TXN_LEDGERS_SIZE][lid] = ledger.uncommitted_size

            # 3. ledger root (either root_hash or seq_no to last changed)
            # TODO: support setting for multiple ledgers
            self.__fill_ledger_root_hash(txn, three_pc_batch, lid, last_audit_txn)

        # 4. state root hash
        txn[AUDIT_TXN_STATE_ROOT][three_pc_batch.ledger_id] = Ledger.hashToStr(three_pc_batch.state_root)

        # 5. set primaries field
        self.__fill_primaries(txn, three_pc_batch, last_audit_txn)

        return txn

    def __fill_ledger_root_hash(self, txn, three_pc_batch, lid, last_audit_txn):
        target_ledger_id = three_pc_batch.ledger_id
        last_audit_txn_data = get_payload_data(last_audit_txn) if last_audit_txn is not None else None

        # 1. ledger is changed in this batch => root_hash
        if lid == target_ledger_id:
            txn[AUDIT_TXN_LEDGER_ROOT][lid] = Ledger.hashToStr(three_pc_batch.txn_root)

        # 2. This ledger is never audited, so do not add the key
        elif last_audit_txn_data is None or lid not in last_audit_txn_data[AUDIT_TXN_LEDGER_ROOT]:
            return

        # 3. ledger is not changed in last batch => delta = delta + 1
        elif isinstance(last_audit_txn_data[AUDIT_TXN_LEDGER_ROOT][lid], int):
            txn[AUDIT_TXN_LEDGER_ROOT][lid] = last_audit_txn_data[AUDIT_TXN_LEDGER_ROOT][lid] + 1

        # 4. ledger is changed in last batch but not changed now => delta = 1
        elif last_audit_txn_data:
            txn[AUDIT_TXN_LEDGER_ROOT][lid] = 1

    def __fill_primaries(self, txn, three_pc_batch, last_audit_txn):
        last_audit_txn_data = get_payload_data(last_audit_txn) if last_audit_txn is not None else None
        last_txn_value = last_audit_txn_data[AUDIT_TXN_PRIMARIES] if last_audit_txn_data else None
        current_primaries = three_pc_batch.primaries

        # 1. First audit txn
        if last_audit_txn_data is None:
            txn[AUDIT_TXN_PRIMARIES] = current_primaries

        # 2. Previous primaries field contains primary list
        # If primaries did not changed, we will store seq_no delta
        # between current txn and last persisted primaries, i.e.
        # we can find seq_no of last actual primaries, like:
        # last_audit_txn_seq_no - last_audit_txn[AUDIT_TXN_PRIMARIES]
        elif isinstance(last_txn_value, Iterable):
            if last_txn_value == current_primaries:
                txn[AUDIT_TXN_PRIMARIES] = 1
            else:
                txn[AUDIT_TXN_PRIMARIES] = current_primaries

        # 3. Previous primaries field is delta
        elif isinstance(last_txn_value, int) and last_txn_value < self.ledger.uncommitted_size:
            last_primaries_seq_no = get_seq_no(last_audit_txn) - last_txn_value
            last_primaries = get_payload_data(
                self.ledger.get_by_seq_no_uncommitted(last_primaries_seq_no))[AUDIT_TXN_PRIMARIES]
            if isinstance(last_primaries, Iterable):
                if last_primaries == current_primaries:
                    txn[AUDIT_TXN_PRIMARIES] = last_txn_value + 1
                else:
                    txn[AUDIT_TXN_PRIMARIES] = current_primaries
            else:
                raise LogicError('Value, mentioned in primaries field must be a '
                                 'seq_no of a txn with primaries')

        # 4. That cannot be
        else:
            raise LogicError('Incorrect primaries field in audit ledger (seq_no: {}. value: {})'.format(
                get_seq_no(last_audit_txn), last_txn_value))
Example #2
0
class AuditBatchHandler(BatchRequestHandler):
    def __init__(self, database_manager: DatabaseManager):
        super().__init__(database_manager, AUDIT_LEDGER_ID)
        # TODO: move it to BatchRequestHandler
        self.tracker = LedgerUncommittedTracker(None, self.ledger.size)

    def post_batch_applied(self,
                           three_pc_batch: ThreePcBatch,
                           prev_handler_result=None):
        self._add_to_ledger(three_pc_batch)
        self.tracker.apply_batch(None, self.ledger.uncommitted_size)

    def post_batch_rejected(self, ledger_id, prev_handler_result=None):
        _, txn_count = self.tracker.reject_batch()
        self.ledger.discardTxns(txn_count)

    def commit_batch(self,
                     ledger_id,
                     txn_count,
                     state_root,
                     txn_root,
                     pp_time,
                     prev_result=None):
        _, txns_count = self.tracker.commit_batch()
        _, committedTxns = self.ledger.commitTxns(txns_count)
        return committedTxns

    def _add_to_ledger(self, three_pc_batch: ThreePcBatch):
        # if PRE-PREPARE doesn't have audit txn (probably old code) - do nothing
        # TODO: remove this check after all nodes support audit ledger
        if not three_pc_batch.has_audit_txn:
            return

        # 1. prepare AUDIT txn
        txn_data = self._create_audit_txn_data(three_pc_batch,
                                               self.ledger.get_last_txn())
        txn = init_empty_txn(txn_type=PlenumTransactions.AUDIT.value)
        txn = set_payload_data(txn, txn_data)

        # 2. Append txn metadata
        self.ledger.append_txns_metadata([txn], three_pc_batch.pp_time)

        # 3. Add to the Ledger
        self.ledger.appendTxns([txn])

    def _create_audit_txn_data(self, three_pc_batch, last_audit_txn):
        # 1. general format and (view_no, pp_seq_no)
        txn = {
            TXN_VERSION: "1",
            AUDIT_TXN_VIEW_NO: three_pc_batch.view_no,
            AUDIT_TXN_PP_SEQ_NO: three_pc_batch.pp_seq_no,
            AUDIT_TXN_LEDGERS_SIZE: {},
            AUDIT_TXN_LEDGER_ROOT: {},
            AUDIT_TXN_STATE_ROOT: {}
        }

        for lid, ledger in self.database_manager.ledgers.items():
            if lid == AUDIT_LEDGER_ID:
                continue
            # 2. ledger size
            txn[AUDIT_TXN_LEDGERS_SIZE][str(lid)] = ledger.uncommitted_size

            # 3. ledger root (either root_hash or seq_no to last changed)
            # TODO: support setting for multiple ledgers
            self.__fill_ledger_root_hash(txn, three_pc_batch, lid,
                                         last_audit_txn)

        # 4. state root hash
        txn[AUDIT_TXN_STATE_ROOT][str(
            three_pc_batch.ledger_id)] = Ledger.hashToStr(
                three_pc_batch.state_root)

        return txn

    def __fill_ledger_root_hash(self, txn, three_pc_batch, lid,
                                last_audit_txn):
        target_ledger_id = three_pc_batch.ledger_id
        last_audit_txn_data = get_payload_data(
            last_audit_txn) if last_audit_txn is not None else None

        # 1. ledger is changed in this batch => root_hash
        if lid == target_ledger_id:
            txn[AUDIT_TXN_LEDGER_ROOT][str(lid)] = Ledger.hashToStr(
                three_pc_batch.txn_root)

        # 2. This ledger is never audited, so do not add the key
        elif last_audit_txn_data is None or str(
                lid) not in last_audit_txn_data[AUDIT_TXN_LEDGER_ROOT]:
            return

        # 3. ledger is not changed in last batch => the same audit seq no
        elif isinstance(last_audit_txn_data[AUDIT_TXN_LEDGER_ROOT][str(lid)],
                        int):
            txn[AUDIT_TXN_LEDGER_ROOT][str(
                lid)] = last_audit_txn_data[AUDIT_TXN_LEDGER_ROOT][str(lid)]

        # 4. ledger is changed in last batch but not changed now => seq_no of last audit txn
        elif last_audit_txn_data:
            txn[AUDIT_TXN_LEDGER_ROOT][str(lid)] = get_seq_no(last_audit_txn)
Example #3
0
class StaticFeesReqHandler(FeeReqHandler):
    write_types = FeeReqHandler.write_types.union({SET_FEES, FEE_TXN})
    query_types = FeeReqHandler.query_types.union({GET_FEES, GET_FEE})
    set_fees_validator_cls = SetFeesMsg
    get_fee_validator_cls = GetFeeMsg
    state_serializer = JsonSerializer()

    def __init__(self,
                 ledger,
                 state,
                 token_ledger,
                 token_state,
                 utxo_cache,
                 domain_state,
                 bls_store,
                 node,
                 write_req_validator,
                 ts_store=None):

        super().__init__(ledger,
                         state,
                         domain_state,
                         idrCache=node.idrCache,
                         upgrader=node.upgrader,
                         poolManager=node.poolManager,
                         poolCfg=node.poolCfg,
                         write_req_validator=node.write_req_validator,
                         bls_store=bls_store,
                         ts_store=ts_store)

        self.token_ledger = token_ledger
        self.token_state = token_state
        self.utxo_cache = utxo_cache
        self.domain_state = domain_state
        self.bls_store = bls_store
        self.write_req_validator = write_req_validator

        self._add_query_handler(GET_FEES, self.get_fees)
        self._add_query_handler(GET_FEE, self.get_fee)

        # Tracks count of transactions paying sovtokenfees while a batch is being
        # processed. Reset to zero once a batch is created (not committed)
        self.fee_txns_in_current_batch = 0
        # Tracks amount of deducted sovtokenfees for a transaction
        self.deducted_fees = {}
        self.token_tracker = LedgerUncommittedTracker(
            token_state.committedHeadHash, token_ledger.uncommitted_root_hash,
            token_ledger.size)

    @property
    def fees(self):
        return self._get_fees(is_committed=False)

    @staticmethod
    def get_ref_for_txn_fees(ledger_id, seq_no):
        return '{}:{}'.format(ledger_id, seq_no)

    def get_txn_fees(self, request) -> int:
        return self.fees.get(request.operation[TXN_TYPE], 0)

    # TODO: Fix this to match signature of `FeeReqHandler` and extract
    # the params from `kwargs`
    def deduct_fees(self, request, cons_time, ledger_id, seq_no, txn):
        txn_type = request.operation[TXN_TYPE]
        fees_key = "{}#{}".format(txn_type, seq_no)
        if txn_type != XFER_PUBLIC and FeesAuthorizer.has_fees(request):
            inputs, outputs, signatures = getattr(request, f.FEES.nm)
            # This is correct since FEES is changed from config ledger whose
            # transactions have no fees
            fees = FeesAuthorizer.calculate_fees_from_req(
                self.utxo_cache, request)
            sigs = {i[ADDRESS]: s for i, s in zip(inputs, signatures)}
            txn = {
                OPERATION: {
                    TXN_TYPE: FEE_TXN,
                    INPUTS: inputs,
                    OUTPUTS: outputs,
                    REF: self.get_ref_for_txn_fees(ledger_id, seq_no),
                    FEES: fees,
                },
                f.SIGS.nm: sigs,
                f.REQ_ID.nm: get_req_id(txn),
                f.PROTOCOL_VERSION.nm: 2,
            }
            txn = reqToTxn(txn)
            self.token_ledger.append_txns_metadata([txn], txn_time=cons_time)
            _, txns = self.token_ledger.appendTxns(
                [TokenReqHandler.transform_txn_for_ledger(txn)])
            self.updateState(txns)
            self.fee_txns_in_current_batch += 1
            self.deducted_fees[fees_key] = fees
            return txn

    def doStaticValidation(self, request: Request):
        operation = request.operation
        if operation[TXN_TYPE] in (SET_FEES, GET_FEES, GET_FEE):
            try:
                if operation[TXN_TYPE] == SET_FEES:
                    self.set_fees_validator_cls(**request.operation)
                elif operation[TXN_TYPE] == GET_FEE:
                    self.get_fee_validator_cls(**request.operation)
            except TypeError as exc:
                raise InvalidClientRequest(request.identifier, request.reqId,
                                           exc)
        else:
            super().doStaticValidation(request)

    def _fees_specific_validation(self, request: Request):
        operation = request.operation
        current_fees = self._get_fees()
        constraint = self.get_auth_constraint(operation)
        wrong_aliases = []
        self._validate_metadata(self.fees, constraint, wrong_aliases)
        if len(wrong_aliases) > 0:
            raise InvalidClientMessageException(
                request.identifier, request.reqId,
                "Fees alias(es) {} does not exist in current fees {}. "
                "Please add the alias(es) via SET_FEES transaction first.".
                format(", ".join(wrong_aliases), current_fees))

    def _validate_metadata(self, current_fees, constraint: AuthConstraint,
                           wrong_aliases):
        if constraint.constraint_id != ConstraintsEnum.ROLE_CONSTRAINT_ID:
            for constr in constraint.auth_constraints:
                self._validate_metadata(current_fees, constr, wrong_aliases)
        else:
            meta_alias = constraint.metadata.get(FEES_FIELD_NAME, None)
            if meta_alias and meta_alias not in current_fees:
                wrong_aliases.append(meta_alias)

    def validate(self, request: Request):
        operation = request.operation
        if operation[TXN_TYPE] == SET_FEES:
            return self.write_req_validator.validate(request, [
                AuthActionEdit(
                    txn_type=SET_FEES, field="*", old_value="*", new_value="*")
            ])
        else:
            super().validate(request)
        if operation[TXN_TYPE] == AUTH_RULE:
            # metadata validation
            self._fees_specific_validation(request)

    def updateState(self, txns, isCommitted=False):
        for txn in txns:
            self._update_state_with_single_txn(txn, is_committed=isCommitted)
        super().updateState(txns, isCommitted=isCommitted)

    def get_fees(self, request: Request):
        fees, proof = self._get_fees(is_committed=True, with_proof=True)
        result = {
            f.IDENTIFIER.nm: request.identifier,
            f.REQ_ID.nm: request.reqId,
            FEES: fees
        }
        if proof:
            result[STATE_PROOF] = proof
        result.update(request.operation)
        return result

    def get_fee(self, request: Request):
        alias = request.operation.get(ALIAS)
        fee, proof = self._get_fee(alias, is_committed=True, with_proof=True)
        result = {
            f.IDENTIFIER.nm: request.identifier,
            f.REQ_ID.nm: request.reqId,
            FEE: fee
        }
        if proof:
            result[STATE_PROOF] = proof
        result.update(request.operation)
        return result

    def post_batch_created(self, ledger_id, state_root):
        # it mean, that all tracker thins was done in onBatchCreated phase for TokenReqHandler
        self.token_tracker.apply_batch(self.token_state.headHash,
                                       self.token_ledger.uncommitted_root_hash,
                                       self.token_ledger.uncommitted_size)
        if ledger_id == TOKEN_LEDGER_ID:
            return
        if self.fee_txns_in_current_batch > 0:
            state_root = self.token_state.headHash
            TokenReqHandler.on_batch_created(self.utxo_cache, state_root)
            # ToDo: Needed investigation about affection of removing setting this var into 0
            self.fee_txns_in_current_batch = 0

    def post_batch_rejected(self, ledger_id):
        uncommitted_hash, uncommitted_txn_root, txn_count = self.token_tracker.reject_batch(
        )
        if ledger_id == TOKEN_LEDGER_ID:
            # TODO: Need to improve this logic for case, when we got a XFER txn with fees
            # All of other txn with fees it's a 2 steps, "apply txn" and "apply fees"
            # But for XFER txn with fees we do only "apply fees with transfer too"
            return
        if txn_count == 0 or self.token_ledger.uncommitted_root_hash == uncommitted_txn_root or \
                self.token_state.headHash == uncommitted_hash:
            return 0
        self.token_state.revertToHead(uncommitted_hash)
        self.token_ledger.discardTxns(txn_count)
        count_reverted = TokenReqHandler.on_batch_rejected(self.utxo_cache)
        logger.info("Reverted {} txns with fees".format(count_reverted))

    def post_batch_committed(self, ledger_id, pp_time, committed_txns,
                             state_root, txn_root):
        token_state_root, token_txn_root, _ = self.token_tracker.commit_batch()
        if ledger_id == TOKEN_LEDGER_ID:
            return
        committed_seq_nos_with_fees = [
            get_seq_no(t) for t in committed_txns
            if get_type(t) != XFER_PUBLIC and "{}#{}".format(
                get_type(t), get_seq_no(t)) in self.deducted_fees
        ]
        if len(committed_seq_nos_with_fees) > 0:
            r = TokenReqHandler.__commit__(
                self.utxo_cache, self.token_ledger, self.token_state,
                len(committed_seq_nos_with_fees), token_state_root,
                txn_root_serializer.serialize(token_txn_root), pp_time)
            i = 0
            for txn in committed_txns:
                if get_seq_no(txn) in committed_seq_nos_with_fees:
                    txn[FEES] = r[i]
                    i += 1
            self.fee_txns_in_current_batch = 0

    def _get_fees(self, is_committed=False, with_proof=False):
        result = self._get_fee_from_state(is_committed=is_committed,
                                          with_proof=with_proof)
        if with_proof:
            fees, proof = result
            return (fees, proof) if fees is not None else ({}, proof)
        else:
            return result if result is not None else {}

    def _get_fee(self, alias, is_committed=False, with_proof=False):
        return self._get_fee_from_state(fees_alias=alias,
                                        is_committed=is_committed,
                                        with_proof=with_proof)

    def _get_fee_from_state(self,
                            fees_alias=None,
                            is_committed=False,
                            with_proof=False):
        fees = None
        proof = None
        try:
            fees_key = build_path_for_set_fees(alias=fees_alias)
            if with_proof:
                proof, serz = self.state.generate_state_proof(fees_key,
                                                              serialize=True,
                                                              get_value=True)
                if serz:
                    serz = rlp_decode(serz)[0]
                root_hash = self.state.committedHeadHash if is_committed else self.state.headHash
                encoded_root_hash = state_roots_serializer.serialize(
                    bytes(root_hash))
                multi_sig = self.bls_store.get(encoded_root_hash)
                if multi_sig:
                    encoded_proof = proof_nodes_serializer.serialize(proof)
                    proof = {
                        MULTI_SIGNATURE: multi_sig.as_dict(),
                        ROOT_HASH: encoded_root_hash,
                        PROOF_NODES: encoded_proof
                    }
                else:
                    proof = {}
            else:
                serz = self.state.get(fees_key, isCommitted=is_committed)
            if serz:
                fees = self.state_serializer.deserialize(serz)
        except KeyError:
            pass
        if with_proof:
            return fees, proof
        return fees

    def _set_to_state(self, key, val):
        val = self.state_serializer.serialize(val)
        key = key.encode()
        self.state.set(key, val)

    def _update_state_with_single_txn(self, txn, is_committed=False):
        typ = get_type(txn)
        if typ == SET_FEES:
            payload = get_payload_data(txn)
            fees_from_req = payload.get(FEES)
            current_fees = self._get_fees()
            current_fees.update(fees_from_req)
            for fees_alias, fees_value in fees_from_req.items():
                self._set_to_state(build_path_for_set_fees(alias=fees_alias),
                                   fees_value)
            self._set_to_state(build_path_for_set_fees(), current_fees)

        elif typ == FEE_TXN:
            for utxo in txn[TXN_PAYLOAD][TXN_PAYLOAD_DATA][INPUTS]:
                TokenReqHandler.spend_input(state=self.token_state,
                                            utxo_cache=self.utxo_cache,
                                            address=utxo[ADDRESS],
                                            seq_no=utxo[SEQNO],
                                            is_committed=is_committed)
            seq_no = get_seq_no(txn)
            for output in txn[TXN_PAYLOAD][TXN_PAYLOAD_DATA][OUTPUTS]:
                TokenReqHandler.add_new_output(state=self.token_state,
                                               utxo_cache=self.utxo_cache,
                                               output=Output(
                                                   output[ADDRESS], seq_no,
                                                   output[AMOUNT]),
                                               is_committed=is_committed)

    @staticmethod
    def _handle_incorrect_funds(sum_inputs, sum_outputs, expected_amount,
                                required_fees, request):
        if sum_inputs < expected_amount:
            error = 'Insufficient funds, sum of inputs is {} ' \
                    'but required is {} (sum of outputs: {}, ' \
                    'fees: {})'.format(sum_inputs, expected_amount, sum_outputs, required_fees)
            raise InsufficientFundsError(request.identifier, request.reqId,
                                         error)
        if sum_inputs > expected_amount:
            error = 'Extra funds, sum of inputs is {} ' \
                    'but required is: {} -- sum of outputs: {} ' \
                    '-- fees: {})'.format(sum_inputs, expected_amount, sum_outputs, required_fees)
            raise ExtraFundsError(request.identifier, request.reqId, error)

    @staticmethod
    def transform_txn_for_ledger(txn):
        """
        Some transactions need to be updated before they can be stored in the
        ledger
        """
        return txn

    def postCatchupCompleteClbk(self):
        self.token_tracker.set_last_committed(
            self.token_state.committedHeadHash,
            self.token_ledger.uncommitted_root_hash, self.token_ledger.size)
Example #4
0
class TokenReqHandler(LedgerRequestHandler):
    write_types = {MINT_PUBLIC, XFER_PUBLIC}
    query_types = {
        GET_UTXO,
    }

    MinSendersForPublicMint = 3

    def __init__(self, ledger, state: PruningState, utxo_cache: UTXOCache,
                 domain_state, bls_store):
        super().__init__(ledger, state)
        self.utxo_cache = utxo_cache
        self.domain_state = domain_state
        self.bls_store = bls_store
        self.tracker = LedgerUncommittedTracker(state.committedHeadHash,
                                                ledger.uncommitted_root_hash,
                                                ledger.size)
        self.query_handlers = {
            GET_UTXO: self.get_all_utxo,
        }

    def handle_xfer_public_txn(self, request):
        # Currently only sum of inputs is matched with sum of outputs. If anything more is
        # needed then a new function should be created.
        try:
            sum_inputs = TokenReqHandler.sum_inputs(self.utxo_cache,
                                                    request,
                                                    is_committed=False)

            sum_outputs = TokenReqHandler.sum_outputs(request)
        except Exception as ex:
            if isinstance(ex, InvalidClientMessageException):
                raise ex
            error = 'Exception {} while processing inputs/outputs'.format(ex)
            raise InvalidClientMessageException(
                request.identifier, getattr(request, 'reqId', None), error)
        else:
            return TokenReqHandler._validate_xfer_public_txn(
                request, sum_inputs, sum_outputs)

    @staticmethod
    def _validate_xfer_public_txn(request: Request, sum_inputs: int,
                                  sum_outputs: int):
        TokenReqHandler.validate_given_inputs_outputs(sum_inputs, sum_outputs,
                                                      sum_outputs, request)

    @staticmethod
    def validate_given_inputs_outputs(inputs_sum,
                                      outputs_sum,
                                      required_amount,
                                      request,
                                      error_msg_suffix: Optional[str] = None):
        """
        Checks three sum values against simple set of rules. inputs_sum must be equal to required_amount. Exceptions
        are raise if it is not equal. The outputs_sum is pass not for checks but to be included in error messages.
        This is confusing but is required in cases where the required amount is different then the sum of outputs (
        in the case of fees).

        :param inputs_sum: the sum of inputs
        :param outputs_sum: the sum of outputs
        :param required_amount: the required amount to validate (could be equal to output_sum, but may be different)
        :param request: the request that is being validated
        :param error_msg_suffix: added message to the error message
        :return: returns if valid or will raise an exception
        """

        if inputs_sum == required_amount:
            return  # Equal is valid
        elif inputs_sum > required_amount:
            error = 'Extra funds, sum of inputs is {} ' \
                    'but required amount: {} -- sum of outputs: {}'.format(inputs_sum, required_amount, outputs_sum)
            if error_msg_suffix and isinstance(error_msg_suffix, str):
                error += ' ' + error_msg_suffix
            raise ExtraFundsError(getattr(request, f.IDENTIFIER.nm, None),
                                  getattr(request, f.REQ_ID.nm, None), error)

        elif inputs_sum < required_amount:
            error = 'Insufficient funds, sum of inputs is {}' \
                    'but required amount is {}. sum of outputs: {}'.format(inputs_sum, required_amount, outputs_sum)
            if error_msg_suffix and isinstance(error_msg_suffix, str):
                error += ' ' + error_msg_suffix
            raise InsufficientFundsError(
                getattr(request, f.IDENTIFIER.nm, None),
                getattr(request, f.REQ_ID.nm, None), error)

        raise InvalidClientMessageException(
            getattr(request, f.IDENTIFIER.nm, None),
            getattr(request, f.REQ_ID.nm, None),
            'Request to not meet minimum requirements')

    def doStaticValidation(self, request: Request):
        static_req_validation(request)

    def validate(self, request: Request):
        req_type = request.operation[TXN_TYPE]
        if req_type == MINT_PUBLIC:
            return validate_multi_sig_txn(request, TRUSTEE, self.domain_state,
                                          self.MinSendersForPublicMint)

        elif req_type == XFER_PUBLIC:
            return self.handle_xfer_public_txn(request)

        raise InvalidClientMessageException(
            request.identifier, getattr(request, 'reqId', None),
            'Unsupported request type - {}'.format(req_type))

    @staticmethod
    def transform_txn_for_ledger(txn):
        """
        Token TXNs does not need to be transformed
        """
        return txn

    def _reqToTxn(self, req: Request):
        """
        Converts the request to a transaction. This is called by LedgerRequestHandler. Not a
        public method. TODO we should consider a more standard approach to inheritance.

        :param req:
        :return: the converted transaction from the Request
        """
        if req.operation[TXN_TYPE] == XFER_PUBLIC:
            sigs = req.operation.pop(SIGS)
        txn = reqToTxn(req)
        if req.operation[TXN_TYPE] == XFER_PUBLIC:
            req.operation[SIGS] = sigs
            sigs = [(i["address"], s)
                    for i, s in zip(req.operation[INPUTS], sigs)]
            add_sigs_to_txn(txn, sigs, sig_type=ED25519)
        return txn

    def _update_state_mint_public_txn(self, txn, is_committed=False):
        payload = get_payload_data(txn)
        seq_no = get_seq_no(txn)
        for output in payload[OUTPUTS]:
            self._add_new_output(Output(output["address"], seq_no,
                                        output["amount"]),
                                 is_committed=is_committed)

    def _update_state_xfer_public(self, txn, is_committed=False):
        payload = get_payload_data(txn)
        for inp in payload[INPUTS]:
            self._spend_input(inp["address"],
                              inp["seqNo"],
                              is_committed=is_committed)
        for output in payload[OUTPUTS]:
            seq_no = get_seq_no(txn)
            self._add_new_output(Output(output["address"], seq_no,
                                        output["amount"]),
                                 is_committed=is_committed)

    def updateState(self, txns, isCommitted=False):
        try:
            for txn in txns:
                typ = get_type(txn)
                if typ == MINT_PUBLIC:
                    self._update_state_mint_public_txn(
                        txn, is_committed=isCommitted)

                if typ == XFER_PUBLIC:
                    self._update_state_xfer_public(txn,
                                                   is_committed=isCommitted)
        except UTXOError as ex:
            error = 'Exception {} while updating state'.format(ex)
            raise OperationError(error)

    def _spend_input(self, address, seq_no, is_committed=False):
        self.spend_input(self.state,
                         self.utxo_cache,
                         address,
                         seq_no,
                         is_committed=is_committed)

    def _add_new_output(self, output: Output, is_committed=False):
        self.add_new_output(self.state,
                            self.utxo_cache,
                            output,
                            is_committed=is_committed)

    def onBatchCreated(self, state_root, txn_time):
        self.on_batch_created(self.utxo_cache, self.tracker, self.ledger,
                              state_root)

    def onBatchRejected(self):
        self.on_batch_rejected(self.utxo_cache, self.tracker, self.state,
                               self.ledger)

    def commit(self, txnCount, stateRoot, txnRoot, pptime) -> List:
        uncommitted_state, uncommitted_txn_root, _ = self.tracker.commit_batch(
        )
        return self.__commit__(self.utxo_cache, self.ledger, self.state,
                               txnCount, stateRoot, txnRoot, pptime,
                               self.ts_store)

    def get_query_response(self, request: Request):
        return self.query_handlers[request.operation[TXN_TYPE]](request)

    def get_all_utxo(self, request: Request):
        address = request.operation[ADDRESS]
        encoded_root_hash = state_roots_serializer.serialize(
            bytes(self.state.committedHeadHash))
        proof, rv = self.state.generate_state_proof_for_keys_with_prefix(
            address, serialize=True, get_value=True)
        multi_sig = self.bls_store.get(encoded_root_hash)
        if multi_sig:
            encoded_proof = proof_nodes_serializer.serialize(proof)
            proof = {
                MULTI_SIGNATURE: multi_sig.as_dict(),
                ROOT_HASH: encoded_root_hash,
                PROOF_NODES: encoded_proof
            }
        else:
            proof = {}

        # The outputs need to be returned in sorted order since each node's reply should be same.
        # Since no of outputs can be large, a concious choice to not use `operator.attrgetter` on an
        # already constructed list was made
        outputs = SortedItems()
        for k, v in rv.items():
            addr, seq_no = self.parse_state_key(k.decode())
            amount = rlp_decode(v)[0]
            if not amount:
                continue
            outputs.add(Output(addr, int(seq_no), int(amount)))

        result = {
            f.IDENTIFIER.nm: request.identifier,
            f.REQ_ID.nm: request.reqId,
            OUTPUTS: outputs.sorted_list
        }
        if proof:
            result[STATE_PROOF] = proof

        result.update(request.operation)
        return result

    def _sum_inputs(self, req: Request, is_committed=False) -> int:
        return self.sum_inputs(self.utxo_cache, req, is_committed=is_committed)

    @staticmethod
    def create_state_key(address: str, seq_no: int) -> bytes:
        return ':'.join([address, str(seq_no)]).encode()

    @staticmethod
    def parse_state_key(key: str) -> List[str]:
        return key.split(':')

    @staticmethod
    def sum_inputs(utxo_cache: UTXOCache,
                   request: Request,
                   is_committed=False) -> int:
        try:
            inputs = request.operation[INPUTS]
            return utxo_cache.sum_inputs(inputs, is_committed=is_committed)
        except UTXOError as ex:
            raise InvalidFundsError(request.identifier, request.reqId,
                                    '{}'.format(ex))

    @staticmethod
    def sum_outputs(request: Request) -> int:
        return sum(o["amount"] for o in request.operation[OUTPUTS])

    @staticmethod
    def spend_input(state, utxo_cache, address, seq_no, is_committed=False):
        state_key = TokenReqHandler.create_state_key(address, seq_no)
        state.set(state_key, b'')
        utxo_cache.spend_output(Output(address, seq_no, None),
                                is_committed=is_committed)

    @staticmethod
    def add_new_output(state, utxo_cache, output: Output, is_committed=False):
        address = output.address
        seq_no = output.seqNo
        amount = output.amount
        state_key = TokenReqHandler.create_state_key(address, seq_no)
        state.set(state_key, str(amount).encode())
        utxo_cache.add_output(output, is_committed=is_committed)

    @staticmethod
    def __commit__(utxo_cache,
                   ledger,
                   state,
                   txnCount,
                   stateRoot,
                   txnRoot,
                   ppTime,
                   ts_store=None):
        r = LedgerRequestHandler._commit(ledger,
                                         state,
                                         txnCount,
                                         stateRoot,
                                         txnRoot,
                                         ppTime,
                                         ts_store=ts_store)
        TokenReqHandler._commit_to_utxo_cache(utxo_cache, stateRoot)
        return r

    @staticmethod
    def _commit_to_utxo_cache(utxo_cache, state_root):
        state_root = base58.b58decode(state_root.encode()) if isinstance(
            state_root, str) else state_root
        if utxo_cache.first_batch_idr != state_root:
            raise TokenValueError(
                'state_root', state_root,
                ("equal to utxo_cache.first_batch_idr hash {}".format(
                    utxo_cache.first_batch_idr)))
        utxo_cache.commit_batch()

    @staticmethod
    def on_batch_created(utxo_cache, tracker: LedgerUncommittedTracker,
                         ledger: Ledger, state_root):
        tracker.apply_batch(state_root, ledger.uncommitted_root_hash,
                            ledger.uncommitted_size)
        utxo_cache.create_batch_from_current(state_root)

    @staticmethod
    def on_batch_rejected(utxo_cache, tracker: LedgerUncommittedTracker,
                          state: PruningState, ledger: Ledger):
        uncommitted_hash, uncommitted_txn_root, txn_count = tracker.reject_batch(
        )
        if txn_count == 0 or ledger.uncommitted_root_hash == uncommitted_txn_root or \
            state.headHash == uncommitted_hash:
            return 0
        state.revertToHead(uncommitted_hash)
        ledger.discardTxns(txn_count)

        utxo_cache.reject_batch()
        return txn_count