Пример #1
0
    def _mapfill_hprobs_atom(self, array_to_fill, dest_indices,
                             dest_param_indices1, dest_param_indices2,
                             layout_atom, param_indices1, param_indices2,
                             resource_alloc, eps):
        """
        Helper function for populating hessian values by block.
        """
        shared_mem_leader = resource_alloc.is_host_leader if (
            resource_alloc is not None) else True

        if param_indices1 is None:
            param_indices1 = list(range(self.model.num_params))
        if param_indices2 is None:
            param_indices2 = list(range(self.model.num_params))
        if dest_param_indices1 is None:
            dest_param_indices1 = list(range(_slct.length(param_indices1)))
        if dest_param_indices2 is None:
            dest_param_indices2 = list(range(_slct.length(param_indices2)))

        param_indices1 = _slct.to_array(param_indices1)
        dest_param_indices1 = _slct.to_array(dest_param_indices1)
        #dest_param_indices2 = _slct.to_array(dest_param_indices2)  # OK if a slice

        #Get a map from global parameter indices to the desired
        # final index within mx_to_fill (fpoffset = final parameter offset)
        iParamToFinal = {
            i: dest_index
            for i, dest_index in zip(param_indices1, dest_param_indices1)
        }

        nEls = layout_atom.num_elements
        nP2 = _slct.length(param_indices2) if isinstance(
            param_indices2, slice) else len(param_indices2)
        dprobs, shm = _smt.create_shared_ndarray(resource_alloc, (nEls, nP2),
                                                 'd')
        dprobs2, shm2 = _smt.create_shared_ndarray(resource_alloc, (nEls, nP2),
                                                   'd')
        self.calclib.mapfill_dprobs_atom(self, dprobs, slice(0, nEls), None,
                                         layout_atom, param_indices2,
                                         resource_alloc, eps)

        orig_vec = self.model.to_vector().copy()
        for i in range(self.model.num_params):
            if i in iParamToFinal:
                iFinal = iParamToFinal[i]
                vec = orig_vec.copy()
                vec[i] += eps
                self.model.from_vector(vec, close=True)
                self.calclib.mapfill_dprobs_atom(self, dprobs2, slice(0, nEls),
                                                 None, layout_atom,
                                                 param_indices2,
                                                 resource_alloc, eps)
                if shared_mem_leader:
                    _fas(array_to_fill,
                         [dest_indices, iFinal, dest_param_indices2],
                         (dprobs2 - dprobs) / eps)
        self.model.from_vector(orig_vec)
        _smt.cleanup_shared_ndarray(shm)
        _smt.cleanup_shared_ndarray(shm2)
def mapfill_dprobs_atom(fwdsim, mx_to_fill, dest_indices, dest_param_indices,
                        layout_atom, param_indices, resource_alloc, eps):

    #eps = 1e-7
    #shared_mem_leader = resource_alloc.is_host_leader if (resource_alloc is not None) else True

    if param_indices is None:
        param_indices = list(range(fwdsim.model.num_params))
    if dest_param_indices is None:
        dest_param_indices = list(range(_slct.length(param_indices)))

    param_indices = _slct.to_array(param_indices)
    dest_param_indices = _slct.to_array(dest_param_indices)

    #Get a map from global parameter indices to the desired
    # final index within mx_to_fill (fpoffset = final parameter offset)
    iParamToFinal = {
        i: dest_index
        for i, dest_index in zip(param_indices, dest_param_indices)
    }

    orig_vec = fwdsim.model.to_vector().copy()
    fwdsim.model.from_vector(
        orig_vec, close=False)  # ensure we call with close=False first

    #Note: no real need for using shared memory here except so that we can pass
    # `resource_alloc` to mapfill_probs_block and have it potentially use multiple procs.
    nEls = layout_atom.num_elements
    probs, shm = _smt.create_shared_ndarray(resource_alloc, (nEls, ),
                                            'd',
                                            memory_tracker=None)
    probs2, shm2 = _smt.create_shared_ndarray(resource_alloc, (nEls, ),
                                              'd',
                                              memory_tracker=None)
    mapfill_probs_atom(fwdsim, probs, slice(0, nEls), layout_atom,
                       resource_alloc)  # probs != shared

    for i in range(fwdsim.model.num_params):
        #print("dprobs cache %d of %d" % (i,self.Np))
        if i in iParamToFinal:
            iFinal = iParamToFinal[i]
            vec = orig_vec.copy()
            vec[i] += eps
            fwdsim.model.from_vector(vec, close=True)
            mapfill_probs_atom(fwdsim, probs2, slice(0, nEls), layout_atom,
                               resource_alloc)
            _fas(mx_to_fill, [dest_indices, iFinal], (probs2 - probs) / eps)
    fwdsim.model.from_vector(orig_vec, close=True)
    _smt.cleanup_shared_ndarray(shm)
    _smt.cleanup_shared_ndarray(shm2)
Пример #3
0
    def global_svd_dot(self, jac_v, minus_jtf):
        """
        Gathers the dot product between a `jtj`-type matrix and a `jtf`-type vector into a global result array.

        This is typically used within SVD-defined basis calculations, where `jac_v` is the "V"
        matrix of the SVD of a jacobian, and `minus_jtf` is the negative dot product between the Jacobian
        matrix and objective function vector.

        Parameters
        ----------
        jac_v : numpy.ndarray or LocalNumpyArray
            An array of `jtj`-type.

        minus_jtf : numpy.ndarray or LocalNumpyArray
            An array of `jtf`-type.

        Returns
        -------
        numpy.ndarray
            The global (gathered) parameter vector `dot(jac_v.T, minus_jtf)`.
        """
        # Assumes jac_v is 'jtj' type and minus_jtf is 'jtf' type.
        # Returns a *global* parameter array that is dot(jac_v.T, minus_jtf)
        local_dot = _np.dot(jac_v.T, minus_jtf)  # (nP, nP_fine) * (nP_fine) = (nP,)

        #Note: Could make this more efficient by being given a shared array like this as the destination
        result, result_shm = _smt.create_shared_ndarray(self.resource_alloc, (jac_v.shape[1],), 'd')
        self.resource_alloc.allreduce_sum(result, local_dot,
                                          unit_ralloc=self.layout.resource_alloc('param-fine'))
        ret = result.copy()
        self.resource_alloc.host_comm_barrier()  # make sure we don't cleanup too quickly
        _smt.cleanup_shared_ndarray(result_shm)
        return ret
Пример #4
0
    def _iter_hprobs_by_rectangle(self, layout, wrt_slices_list, return_dprobs_12):
        # Just needed for compatibility - so base `iter_hprobs_by_rectangle` knows to loop over atoms
        # Similar to _iter_atom_hprobs_by_rectangle but runs over all atoms before yielding and
        #  yielded array has leading dim == # of local elements instead of just 1 atom's # elements.
        nElements = layout.num_elements
        resource_alloc = layout.resource_alloc()
        for wrtSlice1, wrtSlice2 in wrt_slices_list:

            if return_dprobs_12:
                dprobs1, dprobs1_shm = _smt.create_shared_ndarray(resource_alloc, (nElements, _slct.length(wrtSlice1)),
                                                                  'd', zero_out=True)
                dprobs2, dprobs2_shm = _smt.create_shared_ndarray(resource_alloc, (nElements, _slct.length(wrtSlice2)),
                                                                  'd', zero_out=True)
            else:
                dprobs1 = dprobs2 = dprobs1_shm = dprobs2_shm = None

            hprobs, hprobs_shm = _smt.create_shared_ndarray(
                resource_alloc, (nElements, _slct.length(wrtSlice1), _slct.length(wrtSlice2)),
                'd', zero_out=True)

            for atom in layout.atoms:
                self._bulk_fill_hprobs_dprobs_atom(hprobs[atom.element_slice, :, :],
                                                   dprobs1[atom.element_slice, :] if (dprobs1 is not None) else None,
                                                   dprobs2[atom.element_slice, :] if (dprobs2 is not None) else None,
                                                   atom, wrtSlice1, wrtSlice2, resource_alloc)
            #Note: we give resource_alloc as our local `resource_alloc` above because all the arrays
            # have been allocated based on just this subset of processors, unlike a call to bulk_fill_hprobs(...)
            # where the probs & dprobs are memory allocated and filled by a larger group of processors.  (the main
            # function of these args is to know which procs work together to fill the *same* values and which of
            # these are on the *same* host so that only one per host actually writes to the assumed-shared memory.

            if return_dprobs_12:
                dprobs12 = dprobs1[:, :, None] * dprobs2[:, None, :]  # (KM,N,1) * (KM,1,N') = (KM,N,N')
                yield wrtSlice1, wrtSlice2, hprobs, dprobs12
            else:
                yield wrtSlice1, wrtSlice2, hprobs

            _smt.cleanup_shared_ndarray(dprobs1_shm)
            _smt.cleanup_shared_ndarray(dprobs2_shm)
            _smt.cleanup_shared_ndarray(hprobs_shm)
Пример #5
0
    def bcast(self, value, root=0):
        """
        Broadcasts a value from the root processor/host to the others in this resource allocation.

        This is similar to a usual MPI broadcast, except it takes advantage of shared memory when
        it is available.  When shared memory is being used, i.e. when this :class:`ResourceAllocation`
        object has a nontrivial inter-host comm, then this routine places `value` in a shared memory
        buffer and uses the resource allocation's inter-host communicator to broadcast the result
        from the root *host* to all the other hosts using all the processor on the root host in
        parallel (all processors with the same intra-host rank participate in a MPI broadcast).

        Parameters
        ----------
        value : numpy.ndarray
            The value to broadcast.  May be shared memory but doesn't need to be.  Only
            need to specify this on the rank `root` processor, other processors can provide
            any value for this argument (it's unused).

        root : int
            The rank of the processor whose `value` will be to broadcast.

        Returns
        -------
        numpy.ndarray
            The broadcast value, in a new, non-shared-memory array.
        """
        if self.host_comm is not None:
            bcast_shape, bcast_dtype = self.comm.bcast(
                (value.shape, value.dtype) if self.comm.rank == root else None,
                root=root)
            #FUTURE: check whether `value` is already shared memory  (or add flag?) and if so don't allocate/free `ar`
            from pygsti.tools import sharedmemtools as _smt
            ar, ar_shm = _smt.create_shared_ndarray(self, bcast_shape,
                                                    bcast_dtype)
            if self.comm.rank == root:
                ar[(slice(None, None), ) *
                   value.ndim] = value  # put our value into the shared memory.

            self.host_comm.barrier(
            )  # wait until shared mem is written to on all root-host procs
            interhost_root = self.host_index_for_rank[
                root]  # (b/c host_index == interhost.rank)
            ret = self.interhost_comm.bcast(ar, root=interhost_root)
            self.comm.barrier()  # wait until everyone's values are ready
            _smt.cleanup_shared_ndarray(ar_shm)
            return ret
        elif self.comm is not None:
            return self.comm.bcast(value, root=root)
        else:
            return value
Пример #6
0
    def _iter_atom_hprobs_by_rectangle(self, atom, wrt_slices_list, return_dprobs_12, resource_alloc):

        #FUTURE could make a resource_alloc.check_can_allocate_memory call here for ('epp', 'epp')?
        nElements = atom.num_elements
        for wrtSlice1, wrtSlice2 in wrt_slices_list:

            if return_dprobs_12:
                dprobs1, dprobs1_shm = _smt.create_shared_ndarray(resource_alloc, (nElements, _slct.length(wrtSlice1)),
                                                                  'd', zero_out=True)
                dprobs2, dprobs2_shm = _smt.create_shared_ndarray(resource_alloc, (nElements, _slct.length(wrtSlice2)),
                                                                  'd', zero_out=True)
            else:
                dprobs1 = dprobs2 = dprobs1_shm = dprobs2_shm = None

            hprobs, hprobs_shm = _smt.create_shared_ndarray(
                resource_alloc, (nElements, _slct.length(wrtSlice1), _slct.length(wrtSlice2)),
                'd', zero_out=True)

            # Note: no need to index w/ [atom.element_slice,...] (compare with _iter_hprobs_by_rectangles)
            # since these arrays are already sized to this particular atom (not to all the host's atoms)
            self._bulk_fill_hprobs_dprobs_atom(hprobs, dprobs1, dprobs2, atom,
                                               wrtSlice1, wrtSlice2, resource_alloc)
            #Note: we give resource_alloc as our local `resource_alloc` above because all the arrays
            # have been allocated based on just this subset of processors, unlike a call to bulk_fill_hprobs(...)
            # where the probs & dprobs are memory allocated and filled by a larger group of processors.  (the main
            # function of these args is to know which procs work together to fill the *same* values and which of
            # these are on the *same* host so that only one per host actually writes to the assumed-shared memory.

            if return_dprobs_12:
                dprobs12 = dprobs1[:, :, None] * dprobs2[:, None, :]  # (KM,N,1) * (KM,1,N') = (KM,N,N')
                yield wrtSlice1, wrtSlice2, hprobs, dprobs12
            else:
                yield wrtSlice1, wrtSlice2, hprobs

            _smt.cleanup_shared_ndarray(dprobs1_shm)
            _smt.cleanup_shared_ndarray(dprobs2_shm)
            _smt.cleanup_shared_ndarray(hprobs_shm)
Пример #7
0
    def norm2_jtj(self, jtj):
        """
        Compute the Frobenius norm squared of an `jtj`-type matrix.

        Parameters
        ----------
        jtj : numpy.ndarray or LocalNumpyArray
            The array to operate on.

        Returns
        -------
        float
        """
        local_norm2 = _np.array(_np.linalg.norm(jtj)**2)
        local_norm2.shape = (1,)  # for compatibility with allreduce_sum
        result, result_shm = _smt.create_shared_ndarray(self.resource_alloc, (1,), 'd')
        self.resource_alloc.allreduce_sum(result, local_norm2,
                                          unit_ralloc=self.layout.resource_alloc('param-fine'))
        ret = result[0]  # "copies" the single returned element
        self.resource_alloc.host_comm_barrier()  # make sure we don't cleanup too quickly
        _smt.cleanup_shared_ndarray(result_shm)
        return ret
Пример #8
0
    def norm2_f(self, f):
        """
        Compute the Frobenius norm squared of an `f`-type vector.

        Parameters
        ----------
        f : numpy.ndarray or LocalNumpyArray
            The vector to operate on.

        Returns
        -------
        float
        """
        local_dot = _np.array(_np.dot(f, f))
        local_dot.shape = (1,)  # for compatibility with allreduce_sum
        result, result_shm = _smt.create_shared_ndarray(self.resource_alloc, (1,), 'd')
        self.resource_alloc.allreduce_sum(result, local_dot,
                                          unit_ralloc=self.layout.resource_alloc('atom-processing'))
        ret = result[0]  # "copies" the single returned element
        self.resource_alloc.host_comm_barrier()  # make sure we don't cleanup too quickly
        _smt.cleanup_shared_ndarray(result_shm)
        return ret
Пример #9
0
    def dot_x(self, x1, x2):
        """
        Take the dot product of two `x`-type vectors.

        Parameters
        ----------
        x1, x2 : numpy.ndarray or LocalNumpyArray
            The vectors to operate on.

        Returns
        -------
        float
        """
        # assumes x's are in "fine" mode
        local_dot = _np.array(_np.dot(x1, x2))
        local_dot.shape = (1,)  # for compatibility with allreduce_sum
        result, result_shm = _smt.create_shared_ndarray(self.resource_alloc, (1,), 'd')
        self.resource_alloc.allreduce_sum(result, local_dot,
                                          unit_ralloc=self.layout.resource_alloc('param-fine'))
        ret = result[0]  # "copies" the single returned element
        self.resource_alloc.host_comm_barrier()  # make sure we don't cleanup too quickly
        _smt.cleanup_shared_ndarray(result_shm)
        return ret
Пример #10
0
    def max_x(self, x):
        """
        Compute the maximum of an `x`-type vector.

        Parameters
        ----------
        x : numpy.ndarray or LocalNumpyArray
            The vector to operate on.

        Returns
        -------
        float
        """
        # assumes x's are in "fine" mode
        local_max = _np.array(_np.max(x))
        local_max.shape = (1,)  # for compatibility with allreduce_sum
        result, result_shm = _smt.create_shared_ndarray(self.resource_alloc, (1,), 'd')
        self.resource_alloc.allreduce_max(result, local_max,
                                          unit_ralloc=self.layout.resource_alloc('param-fine'))
        ret = result[0]  # "copies" the single returned element
        self.resource_alloc.host_comm_barrier()  # make sure we don't cleanup too quickly
        _smt.cleanup_shared_ndarray(result_shm)
        return ret
Пример #11
0
def custom_solve(a, b, x, ari, resource_alloc, proc_threshold=100):
    """
    Simple parallel Gaussian Elimination with pivoting.

    This function was built to provide a parallel alternative to
    `scipy.linalg.solve`, and can achieve faster runtimes compared
    with the serial SciPy routine when the number of available processors
    and problem size are large enough.

    When the number of processors is greater than `proc_threshold` (below
    this number the routine just calls `scipy.linalg.solve` on the root
    processor) the method works as follows:

    - each processor "owns" some subset of the rows of `a` and `b`.
    - iteratively (over pivot columns), the best pivot row is found, and this row is used to
      eliminate all other elements in the current pivot column.  This procedure operations on
      the joined matrix `a|b`, and when it completes the matrix `a` is in reduced row echelon
      form (RREF).
    - back substitution (trivial because `a` is in *reduced* REF) is performed to find
      the solution `x` such that `a @ x = b`.

    Parameters
    ----------
    a : LocalNumpyArray
        A 2D array with the `'jtj'` distribution, holding the rows of the `a` matrix belonging
        to the current processor.  (This belonging is dictated by the "fine" distribution in
        a distributed layout.)

    b : LocalNumpyArray
        A 1D array with the `'jtf'` distribution, holding the rows of the `b` vector belonging
        to the current processor.

    x : LocalNumpyArray
        A 1D array with the `'jtf'` distribution, holding the rows of the `x` vector belonging
        to the current processor.  This vector is filled by this function.

    ari : ArraysInterface
        An object that provides an interface for creating and manipulating data arrays.

    resource_alloc : ResourceAllocation
        Gives the resources (e.g., processors and memory) available for use.

    proc_threshold : int, optional
        Below this number of processors this routine will simply gather `a` and `b` to a single
        (the rank 0) processor, call SciPy's serial linear solver, `scipy.linalg.solve`, and
        scatter the results back onto all the processors.

    Returns
    -------
    None
    """

    #DEBUG
    #for i in range(a.shape[1]):
    #    print(i, " = ", _np.linalg.norm(a[:,i]))
    #assert(False), "STOP"

    pivot_row_indices = []
    #potential_pivot_indices = list(range(a.shape[0]))  # *local* row indices of rows not already chosen as pivot rows
    potential_pivot_mask = _np.ones(
        a.shape[0], dtype=bool
    )  # *local* row indices of rows not already chosen pivot rows
    all_row_indices = _np.arange(a.shape[0])
    my_row_slice = ari.jtf_param_slice()

    comm = resource_alloc.comm
    host_comm = resource_alloc.host_comm
    ok_buf = _np.empty(1, _np.int64)

    if comm is None or isinstance(ari, _UndistributedArraysInterface):
        x[:] = _scipy.linalg.solve(a, b, assume_a='pos')
        return

    #Just gather everything to one processor and compute there:
    if comm.size < proc_threshold and a.shape[1] < 10000:
        # We're not exactly sure where scipy is better, but until we speed up / change gaussian-elim
        # alg the scipy alg is much faster for small numbers of procs and so should be used unless
        # A is too large to be gathered to the root proc.
        global_a, a_shm = ari.gather_jtj(a, return_shared=True)
        global_b, b_shm = ari.gather_jtf(b, return_shared=True)
        #global_a = ari.gather_jtj(a)
        #global_b = ari.gather_jtf(b)
        if comm.rank == 0:
            try:
                global_x = _scipy.linalg.solve(global_a,
                                               global_b,
                                               assume_a='pos')
                ok_buf[0] = 1  # ok
            except _scipy.linalg.LinAlgError as e:
                ok_buf[0] = 0  # failure!
                err = e
        else:
            global_x = None
            err = _scipy.linalg.LinAlgError(
                "Linear solver fail on root proc!")  # just in case...

        comm.Bcast(ok_buf, root=0)
        if ok_buf[0] == 0:
            _smt.cleanup_shared_ndarray(a_shm)
            _smt.cleanup_shared_ndarray(b_shm)
            raise err  # all procs must raise in sync

        ari.scatter_x(global_x, x)
        _smt.cleanup_shared_ndarray(a_shm)
        _smt.cleanup_shared_ndarray(b_shm)
        return

    if host_comm is not None:
        shared_floats, shared_floats_shm = _smt.create_shared_ndarray(
            resource_alloc, (host_comm.size, ), 'd')
        shared_ints, shared_ints_shm = _smt.create_shared_ndarray(
            resource_alloc, (max(host_comm.size, 3), ), _np.int64)
        shared_rowb, shared_rowb_shm = _smt.create_shared_ndarray(
            resource_alloc, (a.shape[1] + 1, ), 'd')

        # Scratch buffers
        host_index_buf = _np.empty((resource_alloc.interhost_comm.size, 2), _np.int64) \
            if resource_alloc.interhost_comm.rank == 0 else None
        host_val_buf = _np.empty((resource_alloc.interhost_comm.size, 1), 'd') \
            if resource_alloc.interhost_comm.rank == 0 else None
    else:
        shared_floats = shared_ints = shared_rowb = None

        host_index_buf = _np.empty(
            (comm.size, 1), _np.int64) if comm.rank == 0 else None
        host_val_buf = _np.empty(
            (comm.size, 1), 'd') if comm.rank == 0 else None

    # Idea: bring a|b into RREF then back-substitute to get x.

    # for each column, find the best "pivot" row to use to eliminate other rows.
    # (note: column pivoting is not used
    a_orig = a.copy()  # So we can restore original values of a and b
    b_orig = b.copy()  # (they're updated as we go)

    #Scratch space
    local_pivot_rowb = _np.empty(a.shape[1] + 1, 'd')
    smbuf1 = _np.empty(1, 'd')
    smbuf2 = _np.empty(2, _np.int64)
    smbuf3 = _np.empty(3, _np.int64)

    for icol in range(a.shape[1]):

        # Step 1: find the index of the row that is the best pivot.
        # each proc looks for its best pivot (Note: it should not consider rows already pivoted on)
        potential_pivot_indices = all_row_indices[potential_pivot_mask]
        ibest_global, ibest_local, h, k = _find_pivot(
            a, b, icol, potential_pivot_indices, my_row_slice, shared_floats,
            shared_ints, resource_alloc, comm, host_comm, smbuf1, smbuf2,
            smbuf3, host_index_buf, host_val_buf)

        # Step 2: proc that owns best row (holds that row and is root of param-fine comm) broadcasts it
        pivot_row, pivot_b = _broadcast_pivot_row(
            a, b, ibest_local, h, k, shared_rowb, local_pivot_rowb,
            potential_pivot_mask, resource_alloc, comm, host_comm)

        if abs(pivot_row[icol]) < 1e-6:
            # There's no non-zero element in this column to use as a pivot - the column is all zeros.
            # By convention, we just set the corresponding x-value to zero (below) and don't need to do Step 3.
            # NOTE: it's possible that a previous pivot row could have a non-zero element in the icol-th column,
            #  and we could still get here (because we don't consider previously chosen rows as pivot-row candidtate).
            #  But this is ok, since we set the corresponding x-values to 0 so the end result is effectively in RREF.
            pivot_row_indices.append(-1)
            continue

        pivot_row_indices.append(ibest_global)

        # Step 3: all procs update their rows based on the pivot row (including `b`)
        #  - need to update non-pivot rows to eliminate iCol-th entry: row -= alpha * pivot_row
        #    where alpha = row[iCol] / pivot_row[iCol]
        # (Note: don't do this when there isn't a nonzero pivot)
        ipivot_local = ibest_global - my_row_slice.start  # *local* row index of pivot row (ok if negative)
        _update_rows(a, b, icol, ipivot_local, pivot_row, pivot_b)

    # Back substitution:
    # We've accumulated a list of (global) row indices of the rows containing a nonzero
    # element in a given column and zeros in prior columns.
    pivot_row_indices = _np.array(pivot_row_indices)
    _back_substitution(a, b, x, pivot_row_indices, my_row_slice, ari,
                       resource_alloc, host_comm)

    a[:, :] = a_orig  # restore original values of a and b
    b[:] = b_orig  # so they're the same as when we were called.
    # Note: maybe we could just use array copies in the algorithm, but we may need to use the
    # real a and b because they can be shared mem (check?)

    if host_comm is not None:
        _smt.cleanup_shared_ndarray(shared_floats_shm)
        _smt.cleanup_shared_ndarray(shared_ints_shm)
        _smt.cleanup_shared_ndarray(shared_rowb_shm)
    return
Пример #12
0
def mpidot(a, b, loc_row_slice, loc_col_slice, slice_tuples_by_rank, comm,
           out=None, out_shm=None):
    """
    Performs a distributed dot product, dot(a,b).

    Parameters
    ----------
    a : numpy.ndarray
        First array to dot together.

    b : numpy.ndarray
        Second array to dot together.

    loc_row_slice, loc_col_slice : slice
        Specify the row or column indices, respectively, of the
        resulting dot product that are computed by this processor (the
        rows of `a` and columns of `b` that are used). Obtained from
        :func:`distribute_for_dot`.

    slice_tuples_by_rank : list
        A list of (row_slice, col_slice) tuples, one per processor within this
        processors broadcast group, ordered by rank.  Provided by :func:`distribute_for_dot`.

    comm : mpi4py.MPI.Comm or ResourceAllocation or None
        The communicator used to parallelize the dot product.  If a
        :class:`ResourceAllocation` object is given, then a shared
        memory result will be returned when appropriate.

    out : numpy.ndarray, optional
        If not None, the array to use for the result.  This should be the
        same type of array (size, and whether it's shared or not) as this
        function would have created if `out` were `None`.

    out_shm : multiprocessing.shared_memory.SharedMemory, optinal
        The shared memory object corresponding to `out` when it uses
        shared memory.

    Returns
    -------
    result : numpy.ndarray
        The resulting array
    shm : multiprocessing.shared_memory.SharedMemory
        A shared memory object needed to cleanup the shared memory.  If
        a normal array is created, this is `None`.  Provide this to
        :function:`cleanup_shared_ndarray` to ensure `ar` is deallocated properly.
    """
    # R_ij = sum_k A_ik * B_kj
    from ..baseobjs.resourceallocation import ResourceAllocation as _ResourceAllocation
    if isinstance(comm, _ResourceAllocation):
        ralloc = comm
        comm = ralloc.comm
    else:
        ralloc = None

    if comm is None or comm.Get_size() == 1:
        return _np.dot(a, b), None

    if out is None:
        if ralloc is None:
            result, result_shm = _np.zeros((a.shape[0], b.shape[1]), a.dtype), None
        else:
            result, result_shm = _smt.create_shared_ndarray(ralloc, (a.shape[0], b.shape[1]), a.dtype,
                                                            zero_out=True)
    else:
        result = out
        result_shm = out_shm

    rshape = (_slct.length(loc_row_slice), _slct.length(loc_col_slice))
    loc_result_flat = _np.empty(rshape[0] * rshape[1], a.dtype)
    loc_result = loc_result_flat.view(); loc_result.shape = rshape
    loc_result[:, :] = _np.dot(a[loc_row_slice, :], b[:, loc_col_slice])

    # broadcast_com defines the group of processors this processor communicates with.
    # Without shared memory, this is *all* the other processors.  With shared memory, this
    # is one processor on each host.  This code is identical to that in distribute_for_dot.
    if ralloc is None:
        broadcast_comm = comm
    else:
        broadcast_comm = comm if (ralloc.interhost_comm is None) else ralloc.interhost_comm

    comm.barrier()  # wait for all ranks to do their work (get their loc_result)
    for r, (cur_row_slice, cur_col_slice) in enumerate(slice_tuples_by_rank):
        # for each member of the group that will communicate results
        cur_shape = (_slct.length(cur_row_slice), _slct.length(cur_col_slice))
        buf = loc_result_flat if (broadcast_comm.rank == r) else _np.empty(cur_shape[0] * cur_shape[1], a.dtype)
        broadcast_comm.Bcast(buf, root=r)
        if broadcast_comm.rank != r: buf.shape = cur_shape
        else: buf = loc_result  # already of correct shape
        result[cur_row_slice, cur_col_slice] = buf
    comm.barrier()  # wait for all ranks to finish writing to result

    #assert(_np.linalg.norm(_np.dot(a,b) - result)/(_np.linalg.norm(result) + result.size) < 1e-6),\
    #    "DEBUG: %g, %g, %d" % (_np.linalg.norm(_np.dot(a,b) - result), _np.linalg.norm(result), result.size)
    return result, result_shm