def grid_search_ss_split(atlas, I, W, R, u, bounds=[-20, 20]): ss_offset = W.distribution.offsets[rank] print('rank, W.shape, offset:', rank, W.shape, W.arr.shape, ss_offset, W.distribution.offsets[rank]) def fun(ui, uj, i, j, w): ss = np.rint(-R[:, 0] + ui + i + ss_offset).astype(np.int) fs = np.rint(-R[:, 1] + uj + j).astype(np.int) m = (ss > 0) * (ss < atlas.shape[0]) * (fs > 0) * (fs < atlas.shape[1]) forw = atlas[ss[m], fs[m]] m1 = (forw > 0) * (w[m, i, j] > 0) n = np.sum(m1) if n > 0: #err = np.sum(m1*(forw - I.arr[m, i, j]/w[m, i, j])**2 ) / n err = np.sum(m1 * (forw * w[m, i, j] - I.arr[m, i, j])**2) / n else: err = 1e100 return err # define the search window k = np.arange(int(bounds[0]), int(bounds[1]) + 1, 1) k, l = np.meshgrid(k, k, indexing='ij') kls = np.vstack((k.ravel(), l.ravel())).T # define the pupil idices ss = np.arange(W.arr.shape[-2]) fs = np.arange(W.arr.shape[-1]) if len(W.arr.shape) == 2: w = np.array([W.arr for i in range(I.shape[0])]) else: w = W.arr ss_min = comm.allreduce(np.max(ss), op=MPI.MIN) errs = np.empty((len(kls), ), dtype=np.float64) u_out = MpiArray(u.arr, axis=1) for i in ss: print(rank, i, i + ss_offset) sys.stdout.flush() for j in fs: errs.fill(1e100) for k, kl in enumerate(kls): if np.any(w[:, i, j]) > 0: errs[k] = fun(u.arr[0, i, j] + kl[0], u.arr[1, i, j] + kl[1], i, j, w) k = np.argmin(errs) u_out.arr[:, i, j] += kls[k] # output every 10 ss pixels if i % 1 == 0 and i <= ss_min and callback is not None: ug = u_out.gather() callback(ug) return u_out
def grid_search_ss_split_sub_pix(atlas, I, W, R, u, bounds=[-1, 1], max_iters=100): ss_offset = W.distribution.offsets[rank] def fun(x, i=0, j=0, w=0, uu=0): forw = np.empty((R.shape[0], ), dtype=np.float) fm.sub_pixel_atlas_eval(atlas, forw, -R[:, 0] + x[0] + uu[0] + i + ss_offset, -R[:, 1] + x[1] + uu[1] + j) m = (forw > 0) * (w > 0) n = np.sum(m) if n > 0: err = np.sum((forw[m] * w[m] - I.arr[m, i, j])**2) / n else: err = -1 return err if len(W.arr.shape) == 2: w = np.array([W.arr for i in range(I.shape[0])]) else: w = W.arr # define the pupil idices ss = np.arange(W.arr.shape[-2]) fs = np.arange(W.arr.shape[-1]) ss_min = comm.allreduce(np.max(ss), op=MPI.MIN) u_out = MpiArray(u.arr, axis=1) for i in ss: for j in fs: x0 = np.array([0., 0.]) err0 = fun(x0, i, j, w[:, i, j], u.arr[:, i, j]) options = {'maxiter': max_iters, 'eps': 0.1, 'xatol': 0.1} from scipy.optimize import minimize res = minimize(fun, x0, bounds=[bounds, bounds], args=(i, j, w[:, i, j], u.arr[:, i, j]), options={ 'maxiter': max_iters, 'eps': 0.1, 'xatol': 0.1 }) if rank == 0 and j == fs[fs.shape[0] // 2]: print(err0, res) sys.stdout.flush() if res.fun > 0: u_out.arr[:, i, j] += res.x # output every 10 ss pixels if i % 1 == 0 and i <= ss_min and callback is not None: ug = u_out.gather() callback(ug) return u_out
bounds=[(-m, m), (-m, m)], max_iters=params['max_iters'], sub_pixel=params['sub_pixel']) errs2 = errs.allgather() em = errs2.mean() weights = np.abs(errs.arr - em) weights = 1. - weights / weights.max() + 0.2 weights = MpiArray(weights, axis=0) #atlas = build_atlas_distortions_MpiArray(params['frames'], # params['whitefield'], # params['R_ss_fs'], # params['pixel_shifts'], # reg=reg, weights = weights.arr) params['R_ss_fs'] = params['R_ss_fs'].gather() errs = errs.gather() weights = weights.gather() # real-time output if rank == 0: out = {'R_ss_fs': params['R_ss_fs'], 'errs': errs, 'weights': weights} cmdline_config_cxi_reader.write_all(params, args.filename, out) print('display: ' + params['h5_group'] + '/R_ss_fs') sys.stdout.flush() print('') sys.stdout.flush() import time time.sleep(1)