def interp(self, m): # q = data_extra[self.outs[0]]["|q|"] # a = self.a() p = self.point_value() zeros = tf.zeros_like(m) ones = tf.ones_like(m) def poly_i(i, xi): tmp = zeros for j in range(i - 1, i + 1): if j < 0 or j > self.interp_N - 1: continue r = ones for k in range(j, j + 2): if k == i: continue r = r * (m - xi[k]) / (xi[i] - xi[k]) r = tf.where((m >= xi[j]) & (m < xi[j + 1]), r, zeros) tmp = tmp + r return tmp h = tf.stack( [poly_i(i, self.points) for i in range(1, self.interp_N - 1)], axis=-1, ) h = tf.stop_gradient(h) p_r = tf.math.real(p) p_i = tf.math.imag(p) ret_r = tf.reduce_sum(h * p_r, axis=-1) ret_i = tf.reduce_sum(h * p_i, axis=-1) return tf.complex(ret_r, ret_i)
def interp(self, m): zeros = tf.zeros_like(m) p = self.point_value() p_r = tf.math.real(p) p_i = tf.math.imag(p) xi_m = self.h_matrix x_m = spline_x_matrix(m, self.points) x_m = tf.expand_dims(x_m, axis=-1) m_xi = tf.reduce_sum(xi_m * x_m, axis=[-3, -2]) m_xi = tf.stop_gradient(m_xi) ret_r = tf.reduce_sum(tf.cast(m_xi, p_r.dtype) * p_r, axis=-1) ret_i = tf.reduce_sum(tf.cast(m_xi, p_i.dtype) * p_i, axis=-1) return tf.complex(ret_r, ret_i)
def interp(self, m): p = self.point_value() ones = tf.ones_like(m) zeros = tf.zeros_like(m) p_r = tf.math.real(p) p_i = tf.math.imag(p) h, b = get_matrix_interp1d3_v2(m, self.points) h = tf.stop_gradient(h) f = lambda x: tf.reshape( tf.matmul(tf.cast(h, x.dtype), tf.reshape(x, (-1, 1))), b.shape ) + tf.cast(b, x.dtype) ret_r = f(p_r) ret_i = f(p_i) return tf.complex(ret_r, ret_i)
def get_bin_index(self, m): if self.fix_width: m_min = tf.convert_to_tensor(self.points[0], m.dtype) m_max = tf.convert_to_tensor(self.points[-1], m.dtype) delta_width = (m_max - m_min) / (self.interp_N - 1) bin_idx = tf.histogram_fixed_width_bins( m, [m_min - delta_width, m_max + delta_width], nbins=self.interp_N + 1, dtype=tf.dtypes.int64, ) else: dig = lambda x, y: tf.numpy_function(np.digitize, [x, y], tf.int64) bin_idx = dig(m, self.points) bin_idx = bin_idx - 1 # print(tf.reduce_max(bin_idx), tf.reduce_min(bin_idx)) bin_idx = tf.stop_gradient(bin_idx) return bin_idx
def interp(self, m): p = self.point_value() ones = tf.ones_like(m) zeros = tf.zeros_like(m) def add_f(x, bl, br): return tf.where((x > bl) & (x <= br), ones, zeros) x_bin = tf.stack( [ add_f( m, (self.points[i] + self.points[i + 1]) / 2, (self.points[i + 1] + self.points[i + 2]) / 2, ) for i in range(self.interp_N - 2) ], axis=-1, ) p_r = tf.math.real(p) p_i = tf.math.imag(p) x_bin = tf.stop_gradient(x_bin) ret_r = tf.reduce_sum(x_bin * p_r, axis=-1) ret_i = tf.reduce_sum(x_bin * p_i, axis=-1) return tf.complex(ret_r, ret_i)