예제 #1
0
    def __init__(self, cfg):
        super(RefSRSolver, self).__init__(cfg)

        self.srntt = SRNTT(cfg['model']['n_resblocks'],
                           cfg['schedule']['use_weights'],
                           cfg['schedule']['concat']).cuda()
        # self.discriminator = None
        self.discriminator = Discriminator(cfg['data']['input_size']).cuda()
        # self.vgg = None
        self.vgg = VGG19(cfg['model']['final_layer'],
                         cfg['model']['prev_layer'], True).cuda()
        params = list(self.srntt.texture_transfer.parameters()) + list(self.srntt.texture_fusion_medium.parameters()) +\
                 list(self.srntt.texture_fusion_large.parameters()) + list(self.srntt.srntt_out.parameters())
        self.init_epoch = self.cfg['schedule']['init_epoch']
        self.num_epochs = self.cfg['schedule']['num_epochs']
        self.optimizer_init = torch.optim.Adam(params,
                                               lr=cfg['schedule']['lr'])
        self.optimizer = torch.optim.lr_scheduler.MultiStepLR(
            torch.optim.Adam(params, lr=cfg['schedule']['lr']),
            [self.num_epochs // 2], 0.1)
        self.optimizer_d = torch.optim.lr_scheduler.MultiStepLR(
            torch.optim.Adam(self.discriminator.parameters(),
                             lr=cfg['schedule']['lr']), [self.num_epochs // 2],
            0.1)
        self.reconst_loss = nn.L1Loss()
        self.bp_loss = BackProjectionLoss()
        self.texture_loss = TextureLoss(self.cfg['schedule']['use_weights'],
                                        80)
        self.adv_loss = AdvLoss(self.cfg['schedule']['is_WGAN_GP'])
        self.loss_weights = self.cfg['schedule']['loss_weights']
예제 #2
0
        recall = TP / (TP + FN)
        specificity = TN / (FP + TN)

        accuracy = 100 * (correct / total)
        epoch_loss = epoch_loss / n_test

    print(f'Accuracy test on the images : {accuracy}')
    print(f'Test loss : {epoch_loss}')
    print(
        f'Precision : {precision}, Recall : {recall}, Specificity : {specificity}'
    )
    save_intermediate_test(num, accuracy, epoch_loss, precision, recall, specificity,\
                        TP, TN, FP, FN, dic_path, is_fm)
    y_true = np.delete(y_true, [0, 0], axis=0)
    y_score = np.delete(y_score, [0, 0], axis=0)
    return y_true, y_score


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO,
                        format='%(levelname)s: %(message)s')
    args = get_args()

    net = VGG19(n_classes)

    net.to(device=device)

    n = args.num_exp

    main(args=args, num=n, net=net, device=device)
    del net
예제 #3
0
        transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    generator = DualStyleGAN(1024, 512, 8, 2, res_index=6).to(device)
    generator.eval()

    ckpt = torch.load(args.ckpt)
    generator.load_state_dict(ckpt["g_ema"])
    noises_single = generator.make_noise()

    percept = lpips.PerceptualLoss(model="net-lin",
                                   net="vgg",
                                   use_gpu=device.startswith("cuda"))
    vggloss = VGG19().to(device).eval()

    print('Load models successfully!')

    datapath = os.path.join(args.data_path, args.style, 'images/train')
    exstyles_dict = np.load(args.exstyle_path, allow_pickle='TRUE').item()
    instyles_dict = np.load(args.instyle_path, allow_pickle='TRUE').item()
    files = list(exstyles_dict.keys())

    dict = {}
    for ii in range(0, len(files), args.batch):
        batchfiles = files[ii:ii + args.batch]
        imgs = []
        exstyles = []
        instyles = []
        for file in batchfiles:
def main(config):

    print(config)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform = transforms.Compose([
        transforms.Resize((config.image_size, config.image_size)),
        transforms.ToTensor(),
    ])

    train_dataset = TrainDataset(train_dir=config.train_dir,
                                 style_image=config.style_image,
                                 transforms=transform)
    test_dataset = TestDataset(test_dir=config.test_dir, transforms=transform)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.num_workers)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1)

    vgg19 = VGG19().to(device)
    # Loading pretrained model
    vgg19.load_model(config.vgg)
    # Fix parameters
    vgg19.eval()

    generator = Generator(image_size=config.image_size).to(device)
    # generator = Generator().to(device)
    CX_loss = CXLoss(sigma=0.5).to(device)

    if config.load_model:
        generator.load_state_dict(torch.load(config.load_model))

    optimizer = optim.Adam(params=generator.parameters(), lr=config.lr)

    if not exists(config.ckpt_path):
        os.makedirs(config.ckpt_path)

    if not exists(config.result_path):
        os.makedirs(config.result_path)

    for epoch in range(1 + config.start_idx, config.epochs + 1):

        for i, data in enumerate(train_loader):
            source = data['source'].to(device).float()
            style = data['style'].to(device).float()

            optimizer.zero_grad()

            vgg_source = vgg19(source)
            vgg_style = vgg19(style)

            # fake = generator(source, style)
            fake = generator(source)
            vgg_fake = vgg19(fake)

            cx_style_loss = 0
            for s in style_layer:
                cx_style_loss += CX_loss(vgg_style[s], vgg_fake[s])
            cx_style_loss *= config.lambda_style

            cx_content_loss = 0
            for s in content_layer:
                cx_content_loss += CX_loss(vgg_source[s], vgg_fake[s])
            cx_content_loss *= config.lambda_content

            loss = cx_style_loss + cx_content_loss

            loss.backward()
            optimizer.step()

            if i % 100 == 0:
                print(
                    "Epoch: %d/%d | Step: %d/%d | Style loss: %f | Content loss: %f | Loss: %f"
                    % (epoch, config.epochs, i, len(train_loader),
                       cx_style_loss.item(), cx_content_loss.item(),
                       loss.item()))

            if (i + 1) % 500 == 0:
                torch.save(generator.state_dict(),
                           join(config.ckpt_path, 'epoch-%d.pkl' % epoch))

            if i % 100 == 0:
                plt.subplot(131)
                plt.imshow(tensor2image(source))
                plt.title('source')
                plt.subplot(132)
                plt.imshow(tensor2image(style))
                plt.title('style')
                plt.subplot(133)
                plt.imshow(tensor2image(fake))
                plt.title('fake')
                plt.tight_layout()
                plt.savefig(
                    join(config.result_path,
                         'epoch-%d-step-%d.png' % (epoch, i)))

        generator.eval()
        result_path = join(config.result_path, 'epoch-%d' % epoch)
        if not exists(result_path):
            os.makedirs(result_path)
        for i, data in enumerate(test_loader):
            source = data['source'].to(device).float()
            # fake = generator(source, style)
            fake = generator(source)
            plt.subplot(131)
            plt.imshow(tensor2image(source))
            plt.title('content')
            plt.subplot(132)
            plt.imshow(tensor2image(style))
            plt.title('style')
            plt.subplot(133)
            plt.imshow(tensor2image(fake))
            plt.title('fake')
            plt.savefig(join(result_path, 'step-%d.png' % (i + 1)))
        generator.train()
예제 #5
0
if not exists(save_path):
    makedirs(save_path)

input_files = sorted(glob(join(input_path, '*.png')))
ref_files = sorted(glob(join(ref_path, '*.png')))
n_files = len(input_files)
assert n_files == len(ref_files)

srntt = SRNTT(16).cuda()
print('Loading SRNTT ...')
ckpt = torch.load('/home/zwj/Projects/Python/SRNTT_Pytorch/log/srntt_vgg19_div2k/2019-09-20-10:06:34/' +
                  'checkpoint/best.pth')
srntt.load_state_dict(ckpt['srntt'])
print('Done.')
print('Loading VGG19 ...')
net_vgg19 = VGG19('relu_5-1', ['relu_1-1', 'relu_2-1', 'relu_3-1'], True).cuda()
print('Done.')
swaper = Swap(3, 1)

print_format = '%%0%dd/%%0%dd' % (len(str(n_files)), len(str(n_files)))
for i in range(n_files):
    file_name = join(save_path, split(input_files[i])[-1].replace('.png', '.npz'))
    if exists(file_name):
        continue
    print(print_format % (i + 1, n_files))
    img_in_lr = imresize(imread(input_files[i], mode='RGB'), (input_size, input_size), interp='bicubic')
    img_in_lr = img_in_lr.astype(np.float32) / 127.5 - 1
    img_ref = imresize(imread(ref_files[i], mode='RGB'), (input_size * 4, input_size * 4), interp='bicubic')
    img_ref = img_ref.astype(np.float32) / 127.5 - 1
    img_ref_lr = imresize(img_ref, (input_size, input_size), interp='bicubic')
    img_ref_lr = img_ref_lr.astype(np.float32) / 127.5 - 1