Esempio n. 1
0
                          name="tau",
                          constraint=lambda t: tf.clip_by_value(t, 0, 1))
        weights = tf.zeros_like(x_true)
        gamma = tf.Variable(0.99,
                            dtype="float32",
                            name="gamma",
                            constraint=lambda t: tf.clip_by_value(t, 0, 1))

    for i in range(n_iter):
        #convolution layer
        with tf.variable_scope('admm_iterations_{}'.format(i),
                               reuse=tf.AUTO_REUSE):
            #first proximal layer
            update2 = x - odl_op_layer_adjoint(m / tau + odl_op_layer(x) -
                                               z) * tau
            update2 = prelu(apply_conv(update2), name='prelu_1')
            update2 = prelu(apply_conv(update2), name='prelu_2')
            update2 = apply_conv(update2, filters=1)
            x = update2

            #second proximal layer

            update = tf.concat([sigma * (odl_op_layer(x) + m / sigma), y_rt],
                               axis=-1)
            update = prelu(apply_conv(update), name='prelu_3')
            update = prelu(apply_conv(update), name='prelu_4')
            update = apply_conv(update, filters=1)
            z = update

            #dual update
primal_values = []
dual_values = []

with tf.name_scope('tomography'):
    with tf.name_scope('initial_values'):
        primal = tf.concat([tf.zeros_like(x_true)] * n_primal, axis=-1)
        dual = tf.concat([tf.zeros_like(y_rt)] * n_dual, axis=-1)

    for i in range(n_iter):
        with tf.variable_scope('dual_iterate_{}'.format(i)):
            evalpt = primal[..., 1:2]
            evalop = tf.maximum(tf.exp(-mu_water * odl_op_layer(evalpt)),
                                tf.exp(-10.0))
            update = tf.concat([dual, evalop, y_rt], axis=-1)

            update = prelu(apply_conv(update), name='prelu_1')
            update = prelu(apply_conv(update), name='prelu_2')
            update = apply_conv(update, filters=n_dual)
            dual = dual + update

        with tf.variable_scope('primal_iterate_{}'.format(i)):
            evalpt_fwd = primal[..., 0:1]
            evalop_fwd = (-mu_water) * tf.exp(
                -mu_water * odl_op_layer(evalpt_fwd))

            evalpt = dual[..., 0:1]
            evalop = odl_op_layer_adjoint(evalop_fwd * dual[..., 0:1])
            update = tf.concat([primal, evalop], axis=-1)

            update = prelu(apply_conv(update), name='prelu_1')
            update = prelu(apply_conv(update), name='prelu_2')
        # primal = tf.concat([x_fbp] * n_primal, axis=3)
        # primal = tf.concat([tf.zeros_like(x_true)] * n_primal, axis=-1)

        # zero initialization, tf.zeros(tf.shape()) should be faster than tf.zeros_like()
        primal = tf.concat([tf.zeros(tf.shape(x_true))] * n_primal, axis=-1)
        dual = tf.concat([tf.zeros(tf.shape(y_rt))] * n_dual, axis=-1)
        print (tf.shape(primal), tf.shape(dual))

    for i in range(n_iter):
        with tf.variable_scope('dual_iterate_{}'.format(i)):
            evalpt = primal[..., 1:2]
            # prevent overflow when attenuation is large
            evalop = tf.maximum(tf.exp(-mu_water * odl_op_layer(evalpt)), tf.exp(-15.0))
            update = tf.concat([dual, evalop, y_rt], axis=-1)
            # print np.shape(update)
            update = prelu(apply_conv(update), name='prelu_1')
            # y_arr, x_true_arr = generate_data()
            # print (sess.run(update, feed_dict={x_true: x_true_arr,
            #                              y_rt: y_arr}))
            update = prelu(apply_conv(update), name='prelu_2')
            update = apply_conv(update, filters=n_dual)
            dual = dual + update

        with tf.variable_scope('primal_iterate_{}'.format(i)):
            evalpt_fwd = primal[..., 0:1]
            evalop_fwd = (-mu_water) * tf.maximum(tf.exp(-mu_water * odl_op_layer(evalpt_fwd)), tf.exp(-15.0))

            evalpt = dual[..., 0:1]
            evalop = odl_op_layer_adjoint(evalop_fwd * evalpt)
            update = tf.concat([primal, evalop], axis=-1)