def gen_lat(path): clk = vast.Ioport(vast.Input('en')) q = vast.Ioport(vast.Output('Q')) d = vast.Ioport(vast.Input('D')) r = vast.Ioport(vast.Input('rst')) ports = vast.Portlist([clk, q, d, r]) q_reg = vast.Identifier('reg Q = 0;') sens = [] sens.append(vast.Sens(vast.Identifier('en'), type='level')) sens.append(vast.Sens(vast.Identifier('rst'), type='level')) sens.append(vast.Sens(vast.Identifier('D'), type='level')) senslist = vast.SensList(sens) assign_q = vast.NonblockingSubstitution(vast.Lvalue(vast.Identifier('Q')), vast.Rvalue(vast.Identifier('D'))) blocks = [] blocks.append( vast.IfStatement( vast.Identifier('rst'), vast.Identifier('Q <= 0;'), vast.IfStatement(vast.Identifier('en'), assign_q, None), None)) statement = vast.Block(blocks) always = vast.Always(senslist, statement) items = [] items.append(q_reg) items.append(always) ast = vast.ModuleDef("lat", None, ports, items) write_verilog(ast, 'lat.v', path)
def visit_Always(self, node): sens = (tuple([self.visit(n) if isinstance(n, vtypes.Sensitive) else vast.Sens(self.always_visitor.visit(n), 'level') for n in node.sensitivity]) if node.sensitivity else tuple([vast.Sens(None, 'all')])) sensitivity = vast.SensList(sens) statement = self._optimize_block( vast.Block(tuple([self.always_visitor.visit(n) for n in node.statement]))) return vast.Always(sensitivity, statement)
def umc_module(self, dip_list): # create umc module portslist = [] portslist.append(vast.Ioport(vast.Input('clk'))) portslist.append(vast.Ioport(vast.Input('iv', width=self.inp_width))) portslist = vast.Portlist(portslist) inst_list = [] # add instance for dip_generator dip_ports = [vast.PortArg("", vast.Identifier('clk'))] dip_ports.append(vast.PortArg("", vast.Identifier('iv'))) dip_ports.append(vast.PortArg("", vast.Identifier('k1'))) dip_ports.append(vast.PortArg("", vast.Identifier('k2'))) inst = vast.Instance('dip_generator', 'dg', dip_ports, "") inst_list.append(vast.InstanceList('dip_generator', "", [inst])) if len(dip_list) > 0: for i, dip in enumerate(dip_list): inst_list.append( vast.Identifier('reg [{}:0] dip{} [0:{}];'.format( self.input_size_msb, i, len(dip) - 1))) blocks = [] # add always block blocks.append(vast.Identifier('cycle <= cycle + 1;')) statement = vast.Block(blocks) sens = vast.Sens(vast.Identifier('clk'), type='posedge') inst_list.append(vast.Always(vast.SensList([sens]), statement)) statement = vast.Block([vast.Identifier(self.key_constraints)]) sens = vast.Sens(None, type='all') inst_list.append(vast.Always(vast.SensList([sens]), statement)) if len(dip_list) > 0: for i, dip in enumerate(dip_list): inst_list.append( vast.Identifier( 'dip_checker dc{} (clk, dip{}[cycle], k1, k2);'.format( i, i))) # add initial block blocks = [] for i, dip in enumerate(dip_list): for j, inp in enumerate(dip): blocks.append( vast.Identifier('dip{}[{}] <= {};'.format(i, j, inp))) statement = vast.Block(blocks) sens = vast.Sens(vast.Identifier('clk'), type='posedge') inst_list.append(vast.Always(vast.SensList([sens]), statement)) self.umc = vast.ModuleDef("umc", None, portslist, inst_list)
def visit_Event(self, node): sensitivity = vast.SensList( tuple([ self.visit(n) if isinstance(n, vtypes.Sensitive) else vast.Sens(self.visit(n)) for n in node.sensitivity ])) return vast.EventStatement(sensitivity)
def dip_chk_module(self): # create dip_checker module portslist = [] portslist.append(vast.Ioport(vast.Input('clk'))) portslist.append(vast.Ioport(vast.Input('iv', width=self.inp_width))) portslist.append(vast.Ioport(vast.Input('k1', width=self.key_width))) portslist.append(vast.Ioport(vast.Input('k2', width=self.key_width))) portslist = vast.Portlist(portslist) inst_list = [] inst_list.append(vast.Wire('ov0', width=self.out_width)) inst_list.append(vast.Wire('ov1', width=self.out_width)) inst_list.append(vast.Wire('ov2', width=self.out_width)) inst_list.extend([self.org0, self.obf1, self.obf2]) # add always block sens = vast.Sens(vast.Identifier('clk'), type='posedge') senslist = vast.SensList([sens]) blocks = [] blocks.append(vast.Identifier('assume (ov0 == ov1);')) blocks.append(vast.Identifier('assume (ov0 == ov2);')) statement = vast.Block(blocks) inst_list.append(vast.Always(senslist, statement)) self.dip_chk = vast.ModuleDef("dip_checker", None, portslist, inst_list)
def dip_gen_module_uc(self): # create dip_generator module # this is the assume based version # slower than dip_gen_module() due to the assume portslist = [] portslist.append(vast.Ioport(vast.Input('clk'))) portslist.append(vast.Ioport(vast.Input('iv', width=self.inp_width))) portslist.append(vast.Ioport(vast.Input('k1', width=self.key_width))) portslist.append(vast.Ioport(vast.Input('k2', width=self.key_width))) portslist = vast.Portlist(portslist) inst_list = [] inst_list.append(vast.Wire('ov1', width=self.out_width)) inst_list.append(vast.Wire('ov2', width=self.out_width)) inst_list.extend([self.obf1, self.obf2]) # add always* block blocks = [] blocks.append(vast.Identifier('assume (k1 != k2);')) blocks.append(vast.Identifier('assert (ov1 == ov2);')) statement = vast.Block(blocks) # sens = vast.Sens(None, type='all') sens = vast.Sens(vast.Identifier('clk'), type='posedge') inst_list.append(vast.Always(vast.SensList([sens]), statement)) self.dip_gen = vast.ModuleDef("dip_generator", None, portslist, inst_list)
def fk_module(self, dip_list, skip_cycles, equal_keys): # this module finds the correct keys after termination inst_list = [] if len(dip_list) > 0: for i, dip in enumerate(dip_list): inst_list.append( vast.Identifier('reg [{}:0] dip{} [0:{}];'.format( self.input_size_msb, i, len(dip) - 1))) # add always block blocks = [] if equal_keys: blocks.append(vast.Identifier('assume (k1 == k2);')) else: blocks.append(vast.Identifier('assume (k1 != k2);')) blocks.append(vast.Identifier('cycle <= cycle + 1;')) blocks.append(vast.Identifier('if (cycle == {})'.format(skip_cycles))) blocks.append(vast.Identifier(' cover(1);')) statement = vast.Block(blocks) sens = vast.Sens(vast.Identifier('clk'), type='posedge') inst_list.append(vast.Always(vast.SensList([sens]), statement)) statement = vast.Block([vast.Identifier(self.key_constraints)]) sens = vast.Sens(None, type='all') inst_list.append(vast.Always(vast.SensList([sens]), statement)) if len(dip_list) > 0: for i, dip in enumerate(dip_list): inst_list.append( vast.Identifier( 'dip_checker dc{} (clk, dip{}[cycle], k1, k2);'.format( i, i))) # add initial block blocks = [] sens = vast.Sens(vast.Identifier('clk'), type='posedge') senslist = vast.SensList([sens]) for i, dip in enumerate(dip_list): for j, inp in enumerate(dip): blocks.append( vast.Identifier('dip{}[{}] = {};'.format(i, j, inp))) statement = vast.Block(blocks) inst_list.append(vast.Always(senslist, statement)) portslist = vast.Portlist([]) self.ce = vast.ModuleDef("ce", None, portslist, inst_list)
def test(): datawid = vast.Parameter( 'DATAWID', vast.Rvalue(vast.IntConst('32')) ) params = vast.Paramlist( [datawid] ) clk = vast.Ioport( vast.Input('CLK') ) rst = vast.Ioport( vast.Input('RST') ) width = vast.Width( vast.IntConst('7'), vast.IntConst('0') ) led = vast.Ioport( vast.Output('led', width=width) ) ports = vast.Portlist( [clk, rst, led] ) width = vast.Width( vast.Minus(vast.Identifier('DATAWID'), vast.IntConst('1')), vast.IntConst('0') ) count = vast.Reg('count', width=width) assign = vast.Assign( vast.Lvalue(vast.Identifier('led')), vast.Rvalue( vast.Partselect( vast.Identifier('count'), # count vast.Minus(vast.Identifier('DATAWID'), vast.IntConst('1')), # [DATAWID-1: vast.Minus(vast.Identifier('DATAWID'), vast.IntConst('8'))))) # :DATAWID-8] sens = vast.Sens(vast.Identifier('CLK'), type='posedge') senslist = vast.SensList([ sens ]) assign_count_true = vast.NonblockingSubstitution( vast.Lvalue(vast.Identifier('count')), vast.Rvalue(vast.IntConst('0'))) if0_true = vast.Block([ assign_count_true ]) # (count + 1) * 2 count_plus_1 = vast.Plus(vast.Identifier('count'), vast.IntConst('1')) cp1_times_2 = vast.Times(count_plus_1, vast.IntConst('2')) cp1t2_plus_1 = vast.Plus(cp1_times_2, vast.IntConst('1')) assign_count_false = vast.NonblockingSubstitution( vast.Lvalue(vast.Identifier('count')), vast.Rvalue(cp1t2_plus_1)) if0_false = vast.Block([ assign_count_false ]) if0 = vast.IfStatement(vast.Identifier('RST'), if0_true, if0_false) statement = vast.Block([ if0 ]) always = vast.Always(senslist, statement) items = [] items.append(count) items.append(assign) items.append(always) ast = vast.ModuleDef("top", params, ports, items) codegen = ASTCodeGenerator() rslt = codegen.visit(ast) print(rslt) assert(expected == rslt)
def gen_dff(path): clk = vast.Ioport(vast.Input('clk')) q = vast.Ioport(vast.Output('Q')) d = vast.Ioport(vast.Input('D')) ports = vast.Portlist([clk, q, d]) q_reg = vast.Identifier('reg Q = 0;') sens = vast.Sens(vast.Identifier('clk'), type='posedge') senslist = vast.SensList([sens]) assign_q = vast.NonblockingSubstitution(vast.Lvalue(vast.Identifier('Q')), vast.Rvalue(vast.Identifier('D'))) statement = vast.Block([assign_q]) always = vast.Always(senslist, statement) items = [] items.append(q_reg) items.append(always) ast = vast.ModuleDef("dff", None, ports, items) write_verilog(ast, 'dff.v', path)
def dip_gen_module(self): # creates dip_generator module # a more complicated version of dip_gen_module_uc # uses if in place of assume, it is faster but couldn't be used for uc termination portslist = [] portslist.append(vast.Ioport(vast.Input('clk'))) portslist.append(vast.Ioport(vast.Input('iv', width=self.inp_width))) portslist.append(vast.Ioport(vast.Input('k1', width=self.key_width))) portslist.append(vast.Ioport(vast.Input('k2', width=self.key_width))) portslist = vast.Portlist(portslist) inst_list = [] inst_list.append(vast.Wire('ov1', width=self.out_width)) inst_list.append(vast.Wire('ov2', width=self.out_width)) inst_list.extend([self.obf1, self.obf2]) # add always* block blocks = [] # blocks.append(vast.IfStatement(vast.Identifier('k1 != k2'), # vast.IfStatement(vast.Identifier('ov1 != ov2'), # vast.Identifier('assert (ov1 == ov2);'), None), # None)) # TODO: changed for latch locking blocks.append(vast.Identifier('assume (k1 != k2);')) blocks.append(vast.Identifier('assert (ov1 == ov2);')) statement = vast.Block(blocks) # TODO: posedge in case of latch # sens = vast.Sens(None, type='all') sens = vast.Sens(vast.Identifier('clk'), type='posedge') inst_list.append(vast.Always(vast.SensList([sens]), statement)) self.dip_gen = vast.ModuleDef("dip_generator", None, portslist, inst_list)
def visit_Negedge(self, node): sig = self.visit(node.name) t = 'negedge' return vast.Sens(sig, t)
def visit_SensitiveAll(self, node): sig = None t = 'all' return vast.Sens(sig, t)
def visit_Posedge(self, node): sig = self.bind_visitor.visit(node.name) t = 'posedge' return vast.Sens(sig, t)
def ce_module(self, module_name, dip_list): # module for checking ce state_width = vast.Width(vast.IntConst(self.state_size_msb), vast.IntConst('0')) ce_name = module_name + '_ce' step = 1 # module port list portslist = [] portslist.append(vast.Ioport(vast.Input('clk'))) portslist.append( vast.Ioport(vast.Input('ce_iv_s0', width=self.inp_width))) portslist.append( vast.Ioport(vast.Input('ce_state_s0', width=state_width))) portslist = vast.Portlist(portslist) # create other components for dis_generator inst_list = [] inst_list.append(vast.Wire('ce_state1_s0', width=state_width)) inst_list.append(vast.Wire('ce_state2_s0', width=state_width)) inst_list.append(vast.Wire('ce_state1_s1', width=state_width)) inst_list.append(vast.Wire('ce_state2_s1', width=state_width)) inst_list.append(vast.Wire('ce_ov1_s0', width=self.out_width)) inst_list.append(vast.Wire('ce_ov2_s0', width=self.out_width)) inst_list.append(vast.Identifier('assign ce_state1_s0 = ce_state_s0;')) inst_list.append(vast.Identifier('assign ce_state2_s0 = ce_state_s0;')) for s in range(step): # create instances for obf1_ce, obf2_ce ports = create_ports('ce_iv_s' + str(s), 'ce_ov1_s' + str(s), self.orcl_cir) key_ports = [vast.PortArg("", vast.Identifier('k1'))] state_ports = [ vast.PortArg("", vast.Identifier('ce_state1_s{}'.format(s))) ] nstate_ports = [ vast.PortArg("", vast.Identifier('ce_state1_s{}'.format(s + 1))) ] inst = vast.Instance( ce_name, "obf1_ce_s{}".format(s), ports + key_ports + state_ports + nstate_ports, "") obf1_ce = vast.InstanceList(ce_name, "", [inst]) ports = create_ports('ce_iv_s' + str(s), 'ce_ov2_s' + str(s), self.orcl_cir) state_ports = [ vast.PortArg("", vast.Identifier('ce_state2_s{}'.format(s))) ] key_ports = [vast.PortArg("", vast.Identifier('k2'))] nstate_ports = [ vast.PortArg("", vast.Identifier('ce_state2_s{}'.format(s + 1))) ] inst = vast.Instance( ce_name, 'obf2_ce_s{}'.format(s), ports + key_ports + state_ports + nstate_ports, "") obf2_ce = vast.InstanceList(ce_name, "", [inst]) inst_list.extend([obf1_ce, obf2_ce]) # add always block sens = vast.Sens(vast.Identifier('clk'), type='posedge') senslist = vast.SensList([sens]) blocks = [] for s in range(step): blocks.append( vast.Identifier('assert (ce_ov1_s{} == ce_ov2_s{});'.format( s, s))) blocks.append( vast.Identifier( 'assert (ce_state1_s{} == ce_state2_s{});'.format( s + 1, s + 1))) if len(dip_list) > 0: for i, dip in enumerate(dip_list): inst_list.append( vast.Identifier('reg [{}:0] dip{} [0:{}];'.format( self.input_size_msb, i, len(dip) - 1))) # add always block blocks.append(vast.Identifier('cycle <= cycle + 1;')) statement = vast.Block(blocks) sens = vast.Sens(vast.Identifier('clk'), type='posedge') inst_list.append(vast.Always(vast.SensList([sens]), statement)) if len(dip_list) > 0: for i, dip in enumerate(dip_list): inst_list.append( vast.Identifier( 'dip_checker dc{} (clk, dip{}[cycle], k1, k2);'.format( i, i))) # add always block blocks = [] for i, dip in enumerate(dip_list): for j, inp in enumerate(dip): blocks.append( vast.Identifier('dip{}[{}] = {};'.format(i, j, inp))) statement = vast.Block(blocks) inst_list.append(vast.Always(senslist, statement)) # for s in range(step): # blocks.append(vast.Identifier('assume (ce_state1_s{} != ce_state2_s{});'.format(s+1, s+1))) statement = vast.Block(blocks) inst_list.append(vast.Always(senslist, statement)) self.ce = vast.ModuleDef("ce", None, portslist, inst_list)
def lower_fsm(fsm: FSM): assert fsm.start_state == 0, "Starting state is not 0" zero = vast.IntConst(0) all_registers = [ Register("a", 8), Register("b", 8), Register("tmp", 8), Register("_cond", 1), ] register_defs = [ vast.Reg( reg.name, vast.Width(vast.IntConst(reg.width - 1), zero) if reg.width - 1 != 0 else None) for reg in all_registers ] ports = vast.Portlist([ vast.Ioport(vast.Input('clk')), # XXX(rachit): AST can't represent `output reg done` # so assign to a local register and use a wire. vast.Ioport(vast.Output('done')), ]) done_state = max(fsm.actions.keys()) + 1 done_reg = vast.Reg('done_out') hook_up_done = vast.Assign( vast.Lvalue(vast.Identifier('done')), vast.Rvalue(vast.Identifier('done_out')), ) # Register to store the FSM state. fsm_reg_size = int(math.ceil(math.log2(done_state))) + 1 fsm_reg = vast.Reg(name="fsm_reg", width=vast.Width(vast.IntConst(fsm_reg_size - 1), zero)) # Define all the registers. reg_decl = register_defs + [fsm_reg] # Define the initial process inits = vast.Initial( vast.Block([ vast.BlockingSubstitution( vast.Lvalue(vast.Identifier(reg.name)), vast.Rvalue(vast.IntConst(0)), ) for reg in reg_decl ])) # Default case, assigns to the done register. done = vast.IntConst(done_state) default_case = vast.Case(cond=None, statement=vast.Block([ vast.NonblockingSubstitution( vast.Lvalue(vast.Identifier(reg.name)), vast.Rvalue(vast.Identifier(reg.name)), ) for reg in reg_decl ] + [ vast.NonblockingSubstitution( vast.Lvalue(vast.Identifier('done_out')), vast.Rvalue(vast.IntConst(1))) ])) # Generate Case conditions for each transition. cases = [ vast.Case([vast.IntConst(cond_val)], lower_action(fsm_reg, done, action)) for (cond_val, action) in fsm.actions.items() ] case_statement = vast.CaseStatement(comp=vast.Identifier(fsm_reg.name), caselist=cases + [default_case]) always_ff = vast.Always( vast.SensList([vast.Sens(vast.Identifier('clk'), 'posedge')]), vast.Block([case_statement])) return vast.ModuleDef(name="main", paramlist=vast.Paramlist([]), portlist=ports, items=reg_decl + [done_reg, hook_up_done, inits, always_ff])