def __init__(self, config, fifos):
        addresslayout = config.addresslayout
        nodeidsize = addresslayout.nodeidsize
        num_pe = addresslayout.num_pe

        # output
        self.apply_interface = ApplyInterface(**addresslayout.get_params())

        # input override for injecting the message starting the computation
        self.start_message = ApplyInterface(**addresslayout.get_params())
        self.start_message.select = Signal()

        self.fifos = fifos

        self.submodules.roundrobin = RoundRobin(num_pe, switch_policy=SP_CE)

        # arrays for choosing incoming fifo to use
        array_data = Array(fifo.dout.raw_bits() for fifo in fifos)
        array_re = Array(fifo.re for fifo in fifos)
        array_readable = Array(fifo.readable for fifo in fifos)
        array_barrier = Array(fifo.dout.barrier for fifo in fifos)

        barrier_reached = Signal()
        self.comb += barrier_reached.eq(
            reduce(and_, array_barrier) & reduce(and_, array_readable))

        self.submodules.outfifo = RecordFIFO(
            layout=Message(**addresslayout.get_params()).layout, depth=8)

        self.comb += If(
            self.start_message.select,  # override
            self.outfifo.din.raw_bits().eq(self.start_message.msg.raw_bits()),
            self.outfifo.we.eq(self.start_message.valid),
            self.start_message.ack.eq(self.outfifo.writable),
            self.roundrobin.ce.eq(0)
        ).Elif(
            barrier_reached, self.outfifo.din.barrier.eq(1),
            self.outfifo.we.eq(1),
            [array_re[i].eq(self.outfifo.writable)
             for i in range(len(fifos))]).Else(  # normal roundrobin
                 self.outfifo.din.raw_bits().eq(
                     array_data[self.roundrobin.grant]),
                 self.outfifo.we.eq(array_readable[self.roundrobin.grant]
                                    & ~array_barrier[self.roundrobin.grant]),
                 array_re[self.roundrobin.grant].eq(
                     self.outfifo.writable
                     & ~array_barrier[self.roundrobin.grant]), [
                         self.roundrobin.request[i].eq(array_readable[i]
                                                       & ~array_barrier[i])
                         for i in range(len(fifos))
                     ], self.roundrobin.ce.eq(self.outfifo.writable))

        self.comb += [
            self.apply_interface.msg.raw_bits().eq(
                self.outfifo.dout.raw_bits()),
            self.apply_interface.valid.eq(self.outfifo.readable),
            self.outfifo.re.eq(self.apply_interface.ack)
        ]
Exemple #2
0
    def __init__(self, pe_id, config):
        self.network_interface_in = NetworkInterface(
            name="left_in", **config.addresslayout.get_params())
        self.network_interface_out = NetworkInterface(
            name="right_out", **config.addresslayout.get_params())
        self.local_interface_in = NetworkInterface(
            name="local_in", **config.addresslayout.get_params())
        local_interface_out = NetworkInterface(
            name="local_out", **config.addresslayout.get_params())

        self.apply_interface_out = ApplyInterface(
            **config.addresslayout.get_params())
        self.start_message = ApplyInterface(
            **config.addresslayout.get_params())
        self.start_message.select = Signal()

        self.network_round = Signal(config.addresslayout.channel_bits)
        self.round_accepting = Signal(config.addresslayout.channel_bits)

        mux_interface = NetworkInterface(**config.addresslayout.get_params())

        inject = Signal()
        self.comb += [
            inject.eq(
                self.local_interface_in.msg.roundpar == self.network_round),
            If(
                self.network_interface_in.valid &
                (self.network_interface_in.msg.roundpar == self.network_round),
                self.network_interface_in.connect(mux_interface)).Elif(
                    inject, self.local_interface_in.connect(mux_interface))
        ]

        self.comb += [
            If(mux_interface.dest_pe == pe_id,
               mux_interface.connect(local_interface_out)).Else(
                   mux_interface.connect(self.network_interface_out))
        ]

        self.submodules.barriercounter = Barriercounter(config)

        self.comb += [
            local_interface_out.msg.connect(
                self.barriercounter.apply_interface_in.msg),
            self.barriercounter.apply_interface_in.valid.eq(
                local_interface_out.valid),
            local_interface_out.ack.eq(
                self.barriercounter.apply_interface_in.ack),
            If(self.start_message.select,
               self.start_message.connect(self.apply_interface_out)).Else(
                   self.barriercounter.apply_interface_out.connect(
                       self.apply_interface_out)),
            self.round_accepting.eq(self.barriercounter.round_accepting)
        ]
Exemple #3
0
    def __init__(self, config):
        num_pe = config.addresslayout.num_pe
        num_nodes_per_pe = config.addresslayout.num_nodes_per_pe

        self.apply_interface = [
            ApplyInterface(name="network_out",
                           **config.addresslayout.get_params())
            for _ in range(num_pe)
        ]
        self.network_interface = [
            NetworkInterface(name="network_in",
                             **config.addresslayout.get_params())
            for _ in range(num_pe)
        ]

        fifos = [
            InterfaceFIFO(layout=set_layout_parameters(
                _network_layout, **config.addresslayout.get_params()),
                          depth=8) for i in range(num_pe)
        ]
        self.submodules.fifos = fifos
        self.submodules.arbiter = [
            Arbiter(sink, config) for sink in range(num_pe)
        ]

        for i in range(num_pe):
            j = (i + 1) % num_pe

            self.comb += [
                self.arbiter[i].network_interface_out.connect(fifos[i].din),
                fifos[i].dout.connect(self.arbiter[j].network_interface_in),
                self.network_interface[i].connect(
                    self.arbiter[i].local_interface_in),
                self.arbiter[i].apply_interface_out.connect(
                    self.apply_interface[i])
            ]

        network_round = Signal(config.addresslayout.channel_bits)
        next_round = Signal(config.addresslayout.channel_bits)
        proceed = Signal()

        self.comb += [
            proceed.eq(
                reduce(and_,
                       [a.round_accepting == next_round
                        for a in self.arbiter])),
            If(network_round < config.addresslayout.num_channels - 1,
               next_round.eq(network_round + 1)).Else(next_round.eq(0)),
            [
                self.arbiter[i].network_round.eq(network_round)
                for i in range(num_pe)
            ]
        ]

        self.sync += If(proceed, network_round.eq(next_round))
Exemple #4
0
    def __init__(self, pe_id, config):
        addresslayout = config.addresslayout
        nodeidsize = addresslayout.nodeidsize
        num_pe = addresslayout.num_pe
        self.pe_id = pe_id

        # input (n channels)
        self.apply_interface_in = ApplyInterface(name="arbiter_in",
                                                 **addresslayout.get_params())

        # output
        self.apply_interface_out = ApplyInterface(name="arbiter_out",
                                                  **addresslayout.get_params())

        # input override for injecting the message starting the computation
        self.start_message = ApplyInterface(name="start_message",
                                            **addresslayout.get_params())
        self.start_message.select = Signal()

        self.submodules.barriercounter = Barriercounter(config)
        self.current_round = Signal(config.addresslayout.channel_bits)

        self.comb += [
            self.barriercounter.apply_interface_in.msg.raw_bits().eq(
                self.apply_interface_in.msg.raw_bits()),
            self.barriercounter.apply_interface_in.valid.eq(
                self.apply_interface_in.valid),
            self.apply_interface_in.ack.eq(
                self.barriercounter.apply_interface_in.ack),
            self.current_round.eq(self.barriercounter.round_accepting)
        ]

        # choose between init and regular message channel
        self.comb += \
            If(self.start_message.select,
                self.start_message.connect(self.apply_interface_out)
            ).Else(
                self.barriercounter.apply_interface_out.connect(self.apply_interface_out)
            )
Exemple #5
0
    def __init__(self, config):
        num_pe = config.addresslayout.num_pe
        num_nodes_per_pe = config.addresslayout.num_nodes_per_pe

        self.apply_interface = [
            ApplyInterface(name="network_out",
                           **config.addresslayout.get_params())
            for _ in range(num_pe)
        ]
        self.network_interface = [
            NetworkInterface(name="network_in",
                             **config.addresslayout.get_params())
            for _ in range(num_pe)
        ]

        self.submodules.arbiter = [
            Arbiter(sink, config) for sink in range(num_pe)
        ]
        self.comb += [
            a.apply_interface_out.connect(self.apply_interface[i])
            for i, a in enumerate(self.arbiter)
        ]
    def __init__(self, config):
        self.apply_interface_in = ApplyInterface(
            name="barriercounter_in", **config.addresslayout.get_params())
        self.apply_interface_out = ApplyInterface(
            name="barriercounter_out", **config.addresslayout.get_params())
        self.round_accepting = Signal(config.addresslayout.channel_bits)

        apply_interface_in_fifo = InterfaceFIFO(
            layout=self.apply_interface_in.layout, depth=2)
        self.submodules += apply_interface_in_fifo
        self.comb += self.apply_interface_in.connect(
            apply_interface_in_fifo.din)

        num_pe = config.addresslayout.num_pe

        self.barrier_from_pe = Array(Signal() for _ in range(num_pe))
        self.num_from_pe = Array(
            Signal(config.addresslayout.nodeidsize) for _ in range(num_pe))
        self.num_expected_from_pe = Array(
            Signal(config.addresslayout.nodeidsize) for _ in range(num_pe))
        self.all_from_pe = Array(Signal() for _ in range(num_pe))
        self.all_messages_recvd = Signal()
        self.all_barriers_recvd = Signal()

        self.comb += [
            self.all_barriers_recvd.eq(reduce(and_, self.barrier_from_pe)),
            self.all_messages_recvd.eq(reduce(and_, self.all_from_pe)),
        ]

        self.comb += [
            self.all_from_pe[i].eq(
                self.num_from_pe[i] == self.num_expected_from_pe[i])
            for i in range(num_pe)
        ]

        halt = Signal()

        sender_pe = config.addresslayout.pe_adr(
            apply_interface_in_fifo.dout.msg.sender)

        self.waiting_for_stragglers = Signal()

        self.submodules.fsm = FSM()

        self.fsm.act(
            "DEFAULT",
            If(
                self.apply_interface_out.ack,
                apply_interface_in_fifo.dout.ack.eq(1),
                NextValue(self.apply_interface_out.msg.raw_bits(),
                          apply_interface_in_fifo.dout.msg.raw_bits()),
                NextValue(
                    self.apply_interface_out.valid,
                    apply_interface_in_fifo.dout.valid
                    & ~apply_interface_in_fifo.dout.msg.barrier),
                If(
                    apply_interface_in_fifo.dout.valid,
                    If(
                        apply_interface_in_fifo.dout.msg.barrier,
                        NextValue(self.barrier_from_pe[sender_pe], 1),
                        NextValue(self.num_expected_from_pe[sender_pe],
                                  apply_interface_in_fifo.dout.msg.dest_id),
                        If(~apply_interface_in_fifo.dout.msg.halt,
                           NextValue(halt, 0)), NextState("CHK_BARRIER")).Else(
                               NextValue(self.num_from_pe[sender_pe],
                                         self.num_from_pe[sender_pe] + 1)))))

        self.fsm.act(
            "CHK_BARRIER", apply_interface_in_fifo.dout.ack.eq(0),
            If(
                self.apply_interface_out.ack,
                NextValue(self.apply_interface_out.valid, 0),
                If(self.all_barriers_recvd,
                   NextState("PASS_BARRIER")).Else(NextState("DEFAULT"))))

        self.fsm.act(
            "PASS_BARRIER", apply_interface_in_fifo.dout.ack.eq(0),
            If(
                self.apply_interface_out.ack,
                If(
                    self.all_messages_recvd,
                    If(
                        self.round_accepting <
                        config.addresslayout.num_channels - 1,
                        NextValue(self.round_accepting,
                                  self.round_accepting + 1)).Else(
                                      NextValue(self.round_accepting, 0)),
                    NextValue(halt, 1),
                    NextValue(self.apply_interface_out.msg.halt, halt),
                    NextValue(self.apply_interface_out.msg.barrier, 1),
                    NextValue(self.apply_interface_out.valid, 1), [
                        NextValue(self.barrier_from_pe[i], 0)
                        for i in range(num_pe)
                    ],
                    [NextValue(self.num_from_pe[i], 0) for i in range(num_pe)],
                    NextState("DEFAULT")).Else(
                        NextValue(self.apply_interface_out.valid, 0),
                        NextState("WAIT_FOR_STRAGGLER"))))

        self.fsm.act(
            "WAIT_FOR_STRAGGLER",
            self.waiting_for_stragglers.eq(1),
            If(
                self.apply_interface_out.ack,
                apply_interface_in_fifo.dout.ack.eq(1),
                NextValue(self.apply_interface_out.msg.raw_bits(),
                          apply_interface_in_fifo.dout.msg.raw_bits()),
                NextValue(self.apply_interface_out.valid,
                          apply_interface_in_fifo.dout.valid),
                If(
                    apply_interface_in_fifo.dout.valid,
                    NextValue(self.num_from_pe[sender_pe],
                              self.num_from_pe[sender_pe] + 1),
                    NextState(
                        "CHK_BARRIER"
                    )  #this gratuitously checks all_barriers_recvd again, but we need to wait an extra cycle for all_messages_recvd to be updated
                )))
Exemple #7
0
    def __init__(self, config):
        self.config = config
        assert config.addresslayout.num_fpga == 1
        fpga_id = 0
        self.pe_start = pe_start = fpga_id * config.addresslayout.num_pe_per_fpga
        self.pe_end = pe_end = min(
            (fpga_id + 1) * config.addresslayout.num_pe_per_fpga,
            config.addresslayout.num_pe)
        num_local_pe = pe_end - pe_start

        if config.memtype != "BRAM" and config.has_edgedata:
            raise NotImplementedError()

        self.submodules.portsharer = DDRPortSharer(config=config,
                                                   num_ports=num_local_pe)

        if config.inverted:
            from inverted_apply import Apply
            from inverted_scatter import Scatter
            from inverted_network import UpdateNetwork
            self.submodules.network = UpdateNetwork(config)
        else:
            from fifo_network import Network
            from core_apply import Apply
            from core_scatter import Scatter
            self.submodules.network = Network(config)

        self.submodules.apply = [
            Apply(config, i) for i in range(pe_start, pe_end)
        ]
        self.submodules.scatter = [
            Scatter(i, config, port=self.portsharer.get_port(i - pe_start))
            for i in range(pe_start, pe_end)
        ]

        if config.inverted:
            # connect among PEs
            for i in range(num_local_pe):
                self.comb += [
                    self.apply[i].scatter_interface.connect(
                        self.network.apply_interface_in[i]),
                    self.network.scatter_interface_out[i].connect(
                        self.scatter[i].scatter_interface)
                ]
            # connection within PEs is done at start_message
        else:
            # connect within PEs
            self.comb += [
                self.apply[i].scatter_interface.connect(
                    self.scatter[i].scatter_interface)
                for i in range(num_local_pe)
            ]

            # connect to network
            self.comb += [
                self.network.apply_interface[i].connect(
                    self.apply[i].apply_interface) for i in range(num_local_pe)
            ]
            self.comb += [
                self.scatter[i].network_interface.connect(
                    self.network.network_interface[i])
                for i in range(num_local_pe)
            ]

        # state of calculation
        self.global_inactive = Signal()
        if config.inverted:
            self.comb += self.global_inactive.eq(self.network.inactive)
        else:
            self.comb += self.global_inactive.eq(
                reduce(and_, [pe.inactive for pe in self.apply]))

        self.kernel_error = Signal()
        self.comb += self.kernel_error.eq(
            reduce(or_,
                   (a.gatherapplykernel.kernel_error for a in self.apply)))

        self.deadlock = Signal()
        self.comb += self.deadlock.eq(
            reduce(or_, [pe.deadlock for pe in self.apply]))

        self.total_num_messages = Signal(32)
        self.comb += [
            self.total_num_messages.eq(
                sum(scatter.total_num_messages for scatter in self.scatter))
        ]

        if config.inverted:
            start_message = [
                ApplyInterface(name="start_message",
                               **config.addresslayout.get_params())
                for i in range(num_local_pe)
            ]
            for i in range(num_local_pe):
                start_message[i].select = Signal()
                self.comb += [
                    If(start_message[i].select, start_message[i].connect(
                        self.apply[i].apply_interface)).Else(
                            self.scatter[i].apply_interface.connect(
                                self.apply[i].apply_interface))
                ]
        else:
            start_message = [a.start_message for a in self.network.arbiter]
        assert len(start_message) == num_local_pe

        injected = [Signal() for i in range(num_local_pe)]

        self.start = Signal()
        init = Signal()
        self.done = Signal()
        self.cycle_count = Signal(32)

        self.sync += [init.eq(self.start & ~reduce(and_, injected))]

        self.comb += [self.done.eq(~init & self.global_inactive)]

        for i in range(num_local_pe):
            self.comb += [
                start_message[i].select.eq(init),
                start_message[i].msg.barrier.eq(1),
                start_message[i].msg.roundpar.eq(
                    config.addresslayout.num_channels - 1),
                start_message[i].valid.eq(~injected[i])
            ]

        self.sync += [[
            If(start_message[i].ack, injected[i].eq(1))
            for i in range(num_local_pe)
        ],
                      If(~reduce(and_, injected), self.cycle_count.eq(0)).Elif(
                          ~self.global_inactive,
                          self.cycle_count.eq(self.cycle_count + 1))]
 def __init__(self, config):
     self.apply_interface_in = ApplyInterface(
         name="barriercounter_in", **config.addresslayout.get_params())
     self.apply_interface_out = ApplyInterface(
         name="barriercounter_out", **config.addresslayout.get_params())
     self.round_accepting = Signal(config.addresslayout.channel_bits)
    def __init__(self, config):
        num_pe = config.addresslayout.num_pe
        num_nodes_per_pe = config.addresslayout.num_nodes_per_pe

        self.apply_interface = [
            ApplyInterface(**config.addresslayout.get_params())
            for _ in range(num_pe)
        ]
        self.network_interface = [
            NetworkInterface(**config.addresslayout.get_params())
            for _ in range(num_pe)
        ]

        fifos = [[
            RecordFIFO(
                layout=Message(**config.addresslayout.get_params()).layout,
                depth=8) for i in range(num_pe)
        ] for j in range(num_pe)]
        self.submodules.fifos = fifos
        self.submodules.arbiter = [
            Arbiter(config, fifos[sink]) for sink in range(num_pe)
        ]

        self.comb += [
            self.arbiter[i].apply_interface.connect(self.apply_interface[i])
            for i in range(num_pe)
        ]

        # connect fifos across PEs
        for source in range(num_pe):
            array_dest_id = Array(
                fifo.din.dest_id
                for fifo in [fifos[sink][source] for sink in range(num_pe)])
            array_sender = Array(
                fifo.din.sender
                for fifo in [fifos[sink][source] for sink in range(num_pe)])
            array_payload = Array(
                fifo.din.payload
                for fifo in [fifos[sink][source] for sink in range(num_pe)])
            array_roundpar = Array(
                fifo.din.roundpar
                for fifo in [fifos[sink][source] for sink in range(num_pe)])
            array_barrier = Array(
                fifo.din.barrier
                for fifo in [fifos[sink][source] for sink in range(num_pe)])
            array_we = Array(
                fifo.we
                for fifo in [fifos[sink][source] for sink in range(num_pe)])
            array_writable = Array(
                fifo.writable
                for fifo in [fifos[sink][source] for sink in range(num_pe)])

            have_barrier = Signal()
            barrier_ack = Array(Signal() for _ in range(num_pe))
            barrier_done = Signal()

            self.comb += barrier_done.eq(reduce(
                and_, barrier_ack)), have_barrier.eq(
                    self.network_interface[source].msg.barrier
                    & self.network_interface[source].valid)

            self.sync += If(have_barrier & ~barrier_done, [
                barrier_ack[i].eq(barrier_ack[i] | array_writable[i])
                for i in range(num_pe)
            ]).Else([barrier_ack[i].eq(0) for i in range(num_pe)])

            sink = Signal(config.addresslayout.peidsize)

            self.comb += If(
                have_barrier, [array_barrier[i].eq(1) for i in range(num_pe)],
                [
                    array_roundpar[i].eq(
                        self.network_interface[source].msg.roundpar)
                    for i in range(num_pe)
                ], [array_we[i].eq(~barrier_ack[i]) for i in range(num_pe)],
                self.network_interface[source].ack.eq(barrier_done)).Else(
                    sink.eq(self.network_interface[source].dest_pe),
                    array_dest_id[sink].eq(
                        self.network_interface[source].msg.dest_id),
                    array_sender[sink].eq(
                        self.network_interface[source].msg.sender),
                    array_payload[sink].eq(
                        self.network_interface[source].msg.payload),
                    array_roundpar[sink].eq(
                        self.network_interface[source].msg.roundpar),
                    array_we[sink].eq(self.network_interface[source].valid),
                    self.network_interface[source].ack.eq(
                        array_writable[sink]))
Exemple #10
0
    def __init__(self, config, pe_id):
        self.config = config
        self.pe_id = pe_id
        addresslayout = config.addresslayout
        nodeidsize = addresslayout.nodeidsize
        num_nodes_per_pe = addresslayout.num_nodes_per_pe
        num_valid_nodes = max(2, len(config.adj_idx[pe_id])+1)

        # input Q interface
        self.apply_interface = ApplyInterface(name="apply_in", **addresslayout.get_params())

        # scatter interface
        # send self.update message to all neighbors
        # message format (sending_node_id) (normally would be (sending_node_id, weight), but for PR weight = sending_node_id)
        self.scatter_interface = ScatterInterface(name="apply_out", **addresslayout.get_params())

        self.deadlock = Signal()

        ####

        apply_interface_in_fifo = InterfaceFIFO(layout=self.apply_interface.layout, depth=8, name="apply_in_fifo")
        self.submodules += apply_interface_in_fifo
        self.comb += self.apply_interface.connect(apply_interface_in_fifo.din)

        # local node data storage
        self.specials.mem = Memory(layout_len(addresslayout.node_storage_layout), num_valid_nodes, init=config.init_nodedata[pe_id] if config.init_nodedata else None, name="vertex_data_{}".format(self.pe_id))
        rd_port = self.specials.rd_port = self.mem.get_port(has_re=True)
        wr_port = self.specials.wr_port = self.mem.get_port(write_capable=True)

        local_wr_port = Record(layout=get_mem_port_layout(wr_port))
        self.external_wr_port = Record(layout=get_mem_port_layout(wr_port) + [("select", 1)])

        self.comb += [
            If(self.external_wr_port.select,
                self.external_wr_port.connect(wr_port, omit={"select"})
            ).Else(
                local_wr_port.connect(wr_port)
            )
        ]

        # detect termination (now done by collating votes to halt in barriercounter - if barrier is passed on with halt bit set, don't propagate)
        self.inactive = Signal()
        self.sync += If(apply_interface_in_fifo.dout.valid & apply_interface_in_fifo.dout.ack & apply_interface_in_fifo.dout.msg.barrier & apply_interface_in_fifo.dout.msg.halt,
            self.inactive.eq(1)
        )

        # should pipeline advance?
        upstream_ack = Signal()
        collision_re = Signal()
        collision_en = Signal()

        # count levels
        self.level = Signal(32)

        ## Stage 1
        # rename some signals for easier reading, separate barrier and normal valid (for writing to state mem)
        dest_node_id = Signal(nodeidsize)
        sender = Signal(nodeidsize)
        payload = Signal(addresslayout.messagepayloadsize)
        roundpar = Signal(config.addresslayout.channel_bits)
        valid = Signal()
        barrier = Signal()

        self.comb += [
            dest_node_id.eq(apply_interface_in_fifo.dout.msg.dest_id),
            sender.eq(apply_interface_in_fifo.dout.msg.sender),
            payload.eq(apply_interface_in_fifo.dout.msg.payload),
            roundpar.eq(apply_interface_in_fifo.dout.msg.roundpar),
            valid.eq(apply_interface_in_fifo.dout.valid & ~apply_interface_in_fifo.dout.msg.barrier),
            barrier.eq(apply_interface_in_fifo.dout.valid & apply_interface_in_fifo.dout.msg.barrier & ~apply_interface_in_fifo.dout.msg.halt),
        ]

        ## Stage 2
        dest_node_id2 = Signal(nodeidsize)
        sender2 = Signal(nodeidsize)
        payload2 = Signal(addresslayout.messagepayloadsize)
        roundpar2 = Signal(config.addresslayout.channel_bits)
        barrier2 = Signal()
        valid2 = Signal()
        ready = Signal()
        msgvalid2 = Signal()
        statevalid2 = Signal()

        state_barrier = Signal()

        node_idx = Signal(nodeidsize)
        gather_done = Signal()

        next_roundpar = Signal(config.addresslayout.channel_bits)
        self.comb += If(roundpar==config.addresslayout.num_channels-1, next_roundpar.eq(0)).Else(next_roundpar.eq(roundpar+1))

        self.submodules.fsm = FSM()
        self.fsm.act("GATHER",
            rd_port.re.eq(upstream_ack),
            apply_interface_in_fifo.dout.ack.eq(upstream_ack),
            rd_port.adr.eq(addresslayout.local_adr(dest_node_id)),
            NextValue(collision_en, 1),
            If(~collision_re,
                NextValue(valid2, 0) # insert bubble if collision
            ).Elif(upstream_ack,
                NextValue(valid2, valid),
                NextValue(dest_node_id2, dest_node_id),
                NextValue(sender2, sender),
                NextValue(payload2, payload),
                NextValue(roundpar2, next_roundpar),
                NextValue(statevalid2, 1),
                NextValue(msgvalid2, ~barrier),
                If(barrier,
                    NextValue(collision_en, 0),
                    NextValue(valid2, 0),
                    NextState("FLUSH")
                )
            )
        )
        self.fsm.act("FLUSH",
            rd_port.re.eq(0),
            NextValue(node_idx, pe_id << log2_int(num_nodes_per_pe)),
            apply_interface_in_fifo.dout.ack.eq(0),
            rd_port.adr.eq(addresslayout.local_adr(dest_node_id)),
            If(gather_done,
                NextState("APPLY")
            )
        )
        self.fsm.act("APPLY",
            rd_port.re.eq(ready),
            apply_interface_in_fifo.dout.ack.eq(0),
            rd_port.adr.eq(addresslayout.local_adr(node_idx)),
            If(ready,
                NextValue(valid2, 1),
                NextValue(dest_node_id2, node_idx),
                NextValue(node_idx, node_idx+1),
                If(node_idx==(len(config.adj_idx[pe_id]) + (pe_id << log2_int(num_nodes_per_pe))),
                    NextValue(statevalid2, 0),
                    NextValue(barrier2, 1),
                    NextValue(valid2, 1),
                    NextState("BARRIER_SEND")
                )
            )
        )
        self.fsm.act("BARRIER_SEND",
            apply_interface_in_fifo.dout.ack.eq(0),
            rd_port.adr.eq(addresslayout.local_adr(node_idx)),
            If(ready,
                NextValue(barrier2, 0),
                NextValue(valid2, 0),
                If(state_barrier,
                    NextValue(self.level, self.level+1),
                    NextState("GATHER")
                ).Else(
                    NextState("BARRIER_WAIT")
                )
            )
        )
        self.fsm.act("BARRIER_WAIT",
            If(state_barrier,
                NextValue(self.level, self.level+1),
                NextState("GATHER")
            )
        )

        # collision handling (combinatorial)
        self.submodules.collisiondetector = CollisionDetector(addresslayout)

        self.comb += [
            self.collisiondetector.read_adr.eq(addresslayout.local_adr(dest_node_id)),
            self.collisiondetector.read_adr_valid.eq(ready & valid & collision_en), # can't be rd_port.re because that uses collisiondetector.re -> comb loop
            self.collisiondetector.write_adr.eq(local_wr_port.adr),
            self.collisiondetector.write_adr_valid.eq(local_wr_port.we),
            collision_re.eq(self.collisiondetector.re),
            gather_done.eq(self.collisiondetector.all_clear)
        ]

        # User code
        if hasattr(config, "gatherapplykernel"):
            self.submodules.gatherapplykernel = config.gatherapplykernel(config)
        else:
            self.submodules.gatherapplykernel = GatherApplyWrapper(config.gatherkernel(config), config.applykernel(config))

        self.comb += [
            self.gatherapplykernel.level_in.eq(self.level),
            self.gatherapplykernel.nodeid_in.eq(dest_node_id2),
            self.gatherapplykernel.sender_in.eq(sender2),
            self.gatherapplykernel.message_in.raw_bits().eq(payload2),
            self.gatherapplykernel.message_in_valid.eq(msgvalid2),
            self.gatherapplykernel.state_in.raw_bits().eq(rd_port.dat_r),
            self.gatherapplykernel.state_in_valid.eq(statevalid2),
            self.gatherapplykernel.round_in.eq(roundpar2),
            self.gatherapplykernel.barrier_in.eq(barrier2),
            self.gatherapplykernel.valid_in.eq(valid2),
            ready.eq(self.gatherapplykernel.ready),
            upstream_ack.eq((self.gatherapplykernel.ready | ~valid2) & collision_re)
        ]

        # write state updates
        self.comb += [
            local_wr_port.adr.eq(addresslayout.local_adr(self.gatherapplykernel.nodeid_out)),
            local_wr_port.dat_w.eq(self.gatherapplykernel.state_out.raw_bits()),
            state_barrier.eq(self.gatherapplykernel.state_barrier),
            local_wr_port.we.eq(self.gatherapplykernel.state_valid),
            self.gatherapplykernel.state_ack.eq(1)
        ]

        # output handling
        _layout = [
        ( "barrier", 1, DIR_M_TO_S ),
        ( "roundpar", config.addresslayout.channel_bits, DIR_M_TO_S ),
        ( "sender", "nodeidsize", DIR_M_TO_S ),
        ( "msg" , addresslayout.updatepayloadsize, DIR_M_TO_S )
        ]
        outfifo_in = Record(set_layout_parameters(_layout, **addresslayout.get_params()))
        outfifo_out = Record(set_layout_parameters(_layout, **addresslayout.get_params()))

        if config.updates_in_hmc:
            fpga_id = pe_id//config.addresslayout.num_pe_per_fpga
            local_pe_id = pe_id % config.addresslayout.num_pe_per_fpga
            self.submodules.outfifo = HMCBackedFIFO(width=len(outfifo_in), start_addr=local_pe_id*(1<<config.hmc_fifo_bits), end_addr=(local_pe_id + 1)*(1<<config.hmc_fifo_bits), port=config.platform[fpga_id].getHMCPort(local_pe_id))

            self.sync += [
                If(self.outfifo.full, self.deadlock.eq(1))
            ]
        else:
            self.submodules.outfifo = SyncFIFO(width=len(outfifo_in), depth=num_valid_nodes)
            self.comb += self.deadlock.eq(~self.outfifo.writable)

        self.comb += [
            self.outfifo.din.eq(outfifo_in.raw_bits()),
            outfifo_out.raw_bits().eq(self.outfifo.dout)
        ]

        self.comb += [
            outfifo_in.msg.eq(self.gatherapplykernel.update_out.raw_bits()),
            If(self.gatherapplykernel.barrier_out, outfifo_in.sender.eq(pe_id << log2_int(num_nodes_per_pe))
            ).Else(outfifo_in.sender.eq(self.gatherapplykernel.update_sender)),
            self.outfifo.we.eq(self.gatherapplykernel.update_valid),
            outfifo_in.roundpar.eq(self.gatherapplykernel.update_round),
            outfifo_in.barrier.eq(self.gatherapplykernel.barrier_out),
            self.gatherapplykernel.update_ack.eq(self.outfifo.writable)
        ]

        payload4 = Signal(addresslayout.updatepayloadsize)
        sender4 = Signal(addresslayout.nodeidsize)
        roundpar4 = Signal(config.addresslayout.channel_bits)
        barrier4 = Signal()
        valid4 = Signal()

        self.sync += If(self.scatter_interface.ack,
            payload4.eq(outfifo_out.msg),
            sender4.eq(outfifo_out.sender),
            roundpar4.eq(outfifo_out.roundpar),
            barrier4.eq(outfifo_out.barrier),
            valid4.eq(self.outfifo.readable)
        )

        self.comb += [
            self.scatter_interface.payload.eq(payload4),
            self.scatter_interface.sender.eq(sender4),
            self.scatter_interface.roundpar.eq(roundpar4),
            self.scatter_interface.barrier.eq(barrier4),
            self.scatter_interface.valid.eq(valid4)
        ]

        # send from fifo when receiver ready
        self.comb += self.outfifo.re.eq(self.scatter_interface.ack)