def test_prove(self):
        bpi = bp.BulletProofBuilder()
        val = crypto.Scalar(123)
        mask = crypto.Scalar(432)

        bp_res = bpi.prove(val, mask)
        bpi.verify(bp_res)
    def mask_consistency_check(self, bpi):
        sv = [crypto.Scalar(123)]
        gamma = [crypto.Scalar(432)]

        bpi.prove_setup(sv, gamma)
        x = bp._ensure_dst_key()
        y = bp._ensure_dst_key()

        sL = bpi.sL_vct(64)
        sR = bpi.sR_vct(64)

        self.assertEqual(sL.to(0, x), sL.to(0, y))
        self.assertEqual(sL.to(1, x), sL.to(1, y))
        self.assertEqual(sL.to(63, x), sL.to(63, y))
        self.assertNotEqual(sL.to(1, x), sL.to(0, y))
        self.assertNotEqual(sL.to(10, x), sL.to(0, y))

        self.assertEqual(sR.to(0, x), sR.to(0, y))
        self.assertEqual(sR.to(1, x), sR.to(1, y))
        self.assertEqual(sR.to(63, x), sR.to(63, y))
        self.assertNotEqual(sR.to(1, x), sR.to(0, y))

        self.assertNotEqual(sL.to(0, x), sR.to(0, y))
        self.assertNotEqual(sL.to(1, x), sR.to(1, y))
        self.assertNotEqual(sL.to(63, x), sR.to(63, y))

        ve1 = bp._ensure_dst_key()
        ve2 = bp._ensure_dst_key()
        bpi.vector_exponent(bpi.aL, bpi.aR, ve1)
        bpi.vector_exponent(bpi.aL, bpi.aR, ve2)

        bpi.vector_exponent(sL, sR, ve1)
        bpi.vector_exponent(sL, sR, ve2)
        self.assertEqual(ve1, ve2)
    def test_bpp_bprime(self):
        N, M = 64, 4
        MN = N * M
        y = unhexlify(
            b'60421950bee0aab949e63336db1eb9532dba6b4599c5cd9fb1dbde909114100e'
        )
        z = unhexlify(
            b'e0408b528e9d35ccb8386b87f39b85c724740644f4db412483a8852cdb3ceb00'
        )
        zc = crypto.decodeint_into(None, z)
        z_sq = bp._sc_mul(None, z, z)
        sv = [1234, 8789, 4455, 6697]
        sv = [crypto.encodeint_into(None, crypto.Scalar(x)) for x in sv]

        num_inp = len(sv)
        sc_zero = crypto.decodeint_into_noreduce(None, bp._ZERO)
        sc_mone = crypto.decodeint_into_noreduce(None, bp._MINUS_ONE)

        def e_xL(idx, d=None):
            j, i = idx // bp._BP_N, idx % bp._BP_N
            r = None
            if j >= num_inp:
                r = sc_mone
            elif sv[j][i // 8] & (1 << i % 8):
                r = sc_zero
            else:
                r = sc_mone
            if d:
                return crypto.sc_copy(d, r)
            return r

        aR = bp.KeyVEval(MN, lambda i, d: e_xL(i, d), raw=True)
        d_vct = bp.VctD(N, M, z_sq, raw=True)
        ypow_back = bp.KeyVPowersBackwards(MN + 1, y, raw=True)
        aR1_sc1 = crypto.Scalar()

        def aR1_fnc(i, d):
            crypto.sc_add_into(aR1_sc1, aR.to(i), zc)
            crypto.sc_muladd_into(aR1_sc1, d_vct[i], ypow_back[MN - i],
                                  aR1_sc1)
            return crypto.encodeint_into(d, aR1_sc1)

        bprime = bp.KeyVEval(MN, aR1_fnc, raw=False)  # aR1
        b64 = bp._copy_key(None, bprime.to(64))
        b65 = bp._copy_key(None, bprime.to(65))
        b128 = bp._copy_key(None, bprime.to(128))
        b65_2 = bp._copy_key(None, bprime.to(65))
        b64_2 = bp._copy_key(None, bprime.to(64))
        _ = bprime[89]
        b128_2 = bp._copy_key(None, bprime.to(128))

        self.assertEqual(b64, b64_2)
        self.assertEqual(b65, b65_2)
        self.assertEqual(b128, b128_2)
    def test_prove_2(self):
        bpi = bp.BulletProofBuilder()
        val = crypto.Scalar((1 << 30) - 1 + 16)
        mask = crypto.random_scalar()

        bp_res = bpi.prove(val, mask)
        bpi.verify(bp_res)
Exemple #5
0
 def test_clsag_invalid_Cp(self):
     res = self.gen_clsag_sig(ring_size=11, index=5)
     msg, scalars, sc1, sI, sD, ring2, Cp = res
     with self.assertRaises(ValueError):
         Cp = crypto.point_add_into(
             None, Cp, crypto.scalarmult_base_into(None, crypto.Scalar(1)))
         self.verify_clsag(msg, scalars, sc1, sI, sD, ring2, Cp)
    def test_prove_random_masks(self):
        bpi = bp.BulletProofBuilder()
        bpi.use_det_masks = False  # trully randomly generated mask vectors
        val = crypto.Scalar((1 << 30) - 1 + 16)
        mask = crypto.random_scalar()

        bp_res = bpi.prove(val, mask)
        bpi.verify(bp_res)
    def test_sc_inversion(self):
        res = crypto.Scalar()
        inp = crypto_helpers.decodeint(
            unhexlify(
                b"3482fb9735ef879fcae5ec7721b5d3646e155c4fb58d6cc11c732c9c9b76620a"
            ))

        crypto.sc_inv_into(res, inp)
        self.assertEqual(
            hexlify(crypto_helpers.encodeint(res)),
            b"bcf365a551e6358f3f281a6241d4a25eded60230b60a1d48c67b51a85e33d70e",
        )
def _ecdh_encode(amount: int, amount_key: bytes) -> EcdhTuple:
    """
    Output recipients decode amounts from EcdhTuple structure.
    """
    from apps.monero.xmr.serialize_messages.tx_ecdh import EcdhTuple

    ecdh_info = EcdhTuple(mask=crypto_helpers.NULL_KEY_ENC,
                          amount=bytearray(32))
    amnt = crypto.Scalar(amount)
    crypto.encodeint_into(ecdh_info.amount, amnt)
    crypto_helpers.xor8(ecdh_info.amount, _ecdh_hash(amount_key))
    return ecdh_info
    def ctest_multiexp(self):
        scalars = [0, 1, 2, 3, 4, 99]
        point_base = [0, 2, 4, 7, 12, 18]
        scalar_sc = [crypto.Scalar(x) for x in scalars]
        points = [
            crypto.scalarmult_base_into(None, crypto.Scalar(x))
            for x in point_base
        ]

        muex = bp.MultiExp(
            scalars=[crypto.encodeint(x) for x in scalar_sc],
            point_fnc=lambda i, d: crypto.encodepoint(points[i]))

        self.assertEqual(len(muex), len(scalars))
        res = bp.multiexp(None, muex)
        res2 = bp.vector_exponent_custom(
            A=bp.KeyVEval(
                3, lambda i, d: crypto.encodepoint_into(
                    crypto.scalarmult_base_into(
                        None, crypto.Scalar(point_base[i])), d)),
            B=bp.KeyVEval(
                3, lambda i, d: crypto.encodepoint_into(
                    crypto.scalarmult_base_into(
                        None, crypto.Scalar(point_base[3 + i])), d)),
            a=bp.KeyVEval(
                3,
                lambda i, d: crypto.encodeint_into(crypto.Scalar(scalars[i]), d
                                                   ),
            ),
            b=bp.KeyVEval(
                3, lambda i, d: crypto.encodeint_into(
                    crypto.Scalar(scalars[i + 3]), d)),
        )
        self.assertEqual(res, res2)
def prove_range_bp_batch(
        amounts: list[int],
        masks: list[crypto.Scalar],
        bp_plus: bool = False) -> Bulletproof | BulletproofPlus:
    """Calculates Bulletproof in batches"""
    from apps.monero.xmr import bulletproof as bp

    bpi = bp.BulletProofPlusBuilder() if bp_plus else bp.BulletProofBuilder()
    bp_proof = bpi.prove_batch([crypto.Scalar(a) for a in amounts], masks)
    del (bpi, bp)
    gc.collect()

    return bp_proof
    def test_dvct_skips(self):
        z_sq = unhexlify(
            b'e0408b528e9d35ccb8386b87f39b85c724740644f4db412483a8852cdb3ceb00'
        )
        d_vct0 = bp.VctD(64, 8, z_sq, raw=True)
        d_vct1 = bp.VctD(64, 8, z_sq, raw=True)
        tmp = crypto.Scalar()

        # Linear scan vs jump
        for i in range(65):
            tmp = d_vct0[i]
        self.assertEqual(crypto.encodeint_into(None, tmp),
                         crypto.encodeint_into(None, d_vct1[64]))

        # Jumping around
        _ = d_vct0[128]
        self.assertEqual(crypto.encodeint_into(None, d_vct0[64]),
                         crypto.encodeint_into(None, d_vct1[64]))

        # Sync on the same jump
        self.assertEqual(crypto.encodeint_into(None, d_vct0[65]),
                         crypto.encodeint_into(None, d_vct1[65]))
        self.assertEqual(crypto.encodeint_into(None, d_vct0[65]),
                         crypto.encodeint_into(None, d_vct1[65]))

        # Jump vs linear again, move_one vs move_more
        for i in range(1, 10):
            tmp = d_vct0[65 + i]
        self.assertEqual(crypto.encodeint_into(None, tmp),
                         crypto.encodeint_into(None, d_vct1[74]))

        _ = d_vct0[85]
        _ = d_vct1[89]  # different jump sizes, internal state management test
        self.assertEqual(crypto.encodeint_into(None, d_vct0[95]),
                         crypto.encodeint_into(None, d_vct1[95]))

        _ = d_vct0[
            319]  # move_one mults by z_sq then; enforce z component updates
        self.assertEqual(crypto.encodeint_into(None, d_vct0[320]),
                         crypto.encodeint_into(None, d_vct1[320]))

        tmp = crypto.sc_copy(None, d_vct0[64])  # another jump back and forth
        _ = d_vct0[127]
        self.assertEqual(crypto.encodeint_into(None, d_vct0[64]),
                         crypto.encodeint_into(None, tmp))

        _ = d_vct0[0]
        _ = d_vct1[0]
        _ = d_vct0[64]
        self.assertEqual(crypto.encodeint_into(None, d_vct0[5]),
                         crypto.encodeint_into(None, d_vct1[5]))
Exemple #12
0
async def get_tx_keys(ctx: wire.Context, msg: MoneroGetTxKeyRequest,
                      keychain: Keychain) -> MoneroGetTxKeyAck:
    await paths.validate_path(ctx, keychain, msg.address_n)

    do_deriv = msg.reason == _GET_TX_KEY_REASON_TX_DERIVATION
    await layout.require_confirm_tx_key(ctx, export_key=not do_deriv)

    creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

    tx_enc_key = misc.compute_tx_key(
        creds.spend_key_private,
        msg.tx_prefix_hash,
        msg.salt1,
        crypto_helpers.decodeint(msg.salt2),
    )

    # the plain_buff first stores the tx_priv_keys as decrypted here
    # and then is used to store the derivations if applicable
    plain_buff = chacha_poly.decrypt_pack(tx_enc_key, msg.tx_enc_keys)
    utils.ensure(len(plain_buff) % 32 == 0, "Tx key buffer has invalid size")
    msg.tx_enc_keys = b""

    # If return only derivations do tx_priv * view_pub
    if do_deriv:
        if msg.view_public_key is None:
            raise wire.DataError("Missing view public key")

        plain_buff = bytearray(plain_buff)
        view_pub = crypto_helpers.decodepoint(msg.view_public_key)
        tx_priv = crypto.Scalar()
        derivation = crypto.Point()
        n_keys = len(plain_buff) // 32
        for c in range(n_keys):
            crypto.decodeint_into(tx_priv, plain_buff, 32 * c)
            crypto.scalarmult_into(derivation, view_pub, tx_priv)
            crypto.encodepoint_into(plain_buff, derivation, 32 * c)

    # Encrypt by view-key based password.
    tx_enc_key_host, salt = misc.compute_enc_key_host(creds.view_key_private,
                                                      msg.tx_prefix_hash)

    res = chacha_poly.encrypt_pack(tx_enc_key_host, plain_buff)
    res_msg = MoneroGetTxKeyAck(salt=salt)
    if do_deriv:
        res_msg.tx_derivations = res
        return res_msg

    res_msg.tx_keys = res
    return res_msg
async def all_inputs_set(state: State) -> MoneroTransactionAllInputsSetAck:
    state.mem_trace(0)

    await layout.transaction_step(state, state.STEP_ALL_IN)

    from trezor.messages import MoneroTransactionAllInputsSetAck

    if state.last_step != state.STEP_VINI:
        raise ValueError("Invalid state transition")
    if state.current_input_index != state.input_count - 1:
        raise ValueError("Invalid input count")

    # The sum of the masks must match the input masks sum.
    state.sumout = crypto.Scalar()
    state.last_step = state.STEP_ALL_IN
    resp = MoneroTransactionAllInputsSetAck()
    return resp
Exemple #14
0
    def __init__(self, ctx: Context) -> None:
        from apps.monero.xmr.keccak_hasher import KeccakXmrArchive
        from apps.monero.xmr.mlsag_hasher import PreMlsagHasher

        self.ctx = ctx
        """
        Account credentials
        type: AccountCreds
        - view private/public key
        - spend private/public key
        - and its corresponding address
        """
        self.creds: AccountCreds | None = None

        # HMAC/encryption keys used to protect offloaded data
        self.key_hmac: bytes | None = None
        self.key_enc: bytes | None = None
        """
        Transaction keys
        - also denoted as r/R
        - tx_priv is a random number
        - tx_pub is equal to `r*G` or `r*D` for subaddresses
        - for subaddresses the `r` is commonly denoted as `s`, however it is still just a random number
        - the keys are used to derive the one time address and its keys (P = H(A*r)*G + B)
        """
        self.tx_priv: Scalar | None = None
        self.tx_pub: Point | None = None
        """
        In some cases when subaddresses are used we need more tx_keys
        (explained in step 1).
        """
        self.need_additional_txkeys = False

        # Connected client version
        self.client_version = 0
        self.hard_fork = 12

        self.input_count: int | None = 0
        self.output_count = 0
        self.progress_total = 0
        self.progress_cur = 0

        self.output_change: "MoneroTransactionDestinationEntry" | None = None
        self.fee: int | None = 0
        self.tx_type = 0

        # wallet sub-address major index
        self.account_idx: int | None = 0

        # contains additional tx keys if need_additional_tx_keys is True
        self.additional_tx_private_keys: list[Scalar] = []
        self.additional_tx_public_keys: list[bytes] | None = []

        # currently processed input/output index
        self.current_input_index = -1
        self.current_output_index: int | None = -1
        self.is_processing_offloaded = False

        # for pseudo_out recomputation from new mask
        self.input_last_amount: int | None = 0

        self.summary_inputs_money: int | None = 0
        self.summary_outs_money: int | None = 0

        # output commitments
        self.output_pk_commitments: list[bytes] | None = []

        self.output_amounts: list[int] | None = []
        # output *range proof* masks. HP10+ makes them deterministic.
        self.output_masks: list[Scalar] | None = []

        # the range proofs are calculated in batches, this denotes the grouping
        self.rsig_grouping: list[int] | None = []
        # is range proof computing offloaded or not
        self.rsig_offload: bool | None = False
        self.rsig_is_bp_plus: bool | None = False

        # sum of all inputs' pseudo out masks
        self.sumpouts_alphas: Scalar = crypto.Scalar(0)
        # sum of all output' pseudo out masks
        self.sumout: Scalar = crypto.Scalar(0)

        self.subaddresses: Subaddresses | None = {}

        # TX_EXTRA_NONCE extra field for tx.extra, due to sort_tx_extra()
        self.extra_nonce: bytes | None = None

        # Last key image seen. Used for input permutation correctness check
        self.last_ki: bytes | None = None

        # Encryption key to release to host after protocol ends without error
        self.opening_key: bytes | None = None

        # Step transition automaton
        self.last_step: int | None = self.STEP_INIT
        """
        Tx prefix hasher/hash. We use the hasher to incrementally hash and then
        store the final hash in tx_prefix_hash.
        See Monero-Trezor documentation section 3.3 for more details.
        """
        self.tx_prefix_hasher: KeccakXmrArchive | None = KeccakXmrArchive()
        self.tx_prefix_hash: bytes | None = None
        """
        Full message hasher/hash that is to be signed using MLSAG.
        Contains tx_prefix_hash.
        See Monero-Trezor documentation section 3.3 for more details.
        """
        self.full_message_hasher: PreMlsagHasher | None = PreMlsagHasher()
        self.full_message: bytes | None = None
Exemple #15
0
    async def diag(ctx, msg, **kwargs) -> Failure:
        log.debug(__name__, "----diagnostics")
        gc.collect()

        if msg.ins == 0:
            check_mem(0)
            return retit()

        elif msg.ins == 1:
            check_mem(1)
            micropython.mem_info(1)
            return retit()

        elif msg.ins == 2:
            log.debug(__name__, "_____________________________________________")
            log.debug(__name__, "_____________________________________________")
            log.debug(__name__, "_____________________________________________")
            return retit()

        elif msg.ins == 3:
            pass

        elif msg.ins == 4:
            total = 0
            monero = 0

            for k, v in sys.modules.items():
                log.info(__name__, "Mod[%s]: %s", k, v)
                total += 1
                if k.startswith("apps.monero"):
                    monero += 1
            log.info(__name__, "Total modules: %s, Monero modules: %s", total, monero)
            return retit()

        elif msg.ins in [5, 6, 7]:
            check_mem()
            from apps.monero.xmr import bulletproof as bp

            check_mem("BP Imported")
            from apps.monero.xmr import crypto

            check_mem("Crypto Imported")

            bpi = bp.BulletProofBuilder()
            bpi.gc_fnc = gc.collect
            bpi.gc_trace = log_trace

            vals = [crypto.Scalar((1 << 30) - 1 + 16), crypto.Scalar(22222)]
            masks = [crypto.random_scalar(), crypto.random_scalar()]
            check_mem("BP pre input")

            if msg.ins == 5:
                bp_res = bpi.prove_testnet(vals[0], masks[0])
                check_mem("BP post prove")
                bpi.verify_testnet(bp_res)
                check_mem("BP post verify")

            elif msg.ins == 6:
                bp_res = bpi.prove(vals[0], masks[0])
                check_mem("BP post prove")
                bpi.verify(bp_res)
                check_mem("BP post verify")

            elif msg.ins == 7:
                bp_res = bpi.prove_batch(vals, masks)
                check_mem("BP post prove")
                bpi.verify(bp_res)
                check_mem("BP post verify")

            return retit()

        return retit()
Exemple #16
0
    def gen_clsag_sig(self, ring_size=11, index=None):
        msg = random.bytes(32)
        amnt = crypto.Scalar(random.uniform(0xFFFFFF) + 12)
        priv = crypto.random_scalar()
        msk = crypto.random_scalar()
        alpha = crypto.random_scalar()
        P = crypto.scalarmult_base_into(None, priv)
        C = crypto.add_keys2_into(None, msk, amnt, crypto.xmr_H())
        Cp = crypto.add_keys2_into(None, alpha, amnt, crypto.xmr_H())

        ring = []
        for i in range(ring_size - 1):
            tk = TmpKey(
                crypto_helpers.encodepoint(
                    crypto.scalarmult_base_into(None, crypto.random_scalar())),
                crypto_helpers.encodepoint(
                    crypto.scalarmult_base_into(None, crypto.random_scalar())),
            )
            ring.append(tk)

        index = index if index is not None else random.uniform(len(ring))
        ring.insert(
            index,
            TmpKey(crypto_helpers.encodepoint(P),
                   crypto_helpers.encodepoint(C)))
        ring2 = list(ring)
        mg_buffer = []

        self.assertTrue(
            crypto.point_eq(
                crypto.scalarmult_base_into(None, priv),
                crypto_helpers.decodepoint(ring[index].dest),
            ))
        self.assertTrue(
            crypto.point_eq(
                crypto.scalarmult_base_into(
                    None, crypto.sc_sub_into(None, msk, alpha)),
                crypto.point_sub_into(
                    None, crypto_helpers.decodepoint(ring[index].commitment),
                    Cp),
            ))

        clsag.generate_clsag_simple(
            msg,
            ring,
            CtKey(priv, msk),
            alpha,
            Cp,
            index,
            mg_buffer,
        )

        sD = crypto_helpers.decodepoint(mg_buffer[-1])
        sc1 = crypto_helpers.decodeint(mg_buffer[-2])
        scalars = [crypto_helpers.decodeint(x) for x in mg_buffer[1:-2]]
        H = crypto.Point()
        sI = crypto.Point()

        crypto.hash_to_point_into(H, crypto_helpers.encodepoint(P))
        crypto.scalarmult_into(sI, H, priv)  # I = p*H
        return msg, scalars, sc1, sI, sD, ring2, Cp
Exemple #17
0
    def verify_clsag(self, msg, ss, sc1, sI, sD, pubs, C_offset):
        n = len(pubs)
        c = crypto.Scalar()
        D_8 = crypto.Point()
        tmp_bf = bytearray(32)
        C_offset_bf = crypto_helpers.encodepoint(C_offset)

        crypto.sc_copy(c, sc1)
        point_mul8_into(D_8, sD)

        hsh_P = crypto_helpers.get_keccak()  # domain, I, D, P, C, C_offset
        hsh_C = crypto_helpers.get_keccak()  # domain, I, D, P, C, C_offset
        hsh_P.update(clsag._HASH_KEY_CLSAG_AGG_0)
        hsh_C.update(clsag._HASH_KEY_CLSAG_AGG_1)

        def hsh_PC(x):
            hsh_P.update(x)
            hsh_C.update(x)

        for x in pubs:
            hsh_PC(x.dest)

        for x in pubs:
            hsh_PC(x.commitment)

        hsh_PC(crypto.encodepoint_into(tmp_bf, sI))
        hsh_PC(crypto.encodepoint_into(tmp_bf, sD))
        hsh_PC(C_offset_bf)
        mu_P = crypto_helpers.decodeint(hsh_P.digest())
        mu_C = crypto_helpers.decodeint(hsh_C.digest())

        c_to_hash = crypto_helpers.get_keccak(
        )  # domain, P, C, C_offset, message, L, R
        c_to_hash.update(clsag._HASH_KEY_CLSAG_ROUND)
        for i in range(len(pubs)):
            c_to_hash.update(pubs[i].dest)
        for i in range(len(pubs)):
            c_to_hash.update(pubs[i].commitment)
        c_to_hash.update(C_offset_bf)
        c_to_hash.update(msg)

        c_p = crypto.Scalar()
        c_c = crypto.Scalar()
        L = crypto.Point()
        R = crypto.Point()
        tmp_pt = crypto.Point()
        i = 0
        while i < n:
            crypto.sc_mul_into(c_p, mu_P, c)
            crypto.sc_mul_into(c_c, mu_C, c)

            C_P = crypto.point_sub_into(
                None, crypto.decodepoint_into(tmp_pt, pubs[i].commitment),
                C_offset)
            crypto.add_keys2_into(
                L, ss[i], c_p, crypto.decodepoint_into(tmp_pt, pubs[i].dest))
            crypto.point_add_into(L, L,
                                  crypto.scalarmult_into(tmp_pt, C_P, c_c))

            HP = crypto.hash_to_point_into(None, pubs[i].dest)
            crypto.add_keys3_into(R, ss[i], HP, c_p, sI)
            crypto.point_add_into(R, R,
                                  crypto.scalarmult_into(tmp_pt, D_8, c_c))

            chasher = c_to_hash.copy()
            chasher.update(crypto.encodepoint_into(tmp_bf, L))
            chasher.update(crypto.encodepoint_into(tmp_bf, R))
            crypto.decodeint_into(c, chasher.digest())
            i += 1
        res = crypto.sc_sub_into(None, c, sc1)
        if not crypto.sc_eq(res, crypto.Scalar(0)):
            raise ValueError("Signature error")
 def test_prove_plus_2(self):
     bpi = bp.BulletProofPlusBuilder()
     sv = [crypto.Scalar(123), crypto.Scalar(768)]
     gamma = [crypto.Scalar(456), crypto.Scalar(901)]
     proof = bpi.prove_batch(sv, gamma)
     bpi.verify_batch([proof])
 def test_prove_batch16(self):
     bpi = bp.BulletProofBuilder()
     sv = [crypto.Scalar(137 * i) for i in range(16)]
     gamma = [crypto.Scalar(991 * i) for i in range(16)]
     proof = bpi.prove_batch(sv, gamma)
     bpi.verify_batch([proof])
 def test_prove_plus_16(self):
     bpi = bp.BulletProofPlusBuilder()
     sv = [crypto.Scalar(i * 123 + 45) for i in range(16)]
     gamma = [crypto.Scalar(i * 456 * 17) for i in range(16)]
     proof = bpi.prove_batch(sv, gamma)
     bpi.verify_batch([proof])
Exemple #21
0
def _generate_clsag(
    message: bytes,
    P: list[bytes],
    p: crypto.Scalar,
    C_nonzero: list[bytes],
    z: crypto.Scalar,
    Cout: crypto.Point,
    index: int,
    mg_buff: list[bytearray],
) -> list[bytes]:
    sI = crypto.Point()  # sig.I
    sD = crypto.Point()  # sig.D
    sc1 = crypto.Scalar()  # sig.c1
    a = crypto.random_scalar()
    H = crypto.Point()
    D = crypto.Point()
    Cout_bf = crypto_helpers.encodepoint(Cout)

    tmp_sc = crypto.Scalar()
    tmp = crypto.Point()
    tmp_bf = bytearray(32)

    crypto.hash_to_point_into(H, P[index])
    crypto.scalarmult_into(sI, H, p)  # I = p*H
    crypto.scalarmult_into(D, H, z)  # D = z*H
    crypto.sc_mul_into(tmp_sc, z, crypto_helpers.INV_EIGHT_SC)  # 1/8*z
    crypto.scalarmult_into(sD, H, tmp_sc)  # sig.D = 1/8*z*H
    sD = crypto_helpers.encodepoint(sD)

    hsh_P = crypto_helpers.get_keccak()  # domain, I, D, P, C, C_offset
    hsh_C = crypto_helpers.get_keccak()  # domain, I, D, P, C, C_offset
    hsh_P.update(_HASH_KEY_CLSAG_AGG_0)
    hsh_C.update(_HASH_KEY_CLSAG_AGG_1)

    def hsh_PC(x):
        nonlocal hsh_P, hsh_C
        hsh_P.update(x)
        hsh_C.update(x)

    for x in P:
        hsh_PC(x)

    for x in C_nonzero:
        hsh_PC(x)

    hsh_PC(crypto.encodepoint_into(tmp_bf, sI))
    hsh_PC(sD)
    hsh_PC(Cout_bf)
    mu_P = crypto_helpers.decodeint(hsh_P.digest())
    mu_C = crypto_helpers.decodeint(hsh_C.digest())

    del (hsh_PC, hsh_P, hsh_C)
    c_to_hash = crypto_helpers.get_keccak()  # domain, P, C, C_offset, message, aG, aH
    c_to_hash.update(_HASH_KEY_CLSAG_ROUND)
    for i in range(len(P)):
        c_to_hash.update(P[i])
    for i in range(len(P)):
        c_to_hash.update(C_nonzero[i])
    c_to_hash.update(Cout_bf)
    c_to_hash.update(message)

    chasher = c_to_hash.copy()
    crypto.scalarmult_base_into(tmp, a)
    chasher.update(crypto.encodepoint_into(tmp_bf, tmp))  # aG
    crypto.scalarmult_into(tmp, H, a)
    chasher.update(crypto.encodepoint_into(tmp_bf, tmp))  # aH
    c = crypto_helpers.decodeint(chasher.digest())
    del (chasher, H)

    L = crypto.Point()
    R = crypto.Point()
    c_p = crypto.Scalar()
    c_c = crypto.Scalar()
    i = (index + 1) % len(P)
    if i == 0:
        crypto.sc_copy(sc1, c)

    mg_buff.append(int_serialize.dump_uvarint_b(len(P)))
    for _ in range(len(P)):
        mg_buff.append(bytearray(32))

    while i != index:
        crypto.random_scalar(tmp_sc)
        crypto.encodeint_into(mg_buff[i + 1], tmp_sc)

        crypto.sc_mul_into(c_p, mu_P, c)
        crypto.sc_mul_into(c_c, mu_C, c)

        # L = tmp_sc * G + c_P * P[i] + c_c * C[i]
        crypto.add_keys2_into(L, tmp_sc, c_p, crypto.decodepoint_into(tmp, P[i]))
        crypto.decodepoint_into(tmp, C_nonzero[i])  # C = C_nonzero - Cout
        crypto.point_sub_into(tmp, tmp, Cout)
        crypto.scalarmult_into(tmp, tmp, c_c)
        crypto.point_add_into(L, L, tmp)

        # R = tmp_sc * HP + c_p * I + c_c * D
        crypto.hash_to_point_into(tmp, P[i])
        crypto.add_keys3_into(R, tmp_sc, tmp, c_p, sI)
        crypto.point_add_into(R, R, crypto.scalarmult_into(tmp, D, c_c))

        chasher = c_to_hash.copy()
        chasher.update(crypto.encodepoint_into(tmp_bf, L))
        chasher.update(crypto.encodepoint_into(tmp_bf, R))
        crypto.decodeint_into(c, chasher.digest())

        P[i] = None  # type: ignore
        C_nonzero[i] = None  # type: ignore

        i = (i + 1) % len(P)
        if i == 0:
            crypto.sc_copy(sc1, c)

        if i & 3 == 0:
            gc.collect()

    # Final scalar = a - c * (mu_P * p + mu_c * Z)
    crypto.sc_mul_into(tmp_sc, mu_P, p)
    crypto.sc_muladd_into(tmp_sc, mu_C, z, tmp_sc)
    crypto.sc_mulsub_into(tmp_sc, c, tmp_sc, a)
    crypto.encodeint_into(mg_buff[index + 1], tmp_sc)

    if TYPE_CHECKING:
        assert list_of_type(mg_buff, bytes)

    mg_buff.append(crypto_helpers.encodeint(sc1))
    mg_buff.append(sD)
    return mg_buff
Exemple #22
0
def generate_ring_signature(
    prefix_hash: bytes,
    image: crypto.Point,
    pubs: list[crypto.Point],
    sec: crypto.Scalar,
    sec_idx: int,
    test: bool = False,
) -> Sig:
    """
    Generates ring signature with key image.
    void crypto_ops::generate_ring_signature()
    """
    from trezor.utils import memcpy

    if test:
        t = crypto.scalarmult_base_into(None, sec)
        if not crypto.point_eq(t, pubs[sec_idx]):
            raise ValueError("Invalid sec key")

        k_i = monero.generate_key_image(
            crypto_helpers.encodepoint(pubs[sec_idx]), sec)
        if not crypto.point_eq(k_i, image):
            raise ValueError("Key image invalid")
        for k in pubs:
            crypto.ge25519_check(k)

    buff_off = len(prefix_hash)
    buff = bytearray(buff_off + 2 * 32 * len(pubs))
    memcpy(buff, 0, prefix_hash, 0, buff_off)
    mvbuff = memoryview(buff)

    sum = crypto.Scalar(0)
    k = crypto.Scalar(0)
    sig = []

    for _ in range(len(pubs)):
        sig.append([crypto.Scalar(0), crypto.Scalar(0)])  # c, r

    for i in range(len(pubs)):
        if i == sec_idx:
            k = crypto.random_scalar()
            tmp3 = crypto.scalarmult_base_into(None, k)
            crypto.encodepoint_into(mvbuff[buff_off:buff_off + 32], tmp3)
            buff_off += 32

            tmp3 = crypto.hash_to_point_into(
                None, crypto_helpers.encodepoint(pubs[i]))
            tmp2 = crypto.scalarmult_into(None, tmp3, k)
            crypto.encodepoint_into(mvbuff[buff_off:buff_off + 32], tmp2)
            buff_off += 32

        else:
            sig[i] = [crypto.random_scalar(), crypto.random_scalar()]
            tmp3 = pubs[i]
            tmp2 = crypto.ge25519_double_scalarmult_vartime_into(
                None, tmp3, sig[i][0], sig[i][1])
            crypto.encodepoint_into(mvbuff[buff_off:buff_off + 32], tmp2)
            buff_off += 32

            tmp3 = crypto.hash_to_point_into(None,
                                             crypto_helpers.encodepoint(tmp3))
            tmp2 = crypto.add_keys3_into(None, sig[i][1], tmp3, sig[i][0],
                                         image)
            crypto.encodepoint_into(mvbuff[buff_off:buff_off + 32], tmp2)
            buff_off += 32

            crypto.sc_add_into(sum, sum, sig[i][0])

    h = crypto.hash_to_scalar_into(None, buff)
    sig[sec_idx][0] = crypto.sc_sub_into(None, h, sum)
    sig[sec_idx][1] = crypto.sc_mulsub_into(None, sig[sec_idx][0], sec, k)
    return sig