Ejemplo n.º 1
0
 def body(i, rTr, x, r, p):
     with tf.name_scope('cgBody'):
         Ap = M(p)
         alpha = rTr / tf.math.real(dlmri_tutorial.complex_dot(p, Ap))
         x = x + dlmri_tutorial.complex_scale(p, alpha)
         r = r - dlmri_tutorial.complex_scale(Ap, alpha)
         rTrNew = tf.math.real(dlmri_tutorial.complex_dot(r, r))
         beta = rTrNew / rTr
         p = r + dlmri_tutorial.complex_scale(p, beta)
     return i + 1, rTrNew, x, r, p
Ejemplo n.º 2
0
 def call(self, inputs, scale=1.0):
     x = inputs[0]
     y = inputs[1]
     constants = inputs[2:]
     return x - dlmri_tutorial.complex_scale(
         self.AH(self.A(x, *constants) - y, *constants),
         self.weight * scale)
Ejemplo n.º 3
0
 def call(self, kspace, *args):
     axes = [tf.rank(kspace) - 2,
             tf.rank(kspace) - 1]  # axes have to be positive...
     dtype = tf.math.real(kspace).dtype
     scale = tf.math.sqrt(
         tf.cast(tf.math.reduce_prod(tf.shape(kspace)[-2:]), dtype))
     return dlmri_tutorial.complex_scale(
         fftshift(ifft2d(ifftshift(kspace, axes=axes)), axes=axes), scale)
Ejemplo n.º 4
0
        def fn(inputs):
            x = inputs[0]
            y = inputs[1]
            constants = inputs[2:]
            rhs = self.AH(y, *constants) + dlmri_tutorial.complex_scale(x, lambdaa)

            def M(p):
                return self.AH(self.A(p, *constants), *constants) + \
                       dlmri_tutorial.complex_scale(p, lambdaa)

            out = cg(M, rhs, self.max_iter, self.tol)
            return out, rhs
Ejemplo n.º 5
0
        def grad(e):
            #lambdaa = variables[0]
            def fn_grad(inputs):
                e = inputs[0]
                constants = inputs[1:]
                def M(p):
                    return self.AH(self.A(p, *constants), *constants) + \
                           dlmri_tutorial.complex_scale(p, lambdaa)
                Qe = cg(M, e, self.max_iter, self.tol)
                QQe = cg(M, Qe, self.max_iter, self.tol)
                return Qe, QQe

            Qe, QQe = tf.map_fn(fn_grad, (e, *constants),
                            fn_output_signature=(x.dtype, x.dtype), 
                            name='mapFnGrad', 
                            parallel_iterations=self.parallel_iterations)

            dx = dlmri_tutorial.complex_scale(Qe, lambdaa)
            dlambdaa = tf.reduce_sum(dlmri_tutorial.complex_dot(Qe, x, axis=tf.range(1,tf.rank(x)))) - \
                       tf.reduce_sum(dlmri_tutorial.complex_dot(QQe, rhs, axis=tf.range(1,tf.rank(x))))
            dlambdaa = tf.math.real(dlambdaa)
            return [dlambdaa, dx, None] + [None for _ in constants]
Ejemplo n.º 6
0
 def M(p):
     return self.AH(self.A(p, *constants), *constants) + \
            dlmri_tutorial.complex_scale(p, lambdaa)
Ejemplo n.º 7
0
 def call(self, image, *args):
     dtype = tf.math.real(image).dtype
     scale = tf.math.sqrt(
         tf.cast(tf.math.reduce_prod(tf.shape(image)[-2:]), dtype))
     return dlmri_tutorial.complex_scale(fft2d(image), 1 / scale)
Ejemplo n.º 8
0
 def call(self, kspace, *args):
     dtype = tf.math.real(kspace).dtype
     scale = tf.math.sqrt(
         tf.cast(tf.math.reduce_prod(tf.shape(kspace)[-2:]), dtype))
     return dlmri_tutorial.complex_scale(ifft2d(kspace), scale)
Ejemplo n.º 9
0
 def call(self, kspace, mask):
     return dlmri_tutorial.complex_scale(kspace, mask)