def performance(xs, y, ys):
    base_mse = mse(ys, xs)
    test_mse = mse(ys, y)
    base_psnr = complex_psnr(ys, xs, peak='max')
    test_psnr = complex_psnr(ys, y, peak='max')
    batch, nt, nx, ny = y.shape
    base_ssim = 0
    test_ssim = 0
    for i in range(nt):
        base_ssim += compare_ssim(np.abs(ys[0][i]).astype('float64'),
                                  np.abs(xs[0][i]).astype('float64'))
        test_ssim += compare_ssim(np.abs(ys[0][i]).astype('float64'),
                                  np.abs(y[0][i]).astype('float64'))
    base_ssim /= nt
    test_ssim /= nt
    return base_mse, test_mse, base_psnr, test_psnr, base_ssim, test_ssim
示例#2
0
                k_u = Variable(k_und.type(Tensor))
                mask = Variable(mask.type(Tensor))
                gnd = Variable(x_gnd.type(Tensor))

                with torch.no_grad():
                    xf_out, img = xf_net(x_u, k_u, mask)

                test_loss.append(criterion(img['t%d' % (nc - 1)], gnd).item())

                im_und = from_tensor_format(x_und.numpy())
                im_gnd = from_tensor_format(x_gnd.numpy())
                im_rec = from_tensor_format(img['t%d' %
                                                (nc - 1)].data.cpu().numpy())

                for idx in range(im_und.shape[0]):
                    base_psnr.append(complex_psnr(im_gnd[idx], im_und[idx]))
                    epoch_psnr.append(complex_psnr(im_gnd[idx], im_rec[idx]))

            print("Epoch {}/{}".format(epoch + 1, num_epoch))
            print(" time: {}s".format(t_end - t_start))
            print(" training loss:\t\t{:.6f}".format(train_err))
            print(" testing loss:\t\t{:.6f}".format(np.mean(test_loss)))
            print(" base PSNR:\t\t{:.6f}".format(np.mean(base_psnr)))
            print(" test PSNR:\t\t{:.6f}".format(np.mean(epoch_psnr)))

            name = 'model_epoch_%d.npz' % epoch
            torch.save(xf_net.state_dict(), os.path.join(save_dir, name))
            print('model parameters saved at %s' %
                  os.path.join(save_dir, name))
            print('')
            if args.debug and validate_batches == 20:
                break

        # Testing
        vis = []
        test_err = 0
        base_psnr = 0
        test_psnr = 0
        test_batches = 0
        for im in iterate_minibatch(test, batch_size, shuffle=False):
            im_und, k_und, mask, im_gnd = prep_input(im, acc=acc)
            err, pred = val_fn(im_und, mask, k_und, im_gnd)
            test_err += err
            for im_i, und_i, pred_i in zip(im, from_lasagne_format(im_und),
                                           from_lasagne_format(pred)):
                base_psnr += complex_psnr(im_i, und_i, peak='max')
                test_psnr += complex_psnr(im_i, pred_i, peak='max')
            test_batches += 1

            if save_fig and test_batches % save_every == 0:
                vis.append((im[0], from_lasagne_format(pred)[0],
                            from_lasagne_format(im_und)[0],
                            from_lasagne_format(mask, mask=True)[0]))

            if args.debug and test_batches == 20:
                break

        t_end = time.time()

        train_err /= train_batches
        validate_err /= validate_batches
示例#4
0
                          batch_size)  # gnd(batch-size,2,160,224)
        test_err += float(criterion(pred, gnd_in))

        gnd_in = gnd_in.view(
            batch_size, 2, 2, Nx,
            Ny)  # (batch-size, 4, 160, 224)-->(batch-size, 2, 2, 160, 224)
        pred = pred.view(batch_size, 2, 2, Nx, Ny)
        gndb = from_tensor_format(gnd_in.data.cpu().numpy())
        predb = from_tensor_format(pred.data.cpu().numpy())

        for gnd_u, pred_u in zip(gndb, predb):
            test_psnr_fat, test_psnr_water, test_ssim_water, test_ssim_fat = 0, 0, 0, 0
            water_pred, fat_pred, water_gnd, fat_gnd = pred_u[0, :, :], pred_u[
                1, :, :], gnd_u[0, :, :], gnd_u[1, :, :]

            test_psnr_water = complex_psnr(water_gnd, water_pred, peak='max')
            test_psnr_fat = complex_psnr(fat_gnd, fat_pred, peak='max')
            vid.append((test_psnr_fat + test_psnr_water) / 2)

            test_ssim_water = compare_ssim(abs(water_pred),
                                           abs(water_gnd),
                                           multichannel=False)
            test_ssim_fat = compare_ssim(abs(fat_pred),
                                         abs(fat_gnd),
                                         multichannel=False)
            vib.append((test_ssim_fat + test_ssim_water) / 2)

            vis.append((gnd_u[0, :, :], gnd_u[1, :, :], pred_u[0, :, :],
                        pred_u[1, :, :]))

        test_batches += 1
示例#5
0
    # Compile function
    val_fn = compile_test_fn(net, net_config, args)

    # Create dataset
    test = create_dummy_data()

    vis = []
    base_psnr_list = []
    test_psnr_list = []
    for im in iterate_minibatch(test, batch_size, shuffle=False):
        im_und, k_und, mask, im_gnd = prep_input(im, acc=acc)
        err, pred = val_fn(im_und, mask, k_und, im_gnd)
        for im_i, und_i, pred_i in zip(im, from_lasagne_format(im_und),
                                       from_lasagne_format(pred)):
            base_psnr_list.append(complex_psnr(im_i, und_i, peak='max'))
            testpsnr = complex_psnr(im_i, pred_i, peak='max')
            test_psnr_list.append(testpsnr)

        for im_num in range(batch_size):
            vis.append((im[im_num], from_lasagne_format(pred)[im_num],
                        from_lasagne_format(im_und)[im_num],
                        from_lasagne_format(mask, mask=True)[im_num]))

    i = 0
    for im_i, pred_i, und_i, mask_i in vis:
        plt.imsave(join(save_dir, 'im{0}_test.png'.format(i)),
                   abs(np.concatenate([und_i, pred_i, im_i], 1)),
                   cmap='gray')
        plt.imsave(join(save_dir, 'mask{0}.png'.format(i)),
                   mask_i,
                break

        vis = []
        test_err = 0
        base_psnr = 0
        test_psnr = 0
        test_batches = 0
        for im in iterate_minibatch(test, batch_size, shuffle=False):
            im_und, k_und, mask, im_gnd = prep_input(im, acc=acc)

            err, pred = val_fn(im_und, mask, k_und, im_gnd)
            test_err += err
            for im_i, und_i, pred_i in zip(im,
                                           from_lasagne_format(im_und),
                                           from_lasagne_format(pred)):
                base_psnr += complex_psnr(im_i, und_i, peak='max')
                test_psnr += complex_psnr(im_i, pred_i, peak='max')
            test_batches += 1

            if save_fig and test_batches % save_every == 0:
                vis.append((im[0],
                            from_lasagne_format(pred)[0],
                            from_lasagne_format(im_und)[0],
                            from_lasagne_format(mask, mask=True)[0]))

            if args.debug and test_batches == 20:
                break

        t_end = time.time()

        train_err /= train_batches
示例#7
0
                masks = masks.permute(2, 0, 3, 4, 5, 1)
                x_gnd = x_gnd.permute(2, 0, 3, 4, 1)
                x_smaps = x_smaps.permute(2, 0, 3, 4, 5, 1)

                with torch.no_grad():
                    rec = model(x_und, k_und, masks, x_smaps, test=True)

                test_loss.append(criterion(rec + 1e-11, x_gnd).item())

                sense_recon = r2c(rec.data.to('cpu').numpy(), axis=-1)
                sense_gt = r2c(x_gnd.data.to('cpu').numpy(), axis=-1)
                sense_und = r2c(x_und.data.to('cpu').numpy(), axis=-1)

                for idx in range(x_gnd.shape[1]):
                    base_psnr.append(
                        complex_psnr(sense_gt[idx], sense_und[idx]))
                    test_psnr.append(
                        complex_psnr(sense_gt[idx], sense_recon[idx]))

            print("Epoch {}/{}".format(epoch + 1, n_epoch))
            print(" time: {}s".format(t_end - t_start))
            print(" training loss:\t\t{:.6f}".format(train_err))
            print(" testing loss:\t\t{:.6f}".format(np.mean(test_loss)))
            print(" base PSNR:\t\t{:.6f}".format(np.mean(base_psnr)))
            print(" test PSNR:\t\t{:.6f}".format(np.mean(test_psnr)))

            name = 'CTFNet_epoch_%d.npz' % epoch
            torch.save(model.state_dict(), os.path.join(save_dir, name))
            print('model parameters saved at %s' %
                  os.path.join(save_dir, name))
            print('')