Esempio n. 1
0
class DirectFF_sig(Unit):
    def _declr(self):
        self.o = Signal()._m()
        self.clk = Signal()

    def _impl(self):
        r = self._sig("r", def_val=0)
        If(self.clk._onRisingEdge(), r(r))
        self.o(r)
Esempio n. 2
0
class MaxPoolUnit(Unit):
    """
    .. hwt-schematic::
    """
    def __init__(self, width=16, binary=False, **kwargs):
        self.logger = logging.getLogger(self.__class__.__name__)
        self.width = width
        self.binary = binary
        self.top_entity = False
        print_info(self, **kwargs)
        super().__init__()

    def _declr(self):
        self.clk = Signal()
        self.rst = Signal()
        self.en_pool = Signal()
        self.input = VectSignal(self.width * 4)
        self.output = VectSignal(self.width)._m()

        name = f"MaxPoolUnitL{self.layer_id}"
        self._name = name
        self._hdl_module_name = name

    def __comparison(self, param_a, param_b, out):
        return If(param_a > param_b, out(param_a)).Else(out(param_b))

    def __bin_comparison(self, param_a, param_b, out):
        return If(param_a & param_b, out(param_a)).Else(out(param_b))

    def _impl(self):
        signal_width = Bits(bit_length=self.width, force_vector=True)
        inputs = [
            self._sig(name=f"input{i}", dtype=signal_width) for i in range(4)
        ]
        for i in range(4):
            inputs[i](self.input[(i + 1) * self.width:i * self.width])

        first_pool0 = self._sig(name="first_pool0", dtype=signal_width)
        first_pool1 = self._sig(name="first_pool1", dtype=signal_width)
        pool_result = self._sig(name="pool_result", dtype=signal_width)

        comparison = self.__bin_comparison if self.binary else self.__comparison

        comparison(inputs[0], inputs[1], first_pool0)
        comparison(inputs[2], inputs[3], first_pool1)

        If(self.rst, pool_result(0)).Else(
            If(
                self.clk._onRisingEdge(),
                If(self.en_pool,
                   comparison(first_pool0, first_pool1, pool_result)),
            ))

        self.output(pool_result)
Esempio n. 3
0
class ConvUnit(Unit):
    """
    .. hwt-schematic::
    """
    def __init__(self, size=9, width=16, **kwargs):
        self.logger = logging.getLogger(self.__class__.__name__)
        self.size = size
        self.width = width
        self.top_entity = False

        print_info(self, **kwargs)
        super().__init__()

    def _declr(self):
        self.clk = Signal()
        self.rst = Signal()
        self.en_mult = Signal()
        self.en_sum = Signal()
        self.input = VectSignal(self.width * self.size)
        self.output = VectSignal(self.width, signed=True)._m()

        for i in range(self.size):
            setattr(self, f"kernel_{i}", VectSignal(self.width))

        self.multiplier = HObjList(
            FixedPointMultiplier(
                width=self.width,
                layer_id=self.layer_id,
                unit_id=self.unit_id,
                channel_id=self.channel_id,
                process_id=self.process_id,
                pixel_id=i,
                log_level=self.log_level + 1,
            ) for i in range(self.size))

        name = f"ConvUnitL{self.layer_id}"
        self._name = name
        self._hdl_module_name = name

    def __calc_tree_adders(self, signal_list, signal_width):
        first_sum = [
            self._sig(name=f"first_sum_{i}", dtype=signal_width)
            for i in range(4)
        ]
        first_sum[0](signal_list[0] + signal_list[1])
        first_sum[1](signal_list[2] + signal_list[3])
        first_sum[2](signal_list[4] + signal_list[5])
        first_sum[3](signal_list[6] + signal_list[7])

        second_sum = [
            self._sig(name=f"second_sum_{i}", dtype=signal_width)
            for i in range(2)
        ]
        second_sum[0](first_sum[0] + first_sum[1])
        second_sum[1](first_sum[2] + first_sum[3])

        third_sum = self._sig(name="third_sum", dtype=signal_width)
        third_sum(second_sum[0] + second_sum[1] + signal_list[8])
        return third_sum

    def _impl(self):
        signal_width = Bits(bit_length=self.width)
        product_list = [
            self._sig(name=f"product_{i}", dtype=signal_width)
            for i in range(self.size)
        ]

        for i in range(self.size):
            multiplier = self.multiplier[i]
            multiplier.clk(self.clk)
            multiplier.rst(self.rst)
            multiplier.param_a(self.input[self.width * (i + 1):self.width * i])
            multiplier.param_b(getattr(self, f"kernel_{i}"))

            If(self.rst, product_list[i](0)).Else(
                If(
                    self.clk._onRisingEdge(),
                    If(self.en_mult, product_list[i](multiplier.product)),
                ))

        if self.size == 9:
            third_sum = self.__calc_tree_adders(product_list, signal_width)

            If(self.rst, self.output(0)).Else(
                If(self.clk._onRisingEdge(),
                   If(self.en_sum, self.output(third_sum))))
        else:
            If(self.rst, self.output(0)).Else(
                If(
                    self.clk._onRisingEdge(),
                    If(self.en_sum, self.output(product_list[0])),
                ))