예제 #1
0
    def _init_colored_approximations(self, system):
        from openmdao.core.group import Group
        from openmdao.core.implicitcomponent import ImplicitComponent

        self._colored_approx_groups = []
        self._j_colored = None
        self._j_data_sizes = None
        self._j_data_offsets = None

        # don't do anything if the coloring doesn't exist yet
        coloring = system._coloring_info['coloring']
        if not isinstance(coloring, coloring_mod.Coloring):
            return

        outputs = system._outputs
        inputs = system._inputs
        abs2meta = system._var_allprocs_abs2meta
        prom2abs_out = system._var_allprocs_prom2abs_list['output']
        prom2abs_in = system._var_allprocs_prom2abs_list['input']
        approx_wrt_idx = system._owns_approx_wrt_idx

        out_slices = outputs.get_slice_dict()
        in_slices = inputs.get_slice_dict()

        is_total = isinstance(system, Group)

        system._update_wrt_matches()
        wrt_matches = system._coloring_info['wrt_matches']

        data = None
        keys = set()
        for key, apprx in iteritems(self._exec_dict):
            if key[0] in wrt_matches:
                options = apprx[0][1]
                if 'coloring' in options:
                    if data is None:
                        # data is the same for all colored approxs so we only need the first
                        data = self._get_approx_data(system, key)
                    keys.update(a[0] for a in apprx)

        if is_total and system.pathname == '':  # top level approx totals
            of_names = system._owns_approx_of
            full_wrts = system._var_allprocs_abs_names['output'] + \
                system._var_allprocs_abs_names['input']
            wrt_names = system._owns_approx_wrt
        else:
            of_names, wrt_names = system._get_partials_varlists()
            wrt_names = [
                prom2abs_in[n][0] if n in prom2abs_in else prom2abs_out[n][0]
                for n in wrt_names
            ]
            full_wrts = wrt_names

        tmpJ = {
            '@nrows': coloring._shape[0],
            '@ncols': coloring._shape[1],
            '@out_slices': out_slices,
            '@approxs': keys,
            '@jac_slices': {},
        }

        # FIXME: need to deal with mix of local/remote indices

        len_full_ofs = len(system._var_allprocs_abs_names['output'])

        full_idxs = []
        approx_of_idx = system._owns_approx_of_idx
        jac_slices = tmpJ['@jac_slices']
        for abs_of, roffset, rend, _ in system._jacobian_of_iter():
            rslice = slice(roffset, rend)
            for abs_wrt, coffset, cend, _ in system._jacobian_wrt_iter(
                    wrt_matches):
                jac_slices[(abs_of, abs_wrt)] = (rslice, slice(coffset, cend))

            if is_total and (approx_of_idx or len_full_ofs > len(of_names)):
                slc = out_slices[abs_of]
                if abs_of in approx_of_idx:
                    full_idxs.append(
                        np.arange(slc.start, slc.stop)[approx_of_idx[abs_of]])
                else:
                    full_idxs.append(range(slc.start, slc.stop))
        if full_idxs:
            tmpJ['@row_idx_map'] = np.hstack(full_idxs)

        if len(full_wrts) != len(wrt_matches) or approx_wrt_idx:
            if is_total and system.pathname == '':  # top level approx totals
                full_wrt_sizes = [abs2meta[wrt]['size'] for wrt in wrt_names]
            else:
                _, full_wrt_sizes = system._get_partials_var_sizes()

            # need mapping from coloring jac columns (subset) to full jac columns
            col_map = sub2full_indices(full_wrts, wrt_matches, full_wrt_sizes,
                                       approx_wrt_idx)
        else:
            col_map = None

        # get groups of columns from the coloring and compute proper indices into
        # the inputs and outputs vectors.
        is_semi = is_total and system.pathname
        use_full_cols = isinstance(system, ImplicitComponent) or is_semi
        for cols, nzrows in coloring.color_nonzero_iter('fwd'):
            ccols = cols if col_map is None else col_map[cols]
            idx_info = get_input_idx_split(ccols, inputs, outputs,
                                           use_full_cols, is_total)
            self._colored_approx_groups.append(
                (data, cols, tmpJ, idx_info, nzrows))
예제 #2
0
    def _init_colored_approximations(self, system):
        from openmdao.core.group import Group
        from openmdao.core.implicitcomponent import ImplicitComponent

        is_group = isinstance(system, Group)
        is_total = is_group and system.pathname == ''
        is_semi = is_group and not is_total
        use_full_cols = is_semi or isinstance(system, ImplicitComponent)

        self._colored_approx_groups = []

        # don't do anything if the coloring doesn't exist yet
        coloring = system._coloring_info['coloring']
        if not isinstance(coloring, coloring_mod.Coloring):
            return

        system._update_wrt_matches(system._coloring_info)
        wrt_matches = system._coloring_info['wrt_matches']
        out_slices = system._outputs.get_slice_dict()

        if wrt_matches is not None:
            # this maps column indices into colored jac into indices into full jac
            ccol2jcol = np.empty(coloring._shape[1], dtype=INT_DTYPE)

            # colored col to out vec idx
            if is_total:
                ccol2vcol = np.empty(coloring._shape[1], dtype=INT_DTYPE)

            ordered_wrt_iter = list(system._jac_wrt_iter())
            colored_start = colored_end = 0
            for abs_wrt, cstart, cend, vec, cinds in ordered_wrt_iter:
                if wrt_matches is None or abs_wrt in wrt_matches:
                    colored_end += cend - cstart
                    ccol2jcol[colored_start:colored_end] = np.arange(
                        cstart, cend, dtype=INT_DTYPE)
                    if is_total and abs_wrt in out_slices:
                        slc = out_slices[abs_wrt]
                        rng = np.arange(slc.start, slc.stop)
                        if cinds is not None:
                            rng = rng[cinds]
                        ccol2vcol[colored_start:colored_end] = rng
                    colored_start = colored_end

        row_var_sizes = {
            v: sz
            for v, sz in zip(coloring._row_vars, coloring._row_var_sizes)
        }
        row_map = np.empty(coloring._shape[0], dtype=INT_DTYPE)
        abs2prom = system._var_allprocs_abs2prom['output']

        if is_total:
            it = [(of, end - start)
                  for of, start, end, _ in system._jac_of_iter()]
        else:
            it = [(n, arr.size) for n, arr in system._outputs._abs_item_iter()]

        start = end = colorstart = colorend = 0
        for name, sz in it:
            end += sz
            prom = name if is_total else abs2prom[name]
            if prom in row_var_sizes:
                colorend += row_var_sizes[prom]
                row_map[colorstart:colorend] = np.arange(start,
                                                         end,
                                                         dtype=INT_DTYPE)
                colorstart = colorend
            start = end

        for wrt, meta in self._wrt_meta.items():
            if wrt_matches is None or wrt in wrt_matches:
                # data is the same for all colored approxs so we only need the first
                data = self._get_approx_data(system, wrt, meta)
                break

        outputs = system._outputs
        inputs = system._inputs

        for cols, nzrows in coloring.color_nonzero_iter('fwd'):
            nzrows = [row_map[r] for r in nzrows]
            jaccols = cols if wrt_matches is None else ccol2jcol[cols]
            if is_total:
                vcols = ccol2vcol[cols]
            else:
                vcols = jaccols
            vec_ind_list = get_input_idx_split(vcols, inputs, outputs,
                                               use_full_cols, is_total)
            self._colored_approx_groups.append(
                (data, jaccols, vec_ind_list, nzrows))