Пример #1
0
    def _initialize(self, system):
        scope = system.scope
        name2collapsed = scope.name2collapsed
        flat_vars = system.flat_vars
        vector_vars = system.vector_vars
        self.app_ordering = system.app_ordering
        rank = system.mpi.rank

        vec_srcs = set([n[0] for n in vector_vars])

        # to detect overlapping index sets, we need all of the subvar bases
        # that are NOT included in the vectors.  If a base IS included in the vectors,
        # then all of its subvars are just subviews of the base view, so no chance
        # of overlapping causing redundant data in the vectors
        bases = dict([(n.split('[',1)[0], []) for n in vec_srcs
                                    if n.split('[',1)[0] not in vec_srcs])

        # first, add views for vars whose sizes are added to the total,
        # i.e., either they are basevars or their basevars are not included
        # in the vector.
        start, end = 0, 0
        for ivar, (name, var) in enumerate(vector_vars.items()):
            sz = system.local_var_sizes[rank, ivar]
            if sz > 0:
                end += sz
                # store the view, local start idx, and distributed start idx
                self._info[name] = ViewInfo(self.array[start:end], start, slice(None),
                                            end-start, False)

                base = name[0].split('[',1)[0]

                if base in bases and base not in vec_srcs:
                    bases[base].append(name[0])

                    if len(bases[base]) > 1:
                        # check for overlaping subvars
                        idxset = set()
                        bval = scope.get(base)

                        for subname in bases[base]:
                            _, idx = get_val_and_index(scope, subname)
                            idxs = get_flattened_index(idx, get_shape(bval), cvt_to_slice=False)
                            if idxset.intersection(set(idxs)):
                                raise RuntimeError("Subvars %s share overlapping indices. Try reformulating the problem to prevent this." %
                                                   [n for n in bases[base]])
                            idxset.update(idxs)

                if end-start > self.array[start:end].size:
                    raise RuntimeError("size mismatch: in system %s view for %s is %s, size=%d" %
                                 (system.name,name, [start,end],self[name].size))
                start += sz

        # now add views for subvars that are subviews of their
        # basevars
        if vector_vars:
            for name, var in flat_vars.items():
                if name not in vector_vars:
                    self._add_subview(scope, name)
    def test_get_flattened_index(self):
        self.assertEqual(slice(4, 5, None), get_flattened_index(4, (10,)))
        self.assertEqual(slice(9, 10, None), get_flattened_index(-1, (10,)))
        self.assertEqual(slice(90, 100, 1), get_flattened_index(-1, (10, 10)))
        try:
            self.assertEqual(0, get_flattened_index(10, (10, 10)))
        except IndexError as err:
            # Some versions of numpy have slightly different messages, so as
            # long as it is an index error, we are fine.
            pass
        else:
            self.fail("Should get an Indexerror")

        self.assertEqual(slice(22, 23, None), get_flattened_index((2, 2), (10, 10)))
        self.assertEqual(slice(42, 63, 10), get_flattened_index((slice(4, 7), 2), (10, 10)))
        self.assertEqual(slice(40, 61, 10), get_flattened_index((slice(4, 7), 0), (10, 10)))
        self.assertEqual(slice(4, 11, 2), get_flattened_index(slice(4, 11, 2), (20,)))
        self.assertEqual(slice(40, 50, 1), get_flattened_index(slice(4, 5, 2), (20, 10)))

        self.assertEqual(slice(1, 2, None), get_flattened_index(1, (5,)))
        self.assertEqual(slice(6, 7, None), get_flattened_index([1, 2], (3, 4)))
        self.assertEqual(slice(62, 63, None), get_flattened_index([-1, -1], (9, 7)))
        self.assertEqual(slice(3, 25, 7), get_flattened_index([slice(None), 3], (4, 7)))
        self.assertEqual(slice(48, 49, None), get_flattened_index(-2, (50,)))
        self.assertEqual(slice(3, 44, 5), get_flattened_index(slice(3, -3, 5), (50,)))
    def test_get_flattened_index(self):
        self.assertEqual(slice(4, 5, None), get_flattened_index(4, (10, )))
        self.assertEqual(slice(9, 10, None), get_flattened_index(-1, (10, )))
        self.assertEqual(slice(90, 100, 1), get_flattened_index(-1, (10, 10)))
        try:
            self.assertEqual(0, get_flattened_index(10, (10, 10)))
        except IndexError as err:
            # Some versions of numpy have slightly different messages, so as
            # long as it is an index error, we are fine.
            pass
        else:
            self.fail('Should get an Indexerror')

        self.assertEqual(slice(22, 23, None),
                         get_flattened_index((2, 2), (10, 10)))
        self.assertEqual(slice(42, 63, 10),
                         get_flattened_index((slice(4, 7), 2), (10, 10)))
        self.assertEqual(slice(40, 61, 10),
                         get_flattened_index((slice(4, 7), 0), (10, 10)))
        self.assertEqual(slice(4, 11, 2),
                         get_flattened_index(slice(4, 11, 2), (20, )))
        self.assertEqual(slice(40, 50, 1),
                         get_flattened_index(slice(4, 5, 2), (20, 10)))

        self.assertEqual(slice(1, 2, None), get_flattened_index(1, (5, )))
        self.assertEqual(slice(6, 7, None), get_flattened_index([1, 2],
                                                                (3, 4)))
        self.assertEqual(slice(62, 63, None),
                         get_flattened_index([-1, -1], (9, 7)))
        self.assertEqual(slice(3, 25, 7),
                         get_flattened_index([slice(None), 3], (4, 7)))
        self.assertEqual(slice(48, 49, None), get_flattened_index(-2, (50, )))
        self.assertEqual(slice(3, 44, 5),
                         get_flattened_index(slice(3, -3, 5), (50, )))