def chain(n, state_name='x', action='a', start=None, clip=True, can_stay=True): if start is None: start = n // 2 else: start -= 1 start = encode_int(n, 1 << start, signed=False) x = atom(n, state_name, signed=False) a = atom(2, action, signed=False) backward, forward = a[0], a[1] x2 = ite(forward, x << 1, x >> 1) stay = atom(1, 1, signed=False) if clip: stay = (x2 == 0) if can_stay: stay |= ~(forward | backward) if clip or can_stay: x2 = ite(stay, x, x2) circ = x2.aigbv['o', {x2.output: state_name}] return circ.feedback(inputs=[state_name], outputs=[state_name], initials=[start], keep_outputs=True)
def test_closed_system(): c1 = aiger_bv.atom(1, 'c1', signed=False) a = aiger_bv.atom(1, 'a', signed=False) dyn = circ2mdp(chain(n=4, state_name='s', action='a')) dyn <<= (c1 & a).with_output('a').aigbv c1_coin = coin((1, 8), name='c1') dyn <<= c1_coin assert dyn.inputs == {'a'} assert dyn.outputs == {'s'} start = {'s_prev': (True, False, False, False, False)} end = {'s_prev': (False, True, False, False, False)} assert dyn.prob(start, {'a': (0, )}, end) == 0 assert dyn.prob(start, {'a': (1, )}, end) == c1_coin.prob() == 1 / 8 c2 = aiger_bv.atom(1, 'c2', signed=False) const_false = aiger_bv.atom(1, 0, signed=False) state = aiger_bv.atom(5, 's', signed=False) clip = state == 0b00001 policy = circ2mdp(aiger_bv.ite(clip, const_false, c2).with_output('a')) policy <<= coin((1, 8), name='c2') sys = (policy >> dyn).feedback(inputs=['s'], outputs=['s'], latches=['s_prev2'], keep_outputs=True) assert sys.inputs == set() assert sys.outputs == {'s'}
def chain(n, state_name='x', action='H'): bits = n + 1 start = encode_int(bits, 1, signed=False) x = atom(bits, state_name, signed=False) forward = atom(1, action, signed=False) x2 = ite(forward, x << 1, x) circ = x2.aigbv['o', {x2.output: state_name}] return circ.feedback(inputs=[state_name], outputs=[state_name], latches=[f"{state_name}_prev"], initials=[start], keep_outputs=True)
def onehot_gadget(output: str): sat = BV.uatom(1, output) false, true = BV.uatom(2, 0b01), BV.uatom(2, 0b10) expr = BV.ite(sat, true, false) \ .with_output('sat') encoder = D.Encoding( encode=lambda x: 1 << int(x), decode=lambda x: bool((x >> 1) & 1), ) return D.from_aigbv( expr.aigbv, output_encodings={'sat': encoder}, )
def ite(test, idx): return BV.ite(expr[idx], bits[idx], test)
def max_op(lhs, rhs): return BV.ite(lhs > rhs, lhs, rhs)
def min_op(lhs, rhs): return BV.ite(lhs < rhs, lhs, rhs)