예제 #1
0
    def _inference(self, x, dropout):
        with tf.name_scope('conv1'):
            # Transform to Fourier domain
            x_2d = tf.reshape(x, [-1, 28, 28])
            x_2d = tf.complex(x_2d, 0)
            xf_2d = tf.fft2d(x_2d)
            xf = tf.reshape(xf_2d, [-1, NFEATURES])
            xf = tf.expand_dims(xf, 1)  # NSAMPLES x 1 x NFEATURES
            xf = tf.transpose(xf)  # NFEATURES x 1 x NSAMPLES
            # Filter
            Wreal = self._weight_variable([int(NFEATURES/2), self.F, 1])
            Wimg = self._weight_variable([int(NFEATURES/2), self.F, 1])
            W = tf.complex(Wreal, Wimg)
            xf = xf[:int(NFEATURES/2), :, :]
            yf = tf.matmul(W, xf)  # for each feature
            yf = tf.concat([yf, tf.conj(yf)], axis=0)
            yf = tf.transpose(yf)  # NSAMPLES x NFILTERS x NFEATURES
            yf_2d = tf.reshape(yf, [-1, 28, 28])
            # Transform back to spatial domain
            y_2d = tf.ifft2d(yf_2d)
            y_2d = tf.real(y_2d)
            y = tf.reshape(y_2d, [-1, self.F, NFEATURES])
            # Bias and non-linearity
            b = self._bias_variable([1, self.F, 1])
#            b = self._bias_variable([1, self.F, NFEATURES])
            y += b  # NSAMPLES x NFILTERS x NFEATURES
            y = tf.nn.relu(y)
        with tf.name_scope('fc1'):
            W = self._weight_variable([self.F*NFEATURES, NCLASSES])
            b = self._bias_variable([NCLASSES])
            y = tf.reshape(y, [-1, self.F*NFEATURES])
            y = tf.matmul(y, W) + b
        return y
def loss(model,
         n,
         A_stencil,
         A_matrices,
         S_matrices,
         index=None,
         pos=-1.,
         phase="Training",
         epoch=-1,
         grid_size=8,
         remove=True):
    with tf.device(DEVICE):
        A_matrices = tf.conj(A_matrices)
        S_matrices = tf.conj(S_matrices)
        pi = tf.constant(np.pi)
        theta_x = np.array(([
            i * 2 * pi / n
            for i in range(-n // (grid_size * 2) + 1, n // (grid_size * 2) + 1)
        ]))
    with tf.device(DEVICE):
        if phase == "Test" and epoch == 0:
            P_stencil = model(A_stencil, True)
            P_matrix = utils.compute_p2LFA(P_stencil, n, grid_size)
            P_matrix = tf.transpose(P_matrix, [2, 0, 1, 3, 4])
            P_matrix_t = tf.transpose(P_matrix, [0, 1, 2, 4, 3],
                                      conjugate=True)
            A_c = tf.matmul(tf.matmul(P_matrix_t, A_matrices), P_matrix)

            index_to_remove = len(theta_x) * (
                -1 + n // (2 * grid_size)) + n // (2 * grid_size) - 1
            A_c = tf.reshape(A_c, (-1, int(theta_x.shape[0])**2,
                                   (grid_size // 2)**2, (grid_size // 2)**2))
            A_c_removed = tf.concat(
                [A_c[:, :index_to_remove], A_c[:, index_to_remove + 1:]], 1)
            P_matrix_t_reshape = tf.reshape(
                P_matrix_t, (-1, int(theta_x.shape[0])**2,
                             (grid_size // 2)**2, grid_size**2))
            P_matrix_reshape = tf.reshape(
                P_matrix, (-1, int(theta_x.shape[0])**2, grid_size**2,
                           (grid_size // 2)**2))
            A_matrices_reshaped = tf.reshape(
                A_matrices,
                (-1, int(theta_x.shape[0])**2, grid_size**2, grid_size**2))
            A_matrices_removed = tf.concat([
                A_matrices_reshaped[:, :index_to_remove],
                A_matrices_reshaped[:, index_to_remove + 1:]
            ], 1)

            P_matrix_removed = tf.concat([
                P_matrix_reshape[:, :index_to_remove],
                P_matrix_reshape[:, index_to_remove + 1:]
            ], 1)
            P_matrix_t_removed = tf.concat([
                P_matrix_t_reshape[:, :index_to_remove],
                P_matrix_t_reshape[:, index_to_remove + 1:]
            ], 1)

            A_coarse_inv_removed = tf.matrix_solve(A_c_removed,
                                                   P_matrix_t_removed)

            CGC_removed = tf.eye(grid_size ** 2, dtype=tf.complex128) \
                          - tf.matmul(tf.matmul(P_matrix_removed, A_coarse_inv_removed), A_matrices_removed)
            S_matrices_reshaped = tf.reshape(
                S_matrices,
                (-1, int(theta_x.shape[0])**2, grid_size**2, grid_size**2))
            S_removed = tf.concat([
                S_matrices_reshaped[:, :index_to_remove],
                S_matrices_reshaped[:, index_to_remove + 1:]
            ], 1)
            iteration_matrix = tf.matmul(tf.matmul(CGC_removed, S_removed),
                                         S_removed)
            loss_test = tf.reduce_mean(
                tf.reduce_mean(
                    tf.reduce_sum(tf.square(tf.abs(iteration_matrix)), [2, 3]),
                    1))
            return tf.constant([0.]), loss_test.numpy()
        if index is not None:
            P_stencil = model(A_stencil, index=index, pos=pos, phase=phase)
        else:
            P_stencil = model(A_stencil, phase=phase)

        if not (phase == "Test" and epoch == 0):
            P_matrix = utils.compute_p2LFA(P_stencil, n, grid_size)

            P_matrix = tf.transpose(P_matrix, [2, 0, 1, 3, 4])
            P_matrix_t = tf.transpose(P_matrix, [0, 1, 2, 4, 3],
                                      conjugate=True)

            A_c = tf.matmul(tf.matmul(P_matrix_t, A_matrices), P_matrix)
            index_to_remove = len(theta_x) * (
                -1 + n // (2 * grid_size)) + n // (2 * grid_size) - 1
            A_c = tf.reshape(A_c, (-1, int(theta_x.shape[0])**2,
                                   (grid_size // 2)**2, (grid_size // 2)**2))
            A_c_removed = tf.concat(
                [A_c[:, :index_to_remove], A_c[:, index_to_remove + 1:]], 1)
            P_matrix_t_reshape = tf.reshape(
                P_matrix_t, (-1, int(theta_x.shape[0])**2,
                             (grid_size // 2)**2, grid_size**2))
            P_matrix_reshape = tf.reshape(
                P_matrix, (-1, int(theta_x.shape[0])**2, grid_size**2,
                           (grid_size // 2)**2))
            A_matrices_reshaped = tf.reshape(
                A_matrices,
                (-1, int(theta_x.shape[0])**2, grid_size**2, grid_size**2))
            A_matrices_removed = tf.concat([
                A_matrices_reshaped[:, :index_to_remove],
                A_matrices_reshaped[:, index_to_remove + 1:]
            ], 1)

            P_matrix_removed = tf.concat([
                P_matrix_reshape[:, :index_to_remove],
                P_matrix_reshape[:, index_to_remove + 1:]
            ], 1)
            P_matrix_t_removed = tf.concat([
                P_matrix_t_reshape[:, :index_to_remove],
                P_matrix_t_reshape[:, index_to_remove + 1:]
            ], 1)
            A_coarse_inv_removed = tf.matrix_solve(A_c_removed,
                                                   P_matrix_t_removed)

            CGC_removed = tf.eye(grid_size ** 2, dtype=tf.complex128) \
                          - tf.matmul(tf.matmul(P_matrix_removed, A_coarse_inv_removed), A_matrices_removed)
            S_matrices_reshaped = tf.reshape(
                S_matrices,
                (-1, int(theta_x.shape[0])**2, grid_size**2, grid_size**2))
            S_removed = tf.concat([
                S_matrices_reshaped[:, :index_to_remove],
                S_matrices_reshaped[:, index_to_remove + 1:]
            ], 1)
            iteration_matrix_all = tf.matmul(tf.matmul(CGC_removed, S_removed),
                                             S_removed)

            if remove:
                if phase != 'Test':
                    iteration_matrix = iteration_matrix_all
                    for _ in range(0):
                        iteration_matrix = tf.matmul(iteration_matrix_all,
                                                     iteration_matrix_all)
                else:
                    iteration_matrix = iteration_matrix_all
                loss = tf.reduce_mean(
                    tf.reduce_max(
                        tf.pow(
                            tf.reduce_sum(tf.square(tf.abs(iteration_matrix)),
                                          [2, 3]), 1), 1))
            else:
                loss = tf.reduce_mean(
                    tf.reduce_mean(
                        tf.reduce_sum(tf.square(tf.abs(iteration_matrix_all)),
                                      [2, 3]), 1))

                print("Real loss: ", loss.numpy())
            real_loss = loss.numpy()
            return loss, real_loss
예제 #3
0
 def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
     diag_mat = tf.conj(self._diag) if adjoint else self._diag
     assert not adjoint_arg
     return utils.matmul_diag_sparse(diag_mat, x)
예제 #4
0
 def _ccorr(self, a, b):
     a = tf.cast(a, tf.complex64)
     b = tf.cast(b, tf.complex64)
     return tf.real(tf.ifft(tf.conj(tf.fft(a)) * tf.fft(b)))
예제 #5
0
 def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
     diag_mat = tf.conj(self._diag) if adjoint else self._diag
     x = linalg.adjoint(x) if adjoint_arg else x
     return diag_mat * x