Esempio n. 1
0
            exit()

    opt = functions.post_config(opt)
    Gs = []
    Zs, NoiseAmp = {}, {}
    reals_list = torch.FloatTensor(num_images, 1, int(opt.nc_im), int(opt.size_image), int(opt.size_image)).cuda()


    for i in range(num_images):

        real = img.imread("%s/%s_%d.png" % (opt.input_dir, opt.input_name[:-4], i))
        real = functions.np2torch(real, opt)
        real = real[:, 0:3, :, :]
        functions.adjust_scales2image(real, opt)

    dir2save = functions.generate_dir2save(opt)
    reals = {}
    genertator = itertools.product((0,),(False,),(0,),(0,),(0, 1, 2, 3))
    lst = list(genertator)
    opt.list_transformations = lst
    print(opt.list_transformations)
    print("num transformations: ", len(lst))

    if opt.mode == 'train':
        train(opt, Gs, Zs, reals, NoiseAmp)
    if dataset == 'mvtec':
        defect_detection(opt.input_name, opt.test_size, opt)
    else:
        print("this file is just for mvtec dataset")
        exit()
def train(opt, Gs, Zs, reals, NoiseAmp):
    reals, Zs, NoiseAmp, in_s, scale_num = functions.collect_reals(
        opt, reals, Zs, NoiseAmp)
    nfc_prev = 0
    for index_image in range(int(opt.num_images)):
        NoiseAmp[index_image] = []

    while scale_num < opt.stop_scale + 1:

        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)),
                      opt.size_image)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          opt.size_image)
        D_curr, G_curr = init_models(opt)

        index_arr_flag = 0

        for epoch in range(opt.num_epochs):

            optimizerD = optim.Adam(D_curr.parameters(),
                                    lr=opt.lr_d,
                                    betas=(opt.beta1, 0.999))
            optimizerG = optim.Adam(G_curr.parameters(),
                                    lr=opt.lr_g,
                                    betas=(opt.beta1, 0.999))
            schedulerD = torch.optim.lr_scheduler.MultiStepLR(
                optimizer=optimizerD, milestones=[500], gamma=opt.gamma)
            schedulerG = torch.optim.lr_scheduler.MultiStepLR(
                optimizer=optimizerG, milestones=[1000], gamma=opt.gamma)

            print(" ")
            print("this is class: ", opt.pos_class)
            print("size image: ", opt.size_image)
            print("num images: ", opt.num_images)
            print("index of download: ", opt.index_download)
            print("num transforms: ", opt.num_transforms)
            print(" ")

            index_image = range(int(opt.num_images))
            G_curr = functions.reset_grads(G_curr, True)
            D_curr = functions.reset_grads(D_curr, True)

            opt.out_ = functions.generate_dir2save(opt)
            opt.global_outf = '%s/%d' % (opt.out_, scale_num)
            if os.path.exists(opt.global_outf):
                shutil.rmtree(opt.global_outf)
            opt.outf = [
                '%s/%d/index_image_%d' % (opt.out_, scale_num, id)
                for id in index_image
            ]
            try:
                for j in opt.outf:
                    try:
                        os.makedirs(j)
                    except:
                        pass
            except OSError as err:
                print("OS error: {0}".format(err))
                pass

            for id in index_image:
                plt.imsave('%s/real_scale%d.png' % (opt.outf[id], id),
                           functions.convert_image_np(reals[id][scale_num]),
                           vmin=0,
                           vmax=1)
            if (nfc_prev == opt.nfc):
                G_curr.load_state_dict(
                    torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
                D_curr.load_state_dict(
                    torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))
                nfc_prev = 0

            if not index_arr_flag:
                real = torch.cat([reals[id][scale_num] for id in index_image],
                                 dim=0)
                opt.nzx = real.shape[2]
                opt.nzy = real.shape[3]
                opt.receptive_field = opt.ker_size + (
                    (opt.ker_size - 1) *
                    (opt.num_layer - 1)) * opt.stride  # 11
                pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)  # 5
                m_noise = nn.ZeroPad2d(int(pad_noise))
                z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                                 device=opt.device)
                z_opt = m_noise(
                    z_opt.expand(real.shape[0], 3, opt.nzx, opt.nzy))

            elif index_arr_flag:
                z_opt = torch.cat(([Zs[id][scale_num] for id in index_image]),
                                  dim=0)

            in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs,
                                              in_s, NoiseAmp, opt, index_image,
                                              z_opt, optimizerD, optimizerG,
                                              schedulerD, schedulerG,
                                              index_arr_flag)
            if not index_arr_flag:
                for id in index_image:
                    Zs[id].append(z_opt[id:id + 1])
                    NoiseAmp[id].append(opt.noise_amp[id:id + 1])
            else:
                for id in index_image:
                    NoiseAmp[id][scale_num] = opt.noise_amp[id:id + 1]

            index_arr_flag = True

            G_curr = functions.reset_grads(G_curr, False)
            G_curr.eval()
            D_curr = functions.reset_grads(D_curr, False)
            D_curr.eval()

        Gs.append(G_curr)
        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr, optimizerD, optimizerG
        torch.cuda.empty_cache()

    return
Esempio n. 3
0
def SinGAN_generate(Gs,Zs,reals,NoiseAmp,opt,in_s=None,scale_v=1,scale_h=1,n=0,gen_start_scale=0,num_samples=200):

    if in_s is None:
        in_s = torch.full(reals[0][0].shape, 0, device=opt.device, dtype=torch.long)
    images_cur = []
    index_image = range(int(opt.num_images))

    for scale_idx, (G) in enumerate(Gs):
        Z_opt = torch.cat([Zs[idx][scale_idx] for idx in index_image], dim=0)
        noise_amp = torch.cat(([NoiseAmp[id][scale_idx] for id in range(opt.num_images)]), dim=0).cuda()
        pad1 = ((opt.ker_size-1)*opt.num_layer)/2
        m = nn.ZeroPad2d(int(pad1))
        nzx = (Z_opt.shape[2]-pad1*2)*scale_v
        nzy = (Z_opt.shape[3]-pad1*2)*scale_h
        images_prev = images_cur
        images_cur = []

        for i in range(0,num_samples,1):
            if n == 0:
                z_curr = functions.generate_noise([1,nzx,nzy], device=opt.device)
                z_curr = z_curr.expand(1,3,z_curr.shape[2],z_curr.shape[3])
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device)
                z_curr = m(z_curr)

            if images_prev == []:
                I_prev = m(in_s)
            else:
                I_prev_temp = images_prev[i]
                I_prev = imresize(torch.unsqueeze(I_prev_temp[0], dim=0), 1 / opt.scale_factor, opt)
                for id in range(1, opt.num_images):
                    I_prev = torch.cat((I_prev, imresize(torch.unsqueeze(I_prev_temp[id], dim=0), 1 / opt.scale_factor, opt)),
                                    dim=0)
                if opt.mode != "SR":
                    I_prev = I_prev[:, :, 0:round(scale_v * reals[0][n].shape[2]), 0:round(scale_h * reals[0][n].shape[3])]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                    I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
                else:
                    I_prev = m(I_prev)

            if n < gen_start_scale:
                z_curr = Z_opt

            noise_amp_tensor = torch.full([1, z_curr.shape[1], z_curr.shape[2], z_curr.shape[3]], noise_amp[0][0].item(),
                                          dtype=torch.float).cuda()
            for j in range(1, opt.num_images):
                temp = torch.full([1, z_curr.shape[1], z_curr.shape[2], z_curr.shape[3]],
                                  noise_amp[j][0].item(), dtype=torch.float).cuda()
                noise_amp_tensor = torch.cat((noise_amp_tensor, temp), dim=0)
            z_in = noise_amp_tensor*(z_curr)+I_prev
            padded_id_Z_opt = pad_image_id(z_in, index_image)


            I_curr = G(padded_id_Z_opt.detach(),I_prev)

            if n == len(Gs)-1:
                if opt.mode == 'train':
                    dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale)
                else:
                    dir2save = functions.generate_dir2save(opt)
                try:
                    os.makedirs(dir2save)
                except OSError:
                    pass
                if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"):
                    for j in range(opt.num_images):

                        plt.imsave('%s/%d%d.png' % (dir2save, i,j), functions.convert_image_np(I_curr[j].unsqueeze(dim=0).detach()), vmin=0,vmax=1)
            images_cur.append(I_curr)
        n+=1
    return I_curr.detach()