def calc_coeffs_no_i(self, wave_base_j_qq_ZS_at_xs, wave_dual_j_qq_ZS_at_xs, j, xs, i, balls_info, qq): "Calculate alphas (w/ dual) and alpha-duals (w/ base)" jj = [j + j0 for j0 in self.j0s] jpow2 = np.array([2**j for j in jj]) zs_min, zs_max = self.wave.z_range('dual', (qq, jpow2, None), self.minx, self.maxx) omega_no_i = calc_omega(xs.shape[0] - 1, self.k) resp = {} vol_no_i = balls_no_i(balls_info, i) for zs in itt.product(*all_zs_tensor(zs_min, zs_max)): # below, we remove factor for i from sum << this has the biggest impact in performance # also, we calculated alpha_zs previously and cen be further optimised w/ calc_coeffs alpha_zs = omega_no_i * ( (wave_dual_j_qq_ZS_at_xs[zs] * vol_no_i).sum() - wave_dual_j_qq_ZS_at_xs[zs][i] * vol_no_i[i]) resp[zs] = (alpha_zs, alpha_zs) if self.wave.orthogonal: # we are done return resp zs_min, zs_max = self.wave.z_range('base', (qq, jpow2, None), self.minx, self.maxx) for zs in itt.product(*all_zs_tensor(zs_min, zs_max)): if zs not in resp: continue # below, we remove factor for i from sum << this has the biggest impact in performance alpha_d_zs = omega_no_i * ( (wave_base_j_qq_ZS_at_xs[zs] * vol_no_i).sum() - wave_base_j_qq_ZS_at_xs[zs][i] * vol_no_i[i]) resp[zs] = (resp[zs][0], alpha_d_zs) return resp
def calc_coeffs(self, wave_base_j_qq_ZS_at_xs, wave_dual_j_qq_ZS_at_xs, j, xs, balls_info, qq): jj = [j + j0 for j0 in self.j0s] jpow2 = np.array([2**j for j in jj]) zs_min, zs_max = self.wave.z_range('dual', (qq, jpow2, None), self.minx, self.maxx) omega = calc_omega(xs.shape[0], self.k) resp = {} balls = balls_info.sqrt_vol_k for zs in itt.product(*all_zs_tensor(zs_min, zs_max)): alpha_zs = omega * (wave_dual_j_qq_ZS_at_xs[zs] * balls).sum() resp[zs] = (alpha_zs, alpha_zs) if self.wave.orthogonal: # we are done return resp zs_min, zs_max = self.wave.z_range('base', (qq, jpow2, None), self.minx, self.maxx) for zs in itt.product(*all_zs_tensor(zs_min, zs_max)): if zs not in resp: continue alpha_d_zs = omega * (wave_base_j_qq_ZS_at_xs[zs] * balls).sum() resp[zs] = (resp[zs][0], alpha_d_zs) return resp
def test_z_range_2d(wave, what, ix): "Test several facts with the range of z values" # region is p1=(1/3,1/4) & p2=(3/4,2/3) minx, maxx = np.array((1 / 3, 1 / 4)), np.array((3 / 4, 2 / 3)) qq, ss, zz = ix zs_min, zs_max = wave.z_range(what, ix, minx, maxx) one0 = np.array([1, 0]) one1 = np.array([0, 1]) assert not intersect_2d((minx, maxx), wave.fun_ix(what, (qq, ss, zs_min - one0)).support) assert not intersect_2d((minx, maxx), wave.fun_ix(what, (qq, ss, zs_min - one1)).support) assert not intersect_2d((minx, maxx), wave.fun_ix(what, (qq, ss, zs_max + one0)).support) assert not intersect_2d((minx, maxx), wave.fun_ix(what, (qq, ss, zs_max + one1)).support) for zz in itt.product(*all_zs_tensor(zs_min, zs_max)): assert intersect_2d((minx, maxx), wave.fun_ix(what, (qq, ss, zz)).support)
def calc_funs(self, j, qq): """ :param j: int, resolution level :param qq: tensor index in R^d :return: (base funs, dual funs) funs[zs] = base|dual wave _{j,zs}^{(qq)} wave_base_j_qq_ZS, wave_dual_j_qq_ZS """ jj = [j + j0 for j0 in self.j0s] jpow2 = np.array([2**j for j in jj]) funs = {} for what in ['dual', 'base']: zs_min, zs_max = self.wave.z_range(what, (qq, jpow2, None), self.minx, self.maxx) funs[what] = {} for zs in itt.product(*all_zs_tensor(zs_min, zs_max)): funs[what][zs] = self.wave.fun_ix(what, (qq, jpow2, zs)) return funs['base'], funs['dual']
def test_z1(): wave = WaveletTensorProduct(('db1', 'db1')) what = 'dual' ix = ((0, 0), (1, 2), (0, 0)) minx, maxx = np.array((0.2, 0.2)), np.array((0.4, 0.6)) qq, ss, zz = ix zs_min, zs_max = wave.z_range(what, ix, minx, maxx) print(zs_min, zs_max) one0 = np.array([1, 0]) one1 = np.array([0, 1]) assert not intersect_2d((minx, maxx), wave.fun_ix(what, (qq, ss, zs_min - one0)).support) assert not intersect_2d((minx, maxx), wave.fun_ix(what, (qq, ss, zs_min - one1)).support) assert not intersect_2d((minx, maxx), wave.fun_ix(what, (qq, ss, zs_max + one0)).support) assert not intersect_2d((minx, maxx), wave.fun_ix(what, (qq, ss, zs_max + one1)).support) for zz in itt.product(*all_zs_tensor(zs_min, zs_max)): assert intersect_2d((minx, maxx), wave.fun_ix(what, (qq, ss, zz)).support)