Exemple #1
0
    def _finish_greater_than_equal(self, results, l):
        """Finish the calculation."""
        T = results[0]
        bit_bits = results[1:]

        vec = [(GF256(0), GF256(0))]

        # Calculate the vector, using only the first l bits
        for i, bi in enumerate(bit_bits[:l]):
            Ti = GF256(T.bit(i))
            ci = Share(self, GF256, bi ^ Ti)
            vec.append((ci, Ti))

        # Reduce using the diamond operator. We want to do as much
        # as possible in parallel while being careful not to
        # switch the order of elements since the diamond operator
        # is non-commutative.
        while len(vec) > 1:
            tmp = []
            while len(vec) > 1:
                tmp.append(self._diamond(vec.pop(0), vec.pop(0)))
            if len(vec) == 1:
                tmp.append(vec[0])
            vec = tmp

        return GF256(T.bit(l)) ^ (bit_bits[l] ^ vec[0][1])
Exemple #2
0
    def test_prss_share_random_bit(self, runtime):
        """Tests the sharing of a 0/1 GF256 element using PRSS."""
        a = runtime.prss_share_random(field=GF256, binary=True)
        self.assert_type(a, Share)

        opened_a = runtime.open(a)
        opened_a.addCallback(self.assertIn, [GF256(0), GF256(1)])
        return opened_a
Exemple #3
0
    def test_add(self):
        """Test addition."""
        self._test_binary_operator(operator.add, 0, 0, 0)
        self._test_binary_operator(operator.add, 1, 1, 0)
        self._test_binary_operator(operator.add, 100, 100, 0)
        self._test_binary_operator(operator.add, 0, 1, 1)
        self._test_binary_operator(operator.add, 1, 2, 3)

        a = GF256(10)
        a += GF256(10)
        self.assertEquals(a, GF256(0))
Exemple #4
0
    def test_prss_share_random_multi_bit(self, runtime):
        """Tests the sharing of several 0/1 GF256 elements using PRSS."""
        a_list = runtime.prss_share_random_multi(field=GF256,
                                                 quantity=8,
                                                 binary=True)

        for a in a_list:
            self.assert_type(a, Share)
            opened_a = runtime.open(a)
            opened_a.addCallback(self.assertIn, [GF256(0), GF256(1)])

        return gather_shares(a_list)
Exemple #5
0
    def key_expansion(self, key, new_length=None):
        """Rijndael key expansion.

        Input and output are lists of 4-byte columns (words).
        *new_length* is the round for which the key should be expanded.
        If ommitted, the key is expanded for all rounds."""

        assert len(key) >= self.n_k, "Wrong key size."
        assert len(key[0]) == 4, "Key must consist of 4-byte words."

        expanded_key = key

        if new_length == None:
            new_length = self.rounds

        for i in xrange(len(key), self.n_b * (new_length + 1)):
            temp = list(expanded_key[i - 1])

            if i % self.n_k == 0:
                temp.append(temp.pop(0))
                self.byte_sub([temp])
                temp[0] += GF256(2) ** (i / self.n_k - 1)
            elif self.n_k > 6 and i % self.n_k == 4:
                self.byte_sub([temp])

            new_word = []

            for j in xrange(4):
                new_word.append(expanded_key[i - self.n_k][j] + temp[j])

            expanded_key.append(new_word)

        return expanded_key
Exemple #6
0
def protocol(rt):
    l = 7

    rand = dict([(i, random.Random(i)) for i in players])

    inputs = []
    for i in range(count):
        input = dict([(j, rand[j].randint(0, pow(2, l))) for j in players])
        inputs.append(input)

    # Fixed input for easier debugging
    inputs.append({1: 20, 2: 25, 3: 0})

    print "I am player %d, will compare %d numbers" % (id, len(inputs))

    bits = []
    for input in inputs:
        x, y, z = rt.shamir_share([1, 2, 3], Zp, input[id])
        bit = rt.open(x >= y)
        bit.addCallback(lambda b: b == GF256(1))
        bit.addCallback(lambda b, x, y: "%3d >= %3d: %-5s (%s)" \
                            % (x, y, b, b == (x >= y)), input[1], input[2])
        dprint("%s", bit)
        bits.append(bit)

    results = gatherResults(bits)
    results.addCallback(lambda _: rt.shutdown())
Exemple #7
0
    def test_key_expansion(self, runtime):
        aes = AES(runtime, 256, quiet=True)
        key = []
        ascii_key = []

        for i in xrange(8):
            key.append([])

            for j in xrange(4):
                b = 15 * i + j
                key[i].append(Share(runtime, GF256, GF256(b)))
                ascii_key.append(chr(b))

        result = aes.key_expansion(key)

        r = rijndael(ascii_key)
        expected_result = []

        for round_key in r.Ke:
            for word in round_key:
                split_word = []
                expected_result.append(split_word)

                for j in xrange(4):
                    split_word.insert(0, word % 256)
                    word /= 256

        self.verify(runtime, result, expected_result)
Exemple #8
0
    def test_prss_share_zero_bit(self, runtime):
        """Tests the sharing of a zero GF256 element using PRSS."""
        a = runtime.prss_share_zero(GF256, 1)[0]
        self.assert_type(a, Share)

        opened_a = runtime.open(a, threshold=2 * runtime.threshold)
        opened_a.addCallback(self.assertEquals, GF256(0))
        return opened_a
Exemple #9
0
 def preprocess(self, input):
     if isinstance(input, str):
         return [Share(self.runtime, GF256, GF256(ord(c))) for c in input]
     else:
         for byte in input:
             assert byte.field == GF256, \
                 "Input must be a list of GF256 elements " \
                 "or of shares thereof."
         return input
Exemple #10
0
    def _test_binary_operator(self, operation, a, b, expected):
        """Test C{operation} with and without coerced operands."""
        result = operation(GF256(a), GF256(b))
        self.assertEquals(result, GF256(expected))

        result = operation(GF256(a), b)
        self.assertEquals(result, GF256(expected))

        result = operation(a, GF256(b))
        self.assertEquals(result, GF256(expected))
Exemple #11
0
def protocol(rt):
    print "Starting protocol"
    Zp = GF(11)
    a, b, c = rt.prss_share([1, 2, 3], Zp, 0)
    x, y, z = rt.prss_share([1, 2, 3], Zp, 1)

    a_b = rt.open(rt.convert_bit_share(a, GF256))
    b_b = rt.open(rt.convert_bit_share(b, GF256))
    c_b = rt.open(rt.convert_bit_share(c, GF256))

    x_b = rt.open(rt.convert_bit_share(x, GF256))
    y_b = rt.open(rt.convert_bit_share(y, GF256))
    z_b = rt.open(rt.convert_bit_share(z, GF256))

    def check(result, variable, expected):
        if result == expected:
            print "%s: %s (correct)" % (variable, result)
        else:
            print "%s: %s (incorrect, expected %d)" \
                % (variable, result, expected)

    a_b.addCallback(check, "a_b", GF256(0))
    b_b.addCallback(check, "b_b", GF256(0))
    c_b.addCallback(check, "c_b", GF256(0))

    x_b.addCallback(check, "x_b", GF256(1))
    y_b.addCallback(check, "y_b", GF256(1))
    z_b.addCallback(check, "z_b", GF256(1))

    rt.wait_for(a_b, b_b, c_b, x_b, y_b, z_b)
Exemple #12
0
 def test_convert_bit_share(self, runtime):
     """Test conversion 0/1 element conversion from Zp to GF256."""
     # TODO: test conversion from GF256 to Zp and between Zp and Zq
     # fields.
     results = []
     for value in 0, 1:
         share = Share(runtime, self.Zp, self.Zp(value))
         converted = runtime.convert_bit_share(share, GF256)
         self.assertEquals(converted.field, GF256)
         opened = runtime.open(converted)
         opened.addCallback(self.assertEquals, GF256(value))
         results.append(opened)
     return gatherResults(results)
Exemple #13
0
    def synchronize(self):
        """Introduce a synchronization point.

        Returns a :class:`Deferred` which will trigger if and when all
        other players have made their calls to :meth:`synchronize`. By
        adding callbacks to the returned :class:`Deferred`, one can
        divide a protocol execution into disjoint phases.
        """
        self.increment_pc()
        shares = [self._exchange_shares(player, GF256(0))
                  for player in self.players]
        result = gather_shares(shares)
        result.addCallback(lambda _: None)
        return result
Exemple #14
0
    def _test_byte_sub(self, runtime, aes):
        results = []
        expected_results = []

        for i in range(4):
            results.append([])
            expected_results.append([])

            for j in range(4):
                b = 60 * i + j
                results[i].append(Share(runtime, GF256, GF256(b)))
                expected_results[i].append(S[b])

        aes.byte_sub(results)
        self.verify(runtime, results, expected_results)
Exemple #15
0
def invert(rt):
    aes = AES(rt, 192, use_exponentiation=options.exponentiation)
    bytes = [
        Share(rt, GF256, GF256(random.randint(0, 255)))
        for i in range(options.count)
    ]

    start = time.time()

    done = gather_shares([aes.invert(byte) for byte in bytes])

    def finish(_):
        duration = time.time() - start
        print "Finished after %.3f s." % duration
        print "Time per inversion: %.3f ms" % (1000 * duration / options.count)
        rt.shutdown()

    done.addCallback(finish)
Exemple #16
0
    def prss_share_bit_double(self, field):
        """Share a random bit over *field* and GF256.

        The protocol is described in "Efficient Conversion of
        Secret-shared Values Between Different Fields" by Ivan Damgård
        and Rune Thorbek available as `Cryptology ePrint Archive,
        Report 2008/221 <http://eprint.iacr.org/2008/221>`__.
        """
        n = self.num_players
        k = self.options.security_parameter
        prfs = self.players[self.id].prfs(2**k)
        prss_key = self.prss_key()

        b_p = self.prss_share_random(field, binary=True)
        r_p, r_lsb = prss_lsb(n, self.id, field, prfs, prss_key)

        b = self.open(b_p + r_p)
        # Extract least significant bit and change field to GF256.
        b.addCallback(lambda i: GF256(i.value & 1))
        b.field = GF256

        # Use r_lsb to flip b as needed.
        return (b_p, b ^ r_lsb)
Exemple #17
0
    def prss_shamir_share_bit_double(self, field):
        """Shamir share a random bit over *field* and GF256."""
        n = self.num_players
        k = self.options.security_parameter
        prfs = self.players[self.id].prfs(2**k)
        prss_key = self.prss_key()
        inputters = range(1, self.num_players + 1)

        ri = rand.randint(0, 2**k - 1)
        ri_p = self.shamir_share(inputters, field, ri)
        ri_lsb = self.shamir_share(inputters, GF256, ri & 1)

        r_p = reduce(self.add, ri_p)
        r_lsb = reduce(self.add, ri_lsb)

        b_p = self.prss_share_random(field, binary=True)
        b = self.open(b_p + r_p)
        # Extract least significant bit and change field to GF256.
        b.addCallback(lambda i: GF256(i.value & 1))
        b.field = GF256

        # Use r_lsb to flip b as needed.
        return (b_p, b ^ r_lsb)
Exemple #18
0
    def decompose(byte, bits):
        value = byte.value

        for i in range(8):
            c_bits[i].callback(GF256(value & 1))
            value >>= 1
Exemple #19
0
 def test_prss_lsb(self):
     (share, bit) = prss.prss_lsb(None, None, self.field, None, None)
     self.assertEquals(share, self.field(7))
     self.assertEquals(bit, GF256(1))
Exemple #20
0
 def test_neg(self):
     """Test negation."""
     self.assertEquals(-GF256(0), GF256(0))
     self.assertEquals(-GF256(10), GF256(10))
     self.assertEquals(-GF256(100), GF256(100))
Exemple #21
0
 def test_str(self):
     """Test string conversion."""
     self.assertEquals(str(GF256(0)), "[0]")
     self.assertEquals(str(GF256(1)), "[1]")
     self.assertEquals(str(GF256(10)), "[10]")
Exemple #22
0
 def test_bit_decomposition(self, runtime):
     share = Share(runtime, GF256, GF256(99))
     return self.verify(runtime, bit_decompose(share),
                        [1, 1, 0, 0, 0, 1, 1, 0])
Exemple #23
0
 def check_outputs(outputs):
     for o in outputs:
         self.assertIn("opened a: %s" % GF256(17), o)
         self.assertIn("opened b: %s" % GF256(40), o)
         self.assertIn("opened c: %s" % GF256(235), o)
Exemple #24
0
    return convert_replicated_shamir(n, j, field, rep_shares)

def prss_multi(n, j, field, prfs, key, modulus, quantity):
    """Does the same as :meth:`prss`, but multiple times in order to
    call the PRFs less frequently.
    """
    prf_results = random_replicated_sharing(j, prfs, key)
    rep_shares_list = [[] for i in range(quantity)]
    for subset, result in prf_results:
        for i in range(quantity):
            rep_shares_list[i].append((subset, result % modulus))
            result /= modulus
    return [convert_replicated_shamir(n, j, field, rep_shares)
            for rep_shares in rep_shares_list]

@fake(lambda n, j, field, prfs, key: (field(7), GF256(1)))
def prss_lsb(n, j, field, prfs, key):
    """Share a pseudo-random number and its least significant bit.

    The random number is shared over *field* and its least significant
    bit is shared over :class:`viff.field.GF256`. It is important the
    *prfs* generate numbers much less than the size of *field* -- we
    must be able to do an addition for each PRF without overflow in
    *field*.

    >>> from field import GF
    >>> Zp = GF(23)
    >>> prfs = {frozenset([1,2]): PRF("a", 7),
    ...         frozenset([1,3]): PRF("b", 7),
    ...         frozenset([2,3]): PRF("c", 7)}
    >>> prss_lsb(3, 1, Zp, prfs, "key")
Exemple #25
0
 def test_construct(self):
     """Test overflows in constructor."""
     self.assertEquals(GF256(256), GF256(0))
     self.assertEquals(GF256(257), GF256(1))
Exemple #26
0
 def test_invert(self):
     """Test inverse operation, including inverting zero."""
     self.assertRaises(ZeroDivisionError, lambda: ~GF256(0))
     self.assertEquals(~GF256(1), GF256(1))
Exemple #27
0
 def test_field(self):
     """Test field attribute."""
     self.assertIdentical(GF256.field, GF256)
     self.assertIdentical(GF256(10).field, GF256)
Exemple #28
0
    def test_div(self):
        """Test division, including division by zero."""
        self.assertRaises(ZeroDivisionError, lambda: GF256(10) / GF256(0))

        self.assertEquals(GF256(10) / GF256(10), GF256(1))
        self.assertEquals(GF256(10) / GF256(9), GF256(208))
        self.assertEquals(GF256(10) / GF256(5), GF256(2))

        self.assertEquals(10 / GF256(5), GF256(2))
Exemple #29
0
class AES:
    """AES instantiation.

    This class is used together with a :class:`~viff.runtime.Runtime`
    object::

        aes = AES(runtime, 192)
        cleartext = [Share(runtime, GF256, GF256(0)) for i in range(128/8)]
        key = [runtime.prss_share_random(GF256) for i in range(192/8)]
        ciphertext = aes.encrypt("abcdefghijklmnop", key)
        ciphertext = aes.encrypt(cleartext, "keykeykeykeykeykeykeykey")
        ciphertext = aes.encrypt(cleartext, key)

    In every case *ciphertext* will be a list of shares over GF256.
    """

    def __init__(self, runtime, key_size, block_size=128,
                 use_exponentiation=False, quiet=False):
        """Initialize Rijndael.

        AES(runtime, key_size, block_size), whereas key size and block
        size must be given in bits. Block size defaults to 128."""

        assert key_size in [128, 192, 256], \
            "Key size must be 128, 192 or 256"
        assert block_size in [128, 192, 256], \
            "Block size be 128, 192 or 256"

        self.n_k = key_size / 32
        self.n_b = block_size / 32
        self.rounds = max(self.n_k, self.n_b) + 6
        self.runtime = runtime

        if use_exponentiation is not False:
            if (isinstance(use_exponentiation, int) and
                use_exponentiation < len(AES.exponentiation_variants)):
                use_exponentiation = \
                    AES.exponentiation_variants[use_exponentiation]
            elif use_exponentiation not in AES.exponentation_variants:
                use_exponentiation = "shortest_sequential_chain"

            if not quiet:
                print "Use %s for inversion by exponentiation." % \
                    use_exponentiation

            if use_exponentiation == "standard_square_and_multiply":
                self.invert = lambda byte: byte ** 254
            elif use_exponentiation == "shortest_chain_with_least_rounds":
                self.invert = self.invert_by_exponentiation_with_less_rounds
            elif use_exponentiation == "chain_with_least_rounds":
                self.invert = self.invert_by_exponentiation_with_least_rounds
            elif use_exponentiation == "masked":
                self.invert = self.invert_by_masked_exponentiation
            elif use_exponentiation == "masked_online":
                self.invert = self.invert_by_masked_exponentiation_online
            else:
                self.invert = self.invert_by_exponentiation
        else:
            self.invert = self.invert_by_masking

            if not quiet:
                print "Use inversion by masking."

    exponentiation_variants = ["standard_square_and_multiply",
                               "shortest_sequential_chain",
                               "shortest_chain_with_least_rounds",
                               "chain_with_least_rounds",
                               "masked",
                               "masked_online"]

    def invert_by_masking(self, byte):
        bits = bit_decompose(byte)

        for j in range(len(bits)):
            bits[j].addCallback(lambda x: 1 - x)
#            bits[j] = 1 - bits[j]

        while len(bits) > 1:
            bits.append(bits.pop(0) * bits.pop(0))

        # b == 1 if byte is 0, b == 0 else
        b = bits[0]

        r = Share(self.runtime, GF256)
        c = Share(self.runtime, GF256)

        def get_masked_byte(c_opened, r_related, c, r, byte):
            if c_opened == 0:
                r_trial = self.runtime.prss_share_random(GF256)
                c_trial = self.runtime.open((byte + b) * r_trial)
                self.runtime.schedule_callback(c_trial, get_masked_byte,
                                               r_trial, c, r, byte)
            else:
                r_related.addCallback(r.callback)
                c.callback(~c_opened)

        get_masked_byte(0, None, c, r, byte)

        # necessary to avoid communication in multiplication
        # was: return c * r - b
        result = gather_shares([c, r, b])
        result.addCallback(lambda (c, r, b): c * r - b)
        return result

    def invert_by_masked_exponentiation(self, byte):
        def add_and_multiply(masked_byte, random_powers, prep):
            masked_powers = self.runtime.powerchain(masked_byte, 7)
            byte_powers = map(operator.add, masked_powers, random_powers)[1:]
            if prep:
                byte_powers = [Share(self.runtime, GF256, value)
                               for value in byte_powers]
            while len(byte_powers) > 1:
                byte_powers.append(byte_powers.pop(0) * byte_powers.pop(0))
            return byte_powers[0]

        random_powers, prep = self.runtime.prss_powerchain()
        masked_byte = self.runtime.open(byte + random_powers[0])
        return self.runtime.schedule_callback(
            masked_byte, add_and_multiply, random_powers, prep)

    # constants for efficient computation of x^2, x^4, x^8 etc.
    powers_of_two = [[GF256(2**j)**(2**i) for j in range(8)] for i in range(8)]

    def invert_by_masked_exponentiation_online(self, byte):
        bits = bit_decompose(byte)
        byte_powers = []

        for i in range(1,8):
            byte_powers.append(self.runtime.lin_comb(AES.powers_of_two[i], bits))

        while len(byte_powers) > 1:
            byte_powers.append(byte_powers.pop(0) * byte_powers.pop(0))

        return byte_powers[0]

    def invert_by_exponentiation(self, byte):
        byte_2 = byte * byte
        byte_3 = byte_2 * byte
        byte_6 = byte_3 * byte_3
        byte_12 = byte_6 * byte_6
        byte_15 = byte_12 * byte_3
        byte_30 = byte_15 * byte_15
        byte_60 = byte_30 * byte_30
        byte_63 = byte_60 * byte_3
        byte_126 = byte_63 * byte_63
        byte_252 = byte_126 * byte_126
        byte_254 = byte_252 * byte_2
        return byte_254

    def invert_by_exponentiation_with_less_rounds(self, byte):
        byte_2 = byte * byte
        byte_4 = byte_2 * byte_2
        byte_8 = byte_4 * byte_4
        byte_9 = byte_8 * byte
        byte_18 = byte_9 * byte_9
        byte_19 = byte_18 * byte
        byte_36 = byte_18 * byte_18
        byte_55 = byte_36 * byte_19
        byte_72 = byte_36 * byte_36
        byte_127 = byte_72 * byte_55
        byte_254 = byte_127 * byte_127
        return byte_254

    def invert_by_exponentiation_with_least_rounds(self, byte):
        byte_2 = byte * byte
        byte_3 = byte_2 * byte
        byte_4 = byte_2 * byte_2
        byte_7 = byte_4 * byte_3
        byte_8 = byte_4 * byte_4
        byte_15 = byte_8 * byte_7
        byte_16 = byte_8 * byte_8
        byte_31 = byte_16 * byte_15
        byte_32 = byte_16 * byte_16
        byte_63 = byte_32 * byte_31
        byte_64 = byte_32 * byte_32
        byte_127 = byte_64 * byte_63
        byte_254 = byte_127 * byte_127
        return byte_254

    # matrix for byte_sub, the last column is the translation vector
    A = Matrix([[1,0,0,0,1,1,1,1],
                [1,1,0,0,0,1,1,1],
                [1,1,1,0,0,0,1,1],
                [1,1,1,1,0,0,0,1],
                [1,1,1,1,1,0,0,0],
                [0,1,1,1,1,1,0,0],
                [0,0,1,1,1,1,1,0],
                [0,0,0,1,1,1,1,1]])

    # anticipate bit recombination
    for i, row in enumerate(A.rows):
        for j in range(len(row)):
            row[j] *= 2 ** i

    def byte_sub(self, state, use_lin_comb=True):
        """ByteSub operation of Rijndael.

        The first argument should be a matrix consisting of elements
        of GF(2^8)."""

        for h in range(len(state)):
            row = state[h]

            for i in range(len(row)):
                bits = bit_decompose(self.invert(row[i]))

                if use_lin_comb:
                    row[i] = self.runtime.lin_comb(sum(AES.A.rows, []),
                                                   bits * len(AES.A.rows))
                else:
                    # caution: order is lsb first
                    vector = AES.A * Matrix(zip(bits))
                    bits = zip(*vector.rows)[0]
                    row[i] = sum(bits)

                row[i].addCallback(lambda x: 0x63 + x)

    def shift_row(self, state):
        """Rijndael ShiftRow.

        State should be a list of 4 rows."""

        assert len(state) == 4, "Wrong state size."

        if self.n_b in [4,6]:
            offsets = [0, 1, 2, 3]
        else:
            offsets = [0, 1, 3, 4]

        for i, row in enumerate(state):
            for j in range(offsets[i]):
                row.append(row.pop(0))

    # matrix for mix_column
    C = [[2, 3, 1, 1],
         [1, 2, 3, 1],
         [1, 1, 2, 3],
         [3, 1, 1, 2]]

    C = Matrix(C)

    def mix_column(self, state, use_lin_comb=True):
        """Rijndael MixColumn.

        Input should be a list of 4 rows."""

        assert len(state) == 4, "Wrong state size."

        if use_lin_comb:
            columns = zip(*state)

            for i, row in enumerate(state):
                row[:] = [self.runtime.lin_comb(AES.C.rows[i], column)
                          for column in columns]
        else:
            state[:] = (AES.C * Matrix(state)).rows

    def add_round_key(self, state, round_key):
        """Rijndael AddRoundKey.

        State should be a list of 4 rows and round_key a list of
        4-byte columns (words)."""

        assert len(round_key) == self.n_b, "Wrong key size."
        assert len(round_key[0]) == 4, "Key must consist of 4-byte words."

        state[:] = (Matrix(state) + Matrix(zip(*round_key))).rows

    def key_expansion(self, key, new_length=None):
        """Rijndael key expansion.

        Input and output are lists of 4-byte columns (words).
        *new_length* is the round for which the key should be expanded.
        If ommitted, the key is expanded for all rounds."""

        assert len(key) >= self.n_k, "Wrong key size."
        assert len(key[0]) == 4, "Key must consist of 4-byte words."

        expanded_key = key

        if new_length == None:
            new_length = self.rounds

        for i in xrange(len(key), self.n_b * (new_length + 1)):
            temp = list(expanded_key[i - 1])

            if i % self.n_k == 0:
                temp.append(temp.pop(0))
                self.byte_sub([temp])
                temp[0] += GF256(2) ** (i / self.n_k - 1)
            elif self.n_k > 6 and i % self.n_k == 4:
                self.byte_sub([temp])

            new_word = []

            for j in xrange(4):
                new_word.append(expanded_key[i - self.n_k][j] + temp[j])

            expanded_key.append(new_word)

        return expanded_key

    def preprocess(self, input):
        if isinstance(input, str):
            return [Share(self.runtime, GF256, GF256(ord(c)))
                    for c in input]
        else:
            for byte in input:
                assert byte.field == GF256, \
                    "Input must be a list of GF256 elements " \
                    "or of shares thereof."
            return input

    def encrypt(self, cleartext, key, benchmark=False, prepare_at_once=False):
        """Rijndael encryption.

        Cleartext and key should be either a string or a list of bytes
        (possibly shared as elements of GF256)."""

        start = time.time()
        self.runtime.increment_pc()
        self.runtime.fork_pc()

        assert len(cleartext) == 4 * self.n_b, "Wrong length of cleartext."
        assert len(key) == 4 * self.n_k, "Wrong length of key."

        cleartext = self.preprocess(cleartext)
        key = self.preprocess(key)

        state = [cleartext[i::4] for i in xrange(4)]
        key = [key[4*i:4*i+4] for i in xrange(self.n_k)]

        if benchmark:
            global preparation, communication
            preparation = 0
            communication = 0

            def progress(x, i, start_round):
                time_diff = time.time() - start_round
                global communication
                communication += time_diff
                print "Round %2d: %f, %f" % \
                    (i, time_diff, time.time() - start)
                return x

            def prep_progress(i, start_round):
                time_diff = time.time() - start_round
                global preparation
                preparation += time_diff
                print "Round %2d preparation: %f, %f" % \
                    (i, time_diff, time.time() - start)
        else:
            progress = lambda x, i, start_round: x
            prep_progress = lambda i, start_round: None

        expanded_key = self.key_expansion(key[:], 0)
        self.add_round_key(state, expanded_key[0:self.n_b])

        prep_progress(0, start)

        def get_trigger(state):
            return gather_shares(reduce(operator.add, state))

        def round(_, state, i):
            start_round = time.time()

            self.key_expansion(expanded_key, i)

            self.byte_sub(state)
            self.shift_row(state)
            self.mix_column(state)
            self.add_round_key(state, expanded_key[i*self.n_b:(i+1)*self.n_b])

            if not prepare_at_once:
                trigger = get_trigger(state)
                trigger.addCallback(progress, i, time.time())

                if i < self.rounds - 1:
                    self.runtime.schedule_complex_callback(trigger, round, state, i + 1)
                else:
                    self.runtime.schedule_complex_callback(trigger, final_round, state)

            prep_progress(i, start_round)

            return _

        def final_round(_, state):
            start_round = time.time()

            self.key_expansion(expanded_key, self.rounds)

            self.byte_sub(state)
            self.shift_row(state)
            self.add_round_key(state, expanded_key[self.rounds*self.n_b:])

            trigger = get_trigger(state)
            trigger.addCallback(progress, self.rounds, time.time())

            if benchmark:
                trigger.addCallback(finish, state)

            # connect to final result
            for a, b in zip(reduce(operator.add, zip(*state)), result):
                a.addCallback(b.callback)

            prep_progress(self.rounds, start_round)

            return _

        def finish(_, state):
            print "Total preparation time: %f" % preparation
            print "Total communication time: %f" % communication

            return _

        result = [Share(self.runtime, GF256) for i in xrange(4 * self.n_b)]

        if prepare_at_once:
            for i in range(1, self.rounds):
                round(None, state, i)

            final_round(None, state)
        else:
            round(None, state, 1)

        self.runtime.unfork_pc()
        return result
Exemple #30
0
 def test_pow(self):
     """Test exponentiation."""
     self.assertEquals(GF256(3)**3, GF256(3) * GF256(3) * GF256(3))
     self.assertEquals(GF256(27)**100, GF256(27)**50 * GF256(27)**50)