def connect_banyan(cl, swb_ins, swb_outs, bw): I = int(2 * cg.clog2(bw) - 2) J = int(bw / 2) for i in range(cg.clog2(J)): r = J / (2**i) for j in range(J): t = (j % r) >= (r / 2) # straight out_i = int((i * bw) + (2 * j) + t) in_i = int((i * bw + bw) + (2 * j) + t) cl.connect(swb_outs[out_i], swb_ins[in_i]) # cross out_i = int((i * bw) + (2 * j) + (1 - t) + ((r - 1) * ((1 - t) * 2 - 1))) in_i = int((i * bw + bw) + (2 * j) + (1 - t)) cl.connect(swb_outs[out_i], swb_ins[in_i]) if r > 2: # straight out_i = int(((I * J * 2) - ((2 + i) * bw)) + (2 * j) + t) in_i = int(((I * J * 2) - ((1 + i) * bw)) + (2 * j) + t) cl.connect(swb_outs[out_i], swb_ins[in_i]) # cross out_i = int(((I * J * 2) - ((2 + i) * bw)) + (2 * j) + (1 - t) + ((r - 1) * ((1 - t) * 2 - 1))) in_i = int(((I * J * 2) - ((1 + i) * bw)) + (2 * j) + (1 - t)) cl.connect(swb_outs[out_i], swb_ins[in_i])
def mux(w): """ Create a mux. Parameters ---------- w: the width of the mux. Returns ------- a `CircuitGraph` mux. """ c = Circuit(name="mux") # create inputs for i in range(w): c.add(f"in_{i}", "input") sels = [] for i in range(clog2(w)): c.add(f"sel_{i}", "input") c.add(f"not_sel_{i}", "not", fanin=f"sel_{i}") sels.append([f"not_sel_{i}", f"sel_{i}"]) # create output or c.add("out", "or", output=True) i = 0 for sel in product(*sels[::-1]): c.add(f"and_{i}", "and", fanin=[*sel, f"in_{i}"], fanout="out") i += 1 if i == w: break return c
def test_sensitivity_transform(self): # pick random node and input value n = choice(tuple(self.s27.nodes() - self.s27.startpoints())) nstartpoints = self.s27.startpoints(n) while len(nstartpoints) < 1: n = choice(tuple(self.s27.nodes() - self.s27.startpoints())) nstartpoints = self.s27.startpoints(n) input_val = {i: randint(0, 1) for i in nstartpoints} # build sensitivity circuit s = sensitivity_transform(self.s27, n) # find sensitivity at an input model = sat(s, input_val) sen_s = sum(model[o] for o in s.outputs() if "dif_out" in o) # try inputs Hamming distance 1 away output_val = sat(self.s27, input_val)[n] sen_sim = 0 for i in nstartpoints: neighbor_input_val = { g: v if g != i else not v for g, v in input_val.items() } neighbor_output_val = sat(self.s27, neighbor_input_val)[n] if neighbor_output_val != output_val: sen_sim += 1 # check answer self.assertEqual(sen_s, sen_sim) # find input with sensitivity vs = cg.int_to_bin(sen_s, cg.clog2(len(nstartpoints) + 1), True) model = sat(s, {f"sen_out_{i}": v for i, v in enumerate(vs)}) input_val = {i: model[i] for i in nstartpoints} # try inputs Hamming distance 1 away output_val = sat(self.s27, input_val)[n] sen_sim = 0 for i in nstartpoints: neighbor_input_val = { g: v if g != i else not v for g, v in input_val.items() } neighbor_output_val = sat(self.s27, neighbor_input_val)[n] if neighbor_output_val != output_val: sen_sim += 1 # check answer self.assertEqual(sen_s, sen_sim)
def lebl(c, bw, ng): """ Locks a circuitgraph with Logic-Enhanced Banyan Locking as outlined in Joseph Sweeney, Marijn J.H. Heule, and Lawrence Pileggi Modeling Techniques for Logic Locking. In Proceedings of the International Conference on Computer Aided Design 2020 (ICCAD-39). Parameters ---------- circuit: circuitgraph.CircuitGraph Circuit to lock. bw: int Width of Banyan network. lw: int Minimum number of gates mapped to network. Returns ------- circuitgraph.CircuitGraph, dict of str:bool the locked circuit and the correct key value for each key input """ # create copy to lock cl = cg.copy(c) # generate switch and mux s = cg.Circuit(name='switch') m2 = cg.strip_io(logic.mux(2)) s.extend(cg.relabel(m2, {n: f'm2_0_{n}' for n in m2.nodes()})) s.extend(cg.relabel(m2, {n: f'm2_1_{n}' for n in m2.nodes()})) m4 = cg.strip_io(logic.mux(4)) s.extend(cg.relabel(m4, {n: f'm4_0_{n}' for n in m4.nodes()})) s.extend(cg.relabel(m4, {n: f'm4_1_{n}' for n in m4.nodes()})) s.add('in_0', 'buf', fanout=['m2_0_in_0', 'm2_1_in_1']) s.add('in_1', 'buf', fanout=['m2_0_in_1', 'm2_1_in_0']) s.add('out_0', 'buf', fanin='m4_0_out') s.add('out_1', 'buf', fanin='m4_1_out') s.add('key_0', 'input', fanout=['m2_0_sel_0', 'm2_1_sel_0']) s.add('key_1', 'input', fanout=['m4_0_sel_0', 'm4_1_sel_0']) s.add('key_2', 'input', fanout=['m4_0_sel_1', 'm4_1_sel_1']) # generate banyan I = int(2 * cg.clog2(bw) - 2) J = int(bw / 2) # add switches and muxes for i in range(I * J): cl.extend(cg.relabel(s, {n: f'swb_{i}_{n}' for n in s})) # make connections swb_ins = [f'swb_{i//2}_in_{i%2}' for i in range(I * J * 2)] swb_outs = [f'swb_{i//2}_out_{i%2}' for i in range(I * J * 2)] connect_banyan(cl, swb_ins, swb_outs, bw) # get banyan io net_ins = swb_ins[:bw] net_outs = swb_outs[-bw:] # generate key key = { f'swb_{i//3}_key_{i%3}': choice([True, False]) for i in range(3 * I * J) } # generate connections between banyan nodes bfi = {n: set() for n in swb_outs + net_ins} bfo = {n: set() for n in swb_outs + net_ins} for n in swb_outs + net_ins: if cl.fanout(n): fo_node = cl.fanout(n).pop() swb_i = fo_node.split('_')[1] bfi[f'swb_{swb_i}_out_0'].add(n) bfi[f'swb_{swb_i}_out_1'].add(n) bfo[n].add(f'swb_{swb_i}_out_0') bfo[n].add(f'swb_{swb_i}_out_1') # find a mapping of circuit onto banyan net_map = IDPool() for bn in swb_outs + net_ins: for cn in c: net_map.id(f'm_{bn}_{cn}') # mapping implications clauses = [] for bn in swb_outs + net_ins: # fanin if bfi[bn]: for cn in c: if c.fanin(cn): for fcn in c.fanin(cn): clause = [-net_map.id(f'm_{bn}_{cn}')] clause += [ net_map.id(f'm_{fbn}_{fcn}') for fbn in bfi[bn] ] clause += [ net_map.id(f'm_{fbn}_{cn}') for fbn in bfi[bn] ] clauses.append(clause) else: clause = [-net_map.id(f'm_{bn}_{cn}')] clause += [net_map.id(f'm_{fbn}_{cn}') for fbn in bfi[bn]] clauses.append(clause) # fanout if bfo[bn]: for cn in c: clause = [-net_map.id(f'm_{bn}_{cn}')] clause += [net_map.id(f'm_{fbn}_{cn}') for fbn in bfo[bn]] for fcn in c.fanout(cn): clause += [net_map.id(f'm_{fbn}_{fcn}') for fbn in bfo[bn]] clauses.append(clause) # no feed through for cn in c: net_map.id(f'INPUT_OR_{cn}') net_map.id(f'OUTPUT_OR_{cn}') clauses.append([-net_map.id(f'INPUT_OR_{cn}')] + [net_map.id(f'm_{bn}_{cn}') for bn in net_ins]) clauses.append([-net_map.id(f'OUTPUT_OR_{cn}')] + [net_map.id(f'm_{bn}_{cn}') for bn in net_outs]) for bn in net_ins: clauses.append( [net_map.id(f'INPUT_OR_{cn}'), -net_map.id(f'm_{bn}_{cn}')]) for bn in net_outs: clauses.append( [net_map.id(f'OUTPUT_OR_{cn}'), -net_map.id(f'm_{bn}_{cn}')]) clauses.append( [-net_map.id(f'OUTPUT_OR_{cn}'), -net_map.id(f'INPUT_OR_{cn}')]) # at least ngates for bn in swb_outs + net_ins: net_map.id(f'NGATES_OR_{bn}') clauses.append([-net_map.id(f'NGATES_OR_{bn}')] + [net_map.id(f'm_{bn}_{cn}') for cn in c]) for cn in c: clauses.append( [net_map.id(f'NGATES_OR_{bn}'), -net_map.id(f'm_{bn}_{cn}')]) clauses += CardEnc.atleast( bound=ng, lits=[net_map.id(f'NGATES_OR_{bn}') for bn in swb_outs + net_ins], vpool=net_map).clauses # at most one mapping per out for bn in swb_outs + net_ins: clauses += CardEnc.atmost(lits=[ net_map.id(f'm_{bn}_{cn}') for cn in c ], vpool=net_map).clauses # limit number of times a gate is mapped to net outputs to fanout of gate for cn in c: lits = [net_map.id(f'm_{bn}_{cn}') for bn in net_outs] bound = len(c.fanout(cn)) if len(lits) < bound: continue clauses += CardEnc.atmost(bound=bound, lits=lits, vpool=net_map).clauses # prohibit outputs from net for bn in swb_outs + net_ins: for cn in c.outputs(): clauses += [[-net_map.id(f'm_{bn}_{cn}')]] # solve solver = Cadical(bootstrap_with=clauses) if not solver.solve(): print(f'no config for width: {bw}') core = solver.get_core() print(core) code.interact(local=dict(globals(), **locals())) model = solver.get_model() # get mapping mapping = {} for bn in swb_outs + net_ins: selected_gates = [ cn for cn in c if model[net_map.id(f'm_{bn}_{cn}') - 1] > 0 ] if len(selected_gates) > 1: print(f'multiple gates mapped to: {bn}') code.interact(local=dict(globals(), **locals())) mapping[bn] = selected_gates[0] if selected_gates else None potential_net_fanins = list(c.nodes() - (c.endpoints() | set(mapping.values()) | mapping.keys() | c.startpoints())) # connect net inputs for bn in net_ins: if mapping[bn]: cl.connect(mapping[bn], bn) else: cl.connect(choice(potential_net_fanins), bn) mapping.update({cl.fanin(bn).pop(): cl.fanin(bn).pop() for bn in net_ins}) potential_net_fanouts = list(c.nodes() - (c.startpoints() | set(mapping.values()) | mapping.keys() | c.endpoints())) #selected_fo = {} # connect switch boxes for i, bn in enumerate(swb_outs): # get keys if key[f'swb_{i//2}_key_1'] and key[f'swb_{i//2}_key_2']: k = 3 elif not key[f'swb_{i//2}_key_1'] and key[f'swb_{i//2}_key_2']: k = 2 elif key[f'swb_{i//2}_key_1'] and not key[f'swb_{i//2}_key_2']: k = 1 elif not key[f'swb_{i//2}_key_1'] and not key[f'swb_{i//2}_key_2']: k = 0 switch_key = 1 if key[f'swb_{i//2}_key_0'] == 1 else 0 mux_input = f'swb_{i//2}_m4_{i%2}_in_{k}' # connect inner nodes mux_gate_types = set() # constant output, hookup to a node that is already in the affected outputs fanin, not in others if not mapping[bn] and bn in net_outs: decoy_fanout_gate = choice(potential_net_fanouts) #selected_fo[bn] = decoy_fanout_gate cl.connect(bn, decoy_fanout_gate) if cl.type(decoy_fanout_gate) in ['and', 'nand']: cl.set_type(mux_input, '1') elif cl.type(decoy_fanout_gate) in ['or', 'nor', 'xor', 'xnor']: cl.set_type(mux_input, '0') elif cl.type(decoy_fanout_gate) in ['buf']: if randint(0, 1): cl.set_type(mux_input, '1') cl.set_type(decoy_fanout_gate, choice(['and', 'xnor'])) else: cl.set_type(mux_input, '0') cl.set_type(decoy_fanout_gate, choice(['or', 'xor'])) elif cl.type(decoy_fanout_gate) in ['not']: if randint(0, 1): cl.set_type(mux_input, '1') cl.set_type(decoy_fanout_gate, choice(['nand', 'xor'])) else: cl.set_type(mux_input, '0') cl.set_type(decoy_fanout_gate, choice(['nor', 'xnor'])) elif cl.type(decoy_fanout_gate) in ['0', '1']: cl.set_type(mux_input, cl.type(decoy_fanout_gate)) cl.set_type(decoy_fanout_gate, 'buf') else: print('gate error') code.interact(local=dict(globals(), **locals())) mux_gate_types.add(cl.type(mux_input)) # feedthrough elif mapping[bn] in [mapping[fbn] for fbn in bfi[bn]]: cl.set_type(mux_input, 'buf') mux_gate_types.add('buf') if mapping[cl.fanin(f'swb_{i//2}_in_0').pop()] == mapping[bn]: cl.connect(f'swb_{i//2}_m2_{switch_key}_out', mux_input) else: cl.connect(f'swb_{i//2}_m2_{1-switch_key}_out', mux_input) # gate elif mapping[bn]: cl.set_type(mux_input, cl.type(mapping[bn])) mux_gate_types.add(cl.type(mapping[bn])) gfi = cl.fanin(mapping[bn]) if mapping[cl.fanin(f'swb_{i//2}_in_0').pop()] in gfi: cl.connect(f'swb_{i//2}_m2_{switch_key}_out', mux_input) gfi.remove(mapping[cl.fanin(f'swb_{i//2}_in_0').pop()]) if mapping[cl.fanin(f'swb_{i//2}_in_1').pop()] in gfi: cl.connect(f'swb_{i//2}_m2_{1-switch_key}_out', mux_input) # mapped to None, any key works else: k = None # fill out random gates for j in range(4): if j != k: t = sample( set([ 'buf', 'or', 'nor', 'and', 'nand', 'not', 'xor', 'xnor', '0', '1' ]) - mux_gate_types, 1)[0] mux_gate_types.add(t) mux_input = f'swb_{i//2}_m4_{i%2}_in_{j}' cl.set_type(mux_input, t) if t == 'not' or t == 'buf': # pick a random fanin cl.connect(f'swb_{i//2}_m2_{randint(0,1)}_out', mux_input) elif t == '1' or t == '0': pass else: cl.connect(f'swb_{i//2}_m2_0_out', mux_input) cl.connect(f'swb_{i//2}_m2_1_out', mux_input) if [ n for n in cl if cl.type(n) in ['buf', 'not'] and len(cl.fanin(n)) > 1 ]: import code code.interact(local=dict(globals(), **locals())) # connect outputs non constant outs rev_mapping = {} for bn in net_outs: if mapping[bn]: if mapping[bn] not in rev_mapping: rev_mapping[mapping[bn]] = set() rev_mapping[mapping[bn]].add(bn) for cn in rev_mapping.keys(): #for fcn in cl.fanout(cn): # cl.connect(sample(rev_mapping[cn],1)[0],fcn) for fcn, bn in zip_longest(cl.fanout(cn), rev_mapping[cn], fillvalue=list(rev_mapping[cn])[0]): cl.connect(bn, fcn) # delete mapped gates deleted = True while deleted: deleted = False for n in cl.nodes(): # node and all fanout are in the net if n not in mapping and n in mapping.values(): if all(s not in mapping and s in mapping.values() for s in cl.fanout(n)): cl.remove(n) deleted = True # node in net fanout if n in [mapping[o] for o in net_outs] and n in cl: cl.remove(n) deleted = True cg.lint(cl) return cl, key