#                                        momentum=0.9,use_nesterov=True)
        # train = opt_func.minimize(loss)


        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), 1)
        optimizer = opt_func.apply_gradients(zip(grads, tvars),
                                             global_step=global_step)


# Summaries
# tensorboard --logdir=...

with tf.name_scope('summaries'):
    tf.summary.scalar('loss', loss)
    tf.summary.scalar('psnr', psnr(x_result, x_true, 'compute_psnr'))

    tf.summary.image('x_result', x_result)
    tf.summary.image('x_true', x_true)
    tf.summary.image('squared_error', squared_error)
    tf.summary.image('residual', residual)

    merged_summary = tf.summary.merge_all()
    test_summary_writer = tf.summary.FileWriter(adler.tensorflow.util.default_tensorboard_dir(name) + '/test',
                                                sess.graph)
    train_summary_writer = tf.summary.FileWriter(adler.tensorflow.util.default_tensorboard_dir(name) + '/train')

# Initialize all TF variables
sess.run(tf.global_variables_initializer())

# Add op to save and restore
    [primal_values, dual_values],
    feed_dict={
        x_true: x_true_arr_validate,
        y_rt: y_arr_validate,
        is_training: False
    })

import matplotlib.pyplot as plt
from skimage.measure import compare_ssim as ssim
from skimage.measure import compare_psnr as psnr

print(ssim(primal_values_result[-1][0, ..., 0], x_true_arr_validate[0, ...,
                                                                    0]))
print(
    psnr(primal_values_result[-1][0, ..., 0],
         x_true_arr_validate[0, ..., 0],
         dynamic_range=np.max(x_true_arr_validate) -
         np.min(x_true_arr_validate)))


def normalized(val, sign=False):
    if sign:
        val = val * np.sign(np.mean(val))
    return (val - np.mean(val)) / np.std(val)


path = name
for i in range(n_iter):
    vals = primal_values_result[i]
    space.element(vals[..., 0]).show(saveto='{}/x_{}'.format(path, i))
    space.element(vals[..., 0]).show(clim=[0.8, 1.2],
                                     saveto='{}/x_windowed_{}'.format(path, i))