Beispiel #1
0
def train_paint(opt, Gs, Zs, reals, NoiseAmp, centers, paint_inject_scale):
    in_s = torch.full(reals[0].shape, 0, device=opt.device)
    scale_num = 0
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        if scale_num != paint_inject_scale:
            scale_num += 1
            nfc_prev = opt.nfc
            continue
        else:
            opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
            opt.min_nfc = min(
                opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128)

            opt.out_ = functions.generate_dir2save(opt)
            opt.outf = '%s/%d' % (opt.out_, scale_num)
            try:
                os.makedirs(opt.outf)
            except OSError:
                pass

            #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
            #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
            plt.imsave('%s/in_scale.png' % (opt.outf),
                       functions.convert_image_np(reals[scale_num]),
                       vmin=0,
                       vmax=1)

            D_curr, G_curr = init_models(opt)

            z_curr, in_s, G_curr = train_single_scale(D_curr,
                                                      G_curr,
                                                      reals[:scale_num + 1],
                                                      Gs[:scale_num],
                                                      Zs[:scale_num],
                                                      in_s,
                                                      NoiseAmp[:scale_num],
                                                      opt,
                                                      centers=centers)

            G_curr = functions.reset_grads(G_curr, False)
            G_curr.eval()
            D_curr = functions.reset_grads(D_curr, False)
            D_curr.eval()

            Gs[scale_num] = G_curr
            Zs[scale_num] = z_curr
            NoiseAmp[scale_num] = opt.noise_amp

            torch.save(Zs, '%s/Zs.pth' % (opt.out_))
            torch.save(Gs, '%s/Gs.pth' % (opt.out_))
            torch.save(reals, '%s/reals.pth' % (opt.out_))
            torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

            scale_num += 1
            nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Beispiel #2
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        '''
        if (scale_num == opt.stop_scale):
            opt.nfc = 128
            opt.min_nfc = 128
        '''
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/in_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Beispiel #3
0
def train(opt,Gs,Zs,reals,NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    # cur_scale_level: current level from coarest to finest.
    cur_scale_level = 0
    # scale1: for the largest patch size, what ratio wrt the image shape
    reals = functions.creat_reals_pyramid(real_,reals,opt)
    nfc_prev = 0

    # Train including opt.stop_scale
    while cur_scale_level < opt.stop_scale+1:
        # nfc: number of out channels in conv block
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)

        # out_: output directory
        # outf: output folder, with scale
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,cur_scale_level)
        try:
            os.makedirs(opt.outf)
        except OSError:
                pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' %  (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1)

        D_curr,G_curr = init_models(opt)
        # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper)
        if (nfc_prev==opt.nfc):
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1)))

        # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor.
        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr,False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr,False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        cur_scale_level+=1
        nfc_prev = opt.nfc
        del D_curr,G_curr
        torch.cuda.empty_cache()
    return
Beispiel #4
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(
        opt)  #将输入的png图像转变为行 列 通道 像素这样的tensor之后,作归一化,值都在[-1,1]之间
    #通过norm的操作和clamp函数的功能
    #print(real_)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)  #开始对图像作变形
    reals = functions.creat_reals_pyramid(real, reals,
                                          opt)  #这个金字塔开始对图像的通道数作变化,以适应不同特征塔
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        #print('real_ value is {}'.format(real_))
        #print('reals value is {}'.format(reals))
        #print('scale_num value is {}'.format(scale_num))
        #print('stop_scale value is {}'.format(opt.stop_scale))
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        #print('opt.nfc value is {}'.format(opt.nfc))
        #print('opt.min_nfc value is {}'.format(opt.min_nfc))
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)  #ZS噪声图
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Beispiel #5
0
def train(opt,Gs,Zs,reals1, reals2,NoiseAmp):
    logger.info("Starting to train...")

    reals1 = get_reals(reals1, opt, opt.input_name1)
    reals2 = get_reals(reals2, opt, opt.input_name2)
    in_s1 = 0
    in_s2 = 0
    scale_num = 0
    nfc_prev = 0

    while scale_num<opt.stop_scale+1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
                pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale1.png' %  (opt.outf), functions.convert_image_np(reals1[scale_num]), vmin=0, vmax=1)
        plt.imsave('%s/real_scale2.png' % (opt.outf), functions.convert_image_np(reals2[scale_num]), vmin=0, vmax=1)

        D_curr, D_mask1_curr, D_mask2_curr, G_curr = init_models(opt, reals1[len(Gs)].shape)
        if (nfc_prev==opt.nfc):
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,scale_num-1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1)))
            D_mask1_curr.load_state_dict(torch.load('%s/%d/netD_mask1.pth' % (opt.out_, scale_num - 1)))
            D_mask2_curr.load_state_dict(torch.load('%s/%d/netD_mask2.pth' % (opt.out_, scale_num - 1)))

        logger.info(f"Starting to train scale {scale_num}")
        z_curr_tuple, in_s_tuple, G_curr = train_single_scale(D_curr, D_mask1_curr, D_mask2_curr,G_curr,reals1, reals2, Gs,Zs,in_s1, in_s2,NoiseAmp,opt,
                                                              )
        in_s1, in_s2 = in_s_tuple
        logger.info(f"Done training scale {scale_num}")

        G_curr = functions.reset_grads(G_curr,False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr,False)
        D_curr.eval()
        D_mask1_curr = functions.reset_grads(D_mask1_curr,False)
        D_mask1_curr.eval()
        D_mask2_curr = functions.reset_grads(D_mask2_curr,False)
        D_mask2_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr_tuple)
        NoiseAmp.append((opt.noise_amp1, opt.noise_amp2))

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals1, '%s/reals1.pth' % (opt.out_))
        torch.save(reals2, '%s/reals2.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num+=1
        nfc_prev = opt.nfc
        del D_curr, D_mask1_curr, D_mask2_curr,G_curr
    return
Beispiel #6
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0
    #creating a pyramid of masks the same way we did for the img and thus to train on only the correct pixels
    #at all scales
    if opt.inpainting:
        m = functions.read_image_dir(
            '%s/%s_mask%s' %
            (opt.ref_dir, opt.input_name[:-4], opt.input_name[-4:]), opt)
        m = imresize(m, opt.scale1, opt)
        m_s = []  #pyramid of masks
        opt.m_s = functions.creat_reals_pyramid(m, m_s, opt)

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Beispiel #7
0
def train(opt,Gs,Zs,reals,NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    # cur_scale_level: current level from coarest to finest.
    cur_scale_level = 0
    # scale1: for the largest patch size, what ratio wrt the image shape
    reals = functions.creat_reals_pyramid(real_,reals,opt)
    nfc_prev = 0

    # Train including opt.stop_scale
    while cur_scale_level < opt.stop_scale+1:
        # nfc: number of out channels in conv block
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)

        # out_: output directory
        # outf: output folder, with scale
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,cur_scale_level)
        try:
            os.makedirs(opt.outf)
        except OSError:
                pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' %  (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1)

        D_curr,G_curr = init_models(opt)
        # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper)
        if (nfc_prev==opt.nfc):
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1)))

        # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor.
        if fine_tune:
          z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, warmup_steps)
        else:
          z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter)


        G_curr = functions.reset_grads(G_curr,False)
        # D_curr = functions.reset_grads(D_curr,False)
        G_curr.eval()
        # D_curr.eval()

        #################################################################################
        # Visualzie weights
        def visualize_weights(modules, fig_name):
            ori_weights = torch.tensor([]).cuda()
            for m in modules:
                cur_params = m.weight.data.flatten()
                ori_weights = torch.cat((ori_weights, cur_params))
                # cur_params = m.bias.data.flatten()
                # ori_weights = torch.cat((ori_weights, cur_params))
            sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement())
            print(sparsity, ori_weights.nelement())
            ori_weights = ori_weights.cpu().numpy()
            ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100)
            plt.savefig("%s/%s.png" % (opt.outf, fig_name))
            plt.close()

        # Pruning weights Structured or Non-structured
        if not structured:
            modules = [G_curr.head.conv, G_curr.head.norm,
                    G_curr.body.block1.conv, G_curr.body.block1.norm,
                    G_curr.body.block2.conv, G_curr.body.block2.norm,
                    G_curr.body.block3.conv, G_curr.body.block3.norm,
                    G_curr.tail[0]]
            parameters_to_prune = (
                (G_curr.head.conv, 'weight'),
                (G_curr.head.norm, 'weight'),
                (G_curr.body.block1.conv, 'weight'),
                (G_curr.body.block1.norm, 'weight'),
                (G_curr.body.block2.conv, 'weight'),
                (G_curr.body.block2.norm, 'weight'),
                (G_curr.body.block3.conv, 'weight'),
                (G_curr.body.block3.norm, 'weight'),
                (G_curr.tail[0], 'weight'),
                (G_curr.head.conv, 'bias'),
                (G_curr.head.norm, 'bias'),
                (G_curr.body.block1.conv, 'bias'),
                (G_curr.body.block1.norm, 'bias'),
                (G_curr.body.block2.conv, 'bias'),
                (G_curr.body.block2.norm, 'bias'),
                (G_curr.body.block3.conv, 'bias'),
                (G_curr.body.block3.norm, 'bias'),
                (G_curr.tail[0], 'bias'),
            )

            visualize_weights(modules, 'ori')

            # Prune weights
            prune.global_unstructured(
                parameters_to_prune,
                pruning_method=prune.L1Unstructured,
                amount=pruning_amount,
            )
        else:
            modules = [G_curr.head.conv,
            G_curr.body.block1.conv,
            G_curr.body.block2.conv,
            G_curr.body.block3.conv]

            visualize_weights(modules, 'ori')
            # pytorch_total_params = sum(p.numel() for p in G_curr.parameters())
            # print(pytorch_total_params)

            for module in modules:
                m = prune.ln_structured(module, name="weight", amount=pruning_amount, n=1, dim=0)
                # m = prune.ln_structured(module, name="bias", amount=pruning_amount, n=1, dim=0)

        torch.save(G_curr.state_dict(), '%s/raw_prune_netG.pth' % (opt.outf))
        visualize_weights(modules, 'raw-prune')
        if cur_scale_level > 0:
            fake_Gs = Gs.copy()
            fake_Gs.append(G_curr) 
            fake_Zs = Zs.copy()
            fake_Zs.append(z_curr)
            fake_noise = NoiseAmp.copy()
            fake_noise.append(opt.noise_amp)
            fake_reals = reals[:cur_scale_level+1].copy()
            prune_SinGAN_generate(fake_Gs, fake_Zs, fake_reals, fake_noise, opt, gen_start_scale=0, num_samples=1, level=cur_scale_level)

        # Fine-tuning
        if fine_tune:
            G_curr = functions.reset_grads(G_curr, True)
            G_curr.train()

            if not structured:
                # Keep training using inherited weights
                z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter - warmup_steps, prune=True)
            else:
                # Training from scratch
                # G_curr.apply(models.weights_init)
                # D_curr.apply(models.weights_init)
                z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter, prune=True)
        G_curr = functions.reset_grads(G_curr,False)
        G_curr.eval()
        visualize_weights(modules, 'fine-tune')

        for m in modules:
            prune.remove(m, 'weight')
            if not structured:
              prune.remove(m, 'bias')
        
        # pytorch_total_params = sum(p.numel() for p in G_curr.parameters())
        # print(pytorch_total_params)

        #################################################################################
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        cur_scale_level+=1
        nfc_prev = opt.nfc
        del D_curr,G_curr
    return
Beispiel #8
0
def train(opt,Gs,Zs,reals,NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_,opt.scale1,opt)
    print ("real.shape:",real.shape)  #real.shape: torch.Size([1, 3, 186, 248])
    reals = functions.creat_reals_pyramid(real,reals,opt)
    for i in reals: print (i.shape)
    '''
    torch.Size([1, 3, 20, 27])
    torch.Size([1, 3, 27, 36])
    torch.Size([1, 3, 35, 47])
    torch.Size([1, 3, 47, 62])
    torch.Size([1, 3, 61, 82])
    torch.Size([1, 3, 81, 108])
    torch.Size([1, 3, 107, 143])
    torch.Size([1, 3, 141, 188])
    torch.Size([1, 3, 186, 248])
    '''
    nfc_prev = 0
    print ("total %d scales.."% (opt.stop_scale))
    while scale_num<opt.stop_scale+1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        if opt.fast_training:
            if (scale_num > 0) & (scale_num % 4==0):
                opt.niter = opt.niter//2

        '''
        if (scale_num == opt.stop_scale):
            opt.nfc = 128
            opt.min_nfc = 128
        '''
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,scale_num)
        print ("out dir:",opt.outf)
        try:
            os.makedirs(opt.outf)
        except OSError:
                pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' %  (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1)
        #生成D和G,由于是全卷积网络,不需要关心输入大小
        D_curr,G_curr = init_models(opt)
        if (nfc_prev==opt.nfc):  #如果两次网络中的层数一样,则可以finetune上一个scale的网络参数
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,scale_num-1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1)))
            
        #对该scale的网络进行训练,这里每训练一个scale的网络就换,保持显存占用一直很低,不到3g
        z_curr,in_s,G_curr = train_single_scale(D_curr,G_curr,reals,Gs,Zs,in_s,NoiseAmp,opt)

        G_curr = functions.reset_grads(G_curr,False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr,False)
        D_curr.eval()
        
        #保留训练后的G网络
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num+=1
        nfc_prev = opt.nfc
        
        #这里将D和G网络删除,减小显存?
        del D_curr,G_curr
    return
Beispiel #9
0
def train(opt, Gs, Zs, reals, crops, masks, NoiseAmp):
    real_ = functions.read_image(opt)
    real = imresize(real_, opt.scale1, opt)

    #real, _ , _ = functions.random_crop(real, opt.crop_size)
    mask_ = functions.read_mask(opt)
    #eye_ = functions.generate_eye_mask(opt, mask_, 0)
    crop_ = torch.zeros(
        (1, 1, opt.crop_size,
         opt.crop_size))  #Used just for size reference when downsizing
    #eye_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
    eye_color = functions.get_eye_color(real)
    opt.eye_color = eye_color
    #torch.autograd.set_detect_anomaly(True)

    in_s = 0
    scale_num = 0

    reals = functions.create_pyramid(real, reals, opt)
    masks = functions.create_pyramid(mask_, masks, opt, mode="mask")
    #eyes = functions.create_pyramid(eye_,eyes,opt, mode = "mask")

    # Shortcut to get sizes of corresponding crops for each scale
    crops = functions.create_pyramid(crop_, crops, opt, mode="mask")

    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, crops,
                                                  masks, eye_color, Gs, Zs,
                                                  in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr

    return
Beispiel #10
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    # cur_scale_level: current level from coarest to finest.
    cur_scale_level = 0
    # real = imresize(real_,opt.scale1,opt)
    # scale1: for the largest patch size, what ratio wrt the image shape
    reals = functions.creat_reals_pyramid(real_, reals, opt)
    upsamples = [reals[0]]
    diffs = [reals[0]]

    # Need to generate opt.stop_scale, thus upsample opt.stop_scale-1
    for i in range(opt.stop_scale):
        cur_img = reals[i]
        next_img = reals[i + 1]
        _, b, c, d = next_img.shape
        upsampled_real = imresize_to_shape(cur_img, (c, d, b), opt)
        upsamples.append(upsampled_real)
        diff = (next_img - upsampled_real).abs() - 1  # [-1, 1]
        diffs.append(diff)

    # Train including opt.stop_scale
    while cur_scale_level < opt.stop_scale + 1:
        # nfc: number of out channels in conv block
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)),
                      128)
        opt.min_nfc = min(
            opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)

        # out_: output directory
        # outf: output folder, with scale
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, cur_scale_level)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[cur_scale_level]),
                   vmin=0,
                   vmax=1)
        plt.imsave('%s/diff.png' % (opt.outf),
                   functions.convert_image_np(diffs[cur_scale_level].detach()),
                   vmin=0,
                   vmax=1)
        plt.imsave('%s/upsampled.png' % (opt.outf),
                   functions.convert_image_np(
                       upsamples[cur_scale_level].detach()),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper)

        # No need to reload, since training in parallel
        # if (nfc_prev==opt.nfc):
        #     G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1)))
        #     D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1)))

        # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor.
        # z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt)
        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals,
                                                  upsamples, cur_scale_level,
                                                  in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        cur_scale_level += 1
        # nfc_prev = opt.nfc
        del D_curr, G_curr
        torch.cuda.empty_cache()
    return
Beispiel #11
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)

    # 不同规格数据形成的列表
    reals = functions.creat_reals_pyramid(real, reals, opt)
    # print('reals', reals)  # 各个scale的图形形成的列表

    # plt.imsave('Output/real_scale_0.png', functions.convert_image_np(reals[0]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_1.png', functions.convert_image_np(reals[1]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_2.png', functions.convert_image_np(reals[2]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_3.png', functions.convert_image_np(reals[3]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_4.png', functions.convert_image_np(reals[4]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_5.png', functions.convert_image_np(reals[5]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_6.png', functions.convert_image_np(reals[6]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_7.png', functions.convert_image_np(reals[7]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_8.png', functions.convert_image_np(reals[8]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_9.png', functions.convert_image_np(reals[9]), vmin=0, vmax=1)

    nfc_prev = 0

    # opt.stop_scale = 9   循环9次
    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        print('opt.nfc', opt.nfc)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        print('opt.min_nfc', opt.min_nfc)
        if opt.fast_training:
            if (scale_num > 0) & (scale_num % 4 == 0):
                opt.niter = opt.niter // 2

        # out_是生成根路径
        opt.out_ = functions.generate_dir2save(opt)
        # outf是每个scale路径
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        # plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        # 保存原始图像
        plt.imsave('%s/original.png' % (opt.out_),
                   functions.convert_image_np(real_),
                   vmin=0,
                   vmax=1)
        # 在每个scale中保存real_scale
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        # return netD, netG  目前的D和G,D_curr
        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        # train_single_scale()返回:z_opt, in_s, netG
        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        print('G_curr', G_curr)
        G_curr.eval()
        print(G_curr.eval())
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Beispiel #12
0
def invert_model(test_image,
                 model_name,
                 scales2invert=None,
                 penalty=1e-3,
                 show=True):
    '''test_image is an array, model_name is a name'''
    Noise_Solutions = []

    parser = get_arguments()
    parser.add_argument('--input_dir',
                        help='input image dir',
                        default='Input/Images')

    parser.add_argument('--mode', default='RandomSamples')
    opt = parser.parse_args("")
    opt.input_name = model_name
    opt.reg = penalty

    if model_name == 'islands2_basis_2.jpg':  #HARDCODED
        opt.scale_factor = 0.6

    opt = functions.post_config(opt)

    ### Loading in Generators
    Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt)
    for G in Gs:
        G = functions.reset_grads(G, False)
        G.eval()

    ### Loading in Ground Truth Test Images
    reals = []  #deleting old real images
    real = functions.np2torch(test_image, opt)
    functions.adjust_scales2image(real, opt)

    real_ = functions.np2torch(test_image, opt)
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)

    ### General Padding
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    m_noise = nn.ZeroPad2d(int(pad_noise))

    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    m_image = nn.ZeroPad2d(int(pad_image))

    I_prev = None
    REC_ERROR = 0

    if scales2invert is None:
        scales2invert = opt.stop_scale + 1

    for scale in range(scales2invert):
        #for scale in range(3):

        #Get X, G
        X = reals[scale]
        G = Gs[scale]
        noise_amp = NoiseAmp[scale]

        #Defining Dimensions
        opt.nc_z = X.shape[1]
        opt.nzx = X.shape[2]
        opt.nzy = X.shape[3]

        #getting parameters for prior distribution penalty
        pdf = torch.distributions.Normal(0, 1)
        alpha = opt.reg
        #alpha = 1e-2

        #Defining Z
        if scale == 0:
            z_init = functions.generate_noise(
                [1, opt.nzx, opt.nzy], device=opt.device)  #only 1D noise
        else:
            z_init = functions.generate_noise(
                [3, opt.nzx, opt.nzy],
                device=opt.device)  #otherwise move up to 3d noise

        z_init = Variable(z_init.cuda(),
                          requires_grad=True)  #variable to optimize

        #Building I_prev
        if I_prev == None:  #first scale scenario
            in_s = torch.full(reals[0].shape, 0, device=opt.device)  #all zeros
            I_prev = in_s
            I_prev = m_image(I_prev)  #padding

        else:  #otherwise take the output from the previous scale and upsample
            I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)  #upsamples
            I_prev = m_image(I_prev)
            I_prev = I_prev[:, :, 0:X.shape[2] + 10, 0:X.shape[
                3] + 10]  #making sure that precision errors don't mess anything up
            I_prev = functions.upsampling(I_prev, X.shape[2] + 10, X.shape[3] +
                                          10)  #seems to be redundant

        LR = [2e-3, 2e-2, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1]
        Zoptimizer = torch.optim.RMSprop([z_init],
                                         lr=LR[scale])  #Defining Optimizer
        x_loss = []  #for plotting
        epochs = []  #for plotting

        niter = [
            200, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400,
            400, 400
        ]
        for epoch in range(niter[scale]):  #Gradient Descent on Z

            if scale == 0:
                noise_input = m_noise(z_init.expand(1, 3, opt.nzx,
                                                    opt.nzy))  #expand and padd
            else:
                noise_input = m_noise(z_init)  #padding

            z_in = noise_amp * noise_input + I_prev
            G_z = G(z_in, I_prev)

            x_recLoss = F.mse_loss(G_z, X)  #MSE loss

            logProb = pdf.log_prob(z_init).mean()  #Gaussian loss

            loss = x_recLoss - (alpha * logProb.mean())

            Zoptimizer.zero_grad()
            loss.backward()
            Zoptimizer.step()

            #losses['rec'].append(x_recLoss.data[0])
            #print('Image loss: [%d] loss: %0.5f' % (epoch, x_recLoss.item()))
            #print('Noise loss: [%d] loss: %0.5f' % (epoch, z_recLoss.item()))
            x_loss.append(loss.item())
            epochs.append(epoch)

            REC_ERROR = x_recLoss

        if show:
            plt.plot(epochs, x_loss, label='x_loss')
            plt.legend()
            plt.show()

        I_prev = G_z.detach(
        )  #take final output, maybe need to edit this line something's very very fishy

        _ = show_image(X, show, 'target')
        reconstructed_image = show_image(I_prev, show, 'output')
        _ = show_image(noise_input.detach().cpu(), show, 'noise')

        Noise_Solutions.append(noise_input.detach())
    return Noise_Solutions, reconstructed_image, REC_ERROR
Beispiel #13
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0  # iterator through the pyramid
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(
            opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128
        )  # every 4 levels in the pyr, double the filter number. 128 as maximum (not too wide.)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        if opt.fast_training:
            if (scale_num > 0) & (
                    scale_num % 4 == 0
            ):  # every 4 scales half the iteration number! (train less for the finer details. )
                opt.niter = opt.niter // 2

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (
                nfc_prev == opt.nfc
        ):  # if channel num match, then load the weights from last scale to init!
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp,
                                                  opt)  # real work

        G_curr = functions.reset_grads(
            G_curr,
            False)  # Gnet no longer requires gradient. Thus weight is frozen!
        G_curr.eval()
        D_curr = functions.reset_grads(
            D_curr, False)  # Note this D_curr is not the trained one...?
        D_curr.eval()

        Gs.append(
            G_curr
        )  # train append after train G at each scale. Note this G is no longer trainable!
        Zs.append(
            z_curr)  # what is Zs? collection of z_opt towards current layer.
        NoiseAmp.append(opt.noise_amp)  # Noise Amplitude is changed inside?

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Beispiel #14
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    #from the name get the picture
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    #scale1 is defined from adjust2scale, saved in opt
    real = imresize(real_, opt.scale1, opt)
    # a list of resized images
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    #for scale 0 to stop scale
    for scale_num in tqdm_notebook(range(opt.stop_scale + 1),
                                   desc=opt.input_name,
                                   leave=True):

        #define the number of channels in this scale
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)

        #define the minimum number of channels in this scale
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        #the output main directory
        opt.out_ = functions.generate_dir2save(opt)

        #the output sub directory for each scale
        opt.outf = f'{opt.out_}/{scale_num}'

        #if need create the directory
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #save the resized original image for this scale
        plt.imsave(f'{opt.outf}/real_scale.png',
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        #create the generator and discriminator
        D_curr, G_curr = init_models(opt)

        #if the number of channel of previous layer = current nfc
        if (nfc_prev == opt.nfc):

            #direct load the weightfrom last model
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        #train a single scale, get the current z, in_s, generator
        z_curr, in_s, G_curr = train_single_scale(
            D_curr,  #current discriminator
            G_curr,  #current generator
            reals,  #the list of all resized data
            Gs,  #generator list
            Zs,  # a list initialized as []
            in_s,  #
            NoiseAmp,  #
            opt  #parameters
        )

        #make current G and D untrainable,set it into eval mode
        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        # save them into the list
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        #save the checkpoints
        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        nfc_prev = opt.nfc
        #delete the D and G for memory
        del D_curr, G_curr
    return
Beispiel #15
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    #If training for inpainting
    if opt.mode == "inpainting":
        #Importing mask image in space [0,255]

        if opt.on_drive != None:
            mask = img.imread('%s/%s/%s' %
                              (opt.on_drive, opt.input_dir, opt.mask_name))
        else:
            mask = img.imread('%s/%s' % (opt.input_dir, opt.mask_name))

        #Convert mask to [O,1] space, 0 is masked out area, 1 everywhere else
        mask = 1 - (mask / 255)
        #Loading mask to torch tensor
        mask = torch.from_numpy(mask)

        #Resizing the initial mask
        mask = mask[:, :, :, None].view([1, 3, mask.shape[0], mask.shape[1]])
        mask = imresize(mask, opt.scale1, opt)

        #Creating mask pyramid
        opt.masks = functions.creat_reals_pyramid(mask, [], opt)

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        #plt.imsave('%s/real_scale.png' %  (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Beispiel #16
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    memory = []  ##storing memory
    time = []  ##storing time
    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))
        start = datetime.datetime.now()
        z_curr, in_s, G_curr, mbs, percent = train_single_scale(
            D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt)
        memory.append([mbs, percent])
        end = datetime.datetime.now()
        elapsed = end - start
        time.append(elapsed)
        print(f'time: {elapsed}')
        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))
        torch.save(full_memory, '%s/full_memory.pth' % (opt.out_))
        torch.save(full_time, '%s/full_time.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    #torch.save(full_memory, '%s/full_memory.pk' % (opt.out_))
    #torch.save(full_time, '%s/full_time.pk' % (opt.out_))
    print(memory)
    print(time)
    print(full_memory)
    print(full_time)
    #pk.dump(full_memory, open('Full_memory', 'wb'))
    return
Beispiel #17
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    print('train() current parameters')
    print(opt)
    real_ = functions.read_image(opt)
    in_s = 0
    if 'scale_num' in opt and opt.scale_num > 0:
        # EXPERIMENTAL: if we are in 'continue' mode
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    else:
        opt.scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.create_reals_pyramid(real, reals, opt)
    if 'nfc_prev' not in opt:
        opt.nfc_prev = 0

    while opt.scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(opt.scale_num / 4)),
                      128)
        opt.min_nfc = min(
            opt.min_nfc_init * pow(2, math.floor(opt.scale_num / 4)), 128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, opt.scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            print('directory %s already exists' % opt.outf)
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % opt.outf,
                   functions.convert_image_np(reals[opt.scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if opt.nfc_prev == opt.nfc:
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, opt.scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, opt.scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % opt.out_)
        torch.save(Gs, '%s/Gs.pth' % opt.out_)
        torch.save(reals, '%s/reals.pth' % opt.out_)
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % opt.out_)

        opt.scale_num += 1
        opt.nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Beispiel #18
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    # cur_scale_level: current level from coarest to finest.
    cur_scale_level = 0
    # scale1: for the largest patch size, what ratio wrt the image shape
    reals = functions.creat_reals_pyramid(real_, reals, opt)
    nfc_prev = 0

    # Train including opt.stop_scale
    while cur_scale_level < opt.stop_scale + 1:
        # nfc: number of out channels in conv block
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)),
                      128)
        opt.min_nfc = min(
            opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)

        # out_: output directory
        # outf: output folder, with scale
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, cur_scale_level)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[cur_scale_level]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, cur_scale_level - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, cur_scale_level - 1)))

        # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor.
        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()

        #################################################################################
        # Visualzie weights
        def visualize_weights(modules, fig_name):
            ori_weights = torch.tensor([]).cuda()
            for m in modules:
                cur_params = m.weight.data.flatten()
                ori_weights = torch.cat((ori_weights, cur_params))
                cur_params = m.bias.data.flatten()
                ori_weights = torch.cat((ori_weights, cur_params))
            # sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement())
            ori_weights = ori_weights.cpu().numpy()
            ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100)
            plt.savefig("%s/%s.png" % (opt.outf, fig_name))
            plt.close()

        # Pruning all weights
        modules = [
            G_curr.head.conv, G_curr.head.norm, G_curr.body.block1.conv,
            G_curr.body.block1.norm, G_curr.body.block2.conv,
            G_curr.body.block2.norm, G_curr.body.block3.conv,
            G_curr.body.block3.norm, G_curr.tail[0]
        ]
        parameters_to_prune = ((G_curr.head.conv, 'weight'), (G_curr.head.conv,
                                                              'bias'),
                               (G_curr.head.norm, 'weight'), (G_curr.head.norm,
                                                              'bias'),
                               (G_curr.body.block1.conv,
                                'weight'), (G_curr.body.block1.conv, 'bias'),
                               (G_curr.body.block1.norm,
                                'weight'), (G_curr.body.block1.norm, 'bias'),
                               (G_curr.body.block2.conv,
                                'weight'), (G_curr.body.block2.conv, 'bias'),
                               (G_curr.body.block2.norm,
                                'weight'), (G_curr.body.block2.norm, 'bias'),
                               (G_curr.body.block3.conv,
                                'weight'), (G_curr.body.block3.conv,
                                            'bias'), (G_curr.body.block3.norm,
                                                      'weight'),
                               (G_curr.body.block3.norm,
                                'bias'), (G_curr.tail[0],
                                          'weight'), (G_curr.tail[0], 'bias'))

        visualize_weights(modules, 'ori')

        # Prune weights
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=0.2,
        )

        for m in modules:
            prune.remove(m, 'weight')
            prune.remove(m, 'bias')

        visualize_weights(modules, 'prune')
        G_curr.half()
        #################################################################################
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        cur_scale_level += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
        torch.cuda.empty_cache()
    return