Ejemplo n.º 1
0
class Arbiter(Module):
    def __init__(self, pe_id, config):
        self.network_interface_in = NetworkInterface(
            name="left_in", **config.addresslayout.get_params())
        self.network_interface_out = NetworkInterface(
            name="right_out", **config.addresslayout.get_params())
        self.local_interface_in = NetworkInterface(
            name="local_in", **config.addresslayout.get_params())
        local_interface_out = NetworkInterface(
            name="local_out", **config.addresslayout.get_params())

        self.apply_interface_out = ApplyInterface(
            **config.addresslayout.get_params())
        self.start_message = ApplyInterface(
            **config.addresslayout.get_params())
        self.start_message.select = Signal()

        self.network_round = Signal(config.addresslayout.channel_bits)
        self.round_accepting = Signal(config.addresslayout.channel_bits)

        mux_interface = NetworkInterface(**config.addresslayout.get_params())

        inject = Signal()
        self.comb += [
            inject.eq(
                self.local_interface_in.msg.roundpar == self.network_round),
            If(
                self.network_interface_in.valid &
                (self.network_interface_in.msg.roundpar == self.network_round),
                self.network_interface_in.connect(mux_interface)).Elif(
                    inject, self.local_interface_in.connect(mux_interface))
        ]

        self.comb += [
            If(mux_interface.dest_pe == pe_id,
               mux_interface.connect(local_interface_out)).Else(
                   mux_interface.connect(self.network_interface_out))
        ]

        self.submodules.barriercounter = Barriercounter(config)

        self.comb += [
            local_interface_out.msg.connect(
                self.barriercounter.apply_interface_in.msg),
            self.barriercounter.apply_interface_in.valid.eq(
                local_interface_out.valid),
            local_interface_out.ack.eq(
                self.barriercounter.apply_interface_in.ack),
            If(self.start_message.select,
               self.start_message.connect(self.apply_interface_out)).Else(
                   self.barriercounter.apply_interface_out.connect(
                       self.apply_interface_out)),
            self.round_accepting.eq(self.barriercounter.round_accepting)
        ]

    def gen_selfcheck(self, tb):
        yield
Ejemplo n.º 2
0
class Arbiter(Module):
    def __init__(self, pe_id, config):
        self.network_interface_in = NetworkInterface(name="left_in", **config.addresslayout.get_params())
        self.network_interface_out = NetworkInterface(name="right_out", **config.addresslayout.get_params())
        self.local_interface_in = NetworkInterface(name="local_in", **config.addresslayout.get_params())
        local_interface_out = NetworkInterface(name="local_out", **config.addresslayout.get_params())

        self.apply_interface_out = ApplyInterface(**config.addresslayout.get_params())
        self.start_message = ApplyInterface(**config.addresslayout.get_params())
        self.start_message.select = Signal()

        self.network_round = Signal(config.addresslayout.channel_bits)
        self.round_accepting = Signal(config.addresslayout.channel_bits)

        mux_interface = NetworkInterface(**config.addresslayout.get_params())


        inject = Signal()
        self.comb += [
            inject.eq(self.local_interface_in.msg.roundpar == self.network_round),
            If(self.network_interface_in.valid & (self.network_interface_in.msg.roundpar == self.network_round),
                self.network_interface_in.connect(mux_interface)
            ).Elif(inject,
                self.local_interface_in.connect(mux_interface)
            )
        ]



        self.comb += [
            If(mux_interface.dest_pe == pe_id,
                mux_interface.connect(local_interface_out)
            ).Else(
                mux_interface.connect(self.network_interface_out)
            )
        ]

        self.submodules.barriercounter = Barriercounter(config)

        self.comb += [
            local_interface_out.msg.connect(self.barriercounter.apply_interface_in.msg),
            self.barriercounter.apply_interface_in.valid.eq(local_interface_out.valid),
            local_interface_out.ack.eq(self.barriercounter.apply_interface_in.ack),
            If(self.start_message.select,
                self.start_message.connect(self.apply_interface_out)
            ).Else(
                self.barriercounter.apply_interface_out.connect(self.apply_interface_out)
            ),
            self.round_accepting.eq(self.barriercounter.round_accepting)
        ]

    def gen_selfcheck(self, tb):
        yield
Ejemplo n.º 3
0
class Arbiter(Module):
    def __init__(self, pe_id, config):
        addresslayout = config.addresslayout
        nodeidsize = addresslayout.nodeidsize
        num_pe = addresslayout.num_pe
        self.pe_id = pe_id

        # input (n channels)
        self.apply_interface_in = ApplyInterface(name="arbiter_in", **addresslayout.get_params())

        # output
        self.apply_interface_out = ApplyInterface(name="arbiter_out", **addresslayout.get_params())

        # input override for injecting the message starting the computation
        self.start_message = ApplyInterface(name="start_message", **addresslayout.get_params())
        self.start_message.select = Signal()

        self.submodules.barriercounter = Barriercounter(config)
        self.current_round = Signal(config.addresslayout.channel_bits)

        self.comb += [
            self.barriercounter.apply_interface_in.msg.raw_bits().eq(self.apply_interface_in.msg.raw_bits()),
            self.barriercounter.apply_interface_in.valid.eq(self.apply_interface_in.valid),
            self.apply_interface_in.ack.eq(self.barriercounter.apply_interface_in.ack),
            self.current_round.eq(self.barriercounter.round_accepting)
        ]

        # choose between init and regular message channel
        self.comb += \
            If(self.start_message.select,
                self.start_message.connect(self.apply_interface_out)
            ).Else(
                self.barriercounter.apply_interface_out.connect(self.apply_interface_out)
            )
    def gen_selfcheck(self, tb):
        logger = logging.getLogger("sim.arbiter" + str(self.pe_id))
        level = 0
        num_cycles = 0

        while not (yield tb.global_inactive):
            num_cycles += 1

            if (yield self.apply_interface_out.valid) and (yield self.apply_interface_out.ack):
                if (yield self.apply_interface_out.msg.barrier):
                    level += 1
                    logger.debug("{}: Barrier passed to apply".format(num_cycles))
                else:
                    if level % 2 == (yield self.apply_interface_out.msg.roundpar):
                        logger.warning("{}: received message's parity ({}) does not match current round ({})".format(num_cycles, (yield self.apply_interface_out.msg.roundpar), level))
            yield
class Barriercounter(Module):
    def __init__(self, config):
        self.apply_interface_in = ApplyInterface(
            name="barriercounter_in", **config.addresslayout.get_params())
        self.apply_interface_out = ApplyInterface(
            name="barriercounter_out", **config.addresslayout.get_params())
        self.round_accepting = Signal(config.addresslayout.channel_bits)

        apply_interface_in_fifo = InterfaceFIFO(
            layout=self.apply_interface_in.layout, depth=2)
        self.submodules += apply_interface_in_fifo
        self.comb += self.apply_interface_in.connect(
            apply_interface_in_fifo.din)

        num_pe = config.addresslayout.num_pe

        self.barrier_from_pe = Array(Signal() for _ in range(num_pe))
        self.num_from_pe = Array(
            Signal(config.addresslayout.nodeidsize) for _ in range(num_pe))
        self.num_expected_from_pe = Array(
            Signal(config.addresslayout.nodeidsize) for _ in range(num_pe))
        self.all_from_pe = Array(Signal() for _ in range(num_pe))
        self.all_messages_recvd = Signal()
        self.all_barriers_recvd = Signal()

        self.comb += [
            self.all_barriers_recvd.eq(reduce(and_, self.barrier_from_pe)),
            self.all_messages_recvd.eq(reduce(and_, self.all_from_pe)),
        ]

        self.comb += [
            self.all_from_pe[i].eq(
                self.num_from_pe[i] == self.num_expected_from_pe[i])
            for i in range(num_pe)
        ]

        halt = Signal()

        sender_pe = config.addresslayout.pe_adr(
            apply_interface_in_fifo.dout.msg.sender)

        self.waiting_for_stragglers = Signal()

        self.submodules.fsm = FSM()

        self.fsm.act(
            "DEFAULT",
            If(
                self.apply_interface_out.ack,
                apply_interface_in_fifo.dout.ack.eq(1),
                NextValue(self.apply_interface_out.msg.raw_bits(),
                          apply_interface_in_fifo.dout.msg.raw_bits()),
                NextValue(
                    self.apply_interface_out.valid,
                    apply_interface_in_fifo.dout.valid
                    & ~apply_interface_in_fifo.dout.msg.barrier),
                If(
                    apply_interface_in_fifo.dout.valid,
                    If(
                        apply_interface_in_fifo.dout.msg.barrier,
                        NextValue(self.barrier_from_pe[sender_pe], 1),
                        NextValue(self.num_expected_from_pe[sender_pe],
                                  apply_interface_in_fifo.dout.msg.dest_id),
                        If(~apply_interface_in_fifo.dout.msg.halt,
                           NextValue(halt, 0)), NextState("CHK_BARRIER")).Else(
                               NextValue(self.num_from_pe[sender_pe],
                                         self.num_from_pe[sender_pe] + 1)))))

        self.fsm.act(
            "CHK_BARRIER", apply_interface_in_fifo.dout.ack.eq(0),
            If(
                self.apply_interface_out.ack,
                NextValue(self.apply_interface_out.valid, 0),
                If(self.all_barriers_recvd,
                   NextState("PASS_BARRIER")).Else(NextState("DEFAULT"))))

        self.fsm.act(
            "PASS_BARRIER", apply_interface_in_fifo.dout.ack.eq(0),
            If(
                self.apply_interface_out.ack,
                If(
                    self.all_messages_recvd,
                    If(
                        self.round_accepting <
                        config.addresslayout.num_channels - 1,
                        NextValue(self.round_accepting,
                                  self.round_accepting + 1)).Else(
                                      NextValue(self.round_accepting, 0)),
                    NextValue(halt, 1),
                    NextValue(self.apply_interface_out.msg.halt, halt),
                    NextValue(self.apply_interface_out.msg.barrier, 1),
                    NextValue(self.apply_interface_out.valid, 1), [
                        NextValue(self.barrier_from_pe[i], 0)
                        for i in range(num_pe)
                    ],
                    [NextValue(self.num_from_pe[i], 0) for i in range(num_pe)],
                    NextState("DEFAULT")).Else(
                        NextValue(self.apply_interface_out.valid, 0),
                        NextState("WAIT_FOR_STRAGGLER"))))

        self.fsm.act(
            "WAIT_FOR_STRAGGLER",
            self.waiting_for_stragglers.eq(1),
            If(
                self.apply_interface_out.ack,
                apply_interface_in_fifo.dout.ack.eq(1),
                NextValue(self.apply_interface_out.msg.raw_bits(),
                          apply_interface_in_fifo.dout.msg.raw_bits()),
                NextValue(self.apply_interface_out.valid,
                          apply_interface_in_fifo.dout.valid),
                If(
                    apply_interface_in_fifo.dout.valid,
                    NextValue(self.num_from_pe[sender_pe],
                              self.num_from_pe[sender_pe] + 1),
                    NextState(
                        "CHK_BARRIER"
                    )  #this gratuitously checks all_barriers_recvd again, but we need to wait an extra cycle for all_messages_recvd to be updated
                )))

    @passive
    def gen_selfcheck(self, tb):
        logger = logging.getLogger('sim.barriercounter')
        while True:
            if (yield self.waiting_for_stragglers):
                logger.warning("Barriercounter is waiting for stragglers:")
                for i in range(tb.config.addresslayout.num_pe):
                    received = (yield self.num_from_pe[i])
                    expected = (yield self.num_expected_from_pe[i])
                    if received != expected:
                        logger.warning(
                            "Only {} of {} updates received from PE {}".format(
                                received, expected, i))
            yield
Ejemplo n.º 5
0
class Apply(Module):
    def __init__(self, config, pe_id):
        self.config = config
        self.pe_id = pe_id
        addresslayout = config.addresslayout
        nodeidsize = addresslayout.nodeidsize
        num_nodes_per_pe = addresslayout.num_nodes_per_pe
        num_valid_nodes = max(2, config.addresslayout.max_node_per_pe(config.adj_dict)[self.pe_id]+1)

        # input Q interface
        self.apply_interface = ApplyInterface(name="apply_in", **addresslayout.get_params())

        # scatter interface
        # send self.update message to all neighbors
        # message format (sending_node_id) (normally would be (sending_node_id, weight), but for PR weight = sending_node_id)
        self.scatter_interface = ApplyInterface(name="apply_out", **addresslayout.get_params())

        self.deadlock = Signal()

        ####

        apply_interface_in_fifo = InterfaceFIFO(layout=self.apply_interface.layout, depth=8, name="apply_in_fifo")
        self.submodules += apply_interface_in_fifo
        self.comb += self.apply_interface.connect(apply_interface_in_fifo.din)

        # local node data storage
        self.specials.mem = Memory(layout_len(addresslayout.node_storage_layout), num_valid_nodes, init=config.init_nodedata[pe_id] if config.init_nodedata else None, name="vertex_data_{}".format(self.pe_id))
        rd_port = self.specials.rd_port = self.mem.get_port(has_re=True)
        wr_port = self.specials.wr_port = self.mem.get_port(write_capable=True)

        local_wr_port = Record(layout=get_mem_port_layout(wr_port))
        self.external_wr_port = Record(layout=get_mem_port_layout(wr_port) + [("select", 1)])

        self.comb += [
            If(self.external_wr_port.select,
                self.external_wr_port.connect(wr_port, omit={"select"})
            ).Else(
                local_wr_port.connect(wr_port)
            )
        ]

        # should pipeline advance?
        upstream_ack = Signal()
        collision_re = Signal()
        collision_en = Signal()

        # count levels
        self.level = Signal(32)

        ## Stage 1
        # rename some signals for easier reading, separate barrier and normal valid (for writing to state mem)
        dest_node_id = Signal(nodeidsize)
        sender = Signal(nodeidsize)
        payload = Signal(addresslayout.messagepayloadsize)
        roundpar = Signal(config.addresslayout.channel_bits)
        valid = Signal()
        barrier = Signal()

        self.comb += [
            dest_node_id.eq(apply_interface_in_fifo.dout.msg.dest_id),
            sender.eq(apply_interface_in_fifo.dout.msg.sender),
            payload.eq(apply_interface_in_fifo.dout.msg.payload),
            roundpar.eq(apply_interface_in_fifo.dout.msg.roundpar),
            valid.eq(apply_interface_in_fifo.dout.valid & ~apply_interface_in_fifo.dout.msg.barrier),
            barrier.eq(apply_interface_in_fifo.dout.valid & apply_interface_in_fifo.dout.msg.barrier & ~apply_interface_in_fifo.dout.msg.halt),
        ]

        ## Stage 2
        dest_node_id2 = Signal(nodeidsize)
        sender2 = Signal(nodeidsize)
        payload2 = Signal(addresslayout.messagepayloadsize)
        roundpar2 = Signal(config.addresslayout.channel_bits)
        barrier2 = Signal()
        valid2 = Signal()
        ready = Signal()
        msgvalid2 = Signal()
        statevalid2 = Signal()

        state_barrier = Signal()

        node_idx = Signal(nodeidsize)
        gather_done = Signal()

        next_roundpar = Signal(config.addresslayout.channel_bits)
        self.comb += If(roundpar==config.addresslayout.num_channels-1, next_roundpar.eq(0)).Else(next_roundpar.eq(roundpar+1))

        self.submodules.fsm = FSM()
        self.fsm.act("GATHER",
            rd_port.re.eq(upstream_ack),
            apply_interface_in_fifo.dout.ack.eq(upstream_ack),
            rd_port.adr.eq(addresslayout.local_adr(dest_node_id)),
            NextValue(collision_en, 1),
            If(~collision_re,
                NextValue(valid2, 0) # insert bubble if collision
            ).Elif(upstream_ack,
                NextValue(valid2, valid),
                NextValue(dest_node_id2, dest_node_id),
                NextValue(sender2, sender),
                NextValue(payload2, payload),
                NextValue(roundpar2, next_roundpar),
                NextValue(statevalid2, 1),
                NextValue(msgvalid2, ~barrier),
                If(barrier,
                    NextValue(collision_en, 0),
                    NextValue(valid2, 0),
                    NextState("FLUSH")
                )
            )
        )
        self.fsm.act("FLUSH",
            rd_port.re.eq(0),
            NextValue(node_idx, pe_id << log2_int(num_nodes_per_pe)),
            apply_interface_in_fifo.dout.ack.eq(0),
            rd_port.adr.eq(addresslayout.local_adr(dest_node_id)),
            If(gather_done,
                NextState("APPLY")
            )
        )
        self.fsm.act("APPLY",
            rd_port.re.eq(ready),
            apply_interface_in_fifo.dout.ack.eq(0),
            rd_port.adr.eq(addresslayout.local_adr(node_idx)),
            If(ready,
                NextValue(valid2, 1),
                NextValue(dest_node_id2, node_idx),
                NextValue(node_idx, node_idx+1),
                If(node_idx==(num_valid_nodes + (pe_id << log2_int(num_nodes_per_pe))),
                    NextValue(statevalid2, 0),
                    NextValue(barrier2, 1),
                    NextValue(valid2, 1),
                    NextState("BARRIER_SEND")
                )
            )
        )
        self.fsm.act("BARRIER_SEND",
            apply_interface_in_fifo.dout.ack.eq(0),
            rd_port.adr.eq(addresslayout.local_adr(node_idx)),
            If(ready,
                NextValue(barrier2, 0),
                NextValue(valid2, 0),
                If(state_barrier,
                    NextValue(self.level, self.level+1),
                    NextState("GATHER")
                ).Else(
                    NextState("BARRIER_WAIT")
                )
            )
        )
        self.fsm.act("BARRIER_WAIT",
            If(state_barrier,
                NextValue(self.level, self.level+1),
                NextState("GATHER")
            )
        )

        # collision handling (combinatorial)
        self.submodules.collisiondetector = CollisionDetector(addresslayout)

        self.comb += [
            self.collisiondetector.read_adr.eq(addresslayout.local_adr(dest_node_id)),
            self.collisiondetector.read_adr_valid.eq(ready & valid & collision_en), # can't be rd_port.re because that uses collisiondetector.re -> comb loop
            self.collisiondetector.write_adr.eq(local_wr_port.adr),
            self.collisiondetector.write_adr_valid.eq(local_wr_port.we),
            collision_re.eq(self.collisiondetector.re),
            gather_done.eq(self.collisiondetector.all_clear)
        ]

        # User code
        if hasattr(config, "gatherapplykernel"):
            self.submodules.gatherapplykernel = config.gatherapplykernel(config)
        else:
            self.submodules.gatherapplykernel = GatherApplyWrapper(config.gatherkernel(config), config.applykernel(config))

        self.comb += [
            self.gatherapplykernel.level_in.eq(self.level),
            self.gatherapplykernel.nodeid_in.eq(dest_node_id2),
            self.gatherapplykernel.sender_in.eq(sender2),
            self.gatherapplykernel.message_in.raw_bits().eq(payload2),
            self.gatherapplykernel.message_in_valid.eq(msgvalid2),
            self.gatherapplykernel.state_in.raw_bits().eq(rd_port.dat_r),
            self.gatherapplykernel.state_in_valid.eq(statevalid2),
            self.gatherapplykernel.round_in.eq(roundpar2),
            self.gatherapplykernel.barrier_in.eq(barrier2),
            self.gatherapplykernel.valid_in.eq(valid2),
            ready.eq(self.gatherapplykernel.ready),
            upstream_ack.eq((self.gatherapplykernel.ready | ~valid2) & collision_re)
        ]

        # write state updates
        self.comb += [
            local_wr_port.adr.eq(addresslayout.local_adr(self.gatherapplykernel.nodeid_out)),
            local_wr_port.dat_w.eq(self.gatherapplykernel.state_out.raw_bits()),
            state_barrier.eq(self.gatherapplykernel.state_barrier),
            local_wr_port.we.eq(self.gatherapplykernel.state_valid),
            self.gatherapplykernel.state_ack.eq(1)
        ]

        applykernel_out = Message(**addresslayout.get_params())

        self.comb += [
            applykernel_out.halt.eq(0),
            applykernel_out.barrier.eq(self.gatherapplykernel.barrier_out),
            applykernel_out.roundpar.eq(self.gatherapplykernel.update_round),
            applykernel_out.dest_id.eq(0),
            If(self.gatherapplykernel.barrier_out,
                applykernel_out.sender.eq(pe_id << log2_int(num_nodes_per_pe))
            ).Else(
                applykernel_out.sender.eq(self.gatherapplykernel.update_sender)
            ),
            applykernel_out.payload.eq(self.gatherapplykernel.update_out.raw_bits())
        ]

        self.submodules.barrierdistributor = BarrierDistributorApply(config)

        self.comb += [
            self.barrierdistributor.apply_interface_in.msg.eq(applykernel_out),
            self.barrierdistributor.apply_interface_in.valid.eq(self.gatherapplykernel.update_valid),
            self.gatherapplykernel.update_ack.eq(self.barrierdistributor.apply_interface_in.ack)
        ]

        outfifo_in = Message(**addresslayout.get_params())
        outfifo_out = Message(**addresslayout.get_params())

        if config.updates_in_hmc:
            self.submodules.outfifo = HMCBackedFIFO(width=len(outfifo_in), start_addr=pe_id*(1<<config.hmc_fifo_bits), end_addr=(pe_id + 1)*(1<<config.hmc_fifo_bits), port=config.platform.getHMCPort(pe_id))

            self.sync += [
                If(self.outfifo.full, self.deadlock.eq(1))
            ]
        else:
            self.submodules.outfifo = SyncFIFO(width=len(outfifo_in), depth=num_valid_nodes)
            self.comb += self.deadlock.eq(~self.outfifo.writable)

        self.comb += [
            self.outfifo.din.eq(outfifo_in.raw_bits()),
            outfifo_out.raw_bits().eq(self.outfifo.dout)
        ]

        self.comb += [
            self.barrierdistributor.apply_interface_out.msg.connect(outfifo_in),
            self.outfifo.we.eq(self.barrierdistributor.apply_interface_out.valid),
            self.barrierdistributor.apply_interface_out.ack.eq(self.outfifo.writable)
        ]

        self.comb += [
            self.scatter_interface.msg.raw_bits().eq(self.outfifo.dout),
            self.scatter_interface.valid.eq(self.outfifo.readable),
            self.outfifo.re.eq(self.scatter_interface.ack)
        ]

    def gen_simulation(self, tb):
        logger = logging.getLogger('sim.apply')
        while not (yield tb.global_inactive):
            yield
        if self.pe_id == 0:
            logger.info("State at end of computation:")
        num_valid_nodes = tb.config.addresslayout.max_node_per_pe(tb.config.adj_dict)[self.pe_id] + 1
        for node in range(num_valid_nodes):
            vertexid = tb.config.addresslayout.global_adr(self.pe_id, node)
            if vertexid in tb.config.graph:
                p = "{} (origin={}): ".format(vertexid, tb.config.graph.node[vertexid]["origin"])
                state = convert_int_to_record((yield self.mem[node]), tb.config.addresslayout.node_storage_layout)
                p += str(state)
                if vertexid < 32:
                    logger.info(p)
                else:
                    logger.debug(p)
class Barriercounter(Module):
    def __init__(self, config):
        self.apply_interface_in = ApplyInterface(name="barriercounter_in", **config.addresslayout.get_params())
        self.apply_interface_out = ApplyInterface(name="barriercounter_out", **config.addresslayout.get_params())
        self.round_accepting = Signal(config.addresslayout.channel_bits)

        apply_interface_in_fifo = InterfaceFIFO(layout=self.apply_interface_in.layout, depth=2)
        self.submodules += apply_interface_in_fifo
        self.comb += self.apply_interface_in.connect(apply_interface_in_fifo.din)

        num_pe = config.addresslayout.num_pe

        self.barrier_from_pe = Array(Signal() for _ in range(num_pe))
        self.num_from_pe = Array(Signal(config.addresslayout.nodeidsize) for _ in range(num_pe))
        self.num_expected_from_pe = Array(Signal(config.addresslayout.nodeidsize) for _ in range(num_pe))
        self.all_from_pe = Array(Signal() for _ in range (num_pe))
        self.all_messages_recvd = Signal()
        self.all_barriers_recvd = Signal()

        self.comb += [
            self.all_barriers_recvd.eq(reduce(and_, self.barrier_from_pe)),
            self.all_messages_recvd.eq(reduce(and_, self.all_from_pe)),
        ]

        self.comb += [
            self.all_from_pe[i].eq(self.num_from_pe[i] == self.num_expected_from_pe[i]) for i in range(num_pe)
        ]

        halt = Signal()

        sender_pe = config.addresslayout.pe_adr(apply_interface_in_fifo.dout.msg.sender)

        self.waiting_for_stragglers = Signal()

        self.submodules.fsm = FSM()

        self.fsm.act("DEFAULT",
            If(self.apply_interface_out.ack,
                apply_interface_in_fifo.dout.ack.eq(1),
                NextValue(self.apply_interface_out.msg.raw_bits(), apply_interface_in_fifo.dout.msg.raw_bits()),
                NextValue(self.apply_interface_out.valid, apply_interface_in_fifo.dout.valid & ~apply_interface_in_fifo.dout.msg.barrier),
                If(apply_interface_in_fifo.dout.valid,
                    If(apply_interface_in_fifo.dout.msg.barrier,
                        NextValue(self.barrier_from_pe[sender_pe], 1),
                        NextValue(self.num_expected_from_pe[sender_pe], apply_interface_in_fifo.dout.msg.dest_id),
                        If(~apply_interface_in_fifo.dout.msg.halt,
                            NextValue(halt, 0)
                        ),
                        NextState("CHK_BARRIER")
                    ).Else(
                        NextValue(self.num_from_pe[sender_pe], self.num_from_pe[sender_pe] + 1)
                    )
                )
            )
        )

        self.fsm.act("CHK_BARRIER",
            apply_interface_in_fifo.dout.ack.eq(0),
            If(self.apply_interface_out.ack,
                NextValue(self.apply_interface_out.valid, 0),
                If(self.all_barriers_recvd,
                    NextState("PASS_BARRIER")
                ).Else(
                    NextState("DEFAULT")
                )
            )
        )

        self.fsm.act("PASS_BARRIER",
            apply_interface_in_fifo.dout.ack.eq(0),
            If(self.apply_interface_out.ack,
                If(self.all_messages_recvd,
                    If(self.round_accepting < config.addresslayout.num_channels - 1,
                        NextValue(self.round_accepting, self.round_accepting + 1)
                    ).Else(
                        NextValue(self.round_accepting, 0)
                    ),
                    NextValue(halt, 1),
                    NextValue(self.apply_interface_out.msg.halt, halt),
                    NextValue(self.apply_interface_out.msg.barrier, 1),
                    NextValue(self.apply_interface_out.valid, 1),
                    [NextValue(self.barrier_from_pe[i], 0) for i in range(num_pe)],
                    [NextValue(self.num_from_pe[i], 0) for i in range(num_pe)],
                    NextState("DEFAULT")
                ).Else(
                    NextValue(self.apply_interface_out.valid, 0),
                    NextState("WAIT_FOR_STRAGGLER")
                )
            )
        )

        self.fsm.act("WAIT_FOR_STRAGGLER",
            self.waiting_for_stragglers.eq(1),
            If(self.apply_interface_out.ack,
                apply_interface_in_fifo.dout.ack.eq(1),
                NextValue(self.apply_interface_out.msg.raw_bits(), apply_interface_in_fifo.dout.msg.raw_bits()),
                NextValue(self.apply_interface_out.valid, apply_interface_in_fifo.dout.valid),
                If(apply_interface_in_fifo.dout.valid,
                    NextValue(self.num_from_pe[sender_pe], self.num_from_pe[sender_pe] + 1),
                    NextState("CHK_BARRIER") #this gratuitously checks all_barriers_recvd again, but we need to wait an extra cycle for all_messages_recvd to be updated
                )
            )
        )

    @passive
    def gen_selfcheck(self, tb):
        logger = logging.getLogger('sim.barriercounter')
        while True:
            if (yield self.waiting_for_stragglers):
                logger.warning("Barriercounter is waiting for stragglers:")
                for i in range(tb.config.addresslayout.num_pe):
                    received = (yield self.num_from_pe[i])
                    expected = (yield self.num_expected_from_pe[i])
                    if received != expected:
                        logger.warning("Only {} of {} updates received from PE {}".format(received, expected, i))
            yield
Ejemplo n.º 7
0
class Arbiter(Module):
    def __init__(self, pe_id, config):
        addresslayout = config.addresslayout
        nodeidsize = addresslayout.nodeidsize
        num_pe = addresslayout.num_pe
        self.pe_id = pe_id

        # input (n channels)
        self.apply_interface_in = ApplyInterface(name="arbiter_in",
                                                 **addresslayout.get_params())

        # output
        self.apply_interface_out = ApplyInterface(name="arbiter_out",
                                                  **addresslayout.get_params())

        # input override for injecting the message starting the computation
        self.start_message = ApplyInterface(name="start_message",
                                            **addresslayout.get_params())
        self.start_message.select = Signal()

        self.submodules.barriercounter = Barriercounter(config)
        self.current_round = Signal(config.addresslayout.channel_bits)

        self.comb += [
            self.barriercounter.apply_interface_in.msg.raw_bits().eq(
                self.apply_interface_in.msg.raw_bits()),
            self.barriercounter.apply_interface_in.valid.eq(
                self.apply_interface_in.valid),
            self.apply_interface_in.ack.eq(
                self.barriercounter.apply_interface_in.ack),
            self.current_round.eq(self.barriercounter.round_accepting)
        ]

        # choose between init and regular message channel
        self.comb += \
            If(self.start_message.select,
                self.start_message.connect(self.apply_interface_out)
            ).Else(
                self.barriercounter.apply_interface_out.connect(self.apply_interface_out)
            )

    def gen_selfcheck(self, tb):
        logger = logging.getLogger("sim.arbiter" + str(self.pe_id))
        level = 0
        num_cycles = 0

        while not (yield tb.global_inactive):
            num_cycles += 1

            if (yield self.apply_interface_out.valid) and (
                    yield self.apply_interface_out.ack):
                if (yield self.apply_interface_out.msg.barrier):
                    level += 1
                    logger.debug(
                        "{}: Barrier passed to apply".format(num_cycles))
                else:
                    if level % 2 == (yield
                                     self.apply_interface_out.msg.roundpar):
                        logger.warning(
                            "{}: received message's parity ({}) does not match current round ({})"
                            .format(
                                num_cycles,
                                (yield self.apply_interface_out.msg.roundpar),
                                level))
            yield
Ejemplo n.º 8
0
class Barriercounter(Module):
    def __init__(self, config):
        self.apply_interface_in = ApplyInterface(
            name="barriercounter_in", **config.addresslayout.get_params())
        self.apply_interface_out = ApplyInterface(
            name="barriercounter_out", **config.addresslayout.get_params())
        self.round_accepting = Signal(config.addresslayout.channel_bits)

        apply_interface_in_fifo = InterfaceFIFO(
            layout=self.apply_interface_in.layout, depth=2)
        self.submodules += apply_interface_in_fifo
        self.comb += self.apply_interface_in.connect(
            apply_interface_in_fifo.din)

        num_pe = config.addresslayout.num_pe

        self.barrier_from_pe = Array(Signal() for _ in range(num_pe))
        self.num_from_pe = Array(
            Signal(config.addresslayout.nodeidsize) for _ in range(num_pe))
        self.num_expected_from_pe = Array(
            Signal(config.addresslayout.nodeidsize) for _ in range(num_pe))
        self.all_from_pe = Array(Signal() for _ in range(num_pe))
        self.all_messages_recvd = Signal()
        self.all_barriers_recvd = Signal()

        self.comb += [
            self.all_barriers_recvd.eq(reduce(and_, self.barrier_from_pe)),
            self.all_messages_recvd.eq(reduce(and_, self.all_from_pe)),
        ]

        self.comb += [
            self.all_from_pe[i].eq(
                self.num_from_pe[i] == self.num_expected_from_pe[i])
            for i in range(num_pe)
        ]

        halt = Signal()

        sender_pe = config.addresslayout.pe_adr(
            apply_interface_in_fifo.dout.msg.sender)

        self.submodules.fsm = FSM()

        self.fsm.act(
            "DEFAULT",
            If(
                self.apply_interface_out.ack,
                apply_interface_in_fifo.dout.ack.eq(1),
                NextValue(self.apply_interface_out.msg.raw_bits(),
                          apply_interface_in_fifo.dout.msg.raw_bits()),
                NextValue(
                    self.apply_interface_out.valid,
                    apply_interface_in_fifo.dout.valid
                    & ~apply_interface_in_fifo.dout.msg.barrier),
                If(
                    apply_interface_in_fifo.dout.valid,
                    If(
                        apply_interface_in_fifo.dout.msg.barrier,
                        NextValue(self.barrier_from_pe[sender_pe], 1),
                        NextValue(self.num_expected_from_pe[sender_pe],
                                  apply_interface_in_fifo.dout.msg.dest_id),
                        If(~apply_interface_in_fifo.dout.msg.halt,
                           NextValue(halt, 0)), NextState("CHK_BARRIER")).Else(
                               NextValue(self.num_from_pe[sender_pe],
                                         self.num_from_pe[sender_pe] + 1)))))

        self.fsm.act(
            "CHK_BARRIER", apply_interface_in_fifo.dout.ack.eq(0),
            If(
                self.apply_interface_out.ack,
                NextValue(self.apply_interface_out.valid, 0),
                If(self.all_barriers_recvd,
                   NextState("PASS_BARRIER")).Else(NextState("DEFAULT"))))

        self.fsm.act(
            "PASS_BARRIER", apply_interface_in_fifo.dout.ack.eq(0),
            If(
                self.apply_interface_out.ack,
                If(
                    self.all_messages_recvd,
                    If(
                        self.round_accepting <
                        config.addresslayout.num_channels - 1,
                        NextValue(self.round_accepting,
                                  self.round_accepting + 1)).Else(
                                      NextValue(self.round_accepting, 0)),
                    NextValue(halt, 1),
                    NextValue(self.apply_interface_out.msg.halt, halt),
                    NextValue(self.apply_interface_out.msg.barrier, 1),
                    NextValue(self.apply_interface_out.valid, 1), [
                        NextValue(self.barrier_from_pe[i], 0)
                        for i in range(num_pe)
                    ],
                    [NextValue(self.num_from_pe[i], 0) for i in range(num_pe)],
                    NextState("DEFAULT")).Else(
                        NextValue(self.apply_interface_out.valid, 0),
                        NextState("WAIT_FOR_STRAGGLER"))))

        self.fsm.act(
            "WAIT_FOR_STRAGGLER",
            If(
                self.apply_interface_out.ack,
                apply_interface_in_fifo.dout.ack.eq(1),
                NextValue(self.apply_interface_out.msg.raw_bits(),
                          apply_interface_in_fifo.dout.msg.raw_bits()),
                NextValue(self.apply_interface_out.valid,
                          apply_interface_in_fifo.dout.valid),
                If(
                    apply_interface_in_fifo.dout.valid,
                    NextValue(self.num_from_pe[sender_pe],
                              self.num_from_pe[sender_pe] + 1),
                    NextState(
                        "CHK_BARRIER"
                    )  #this gratuitously checks all_barriers_recvd again, but we need to wait an extra cycle for all_messages_recvd to be updated
                )))
class Barriercounter(Module):
    def __init__(self, config):
        self.apply_interface_in = ApplyInterface(name="barriercounter_in", **config.addresslayout.get_params())
        self.apply_interface_out = ApplyInterface(name="barriercounter_out", **config.addresslayout.get_params())
        self.round_accepting = Signal(config.addresslayout.channel_bits)

        apply_interface_in_fifo = InterfaceFIFO(layout=self.apply_interface_in.layout, depth=2)
        self.submodules += apply_interface_in_fifo
        self.comb += self.apply_interface_in.connect(apply_interface_in_fifo.din)

        num_pe = config.addresslayout.num_pe

        self.barrier_from_pe = Array(Signal() for _ in range(num_pe))
        self.num_from_pe = Array(Signal(config.addresslayout.nodeidsize) for _ in range(num_pe))
        self.num_expected_from_pe = Array(Signal(config.addresslayout.nodeidsize) for _ in range(num_pe))
        self.all_from_pe = Array(Signal() for _ in range (num_pe))
        self.all_messages_recvd = Signal()
        self.all_barriers_recvd = Signal()

        self.comb += [
            self.all_barriers_recvd.eq(reduce(and_, self.barrier_from_pe)),
            self.all_messages_recvd.eq(reduce(and_, self.all_from_pe)),
        ]

        self.comb += [
            self.all_from_pe[i].eq(self.num_from_pe[i] == self.num_expected_from_pe[i]) for i in range(num_pe)
        ]

        halt = Signal()

        sender_pe = config.addresslayout.pe_adr(apply_interface_in_fifo.dout.msg.sender)

        self.submodules.fsm = FSM()

        self.fsm.act("DEFAULT",
            If(self.apply_interface_out.ack,
                apply_interface_in_fifo.dout.ack.eq(1),
                NextValue(self.apply_interface_out.msg.raw_bits(), apply_interface_in_fifo.dout.msg.raw_bits()),
                NextValue(self.apply_interface_out.valid, apply_interface_in_fifo.dout.valid & ~apply_interface_in_fifo.dout.msg.barrier),
                If(apply_interface_in_fifo.dout.valid,
                    If(apply_interface_in_fifo.dout.msg.barrier,
                        NextValue(self.barrier_from_pe[sender_pe], 1),
                        NextValue(self.num_expected_from_pe[sender_pe], apply_interface_in_fifo.dout.msg.dest_id),
                        If(~apply_interface_in_fifo.dout.msg.halt,
                            NextValue(halt, 0)
                        ),
                        NextState("CHK_BARRIER")
                    ).Else(
                        NextValue(self.num_from_pe[sender_pe], self.num_from_pe[sender_pe] + 1)
                    )
                )
            )
        )

        self.fsm.act("CHK_BARRIER",
            apply_interface_in_fifo.dout.ack.eq(0),
            If(self.apply_interface_out.ack,
                NextValue(self.apply_interface_out.valid, 0),
                If(self.all_barriers_recvd,
                    NextState("PASS_BARRIER")
                ).Else(
                    NextState("DEFAULT")
                )
            )
        )

        self.fsm.act("PASS_BARRIER",
            apply_interface_in_fifo.dout.ack.eq(0),
            If(self.apply_interface_out.ack,
                If(self.all_messages_recvd,
                    If(self.round_accepting < config.addresslayout.num_channels - 1,
                        NextValue(self.round_accepting, self.round_accepting + 1)
                    ).Else(
                        NextValue(self.round_accepting, 0)
                    ),
                    NextValue(halt, 1),
                    NextValue(self.apply_interface_out.msg.halt, halt),
                    NextValue(self.apply_interface_out.msg.barrier, 1),
                    NextValue(self.apply_interface_out.valid, 1),
                    [NextValue(self.barrier_from_pe[i], 0) for i in range(num_pe)],
                    [NextValue(self.num_from_pe[i], 0) for i in range(num_pe)],
                    NextState("DEFAULT")
                ).Else(
                    NextValue(self.apply_interface_out.valid, 0),
                    NextState("WAIT_FOR_STRAGGLER")
                )
            )
        )

        self.fsm.act("WAIT_FOR_STRAGGLER",
            If(self.apply_interface_out.ack,
                apply_interface_in_fifo.dout.ack.eq(1),
                NextValue(self.apply_interface_out.msg.raw_bits(), apply_interface_in_fifo.dout.msg.raw_bits()),
                NextValue(self.apply_interface_out.valid, apply_interface_in_fifo.dout.valid),
                If(apply_interface_in_fifo.dout.valid,
                    NextValue(self.num_from_pe[sender_pe], self.num_from_pe[sender_pe] + 1),
                    NextState("CHK_BARRIER") #this gratuitously checks all_barriers_recvd again, but we need to wait an extra cycle for all_messages_recvd to be updated
                )
            )
        )
Ejemplo n.º 10
0
class Apply(Module):
    def __init__(self, config, pe_id):
        self.config = config
        self.pe_id = pe_id
        addresslayout = config.addresslayout
        nodeidsize = addresslayout.nodeidsize
        num_nodes_per_pe = addresslayout.num_nodes_per_pe
        num_valid_nodes = max(2, len(config.adj_idx[pe_id])+1)

        # input Q interface
        self.apply_interface = ApplyInterface(name="apply_in", **addresslayout.get_params())

        # scatter interface
        # send self.update message to all neighbors
        # message format (sending_node_id) (normally would be (sending_node_id, weight), but for PR weight = sending_node_id)
        self.scatter_interface = ScatterInterface(name="apply_out", **addresslayout.get_params())

        self.deadlock = Signal()

        ####

        apply_interface_in_fifo = InterfaceFIFO(layout=self.apply_interface.layout, depth=8, name="apply_in_fifo")
        self.submodules += apply_interface_in_fifo
        self.comb += self.apply_interface.connect(apply_interface_in_fifo.din)

        # local node data storage
        self.specials.mem = Memory(layout_len(addresslayout.node_storage_layout), num_valid_nodes, init=config.init_nodedata[pe_id] if config.init_nodedata else None, name="vertex_data_{}".format(self.pe_id))
        rd_port = self.specials.rd_port = self.mem.get_port(has_re=True)
        wr_port = self.specials.wr_port = self.mem.get_port(write_capable=True)

        local_wr_port = Record(layout=get_mem_port_layout(wr_port))
        self.external_wr_port = Record(layout=get_mem_port_layout(wr_port) + [("select", 1)])

        self.comb += [
            If(self.external_wr_port.select,
                self.external_wr_port.connect(wr_port, omit={"select"})
            ).Else(
                local_wr_port.connect(wr_port)
            )
        ]

        # detect termination (now done by collating votes to halt in barriercounter - if barrier is passed on with halt bit set, don't propagate)
        self.inactive = Signal()
        self.sync += If(apply_interface_in_fifo.dout.valid & apply_interface_in_fifo.dout.ack & apply_interface_in_fifo.dout.msg.barrier & apply_interface_in_fifo.dout.msg.halt,
            self.inactive.eq(1)
        )

        # should pipeline advance?
        upstream_ack = Signal()
        collision_re = Signal()
        collision_en = Signal()

        # count levels
        self.level = Signal(32)

        ## Stage 1
        # rename some signals for easier reading, separate barrier and normal valid (for writing to state mem)
        dest_node_id = Signal(nodeidsize)
        sender = Signal(nodeidsize)
        payload = Signal(addresslayout.messagepayloadsize)
        roundpar = Signal(config.addresslayout.channel_bits)
        valid = Signal()
        barrier = Signal()

        self.comb += [
            dest_node_id.eq(apply_interface_in_fifo.dout.msg.dest_id),
            sender.eq(apply_interface_in_fifo.dout.msg.sender),
            payload.eq(apply_interface_in_fifo.dout.msg.payload),
            roundpar.eq(apply_interface_in_fifo.dout.msg.roundpar),
            valid.eq(apply_interface_in_fifo.dout.valid & ~apply_interface_in_fifo.dout.msg.barrier),
            barrier.eq(apply_interface_in_fifo.dout.valid & apply_interface_in_fifo.dout.msg.barrier & ~apply_interface_in_fifo.dout.msg.halt),
        ]

        ## Stage 2
        dest_node_id2 = Signal(nodeidsize)
        sender2 = Signal(nodeidsize)
        payload2 = Signal(addresslayout.messagepayloadsize)
        roundpar2 = Signal(config.addresslayout.channel_bits)
        barrier2 = Signal()
        valid2 = Signal()
        ready = Signal()
        msgvalid2 = Signal()
        statevalid2 = Signal()

        state_barrier = Signal()

        node_idx = Signal(nodeidsize)
        gather_done = Signal()

        next_roundpar = Signal(config.addresslayout.channel_bits)
        self.comb += If(roundpar==config.addresslayout.num_channels-1, next_roundpar.eq(0)).Else(next_roundpar.eq(roundpar+1))

        self.submodules.fsm = FSM()
        self.fsm.act("GATHER",
            rd_port.re.eq(upstream_ack),
            apply_interface_in_fifo.dout.ack.eq(upstream_ack),
            rd_port.adr.eq(addresslayout.local_adr(dest_node_id)),
            NextValue(collision_en, 1),
            If(~collision_re,
                NextValue(valid2, 0) # insert bubble if collision
            ).Elif(upstream_ack,
                NextValue(valid2, valid),
                NextValue(dest_node_id2, dest_node_id),
                NextValue(sender2, sender),
                NextValue(payload2, payload),
                NextValue(roundpar2, next_roundpar),
                NextValue(statevalid2, 1),
                NextValue(msgvalid2, ~barrier),
                If(barrier,
                    NextValue(collision_en, 0),
                    NextValue(valid2, 0),
                    NextState("FLUSH")
                )
            )
        )
        self.fsm.act("FLUSH",
            rd_port.re.eq(0),
            NextValue(node_idx, pe_id << log2_int(num_nodes_per_pe)),
            apply_interface_in_fifo.dout.ack.eq(0),
            rd_port.adr.eq(addresslayout.local_adr(dest_node_id)),
            If(gather_done,
                NextState("APPLY")
            )
        )
        self.fsm.act("APPLY",
            rd_port.re.eq(ready),
            apply_interface_in_fifo.dout.ack.eq(0),
            rd_port.adr.eq(addresslayout.local_adr(node_idx)),
            If(ready,
                NextValue(valid2, 1),
                NextValue(dest_node_id2, node_idx),
                NextValue(node_idx, node_idx+1),
                If(node_idx==(len(config.adj_idx[pe_id]) + (pe_id << log2_int(num_nodes_per_pe))),
                    NextValue(statevalid2, 0),
                    NextValue(barrier2, 1),
                    NextValue(valid2, 1),
                    NextState("BARRIER_SEND")
                )
            )
        )
        self.fsm.act("BARRIER_SEND",
            apply_interface_in_fifo.dout.ack.eq(0),
            rd_port.adr.eq(addresslayout.local_adr(node_idx)),
            If(ready,
                NextValue(barrier2, 0),
                NextValue(valid2, 0),
                If(state_barrier,
                    NextValue(self.level, self.level+1),
                    NextState("GATHER")
                ).Else(
                    NextState("BARRIER_WAIT")
                )
            )
        )
        self.fsm.act("BARRIER_WAIT",
            If(state_barrier,
                NextValue(self.level, self.level+1),
                NextState("GATHER")
            )
        )

        # collision handling (combinatorial)
        self.submodules.collisiondetector = CollisionDetector(addresslayout)

        self.comb += [
            self.collisiondetector.read_adr.eq(addresslayout.local_adr(dest_node_id)),
            self.collisiondetector.read_adr_valid.eq(ready & valid & collision_en), # can't be rd_port.re because that uses collisiondetector.re -> comb loop
            self.collisiondetector.write_adr.eq(local_wr_port.adr),
            self.collisiondetector.write_adr_valid.eq(local_wr_port.we),
            collision_re.eq(self.collisiondetector.re),
            gather_done.eq(self.collisiondetector.all_clear)
        ]

        # User code
        if hasattr(config, "gatherapplykernel"):
            self.submodules.gatherapplykernel = config.gatherapplykernel(config)
        else:
            self.submodules.gatherapplykernel = GatherApplyWrapper(config.gatherkernel(config), config.applykernel(config))

        self.comb += [
            self.gatherapplykernel.level_in.eq(self.level),
            self.gatherapplykernel.nodeid_in.eq(dest_node_id2),
            self.gatherapplykernel.sender_in.eq(sender2),
            self.gatherapplykernel.message_in.raw_bits().eq(payload2),
            self.gatherapplykernel.message_in_valid.eq(msgvalid2),
            self.gatherapplykernel.state_in.raw_bits().eq(rd_port.dat_r),
            self.gatherapplykernel.state_in_valid.eq(statevalid2),
            self.gatherapplykernel.round_in.eq(roundpar2),
            self.gatherapplykernel.barrier_in.eq(barrier2),
            self.gatherapplykernel.valid_in.eq(valid2),
            ready.eq(self.gatherapplykernel.ready),
            upstream_ack.eq((self.gatherapplykernel.ready | ~valid2) & collision_re)
        ]

        # write state updates
        self.comb += [
            local_wr_port.adr.eq(addresslayout.local_adr(self.gatherapplykernel.nodeid_out)),
            local_wr_port.dat_w.eq(self.gatherapplykernel.state_out.raw_bits()),
            state_barrier.eq(self.gatherapplykernel.state_barrier),
            local_wr_port.we.eq(self.gatherapplykernel.state_valid),
            self.gatherapplykernel.state_ack.eq(1)
        ]

        # output handling
        _layout = [
        ( "barrier", 1, DIR_M_TO_S ),
        ( "roundpar", config.addresslayout.channel_bits, DIR_M_TO_S ),
        ( "sender", "nodeidsize", DIR_M_TO_S ),
        ( "msg" , addresslayout.updatepayloadsize, DIR_M_TO_S )
        ]
        outfifo_in = Record(set_layout_parameters(_layout, **addresslayout.get_params()))
        outfifo_out = Record(set_layout_parameters(_layout, **addresslayout.get_params()))

        if config.updates_in_hmc:
            fpga_id = pe_id//config.addresslayout.num_pe_per_fpga
            local_pe_id = pe_id % config.addresslayout.num_pe_per_fpga
            self.submodules.outfifo = HMCBackedFIFO(width=len(outfifo_in), start_addr=local_pe_id*(1<<config.hmc_fifo_bits), end_addr=(local_pe_id + 1)*(1<<config.hmc_fifo_bits), port=config.platform[fpga_id].getHMCPort(local_pe_id))

            self.sync += [
                If(self.outfifo.full, self.deadlock.eq(1))
            ]
        else:
            self.submodules.outfifo = SyncFIFO(width=len(outfifo_in), depth=num_valid_nodes)
            self.comb += self.deadlock.eq(~self.outfifo.writable)

        self.comb += [
            self.outfifo.din.eq(outfifo_in.raw_bits()),
            outfifo_out.raw_bits().eq(self.outfifo.dout)
        ]

        self.comb += [
            outfifo_in.msg.eq(self.gatherapplykernel.update_out.raw_bits()),
            If(self.gatherapplykernel.barrier_out, outfifo_in.sender.eq(pe_id << log2_int(num_nodes_per_pe))
            ).Else(outfifo_in.sender.eq(self.gatherapplykernel.update_sender)),
            self.outfifo.we.eq(self.gatherapplykernel.update_valid),
            outfifo_in.roundpar.eq(self.gatherapplykernel.update_round),
            outfifo_in.barrier.eq(self.gatherapplykernel.barrier_out),
            self.gatherapplykernel.update_ack.eq(self.outfifo.writable)
        ]

        payload4 = Signal(addresslayout.updatepayloadsize)
        sender4 = Signal(addresslayout.nodeidsize)
        roundpar4 = Signal(config.addresslayout.channel_bits)
        barrier4 = Signal()
        valid4 = Signal()

        self.sync += If(self.scatter_interface.ack,
            payload4.eq(outfifo_out.msg),
            sender4.eq(outfifo_out.sender),
            roundpar4.eq(outfifo_out.roundpar),
            barrier4.eq(outfifo_out.barrier),
            valid4.eq(self.outfifo.readable)
        )

        self.comb += [
            self.scatter_interface.payload.eq(payload4),
            self.scatter_interface.sender.eq(sender4),
            self.scatter_interface.roundpar.eq(roundpar4),
            self.scatter_interface.barrier.eq(barrier4),
            self.scatter_interface.valid.eq(valid4)
        ]

        # send from fifo when receiver ready
        self.comb += self.outfifo.re.eq(self.scatter_interface.ack)

    def gen_simulation(self, tb):
        logger = logging.getLogger('sim.apply')
        while not (yield tb.global_inactive):
            yield
        if self.pe_id == 0:
            logger.info("State at end of computation:")
        for node in range(len(tb.config.adj_idx[self.pe_id])):
            vertexid = tb.config.addresslayout.global_adr(self.pe_id, node)
            if vertexid in tb.config.graph:
                p = "{} (origin={}): ".format(vertexid, tb.config.graph.node[vertexid]["origin"])
                state = convert_int_to_record((yield self.mem[node]), tb.config.addresslayout.node_storage_layout)
                p += str(state)
                if vertexid < 32:
                    logger.info(p)
                else:
                    logger.debug(p)