Esempio n. 1
0
def train(opt):
    print("Training model with the following parameters:")
    print("\t number of stages: {}".format(opt.train_stages))
    print("\t number of concurrently trained stages: {}".format(opt.train_depth))
    print("\t learning rate scaling: {}".format(opt.lr_scale))
    print("\t non-linearity: {}".format(opt.activation))

    real, real2 = functions.read_two_domains(opt)
    # real = functions.read_image(opt)
    # print(0, real.shape)
    real = functions.adjust_scales2image(real, opt)
    reals = functions.create_reals_pyramid(real, opt)

    real2 = functions.adjust_scales2image(real2, opt)
    reals2 = functions.create_reals_pyramid(real2, opt)

    generator, generator2 = init_G(opt)
    fixed_noise = []
    noise_amp = []
    fixed_noise2 = []
    noise_amp2 = []
    for scale_num in range(opt.stop_scale+1):
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
                print(OSError)
                pass
        functions.save_image('{}/real_scale.jpg'.format(opt.outf), reals[scale_num])

        d_curr, d_curr2 = init_D(opt)
        if scale_num > 0:
            d_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1)))
            generator.init_next_stage()
            d_curr2.load_state_dict(torch.load('%s/%d/netD2.pth' % (opt.out_,scale_num-1)))
            generator2.init_next_stage()

        writer = SummaryWriter(log_dir=opt.outf)
        fixed_noise, noise_amp, generator, d_curr, fixed_noise2, noise_amp2, generator2, d_curr2 = \
            train_single_scale(d_curr, generator, reals, fixed_noise, noise_amp, d_curr2, generator2, reals2,
                               fixed_noise2, noise_amp2, opt, scale_num, writer)

        torch.save(fixed_noise, '%s/fixed_noise.pth' % (opt.out_))
        torch.save(generator, '%s/G.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(noise_amp, '%s/noise_amp.pth' % (opt.out_))
        torch.save(fixed_noise2, '%s/fixed_noise2.pth' % (opt.out_))
        torch.save(generator2, '%s/G2.pth' % (opt.out_))
        torch.save(reals2, '%s/reals2.pth' % (opt.out_))
        torch.save(noise_amp2, '%s/noise_amp2.pth' % (opt.out_))
        del d_curr, d_curr2
    writer.close()
    return
Esempio n. 2
0
def generate_samples(netG,
                     reals_shapes,
                     noise_amp,
                     scale_w=1.0,
                     scale_h=1.0,
                     reconstruct=False,
                     n=50):
    if reconstruct:
        reconstruction = netG(fixed_noise, reals_shapes, noise_amp)
        if opt.train_mode == "generation" or opt.train_mode == "retarget":
            functions.save_image('{}/reconstruction.jpg'.format(dir2save),
                                 reconstruction.detach())
            functions.save_image('{}/real_image.jpg'.format(dir2save),
                                 reals[-1].detach())
        elif opt.train_mode == "harmonization" or opt.train_mode == "editing":
            functions.save_image('{}/{}_wo_mask.jpg'.format(dir2save, _name),
                                 reconstruction.detach())
            functions.save_image(
                '{}/real_image.jpg'.format(dir2save),
                imresize_to_shape(real, reals_shapes[-1][2:], opt).detach())
        return reconstruction

    if scale_w == 1. and scale_h == 1.:
        dir2save_parent = os.path.join(dir2save, "random_samples")
    else:
        reals_shapes = [[
            r_shape[0], r_shape[1],
            int(r_shape[2] * scale_h),
            int(r_shape[3] * scale_w)
        ] for r_shape in reals_shapes]
        dir2save_parent = os.path.join(
            dir2save,
            "random_samples_scale_h_{}_scale_w_{}".format(scale_h, scale_w))

    make_dir(dir2save_parent)

    for idx in range(n):
        noise = functions.sample_random_noise(opt.train_stages - 1,
                                              reals_shapes, opt)
        sample = netG(noise, reals_shapes, noise_amp)
        functions.save_image(
            '{}/gen_sample_{}.jpg'.format(dir2save_parent, idx),
            sample.detach())
Esempio n. 3
0
def generate_samples(netG, opt, depth, noise_amp, writer, reals, iter, n=25):
    opt.out_ = functions.generate_dir2save(opt)
    dir2save = '{}/gen_samples_stage_{}'.format(opt.out_, depth)
    reals_shapes = [r.shape for r in reals]
    all_images = []
    try:
        os.makedirs(dir2save)
    except OSError:
        pass
    with torch.no_grad():
        for idx in range(n):
            noise = functions.sample_random_noise(depth, reals_shapes, opt)
            sample = netG(noise, reals_shapes, noise_amp)
            all_images.append(sample)
            functions.save_image('{}/gen_sample_{}.jpg'.format(dir2save, idx), sample.detach())

        all_images = torch.cat(all_images, 0)
        all_images[0] = reals[depth].squeeze()
        grid = make_grid(all_images, nrow=min(5, n), normalize=True)
        writer.add_image('gen_images_{}'.format(depth), grid, iter)
Esempio n. 4
0
def train_single_scale(netD, netG, reals, fixed_noise, noise_amp, netD2, netG2,
                       reals2, fixed_noise2, noise_amp2, opt, depth, writer,
                       fakes, fakes2, in_s, in_s2):
    reals_shapes = [real.shape for real in reals]
    real = reals[depth]
    # reals_shapes2 = [real2.shape for real2 in reals2]
    real2 = reals2[depth]

    # alpha = opt.alpha
    # lambda_idt = opt.lambda_idt
    # lambda_cyc = opt.lambda_cyc
    # lambda_tv = opt.lambda_tv
    lambda_idt = 1
    lambda_cyc = 1
    lambda_tv = 1

    ############################
    # define z_opt for training on reconstruction
    ###########################
    z_opt = functions.generate_noise(
        [
            3,  # opt.nfc
            reals_shapes[depth][2],
            reals_shapes[depth][3]
        ],
        device=opt.device)

    z_opt2 = functions.generate_noise(
        [3, reals_shapes[depth][2], reals_shapes[depth][3]], device=opt.device)

    fixed_noise.append(z_opt.detach())
    fixed_noise2.append(z_opt2.detach())

    ############################
    # define optimizers, learning rate schedulers, and learning rates for lower stages
    ###########################
    # setup optimizers for D
    optimizerD = optim.Adam(itertools.chain(netD.parameters(),
                                            netD2.parameters()),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))

    # setup optimizers for G
    # remove gradients from stages that are not trained
    for block in netG.body[:-opt.train_depth]:
        for param in block.parameters():
            param.requires_grad = False

    # set different learning rate for lower stages
    parameter_list = [{
        "params":
        block.parameters(),
        "lr":
        opt.lr_g *
        (opt.lr_scale**(len(netG.body[-opt.train_depth:]) - 1 - idx))
    } for idx, block in enumerate(netG.body[-opt.train_depth:])]

    # add parameters of head and tail to training
    # if depth - opt.train_depth < 0:
    #     parameter_list += [{"params": netG.head.parameters(), "lr": opt.lr_g * (opt.lr_scale**depth)}]
    parameter_list += [{"params": netG.tail.parameters(), "lr": opt.lr_g}]
    parameter_list += [{"params": netG.head2.parameters(), "lr": opt.lr_g}]
    parameter_list += [{"params": netG.body2.parameters(), "lr": opt.lr_g}]
    parameter_list += [{"params": netG.tail2.parameters(), "lr": opt.lr_g}]

    for block in netG2.body[:-opt.train_depth]:
        for param in block.parameters():
            param.requires_grad = False

    # set different learning rate for lower stages
    parameter_list2 = [{
        "params":
        block.parameters(),
        "lr":
        opt.lr_g *
        (opt.lr_scale**(len(netG2.body[-opt.train_depth:]) - 1 - idx))
    } for idx, block in enumerate(netG2.body[-opt.train_depth:])]

    # add parameters of head and tail to training
    # if depth - opt.train_depth < 0:
    #     parameter_list2 += [{"params": netG2.head.parameters(), "lr": opt.lr_g * (opt.lr_scale**depth)}]
    parameter_list2 += [{"params": netG2.tail.parameters(), "lr": opt.lr_g}]
    parameter_list2 += [{"params": netG2.head2.parameters(), "lr": opt.lr_g}]
    parameter_list2 += [{"params": netG2.body2.parameters(), "lr": opt.lr_g}]
    parameter_list2 += [{"params": netG2.tail2.parameters(), "lr": opt.lr_g}]

    optimizerG = optim.Adam(itertools.chain(parameter_list, parameter_list2),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))

    # define learning rate schedules
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(
        optimizer=optimizerD, milestones=[0.8 * opt.niter], gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(
        optimizer=optimizerG, milestones=[0.8 * opt.niter], gamma=opt.gamma)

    ############################
    # calculate noise_amp       netG(noise, prev, reals_shapes)[-1]
    ###########################
    # if depth == 0:
    #     noise_amp.append(1)
    # else:
    #     noise_amp.append(0)

    # start training
    _iter = tqdm(range(opt.niter))
    loss_print = {}
    for iter in _iter:
        _iter.set_description('stage [{}/{}]:'.format(depth, opt.stop_scale))

        ############################
        # (0) sample noise for unconditional generation
        ###########################
        # noise = functions.sample_random_noise(reals, depth, reals_shapes, opt, noise_amp)
        # noise2 = functions.sample_random_noise(reals2, depth, reals_shapes, opt, noise_amp)
        # 1.1.1     1.2.1   2.1.1   2.2.1

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            # netD.zero_grad()
            optimizerD.zero_grad()

            output = netD(real2).to(opt.device)
            errD_real = -output.mean()
            errD_real.backward(retain_graph=True)
            loss_print['errD_real'] = errD_real.item()

            output2 = netD2(real).to(opt.device)
            errD_real2 = -output2.mean()
            errD_real2.backward(retain_graph=True)
            loss_print['errD_real2'] = errD_real2.item()

            if (j == 0) & (iter == 0):
                if depth == 0:  # 1 opt.bsz   1.1.1
                    noise_amp.append(1)
                    noise_amp2.append(1)
                    prev = torch.full(
                        [1, opt.nc_im, reals_shapes[0][2], reals_shapes[0][3]],
                        0,
                        device=opt.device)
                    in_s.append(prev)
                    # in_s_ = prev

                    prev2 = torch.full(
                        [1, opt.nc_im, reals_shapes[0][2], reals_shapes[0][3]],
                        0,
                        device=opt.device)
                    in_s2.append(prev2)
                    # in_s2_ = prev2

                    c_prev = torch.full([
                        1, opt.nc_im, reals_shapes[depth][2],
                        reals_shapes[depth][3]
                    ],
                                        0,
                                        device=opt.device)
                    z_prev = torch.full([
                        1, opt.nc_im, reals_shapes[depth][2],
                        reals_shapes[depth][3]
                    ],
                                        0,
                                        device=opt.device)

                    c_prev2 = torch.full([
                        1, opt.nc_im, reals_shapes[depth][2],
                        reals_shapes[depth][3]
                    ],
                                         0,
                                         device=opt.device)
                    z_prev2 = torch.full([
                        1, opt.nc_im, reals_shapes[depth][2],
                        reals_shapes[depth][3]
                    ],
                                         0,
                                         device=opt.device)

                else:  # 2.1.1
                    # in_s2 = in_s2_
                    # in_s = in_s_
                    prev2, c_prev2 = cycle_rec(netG2, netG, fixed_noise2,
                                               reals2, noise_amp2, opt, depth,
                                               reals_shapes, in_s2)
                    prev, c_prev = cycle_rec(netG, netG2, fixed_noise, reals,
                                             noise_amp, opt, depth,
                                             reals_shapes, in_s)
                    z_prev2 = draw_concat(netG, reals2, 'rec', opt, depth,
                                          reals_shapes, in_s2)
                    z_prev = draw_concat(netG2, reals, 'rec', opt, depth,
                                         reals_shapes, in_s)
            else:  # 1.1.2     1.1.3       2.1.2       2.1.3   2.2.1
                if len(in_s2) > 1:
                    ins2_index = in_s2[:-1]
                    ins_index = in_s[:-1]
                else:
                    ins2_index = in_s2
                    ins_index = in_s
                prev2, c_prev2 = cycle_rec(netG2, netG, fixed_noise2, reals2,
                                           noise_amp2, opt, depth,
                                           reals_shapes, ins2_index)
                prev, c_prev = cycle_rec(netG, netG2, fixed_noise, reals,
                                         noise_amp, opt, depth, reals_shapes,
                                         ins_index)

            if j == 0:  # 1.1.1    1.2.1     2.1.1      2.2.1
                if depth > 0:
                    in_s_ = torch.full([
                        1, opt.nc_im, reals_shapes[depth][2],
                        reals_shapes[depth][3]
                    ],
                                       0,
                                       device=opt.device)
                    in_s2_ = torch.full([
                        1, opt.nc_im, reals_shapes[depth][2],
                        reals_shapes[depth][3]
                    ],
                                        0,
                                        device=opt.device)
                    noise_amp.append(0)
                    noise_amp2.append(0)
                    z_reconstruction = netG2(fixed_noise, in_s_, reals_shapes)
                    z_reconstruction2 = netG(fixed_noise2, in_s2_,
                                             reals_shapes)
                    if iter != 0:
                        in_s.pop()
                        in_s2.pop()
                    in_s.append(in_s_)
                    in_s2.append(in_s2_)

                    criterion = nn.MSELoss()
                    rec_loss = criterion(z_reconstruction, real)
                    rec_loss2 = criterion(z_reconstruction2, real2)

                    RMSE = torch.sqrt(rec_loss).detach()
                    RMSE2 = torch.sqrt(rec_loss2).detach()
                    _noise_amp = 0.1 * RMSE  # opt.noise_amp_init
                    _noise_amp2 = 0.1 * RMSE2
                    noise_amp[-1] = _noise_amp
                    noise_amp2[-1] = _noise_amp2
                noise = functions.sample_random_noise(reals, depth,
                                                      reals_shapes, opt,
                                                      noise_amp)
                noise2 = functions.sample_random_noise(reals2, depth,
                                                       reals_shapes, opt,
                                                       noise_amp2)

            # train with fake
            if j == opt.Dsteps - 1:
                fake = netG(noise, prev, reals_shapes)
                fake2 = netG2(noise2, prev2, reals_shapes)
            else:
                with torch.no_grad():
                    fake = netG(noise, prev, reals_shapes)
                    fake2 = netG2(noise2, prev2, reals_shapes)
            output = netD(fake.detach())
            errD_fake = output.mean()
            errD_fake.backward(retain_graph=True)
            loss_print['errD_fake'] = errD_fake.item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real2, fake, opt.lambda_grad, opt.device)
            gradient_penalty.backward()
            loss_print['gradient_penalty'] = gradient_penalty.item()

            output2 = netD2(fake2.detach())
            errD_fake2 = output2.mean()
            errD_fake2.backward(retain_graph=True)
            loss_print['errD_fake2'] = errD_fake2.item()

            gradient_penalty2 = functions.calc_gradient_penalty(
                netD2, real, fake2, opt.lambda_grad, opt.device)
            gradient_penalty2.backward()
            loss_print['gradient_penalty2'] = gradient_penalty2.item()

            optimizerD.step()
            # conda activate tui

        if iter != 0:
            fakes.pop()
            fakes2.pop()
        fakes.append(fake)
        fakes2.append(fake2)

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################
        optimizerG.zero_grad()
        loss_tv = TVLoss()

        output = netD(fake)
        errG = -output.mean() + lambda_tv * loss_tv(fake)
        errG.backward(retain_graph=True)
        loss_print['errG'] = errG.item()

        output2 = netD2(fake2)
        errG2 = -output2.mean() + lambda_tv * loss_tv(fake2)
        errG2.backward(retain_graph=True)
        loss_print['errG2'] = errG2.item()

        loss = nn.L1Loss()  # nn.MSELoss()

        rec = netG(fixed_noise2, z_prev2, reals_shapes)
        rec_loss = lambda_idt * loss(rec, real2)
        rec_loss.backward(retain_graph=True)
        loss_print['rec_loss'] = rec_loss.item()
        rec_loss = rec_loss.detach()

        cyc = netG(fakes2, c_prev2, reals_shapes)
        cyc_loss = lambda_cyc * loss(cyc, real2)
        cyc_loss.backward(retain_graph=True)
        loss_print['cyc_loss'] = cyc_loss.item()
        cyc_loss = cyc_loss.detach()

        rec2 = netG2(fixed_noise, z_prev, reals_shapes)
        rec_loss2 = lambda_idt * loss(rec2, real)
        rec_loss2.backward(retain_graph=True)
        loss_print['rec_loss2'] = rec_loss2.item()
        rec_loss2 = rec_loss2.detach()

        cyc2 = netG2(fakes, c_prev, reals_shapes)
        cyc_loss2 = lambda_cyc * loss(cyc2, real)
        cyc_loss2.backward(retain_graph=True)
        loss_print['cyc_loss2'] = cyc_loss2.item()
        cyc_loss2 = cyc_loss2.detach()

        for _ in range(opt.Gsteps):  # opt.Gsteps
            optimizerG.step()

        ############################
        # (3) Log Results
        ###########################
        if iter % 500 == 0 or iter == (opt.niter - 1):
            functions.save_image(
                '{}/fake_sample_{}.jpg'.format(opt.outf, iter + 1),
                fake.detach())
            functions.save_image(
                '{}/fake_sample2_{}.jpg'.format(opt.outf, iter + 1),
                fake2.detach())
            # functions.save_image('{}/reconstruction_{}.jpg'.format(opt.outf, iter+1), rec.detach())
            # functions.save_image('{}/reconstruction2_{}.jpg'.format(opt.outf, iter+1), rec2.detach())
            # generate_samples(netG, opt, depth, noise_amp, writer, reals, iter+1)

            log = " Iteration [{}/{}]".format(iter, opt.niter)
            for tag, value in loss_print.items():
                log += ", {}: {:.4f}".format(tag, value)
            print(log)
        # if iter % 250 == 0 or iter+1 == opt.niter:
        #     writer.add_scalar('Loss/train/D/real/{}'.format(j), -errD_real.item(), iter+1)
        #     writer.add_scalar('Loss/train/D/fake/{}'.format(j), errD_fake.item(), iter+1)
        #     writer.add_scalar('Loss/train/D/gradient_penalty/{}'.format(j), gradient_penalty.item(), iter+1)
        #     writer.add_scalar('Loss/train/D/real2/{}'.format(j), -errD_real2.item(), iter+1)
        #     writer.add_scalar('Loss/train/D/fake2/{}'.format(j), errD_fake2.item(), iter+1)
        #     writer.add_scalar('Loss/train/D/gradient_penalty2/{}'.format(j), gradient_penalty2.item(), iter+1)
        #
        #     writer.add_scalar('Loss/train/G/gen', errG.item(), iter+1)
        #     writer.add_scalar('Loss/train/G/reconstruction', rec_loss.item(), iter+1)
        #     writer.add_scalar('Loss/train/G/cycle', cyc_loss.item(), iter+1)
        #     writer.add_scalar('Loss/train/G/gen2', errG2.item(), iter+1)
        #     writer.add_scalar('Loss/train/G/reconstruction2', rec_loss2.item(), iter+1)
        #     writer.add_scalar('Loss/train/G/cycle2', cyc_loss2.item(), iter+1)
        #
        # if iter % 500 == 0 or iter+1 == opt.niter:
        #     functions.save_image('{}/fake_sample_{}.jpg'.format(opt.outf, iter+1), fake.detach())
        #     functions.save_image('{}/reconstruction_{}.jpg'.format(opt.outf, iter+1), rec.detach())
        #     # generate_samples(netG, opt, depth, noise_amp, writer, reals, iter+1)

        schedulerD.step()
        schedulerG.step()

    functions.save_networks(netG, netD, z_opt, netG2, netD2, z_opt2, opt)
    return fixed_noise, noise_amp, netG, netD, fixed_noise2, noise_amp2, netG2, netD2, in_s, in_s2
Esempio n. 5
0
                                              functions.generate_noise([opt.nc_im, fixed_noise[0].shape[2],
                                                                        fixed_noise[0].shape[3]],
                                                                        device=opt.device)

        out = generate_samples(netG, reals_shapes, noise_amp, reconstruct=True)

        mask_file_name = '{}_mask{}'.format(opt.naive_img[:-4],
                                            opt.naive_img[-4:])
        if os.path.exists(mask_file_name):
            mask = functions.read_image_dir(mask_file_name, opt)
            if mask.shape[3] != out.shape[3]:
                mask = imresize_to_shape(mask, [out.shape[2], out.shape[3]],
                                         opt)
            mask = functions.dilate_mask(mask, opt)
            out = (1 - mask) * reals[-1] + mask * out
            functions.save_image('{}/{}_w_mask.jpg'.format(dir2save, _name),
                                 out.detach())
        else:
            print("Warning: mask {} not found.".format(mask_file_name))
            print("Harmonization/Editing only performed without mask.")

    elif opt.train_mode == "animation":
        print("Generating GIFs...")
        for _start_scale in range(3):
            for _beta in range(80, 100, 5):
                functions.generate_gif(dir2save,
                                       netG,
                                       fixed_noise,
                                       reals,
                                       noise_amp,
                                       opt,
                                       alpha=0.1,
Esempio n. 6
0
def train_single_scale(netD, netG, reals, fixed_noise, noise_amp, opt, depth, writer):
    reals_shapes = [real.shape for real in reals]
    real = reals[depth]

    alpha = opt.alpha

    ############################
    # define z_opt for training on reconstruction
    ###########################
    if depth == 0:
        if opt.train_mode == "generation" or opt.train_mode == "retarget":
            z_opt = reals[0]
        elif opt.train_mode == "animation":
            z_opt = functions.generate_noise([opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3]],
                                             device=opt.device).detach()
    else:
        if opt.train_mode == "generation" or opt.train_mode == "animation":
            z_opt = functions.generate_noise([opt.nfc,
                                              reals_shapes[depth][2]+opt.num_layer*2,
                                              reals_shapes[depth][3]+opt.num_layer*2],
                                              device=opt.device)
        else:
            z_opt = functions.generate_noise([opt.nfc, reals_shapes[depth][2], reals_shapes[depth][3]],
                                              device=opt.device).detach()
    fixed_noise.append(z_opt.detach())

    ############################
    # define optimizers, learning rate schedulers, and learning rates for lower stages
    ###########################
    # setup optimizers for D
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999))

    # setup optimizers for G
    # remove gradients from stages that are not trained
    for block in netG.body[:-opt.train_depth]:
        for param in block.parameters():
            param.requires_grad = False

    # set different learning rate for lower stages
    parameter_list = [{"params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG.body[-opt.train_depth:])-1-idx))}
               for idx, block in enumerate(netG.body[-opt.train_depth:])]

    # add parameters of head and tail to training
    if depth - opt.train_depth < 0:
        parameter_list += [{"params": netG.head.parameters(), "lr": opt.lr_g * (opt.lr_scale**depth)}]
    parameter_list += [{"params": netG.tail.parameters(), "lr": opt.lr_g}]
    optimizerG = optim.Adam(parameter_list, lr=opt.lr_g, betas=(opt.beta1, 0.999))

    # define learning rate schedules
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[0.8*opt.niter], gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[0.8*opt.niter], gamma=opt.gamma)

    ############################
    # calculate noise_amp
    ###########################
    if depth == 0:
        noise_amp.append(1)
    else:
        noise_amp.append(0)
        z_reconstruction = netG(fixed_noise, reals_shapes, noise_amp)

        criterion = nn.MSELoss()
        rec_loss = criterion(z_reconstruction, real)

        RMSE = torch.sqrt(rec_loss).detach()
        _noise_amp = opt.noise_amp_init * RMSE
        noise_amp[-1] = _noise_amp

    # start training
    _iter = tqdm(range(opt.niter))
    for iter in _iter:
        _iter.set_description('stage [{}/{}]:'.format(depth, opt.stop_scale))

        ############################
        # (0) sample noise for unconditional generation
        ###########################
        noise = functions.sample_random_noise(depth, reals_shapes, opt)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()
            output = netD(real)
            errD_real = -output.mean()

            # train with fake
            if j == opt.Dsteps - 1:
                fake = netG(noise, reals_shapes, noise_amp)
            else:
                with torch.no_grad():
                    fake = netG(noise, reals_shapes, noise_amp)

            output = netD(fake.detach())
            errD_fake = output.mean()

            gradient_penalty = functions.calc_gradient_penalty(netD, real, fake, opt.lambda_grad, opt.device)
            errD_total = errD_real + errD_fake + gradient_penalty
            errD_total.backward()
            optimizerD.step()

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################
        output = netD(fake)
        errG = -output.mean()

        if alpha != 0:
            loss = nn.MSELoss()
            rec = netG(fixed_noise, reals_shapes, noise_amp)
            rec_loss = alpha * loss(rec, real)
        else:
            rec_loss = 0

        netG.zero_grad()
        errG_total = errG + rec_loss
        errG_total.backward()

        for _ in range(opt.Gsteps):
            optimizerG.step()

        ############################
        # (3) Log Results
        ###########################
        if iter % 250 == 0 or iter+1 == opt.niter:
            writer.add_scalar('Loss/train/D/real/{}'.format(j), -errD_real.item(), iter+1)
            writer.add_scalar('Loss/train/D/fake/{}'.format(j), errD_fake.item(), iter+1)
            writer.add_scalar('Loss/train/D/gradient_penalty/{}'.format(j), gradient_penalty.item(), iter+1)
            writer.add_scalar('Loss/train/G/gen', errG.item(), iter+1)
            writer.add_scalar('Loss/train/G/reconstruction', rec_loss.item(), iter+1)
        if iter % 500 == 0 or iter+1 == opt.niter:
            functions.save_image('{}/fake_sample_{}.jpg'.format(opt.outf, iter+1), fake.detach())
            functions.save_image('{}/reconstruction_{}.jpg'.format(opt.outf, iter+1), rec.detach())
            generate_samples(netG, opt, depth, noise_amp, writer, reals, iter+1)

        schedulerD.step()
        schedulerG.step()
        # break

    functions.save_networks(netG, netD, z_opt, opt)
    return fixed_noise, noise_amp, netG, netD
def generate_samples(netG,
                     img_to_augment,
                     naive_img,
                     naive_img_large,
                     aug,
                     opt,
                     depth,
                     noise_amp,
                     writer,
                     reals,
                     iter,
                     n=16):
    opt.out_ = functions.generate_dir2save(opt)
    dir2save = '{}/harmonized_samples_stage_{}'.format(opt.out_, depth)
    reals_shapes = [r.shape for r in reals]
    _name = "harmonized" if opt.train_mode == "harmonization" else "edited"
    images = []
    try:
        os.makedirs(dir2save)
    except OSError:
        pass

    if naive_img is not None:
        n = n - 1
    if opt.fine_tune:
        n = 1
    with torch.no_grad():
        for idx in range(n):
            noise = []
            for d in range(depth + 1):
                if d == 0:
                    if opt.fine_tune:
                        if opt.train_mode == "harmonization":
                            augmented_image = functions.np2torch(
                                naive_img, opt)
                            noise.append(augmented_image)
                        elif opt.train_mode == "editing":
                            augmented_image = functions.np2torch(
                                naive_img, opt)
                            noise.append(augmented_image + opt.noise_scaling *
                                         functions.generate_noise(
                                             [
                                                 opt.nc_im, reals_shapes[d][2],
                                                 reals_shapes[d][3]
                                             ],
                                             device=opt.device).detach())
                    else:
                        if opt.train_mode == "harmonization":
                            data = {"image": img_to_augment}
                            augmented = aug.transform(**data)
                            augmented_image = functions.np2torch(
                                augmented["image"], opt)
                            noise.append(augmented_image)
                        elif opt.train_mode == "editing":
                            image = functions.shuffle_grid(img_to_augment)
                            augmented_image = functions.np2torch(image, opt)
                            noise.append(augmented_image + opt.noise_scaling *
                                         functions.generate_noise(
                                             [
                                                 opt.nc_im, reals_shapes[d][2],
                                                 reals_shapes[d][3]
                                             ],
                                             device=opt.device).detach())
                else:
                    noise.append(
                        functions.generate_noise(
                            [opt.nfc, reals_shapes[d][2], reals_shapes[d][3]],
                            device=opt.device).detach())
            sample = netG(noise, reals_shapes, noise_amp)
            functions.save_image(
                '{}/{}_naive_sample.jpg'.format(dir2save, idx),
                augmented_image)
            functions.save_image(
                '{}/{}_{}_sample.jpg'.format(dir2save, idx, _name),
                sample.detach())
            augmented_image = imresize_to_shape(augmented_image,
                                                sample.shape[2:], opt)
            images.append(augmented_image)
            images.append(sample.detach())

        if opt.fine_tune:
            mask_file_name = '{}_mask{}'.format(opt.naive_img[:-4],
                                                opt.naive_img[-4:])
            augmented_image = imresize_to_shape(naive_img_large,
                                                sample.shape[2:], opt)
            if os.path.exists(mask_file_name):
                mask = get_mask(mask_file_name, augmented_image, opt)
                sample_w_mask = (
                    1 - mask) * augmented_image + mask * sample.detach()
                functions.save_image(
                    '{}/{}_sample_w_mask_{}.jpg'.format(dir2save, _name, iter),
                    sample_w_mask.detach())
                images = torch.cat(
                    [augmented_image,
                     sample.detach(), sample_w_mask], 0)
                grid = make_grid(images, nrow=3, normalize=True)
                writer.add_image('{}_images_{}'.format(_name, depth), grid,
                                 iter)
            else:
                print(
                    "Warning: no mask with name {} exists for image {}".format(
                        mask_file_name, opt.input_name))
                print("Only showing results without mask.")
                images = torch.cat([augmented_image, sample.detach()], 0)
                grid = make_grid(images, nrow=2, normalize=True)
                writer.add_image('{}_images_{}'.format(_name, depth), grid,
                                 iter)
            functions.save_image(
                '{}/{}_sample_{}.jpg'.format(dir2save, _name, iter),
                sample.detach())
        else:
            if naive_img is not None:
                noise = []
                for d in range(depth + 1):
                    if d == 0:
                        if opt.train_mode == "harmonization":
                            noise.append(functions.np2torch(naive_img, opt))
                        elif opt.train_mode == "editing":
                            noise.append(functions.np2torch(naive_img, opt) + opt.noise_scaling * \
                                              functions.generate_noise([opt.nc_im, reals_shapes[d][2],
                                                                        reals_shapes[d][3]],
                                                                        device=opt.device).detach())
                    else:
                        noise.append(
                            functions.generate_noise(
                                [
                                    opt.nfc, reals_shapes[d][2],
                                    reals_shapes[d][3]
                                ],
                                device=opt.device).detach())
                sample = netG(noise, reals_shapes, noise_amp)
                _naive_img = imresize_to_shape(naive_img_large,
                                               sample.shape[2:], opt)
                images.insert(0, sample.detach())
                images.insert(0, _naive_img)
                functions.save_image(
                    '{}/{}_sample_{}.jpg'.format(dir2save, _name, iter),
                    sample.detach())

                mask_file_name = '{}_mask{}'.format(opt.naive_img[:-4],
                                                    opt.naive_img[-4:])
                if os.path.exists(mask_file_name):
                    mask = get_mask(mask_file_name, _naive_img, opt)
                    sample_w_mask = (
                        1 - mask) * _naive_img + mask * sample.detach()
                    functions.save_image(
                        '{}/{}_sample_w_mask_{}.jpg'.format(
                            dir2save, _name, iter), sample_w_mask)

            images = torch.cat(images, 0)
            grid = make_grid(images, nrow=4, normalize=True)
            writer.add_image('{}_images_{}'.format(_name, depth), grid, iter)
def train(opt):
    print("Training model with the following parameters:")
    print("\t number of stages: {}".format(opt.train_stages))
    print("\t number of concurrently trained stages: {}".format(
        opt.train_depth))
    print("\t learning rate scaling: {}".format(opt.lr_scale))
    print("\t non-linearity: {}".format(opt.activation))

    real = functions.read_image(opt)
    real = functions.adjust_scales2image(real, opt)
    reals = functions.create_reals_pyramid(real, opt)
    print("Training on image pyramid: {}".format([r.shape for r in reals]))
    print("")

    if opt.naive_img != "":
        naive_img = functions.read_image_dir(opt.naive_img, opt)
        naive_img_large = imresize_to_shape(naive_img, reals[-1].shape[2:],
                                            opt)
        naive_img = imresize_to_shape(naive_img, reals[0].shape[2:], opt)
        naive_img = functions.convert_image_np(naive_img) * 255.0
    else:
        naive_img = None
        naive_img_large = None

    if opt.fine_tune:
        img_to_augment = naive_img
    else:
        img_to_augment = functions.convert_image_np(reals[0]) * 255.0

    if opt.train_mode == "editing":
        opt.noise_scaling = 0.1

    generator = init_G(opt)
    if opt.fine_tune:
        for _ in range(opt.train_stages - 1):
            generator.init_next_stage()
        generator.load_state_dict(
            torch.load(
                '{}/{}/netG.pth'.format(opt.model_dir, opt.train_stages - 1),
                map_location="cuda:{}".format(torch.cuda.current_device())))

    fixed_noise = []
    noise_amp = []

    for scale_num in range(opt.start_scale, opt.train_stages):
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            print(OSError)
            pass
        functions.save_image('{}/real_scale.jpg'.format(opt.outf),
                             reals[scale_num])

        d_curr = init_D(opt)
        if opt.fine_tune:
            d_curr.load_state_dict(
                torch.load('{}/{}/netD.pth'.format(opt.model_dir,
                                                   opt.train_stages - 1),
                           map_location="cuda:{}".format(
                               torch.cuda.current_device())))
        elif scale_num > 0:
            d_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))
            generator.init_next_stage()

        writer = SummaryWriter(log_dir=opt.outf)
        fixed_noise, noise_amp, generator, d_curr = train_single_scale(
            d_curr, generator, reals, img_to_augment, naive_img,
            naive_img_large, fixed_noise, noise_amp, opt, scale_num, writer)

        torch.save(fixed_noise, '%s/fixed_noise.pth' % (opt.out_))
        torch.save(generator, '%s/G.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(noise_amp, '%s/noise_amp.pth' % (opt.out_))
        del d_curr
    writer.close()
    return
Esempio n. 9
0
def train_single_scale(netD, netG, reals, fixed_noise, noise_amp, netD2, netG2, reals2, fixed_noise2, noise_amp2, opt, depth, writer):
    reals_shapes = [real.shape for real in reals]
    real = reals[depth]
    reals_shapes2 = [real2.shape for real2 in reals2]
    real2 = reals2[depth]

    # alpha = opt.alpha
    lambda_idt = opt.lambda_idt
    lambda_cyc = opt.lambda_cyc
    lambda_tv = opt.lambda_tv

    ############################
    # define z_opt for training on reconstruction
    ###########################
    if depth == 0:
        if opt.train_mode == "generation" or opt.train_mode == "retarget":
            z_opt = reals[0]
            z_opt2 = reals2[0]
        elif opt.train_mode == "animation":
            z_opt = functions.generate_noise([opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3]],
                                             device=opt.device).detach()
            z_opt2 = functions.generate_noise([opt.nc_im, reals_shapes2[depth][2], reals_shapes2[depth][3]],
                                             device=opt.device).detach()
    else:
        if opt.train_mode == "generation" or opt.train_mode == "animation":
            z_opt0 = functions.generate_noise([opt.nfc,
                                              reals_shapes[depth][2]+opt.num_layer*2,
                                              reals_shapes[depth][3]+opt.num_layer*2],
                                              device=opt.device)

            fixed_noise.append(z_opt0.detach())
            # fakes_shapes = [fake.shape for fake in fixed_noise]
            noise_amp_f = [0.1] * 15
            z_opt = netG(reals, fixed_noise, reals_shapes, noise_amp_f)
            fixed_noise = fixed_noise[: -1]

            z_opt02 = functions.generate_noise([opt.nfc,
                                              reals_shapes2[depth][2]+opt.num_layer*2,
                                              reals_shapes2[depth][3]+opt.num_layer*2],
                                              device=opt.device)

            # fixed_noise2.append(z_opt02.detach())
            # fakes_shapes2 = [fake2.shape for fake2 in fixed_noise2]
            z_opt2 = netG2(reals[1:], z_opt, noise_amp_f)
            fixed_noise2 = fixed_noise2[: -1]
            # criterion = nn.MSELoss()
            # rec_loss = criterion(z_opt1, real)
            #
            # RMSE = torch.sqrt(rec_loss).detach()
            # _noise_amp = opt.noise_amp_init * RMSE
            # noise_amp_f[-1] = _noise_amp
            # fixed_noise.pop()
        else:
            z_opt = functions.generate_noise([opt.nfc, reals_shapes[depth][2], reals_shapes[depth][3]],
                                              device=opt.device).detach()
            # 暂时未更新
    fixed_noise.append(z_opt.detach())
    fixed_noise2.append(z_opt2.detach())

    ############################
    # define optimizers, learning rate schedulers, and learning rates for lower stages
    ###########################
    # setup optimizers for D
    optimizerD = optim.Adam(itertools.chain(netD.parameters(),netD2.parameters()), lr=opt.lr_d, betas=(opt.beta1, 0.999))

    # setup optimizers for G
    # remove gradients from stages that are not trained
    for block in netG.body[:-opt.train_depth]:
        for param in block.parameters():
            param.requires_grad = False

    # set different learning rate for lower stages
    parameter_list = [{"params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG.body[-opt.train_depth:])-1-idx))}
               for idx, block in enumerate(netG.body[-opt.train_depth:])]

    # add parameters of head and tail to training
    if depth - opt.train_depth < 0:
        parameter_list += [{"params": netG.head.parameters(), "lr": opt.lr_g * (opt.lr_scale**depth)}]
    parameter_list += [{"params": netG.tail.parameters(), "lr": opt.lr_g}]

    for block in netG2.body[:-opt.train_depth]:
        for param in block.parameters():
            param.requires_grad = False

    # set different learning rate for lower stages
    parameter_list2 = [{"params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG2.body[-opt.train_depth:])-1-idx))}
               for idx, block in enumerate(netG2.body[-opt.train_depth:])]

    # add parameters of head and tail to training
    if depth - opt.train_depth < 0:
        parameter_list2 += [{"params": netG2.head.parameters(), "lr": opt.lr_g * (opt.lr_scale**depth)}]
    parameter_list2 += [{"params": netG2.tail.parameters(), "lr": opt.lr_g}]

    optimizerG = optim.Adam(itertools.chain(parameter_list, parameter_list2), lr=opt.lr_g, betas=(opt.beta1, 0.999))

    # define learning rate schedules
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[0.8*opt.niter], gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[0.8*opt.niter], gamma=opt.gamma)

    ############################
    # calculate noise_amp
    ###########################
    if depth == 0:
        noise_amp.append(1)
    else:
        noise_amp.append(0)
        z_reconstruction = netG(fixed_noise, reals_shapes, noise_amp)

        criterion = nn.MSELoss()
        rec_loss = criterion(z_reconstruction, real)

        RMSE = torch.sqrt(rec_loss).detach()
        _noise_amp = opt.noise_amp_init * RMSE
        noise_amp[-1] = _noise_amp

    # start training
    _iter = tqdm(range(opt.niter)) #
    for iter in _iter:
        _iter.set_description('stage [{}/{}]:'.format(depth, opt.stop_scale))

        ############################
        # (0) sample noise for unconditional generation
        ###########################
        noise = functions.sample_random_noise(depth, reals_shapes, opt)
        noise2 = functions.sample_random_noise(depth, reals_shapes, opt)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps): #
            # train with real
            # netD.zero_grad()
            optimizerD.zero_grad()

            output = netD(real2)
            errD_real = -output.mean()
            output2 = netD(real)
            errD_real2 = -output2.mean()

            # train with fake
            if j == opt.Dsteps - 1:
                fake = netG(reals, reals_shapes, noise_amp, add_noise=True) # 噪声 + 真实图像
                # fake2, _ = netG(noise2, reals_shapes2, noise_amp2)
            else:
                with torch.no_grad():
                    fake = netG(reals, reals_shapes, noise_amp, add_noise=True)
                    # fake2, _ = netG(noise2, reals_shapes2, noise_amp2)

            output = netD(fake.detach())
            errD_fake = output.mean()
            gradient_penalty = functions.calc_gradient_penalty(netD, real2, fake, opt.lambda_grad, opt.device)

            if j == opt.Dsteps - 1:
                fake2 = netG2(reals2, reals_shapes2, noise_amp2, add_noise=True)
            else:
                with torch.no_grad():
                    fake2 = netG2(reals2, reals_shapes2, noise_amp2, add_noise=True)

            output2 = netD2(fake2.detach())
            errD_fake2 = output2.mean()
            gradient_penalty2 = functions.calc_gradient_penalty(netD2, real, fake2, opt.lambda_grad, opt.device)


            errD_total = errD_real + errD_fake + gradient_penalty + errD_real2 + errD_fake2 + gradient_penalty2
            errD_total.backward()
            optimizerD.step()

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################
        optimizerG.zero_grad()
        loss_tv = TVLoss()

        output = netD(fake)
        errG = -output.mean() + lambda_tv * loss_tv(fake)

        output2 = netD2(fake2)
        errG2 = -output2.mean() + lambda_tv * loss_tv(fake2)

        loss = nn.L1Loss() # nn.MSELoss()

        rec = netG(real2, reals_shapes2, noise_amp2) # real
        rec_loss = lambda_idt * loss(rec, real2)
        rec_loss = rec_loss.detach()

        cyc = netG(fake2, reals_shapes2, noise_amp2)
        cyc_loss = lambda_cyc* loss(cyc, real2)
        cyc_loss = cyc_loss.detach()

        rec2 = netG2(real, reals_shapes, noise_amp)
        rec_loss2 = lambda_idt * loss(rec2, real)
        rec_loss2 = rec_loss2.detach()

        cyc2 = netG2(fake, reals_shapes, noise_amp)
        cyc_loss2 = lambda_cyc* loss(cyc2, real)
        cyc_loss2 = cyc_loss2.detach()

        errG_total = errG + rec_loss + errG2 + cyc_loss + cyc_loss2 + rec_loss2
        errG_total.backward()

        for _ in range(opt.Gsteps): # opt.Gsteps
            optimizerG.step()

        ############################
        # (3) Log Results
        ###########################
        if iter % 250 == 0 or iter+1 == opt.niter:
            writer.add_scalar('Loss/train/D/real/{}'.format(j), -errD_real.item(), iter+1)
            writer.add_scalar('Loss/train/D/fake/{}'.format(j), errD_fake.item(), iter+1)
            writer.add_scalar('Loss/train/D/gradient_penalty/{}'.format(j), gradient_penalty.item(), iter+1)
            writer.add_scalar('Loss/train/D/real2/{}'.format(j), -errD_real2.item(), iter+1)
            writer.add_scalar('Loss/train/D/fake2/{}'.format(j), errD_fake2.item(), iter+1)
            writer.add_scalar('Loss/train/D/gradient_penalty2/{}'.format(j), gradient_penalty2.item(), iter+1)

            writer.add_scalar('Loss/train/G/gen', errG.item(), iter+1)
            writer.add_scalar('Loss/train/G/reconstruction', rec_loss.item(), iter+1)
            writer.add_scalar('Loss/train/G/cycle', cyc_loss.item(), iter+1)
            writer.add_scalar('Loss/train/G/gen2', errG2.item(), iter+1)
            writer.add_scalar('Loss/train/G/reconstruction2', rec_loss2.item(), iter+1)
            writer.add_scalar('Loss/train/G/cycle2', cyc_loss2.item(), iter+1)

        if iter % 500 == 0 or iter+1 == opt.niter:
            functions.save_image('{}/fake_sample_{}.jpg'.format(opt.outf, iter+1), fake.detach())
            functions.save_image('{}/reconstruction_{}.jpg'.format(opt.outf, iter+1), rec.detach())
            generate_samples(netG, opt, depth, noise_amp, writer, reals, iter+1)

        schedulerD.step()
        schedulerG.step()
        # break

    functions.save_networks(netG, netD, z_opt, netG2, netD2, z_opt2, opt)
    return fixed_noise, noise_amp, netG, netD, fixed_noise2, noise_amp2, netG2, netD2