예제 #1
0
 def test_misc(self):
     secint = mpc.SecInt()
     secfxp = mpc.SecFxp()
     secfld = mpc.SecFld()
     for secnum in (secint, secfxp, secfld):
         self.assertEqual(type(mpc.run(mpc.output(secnum(0), raw=True))), secnum.field)
         self.assertEqual(mpc.run(mpc.output(mpc._reshare([secnum(0)]))), [0])
         self.assertEqual(mpc.run(mpc.output(mpc.all(secnum(1) for _ in range(5)))), True)
         self.assertEqual(mpc.run(mpc.output(mpc.all([secnum(1), secnum(1), secnum(0)]))), False)
         self.assertEqual(mpc.run(mpc.output(mpc.any(secnum(0) for _ in range(5)))), False)
         self.assertEqual(mpc.run(mpc.output(mpc.any([secnum(0), secnum(1), secnum(1)]))), True)
         self.assertEqual(mpc.run(mpc.output(mpc.sum([secnum(1)], start=1))), 2)
         self.assertEqual(mpc.run(mpc.output(mpc.prod([secnum(1)], start=1))), 1)
         self.assertEqual(mpc.run(mpc.output(mpc.sum([secnum(1)], start=secnum(1)))), 2)
     self.assertEqual(mpc.run(mpc.output(mpc.min(secint(i) for i in range(-1, 2, 1)))), -1)
     self.assertEqual(mpc.run(mpc.output(mpc.argmin(secint(i) for i in range(-1, 2, 1))[0])), 0)
     self.assertEqual(mpc.run(mpc.output(mpc.max(secfxp(i) for i in range(-1, 2, 1)))), 1)
     self.assertEqual(mpc.run(mpc.output(mpc.argmax(secfxp(i) for i in range(-1, 2, 1))[0])), 2)
     self.assertEqual(mpc.run(mpc.output(list(mpc.min_max(map(secfxp, range(5)))))), [0, 4])
     x = (secint(i) for i in range(-3, 3))
     s = [0, -1, 1, -2, 2, -3]
     self.assertEqual(mpc.run(mpc.output(mpc.sorted(x, key=lambda a: a*(2*a+1)))), s)
     x = (secfxp(i) for i in range(5))
     self.assertEqual(mpc.run(mpc.output(mpc.sorted(x, reverse=True))), [4, 3, 2, 1, 0])
     self.assertEqual(mpc.run(mpc.output(mpc.sum(map(secint, range(5))))), 10)
     self.assertEqual(mpc.run(mpc.output(mpc.sum([secfxp(2.75)], start=3.125))), 5.875)
     self.assertEqual(int(mpc.run(mpc.output(mpc.prod(map(secfxp, range(1, 5)))))), 24)
     self.assertEqual(int(mpc.run(mpc.output(mpc.prod([secfxp(1.414214)]*4)))), 4)
예제 #2
0
파일: id3gini.py 프로젝트: xujiangyu/mpyc
async def id3(T, R) -> asyncio.Future:
    sizes = [mpc.in_prod(T, v) for v in S[C]]
    i, mx = mpc.argmax(sizes)
    sizeT = mpc.sum(sizes)
    stop = (sizeT <= int(args.epsilon * len(T))) + (mx == sizeT)
    if not (R and await mpc.is_zero_public(stop)):
        i = await mpc.output(i)
        logging.info(f'Leaf node label {i}')
        tree = i
    else:
        T_R = [[mpc.schur_prod(T, v) for v in S[A]] for A in R]
        gains = [GI(mpc.matrix_prod(T_A, S[C], True)) for T_A in T_R]
        k = await mpc.output(mpc.argmax(gains, key=SecureFraction)[0])
        T_Rk = T_R[k]
        del T_R, gains  # release memory
        A = list(R)[k]
        logging.info(f'Attribute node {A}')
        if args.parallel_subtrees:
            subtrees = await mpc.gather(
                [id3(Tj, R.difference([A])) for Tj in T_Rk])
        else:
            subtrees = [await id3(Tj, R.difference([A])) for Tj in T_Rk]
        tree = A, subtrees
    return tree
예제 #3
0
async def main():
    global secint

    parser = argparse.ArgumentParser()
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        metavar='B',
                        help='number of images to classify')
    parser.add_argument(
        '-o',
        '--offset',
        type=int,
        metavar='O',
        help='offset for batch (otherwise random in [0,10000-B])')
    parser.add_argument(
        '-d',
        '--d-k-star',
        type=int,
        metavar='D',
        help='k=D=0,1,2 for Legendre-based comparison using d_k^*')
    parser.add_argument('--no-legendre',
                        action='store_true',
                        help='disable Legendre-based comparison')
    parser.add_argument('--no-vectorization',
                        action='store_true',
                        help='disable vectorization of comparisons')
    parser.set_defaults(batch_size=1, offset=-1, d_k_star=1)
    args = parser.parse_args()

    batch_size = args.batch_size
    offset = args.offset
    if args.no_legendre:
        secint = mpc.SecInt(14)  # using vectorized MPyC integer comparison
    else:
        if args.d_k_star == 0:
            secint = mpc.SecInt(
                14, p=3546374752298322551)  # Legendre-0 range [-134, 134]
            bsgn = bsgn_0
            vector_bsgn = vector_bsgn_0
        elif args.d_k_star == 1:
            secint = mpc.SecInt(
                14, p=9409569905028393239)  # Legendre-1 range [-383, 383]
            bsgn = bsgn_1
            vector_bsgn = vector_bsgn_1
        else:
            secint = mpc.SecInt(
                14, p=15569949805843283171)  # Legendre-2 range [-594, 594]
            bsgn = bsgn_2
            vector_bsgn = vector_bsgn_2

    one_by_one = args.no_vectorization

    await mpc.start()

    if offset < 0:
        offset = random.randrange(10001 - batch_size) if mpc.pid == 0 else None
        offset = await mpc.transfer(offset, senders=0)

    logging.info('--------------- INPUT   -------------')
    print(
        f'Type = {secint.__name__}, range = ({offset}, {offset + batch_size})')
    # read batch_size labels and images at given offset
    df = gzip.open(os.path.join('data', 'cnn', 't10k-labels-idx1-ubyte.gz'))
    d = df.read()[8 + offset:8 + offset + batch_size]
    labels = list(map(int, d))
    print('Labels:', labels)
    df = gzip.open(os.path.join('data', 'cnn', 't10k-images-idx3-ubyte.gz'))
    d = df.read()[16 + offset * 28**2:16 + (offset + batch_size) * 28**2]
    L = np.array(list(d)).reshape(batch_size, 28**2)
    if batch_size == 1:
        x = np.array(L[0]).reshape(28, 28)
        print(
            np.array2string(np.vectorize(lambda a: int(bool((a / 255))))(x),
                            separator=''))

    L = np.vectorize(lambda a: secint(int(a)))(L).tolist()

    logging.info('--------------- LAYER 1 -------------')
    logging.info('- - - - - - - - fc      - - - - - - -')
    L = mpc.matrix_prod(L, load_W('fc1'))
    L = mpc.matrix_add(L, [load_b('fc1')] * len(L))
    logging.info('- - - - - - - - bsgn    - - - - - - -')
    if one_by_one:
        L = np.vectorize(lambda a: (a >= 0) * 2 - 1)(L).tolist()
    else:
        L = [vector_sge(_) for _ in L]
    await mpc.barrier()

    logging.info('--------------- LAYER 2 -------------')
    logging.info('- - - - - - - - fc      - - - - - - -')
    L = mpc.matrix_prod(L, load_W('fc2'))
    L = mpc.matrix_add(L, [load_b('fc2')] * len(L))
    await mpc.barrier()
    logging.info('- - - - - - - - bsgn    - - - - - - -')
    if args.no_legendre:
        secint.bit_length = 10
        if one_by_one:
            activate = np.vectorize(lambda a: (a >= 0) * 2 - 1)
            L = activate(L).tolist()
        else:
            L = [vector_sge(_) for _ in L]
    else:
        if one_by_one:
            activate = np.vectorize(bsgn)
            L = activate(L).tolist()
        else:
            L = [vector_bsgn(_) for _ in L]
    await mpc.barrier()

    logging.info('--------------- LAYER 3 -------------')
    logging.info('- - - - - - - - fc      - - - - - - -')
    L = mpc.matrix_prod(L, load_W('fc3'))
    L = mpc.matrix_add(L, [load_b('fc3')] * len(L))
    await mpc.barrier()
    logging.info('- - - - - - - - bsgn    - - - - - - -')
    if args.no_legendre:
        secint.bit_length = 10
        if one_by_one:
            activate = np.vectorize(lambda a: (a >= 0) * 2 - 1)
            L = activate(L).tolist()
        else:
            L = [vector_sge(_) for _ in L]
    else:
        if one_by_one:
            activate = np.vectorize(bsgn)
            L = activate(L).tolist()
        else:
            L = [vector_bsgn(_) for _ in L]
    await mpc.barrier()

    logging.info('--------------- LAYER 4 -------------')
    logging.info('- - - - - - - - fc      - - - - - - -')
    L = mpc.matrix_prod(L, load_W('fc4'))
    L = mpc.matrix_add(L, [load_b('fc4')] * len(L))
    await mpc.barrier()

    logging.info('--------------- OUTPUT  -------------')
    if args.no_legendre:
        secint.bit_length = 14
    for i in range(batch_size):
        prediction = await mpc.output(mpc.argmax(L[i])[0])
        error = '******* ERROR *******' if prediction != labels[i] else ''
        print(
            f'Image #{offset+i} with label {labels[i]}: {prediction} predicted. {error}'
        )
        print(await mpc.output(L[i]))

    await mpc.shutdown()
예제 #4
0
 def test_misc(self):
     secint = mpc.SecInt()
     secfxp = mpc.SecFxp()
     secfld = mpc.SecFld()
     for secnum in (secint, secfxp, secfld):
         self.assertEqual(type(mpc.run(mpc.output(secnum(0), raw=True))),
                          secnum.field)
         self.assertEqual(mpc.run(mpc.output(mpc._reshare([secnum(0)]))),
                          [0])
         self.assertEqual(
             mpc.run(mpc.output(mpc.all(secnum(1) for _ in range(5)))),
             True)
         self.assertEqual(
             mpc.run(mpc.output(mpc.all([secnum(1),
                                         secnum(1),
                                         secnum(0)]))), False)
         self.assertEqual(
             mpc.run(mpc.output(mpc.any(secnum(0) for _ in range(5)))),
             False)
         self.assertEqual(
             mpc.run(mpc.output(mpc.any([secnum(0),
                                         secnum(1),
                                         secnum(1)]))), True)
         self.assertEqual(
             mpc.run(mpc.output(mpc.sum([secnum(1)], start=1))), 2)
         self.assertEqual(
             mpc.run(mpc.output(mpc.prod([secnum(1)], start=1))), 1)
         self.assertEqual(
             mpc.run(mpc.output(mpc.sum([secnum(1)], start=secnum(1)))), 2)
         self.assertEqual(
             mpc.run(mpc.output(mpc.find([secnum(1)], 0, e=-1))), -1)
         self.assertEqual(mpc.run(mpc.output(mpc.find([secnum(1)], 1))), 0)
         self.assertEqual(
             mpc.run(mpc.output(mpc.find([secnum(1)], 1, f=lambda i: i))),
             0)
     self.assertEqual(
         mpc.run(mpc.output(mpc.min(secint(i) for i in range(-1, 2, 1)))),
         -1)
     self.assertEqual(
         mpc.run(
             mpc.output(mpc.argmin(secint(i) for i in range(-1, 2, 1))[0])),
         0)
     self.assertEqual(
         mpc.run(mpc.output(mpc.max(secfxp(i) for i in range(-1, 2, 1)))),
         1)
     self.assertEqual(
         mpc.run(
             mpc.output(mpc.argmax(secfxp(i) for i in range(-1, 2, 1))[0])),
         2)
     self.assertEqual(
         mpc.run(mpc.output(list(mpc.min_max(map(secfxp, range(5)))))),
         [0, 4])
     x = (secint(i) for i in range(-3, 3))
     s = [0, -1, 1, -2, 2, -3]
     self.assertEqual(
         mpc.run(mpc.output(mpc.sorted(x, key=lambda a: a * (2 * a + 1)))),
         s)
     x = (secfxp(i) for i in range(5))
     self.assertEqual(mpc.run(mpc.output(mpc.sorted(x, reverse=True))),
                      [4, 3, 2, 1, 0])
     self.assertEqual(mpc.run(mpc.output(mpc.sum(map(secint, range(5))))),
                      10)
     self.assertEqual(
         mpc.run(mpc.output(mpc.sum([secfxp(2.75)], start=3.125))), 5.875)
     self.assertEqual(
         int(mpc.run(mpc.output(mpc.prod(map(secfxp, range(1, 5)))))), 24)
     self.assertEqual(
         int(mpc.run(mpc.output(mpc.prod([secfxp(1.414214)] * 4)))), 4)
     self.assertEqual(mpc.find([], 0), 0)
     self.assertEqual(mpc.find([], 0, e=None), (1, 0))
     self.assertEqual(
         mpc.run(mpc.output(list(mpc.find([secfld(1)], 1, e=None)))),
         [0, 0])
     self.assertEqual(
         mpc.run(mpc.output(mpc.find([secfld(2)], 2, bits=False))), 0)
     x = [secint(i) for i in range(5)]
     f = lambda i: [i**2, 3**i]
     self.assertEqual(mpc.run(mpc.output(mpc.find(x, 2, bits=False, f=f))),
                      [4, 9])
     cs_f = lambda b, i: [b * (2 * i + 1) + i**2, (b * 2 + 1) * 3**i]
     self.assertEqual(
         mpc.run(mpc.output(mpc.find(x, 2, bits=False, cs_f=cs_f))), [4, 9])
예제 #5
0
    def test_secfxp(self):
        secfxp = mpc.SecFxp()
        self.assertEqual(
            mpc.run(mpc.output(mpc.input(secfxp(7.75), senders=0))), 7.75)
        c = mpc.to_bits(secfxp(0),
                        0)  # mpc.output() only works for nonempty lists
        self.assertEqual(c, [])
        c = mpc.run(mpc.output(mpc.to_bits(secfxp(0))))
        self.assertEqual(c, [0.0] * 32)
        c = mpc.run(mpc.output(mpc.to_bits(secfxp(1))))
        self.assertEqual(c, [0.0] * 16 + [1.0] + [0.0] * 15)
        c = mpc.run(mpc.output(mpc.to_bits(secfxp(0.5))))
        self.assertEqual(c, [0.0] * 15 + [1.0] + [0.0] * 16)
        c = mpc.run(mpc.output(mpc.to_bits(secfxp(8113))))
        self.assertEqual(c, [0.0] * 16 +
                         [1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0])
        c = mpc.run(mpc.output(mpc.to_bits(secfxp(2**15 - 1))))
        self.assertEqual(c, [0] * 16 + [1] * 15 + [0])
        c = mpc.run(mpc.output(mpc.to_bits(secfxp(-1))))
        self.assertEqual(c, [0] * 16 + [1] * 16)
        c = mpc.run(mpc.output(mpc.to_bits(secfxp(-2**15))))
        self.assertEqual(c, [0] * 31 + [1])

        for f in [8, 16, 32, 64]:
            secfxp = mpc.SecFxp(2 * f)
            c = mpc.run(mpc.output(secfxp(1) + secfxp(1)))
            self.assertEqual(c, 2)
            c = mpc.run(mpc.output(secfxp(2**-f) + secfxp(1)))
            if f != 64:  # NB: 1 + 2**-64 == 1 in Python
                self.assertEqual(c, 1 + 2**-f)
            self.assertEqual(mpc.run(mpc.output(secfxp(0.5) * secfxp(2.0))), 1)
            self.assertEqual(mpc.run(mpc.output(secfxp(2.0) * secfxp(0.5))), 1)
            c = mpc.run(
                mpc.output(
                    secfxp(2**(f // 2 - 1) - 0.5) *
                    secfxp(-2**(f // 2) + 0.5)))
            self.assertEqual(c, -2**(f - 1) + 1.5 * 2**(f // 2 - 1) - 0.25)

            s = [10.75, -3.375, 0.125, -0.125]
            self.assertEqual(mpc.run(mpc.output(list(map(secfxp, s)))), s)

            s = [10.5, -3.25, 0.125, -0.125]
            a, b, c, d = list(map(secfxp, s))
            t = [v * v for v in s]
            self.assertEqual(mpc.run(mpc.output([a * a, b * b, c * c, d * d])),
                             t)
            x = [a, b, c, d]
            self.assertEqual(mpc.run(mpc.output(mpc.schur_prod(x, x))), t)
            self.assertEqual(mpc.run(mpc.output(mpc.schur_prod(x, x[:]))), t)
            t = sum(t)
            self.assertEqual(mpc.run(mpc.output(mpc.in_prod(x, x))), t)
            self.assertEqual(mpc.run(mpc.output(mpc.in_prod(x, x[:]))), t)
            self.assertEqual(
                mpc.run(mpc.output(mpc.matrix_prod([x], [x], True)[0])), [t])
            u = mpc.unit_vector(secfxp(3), 4)
            self.assertEqual(
                mpc.run(mpc.output(mpc.matrix_prod([x], [u], True)[0])),
                [s[3]])
            self.assertEqual(
                mpc.run(mpc.output(mpc.matrix_prod([u], [x], True)[0])),
                [s[3]])
            self.assertEqual(
                mpc.run(mpc.output(mpc.gauss([[a]], b, [a], [b])[0])), [0])
            t = [_ for a, b, c, d in [s] for _ in [a + b, a * b, a - b]]
            self.assertEqual(mpc.run(mpc.output([a + b, a * b, a - b])), t)
            t = [
                _ for a, b, c, d in [s]
                for _ in [(a + b)**2, (a + b)**2 + 3 * c]
            ]
            self.assertEqual(
                mpc.run(mpc.output([(a + b)**2, (a + b)**2 + 3 * c])), t)
            t = [_ for a, b, c, d in [s] for _ in [a < b, b < c, c < d]]
            self.assertEqual(mpc.run(mpc.output([a < b, b < c, c < d])), t)
            t = s[0] < s[1] and s[1] < s[2]
            self.assertEqual(mpc.run(mpc.output((a < b) & (b < c))), t)
            t = s[0] < s[1] or s[1] < s[2]
            self.assertEqual(mpc.run(mpc.output((a < b) | (b < c))), t)
            t = (int(s[0] < s[1]) ^ int(s[1] < s[2]))
            self.assertEqual(mpc.run(mpc.output((a < b) ^ (b < c))), t)
            t = (int(not s[0] < s[1]) ^ int(s[1] < s[2]))
            self.assertEqual(mpc.run(mpc.output(~(a < b) ^ b < c)), t)
            t = [s[0] > 1, 10 * s[1] < 5, 10 * s[0] == 5]
            self.assertEqual(
                mpc.run(mpc.output([a > 1, 10 * b < 5, 10 * a == 5])), t)

            s[3] = -0.120
            d = secfxp(s[3])
            t = s[3] / 0.25
            self.assertAlmostEqual(mpc.run(mpc.output(d / 0.25)),
                                   t,
                                   delta=2**(1 - f))
            t = round(s[3] / s[2] + s[0])
            self.assertEqual(round(mpc.run(mpc.output(d / c + a))), t)
            t = ((s[0] + s[1])**2 + 3 * s[2]) / s[2]
            self.assertAlmostEqual(mpc.run(mpc.output(
                ((a + b)**2 + 3 * c) / c)),
                                   t,
                                   delta=2**(8 - f))
            t = 1 / s[3]
            self.assertAlmostEqual(mpc.run(mpc.output(1 / d)),
                                   t,
                                   delta=2**(6 - f))
            t = s[2] / s[3]
            self.assertAlmostEqual(mpc.run(mpc.output(c / d)),
                                   t,
                                   delta=2**(3 - f))
            t = -s[3] / s[2]
            self.assertAlmostEqual(mpc.run(mpc.output(-d / c)),
                                   t,
                                   delta=2**(3 - f))

            self.assertEqual(mpc.run(mpc.output(mpc.sgn(+a))), s[0] > 0)
            self.assertEqual(mpc.run(mpc.output(mpc.sgn(-a))), -(s[0] > 0))
            self.assertEqual(mpc.run(mpc.output(mpc.sgn(secfxp(0)))), 0)
            self.assertEqual(mpc.run(mpc.output(abs(secfxp(-1.5)))), 1.5)

            self.assertEqual(mpc.run(mpc.output(mpc.min(a, b, c, d))), min(s))
            self.assertEqual(mpc.run(mpc.output(mpc.min(a, 0))), min(s[0], 0))
            self.assertEqual(mpc.run(mpc.output(mpc.min(0, b))), min(0, s[1]))
            self.assertEqual(mpc.run(mpc.output(mpc.max(a, b, c, d))), max(s))
            self.assertEqual(mpc.run(mpc.output(mpc.max(a, 0))), max(s[0], 0))
            self.assertEqual(mpc.run(mpc.output(mpc.max(0, b))), max(0, s[1]))
            self.assertEqual(
                mpc.run(mpc.output(list(mpc.min_max(a, b, c, d)))),
                [min(s), max(s)])
            self.assertEqual(mpc.run(mpc.output(mpc.argmin([a, b, c, d])[0])),
                             1)
            self.assertEqual(
                mpc.run(mpc.output(mpc.argmin([a, b], key=operator.neg)[1])),
                max(s))
            self.assertEqual(mpc.run(mpc.output(mpc.argmax([a, b, c, d])[0])),
                             0)
            self.assertEqual(
                mpc.run(mpc.output(mpc.argmax([a, b], key=operator.neg)[1])),
                min(s))

            self.assertEqual(mpc.run(mpc.output(secfxp(5) % 2)), 1)
            self.assertEqual(mpc.run(mpc.output(secfxp(1) % 2**(1 - f))), 0)
            self.assertEqual(mpc.run(mpc.output(secfxp(2**-f) % 2**(1 - f))),
                             2**-f)
            self.assertEqual(
                mpc.run(mpc.output(secfxp(2 * 2**-f) % 2**(1 - f))), 0)
            self.assertEqual(mpc.run(mpc.output(secfxp(1) // 2**(1 - f))),
                             2**(f - 1))
            self.assertEqual(mpc.run(mpc.output(secfxp(27.0) % 7.0)), 6.0)
            self.assertEqual(mpc.run(mpc.output(secfxp(-27.0) // 7.0)), -4.0)
            self.assertEqual(
                mpc.run(mpc.output(list(divmod(secfxp(27.0), 6.0)))),
                [4.0, 3.0])
            self.assertEqual(mpc.run(mpc.output(secfxp(21.5) % 7.5)), 6.5)
            self.assertEqual(mpc.run(mpc.output(secfxp(-21.5) // 7.5)), -3.0)
            self.assertEqual(
                mpc.run(mpc.output(list(divmod(secfxp(21.5), 0.5)))),
                [43.0, 0.0])
예제 #6
0
async def main():
    global secnum

    k = 1 if len(sys.argv) == 1 else float(sys.argv[1])
    if k - int(k) == 0.5:
        secnum = mpc.SecFxp(10, 4)
    else:
        secnum = mpc.SecInt(37)
    batch_size = round(k - 0.01)

    await mpc.start()

    if len(sys.argv) <= 2:
        offset = random.randrange(10001 - batch_size) if mpc.pid == 0 else None
        offset = await mpc.transfer(offset, senders=0)
    else:
        offset = int(sys.argv[2])

    f = 6

    logging.info('--------------- INPUT   -------------')
    print(
        f'Type = {secnum.__name__}, range = ({offset}, {offset + batch_size})')
    # read batch_size labels and images at given offset
    df = gzip.open(os.path.join('data', 'cnn', 't10k-labels-idx1-ubyte.gz'))
    d = df.read()[8 + offset:8 + offset + batch_size]
    labels = list(map(int, d))
    print('Labels:', labels)
    df = gzip.open(os.path.join('data', 'cnn', 't10k-images-idx3-ubyte.gz'))
    d = df.read()[16 + offset * 28**2:16 + (offset + batch_size) * 28**2]
    x = list(map(lambda a: a / 255, d))
    x = np.array(x).reshape(batch_size, 1, 28, 28)
    if batch_size == 1:
        print(np.vectorize(lambda a: int(bool(a)))(x[0, 0]))
    x = scale_to_int(1 << f)(x)

    logging.info('--------------- LAYER 1 -------------')
    W, b = load('conv1', f)
    x = convolvetensor(x, W, b)
    await mpc.barrier()
    if secnum.__name__.startswith('SecInt'):
        secnum.bit_length = 16
    x = maxpool(x)
    await mpc.barrier()
    x = ReLU(x)
    await mpc.barrier()

    logging.info('--------------- LAYER 2 -------------')
    W, b = load('conv2', f, 3)
    x = convolvetensor(x, W, b)
    await mpc.barrier()
    if secnum.__name__.startswith('SecInt'):
        secnum.bit_length = 23
    x = maxpool(x)
    await mpc.barrier()
    x = ReLU(x)
    await mpc.barrier()

    logging.info('--------------- LAYER 3 -------------')
    x = x.reshape(batch_size, 64 * 7**2)
    W, b = load('fc1', f, 4)
    x = tensormatrix_prod(x, W, b)
    if secnum.__name__.startswith('SecInt'):
        secnum.bit_length = 30
    x = ReLU(x)
    await mpc.barrier()

    logging.info('--------------- LAYER 4 -------------')
    W, b = load('fc2', f, 5)
    x = tensormatrix_prod(x, W, b)

    logging.info('--------------- OUTPUT  -------------')
    if secnum.__name__.startswith('SecInt'):
        secnum.bit_length = 37
    for i in range(batch_size):
        prediction = int(await mpc.output(mpc.argmax(x[i])[0]))
        error = '******* ERROR *******' if prediction != labels[i] else ''
        print(
            f'Image #{offset+i} with label {labels[i]}: {prediction} predicted. {error}'
        )
        print(await mpc.output(x[i]))

    await mpc.shutdown()