예제 #1
0
    def global_svd_dot(self, jac_v, minus_jtf):
        """
        Gathers the dot product between a `jtj`-type matrix and a `jtf`-type vector into a global result array.

        This is typically used within SVD-defined basis calculations, where `jac_v` is the "V"
        matrix of the SVD of a jacobian, and `minus_jtf` is the negative dot product between the Jacobian
        matrix and objective function vector.

        Parameters
        ----------
        jac_v : numpy.ndarray or LocalNumpyArray
            An array of `jtj`-type.

        minus_jtf : numpy.ndarray or LocalNumpyArray
            An array of `jtf`-type.

        Returns
        -------
        numpy.ndarray
            The global (gathered) parameter vector `dot(jac_v.T, minus_jtf)`.
        """
        # Assumes jac_v is 'jtj' type and minus_jtf is 'jtf' type.
        # Returns a *global* parameter array that is dot(jac_v.T, minus_jtf)
        local_dot = _np.dot(jac_v.T, minus_jtf)  # (nP, nP_fine) * (nP_fine) = (nP,)

        #Note: Could make this more efficient by being given a shared array like this as the destination
        result, result_shm = _smt.create_shared_ndarray(self.resource_alloc, (jac_v.shape[1],), 'd')
        self.resource_alloc.allreduce_sum(result, local_dot,
                                          unit_ralloc=self.layout.resource_alloc('param-fine'))
        ret = result.copy()
        self.resource_alloc.host_comm_barrier()  # make sure we don't cleanup too quickly
        _smt.cleanup_shared_ndarray(result_shm)
        return ret
예제 #2
0
    def _mapfill_hprobs_atom(self, array_to_fill, dest_indices,
                             dest_param_indices1, dest_param_indices2,
                             layout_atom, param_indices1, param_indices2,
                             resource_alloc, eps):
        """
        Helper function for populating hessian values by block.
        """
        shared_mem_leader = resource_alloc.is_host_leader if (
            resource_alloc is not None) else True

        if param_indices1 is None:
            param_indices1 = list(range(self.model.num_params))
        if param_indices2 is None:
            param_indices2 = list(range(self.model.num_params))
        if dest_param_indices1 is None:
            dest_param_indices1 = list(range(_slct.length(param_indices1)))
        if dest_param_indices2 is None:
            dest_param_indices2 = list(range(_slct.length(param_indices2)))

        param_indices1 = _slct.to_array(param_indices1)
        dest_param_indices1 = _slct.to_array(dest_param_indices1)
        #dest_param_indices2 = _slct.to_array(dest_param_indices2)  # OK if a slice

        #Get a map from global parameter indices to the desired
        # final index within mx_to_fill (fpoffset = final parameter offset)
        iParamToFinal = {
            i: dest_index
            for i, dest_index in zip(param_indices1, dest_param_indices1)
        }

        nEls = layout_atom.num_elements
        nP2 = _slct.length(param_indices2) if isinstance(
            param_indices2, slice) else len(param_indices2)
        dprobs, shm = _smt.create_shared_ndarray(resource_alloc, (nEls, nP2),
                                                 'd')
        dprobs2, shm2 = _smt.create_shared_ndarray(resource_alloc, (nEls, nP2),
                                                   'd')
        self.calclib.mapfill_dprobs_atom(self, dprobs, slice(0, nEls), None,
                                         layout_atom, param_indices2,
                                         resource_alloc, eps)

        orig_vec = self.model.to_vector().copy()
        for i in range(self.model.num_params):
            if i in iParamToFinal:
                iFinal = iParamToFinal[i]
                vec = orig_vec.copy()
                vec[i] += eps
                self.model.from_vector(vec, close=True)
                self.calclib.mapfill_dprobs_atom(self, dprobs2, slice(0, nEls),
                                                 None, layout_atom,
                                                 param_indices2,
                                                 resource_alloc, eps)
                if shared_mem_leader:
                    _fas(array_to_fill,
                         [dest_indices, iFinal, dest_param_indices2],
                         (dprobs2 - dprobs) / eps)
        self.model.from_vector(orig_vec)
        _smt.cleanup_shared_ndarray(shm)
        _smt.cleanup_shared_ndarray(shm2)
예제 #3
0
    def deallocate_jtj_shared_mem_buf(self, jtj_buf):
        """
        Frees the scratch memory allocated by :method:`allocate_jtj_shared_mem_buf`.

        Parameters
        ----------
        jtj_buf : tuple or None
            The value returned from :method:`allocate_jtj_shared_mem_buf`
        """
        buf, buf_shm = jtj_buf
        _smt.cleanup_shared_ndarray(buf_shm)
예제 #4
0
    def bcast(self, value, root=0):
        """
        Broadcasts a value from the root processor/host to the others in this resource allocation.

        This is similar to a usual MPI broadcast, except it takes advantage of shared memory when
        it is available.  When shared memory is being used, i.e. when this :class:`ResourceAllocation`
        object has a nontrivial inter-host comm, then this routine places `value` in a shared memory
        buffer and uses the resource allocation's inter-host communicator to broadcast the result
        from the root *host* to all the other hosts using all the processor on the root host in
        parallel (all processors with the same intra-host rank participate in a MPI broadcast).

        Parameters
        ----------
        value : numpy.ndarray
            The value to broadcast.  May be shared memory but doesn't need to be.  Only
            need to specify this on the rank `root` processor, other processors can provide
            any value for this argument (it's unused).

        root : int
            The rank of the processor whose `value` will be to broadcast.

        Returns
        -------
        numpy.ndarray
            The broadcast value, in a new, non-shared-memory array.
        """
        if self.host_comm is not None:
            bcast_shape, bcast_dtype = self.comm.bcast(
                (value.shape, value.dtype) if self.comm.rank == root else None,
                root=root)
            #FUTURE: check whether `value` is already shared memory  (or add flag?) and if so don't allocate/free `ar`
            from pygsti.tools import sharedmemtools as _smt
            ar, ar_shm = _smt.create_shared_ndarray(self, bcast_shape,
                                                    bcast_dtype)
            if self.comm.rank == root:
                ar[(slice(None, None), ) *
                   value.ndim] = value  # put our value into the shared memory.

            self.host_comm.barrier(
            )  # wait until shared mem is written to on all root-host procs
            interhost_root = self.host_index_for_rank[
                root]  # (b/c host_index == interhost.rank)
            ret = self.interhost_comm.bcast(ar, root=interhost_root)
            self.comm.barrier()  # wait until everyone's values are ready
            _smt.cleanup_shared_ndarray(ar_shm)
            return ret
        elif self.comm is not None:
            return self.comm.bcast(value, root=root)
        else:
            return value
def mapfill_dprobs_atom(fwdsim, mx_to_fill, dest_indices, dest_param_indices,
                        layout_atom, param_indices, resource_alloc, eps):

    #eps = 1e-7
    #shared_mem_leader = resource_alloc.is_host_leader if (resource_alloc is not None) else True

    if param_indices is None:
        param_indices = list(range(fwdsim.model.num_params))
    if dest_param_indices is None:
        dest_param_indices = list(range(_slct.length(param_indices)))

    param_indices = _slct.to_array(param_indices)
    dest_param_indices = _slct.to_array(dest_param_indices)

    #Get a map from global parameter indices to the desired
    # final index within mx_to_fill (fpoffset = final parameter offset)
    iParamToFinal = {
        i: dest_index
        for i, dest_index in zip(param_indices, dest_param_indices)
    }

    orig_vec = fwdsim.model.to_vector().copy()
    fwdsim.model.from_vector(
        orig_vec, close=False)  # ensure we call with close=False first

    #Note: no real need for using shared memory here except so that we can pass
    # `resource_alloc` to mapfill_probs_block and have it potentially use multiple procs.
    nEls = layout_atom.num_elements
    probs, shm = _smt.create_shared_ndarray(resource_alloc, (nEls, ),
                                            'd',
                                            memory_tracker=None)
    probs2, shm2 = _smt.create_shared_ndarray(resource_alloc, (nEls, ),
                                              'd',
                                              memory_tracker=None)
    mapfill_probs_atom(fwdsim, probs, slice(0, nEls), layout_atom,
                       resource_alloc)  # probs != shared

    for i in range(fwdsim.model.num_params):
        #print("dprobs cache %d of %d" % (i,self.Np))
        if i in iParamToFinal:
            iFinal = iParamToFinal[i]
            vec = orig_vec.copy()
            vec[i] += eps
            fwdsim.model.from_vector(vec, close=True)
            mapfill_probs_atom(fwdsim, probs2, slice(0, nEls), layout_atom,
                               resource_alloc)
            _fas(mx_to_fill, [dest_indices, iFinal], (probs2 - probs) / eps)
    fwdsim.model.from_vector(orig_vec, close=True)
    _smt.cleanup_shared_ndarray(shm)
    _smt.cleanup_shared_ndarray(shm2)
예제 #6
0
    def norm2_jtj(self, jtj):
        """
        Compute the Frobenius norm squared of an `jtj`-type matrix.

        Parameters
        ----------
        jtj : numpy.ndarray or LocalNumpyArray
            The array to operate on.

        Returns
        -------
        float
        """
        local_norm2 = _np.array(_np.linalg.norm(jtj)**2)
        local_norm2.shape = (1,)  # for compatibility with allreduce_sum
        result, result_shm = _smt.create_shared_ndarray(self.resource_alloc, (1,), 'd')
        self.resource_alloc.allreduce_sum(result, local_norm2,
                                          unit_ralloc=self.layout.resource_alloc('param-fine'))
        ret = result[0]  # "copies" the single returned element
        self.resource_alloc.host_comm_barrier()  # make sure we don't cleanup too quickly
        _smt.cleanup_shared_ndarray(result_shm)
        return ret
예제 #7
0
    def norm2_f(self, f):
        """
        Compute the Frobenius norm squared of an `f`-type vector.

        Parameters
        ----------
        f : numpy.ndarray or LocalNumpyArray
            The vector to operate on.

        Returns
        -------
        float
        """
        local_dot = _np.array(_np.dot(f, f))
        local_dot.shape = (1,)  # for compatibility with allreduce_sum
        result, result_shm = _smt.create_shared_ndarray(self.resource_alloc, (1,), 'd')
        self.resource_alloc.allreduce_sum(result, local_dot,
                                          unit_ralloc=self.layout.resource_alloc('atom-processing'))
        ret = result[0]  # "copies" the single returned element
        self.resource_alloc.host_comm_barrier()  # make sure we don't cleanup too quickly
        _smt.cleanup_shared_ndarray(result_shm)
        return ret
예제 #8
0
    def dot_x(self, x1, x2):
        """
        Take the dot product of two `x`-type vectors.

        Parameters
        ----------
        x1, x2 : numpy.ndarray or LocalNumpyArray
            The vectors to operate on.

        Returns
        -------
        float
        """
        # assumes x's are in "fine" mode
        local_dot = _np.array(_np.dot(x1, x2))
        local_dot.shape = (1,)  # for compatibility with allreduce_sum
        result, result_shm = _smt.create_shared_ndarray(self.resource_alloc, (1,), 'd')
        self.resource_alloc.allreduce_sum(result, local_dot,
                                          unit_ralloc=self.layout.resource_alloc('param-fine'))
        ret = result[0]  # "copies" the single returned element
        self.resource_alloc.host_comm_barrier()  # make sure we don't cleanup too quickly
        _smt.cleanup_shared_ndarray(result_shm)
        return ret
예제 #9
0
    def max_x(self, x):
        """
        Compute the maximum of an `x`-type vector.

        Parameters
        ----------
        x : numpy.ndarray or LocalNumpyArray
            The vector to operate on.

        Returns
        -------
        float
        """
        # assumes x's are in "fine" mode
        local_max = _np.array(_np.max(x))
        local_max.shape = (1,)  # for compatibility with allreduce_sum
        result, result_shm = _smt.create_shared_ndarray(self.resource_alloc, (1,), 'd')
        self.resource_alloc.allreduce_max(result, local_max,
                                          unit_ralloc=self.layout.resource_alloc('param-fine'))
        ret = result[0]  # "copies" the single returned element
        self.resource_alloc.host_comm_barrier()  # make sure we don't cleanup too quickly
        _smt.cleanup_shared_ndarray(result_shm)
        return ret
예제 #10
0
    def _iter_hprobs_by_rectangle(self, layout, wrt_slices_list, return_dprobs_12):
        # Just needed for compatibility - so base `iter_hprobs_by_rectangle` knows to loop over atoms
        # Similar to _iter_atom_hprobs_by_rectangle but runs over all atoms before yielding and
        #  yielded array has leading dim == # of local elements instead of just 1 atom's # elements.
        nElements = layout.num_elements
        resource_alloc = layout.resource_alloc()
        for wrtSlice1, wrtSlice2 in wrt_slices_list:

            if return_dprobs_12:
                dprobs1, dprobs1_shm = _smt.create_shared_ndarray(resource_alloc, (nElements, _slct.length(wrtSlice1)),
                                                                  'd', zero_out=True)
                dprobs2, dprobs2_shm = _smt.create_shared_ndarray(resource_alloc, (nElements, _slct.length(wrtSlice2)),
                                                                  'd', zero_out=True)
            else:
                dprobs1 = dprobs2 = dprobs1_shm = dprobs2_shm = None

            hprobs, hprobs_shm = _smt.create_shared_ndarray(
                resource_alloc, (nElements, _slct.length(wrtSlice1), _slct.length(wrtSlice2)),
                'd', zero_out=True)

            for atom in layout.atoms:
                self._bulk_fill_hprobs_dprobs_atom(hprobs[atom.element_slice, :, :],
                                                   dprobs1[atom.element_slice, :] if (dprobs1 is not None) else None,
                                                   dprobs2[atom.element_slice, :] if (dprobs2 is not None) else None,
                                                   atom, wrtSlice1, wrtSlice2, resource_alloc)
            #Note: we give resource_alloc as our local `resource_alloc` above because all the arrays
            # have been allocated based on just this subset of processors, unlike a call to bulk_fill_hprobs(...)
            # where the probs & dprobs are memory allocated and filled by a larger group of processors.  (the main
            # function of these args is to know which procs work together to fill the *same* values and which of
            # these are on the *same* host so that only one per host actually writes to the assumed-shared memory.

            if return_dprobs_12:
                dprobs12 = dprobs1[:, :, None] * dprobs2[:, None, :]  # (KM,N,1) * (KM,1,N') = (KM,N,N')
                yield wrtSlice1, wrtSlice2, hprobs, dprobs12
            else:
                yield wrtSlice1, wrtSlice2, hprobs

            _smt.cleanup_shared_ndarray(dprobs1_shm)
            _smt.cleanup_shared_ndarray(dprobs2_shm)
            _smt.cleanup_shared_ndarray(hprobs_shm)
예제 #11
0
    def _iter_atom_hprobs_by_rectangle(self, atom, wrt_slices_list, return_dprobs_12, resource_alloc):

        #FUTURE could make a resource_alloc.check_can_allocate_memory call here for ('epp', 'epp')?
        nElements = atom.num_elements
        for wrtSlice1, wrtSlice2 in wrt_slices_list:

            if return_dprobs_12:
                dprobs1, dprobs1_shm = _smt.create_shared_ndarray(resource_alloc, (nElements, _slct.length(wrtSlice1)),
                                                                  'd', zero_out=True)
                dprobs2, dprobs2_shm = _smt.create_shared_ndarray(resource_alloc, (nElements, _slct.length(wrtSlice2)),
                                                                  'd', zero_out=True)
            else:
                dprobs1 = dprobs2 = dprobs1_shm = dprobs2_shm = None

            hprobs, hprobs_shm = _smt.create_shared_ndarray(
                resource_alloc, (nElements, _slct.length(wrtSlice1), _slct.length(wrtSlice2)),
                'd', zero_out=True)

            # Note: no need to index w/ [atom.element_slice,...] (compare with _iter_hprobs_by_rectangles)
            # since these arrays are already sized to this particular atom (not to all the host's atoms)
            self._bulk_fill_hprobs_dprobs_atom(hprobs, dprobs1, dprobs2, atom,
                                               wrtSlice1, wrtSlice2, resource_alloc)
            #Note: we give resource_alloc as our local `resource_alloc` above because all the arrays
            # have been allocated based on just this subset of processors, unlike a call to bulk_fill_hprobs(...)
            # where the probs & dprobs are memory allocated and filled by a larger group of processors.  (the main
            # function of these args is to know which procs work together to fill the *same* values and which of
            # these are on the *same* host so that only one per host actually writes to the assumed-shared memory.

            if return_dprobs_12:
                dprobs12 = dprobs1[:, :, None] * dprobs2[:, None, :]  # (KM,N,1) * (KM,1,N') = (KM,N,N')
                yield wrtSlice1, wrtSlice2, hprobs, dprobs12
            else:
                yield wrtSlice1, wrtSlice2, hprobs

            _smt.cleanup_shared_ndarray(dprobs1_shm)
            _smt.cleanup_shared_ndarray(dprobs2_shm)
            _smt.cleanup_shared_ndarray(hprobs_shm)
예제 #12
0
def custom_solve(a, b, x, ari, resource_alloc, proc_threshold=100):
    """
    Simple parallel Gaussian Elimination with pivoting.

    This function was built to provide a parallel alternative to
    `scipy.linalg.solve`, and can achieve faster runtimes compared
    with the serial SciPy routine when the number of available processors
    and problem size are large enough.

    When the number of processors is greater than `proc_threshold` (below
    this number the routine just calls `scipy.linalg.solve` on the root
    processor) the method works as follows:

    - each processor "owns" some subset of the rows of `a` and `b`.
    - iteratively (over pivot columns), the best pivot row is found, and this row is used to
      eliminate all other elements in the current pivot column.  This procedure operations on
      the joined matrix `a|b`, and when it completes the matrix `a` is in reduced row echelon
      form (RREF).
    - back substitution (trivial because `a` is in *reduced* REF) is performed to find
      the solution `x` such that `a @ x = b`.

    Parameters
    ----------
    a : LocalNumpyArray
        A 2D array with the `'jtj'` distribution, holding the rows of the `a` matrix belonging
        to the current processor.  (This belonging is dictated by the "fine" distribution in
        a distributed layout.)

    b : LocalNumpyArray
        A 1D array with the `'jtf'` distribution, holding the rows of the `b` vector belonging
        to the current processor.

    x : LocalNumpyArray
        A 1D array with the `'jtf'` distribution, holding the rows of the `x` vector belonging
        to the current processor.  This vector is filled by this function.

    ari : ArraysInterface
        An object that provides an interface for creating and manipulating data arrays.

    resource_alloc : ResourceAllocation
        Gives the resources (e.g., processors and memory) available for use.

    proc_threshold : int, optional
        Below this number of processors this routine will simply gather `a` and `b` to a single
        (the rank 0) processor, call SciPy's serial linear solver, `scipy.linalg.solve`, and
        scatter the results back onto all the processors.

    Returns
    -------
    None
    """

    #DEBUG
    #for i in range(a.shape[1]):
    #    print(i, " = ", _np.linalg.norm(a[:,i]))
    #assert(False), "STOP"

    pivot_row_indices = []
    #potential_pivot_indices = list(range(a.shape[0]))  # *local* row indices of rows not already chosen as pivot rows
    potential_pivot_mask = _np.ones(
        a.shape[0], dtype=bool
    )  # *local* row indices of rows not already chosen pivot rows
    all_row_indices = _np.arange(a.shape[0])
    my_row_slice = ari.jtf_param_slice()

    comm = resource_alloc.comm
    host_comm = resource_alloc.host_comm
    ok_buf = _np.empty(1, _np.int64)

    if comm is None or isinstance(ari, _UndistributedArraysInterface):
        x[:] = _scipy.linalg.solve(a, b, assume_a='pos')
        return

    #Just gather everything to one processor and compute there:
    if comm.size < proc_threshold and a.shape[1] < 10000:
        # We're not exactly sure where scipy is better, but until we speed up / change gaussian-elim
        # alg the scipy alg is much faster for small numbers of procs and so should be used unless
        # A is too large to be gathered to the root proc.
        global_a, a_shm = ari.gather_jtj(a, return_shared=True)
        global_b, b_shm = ari.gather_jtf(b, return_shared=True)
        #global_a = ari.gather_jtj(a)
        #global_b = ari.gather_jtf(b)
        if comm.rank == 0:
            try:
                global_x = _scipy.linalg.solve(global_a,
                                               global_b,
                                               assume_a='pos')
                ok_buf[0] = 1  # ok
            except _scipy.linalg.LinAlgError as e:
                ok_buf[0] = 0  # failure!
                err = e
        else:
            global_x = None
            err = _scipy.linalg.LinAlgError(
                "Linear solver fail on root proc!")  # just in case...

        comm.Bcast(ok_buf, root=0)
        if ok_buf[0] == 0:
            _smt.cleanup_shared_ndarray(a_shm)
            _smt.cleanup_shared_ndarray(b_shm)
            raise err  # all procs must raise in sync

        ari.scatter_x(global_x, x)
        _smt.cleanup_shared_ndarray(a_shm)
        _smt.cleanup_shared_ndarray(b_shm)
        return

    if host_comm is not None:
        shared_floats, shared_floats_shm = _smt.create_shared_ndarray(
            resource_alloc, (host_comm.size, ), 'd')
        shared_ints, shared_ints_shm = _smt.create_shared_ndarray(
            resource_alloc, (max(host_comm.size, 3), ), _np.int64)
        shared_rowb, shared_rowb_shm = _smt.create_shared_ndarray(
            resource_alloc, (a.shape[1] + 1, ), 'd')

        # Scratch buffers
        host_index_buf = _np.empty((resource_alloc.interhost_comm.size, 2), _np.int64) \
            if resource_alloc.interhost_comm.rank == 0 else None
        host_val_buf = _np.empty((resource_alloc.interhost_comm.size, 1), 'd') \
            if resource_alloc.interhost_comm.rank == 0 else None
    else:
        shared_floats = shared_ints = shared_rowb = None

        host_index_buf = _np.empty(
            (comm.size, 1), _np.int64) if comm.rank == 0 else None
        host_val_buf = _np.empty(
            (comm.size, 1), 'd') if comm.rank == 0 else None

    # Idea: bring a|b into RREF then back-substitute to get x.

    # for each column, find the best "pivot" row to use to eliminate other rows.
    # (note: column pivoting is not used
    a_orig = a.copy()  # So we can restore original values of a and b
    b_orig = b.copy()  # (they're updated as we go)

    #Scratch space
    local_pivot_rowb = _np.empty(a.shape[1] + 1, 'd')
    smbuf1 = _np.empty(1, 'd')
    smbuf2 = _np.empty(2, _np.int64)
    smbuf3 = _np.empty(3, _np.int64)

    for icol in range(a.shape[1]):

        # Step 1: find the index of the row that is the best pivot.
        # each proc looks for its best pivot (Note: it should not consider rows already pivoted on)
        potential_pivot_indices = all_row_indices[potential_pivot_mask]
        ibest_global, ibest_local, h, k = _find_pivot(
            a, b, icol, potential_pivot_indices, my_row_slice, shared_floats,
            shared_ints, resource_alloc, comm, host_comm, smbuf1, smbuf2,
            smbuf3, host_index_buf, host_val_buf)

        # Step 2: proc that owns best row (holds that row and is root of param-fine comm) broadcasts it
        pivot_row, pivot_b = _broadcast_pivot_row(
            a, b, ibest_local, h, k, shared_rowb, local_pivot_rowb,
            potential_pivot_mask, resource_alloc, comm, host_comm)

        if abs(pivot_row[icol]) < 1e-6:
            # There's no non-zero element in this column to use as a pivot - the column is all zeros.
            # By convention, we just set the corresponding x-value to zero (below) and don't need to do Step 3.
            # NOTE: it's possible that a previous pivot row could have a non-zero element in the icol-th column,
            #  and we could still get here (because we don't consider previously chosen rows as pivot-row candidtate).
            #  But this is ok, since we set the corresponding x-values to 0 so the end result is effectively in RREF.
            pivot_row_indices.append(-1)
            continue

        pivot_row_indices.append(ibest_global)

        # Step 3: all procs update their rows based on the pivot row (including `b`)
        #  - need to update non-pivot rows to eliminate iCol-th entry: row -= alpha * pivot_row
        #    where alpha = row[iCol] / pivot_row[iCol]
        # (Note: don't do this when there isn't a nonzero pivot)
        ipivot_local = ibest_global - my_row_slice.start  # *local* row index of pivot row (ok if negative)
        _update_rows(a, b, icol, ipivot_local, pivot_row, pivot_b)

    # Back substitution:
    # We've accumulated a list of (global) row indices of the rows containing a nonzero
    # element in a given column and zeros in prior columns.
    pivot_row_indices = _np.array(pivot_row_indices)
    _back_substitution(a, b, x, pivot_row_indices, my_row_slice, ari,
                       resource_alloc, host_comm)

    a[:, :] = a_orig  # restore original values of a and b
    b[:] = b_orig  # so they're the same as when we were called.
    # Note: maybe we could just use array copies in the algorithm, but we may need to use the
    # real a and b because they can be shared mem (check?)

    if host_comm is not None:
        _smt.cleanup_shared_ndarray(shared_floats_shm)
        _smt.cleanup_shared_ndarray(shared_ints_shm)
        _smt.cleanup_shared_ndarray(shared_rowb_shm)
    return