Esempio n. 1
0
    def interp(self, m):
        # q = data_extra[self.outs[0]]["|q|"]
        # a = self.a()
        p = self.point_value()
        zeros = tf.zeros_like(m)
        ones = tf.ones_like(m)

        def poly_i(i, xi):
            tmp = zeros
            for j in range(i - 1, i + 1):
                if j < 0 or j > self.interp_N - 1:
                    continue
                r = ones
                for k in range(j, j + 2):
                    if k == i:
                        continue
                    r = r * (m - xi[k]) / (xi[i] - xi[k])
                r = tf.where((m >= xi[j]) & (m < xi[j + 1]), r, zeros)
                tmp = tmp + r
            return tmp

        h = tf.stack(
            [poly_i(i, self.points) for i in range(1, self.interp_N - 1)],
            axis=-1,
        )
        h = tf.stop_gradient(h)
        p_r = tf.math.real(p)
        p_i = tf.math.imag(p)
        ret_r = tf.reduce_sum(h * p_r, axis=-1)
        ret_i = tf.reduce_sum(h * p_i, axis=-1)
        return tf.complex(ret_r, ret_i)
Esempio n. 2
0
 def interp(self, m):
     zeros = tf.zeros_like(m)
     p = self.point_value()
     p_r = tf.math.real(p)
     p_i = tf.math.imag(p)
     xi_m = self.h_matrix
     x_m = spline_x_matrix(m, self.points)
     x_m = tf.expand_dims(x_m, axis=-1)
     m_xi = tf.reduce_sum(xi_m * x_m, axis=[-3, -2])
     m_xi = tf.stop_gradient(m_xi)
     ret_r = tf.reduce_sum(tf.cast(m_xi, p_r.dtype) * p_r, axis=-1)
     ret_i = tf.reduce_sum(tf.cast(m_xi, p_i.dtype) * p_i, axis=-1)
     return tf.complex(ret_r, ret_i)
Esempio n. 3
0
 def interp(self, m):
     p = self.point_value()
     ones = tf.ones_like(m)
     zeros = tf.zeros_like(m)
     p_r = tf.math.real(p)
     p_i = tf.math.imag(p)
     h, b = get_matrix_interp1d3_v2(m, self.points)
     h = tf.stop_gradient(h)
     f = lambda x: tf.reshape(
         tf.matmul(tf.cast(h, x.dtype), tf.reshape(x, (-1, 1))), b.shape
     ) + tf.cast(b, x.dtype)
     ret_r = f(p_r)
     ret_i = f(p_i)
     return tf.complex(ret_r, ret_i)
Esempio n. 4
0
 def get_bin_index(self, m):
     if self.fix_width:
         m_min = tf.convert_to_tensor(self.points[0], m.dtype)
         m_max = tf.convert_to_tensor(self.points[-1], m.dtype)
         delta_width = (m_max - m_min) / (self.interp_N - 1)
         bin_idx = tf.histogram_fixed_width_bins(
             m,
             [m_min - delta_width, m_max + delta_width],
             nbins=self.interp_N + 1,
             dtype=tf.dtypes.int64,
         )
     else:
         dig = lambda x, y: tf.numpy_function(np.digitize, [x, y], tf.int64)
         bin_idx = dig(m, self.points)
     bin_idx = bin_idx - 1
     # print(tf.reduce_max(bin_idx), tf.reduce_min(bin_idx))
     bin_idx = tf.stop_gradient(bin_idx)
     return bin_idx
Esempio n. 5
0
    def interp(self, m):
        p = self.point_value()
        ones = tf.ones_like(m)
        zeros = tf.zeros_like(m)

        def add_f(x, bl, br):
            return tf.where((x > bl) & (x <= br), ones, zeros)

        x_bin = tf.stack(
            [
                add_f(
                    m,
                    (self.points[i] + self.points[i + 1]) / 2,
                    (self.points[i + 1] + self.points[i + 2]) / 2,
                ) for i in range(self.interp_N - 2)
            ],
            axis=-1,
        )
        p_r = tf.math.real(p)
        p_i = tf.math.imag(p)
        x_bin = tf.stop_gradient(x_bin)
        ret_r = tf.reduce_sum(x_bin * p_r, axis=-1)
        ret_i = tf.reduce_sum(x_bin * p_i, axis=-1)
        return tf.complex(ret_r, ret_i)