def grid_search_ss_split(atlas, I, W, R, u, bounds=[-20, 20]):
    ss_offset = W.distribution.offsets[rank]
    print('rank, W.shape, offset:', rank, W.shape, W.arr.shape, ss_offset,
          W.distribution.offsets[rank])

    def fun(ui, uj, i, j, w):
        ss = np.rint(-R[:, 0] + ui + i + ss_offset).astype(np.int)
        fs = np.rint(-R[:, 1] + uj + j).astype(np.int)
        m = (ss > 0) * (ss < atlas.shape[0]) * (fs > 0) * (fs < atlas.shape[1])
        forw = atlas[ss[m], fs[m]]
        m1 = (forw > 0) * (w[m, i, j] > 0)
        n = np.sum(m1)
        if n > 0:
            #err = np.sum(m1*(forw - I.arr[m, i, j]/w[m, i, j])**2 ) / n
            err = np.sum(m1 * (forw * w[m, i, j] - I.arr[m, i, j])**2) / n
        else:
            err = 1e100
        return err

    # define the search window
    k = np.arange(int(bounds[0]), int(bounds[1]) + 1, 1)
    k, l = np.meshgrid(k, k, indexing='ij')
    kls = np.vstack((k.ravel(), l.ravel())).T

    # define the pupil idices
    ss = np.arange(W.arr.shape[-2])
    fs = np.arange(W.arr.shape[-1])

    if len(W.arr.shape) == 2:
        w = np.array([W.arr for i in range(I.shape[0])])
    else:
        w = W.arr

    ss_min = comm.allreduce(np.max(ss), op=MPI.MIN)
    errs = np.empty((len(kls), ), dtype=np.float64)
    u_out = MpiArray(u.arr, axis=1)
    for i in ss:
        print(rank, i, i + ss_offset)
        sys.stdout.flush()
        for j in fs:
            errs.fill(1e100)
            for k, kl in enumerate(kls):
                if np.any(w[:, i, j]) > 0:
                    errs[k] = fun(u.arr[0, i, j] + kl[0],
                                  u.arr[1, i, j] + kl[1], i, j, w)

            k = np.argmin(errs)
            u_out.arr[:, i, j] += kls[k]

        # output every 10 ss pixels
        if i % 1 == 0 and i <= ss_min and callback is not None:
            ug = u_out.gather()
            callback(ug)

    return u_out
def grid_search_ss_split_sub_pix(atlas,
                                 I,
                                 W,
                                 R,
                                 u,
                                 bounds=[-1, 1],
                                 max_iters=100):
    ss_offset = W.distribution.offsets[rank]

    def fun(x, i=0, j=0, w=0, uu=0):
        forw = np.empty((R.shape[0], ), dtype=np.float)
        fm.sub_pixel_atlas_eval(atlas, forw,
                                -R[:, 0] + x[0] + uu[0] + i + ss_offset,
                                -R[:, 1] + x[1] + uu[1] + j)

        m = (forw > 0) * (w > 0)
        n = np.sum(m)
        if n > 0:
            err = np.sum((forw[m] * w[m] - I.arr[m, i, j])**2) / n
        else:
            err = -1
        return err

    if len(W.arr.shape) == 2:
        w = np.array([W.arr for i in range(I.shape[0])])
    else:
        w = W.arr

    # define the pupil idices
    ss = np.arange(W.arr.shape[-2])
    fs = np.arange(W.arr.shape[-1])

    ss_min = comm.allreduce(np.max(ss), op=MPI.MIN)
    u_out = MpiArray(u.arr, axis=1)
    for i in ss:
        for j in fs:
            x0 = np.array([0., 0.])
            err0 = fun(x0, i, j, w[:, i, j], u.arr[:, i, j])
            options = {'maxiter': max_iters, 'eps': 0.1, 'xatol': 0.1}
            from scipy.optimize import minimize
            res = minimize(fun,
                           x0,
                           bounds=[bounds, bounds],
                           args=(i, j, w[:, i, j], u.arr[:, i, j]),
                           options={
                               'maxiter': max_iters,
                               'eps': 0.1,
                               'xatol': 0.1
                           })
            if rank == 0 and j == fs[fs.shape[0] // 2]:
                print(err0, res)
                sys.stdout.flush()
            if res.fun > 0:
                u_out.arr[:, i, j] += res.x

        # output every 10 ss pixels
        if i % 1 == 0 and i <= ss_min and callback is not None:
            ug = u_out.gather()
            callback(ug)

    return u_out
                                             bounds=[(-m, m), (-m, m)],
                                             max_iters=params['max_iters'],
                                             sub_pixel=params['sub_pixel'])

    errs2 = errs.allgather()
    em = errs2.mean()
    weights = np.abs(errs.arr - em)
    weights = 1. - weights / weights.max() + 0.2
    weights = MpiArray(weights, axis=0)

    #atlas  = build_atlas_distortions_MpiArray(params['frames'],
    #                                          params['whitefield'],
    #                                          params['R_ss_fs'],
    #                                          params['pixel_shifts'],
    #                                          reg=reg, weights = weights.arr)

    params['R_ss_fs'] = params['R_ss_fs'].gather()
    errs = errs.gather()
    weights = weights.gather()

    # real-time output
    if rank == 0:
        out = {'R_ss_fs': params['R_ss_fs'], 'errs': errs, 'weights': weights}
        cmdline_config_cxi_reader.write_all(params, args.filename, out)
        print('display: ' + params['h5_group'] + '/R_ss_fs')
        sys.stdout.flush()
        print('')
        sys.stdout.flush()
        import time
        time.sleep(1)