Ejemplo n.º 1
0
    def create(source_set, sink_set, layout, connections):
        def ensure_uniform(l):
            assert min(l) == max(l)
            return l[0]

        sorted_sources = sorted(source_set)
        flat_sources = list(flatten(sorted_sources))
        nesting = convert_to_nested_indices(sorted_sources)

        # get buffer type for hub and assert its uniform
        structs = [
            BufferStructure.from_layout(get_by_path(layout, s))
            for s in flat_sources
        ]
        btype = ensure_uniform([s.buffer_type for s in structs])
        # max context size
        context_size = max([s.context_size for s in structs])

        hub = Hub(flat_sources, nesting, sorted(sink_set), btype, context_size)
        hub.setup(connections)
        hub.sizes = [structs[i].feature_size for i in hub.perm]
        hub.size = sum(hub.sizes)
        hub.is_backward_only = ensure_uniform(
            [structs[i].is_backward_only for i in hub.perm])
        return hub
Ejemplo n.º 2
0
    def permute_rows(self):
        """
        Given a list of sources and a connection table, find a permutation of
        the sources, such that they can be connected to the sinks via a single
        buffer.
        """
        # systematically try all permutations until one satisfies the condition
        for perm in itertools.permutations(self.nesting):
            self.perm = list(flatten(perm))
            ct = np.atleast_2d(self.connection_table[self.perm])
            if Hub.can_be_connected_with_single_buffer(ct):
                self.connection_table = ct
                self.flat_sources = [self.flat_sources[i] for i in self.perm]
                return

        raise NetworkValidationError("Failed to lay out buffers. " "Please change connectivity.")
Ejemplo n.º 3
0
def get_all_sources(forced_orders, connections, layout):
    """Gather all sources while preserving order of the sources."""
    all_sinks = sorted(set(list(zip(*connections))[1])) if connections else []
    all_sources = list()
    for s in gather_array_nodes(layout):
        if s in all_sinks + ['parameters', 'gradients']:
            continue
        for fo in forced_orders:
            if s in set(flatten(all_sources)):
                break
            elif s in fo:
                all_sources.append(fo)
                break
        else:
            all_sources.append(s)

    return all_sources
Ejemplo n.º 4
0
    def permute_rows(self):
        """
        Given a list of sources and a connection table, find a permutation of
        the sources, such that they can be connected to the sinks via a single
        buffer.
        """
        # systematically try all permutations until one satisfies the condition
        for perm in itertools.permutations(self.nesting):
            self.perm = list(flatten(perm))
            ct = np.atleast_2d(self.connection_table[self.perm])
            if Hub.can_be_connected_with_single_buffer(ct):
                self.connection_table = ct
                self.flat_sources = [self.flat_sources[i] for i in self.perm]
                return

        raise NetworkValidationError("Failed to lay out buffers. "
                                     "Please change connectivity.")
Ejemplo n.º 5
0
def get_all_sources(forced_orders, connections, layout):
    """Gather all sources while preserving order of the sources."""
    all_sinks = sorted(set(list(zip(*connections))[1])) if connections else []
    all_sources = list()
    for s in gather_array_nodes(layout):
        if s in all_sinks + ['parameters', 'gradients']:
            continue
        for fo in forced_orders:
            if s in set(flatten(all_sources)):
                break
            elif s in fo:
                all_sources.append(fo)
                break
        else:
            all_sources.append(s)

    return all_sources
Ejemplo n.º 6
0
    def create(source_set, sink_set, layout, connections):
        def ensure_uniform(l):
            assert min(l) == max(l)
            return l[0]

        sorted_sources = sorted(source_set)
        flat_sources = list(flatten(sorted_sources))
        nesting = convert_to_nested_indices(sorted_sources)

        # get buffer type for hub and assert its uniform
        structs = [BufferStructure.from_layout(get_by_path(layout, s)) for s in flat_sources]
        btype = ensure_uniform([s.buffer_type for s in structs])
        # max context size
        context_size = max([s.context_size for s in structs])

        hub = Hub(flat_sources, nesting, sorted(sink_set), btype, context_size)
        hub.setup(connections)
        hub.sizes = [structs[i].feature_size for i in hub.perm]
        hub.size = sum(hub.sizes)
        hub.is_backward_only = ensure_uniform([structs[i].is_backward_only for i in hub.perm])
        return hub
Ejemplo n.º 7
0
def test_flatten():
    assert list(flatten([0, (1, 2, 3), 4, [5, (6, 7), 8]])) == list(range(9))
Ejemplo n.º 8
0
def test_flatten():
    assert list(flatten([0, (1, 2, 3), 4, [5, (6, 7), 8]])) == list(range(9))