def FLRound(x, mode): """ Rounding with floating point output. *mode*: 0 -> floor, 1 -> ceil, -1 > trunc """ v1, p1, z1, s1, l, k = x.v, x.p, x.z, x.s, x.vlen, x.plen a = types.sint() AdvInteger.LTZ(a, p1, k, x.kappa) b = p1.less_than(-l + 1, k, x.kappa) v2, inv_2pow_p1 = AdvInteger.Oblivious_Trunc(v1, l, -a * (1 - b) * x.p, x.kappa, True) c = AdvInteger.EQZ(v2, l, x.kappa) if mode == -1: away_from_zero = 0 mode = x.s else: away_from_zero = mode + s1 - 2 * mode * s1 v = v1 - v2 + (1 - c) * inv_2pow_p1 * away_from_zero d = v.equal(AdvInteger.two_power(l), l + 1, x.kappa) v = d * AdvInteger.two_power(l - 1) + (1 - d) * v v = a * ((1 - b) * v + b * away_from_zero * AdvInteger.two_power(l - 1)) + (1 - a) * v1 s = (1 - b * mode) * s1 z = AdvInteger.or_op(AdvInteger.EQZ(v, l, x.kappa), z1) v = v * (1 - z) p = ((p1 + d * a) * (1 - b) + b * away_from_zero * (1 - l)) * (1 - z) return v, p, z, s
def Int2FL(a, gamma, l, kappa): lam = gamma - 1 s = types.sint() AdvInteger.LTZ(s, a, gamma, kappa) z = AdvInteger.EQZ(a, gamma, kappa) a = (1 - 2 * s) * a a_bits = AdvInteger.BitDec(a, lam, lam, kappa) a_bits.reverse() b = AdvInteger.PreOR(a_bits) t = a * (1 + sum(2**i * (1 - b_i) for i, b_i in enumerate(b))) p = -(lam - sum(b)) if lam > l: if types.sfloat.round_nearest: v, overflow = TruncRoundNearestAdjustOverflow( t, gamma - 1, l, kappa) p = p + overflow else: v = types.sint() AdvInteger.Trunc(v, t, gamma - 1, gamma - l - 1, kappa, False) #TODO: Shouldnt this be only gamma else: v = 2**(l - gamma + 1) * t p = (p + gamma - 1 - l) * (1 - z) return v, p, z, s
def FLLT(fl_a, fl_b): t = fl_a.err if isinstance(fl_b, types.sfloat): t = t + fl_b.err t = t == 0 z1 = fl_a.z z2 = fl_b.z s1 = fl_a.s s2 = fl_b.s a = fl_a.p.less_than(fl_b.p, fl_a.plen, fl_a.kappa) c = AdvInteger.EQZ(fl_a.p - fl_b.p, fl_a.plen, fl_a.kappa) d = ((1 - 2 * fl_a.s) * fl_a.v).less_than((1 - 2 * fl_b.s) * fl_b.v, fl_a.vlen + 1, fl_a.kappa) cd = c * d ca = c * a b1 = cd + a - ca b2 = cd + 1 + ca - c - a s12 = fl_a.s * fl_b.s z12 = fl_a.z * fl_b.z b = (z1 - z12) * (1 - s2) + (z2 - z12) * s1 + (1 + z12 - z1 - z2) * \ (s1 - s12 + (1 + s12 - s1 - s2) * b1 + s12 * b2) * t return b