예제 #1
0
def get_data_loader(config):

    data_transform = torchvision.transforms.Compose([Resize(config['image_size']), ToTorchFormatTensor()])
    val_generator = CustomDataset("test.txt", transform = data_transform)
    val_loader = torch.utils.data.DataLoader(val_generator, batch_size=1, shuffle=True)
    print('training and validation dataloader created')
    return val_loader 
예제 #2
0
def get_data_loader(config):
    "Create Dataloaders"

    data_transform = torchvision.transforms.Compose(
        [Resize(config['image_size']),
         ToTorchFormatTensor()])
    train_generator = CustomDataset(config['train_dataset'],
                                    transform=data_transform)
    val_generator = CustomDataset(config['validation_dataset'],
                                  transform=data_transform)
    train_loader = torch.utils.data.DataLoader(train_generator,
                                               batch_size=config['batch_size'],
                                               num_workers=4,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_generator,
                                             batch_size=config['batch_size'],
                                             num_workers=4,
                                             shuffle=True)
    print('training and validation dataloader created')
    return train_loader, val_loader
예제 #3
0
from Dataset import CustomDataset
import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from wct import wct

my_dataset = CustomDataset(path='../data')
dataloader = DataLoader(my_dataset, batch_size=2, shuffle=True)
for i_batch, batch in enumerate(dataloader):
    #all of the batch[i] corresponds to what got returned from the loader
    #batch[0] is all of the np arrays
    print("batch " + str(i_batch))
    tensors = batch[0]
    f1 = tensors[0]
    f2 = tensors[1]
    res = wct(0.2, f1, f2)[0].numpy()
    plt.imshow(res.astype('uint8'))
plt.show()
예제 #4
0
def main(device):
    dataset = CustomDataset(transform=transform)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    results = dict()
    for augment in augment_list:
        if augment not in results.keys():
            results[augment] = dict()
        for model_name in model_list:
            if model_name not in results[augment].keys():
                results[augment][model_name] = dict()
            accuracies, precisions, recalls, f1_scores, aucs = list(), list(
            ), list(), list(), list()
            for fold in range(1, 6):
                print("Fold = {}".format(fold))
                model_path = os.path.join(
                    './model_save', augment, model_name,
                    '0.01_030(030)(fold={}).pth'.format(fold))
                check_point = torch.load(model_path)
                model = check_point['model']
                model.load_state_dict(check_point['model_state_dict'])
                accuracy, precision, recall, f1_score, auc = evaluate(
                    device, model, dataloader)

                accuracies.append(accuracy)
                precisions.append(precision)
                recalls.append(recall)
                f1_scores.append(f1_score)
                aucs.append(auc)
            print('\n============ TEST REPORT ============\n')
            print("test accuracy = ", accuracies)
            print("test precision = ", precisions)
            print("test recall = ", recalls)
            print("test f1_score = ", f1_scores)
            print("test auc = ", aucs)

            print("\ntest accuracy mean = {}, std = {}".format(
                np.round(np.mean(accuracies), 4),
                np.round(np.std(accuracies), 4)))
            print("test precision mean = {}, std = {}".format(
                np.round(np.mean(precisions), 4),
                np.round(np.std(precisions), 4)))
            print("test recall mean = {}, std = {}".format(
                np.round(np.mean(recalls), 4), np.round(np.std(recalls), 4)))
            print("test f1 score mean = {}, std = {}".format(
                np.round(np.mean(f1_scores), 4),
                np.round(np.std(f1_scores), 4)))
            print("test auc mean = {}, std = {}".format(
                np.round(np.mean(aucs), 4), np.round(np.std(aucs), 4)))

            results[augment][model_name]['accuracy'] = [
                np.round(np.mean(accuracies), 4),
                np.round(np.std(accuracies), 4)
            ]
            results[augment][model_name]['precision'] = [
                np.round(np.mean(precisions), 4),
                np.round(np.std(precisions), 4)
            ]
            results[augment][model_name]['recall'] = [
                np.round(np.mean(recalls), 4),
                np.round(np.std(recalls), 4)
            ]
            results[augment][model_name]['f1_score'] = [
                np.round(np.mean(f1_scores), 4),
                np.round(np.std(f1_scores), 4)
            ]
            results[augment][model_name]['auc'] = [
                np.round(np.mean(aucs), 4),
                np.round(np.std(aucs), 4)
            ]
            print('\n============ TEST REPORT ============\n')

    reporter(results)

    return results
예제 #5
0
def test(opt):
    img_data = CustomDataset(opt.data_dir, opt.mask_dir, opt.img_size)
    custom_loader = DataLoader(img_data,
                               batch_size=opt.batch_size,
                               shuffle=True,
                               num_workers=0,
                               drop_last=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Model(opt.input_nc)
    model.init('xavier')

    Punet = model.model
    dnet = model.dnet
    Punet.to(device)
    dnet.to(device)

    if opt.checkpoint:
        log_state = torch.load(opt.checkpoint)
        model_dict = Punet.state_dict()
        model_dict.update(log_state)
        Punet.load_state_dict(model_dict)
        print("load checkpoint.")

    print("Testing...................")

    # for counting average psnr and ssim of the testset
    # count = 0
    # total_psnr = 0
    # total_ssim = 0

    while True:
        test_batch = iter(custom_loader).next()
        for idx, item in enumerate(test_batch):
            test_batch[idx] = normalize(item)
        test_img, test_mask, test_masked, test_img_2x, test_mask_2x = test_batch
        test_img = test_img.type(torch.FloatTensor).to(device)
        test_mask = test_mask.type(torch.FloatTensor).to(device)

        test_masked = test_masked.type(torch.FloatTensor).to(device)
        test_img_2x = test_img_2x.type(torch.FloatTensor).to(device)
        test_mask_2x = test_mask_2x.type(torch.FloatTensor).to(device)
        # test output
        pred = Punet(test_masked, test_mask)
        pred = pred.detach().cpu()
        test_img = test_img.detach().cpu()
        test_masked = test_masked.detach().cpu()
        test_mask = test_mask.detach().cpu()

        pred = de_normalize(pred)
        test_img = de_normalize(test_img)
        test_masked = de_normalize(test_masked)
        test_mask = de_normalize(test_mask)

        psnr, ssim = evaluate_result(pred, test_img)
        print("PSNR: {:.5f}, SSIM: {:.5f}".format(psnr, ssim))

        # count += 1
        # total_psnr += psnr
        # total_ssim += ssim
        # if count > 2000:
        #     print("total_psnr: {:.5f}, total:ssim:{:.5f}".format(total_psnr / count, total_ssim / count))
        #     break

        plt.figure(figsize=(32, 16))
        plt.axis('off')
        plt.title('fake image')
        plt.subplot(1, 3, 1)
        plt.imshow(np.transpose(pred[0], (1, 2, 0)))
        plt.subplot(1, 3, 2)
        plt.imshow(np.transpose(test_img[0], (1, 2, 0)))
        plt.subplot(1, 3, 3)
        plt.imshow(np.transpose(test_masked[0], (1, 2, 0)))
        plt.show()
예제 #6
0
def train(opt):
    img_data = CustomDataset(opt.data_dir, opt.mask_dir, opt.img_size)
    custom_loader = DataLoader(img_data,
                               batch_size=opt.batch_size,
                               shuffle=True,
                               num_workers=0,
                               drop_last=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Model(opt.input_nc)
    model.init('xavier')

    Punet = model.model
    dnet = model.dnet
    refine_net = Stage2Network(opt.input_nc)

    Punet.to(device)
    dnet.to(device)
    refine_net.to(device)

    Punet.load_state_dict(torch.load("./model/recent/_gnet.pth"))

    # if opt.checkpoint:
    # refine_net.load_state_dict(torch.load('./model/_gnet.pth'))
    dnet.load_state_dict(torch.load('./model/_dnet.pth'))
    print("load checkpoint.")

    for param in Punet.parameters():
        param.requires_grad = False

    print("Training...................")

    test_img, test_mask, test_masked, test_img_2x, test_mask_2x = iter(
        custom_loader).next()
    test_img = test_img.type(torch.FloatTensor).to(device)
    test_mask = test_mask.type(torch.FloatTensor).to(device)
    test_masked = test_masked.type(torch.FloatTensor).to(device)
    test_img_2x = test_img_2x.type(torch.FloatTensor).to(device)
    test_mask_2x = test_mask_2x.type(torch.FloatTensor).to(device)

    # test output
    # pred, pred_2x = Punet(test_masked, test_mask)
    # pred_2x = pred_2x.detach().cpu()
    # test_img = test_img.detach().cpu()
    # test_masked = test_masked.detach().cpu()
    # plt.figure(figsize=(32, 16))
    # plt.axis('off')
    # plt.title('fake image')
    # plt.subplot(1, 3, 1)
    # plt.imshow(np.transpose(pred_2x[0], (1, 2, 0)))
    # plt.subplot(1, 3, 2)
    # plt.imshow(np.transpose(test_img[0], (1, 2, 0)))
    # plt.subplot(1, 3, 3)
    # plt.imshow(np.transpose(test_masked[0], (1, 2, 0)))
    # plt.show()

    optimizer = torch.optim.Adam(refine_net.parameters(), lr=opt.lr)
    optimizer_D = torch.optim.Adam(dnet.parameters(), lr=2 * opt.lr)

    criterion = CriterionPerPixel(use_gram=True)
    criterion_D = CriterionD(False)
    criterionGAN = CriterionGAN(use_lsgan=False)

    losses = []
    losses_valid = []
    losses_hole = []
    losses_perceptual = []
    losses_tv = []
    losses_gan = []
    losses_d = []

    for epoch in range(opt.epochs):
        for idx, data in enumerate(custom_loader):
            img = data[0].type(torch.FloatTensor).to(device)
            mask = data[1].type(torch.FloatTensor).to(device)
            masked = data[2].type(torch.FloatTensor).to(device)
            img_2x = data[3].type(torch.FloatTensor).to(device)
            mask_2x = data[4].type(torch.FloatTensor).to(device)

            set_requires_grad(dnet, requires_grad=True)
            dnet.zero_grad()
            real_out, mid_real = dnet(img_2x)
            label = torch.full(real_out.size(), 1, device=device)
            real_loss = criterion_D(real_out, label)
            real_loss.backward()

            pred = Punet(masked, mask)
            pred_2x, blur_2x = refine_net(img, mask, img_2x, mask_2x, pred)

            fake_out, mid_fake = dnet(pred_2x.detach())
            label.fill_(0)
            fake_loss = criterion_D(fake_out, label)
            fake_loss.backward()

            err_d = real_loss + fake_loss
            optimizer_D.step()

            set_requires_grad(dnet, requires_grad=False)
            refine_net.zero_grad()
            fake_loss_g, mid_fake_g = dnet(pred_2x)
            label.fill_(1)
            l_gan = criterionGAN(fake_loss_g, label)
            loss_valid, loss_hole, loss_perceptual, loss_tv = criterion(
                pred_2x, img_2x, mask_2x)
            loss = loss_valid + 6 * loss_hole + 0.05 * loss_perceptual + \
                    0.3 * loss_tv  + 2.5 * l_gan
            loss.backward()
            optimizer.step()

            losses.append(loss.item())
            losses_valid.append(loss_valid.item())
            losses_hole.append(loss_hole.item())
            losses_perceptual.append(loss_perceptual.item())
            losses_tv.append(loss_tv.item())
            losses_gan.append(l_gan.item())
            losses_d.append(err_d.item())

            if idx % opt.save_per_iter == 0:
                time_str = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
                torch.save(refine_net.state_dict(), './model/' + '_gnet.pth')
                torch.save(dnet.state_dict(), './model/' + '_dnet.pth')
                # print('Model saved.')
            if idx % 200 == 0 or (epoch == opt.epochs - 1
                                  and idx == len(custom_loader) - 1):
                print(
                    "Epoch: {}, step: {}, loss_valid: {:.5f}, loss_hole: {:.5f}, loss_perceptual: {:.5f}, loss_total: {:.5f}, "
                    "loss_tv: {:.5f}, l_gan: {:.5f}, l_d: {:.5f}".format(
                        epoch, idx,
                        np.mean(losses_valid), np.mean(losses_hole),
                        np.mean(losses_perceptual), np.mean(losses),
                        np.mean(losses_tv), np.mean(losses_gan),
                        np.mean(losses_d)))
                losses.clear()
                losses_valid.clear()
                losses_hole.clear()
                losses_perceptual.clear()
                losses_tv.clear()
                losses_gan.clear()
                losses_d.clear()
                with torch.no_grad():
                    pred = Punet(test_masked, test_mask)
                    pred_2x, blur_2x = refine_net(test_img, test_mask,
                                                  test_img_2x, test_mask_2x,
                                                  pred)
                    pred_2x = pred_2x.detach().cpu()
                    blur_2x = blur_2x.detach().cpu()
                    target = test_img_2x.detach().cpu()
                    masked = test_mask_2x.detach().cpu()
                sample_pred = vutils.make_grid(pred_2x,
                                               padding=2,
                                               normalize=False)
                sample_target = vutils.make_grid(target,
                                                 padding=2,
                                                 normalize=False)
                sample_masked = vutils.make_grid(masked,
                                                 padding=2,
                                                 normalize=False)
                sample_blur = vutils.make_grid(blur_2x,
                                               padding=2,
                                               normalize=False)
                plt.figure(figsize=(32, 16))
                plt.axis('off')
                plt.title('fake image')
                plt.subplot(4, 1, 1)
                plt.imshow(np.transpose(sample_pred, (1, 2, 0)))
                plt.subplot(4, 1, 2)
                plt.imshow(np.transpose(sample_target, (1, 2, 0)))
                plt.subplot(4, 1, 3)
                plt.imshow(np.transpose(sample_masked, (1, 2, 0)))
                plt.subplot(4, 1, 4)
                plt.imshow(np.transpose(sample_blur, (1, 2, 0)))
                plt.savefig("./sample/epoch_{}_iter_{}.png".format(epoch, idx))
                plt.close()
예제 #7
0
                        help="'cuda' for cuda, 'cpu' for cpu, default = cuda",
                        default='cuda')
    parser.add_argument('--batch_size',
                        help="batchsize, default = 1",
                        default=1,
                        type=int)
    args = parser.parse_args()

    print(args)
    print(os.getcwd())
    device = torch.device(args.cuda)
    # model_dir = 'models/07121619/25epo_210000step.ckpt'
    state_dict = torch.load(args.model_dir)
    model = Unet().to(device)
    model.load_state_dict(state_dict)
    dataset = CustomDataset(root_dir=args.image_dir)
    dataloader = DataLoader(dataset, args.batch_size, shuffle=False)
    writer = SummaryWriter('log/Att_test')
    model.eval()
    for i, batch in enumerate(dataloader):
        img = batch.to(device)
        # mask = batch['mask'].to(device)
        with torch.no_grad():
            pred, loss, attention = model(img)
        for j, _attention in enumerate(attention):
            size = _attention.size(
            )  # batch, (org_size, org_size), (224, ) or (13*224//org_size, )
            if size[3] != 224:
                patch = F.pad(img, (6 * 224 // size[2], 6 * 224 // size[2],
                                    6 * 224 // size[2], 6 * 224 // size[2]))
                patch = patch.unfold(2, size[3], 224 // size[1]) \
예제 #8
0
def train(opt):
    img_data = CustomDataset(opt.data_dir, opt.mask_dir, opt.img_size)
    custom_loader = DataLoader(img_data,
                               batch_size=opt.batch_size,
                               shuffle=True,
                               num_workers=0,
                               drop_last=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Model(opt.input_nc)
    model.init('xavier')

    Punet = model.model
    dnet = model.dnet
    pdnet = model.dnet_

    Punet.to(device)
    dnet.to(device)
    pdnet.to(device)

    if opt.checkpoint:
        log_state = torch.load(opt.checkpoint)
        model_dict = Punet.state_dict()
        model_dict.update(log_state)
        Punet.load_state_dict(model_dict)
        dnet.load_state_dict(torch.load('./model/_dnet.pth'))
        pdnet.load_state_dict(torch.load('./model/_pdnet.pth'))
        print("load checkpoint.")

    print("Training...................")

    test_batch = iter(custom_loader).next()
    for idx, item in enumerate(test_batch):
        test_batch[idx] = normalize(item)
    test_img, test_mask, test_masked, test_img_2x, test_mask_2x = test_batch
    test_img = test_img.type(torch.FloatTensor).to(device)
    test_mask = test_mask.type(torch.FloatTensor).to(device)
    test_masked = test_masked.type(torch.FloatTensor).to(device)
    test_img_2x = test_img_2x.type(torch.FloatTensor).to(device)
    test_mask_2x = test_mask_2x.type(torch.FloatTensor).to(device)

    # test output
    # pred = Punet(test_masked, test_mask)
    # pred = pred.detach().cpu()
    # test_img = test_img.detach().cpu()
    # test_masked = test_masked.detach().cpu()
    # plt.figure(figsize=(32, 16))
    # plt.axis('off')
    # plt.title('fake image')
    # plt.subplot(1, 3, 1)
    # plt.imshow(np.transpose(pred[0], (1, 2, 0)))
    # plt.subplot(1, 3, 2)
    # plt.imshow(np.transpose(test_img[0], (1, 2, 0)))
    # plt.subplot(1, 3, 3)
    # plt.imshow(np.transpose(test_masked[0], (1, 2, 0)))
    # plt.show()

    optimizer = torch.optim.Adam(Punet.parameters(),
                                 lr=opt.lr,
                                 betas=(0.9, 0.99))
    optimizer_D = torch.optim.Adam(dnet.parameters(),
                                   lr=opt.lr * 1.3,
                                   betas=(0.9, 0.99))
    optimizer_DP = torch.optim.Adam(pdnet.parameters(),
                                    lr=opt.lr * 1.3,
                                    betas=(0.9, 0.99))

    criterion = CriterionPerPixel(use_gram=True)
    criterion_D = criterion_GAN(use_lsgan=False)
    criterionGAN = criterion_GAN(use_lsgan=False)

    losses = []
    losses_valid = []
    losses_hole = []
    losses_perceptual = []
    losses_style = []
    losses_gan = []
    losses_d = []
    losses_dp = []

    for epoch in range(opt.epochs):
        for idx, data in enumerate(custom_loader):
            for i, item in enumerate(data):
                data[i] = normalize(item)
            img = data[0].type(torch.FloatTensor).to(device)
            mask = data[1].type(torch.FloatTensor).to(device)
            masked = data[2].type(torch.FloatTensor).to(device)
            # img_2x = data[3].type(torch.FloatTensor).to(device)
            # mask_2x = data[4].type(torch.FloatTensor).to(device)

            # dnet
            set_requires_grad(dnet, requires_grad=True)
            dnet.zero_grad()
            pred = Punet(masked, mask)
            fake_out = dnet(torch.cat((masked, pred.detach()), dim=1))
            real_out = dnet(torch.cat((masked, img), dim=1))

            real_loss = criterion_D(real_out, True)
            fake_loss = criterion_D(fake_out, False)

            err_d = (real_loss + fake_loss) / 2
            err_d.backward()
            optimizer_D.step()

            # patch dnet
            set_requires_grad(pdnet, requires_grad=True)
            pdnet.zero_grad()

            fake = torch.cat((img, pred.detach()), dim=1)
            real = torch.cat((img, img), dim=1)

            real_loss = 0
            fake_loss = 0

            coord = []
            rand_x = random.randint(0, 256 - 64)
            rand_y = random.randint(0, 256 - 64)
            coord.append((rand_x, rand_y))
            fake_out = fake[:, :, rand_x:rand_x + 64, rand_y:rand_y + 64]
            real_out = real[:, :, rand_x:rand_x + 64, rand_y:rand_y + 64]
            for i in range(1, 6):
                rand_x = random.randint(0, 256 - 64)
                rand_y = random.randint(0, 256 - 64)
                coord.append((rand_x, rand_y))
                fake_out = torch.cat((fake_out, fake[:, :, rand_x:rand_x + 64,
                                                     rand_y:rand_y + 64]),
                                     dim=0)
                real_out = torch.cat((real_out, real[:, :, rand_x:rand_x + 64,
                                                     rand_y:rand_y + 64]),
                                     dim=0)

            real_loss = criterion_D(pdnet(real_out), True)
            fake_loss = criterion_D(pdnet(fake_out), False)

            err_dp = (real_loss + fake_loss) / 2
            err_dp.backward()
            optimizer_DP.step()

            # gnet
            set_requires_grad(dnet, requires_grad=False)
            set_requires_grad(pdnet, requires_grad=False)
            Punet.zero_grad()
            fake_loss_g = dnet(torch.cat((masked, pred), dim=1))
            l_gan = criterionGAN(fake_loss_g, True)

            l_gan2 = 0
            fake_ = torch.cat((img, pred), dim=1)
            fake_set = fake_[:, :, coord[0][0]:coord[0][0] + 64,
                             coord[0][1]:coord[0][1] + 64]
            for i in range(1, 6):
                rand_x = coord[i][0]
                rand_y = coord[i][1]
                fake_set = torch.cat((fake_set, fake_[:, :, rand_x:rand_x + 64,
                                                      rand_y:rand_y + 64]),
                                     dim=0)

            fake_loss_g = pdnet(fake_set)
            l_gan2 += criterionGAN(fake_loss_g, True)

            loss_valid, loss_hole, loss_perceptual, loss_style = criterion(
                pred, img, mask)
            loss = loss_valid + 3 * loss_hole + 0.7 * loss_perceptual + 50 * loss_style + 0.2 * l_gan + 0.1 * l_gan2
            loss.backward()
            optimizer.step()

            losses.append(loss.item())
            losses_valid.append(loss_valid.item())
            losses_hole.append(loss_hole.item())
            losses_perceptual.append(loss_perceptual.item())
            losses_style.append(loss_style.item())
            losses_gan.append(l_gan.item())
            losses_d.append(err_d.item())
            losses_dp.append(err_dp.item())

            if idx % opt.save_per_iter == 0:
                time_str = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
                torch.save(Punet.state_dict(), './model/' + '_gnet.pth')
                torch.save(dnet.state_dict(), './model/' + '_dnet.pth')
                torch.save(pdnet.state_dict(), './model/' + '_pdnet.pth')
                # print('Model saved.')
            if idx % 200 == 0 or (epoch == opt.epochs - 1
                                  and idx == len(custom_loader) - 1):
                print(
                    "Epoch: {}, step: {}, loss_valid: {:.5f}, loss_hole: {:.5f}, loss_perceptual: {:.5f}, loss_total: {:.5f}, "
                    "loss_style: {:.5f}, l_gan: {:.5f}, l_d: {:.5f}, l_dp: {:.5f}"
                    .format(epoch, idx, np.mean(losses_valid),
                            np.mean(losses_hole), np.mean(losses_perceptual),
                            np.mean(losses), np.mean(losses_style),
                            np.mean(losses_gan), np.mean(losses_d),
                            np.mean(losses_dp)))

                losses.clear()
                losses_valid.clear()
                losses_hole.clear()
                losses_perceptual.clear()
                losses_style.clear()
                losses_gan.clear()
                losses_d.clear()
                losses_dp.clear()

                with torch.no_grad():
                    pred = Punet(test_masked, test_mask)
                    pred = pred.detach().cpu()
                    target = test_img.detach().cpu()
                    masked = test_masked.detach().cpu()

                    pred = de_normalize(pred)
                    target = de_normalize(target)
                    masked = de_normalize(masked)

                sample_pred = vutils.make_grid(pred,
                                               padding=2,
                                               normalize=False)
                sample_target = vutils.make_grid(target,
                                                 padding=2,
                                                 normalize=False)
                sample_masked = vutils.make_grid(masked,
                                                 padding=2,
                                                 normalize=False)
                plt.figure(figsize=(32, 16))
                plt.axis('off')
                plt.title('fake image')
                plt.subplot(3, 1, 1)
                plt.imshow(np.transpose(sample_pred, (1, 2, 0)))
                plt.subplot(3, 1, 2)
                plt.imshow(np.transpose(sample_target, (1, 2, 0)))
                plt.subplot(3, 1, 3)
                plt.imshow(np.transpose(sample_masked, (1, 2, 0)))
                plt.savefig("./sample/epoch_{}_iter_{}.png".format(epoch, idx))
                plt.close()