示例#1
0
def exp_pxy():
  def trans(x, y, Px, Py):
    p = 1 + Px * x + Py * y
    return [float(x / (p + np.random.randn() * 0)), float(y / (p + np.random.randn() * 0))]

  dtype = 'float32'
  def sampling(w, h, P):
    X = np.random.randint(0, 256, size = P.shape)
    return np.array([[x[0], x[1]] + trans(x[0], x[1], p[0], p[1]) for x, p in zip(X, P)], dtype=dtype)

  np.random.seed(2012310818)
  n = 1000
  P = np.random.rand(n, 2).astype(dtype) * 1e-3
  X = sampling(256, 256, P)
  P *= 10000

  network = OrderedDict()

  if dtype == 'float64':
    input_var = T.dmatrix('input')
    psp_var = T.dmatrix('psp')
  else:
    input_var = T.fmatrix('input')
    psp_var = T.fmatrix('psp')
    
  theano.config.floatX = dtype
  network['input'] = lasagne.layers.InputLayer(shape = (None, 4), input_var = input_var)
  network['fc1'] = lasagne.layers.DenseLayer(network['input'], num_units=1024,
                                             nonlinearity=lasagne.nonlinearities.rectify)
  network['fc2'] = lasagne.layers.DenseLayer(network['fc1'], num_units=1024,
                                             nonlinearity=lasagne.nonlinearities.rectify)
  network['fc3'] = lasagne.layers.DenseLayer(network['fc2'], num_units=2,
                                             nonlinearity=lasagne.nonlinearities.rectify)

  pred = lasagne.layers.get_output(network['fc3'])
  loss = lasagne.objectives.squared_error(psp_var, pred).mean()
  paras = lasagne.layers.get_all_params(network['fc3'])
  updates = lasagne.updates.adam(loss, paras, 0.0001)
  ftrain = theano.function([input_var, psp_var], [loss, pred], updates = updates)

  theano.config.floatX = 'float32'

  def iterate_minibatch(idx, batchsize, l):
    for start in range(0, l, batchsize):
      yidx = idx[start : start + batchsize]
      yield (X[yidx], P[yidx])
      
  idx = np.random.permutation(np.arange(0, n))
  for i in range(100):
    lval = 0
    cnt = 0
    for batch in iterate_minibatch(idx, 32, n):
      lval += ftrain(batch[0], batch[1])[0]
      cnt += 1
    print(i, lval / cnt)
示例#2
0
def exp_raw(dtype):
  shp = (None, 3, 256, 256)
  input_var = T.tensor4('input_var', dtype = 'float32')
  psp = T.dmatrix("psp")
  network = OrderedDict()
  network['input'] = lasagne.layers.InputLayer(shape = shp, input_var = input_var)
  # network = make_vgg16(network, 'model/vgg16_weights_from_caffe.h5')
  # First conv and segmentation part
  network['conv1_1'] = lasagne.layers.Conv2DLayer(network['input'],
    num_filters = 64, filter_size = (3, 3),nonlinearity = lasagne.nonlinearities.rectify,
    W=lasagne.init.GlorotUniform())
  network['conv1_2'] = lasagne.layers.Conv2DLayer(network['conv1_1'],
    num_filters = 64, filter_size = (3, 3), nonlinearity = lasagne.nonlinearities.rectify)
  network['pool1_1'] = lasagne.layers.MaxPool2DLayer(network['conv1_2'], pool_size = (2, 2))
  network['norm1_1'] = lasagne.layers.BatchNormLayer(network['pool1_1'])

  network['conv1_3'] = lasagne.layers.Conv2DLayer(network['norm1_1'],
    num_filters = 128, filter_size = (3, 3), nonlinearity = lasagne.nonlinearities.rectify)
  network['conv1_4'] = lasagne.layers.Conv2DLayer(network['conv1_3'],
    num_filters = 128, filter_size = (3, 3), nonlinearity = lasagne.nonlinearities.rectify)
  network['pool1_2'] = lasagne.layers.MaxPool2DLayer(network['conv1_4'], pool_size = (2, 2))
  network['norm1_2'] = lasagne.layers.BatchNormLayer(network['pool1_2'])

  network['conv1_5'] = lasagne.layers.Conv2DLayer(network['norm1_2'],
    num_filters = 256, filter_size = (3, 3), nonlinearity = lasagne.nonlinearities.rectify)
  network['pool1_3'] = lasagne.layers.MaxPool2DLayer(network['conv1_5'], pool_size = (2, 2))

  network['conv1_6'] = lasagne.layers.Conv2DLayer(network['pool1_3'],
    num_filters = 256, filter_size = (3, 3), nonlinearity = lasagne.nonlinearities.rectify)
  network['pool1_4'] = lasagne.layers.MaxPool2DLayer(network['conv1_6'], pool_size = (2, 2))

  # Perspective Transform
  network['norm2'] = lasagne.layers.BatchNormLayer(network['pool1_4'])
  # network['cast'] = CastingLayer(network['norm2'], dtype)
  theano.config.floatX = dtype 
  network['pfc2_1'] = lasagne.layers.DenseLayer(
    lasagne.layers.dropout(network['norm2'], p = 0.05),
    num_units = 1024, nonlinearity = lasagne.nonlinearities.rectify)
  network['pfc2_2'] = lasagne.layers.DenseLayer(
    lasagne.layers.dropout(network['pfc2_1'], p=0.05),
    num_units = 1024, nonlinearity = lasagne.nonlinearities.rectify)
  network['pfc2_3'] = lasagne.layers.DenseLayer(
    lasagne.layers.dropout(network['pfc2_2'], p=0.05),
    num_units = 1024, nonlinearity = lasagne.nonlinearities.rectify)
  # loss target 2
  network['pfc_out'] = lasagne.layers.DenseLayer(
    lasagne.layers.dropout(network['pfc2_3'], p = 0.05),
    num_units = 8, nonlinearity = lasagne.nonlinearities.rectify)
  theano.config.floatX = 'float32'

  predict = lasagne.layers.get_output(network['pfc_out'])
  loss = T.sqrt(lasagne.objectives.squared_error(predict, psp).mean())
  paras = lasagne.layers.get_all_params(network['pfc_out'], trainable = True)
  updates = adam(loss, paras, [theano.shared(np.float32(0.0001)) for i in range(len(paras))])
  ftrain = theano.function([input_var, psp], [loss, predict], updates = updates)

  def get_inputs(meta, batch, path):
    # batchidx = [keys[i] for i in batch]
    input = np.array([read_image(path + 'patch/' + idx + '.jpg', shape = (256, 256))
      for idx in batch]).astype(np.float32)
    seg = np.array([read_image(path + 'pmask/' + idx + '.jpg', shape = (256, 256))
      for idx in batch]).astype(np.float32)
    dat = [meta[key] for key in batch]
    Ps = np.array([np.array(dat[i][0]).flatten()[0 : 8] for i in range(len(batch))])
    for P in Ps:
      P[6 : 8] = (P[6 : 8] + 1e-3) * 1e4
    return input, Ps

  path = '/home/yancz/text_generator/data/real/'
  dat, meta = load_data(path, 10000, False)
  for epoch in range(10):
    loss = 0
    trs = 0
    for batch in iterate_minibatch(dat['train'], 32, len(dat['train'])):
      inputs = get_inputs(meta, batch, path)
      l, valp = ftrain(*inputs)
      log(l)
      print(valp)
      loss += l
      trs += 1
    loss /= trs
    log('loss ' + str(epoch) + ' ' + str(l))
  return ftrain
示例#3
0
def evaluate(test_data, mask, model_save_path, model_file):
    with tf.Graph().as_default() as g:
        #x = tf.placeholder(tf.float32, shape=[None, 6, 117, 120, 2], name='x-input')
        y_ = tf.placeholder(tf.float32,
                            shape=[None, 6, 117, 120, 2],
                            name='y-label')
        mask_p = tf.placeholder(tf.complex64,
                                shape=[None, 6, 117, 120],
                                name='mask')
        kspace_p = tf.placeholder(tf.complex64,
                                  shape=[None, 6, 117, 120],
                                  name='kspace')
        kspace_full = tf.placeholder(tf.complex64,
                                     shape=[None, 6, 117, 120],
                                     name='kspace_full')

        y, block_1, block_2, block_3, block_k_1 = inference.inference(
            mask_p, kspace_p, None)

        loss = tf.reduce_mean(
            tf.reduce_mean(tf.squared_difference(y, y_), [1, 2, 3, 4]))

        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(model_save_path)
            saver = tf.train.Saver()
            test_case = 'show image'
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                if __name__ == '__main__':
                    if test_case == 'check_loss':
                        count = 0
                        for ys in train.iterate_minibatch(test_data,
                                                          batch_size,
                                                          shuffle=True):
                            xs_l, kspace_l, mask_l, ys_l, k_full_l = train.prep_input(
                                test_data, mask)
                            loss_value, y_pred, block_1_pred, block_2_pred, block_3_pred, block_k_pred = sess.run(
                                [
                                    loss, y, block_1, block_2, block_3,
                                    block_k_1
                                ],
                                feed_dict={
                                    y_: ys_l,
                                    mask_p: mask_l,
                                    kspace_p: kspace_l,
                                    kspace_full: k_full_l
                                })
                            print("The loss of No.{} test data = {}".format(
                                count + 1, loss_value))

                            y_c = real2complex(y_pred)
                            xs_c = real2complex(xs_l)
                            base_mse, test_mse, base_psnr, \
                            test_psnr, base_ssim, test_ssim = performance(xs_c, y_c, ys)
                            print("test loss:\t\t{:.6f}".format(loss_value))
                            print("test psnr:\t\t{:.6f}".format(test_psnr))
                            print("base psnr:\t\t{:.6f}".format(base_psnr))
                            print("base mse:\t\t{:.6f}".format(base_mse))
                            print("test mse:\t\t{:.6f}".format(test_mse))
                            print("base ssim:\t\t{:.6f}".format(base_ssim))
                            print("test ssim:\t\t{:.6f}".format(test_ssim))
                            count += 1
                    elif test_case == 'show image':
                        project_root = '.'

                        figure_save_path = join(
                            project_root, 'result/images/%s' % model_file)
                        if not os.path.isdir(figure_save_path):
                            os.makedirs(figure_save_path)

                        mat_save_path = join(project_root,
                                             'result/mat/%s' % model_file)
                        if not os.path.isdir(mat_save_path):
                            os.makedirs(mat_save_path)

                        quantization_save_path = join(
                            project_root,
                            'result/quantization/%s' % model_file)
                        if not os.path.isdir(quantization_save_path):
                            os.makedirs(quantization_save_path)

                        Test_MSE = []
                        Test_PSNR = []
                        Test_SSIM = []

                        Base_MSE = []
                        Base_PSNR = []
                        Base_SSIM = []

                        for order in range(0, 100):
                            ys = test_data[order]
                            ys = ys[np.newaxis, :]
                            xs_l, kspace_l, mask_l, ys_l, k_full_l = train.prep_input(
                                ys, mask)
                            time_start = time.time()
                            loss_value, y_pred, block_1_pred, block_2_pred, block_3_pred, k_recon_pred = sess.run(
                                [
                                    loss, y, block_1, block_2, block_3,
                                    block_k_1
                                ],
                                feed_dict={
                                    y_: ys_l,
                                    mask_p: mask_l,
                                    kspace_p: kspace_l,
                                    kspace_full: k_full_l
                                })
                            time_end = time.time()
                            y_pred_new = real2complex(y_pred)
                            xs = real2complex(xs_l)
                            if order == 0:
                                order_x = 100
                            elif order == 1:
                                order_x = 60
                            elif order == 2:
                                order_x = 85
                            elif order == 6:
                                order_x = 40
                            else:
                                order_x = 55
                            # order_x = 55 # (order, order_x): (0, 100), (1, 60), (6, 40), (7, 55)
                            ys_t = ys[:, :, order_x, :]
                            y_pred_t = y_pred_new[:, :, order_x, :]
                            xs_t = xs[:, :, order_x, :]
                            xs_t_error = ys_t - xs_t
                            y_pred_error = ys_t - y_pred_t

                            base_mse, test_mse, base_psnr, \
                            test_psnr, base_ssim, test_ssim = performance(xs, y_pred_new, ys)

                            print("test time:\t\t{:.6f}".format(time_end -
                                                                time_start))
                            print("test loss:\t\t{:.6f}".format(loss_value))
                            print("test psnr:\t\t{:.6f}".format(test_psnr))
                            print("base psnr:\t\t{:.6f}".format(base_psnr))
                            print("base mse:\t\t{:.6f}".format(base_mse))
                            print("test mse:\t\t{:.6f}".format(test_mse))
                            print("base ssim:\t\t{:.6f}".format(base_ssim))
                            print("test ssim:\t\t{:.6f}".format(test_ssim))
                            base_mse = ("%.6f" % base_mse)
                            test_mse = ("%.6f" % test_mse)

                            Test_MSE.append(test_mse)
                            Test_PSNR.append(test_psnr)
                            Test_SSIM.append(test_ssim)

                            Base_MSE.append(base_mse)
                            Base_PSNR.append(base_psnr)
                            Base_SSIM.append(base_ssim)

                            train_plot = np.load(
                                join(project_root, 'models/%s' % model_file,
                                     'train_plot.npy'))
                            validate_plot = np.load(
                                join(project_root, 'models/%s' % model_file,
                                     'validate_plot.npy'))
                            [
                                num_train_plot,
                            ] = train_plot.shape
                            [
                                num_validate_plot,
                            ] = validate_plot.shape
                            x1 = np.arange(1, num_train_plot + 1)
                            x2 = np.arange(1, num_validate_plot + 1)

                            plt.figure(15)
                            l1, = plt.plot(x1, train_plot)
                            l2, = plt.plot(x2, validate_plot)
                            plt.legend(
                                handles=[
                                    l1,
                                    l2,
                                ],
                                labels=['train loss', 'validation loss'],
                                loc=1)
                            plt.xlabel('epoch')
                            plt.ylabel('loss')
                            plt.title('loss')
                            if not os.path.exists(
                                    join(figure_save_path, 'loss.tif')):
                                plt.savefig(join(figure_save_path, 'loss.tif'),
                                            dpi=300)
                            #plt.show()

                        scio.savemat(join(quantization_save_path, 'Test_MSE'),
                                     {'test_mse': Test_MSE})
                        scio.savemat(join(quantization_save_path, 'Test_PSNR'),
                                     {'test_psnr': Test_PSNR})
                        scio.savemat(join(quantization_save_path, 'Test_SSIM'),
                                     {'test_ssim': Test_SSIM})

                        scio.savemat(join(quantization_save_path, 'Base_MSE'),
                                     {'base_mse': Base_MSE})
                        scio.savemat(join(quantization_save_path, 'Base_PSNR'),
                                     {'base_psnr': Base_PSNR})
                        scio.savemat(join(quantization_save_path, 'Base_SSIM'),
                                     {'base_ssim': Base_SSIM})
                        #elif test_case == "Save image":

            else:
                print("No checkpoint file found")
示例#4
0
def evaluate(test_data, mask, model_save_path, model_file):
    with tf.Graph().as_default() as g:
        #x = tf.placeholder(tf.float32, shape=[None, 6, 117, 120, 2], name='x-input')
        y_ = tf.placeholder(tf.float32,
                            shape=[None, 6, 117, 120, 2],
                            name='y-label')
        mask_p = tf.placeholder(tf.complex64,
                                shape=[None, 6, 117, 120],
                                name='mask')
        kspace_p = tf.placeholder(tf.complex64,
                                  shape=[None, 6, 117, 120],
                                  name='kspace')
        kspace_full = tf.placeholder(tf.complex64,
                                     shape=[None, 6, 117, 120],
                                     name='kspace_full')

        y, block_1, block_2, block_3, block_k_1 = inference.inference(
            mask_p, kspace_p, None)

        loss = tf.reduce_mean(
            tf.reduce_mean(tf.squared_difference(y, y_), [1, 2, 3, 4]))

        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(model_save_path)
            saver = tf.train.Saver()
            test_case = 'show image'
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                if __name__ == '__main__':
                    if test_case == 'check_loss':
                        count = 0
                        for ys in train.iterate_minibatch(test_data,
                                                          batch_size,
                                                          shuffle=True):
                            xs_l, kspace_l, mask_l, ys_l, k_full_l = train.prep_input(
                                test_data, mask)
                            loss_value, y_pred, block_1_pred, block_2_pred, block_3_pred, block_k_pred = sess.run(
                                [
                                    loss, y, block_1, block_2, block_3,
                                    block_k_1
                                ],
                                feed_dict={
                                    y_: ys_l,
                                    mask_p: mask_l,
                                    kspace_p: kspace_l,
                                    kspace_full: k_full_l
                                })
                            print("The loss of No.{} test data = {}".format(
                                count + 1, loss_value))

                            y_c = real2complex(y_pred)
                            xs_c = real2complex(xs_l)
                            base_mse, test_mse, base_psnr, \
                            test_psnr, base_ssim, test_ssim = performance(xs_c, y_c, ys)
                            print("test loss:\t\t{:.6f}".format(loss_value))
                            print("test psnr:\t\t{:.6f}".format(test_psnr))
                            print("base psnr:\t\t{:.6f}".format(base_psnr))
                            print("base mse:\t\t{:.6f}".format(base_mse))
                            print("test mse:\t\t{:.6f}".format(test_mse))
                            print("base ssim:\t\t{:.6f}".format(base_ssim))
                            print("test ssim:\t\t{:.6f}".format(test_ssim))
                            count += 1
                    elif test_case == 'show image':
                        project_root = '.'

                        figure_save_path = join(
                            project_root, 'result/images/%s' % model_file)
                        if not os.path.isdir(figure_save_path):
                            os.makedirs(figure_save_path)

                        mat_save_path = join(project_root,
                                             'result/mat/%s' % model_file)
                        if not os.path.isdir(mat_save_path):
                            os.makedirs(mat_save_path)

                        quantization_save_path = join(
                            project_root,
                            'result/quantization/%s' % model_file)
                        if not os.path.isdir(quantization_save_path):
                            os.makedirs(quantization_save_path)

                        Test_MSE = []
                        Test_PSNR = []
                        Test_SSIM = []

                        K_Test_MSE = []
                        K_Test_PSNR = []
                        K_Test_SSIM = []

                        Block1_Test_MSE = []
                        Block1_Test_PSNR = []
                        Block1_Test_SSIM = []

                        Block2_Test_MSE = []
                        Block2_Test_PSNR = []
                        Block2_Test_SSIM = []

                        Block3_Test_MSE = []
                        Block3_Test_PSNR = []
                        Block3_Test_SSIM = []

                        Base_MSE = []
                        Base_PSNR = []
                        Base_SSIM = []

                        for order in range(0, 100):
                            ys = test_data[order]
                            ys = ys[np.newaxis, :]
                            xs_l, kspace_l, mask_l, ys_l, k_full_l = train.prep_input(
                                ys, mask)
                            time_start = time.time()
                            loss_value, y_pred, block_1_pred, block_2_pred, block_3_pred, k_recon_pred = sess.run(
                                [
                                    loss, y, block_1, block_2, block_3,
                                    block_k_1
                                ],
                                feed_dict={
                                    y_: ys_l,
                                    mask_p: mask_l,
                                    kspace_p: kspace_l,
                                    kspace_full: k_full_l
                                })
                            time_end = time.time()
                            y_pred_new = real2complex(y_pred)
                            k_recon_pred = real2complex(k_recon_pred)
                            block_1_pred = real2complex(block_1_pred)
                            block_2_pred = real2complex(block_2_pred)
                            block_3_pred = real2complex(block_3_pred)
                            xs = real2complex(xs_l)
                            if order == 0:
                                order_x = 100
                            elif order == 1:
                                order_x = 60
                            elif order == 2:
                                order_x = 85
                            elif order == 6:
                                order_x = 40
                            else:
                                order_x = 55
                            # order_x = 55 # (order, order_x): (0, 100), (1, 60), (6, 40), (7, 55)
                            ys_t = ys[:, :, order_x, :]
                            y_pred_t = y_pred_new[:, :, order_x, :]
                            xs_t = xs[:, :, order_x, :]
                            xs_t_error = ys_t - xs_t
                            y_pred_error = ys_t - y_pred_t

                            base_mse, test_mse, base_psnr, \
                            test_psnr, base_ssim, test_ssim = performance(xs, y_pred_new, ys)

                            base_mse, k_test_mse, base_psnr, \
                            k_test_psnr, base_ssim, k_test_ssim = performance(xs, k_recon_pred, ys)

                            base_mse, block1_test_mse, base_psnr, \
                            block1_test_psnr, base_ssim, block1_test_ssim = performance(xs, block_1_pred, ys)

                            base_mse, block2_test_mse, base_psnr, \
                            block2_test_psnr, base_ssim, block2_test_ssim = performance(xs, block_2_pred, ys)

                            base_mse, block3_test_mse, base_psnr, \
                            block3_test_psnr, base_ssim, block3_test_ssim = performance(xs, block_3_pred, ys)

                            print("test time:\t\t{:.6f}".format(time_end -
                                                                time_start))
                            print("test loss:\t\t{:.6f}".format(loss_value))
                            print("test psnr:\t\t{:.6f}".format(test_psnr))
                            print("base psnr:\t\t{:.6f}".format(base_psnr))
                            print("base mse:\t\t{:.6f}".format(base_mse))
                            print("test mse:\t\t{:.6f}".format(test_mse))
                            print("base ssim:\t\t{:.6f}".format(base_ssim))
                            print("test ssim:\t\t{:.6f}".format(test_ssim))
                            base_mse = ("%.6f" % base_mse)
                            test_mse = ("%.6f" % test_mse)
                            k_test_mse = ("%.6f" % k_test_mse)
                            block1_test_mse = ("%.6f" % block1_test_mse)
                            block2_test_mse = ("%.6f" % block2_test_mse)
                            block3_test_mse = ("%.6f" % block3_test_mse)

                            Test_MSE.append(test_mse)
                            Test_PSNR.append(test_psnr)
                            Test_SSIM.append(test_ssim)

                            K_Test_MSE.append(k_test_mse)
                            K_Test_PSNR.append(k_test_psnr)
                            K_Test_SSIM.append(k_test_ssim)

                            Block1_Test_MSE.append(block1_test_mse)
                            Block1_Test_PSNR.append(block1_test_psnr)
                            Block1_Test_SSIM.append(block1_test_ssim)

                            Block2_Test_MSE.append(block2_test_mse)
                            Block2_Test_PSNR.append(block2_test_psnr)
                            Block2_Test_SSIM.append(block2_test_ssim)

                            Block3_Test_MSE.append(block3_test_mse)
                            Block3_Test_PSNR.append(block3_test_psnr)
                            Block3_Test_SSIM.append(block3_test_ssim)

                            Base_MSE.append(base_mse)
                            Base_PSNR.append(base_psnr)
                            Base_SSIM.append(base_ssim)

                            mask_shift = mymath.fftshift(mask, axes=(-1, -2))
                            gamma = 1
                            plt.figure(1)
                            plt.subplot(221)
                            plt.imshow(
                                exposure.adjust_gamma(np.abs(ys[0][0]), gamma),
                                plt.cm.gray)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title('ground truth')
                            plt.subplot(222)
                            plt.imshow(
                                exposure.adjust_gamma(abs(mask_shift[0][0]),
                                                      gamma), plt.cm.gray)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title('mask')
                            plt.subplot(223)
                            plt.imshow(
                                exposure.adjust_gamma(abs(xs[0][0]), gamma),
                                plt.cm.gray)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("undersampling")
                            plt.subplot(224)
                            plt.imshow(
                                exposure.adjust_gamma(abs(y_pred_new[0][0]),
                                                      gamma), plt.cm.gray)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("reconstruction")
                            plt.savefig(join(figure_save_path,
                                             'test%s.tif' % order),
                                        dpi=300)

                            plt.figure(2)
                            plt.imshow(
                                exposure.adjust_gamma(np.abs(ys[0][0]), gamma),
                                plt.cm.gray)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title('ground truth')
                            scio.savemat(join(mat_save_path, 'gr%s' % order),
                                         {'gr': abs(ys[0][0])})
                            plt.savefig(join(figure_save_path,
                                             'gr%s.tif' % order),
                                        dpi=300)

                            plt.figure(3)
                            plt.imshow(
                                exposure.adjust_gamma(abs(xs[0][0]), gamma),
                                plt.cm.gray)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("undersampling: " + base_mse + ' ' +
                                      str(round(base_psnr, 5)) + ' ' +
                                      str(round(base_ssim, 4)))
                            scio.savemat(
                                join(mat_save_path, 'under%s' % order),
                                {'under': abs(xs[0][0])})
                            plt.savefig(join(figure_save_path,
                                             'under%s.tif' % order),
                                        dpi=300)

                            plt.figure(4)
                            plt.imshow(
                                exposure.adjust_gamma(abs(y_pred_new[0][0]),
                                                      gamma), plt.cm.gray)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("reconstruction: " + test_mse + ' ' +
                                      str(round(test_psnr, 5)) + ' ' +
                                      str(round(test_ssim, 4)))
                            scio.savemat(
                                join(mat_save_path, 'recon%s' % order),
                                {'recon': abs(y_pred_new[0][0])})
                            plt.savefig(join(figure_save_path,
                                             'recon%s.tif' % order),
                                        dpi=300)

                            plt.figure(5)
                            plt.imshow(
                                exposure.adjust_gamma(abs(k_recon_pred[0][0]),
                                                      gamma), plt.cm.gray)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("k_recon: " + k_test_mse + ' ' +
                                      str(round(k_test_psnr, 5)) + ' ' +
                                      str(round(k_test_ssim, 4)))
                            scio.savemat(
                                join(mat_save_path, 'k_recon%s' % order),
                                {'k_recon': abs(k_recon_pred[0][0])})
                            plt.savefig(join(figure_save_path,
                                             'k_recon%s.tif' % order),
                                        dpi=300)

                            plt.figure(6)
                            plt.imshow(
                                exposure.adjust_gamma(abs(block_1_pred[0][0]),
                                                      gamma), plt.cm.gray)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("block1_recon: " + block1_test_mse +
                                      ' ' + str(round(block1_test_psnr, 5)) +
                                      ' ' + str(round(block1_test_ssim, 4)))
                            scio.savemat(
                                join(mat_save_path, 'block1_recon%s' % order),
                                {'block1_recon': abs(block_1_pred[0][0])})
                            plt.savefig(join(figure_save_path,
                                             'block1_recon%s.tif' % order),
                                        dpi=300)

                            plt.figure(7)
                            plt.imshow(
                                exposure.adjust_gamma(abs(block_2_pred[0][0]),
                                                      gamma), plt.cm.gray)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("block2_recon: " + block2_test_mse +
                                      ' ' + str(round(block2_test_psnr, 5)) +
                                      ' ' + str(round(block2_test_ssim, 4)))
                            scio.savemat(
                                join(mat_save_path, 'block2_recon%s' % order),
                                {'block2_recon': abs(block_2_pred[0][0])})
                            plt.savefig(join(figure_save_path,
                                             'block2_recon%s.tif' % order),
                                        dpi=300)

                            plt.figure(8)
                            plt.imshow(
                                exposure.adjust_gamma(abs(block_3_pred[0][0]),
                                                      gamma), plt.cm.gray)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("block3_recon: " + block3_test_mse +
                                      ' ' + str(round(block3_test_psnr, 5)) +
                                      ' ' + str(round(block3_test_ssim, 4)))
                            scio.savemat(
                                join(mat_save_path, 'block3_recon%s' % order),
                                {'block3_recon': abs(block_3_pred[0][0])})
                            plt.savefig(join(figure_save_path,
                                             'block3_recon%s.tif' % order),
                                        dpi=300)

                            plt.figure(9)
                            plt.imshow(exposure.adjust_gamma(
                                abs(abs(ys[0][0]) - abs(y_pred_new[0][0])),
                                gamma),
                                       vmin=0,
                                       vmax=0.07)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("error: " + test_mse + ' ' +
                                      str(round(test_psnr, 5)) + ' ' +
                                      str(round(test_ssim, 4)))
                            scio.savemat(
                                join(mat_save_path, 'error%s' % order), {
                                    'error':
                                    abs(abs(ys[0][0]) - abs(y_pred_new[0][0]))
                                })
                            plt.savefig(join(figure_save_path,
                                             'error%s.tif' % order),
                                        dpi=300)

                            plt.figure(10)
                            plt.imshow(exposure.adjust_gamma(
                                abs(abs(ys[0][0]) - abs(k_recon_pred[0][0])),
                                gamma),
                                       vmin=0,
                                       vmax=0.07)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("k_error: " + k_test_mse + ' ' +
                                      str(round(k_test_psnr, 5)) + ' ' +
                                      str(round(k_test_ssim, 4)))
                            scio.savemat(
                                join(mat_save_path, 'k_error%s' % order), {
                                    'k_error':
                                    abs(
                                        abs(ys[0][0]) -
                                        abs(k_recon_pred[0][0]))
                                })
                            plt.savefig(join(figure_save_path,
                                             'k_error%s.tif' % order),
                                        dpi=300)

                            plt.figure(11)
                            plt.imshow(exposure.adjust_gamma(
                                abs(abs(ys[0][0]) - abs(block_1_pred[0][0])),
                                gamma),
                                       vmin=0,
                                       vmax=0.07)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("block1_error: " + block1_test_mse +
                                      ' ' + str(round(block1_test_psnr, 5)) +
                                      ' ' + str(round(block1_test_ssim, 4)))
                            scio.savemat(
                                join(mat_save_path, 'block1_error%s' % order),
                                {
                                    'block1_error':
                                    abs(
                                        abs(ys[0][0]) -
                                        abs(block_1_pred[0][0]))
                                })
                            plt.savefig(join(figure_save_path,
                                             'block1_error%s.tif' % order),
                                        dpi=300)

                            plt.figure(12)
                            plt.imshow(exposure.adjust_gamma(
                                abs(abs(ys[0][0]) - abs(block_2_pred[0][0])),
                                gamma),
                                       vmin=0,
                                       vmax=0.07)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("block2_error: " + block2_test_mse +
                                      ' ' + str(round(block2_test_psnr, 5)) +
                                      ' ' + str(round(block2_test_ssim, 4)))
                            scio.savemat(
                                join(mat_save_path, 'block2_error%s' % order),
                                {
                                    'block2_error':
                                    abs(
                                        abs(ys[0][0]) -
                                        abs(block_2_pred[0][0]))
                                })
                            plt.savefig(join(figure_save_path,
                                             'block2_error%s.tif' % order),
                                        dpi=300)

                            plt.figure(13)
                            plt.imshow(exposure.adjust_gamma(
                                abs(abs(ys[0][0]) - abs(block_3_pred[0][0])),
                                gamma),
                                       vmin=0,
                                       vmax=0.07)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("block3_error: " + block3_test_mse +
                                      ' ' + str(round(block3_test_psnr, 5)) +
                                      ' ' + str(round(block3_test_ssim, 4)))
                            scio.savemat(
                                join(mat_save_path, 'block3_error%s' % order),
                                {
                                    'block3_error':
                                    abs(
                                        abs(ys[0][0]) -
                                        abs(block_3_pred[0][0]))
                                })
                            plt.savefig(join(figure_save_path,
                                             'block3_error%s.tif' % order),
                                        dpi=300)

                            plt.figure(14)
                            plt.subplot(511)
                            plt.imshow(np.abs(ys_t[0]), plt.cm.gray)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("gnd_t_y")
                            plt.subplot(512)
                            plt.imshow(np.abs(xs_t[0]), plt.cm.gray)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("under_t_y")
                            plt.subplot(513)
                            plt.imshow(np.abs(xs_t_error[0]))
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("under_t_y_error")
                            plt.subplot(514)
                            plt.imshow(np.abs(y_pred_t[0]), plt.cm.gray)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("recon_t_y")
                            plt.subplot(515)
                            plt.imshow(np.abs(y_pred_error[0]))
                            plt.xticks([])
                            plt.yticks([])
                            plt.title("recon_t_y_error")
                            plt.savefig(
                                join(figure_save_path, 't_y%s.tif' % order))
                            train_plot = np.load(
                                join(project_root, 'models/%s' % model_file,
                                     'train_plot.npy'))
                            validate_plot = np.load(
                                join(project_root, 'models/%s' % model_file,
                                     'validate_plot.npy'))
                            [
                                num_train_plot,
                            ] = train_plot.shape
                            [
                                num_validate_plot,
                            ] = validate_plot.shape
                            x1 = np.arange(1, num_train_plot + 1)
                            x2 = np.arange(1, num_validate_plot + 1)

                            plt.figure(15)
                            l1, = plt.plot(x1, train_plot)
                            l2, = plt.plot(x2, validate_plot)
                            plt.legend(
                                handles=[
                                    l1,
                                    l2,
                                ],
                                labels=['train loss', 'validation loss'],
                                loc=1)
                            plt.xlabel('epoch')
                            plt.ylabel('loss')
                            plt.title('loss')
                            if not os.path.exists(
                                    join(figure_save_path, 'loss.tif')):
                                plt.savefig(join(figure_save_path, 'loss.tif'),
                                            dpi=300)
                            #plt.show()

                        scio.savemat(join(quantization_save_path, 'Test_MSE'),
                                     {'test_mse': Test_MSE})
                        scio.savemat(join(quantization_save_path, 'Test_PSNR'),
                                     {'test_psnr': Test_PSNR})
                        scio.savemat(join(quantization_save_path, 'Test_SSIM'),
                                     {'test_ssim': Test_SSIM})

                        scio.savemat(
                            join(quantization_save_path, 'K_Test_MSE'),
                            {'k_test_mse': K_Test_MSE})
                        scio.savemat(
                            join(quantization_save_path, 'K_Test_PSNR'),
                            {'k_test_psnr': K_Test_PSNR})
                        scio.savemat(
                            join(quantization_save_path, 'K_Test_SSIM'),
                            {'k_test_ssim': K_Test_SSIM})

                        scio.savemat(
                            join(quantization_save_path, 'Block1_Test_MSE'),
                            {'block1_test_mse': Block1_Test_MSE})
                        scio.savemat(
                            join(quantization_save_path, 'Block1_Test_PSNR'),
                            {'block1_test_psnr': Block1_Test_PSNR})
                        scio.savemat(
                            join(quantization_save_path, 'Block1_Test_SSIM'),
                            {'block1_test_ssim': Block1_Test_SSIM})

                        scio.savemat(
                            join(quantization_save_path, 'Block2_Test_MSE'),
                            {'block2_test_mse': Block2_Test_MSE})
                        scio.savemat(
                            join(quantization_save_path, 'Block2_Test_PSNR'),
                            {'block2_test_psnr': Block2_Test_PSNR})
                        scio.savemat(
                            join(quantization_save_path, 'Block2_Test_SSIM'),
                            {'block2_test_ssim': Block2_Test_SSIM})

                        scio.savemat(
                            join(quantization_save_path, 'Block3_Test_MSE'),
                            {'block3_test_mse': Block3_Test_MSE})
                        scio.savemat(
                            join(quantization_save_path, 'Block3_Test_PSNR'),
                            {'block3_test_psnr': Block3_Test_PSNR})
                        scio.savemat(
                            join(quantization_save_path, 'Block3_Test_SSIM'),
                            {'block3_test_ssim': Block3_Test_SSIM})

                        scio.savemat(join(quantization_save_path, 'Base_MSE'),
                                     {'base_mse': Base_MSE})
                        scio.savemat(join(quantization_save_path, 'Base_PSNR'),
                                     {'base_psnr': Base_PSNR})
                        scio.savemat(join(quantization_save_path, 'Base_SSIM'),
                                     {'base_ssim': Base_SSIM})
                        #elif test_case == "Save image":

            else:
                print("No checkpoint file found")
示例#5
0
def evaluate(test_data, mask):
    with tf.Graph().as_default() as g:
        y_ = tf.placeholder(tf.float32,
                            shape=[None, 6, 117, 120, 2],
                            name='y-label')
        mask_p = tf.placeholder(tf.complex64,
                                shape=[None, 6, 117, 120],
                                name='mask')
        kspace_p = tf.placeholder(tf.complex64,
                                  shape=[None, 6, 117, 120],
                                  name='kspace')

        y = inference.inference(mask_p, kspace_p, None)

        loss = tf.reduce_mean(
            tf.reduce_mean(tf.squared_difference(y, y_), [1, 2, 3, 4]))

        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(model_save_path)
            saver = tf.train.Saver()
            test_case = 'show image'
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                if __name__ == '__main__':
                    if test_case == 'check_loss':
                        count = 0
                        for ys in train.iterate_minibatch(test_data,
                                                          batch_size,
                                                          shuffle=True):
                            xs_l, kspace_l, mask_l, ys_l = train.prep_input(
                                ys, mask)
                            loss_value, y_pred = sess.run([loss, y],
                                                          feed_dict={
                                                              y_: ys_l,
                                                              mask_p: mask_l,
                                                              kspace_p:
                                                              kspace_l
                                                          })
                            print("The loss of No.{} test data = {}".format(
                                count + 1, loss_value))

                            y_c = real2complex(y_pred)
                            xs_c = real2complex(xs_l)
                            base_mse, test_mse, base_psnr, \
                            test_psnr, base_ssim, test_ssim = performance(xs_c, y_c, ys)
                            print("test loss:\t\t{:.6f}".format(loss_value))
                            print("test psnr:\t\t{:.6f}".format(test_psnr))
                            print("base psnr:\t\t{:.6f}".format(base_psnr))
                            print("base mse:\t\t{:.6f}".format(base_mse))
                            print("test mse:\t\t{:.6f}".format(test_mse))
                            print("base ssim:\t\t{:.6f}".format(base_ssim))
                            print("test ssim:\t\t{:.6f}".format(test_ssim))
                            count += 1
                    elif test_case == 'show image':
                        final_result_path = '/media/keziwen/86AA9651AA963E1D/Tensorflow/MyDeepMRI-KI_Net V2/for KI/final results'
                        figure_save_path = join(final_result_path,
                                                'KI vs KI_with_KLoss(e-1)',
                                                'KI')
                        if not os.path.isdir(figure_save_path):
                            os.makedirs(figure_save_path)
                        order = 79
                        ys = test_data[order]
                        ys = ys[np.newaxis, :]
                        xs_l, kspace_l, mask_l, ys_l = train.prep_input(
                            ys, mask)
                        time_start = time.time()
                        loss_value, y_pred = sess.run([loss, y],
                                                      feed_dict={
                                                          y_: ys_l,
                                                          mask_p: mask_l,
                                                          kspace_p: kspace_l
                                                      })
                        time_end = time.time()
                        y_pred = real2complex(y_pred)
                        xs = real2complex(xs_l)
                        if order == 0:
                            order_x = 100
                        elif order == 1:
                            order_x = 60
                        elif order == 2:
                            order_x = 85
                        elif order == 6:
                            order_x = 40
                        else:
                            order_x = 55
                        #order_x = 55 # (order, order_x): (0, 100), (1, 60), (6, 40), (7, 55)
                        ys_t = ys[:, :, order_x, :]
                        y_pred_t = y_pred[:, :, order_x, :]
                        xs_t = xs[:, :, order_x, :]
                        xs_t_error = ys_t - xs_t
                        y_pred_error = ys_t - y_pred_t

                        base_mse, test_mse, base_psnr,\
                        test_psnr, base_ssim, test_ssim = performance(xs, y_pred, ys)
                        print("test time:\t\t{:.6f}".format(time_end -
                                                            time_start))
                        print("test loss:\t\t{:.6f}".format(loss_value))
                        print("test psnr:\t\t{:.6f}".format(test_psnr))
                        print("base psnr:\t\t{:.6f}".format(base_psnr))
                        print("base mse:\t\t{:.6f}".format(base_mse))
                        print("test mse:\t\t{:.6f}".format(test_mse))
                        print("base ssim:\t\t{:.6f}".format(base_ssim))
                        print("test ssim:\t\t{:.6f}".format(test_ssim))
                        mask_shift = mymath.fftshift(mask, axes=(-1, -2))
                        gamma = 1
                        plt.figure(1)
                        plt.subplot(221)
                        plt.imshow(
                            exposure.adjust_gamma(np.abs(ys[0][0]), gamma),
                            plt.cm.gray)
                        plt.title('ground truth')
                        plt.subplot(222)
                        plt.imshow(
                            exposure.adjust_gamma(abs(mask_shift[0][0]),
                                                  gamma), plt.cm.gray)
                        plt.title('mask')
                        plt.subplot(223)
                        plt.imshow(exposure.adjust_gamma(abs(xs[0][0]), gamma),
                                   plt.cm.gray)
                        plt.title("undersampling")
                        plt.subplot(224)
                        plt.imshow(
                            exposure.adjust_gamma(abs(y_pred[0][0]), gamma),
                            plt.cm.gray)
                        plt.title("reconstruction")
                        plt.savefig(
                            join(figure_save_path, 'test%s.png' % order))
                        plt.figure(2)
                        plt.imshow(
                            exposure.adjust_gamma(np.abs(ys[0][0]), gamma),
                            plt.cm.gray)
                        plt.title('ground truth')
                        plt.savefig(join(figure_save_path, 'gr%s.png' % order))
                        plt.figure(3)
                        plt.imshow(exposure.adjust_gamma(abs(xs[0][0]), gamma),
                                   plt.cm.gray)
                        plt.title("undersampling")
                        plt.savefig(
                            join(figure_save_path, 'under%s.png' % order))
                        plt.figure(4)
                        plt.imshow(
                            exposure.adjust_gamma(abs(y_pred[0][0]), gamma),
                            plt.cm.gray)
                        plt.title("reconstruction")
                        plt.savefig(
                            join(figure_save_path, 'recon%s.png' % order))
                        plt.figure(5)
                        plt.imshow(
                            exposure.adjust_gamma(
                                abs(np.abs(ys[0][0]) - abs(y_pred[0][0])),
                                gamma))
                        plt.title("error")
                        plt.savefig(
                            join(figure_save_path, 'error%s.png' % order))
                        plt.figure(6)
                        plt.subplot(511)
                        plt.imshow(np.abs(ys_t[0]), plt.cm.gray)
                        plt.title("gnd_t_y")
                        plt.subplot(512)
                        plt.imshow(np.abs(xs_t[0]), plt.cm.gray)
                        plt.title("under_t_y")
                        plt.subplot(513)
                        plt.imshow(np.abs(xs_t_error[0]))
                        plt.title("under_t_y_error")
                        plt.subplot(514)
                        plt.imshow(np.abs(y_pred_t[0]), plt.cm.gray)
                        plt.title("recon_t_y")
                        plt.subplot(515)
                        plt.imshow(np.abs(y_pred_error[0]))
                        plt.title("recon_t_y_error")
                        plt.savefig(join(figure_save_path,
                                         't_y%s.png' % order))
                        train_plot = np.load(
                            join(project_root, 'models/%s' % model_file,
                                 'train_plot.npy'))
                        validate_plot = np.load(
                            join(project_root, 'models/%s' % model_file,
                                 'validate_plot.npy'))
                        [
                            num_train_plot,
                        ] = train_plot.shape
                        [
                            num_validate_plot,
                        ] = validate_plot.shape
                        x1 = np.arange(1, num_train_plot + 1)
                        x2 = np.arange(1, num_validate_plot + 1)
                        plt.figure(7)
                        l1, = plt.plot(x1, train_plot)
                        l2, = plt.plot(x2, validate_plot)
                        plt.legend(handles=[
                            l1,
                            l2,
                        ],
                                   labels=['train loss', 'validation loss'],
                                   loc=1)
                        plt.xlabel('epoch')
                        plt.ylabel('loss')
                        plt.title('loss')
                        if not os.path.exists(
                                join(figure_save_path, 'loss.png')):
                            plt.savefig(join(figure_save_path, 'loss.png'))
                        #plt.show()
                    #elif test_case == "Save image":

            else:
                print("No checkpoint file found")
def evaluate(test_data, mask):
    with tf.Graph() .as_default() as g:
        #x = tf.placeholder(tf.float32, shape=[None, 6, 117, 120, 2], name='x-input')
        y_ = tf.placeholder(tf.float32, shape=[None, 6, 117, 120, 2], name='y-label')
        mask_p = tf.placeholder(tf.complex64, shape=[None, 6, 117, 120], name='mask')
        kspace_p = tf.placeholder(tf.complex64, shape=[None, 6, 117, 120], name='kspace')
        kspace_full = tf.placeholder(tf.complex64, shape=[None, 6, 117, 120], name='kspace_full')

        y, k_recon, block_1, block_2, block_3 = inference.inference(mask_p, kspace_p, None)

        loss = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(y, y_), [1, 2, 3, 4]))

        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(model_save_path)
            saver = tf.train.Saver()
            test_case = 'show image'
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                if __name__ == '__main__':
                    if test_case == 'check_loss':
                        count = 0
                        for ys in train.iterate_minibatch(test_data, batch_size, shuffle=True):
                            xs_l, kspace_l, mask_l, ys_l, k_full_l = train.prep_input(test_data, mask)
                            loss_value, y_pred = sess.run([loss, y],
                                                          feed_dict={y_: ys_l, mask_p: mask_l, kspace_p: kspace_l, kspace_full: k_full_l})
                            print("The loss of No.{} test data = {}".format(count + 1, loss_value))

                            y_c = real2complex(y_pred)
                            xs_c = real2complex(xs_l)
                            base_mse, test_mse, base_psnr, \
                            test_psnr, base_ssim, test_ssim = performance(xs_c, y_c, ys)
                            print("test loss:\t\t{:.6f}".format(loss_value))
                            print("test psnr:\t\t{:.6f}".format(test_psnr))
                            print("base psnr:\t\t{:.6f}".format(base_psnr))
                            print("base mse:\t\t{:.6f}".format(base_mse))
                            print("test mse:\t\t{:.6f}".format(test_mse))
                            print("base ssim:\t\t{:.6f}".format(base_ssim))
                            print("test ssim:\t\t{:.6f}".format(test_ssim))
                            count += 1
                    elif test_case == 'show image':
                        project_root = '.'

                        quantization_save_path = join(project_root, 'result/quantization/%s' % model_file)
                        if not os.path.isdir(quantization_save_path):
                            os.makedirs(quantization_save_path)
                        Test_MSE = []
                        Test_PSNR = []
                        Test_SSIM = []
                        Base_MSE = []
                        Base_PSNR = []
                        Base_SSIM = []
                        for order in range(0, 100):
                            ys = test_data[order]
                            ys = ys[np.newaxis, :]
                            xs_l, kspace_l, mask_l, ys_l, k_full_l = train.prep_input(ys, mask)
                            time_start = time.time()
                            loss_value, y_pred = sess.run([loss, y],
                                                          feed_dict={y_: ys_l, mask_p: mask_l,
                                                                     kspace_p: kspace_l, kspace_full: k_full_l})
                            time_end = time.time()
                            y_pred_new = real2complex(y_pred)
                            xs = real2complex(xs_l)


                            base_mse, test_mse, base_psnr, \
                            test_psnr, base_ssim, test_ssim = performance(xs, y_pred_new, ys)

                            print("test time:\t\t{:.6f}".format(time_end - time_start))
                            print("test loss:\t\t{:.6f}".format(loss_value))
                            print("test psnr:\t\t{:.6f}".format(test_psnr))
                            print("base psnr:\t\t{:.6f}".format(base_psnr))
                            print("base mse:\t\t{:.6f}".format(base_mse))
                            print("test mse:\t\t{:.6f}".format(test_mse))
                            print("base ssim:\t\t{:.6f}".format(base_ssim))
                            print("test ssim:\t\t{:.6f}".format(test_ssim))
                            base_mse = ("%.6f" % base_mse)
                            test_mse = ("%.6f" % test_mse)

                            Test_MSE.append(test_mse)
                            Test_PSNR.append(test_psnr)
                            Test_SSIM.append(test_ssim)

                            Base_MSE.append(base_mse)
                            Base_PSNR.append(base_psnr)
                            Base_SSIM.append(base_ssim)



                        scio.savemat(join(quantization_save_path, 'Test_MSE_%s' % lambda_num), {'test_mse': Test_MSE})
                        scio.savemat(join(quantization_save_path, 'Test_PSNR_%s' % lambda_num), {'test_psnr': Test_PSNR})
                        scio.savemat(join(quantization_save_path, 'Test_SSIM_%s' % lambda_num), {'test_ssim': Test_SSIM})
                        scio.savemat(join(quantization_save_path, 'Base_MSE_%s' % lambda_num), {'base_mse': Base_MSE})
                        scio.savemat(join(quantization_save_path, 'Base_PSNR_%s' % lambda_num), {'base_psnr': Base_PSNR})
                        scio.savemat(join(quantization_save_path, 'Base_SSIM_%s' % lambda_num), {'base_ssim': Base_SSIM})
                            #plt.show()
                        #elif test_case == "Save image":

            else:
                print("No checkpoint file found")