예제 #1
0
def visualize_contours(image, inputs, preds, configs, save_path=None):
    show_img = np.copy(image)[0].transpose([1, 2, 0])
    fig, ax = plt.subplots(figsize=(7, 7))
    ax.imshow(show_img, cmap=plt.cm.gray)

    assert len(inputs) == len(preds)
    for idx in range(len(inputs)):
        init = to_numpy(inputs[idx].squeeze(0) * configs.im_size)
        pred = to_numpy(preds[idx].squeeze(0) * configs.im_size)

        ax.plot(init[:, 0], init[:, 1], 'ro', markersize=0.5)
        ax.plot(pred[:, 0], pred[:, 1], 'bo', markersize=0.5)

    if save_path:
        plt.savefig(save_path)
    plt.close()
 def test(epoch):
     if epoch % args.epochs == 0 or epoch % args.test_freq == 0:
         output = model(fixed_latent)
         output = to_numpy(output)
         if args.animate:
             i_plot = epoch // args.test_freq
             plot_prediction_det_animate2(run_dir,
                                          output_arr,
                                          output[0],
                                          epoch,
                                          args.idx,
                                          i_plot,
                                          plot_fn='imshow',
                                          cmap=args.cmap,
                                          same_scale=args.same_scale)
         else:
             plot_prediction_det(run_dir,
                                 output_arr,
                                 output[0],
                                 epoch,
                                 args.idx,
                                 plot_fn='imshow',
                                 cmap=args.cmap,
                                 same_scale=args.same_scale)
         np.save(run_dir + f'/epoch{epoch}.npy', output[0])
예제 #3
0
def save_error_mask(error_mask, save_path):
    if isinstance(error_mask, torch.Tensor):
        error_mask = to_numpy(error_mask)
    plt.axis('off')
    plt.imshow(error_mask, cmap='jet')  # crop pad,所以不能按照 error_mask 设置 fig 大小
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0)
    plt.clf()
예제 #4
0
    def test(epoch):
        model.eval()
        loss_test = 0.
        relative_l2, err2 = [], []
        for batch_idx, (input, target) in enumerate(test_loader):
            input, target = input.to(device), target.to(device)
            output = model(input)
            loss_pde = constitutive_constraint(input, output, sobel_filter) \
                + continuity_constraint(output, sobel_filter)
            loss_dirichlet, loss_neumann = boundary_condition(output)
            loss_boundary = loss_dirichlet + loss_neumann
            loss = loss_pde + loss_boundary * args.weight_bound
            loss_test += loss.item()
            # sum over H, W --> (B, C)
            err2_sum = torch.sum((output - target)**2, [-1, -2])
            relative_l2.append(torch.sqrt(err2_sum /
                                          (target**2).sum([-1, -2])))
            err2.append(err2_sum)
            # plot predictions
            if (epoch % args.plot_freq == 0 or epoch == args.epochs) and \
                batch_idx == len(test_loader) - 1:
                n_samples = 6 if epoch == args.epochs else 2
                idx = torch.randperm(input.size(0))[:n_samples]
                samples_output = output.data.cpu()[idx].numpy()
                samples_target = target.data.cpu()[idx].numpy()
                for i in range(n_samples):
                    print('epoch {}: plotting prediction {}'.format(epoch, i))
                    plot_prediction_det(args.pred_dir,
                                        samples_target[i],
                                        samples_output[i],
                                        epoch,
                                        i,
                                        plot_fn=args.plot_fn)

        loss_test /= (batch_idx + 1)
        relative_l2 = to_numpy(torch.cat(relative_l2, 0).mean(0))
        r2_score = 1 - to_numpy(torch.cat(err2, 0).sum(0)) / y_test_variation
        print(f"Epoch: {epoch}, test r2-score:  {r2_score}")
        print(f"Epoch: {epoch}, test relative-l2:  {relative_l2}")
        print(f'Epoch {epoch}: test loss: {loss_train:.6f}, loss_pde: {loss_pde.item():.6f}, '\
                f'dirichlet {loss_dirichlet:.6f}, nuemann {loss_neumann.item():.6f}')

        if epoch % args.log_freq == 0:
            logger['loss_test'].append(loss_test)
            logger['r2_test'].append(r2_score)
            logger['nrmse_test'].append(relative_l2)
    def test(epoch):
        model.eval()
        loss_test = 0.
        relative_l2, err2 = [], []
        for batch_idx, (input, target) in enumerate(test_loader):
            input, target = input.to(device), target.to(device)
            output = model(input)
            loss = F.mse_loss(output, target)
            loss_test += loss.item()
            # sum over H, W --> (B, C)
            err2_sum = torch.sum((output - target)**2, [-1, -2])
            relative_l2.append(torch.sqrt(err2_sum /
                                          (target**2).sum([-1, -2])))
            err2.append(err2_sum)
            # plot predictions
            if (epoch % args.plot_freq == 0 or epoch == args.epochs) and \
                batch_idx == len(test_loader) - 1:
                n_samples = 6 if epoch == args.epochs else 2
                idx = torch.randperm(input.size(0))[:n_samples]
                samples_output = output.data.cpu()[idx].numpy()
                samples_target = target.data.cpu()[idx].numpy()

                for i in range(n_samples):
                    print('epoch {}: plotting prediction {}'.format(epoch, i))
                    plot_prediction_det(args.pred_dir,
                                        samples_target[i],
                                        samples_output[i],
                                        epoch,
                                        i,
                                        plot_fn=args.plot_fn)

        loss_test /= (batch_idx + 1)
        relative_l2 = to_numpy(torch.cat(relative_l2, 0).mean(0))
        r2_score = 1 - to_numpy(torch.cat(err2, 0).sum(0)) / y_test_variation
        print(
            f"Epoch: {epoch}, test r2-score:  {r2_score}, relative-l2:  {relative_l2}"
        )
        if epoch % args.log_freq == 0:
            logger['loss_test'].append(loss_test)
            logger['r2_test'].append(r2_score)
            logger['nrmse_test'].append(relative_l2)
예제 #6
0
    def test_metric(self, handle_nan=True):
        relative_l2, err2 = [], []
        num_nan_inf = 0
        for batch_idx, (input, target) in enumerate(self.test_loader):
            input, target = input.to(self.device), target.to(self.device)
            pred_mean, pred_var = self.model.predict(
                input, n_samples=self.n_samples, temperature=self.temperature)
            # handling nan, inf
            if handle_nan:
                exception = torch.isnan(pred_mean) + torch.isinf(pred_mean)
                exception = exception.sum((1, 2, 3)).gt(0)
                normal = (1 - exception)
                # print(normal)
                normal_idx = torch.arange(len(normal)).to(
                    self.device).masked_select(normal).to(torch.long)
                # print(normal_idx)
                pred_mean, target = pred_mean.index_select(
                    0, normal_idx), target.index_select(0, normal_idx)
                num_nan_inf += exception.sum()
                # print(pred_mean.shape)

            err2_sum = torch.sum((pred_mean - target)**2, [-1, -2])
            relative_l2.append(torch.sqrt(err2_sum /
                                          (target**2).sum([-1, -2])))
            err2.append(err2_sum)

        relative_l2 = to_numpy(torch.cat(relative_l2, 0).mean(0))
        r2_score = 1 - to_numpy(torch.cat(err2,
                                          0).sum(0)) / self.y_test_variation
        print(relative_l2)
        print(r2_score)
        np.savetxt(self.post_dir + '/nrmse_test.txt', relative_l2)
        np.savetxt(self.post_dir + '/r2_test.txt', r2_score)
        if handle_nan:
            abnormal_rate = num_nan_inf / len(self.test_loader.dataset)
            print(f'num_nan_inf: {num_nan_inf}')
            print(f'abnormal rate: {abnormal_rate:.6f}')
            np.savetxt(
                self.post_dir + '/log_stats.txt',
                [num_nan_inf,
                 len(self.test_loader.dataset), abnormal_rate])
예제 #7
0
def cal_ssim(clean, noisy, normalized=True):
    """Use skimage.meamsure.compare_ssim to calculate SSIM

    Args:
        clean (Tensor): (B, 1, H, W)
        noisy (Tensor): (B, 1, H, W)
        normalized (bool): If True, the range of tensors are [-0.5 , 0.5]
            else [0, 255]
    Returns:
        SSIM per image: (B, )
    """
    if normalized:
        clean = clean.add(0.5).mul(255).clamp(0, 255)
        noisy = noisy.add(0.5).mul(255).clamp(0, 255)

    clean, noisy = to_numpy(clean), to_numpy(noisy)
    ssim = np.array([
        compare_ssim(clean[i, 0], noisy[i, 0], data_range=255)
        for i in range(clean.shape[0])
    ])

    return ssim
예제 #8
0
def visualize_contours_using_snake(image,
                                   inputs,
                                   vf,
                                   configs,
                                   save_path=None,
                                   corner_path=None):
    img = np.copy(image)[0].transpose([1, 2, 0])
    vf = vf.numpy()[0].transpose([1, 2, 0])
    fig, ax = plt.subplots(figsize=(7, 7))
    ax.imshow(img, cmap=plt.cm.gray)

    preds = list()

    for init in inputs:
        pred = active_contour(img,
                              to_numpy(init.squeeze(0) * configs.im_size),
                              vf,
                              w_edge=0.5,
                              gamma=0.1)
        preds.append(pred)

    for idx in range(len(inputs)):
        init = to_numpy(inputs[idx].squeeze(0) * configs.im_size)
        pred = preds[idx]
        # ax.plot(init[:, 0], init[:, 1], 'ro', markersize=0.5)
        ax.plot(pred[:, 0], pred[:, 1], 'bo', markersize=1.0)

    corner_img = imread(corner_path)

    for y in range(configs.im_size):
        for x in range(configs.im_size):
            if corner_img[y, x] >= 128:
                ax.plot(x, y, 'ro', markersize=0.5)

    if save_path:
        plt.savefig(save_path)
    plt.close()
예제 #9
0
def visualize_vector_fields(room_map,
                            vf,
                            inputs,
                            preds,
                            configs,
                            save_path=None):
    room_map = room_map.numpy()[0]
    vf = vf.numpy()[0].transpose([1, 2, 0])
    fig, ax = plt.subplots()
    ax.set_axis_off()
    image = imresize(room_map * 20, [64, 64])
    plt.imshow(image)
    plt.quiver(vf[::4, ::4, 0], -vf[::4, ::4, 1], units='width')
    assert len(inputs) == len(preds)

    for idx in range(len(inputs)):
        init = to_numpy(inputs[idx].squeeze(0) * configs.im_size)
        pred = to_numpy(preds[idx].squeeze(0) * configs.im_size)
        ax.plot(init[:, 0] / 4, init[:, 1] / 4, 'ro', markersize=0.5)
        ax.plot(pred[:, 0] / 4, pred[:, 1] / 4, 'bo', markersize=0.5)

    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
    plt.close()
예제 #10
0
                    target_y**2, [-1, -2])
            return relative_x, relative_y

        # lr scheduling
        step = (epoch - 1) * len(train_loader) + batch_idx
        pct = step / total_steps
        lr = scheduler.step(pct)
        adjust_learning_rate(optimizer, lr)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()
        loss_train += loss.item()

    loss_train /= batch_idx

    rel2_cat = torch.cat(relative_l2, 0)  # torch.Size([1344, 2])
    re_l2 = to_numpy(torch.mean(rel2_cat, 0))
    relative_ux = re_l2[0]
    relative_uy = re_l2[1]

    print(
        f'Epoch {epoch}: training loss: {loss_train:.6f} ' \
        f'relative-ux: {relative_ux: .5f}, relative-uy: {relative_uy: .5f}')

    logf = open(args.train_dir + "/running_log.txt", 'a')
    logf.write(
        f'Epoch {epoch}: training loss: {loss_train:.6f}, ' \
        f'relative-ux: {relative_ux: .5f}, relative-uy: {relative_uy: .5f} \n')
    logf.close()

    if epoch % args.log_freq == 0:
        logger['loss_train'].append(loss_train)
예제 #11
0
            # e2 = torch.sum((output[:,:2] - target[:,:2]) ** 2 , [-1, -2])
            # torch.sqrt(e2 / (target[:,:2] ** 2).sum([-1, -2]))
            err2_sum = torch.sum((output - target)**2, [-1, -2])
            # relative_l2.append(err2_sum)
            relative_l2.append(torch.sqrt(err2_sum /
                                          (target**2).sum([-1, -2])))
            # lr scheduling
            step = (epoch - 1) * len(train_loader) + batch_idx
            pct = step / total_steps
            lr = scheduler.step(pct)
            adjust_learning_rate(optimizer, lr)
            optimizer.step()
            loss_train += loss.item()

        loss_train /= batch_idx
        relative_l2 = to_numpy(torch.cat(relative_l2, 0).mean(0))
        relative_u = np.mean(relative_l2[:2])
        relative_s = np.mean(relative_l2[2:])
        print(f'Epoch {epoch}, lr {lr:.6f}')
        # print(f'Epoch {epoch}: training loss: {loss_train:.6f}, pde1: {loss_pde1:.6f}, pde2: {loss_pde2:.6f}, '\
        #     # f'dirichlet {loss_dirichlet:.6f}, nuemann {loss_neumann:.6f}')
        #     f'boundary: {loss_boundary:.6f}, relative-u: {relative_u: .5f}, relative_s: {relative_s: .5f}')
        print(f'Epoch {epoch}: training loss: {loss_train:.6f}, pde1: {loss_pde1:.6f} '\
            f'boundary: {loss_boundary:.6f}, relative-u: {relative_u: .5f}')
        if epoch % args.log_freq == 0:
            logger['loss_train'].append(loss_train)
            logger['loss_pde1'].append(loss_pde1)
            # logger['loss_pde2'].append(loss_pde2)
            logger['loss_b'].append(loss_boundary)
            logger['u_l2loss'].append(relative_u)
            logger['s_l2loss'].append(relative_s)
예제 #12
0
        # time per 512x512 (training image is 256x256)
        time_taken /= (n_test_samples / multiplier)
        psnr = psnr / n_test_samples
        ssim = ssim / n_test_samples

        result = {'psnr': psnr, 'ssim': ssim, 'time': time_taken}
        case.update(result)
        pprint(result)
        logger.update({f'noise{noise_level}_{image_type}': case})

        # fixed test: (n_plots, 4, 1, 256, 256)
        for i in range(n_plots):
            print(f'plot {i}-th denoising: [noisy, denoised, clean]')
            fixed_denoised = model(fixed_test_noisy[i])
            fixed_noisy_stitched = stitch_pathes(to_numpy(fixed_test_noisy[i]))
            fixed_denoised_stitched = stitch_pathes(to_numpy(fixed_denoised))
            fixed_clean_stitched = stitch_pathes(to_numpy(fixed_test_clean[i]))
            plot_row(np.concatenate(
                (fixed_noisy_stitched, fixed_denoised_stitched,
                 fixed_clean_stitched)),
                     test_case_dir,
                     f'denoising{i}',
                     same_range=True,
                     plot_fn='imshow',
                     cmap=cmap,
                     colorbar=False)

        with open(
                test_dir + "/results_{}.txt".format(
                    'cpu' if args_test.no_cuda else 'gpu'), 'w') as args_file:
def test():
    transform = _get_transform()

    # Prepare data
    print("Loading Data")
    val_dataset = VideoAttTarget_video(videoattentiontarget_val_data,
                                       videoattentiontarget_val_label,
                                       transform=transform,
                                       test=True,
                                       seq_len_limit=50)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=0,
                                             collate_fn=video_pack_sequences)

    # Define device
    device = torch.device('cuda', args.device)

    # Load model
    num_lstm_layers = 2
    print("Constructing model")
    model = ModelSpatioTemporal(num_lstm_layers=num_lstm_layers)
    model.cuda(device)

    print("Loading weights")
    model_dict = model.state_dict()
    snapshot = torch.load(args.model_weights)
    snapshot = snapshot['model']
    model_dict.update(snapshot)
    model.load_state_dict(model_dict)

    print('Evaluation in progress ...')
    model.train(False)
    AUC = []
    in_vs_out_groundtruth = []
    in_vs_out_pred = []
    distance = []
    chunk_size = 3
    with torch.no_grad():
        for batch_val, (img_val, face_val, head_channel_val, gaze_heatmap_val,
                        cont_gaze, inout_label_val,
                        lengths_val) in enumerate(val_loader):
            print('\tprogress = ', batch_val + 1, '/', len(val_loader))
            X_pad_data_img, X_pad_sizes = pack_padded_sequence(
                img_val, lengths_val, batch_first=True)
            X_pad_data_head, _ = pack_padded_sequence(head_channel_val,
                                                      lengths_val,
                                                      batch_first=True)
            X_pad_data_face, _ = pack_padded_sequence(face_val,
                                                      lengths_val,
                                                      batch_first=True)
            Y_pad_data_cont_gaze, _ = pack_padded_sequence(cont_gaze,
                                                           lengths_val,
                                                           batch_first=True)
            Y_pad_data_heatmap, _ = pack_padded_sequence(gaze_heatmap_val,
                                                         lengths_val,
                                                         batch_first=True)
            Y_pad_data_inout, _ = pack_padded_sequence(inout_label_val,
                                                       lengths_val,
                                                       batch_first=True)

            hx = (torch.zeros(
                (num_lstm_layers, args.batch_size, 512, 7, 7)).cuda(device),
                  torch.zeros((num_lstm_layers, args.batch_size, 512, 7,
                               7)).cuda(device)
                  )  # (num_layers, batch_size, feature dims)
            last_index = 0
            previous_hx_size = args.batch_size

            for i in range(0, lengths_val[0], chunk_size):
                X_pad_sizes_slice = X_pad_sizes[i:i + chunk_size].cuda(device)
                curr_length = np.sum(X_pad_sizes_slice.cpu().detach().numpy())
                # slice padded data
                X_pad_data_slice_img = X_pad_data_img[last_index:last_index +
                                                      curr_length].cuda(device)
                X_pad_data_slice_head = X_pad_data_head[last_index:last_index +
                                                        curr_length].cuda(
                                                            device)
                X_pad_data_slice_face = X_pad_data_face[last_index:last_index +
                                                        curr_length].cuda(
                                                            device)
                Y_pad_data_slice_cont_gaze = Y_pad_data_cont_gaze[
                    last_index:last_index + curr_length].cuda(device)
                Y_pad_data_slice_heatmap = Y_pad_data_heatmap[
                    last_index:last_index + curr_length].cuda(device)
                Y_pad_data_slice_inout = Y_pad_data_inout[
                    last_index:last_index + curr_length].cuda(device)
                last_index += curr_length

                # detach previous hidden states to stop gradient flow
                prev_hx = (hx[0][:, :min(X_pad_sizes_slice[0], previous_hx_size
                                         ), :, :, :].detach(),
                           hx[1][:, :min(X_pad_sizes_slice[0], previous_hx_size
                                         ), :, :, :].detach())

                # forward pass
                deconv, inout_val, hx = model(X_pad_data_slice_img, X_pad_data_slice_head, X_pad_data_slice_face, \
                                                         hidden_scene=prev_hx, batch_sizes=X_pad_sizes_slice)

                for b_i in range(len(Y_pad_data_slice_cont_gaze)):
                    if Y_pad_data_slice_inout[b_i]:  # ONLY for 'inside' cases
                        # AUC: area under curve of ROC
                        multi_hot = torch.zeros(
                            output_resolution,
                            output_resolution)  # set the size of the output
                        gaze_x = Y_pad_data_slice_cont_gaze[b_i, 0]
                        gaze_y = Y_pad_data_slice_cont_gaze[b_i, 1]
                        multi_hot = imutils.draw_labelmap(multi_hot, [
                            gaze_x * output_resolution,
                            gaze_y * output_resolution
                        ],
                                                          3,
                                                          type='Gaussian')
                        multi_hot = (multi_hot > 0).float(
                        ) * 1  # make GT heatmap as binary labels
                        multi_hot = misc.to_numpy(multi_hot)

                        scaled_heatmap = imresize(
                            deconv[b_i].squeeze(),
                            (output_resolution, output_resolution),
                            interp='bilinear')
                        auc_score = evaluation.auc(scaled_heatmap, multi_hot)
                        AUC.append(auc_score)

                        # distance: L2 distance between ground truth and argmax point
                        pred_x, pred_y = evaluation.argmax_pts(
                            deconv[b_i].squeeze())
                        norm_p = [
                            pred_x / output_resolution,
                            pred_y / output_resolution
                        ]
                        dist_score = evaluation.L2_dist(
                            Y_pad_data_slice_cont_gaze[b_i], norm_p).item()
                        distance.append(dist_score)

                # in vs out classification
                in_vs_out_groundtruth.extend(
                    Y_pad_data_slice_inout.cpu().numpy())
                in_vs_out_pred.extend(inout_val.cpu().numpy())

                previous_hx_size = X_pad_sizes_slice[-1]

            try:
                print("\tAUC:{:.4f}"
                      "\tdist:{:.4f}"
                      "\tin vs out AP:{:.4f}".format(
                          torch.mean(torch.tensor(AUC)),
                          torch.mean(torch.tensor(distance)),
                          evaluation.ap(in_vs_out_groundtruth,
                                        in_vs_out_pred)))
            except:
                pass

    print("Summary ")
    print("\tAUC:{:.4f}"
          "\tdist:{:.4f}"
          "\tin vs out AP:{:.4f}".format(
              torch.mean(torch.tensor(AUC)),
              torch.mean(torch.tensor(distance)),
              evaluation.ap(in_vs_out_groundtruth, in_vs_out_pred)))
예제 #14
0
    def test(epoch):
        model.eval()
        loss_test = 0.
        # mse = 0.
        relative_l2, err2 = [], []
        for batch_idx, (input, target) in enumerate(test_loader):
            input, target = input.to(device), target.to(device)
            # every 10 epochs evaluate the mean accurately
            if epoch % 10 == 0:
                output_samples = model.sample(input,
                                              n_samples=20,
                                              temperature=1.0)
                output = output_samples.mean(0)
            else:
                # evaluate with one output sample
                output, _ = model.generate(input)

            residual_norm = constitutive_constraint(input, output, sobel_filter) \
                + continuity_constraint(output, sobel_filter)
            loss_dirichlet, loss_neumann = boundary_condition(output)
            loss_boundary = loss_dirichlet + loss_neumann
            loss_pde = residual_norm + loss_boundary * args.weight_bound
            # evaluate predictive entropy: E_p(y|x) [log p(y|x)]
            neg_entropy = log_likeihood.mean() / math.log(2.) / n_out_pixels
            loss = loss_pde * args.beta + neg_entropy

            loss_test += loss.item()
            err2_sum = torch.sum((output - target)**2, [-1, -2])
            # print(err2_sum)
            relative_l2.append(torch.sqrt(err2_sum /
                                          (target**2).sum([-1, -2])))
            err2.append(err2_sum)

            # plot predictions
            if (epoch % args.plot_freq == 0
                    or epoch % args.epochs == 0) and batch_idx == 0:
                n_samples = 6 if epoch == args.epochs else 2
                idx = np.random.permutation(input.size(0))[:n_samples]
                samples_target = target.data.cpu()[idx].numpy()

                for i in range(n_samples):
                    print('epoch {}: plotting prediction {}'.format(epoch, i))
                    pred_mean, pred_var = model.predict(input[[idx[i]]])
                    plot_prediction_bayes2(args.pred_dir,
                                           samples_target[i],
                                           pred_mean[0],
                                           pred_var[0],
                                           epoch,
                                           idx[i],
                                           plot_fn='imshow',
                                           cmap='jet',
                                           same_scale=False)
                    # plot samples p(y|x)
                    print(idx[i])
                    print(input[[idx[i]]].shape)
                    samples_pred = model.sample(input[[idx[i]]],
                                                n_samples=15)[:, 0]
                    samples = torch.cat((target[[idx[i]]], samples_pred), 0)
                    # print(samples.shape)
                    save_samples(args.pred_dir,
                                 samples,
                                 epoch,
                                 idx[i],
                                 'samples',
                                 nrow=4,
                                 heatmap=True,
                                 cmap='jet')

        loss_test /= (batch_idx + 1)
        relative_l2 = to_numpy(torch.cat(relative_l2, 0).mean(0))
        r2_score = 1 - to_numpy(torch.cat(err2, 0).sum(0)) / y_test_variation

        print(f"Epoch {epoch}: test r2-score:  {r2_score}")
        print(f"Epoch {epoch}: test relative l2:  {relative_l2}")
        print(f'Epoch {epoch}: test loss: {loss_test:.6f}, residual: {residual_norm.item():.6f}, '\
                f'boundary {loss_boundary.item():.6f}, neg entropy {neg_entropy.item():.6f}')

        if epoch % args.log_freq == 0:
            logger['loss_test'].append(loss_test)
            logger['r2_test'].append(r2_score)
            logger['nrmse_test'].append(relative_l2)
            logger['entropy_test'].append(-neg_entropy.item())
        tic = time.time()
        if image_type != 'Confocal_FISH':
            noisy_file = data_dir + f'/{image_type}/raw/19/HV110_P0500510000.png'
        else:
            noisy_file = data_dir + f'/{image_type}/raw/19/HV140_P100510000.png'
        clean_file = data_dir + f'/{image_type}/gt/19/avg50.png'
        noisy = four_crop(Image.open(noisy_file)).to(device)
        clean = four_crop(Image.open(clean_file)).to(device)
        print(noisy.shape)
        print(clean.shape)

        denoised = model(noisy.to(device))
        psnr = cal_psnr(clean, denoised).mean(0)
        ssim = cal_ssim(clean, denoised).mean(0)

        denoised = stitch_pathes(to_numpy(denoised))[0]
        noisy = stitch_pathes(to_numpy(noisy))[0]
        clean = stitch_pathes(to_numpy(clean))[0]

        print(image_type)
        print(f'psnr: {psnr}')
        print(f'ssim: {ssim}')
        if i < 3:
            psnr_3c += psnr.item()
            ssim_3c += ssim.item()
        save_file = test_dir + f'/{args_test.model}_noise{noise_level}_{image_type}_test19_idx0_denoised.png'
        plt.imsave(save_file, denoised, format="png", cmap="gray")
    print('Confocal BPAE avg of 3 channels')
    print(f'PSNR avg: {psnr_3c / 3.}')
    print(f'SSIM avg: {ssim_3c / 3.}')