Ejemplo n.º 1
0
 def visit_GenerateIf(self, node):
     genif_visitor = VerilogBlockingVisitor()
     cond = genif_visitor.visit(node.cond)
     true_items = self._visit_Generate(node)
     true_block = vast.Block(true_items, scope=node.true_scope)
     false_items = self._visit_Generate(node.Else)
     false_block = (vast.Block(false_items, scope=node.Else.false_scope)
                    if false_items else None)
     _if = vast.IfStatement(cond, true_block, false_block)
     return vast.GenerateStatement(tuple([_if]))
Ejemplo n.º 2
0
def test():
    datawid = vast.Parameter( 'DATAWID', vast.Rvalue(vast.IntConst('32')) )
    params = vast.Paramlist( [datawid] )
    clk = vast.Ioport( vast.Input('CLK') )
    rst = vast.Ioport( vast.Input('RST') )
    width = vast.Width( vast.IntConst('7'), vast.IntConst('0') )
    led = vast.Ioport( vast.Output('led', width=width) )
    ports = vast.Portlist( [clk, rst, led] )

    width = vast.Width( vast.Minus(vast.Identifier('DATAWID'), vast.IntConst('1')), vast.IntConst('0') )
    count = vast.Reg('count', width=width)

    assign = vast.Assign(
        vast.Lvalue(vast.Identifier('led')), 
        vast.Rvalue(
            vast.Partselect(
                vast.Identifier('count'), # count
                vast.Minus(vast.Identifier('DATAWID'), vast.IntConst('1')), # [DATAWID-1:
                vast.Minus(vast.Identifier('DATAWID'), vast.IntConst('8'))))) # :DATAWID-8]

    sens = vast.Sens(vast.Identifier('CLK'), type='posedge')
    senslist = vast.SensList([ sens ])

    assign_count_true = vast.NonblockingSubstitution(
        vast.Lvalue(vast.Identifier('count')),
        vast.Rvalue(vast.IntConst('0')))
    if0_true = vast.Block([ assign_count_true ])

    # (count + 1) * 2
    count_plus_1 = vast.Plus(vast.Identifier('count'), vast.IntConst('1'))
    cp1_times_2 = vast.Times(count_plus_1, vast.IntConst('2'))
    cp1t2_plus_1 = vast.Plus(cp1_times_2, vast.IntConst('1'))
    assign_count_false = vast.NonblockingSubstitution(
        vast.Lvalue(vast.Identifier('count')),
        vast.Rvalue(cp1t2_plus_1))
    if0_false = vast.Block([ assign_count_false ])

    if0 = vast.IfStatement(vast.Identifier('RST'), if0_true, if0_false)
    statement = vast.Block([ if0 ])

    always = vast.Always(senslist, statement)

    items = []
    items.append(count)
    items.append(assign)
    items.append(always)

    ast = vast.ModuleDef("top", params, ports, items)
    
    codegen = ASTCodeGenerator()
    rslt = codegen.visit(ast)
    print(rslt)
    
    assert(expected == rslt)
Ejemplo n.º 3
0
    def umc_module(self, dip_list):
        # create umc module
        portslist = []
        portslist.append(vast.Ioport(vast.Input('clk')))
        portslist.append(vast.Ioport(vast.Input('iv', width=self.inp_width)))
        portslist = vast.Portlist(portslist)

        inst_list = []
        # add instance for dip_generator
        dip_ports = [vast.PortArg("", vast.Identifier('clk'))]
        dip_ports.append(vast.PortArg("", vast.Identifier('iv')))
        dip_ports.append(vast.PortArg("", vast.Identifier('k1')))
        dip_ports.append(vast.PortArg("", vast.Identifier('k2')))
        inst = vast.Instance('dip_generator', 'dg', dip_ports, "")
        inst_list.append(vast.InstanceList('dip_generator', "", [inst]))

        if len(dip_list) > 0:
            for i, dip in enumerate(dip_list):
                inst_list.append(
                    vast.Identifier('reg [{}:0] dip{} [0:{}];'.format(
                        self.input_size_msb, i,
                        len(dip) - 1)))

        blocks = []
        # add always block
        blocks.append(vast.Identifier('cycle <= cycle + 1;'))
        statement = vast.Block(blocks)
        sens = vast.Sens(vast.Identifier('clk'), type='posedge')
        inst_list.append(vast.Always(vast.SensList([sens]), statement))

        statement = vast.Block([vast.Identifier(self.key_constraints)])
        sens = vast.Sens(None, type='all')
        inst_list.append(vast.Always(vast.SensList([sens]), statement))

        if len(dip_list) > 0:
            for i, dip in enumerate(dip_list):
                inst_list.append(
                    vast.Identifier(
                        'dip_checker dc{} (clk, dip{}[cycle], k1, k2);'.format(
                            i, i)))

        # add initial block
        blocks = []
        for i, dip in enumerate(dip_list):
            for j, inp in enumerate(dip):
                blocks.append(
                    vast.Identifier('dip{}[{}] <= {};'.format(i, j, inp)))
        statement = vast.Block(blocks)
        sens = vast.Sens(vast.Identifier('clk'), type='posedge')
        inst_list.append(vast.Always(vast.SensList([sens]), statement))

        self.umc = vast.ModuleDef("umc", None, portslist, inst_list)
Ejemplo n.º 4
0
 def visit_Forever(self, node):
     if self.for_verilator:
         return vast.SingleStatement(
             vast.SystemCall('write', (vast.StringConst(''), )))
     statement = self._optimize_block(
         vast.Block(tuple([self.visit(s) for s in node.statement])))
     return vast.ForeverStatement(statement)
Ejemplo n.º 5
0
    def dip_chk_module(self):
        # create dip_checker module
        portslist = []
        portslist.append(vast.Ioport(vast.Input('clk')))
        portslist.append(vast.Ioport(vast.Input('iv', width=self.inp_width)))
        portslist.append(vast.Ioport(vast.Input('k1', width=self.key_width)))
        portslist.append(vast.Ioport(vast.Input('k2', width=self.key_width)))
        portslist = vast.Portlist(portslist)

        inst_list = []
        inst_list.append(vast.Wire('ov0', width=self.out_width))
        inst_list.append(vast.Wire('ov1', width=self.out_width))
        inst_list.append(vast.Wire('ov2', width=self.out_width))
        inst_list.extend([self.org0, self.obf1, self.obf2])

        # add always block
        sens = vast.Sens(vast.Identifier('clk'), type='posedge')
        senslist = vast.SensList([sens])

        blocks = []
        blocks.append(vast.Identifier('assume (ov0 == ov1);'))
        blocks.append(vast.Identifier('assume (ov0 == ov2);'))
        statement = vast.Block(blocks)
        inst_list.append(vast.Always(senslist, statement))
        self.dip_chk = vast.ModuleDef("dip_checker", None, portslist,
                                      inst_list)
Ejemplo n.º 6
0
    def dip_gen_module_uc(self):
        # create dip_generator module
        # this is the assume based version
        # slower than dip_gen_module() due to the assume
        portslist = []
        portslist.append(vast.Ioport(vast.Input('clk')))
        portslist.append(vast.Ioport(vast.Input('iv', width=self.inp_width)))
        portslist.append(vast.Ioport(vast.Input('k1', width=self.key_width)))
        portslist.append(vast.Ioport(vast.Input('k2', width=self.key_width)))
        portslist = vast.Portlist(portslist)

        inst_list = []
        inst_list.append(vast.Wire('ov1', width=self.out_width))
        inst_list.append(vast.Wire('ov2', width=self.out_width))
        inst_list.extend([self.obf1, self.obf2])

        # add always* block
        blocks = []
        blocks.append(vast.Identifier('assume (k1 != k2);'))
        blocks.append(vast.Identifier('assert (ov1 == ov2);'))
        statement = vast.Block(blocks)

        # sens = vast.Sens(None, type='all')
        sens = vast.Sens(vast.Identifier('clk'), type='posedge')
        inst_list.append(vast.Always(vast.SensList([sens]), statement))

        self.dip_gen = vast.ModuleDef("dip_generator", None, portslist,
                                      inst_list)
Ejemplo n.º 7
0
 def visit_Wait(self, node):
     cond = self.visit(node.condition)
     if node.statement is None:
         return vast.WaitStatement(cond, None)
     statement = self._optimize_block(
         vast.Block(tuple([self.visit(s) for s in node.statement])))
     return vast.WaitStatement(cond, statement)
Ejemplo n.º 8
0
def gen_lat(path):
    clk = vast.Ioport(vast.Input('en'))
    q = vast.Ioport(vast.Output('Q'))
    d = vast.Ioport(vast.Input('D'))
    r = vast.Ioport(vast.Input('rst'))
    ports = vast.Portlist([clk, q, d, r])

    q_reg = vast.Identifier('reg Q = 0;')

    sens = []
    sens.append(vast.Sens(vast.Identifier('en'), type='level'))
    sens.append(vast.Sens(vast.Identifier('rst'), type='level'))
    sens.append(vast.Sens(vast.Identifier('D'), type='level'))

    senslist = vast.SensList(sens)

    assign_q = vast.NonblockingSubstitution(vast.Lvalue(vast.Identifier('Q')),
                                            vast.Rvalue(vast.Identifier('D')))

    blocks = []
    blocks.append(
        vast.IfStatement(
            vast.Identifier('rst'), vast.Identifier('Q <= 0;'),
            vast.IfStatement(vast.Identifier('en'), assign_q, None), None))

    statement = vast.Block(blocks)
    always = vast.Always(senslist, statement)

    items = []
    items.append(q_reg)
    items.append(always)
    ast = vast.ModuleDef("lat", None, ports, items)

    write_verilog(ast, 'lat.v', path)
Ejemplo n.º 9
0
 def visit_Task(self, node):
     name = node.name
     statement = ([self.visit(v).first for v in node.io_variable.values()] +
                  [self.visit(v) for v in node.variable.values()])
     statement.append(self._optimize_block(
         vast.Block(tuple([self.blocking_visitor.visit(s) for s in node.statement]))))
     return vast.Task(name, statement)
Ejemplo n.º 10
0
    def fk_module(self, dip_list, skip_cycles, equal_keys):
        # this module finds the correct keys after termination
        inst_list = []
        if len(dip_list) > 0:
            for i, dip in enumerate(dip_list):
                inst_list.append(
                    vast.Identifier('reg [{}:0] dip{} [0:{}];'.format(
                        self.input_size_msb, i,
                        len(dip) - 1)))

    # add always block
        blocks = []
        if equal_keys:
            blocks.append(vast.Identifier('assume (k1 == k2);'))
        else:
            blocks.append(vast.Identifier('assume (k1 != k2);'))
        blocks.append(vast.Identifier('cycle <= cycle + 1;'))
        blocks.append(vast.Identifier('if (cycle == {})'.format(skip_cycles)))
        blocks.append(vast.Identifier('  cover(1);'))
        statement = vast.Block(blocks)
        sens = vast.Sens(vast.Identifier('clk'), type='posedge')
        inst_list.append(vast.Always(vast.SensList([sens]), statement))

        statement = vast.Block([vast.Identifier(self.key_constraints)])
        sens = vast.Sens(None, type='all')
        inst_list.append(vast.Always(vast.SensList([sens]), statement))

        if len(dip_list) > 0:
            for i, dip in enumerate(dip_list):
                inst_list.append(
                    vast.Identifier(
                        'dip_checker dc{} (clk, dip{}[cycle], k1, k2);'.format(
                            i, i)))

        # add initial block
        blocks = []
        sens = vast.Sens(vast.Identifier('clk'), type='posedge')
        senslist = vast.SensList([sens])
        for i, dip in enumerate(dip_list):
            for j, inp in enumerate(dip):
                blocks.append(
                    vast.Identifier('dip{}[{}] = {};'.format(i, j, inp)))
        statement = vast.Block(blocks)
        inst_list.append(vast.Always(senslist, statement))

        portslist = vast.Portlist([])
        self.ce = vast.ModuleDef("ce", None, portslist, inst_list)
Ejemplo n.º 11
0
 def visit_Initial(self, node):
     self.blocking_visitor.in_initial = True
     statement = self._optimize_block(
         vast.Block(
             tuple([self.blocking_visitor.visit(s)
                    for s in node.statement])))
     self.blocking_visitor.in_initial = False
     return vast.Initial(statement)
Ejemplo n.º 12
0
 def visit_Always(self, node):
     sens = (tuple([self.visit(n) if isinstance(n, vtypes.Sensitive) else
                    vast.Sens(self.always_visitor.visit(n), 'level')
                    for n in node.sensitivity]) if node.sensitivity else
             tuple([vast.Sens(None, 'all')]))
     sensitivity = vast.SensList(sens)
     statement = self._optimize_block(
         vast.Block(tuple([self.always_visitor.visit(n) for n in node.statement])))
     return vast.Always(sensitivity, statement)
Ejemplo n.º 13
0
 def visit_GenerateFor(self, node):
     genfor_visitor = VerilogBlockingVisitor()
     pre = genfor_visitor.visit(node.pre)
     cond = genfor_visitor.visit(node.cond)
     post = genfor_visitor.visit(node.post)
     items = self._visit_Generate(node)
     block = vast.Block(items, scope=node.scope)
     _for = vast.ForStatement(pre, cond, post, block)
     return vast.GenerateStatement(tuple([_for]))
Ejemplo n.º 14
0
def lower_action(fsm_reg: vast.Reg, done_stage: vast.IntConst, action: Action):
    body = [lower_update(upd) for upd in action.updates]

    # Statement to update the FSM register to next state.
    st_upd = vast.NonblockingSubstitution(
        vast.Lvalue(vast.Identifier(fsm_reg.name)),
        lower_next(done_stage, action.next_state),
    )
    return vast.Block(body + [st_upd])
Ejemplo n.º 15
0
 def _optimize_block(self, node):
     if not isinstance(node, vast.Block): return node
     ret = []
     for n in node.statements:
         if isinstance(n, vast.Block) and n.scope is None:
             ret.extend(n.statements)
         else:
             ret.append(n)
     return vast.Block(tuple(ret))
Ejemplo n.º 16
0
 def visit_Function(self, node):
     name = node.name
     retwidth = (None if node.width is None else
                 self.make_width_msb_lsb(node.width_msb, node.width_lsb)
                 if node.width_msb is not None and node.width_lsb is not None else
                 self.make_width(node.width))
     statement = ([self.visit(v).first for v in node.io_variable.values()] +
                  [self.visit(v) for v in node.variable.values()])
     statement.append(self._optimize_block(
         vast.Block(tuple([self.blocking_visitor.visit(s) for s in node.statement]))))
     return vast.Function(name, retwidth, statement)
Ejemplo n.º 17
0
 def visit_Function(self, node):
     name = node.name
     retwidth = self.make_width(node)
     statement = ([self.visit(v).first for v in node.io_variable.values()] +
                  [self.visit(v) for v in node.variable.values()])
     statement.append(
         self._optimize_block(
             vast.Block(
                 tuple([
                     self.blocking_visitor.visit(s) for s in node.statement
                 ]))))
     return vast.Function(name, retwidth, statement)
Ejemplo n.º 18
0
def gen_dff(path):
    clk = vast.Ioport(vast.Input('clk'))
    q = vast.Ioport(vast.Output('Q'))
    d = vast.Ioport(vast.Input('D'))
    ports = vast.Portlist([clk, q, d])

    q_reg = vast.Identifier('reg Q = 0;')

    sens = vast.Sens(vast.Identifier('clk'), type='posedge')
    senslist = vast.SensList([sens])

    assign_q = vast.NonblockingSubstitution(vast.Lvalue(vast.Identifier('Q')),
                                            vast.Rvalue(vast.Identifier('D')))

    statement = vast.Block([assign_q])
    always = vast.Always(senslist, statement)

    items = []
    items.append(q_reg)
    items.append(always)
    ast = vast.ModuleDef("dff", None, ports, items)

    write_verilog(ast, 'dff.v', path)
Ejemplo n.º 19
0
    def dip_gen_module(self):
        # creates dip_generator module
        # a more complicated version of dip_gen_module_uc
        # uses if in place of assume, it is faster but couldn't be used for uc termination
        portslist = []
        portslist.append(vast.Ioport(vast.Input('clk')))
        portslist.append(vast.Ioport(vast.Input('iv', width=self.inp_width)))
        portslist.append(vast.Ioport(vast.Input('k1', width=self.key_width)))
        portslist.append(vast.Ioport(vast.Input('k2', width=self.key_width)))
        portslist = vast.Portlist(portslist)

        inst_list = []
        inst_list.append(vast.Wire('ov1', width=self.out_width))
        inst_list.append(vast.Wire('ov2', width=self.out_width))
        inst_list.extend([self.obf1, self.obf2])

        # add always* block
        blocks = []
        # blocks.append(vast.IfStatement(vast.Identifier('k1 != k2'),
        #                                vast.IfStatement(vast.Identifier('ov1 != ov2'),
        #                                                 vast.Identifier('assert (ov1 == ov2);'), None),
        #                                None))
        # TODO: changed for latch locking
        blocks.append(vast.Identifier('assume (k1 != k2);'))
        blocks.append(vast.Identifier('assert (ov1 == ov2);'))

        statement = vast.Block(blocks)

        # TODO: posedge in case of latch
        # sens = vast.Sens(None, type='all')
        sens = vast.Sens(vast.Identifier('clk'), type='posedge')

        inst_list.append(vast.Always(vast.SensList([sens]), statement))

        self.dip_gen = vast.ModuleDef("dip_generator", None, portslist,
                                      inst_list)
Ejemplo n.º 20
0
 def visit_Forever(self, node):
     statement = self._optimize_block(
         vast.Block(tuple([self.visit(s) for s in node.statement])))
     return vast.ForeverStatement(statement)
Ejemplo n.º 21
0
    def ce_module(self, module_name, dip_list):
        # module for checking ce
        state_width = vast.Width(vast.IntConst(self.state_size_msb),
                                 vast.IntConst('0'))
        ce_name = module_name + '_ce'
        step = 1

        # module port list
        portslist = []
        portslist.append(vast.Ioport(vast.Input('clk')))
        portslist.append(
            vast.Ioport(vast.Input('ce_iv_s0', width=self.inp_width)))
        portslist.append(
            vast.Ioport(vast.Input('ce_state_s0', width=state_width)))
        portslist = vast.Portlist(portslist)

        # create other components for dis_generator
        inst_list = []
        inst_list.append(vast.Wire('ce_state1_s0', width=state_width))
        inst_list.append(vast.Wire('ce_state2_s0', width=state_width))
        inst_list.append(vast.Wire('ce_state1_s1', width=state_width))
        inst_list.append(vast.Wire('ce_state2_s1', width=state_width))

        inst_list.append(vast.Wire('ce_ov1_s0', width=self.out_width))
        inst_list.append(vast.Wire('ce_ov2_s0', width=self.out_width))

        inst_list.append(vast.Identifier('assign ce_state1_s0 = ce_state_s0;'))
        inst_list.append(vast.Identifier('assign ce_state2_s0 = ce_state_s0;'))

        for s in range(step):
            # create instances for obf1_ce, obf2_ce
            ports = create_ports('ce_iv_s' + str(s), 'ce_ov1_s' + str(s),
                                 self.orcl_cir)
            key_ports = [vast.PortArg("", vast.Identifier('k1'))]
            state_ports = [
                vast.PortArg("", vast.Identifier('ce_state1_s{}'.format(s)))
            ]
            nstate_ports = [
                vast.PortArg("",
                             vast.Identifier('ce_state1_s{}'.format(s + 1)))
            ]
            inst = vast.Instance(
                ce_name, "obf1_ce_s{}".format(s),
                ports + key_ports + state_ports + nstate_ports, "")
            obf1_ce = vast.InstanceList(ce_name, "", [inst])

            ports = create_ports('ce_iv_s' + str(s), 'ce_ov2_s' + str(s),
                                 self.orcl_cir)
            state_ports = [
                vast.PortArg("", vast.Identifier('ce_state2_s{}'.format(s)))
            ]
            key_ports = [vast.PortArg("", vast.Identifier('k2'))]
            nstate_ports = [
                vast.PortArg("",
                             vast.Identifier('ce_state2_s{}'.format(s + 1)))
            ]
            inst = vast.Instance(
                ce_name, 'obf2_ce_s{}'.format(s),
                ports + key_ports + state_ports + nstate_ports, "")
            obf2_ce = vast.InstanceList(ce_name, "", [inst])

            inst_list.extend([obf1_ce, obf2_ce])

        # add always block
        sens = vast.Sens(vast.Identifier('clk'), type='posedge')
        senslist = vast.SensList([sens])

        blocks = []
        for s in range(step):
            blocks.append(
                vast.Identifier('assert (ce_ov1_s{} == ce_ov2_s{});'.format(
                    s, s)))
            blocks.append(
                vast.Identifier(
                    'assert (ce_state1_s{} == ce_state2_s{});'.format(
                        s + 1, s + 1)))

        if len(dip_list) > 0:
            for i, dip in enumerate(dip_list):
                inst_list.append(
                    vast.Identifier('reg [{}:0] dip{} [0:{}];'.format(
                        self.input_size_msb, i,
                        len(dip) - 1)))

        # add always block
        blocks.append(vast.Identifier('cycle <= cycle + 1;'))
        statement = vast.Block(blocks)

        sens = vast.Sens(vast.Identifier('clk'), type='posedge')
        inst_list.append(vast.Always(vast.SensList([sens]), statement))

        if len(dip_list) > 0:
            for i, dip in enumerate(dip_list):
                inst_list.append(
                    vast.Identifier(
                        'dip_checker dc{} (clk, dip{}[cycle], k1, k2);'.format(
                            i, i)))

        # add always block
        blocks = []
        for i, dip in enumerate(dip_list):
            for j, inp in enumerate(dip):
                blocks.append(
                    vast.Identifier('dip{}[{}] = {};'.format(i, j, inp)))
        statement = vast.Block(blocks)
        inst_list.append(vast.Always(senslist, statement))

        # for s in range(step):
        #     blocks.append(vast.Identifier('assume (ce_state1_s{} != ce_state2_s{});'.format(s+1, s+1)))

        statement = vast.Block(blocks)
        inst_list.append(vast.Always(senslist, statement))
        self.ce = vast.ModuleDef("ce", None, portslist, inst_list)
Ejemplo n.º 22
0
 def visit_list(self, node):
     return self._optimize_block(
         vast.Block(tuple([self.visit(n) for n in node])))
Ejemplo n.º 23
0
def lower_fsm(fsm: FSM):
    assert fsm.start_state == 0, "Starting state is not 0"
    zero = vast.IntConst(0)
    all_registers = [
        Register("a", 8),
        Register("b", 8),
        Register("tmp", 8),
        Register("_cond", 1),
    ]
    register_defs = [
        vast.Reg(
            reg.name,
            vast.Width(vast.IntConst(reg.width -
                                     1), zero) if reg.width - 1 != 0 else None)
        for reg in all_registers
    ]

    ports = vast.Portlist([
        vast.Ioport(vast.Input('clk')),
        # XXX(rachit): AST can't represent `output reg done`
        # so assign to a local register and use a wire.
        vast.Ioport(vast.Output('done')),
    ])
    done_state = max(fsm.actions.keys()) + 1
    done_reg = vast.Reg('done_out')
    hook_up_done = vast.Assign(
        vast.Lvalue(vast.Identifier('done')),
        vast.Rvalue(vast.Identifier('done_out')),
    )

    # Register to store the FSM state.
    fsm_reg_size = int(math.ceil(math.log2(done_state))) + 1
    fsm_reg = vast.Reg(name="fsm_reg",
                       width=vast.Width(vast.IntConst(fsm_reg_size - 1), zero))

    # Define all the registers.
    reg_decl = register_defs + [fsm_reg]

    # Define the initial process
    inits = vast.Initial(
        vast.Block([
            vast.BlockingSubstitution(
                vast.Lvalue(vast.Identifier(reg.name)),
                vast.Rvalue(vast.IntConst(0)),
            ) for reg in reg_decl
        ]))

    # Default case, assigns to the done register.
    done = vast.IntConst(done_state)
    default_case = vast.Case(cond=None,
                             statement=vast.Block([
                                 vast.NonblockingSubstitution(
                                     vast.Lvalue(vast.Identifier(reg.name)),
                                     vast.Rvalue(vast.Identifier(reg.name)),
                                 ) for reg in reg_decl
                             ] + [
                                 vast.NonblockingSubstitution(
                                     vast.Lvalue(vast.Identifier('done_out')),
                                     vast.Rvalue(vast.IntConst(1)))
                             ]))

    # Generate Case conditions for each transition.
    cases = [
        vast.Case([vast.IntConst(cond_val)],
                  lower_action(fsm_reg, done, action))
        for (cond_val, action) in fsm.actions.items()
    ]
    case_statement = vast.CaseStatement(comp=vast.Identifier(fsm_reg.name),
                                        caselist=cases + [default_case])
    always_ff = vast.Always(
        vast.SensList([vast.Sens(vast.Identifier('clk'), 'posedge')]),
        vast.Block([case_statement]))

    return vast.ModuleDef(name="main",
                          paramlist=vast.Paramlist([]),
                          portlist=ports,
                          items=reg_decl +
                          [done_reg, hook_up_done, inits, always_ff])