예제 #1
0
        loss_cyc_A = cycle_consistency_loss(cycle_A, real_A)

        cycle_B = netG_A2B(fake_A)
        loss_cyc_B = cycle_consistency_loss(cycle_B, real_B)

        #full objective loss
        loss_G = loss_G_A2B + loss_G_B2A + opt.lambda1 * (loss_cyc_A +
                                                          loss_cyc_B)
        loss_G.backward()

        optimizerG_A2B.step()
        optimizerG_B2A.step()

        ###Discriminator###
        #DA
        netD_A.zero_grad()

        real_out = netD_A(real_A)
        loss_DA_real = gan_loss(real_out, real_label)
        fake_A = fake_buffer_A.refresh(fake_A)
        fake_out = netD_A(fake_A.detach())
        loss_DA_fake = gan_loss(fake_out, fake_label)
        loss_DA = (loss_DA_real + loss_DA_fake) * 0.5
        loss_DA.backward()

        optimizerD_A.step()

        #DB
        netD_B.zero_grad()

        real_out = netD_B(real_B)
예제 #2
0
for epoch in range(args.num_epochs):
    start_time = timer()

    # Variables for recording statistics.
    average_discriminator_real_performance = 0.0
    average_discriminator_generated_performance = 0.0
    average_discriminator_loss = 0.0
    average_generator_loss = 0.0

    # Train: perform 'args.epoch_length' mini-batch updates per "epoch".
    for i in range(args.epoch_length):
        total_training_steps += 1

        # Train the discriminator:
        discriminator_model.zero_grad()

        # Evaluate a mini-batch of real images.
        random_indexes = np.random.choice(len(images), args.mini_batch_size)
        real_images = torch.tensor(images[random_indexes], device=DEVICE)

        real_predictions = discriminator_model(real_images)

        # Evaluate a mini-batch of generated images.
        random_latent_space_vectors = torch.randn(args.mini_batch_size,
                                                  512,
                                                  1,
                                                  1,
                                                  device=DEVICE)
        generated_images = generator_model(random_latent_space_vectors)
예제 #3
0
            train_dec = False
        if train_dec is False and train_dis is False:
            train_dis = True
            train_dec = True

        NetE.zero_grad()
        loss_encoder.backward(retain_graph=True)
        optimizer_encorder.step()

        if train_dec:
            NetG.zero_grad()
            loss_decoder.backward(retain_graph=True)
            optimizer_decoder.step()

        if train_dis:
            NetD.zero_grad()
            loss_discriminator.backward()
            optimizer_discriminator.step()

        print(
            '[%d/%d][%d/%d] loss_discriminator: %.4f loss_decoder: %.4f loss_encoder: %.4f D_x: %.4f D_G_z1: %.4f  D_G_z2: %.4f'
            % (epoch, opt.niter, i, len(dataloader), loss_discriminator.item(),
               loss_decoder.item(), loss_encoder.item(), D_x, D_G_z1, D_G_z2))

    mu, logvar = NetE(fixed_batch)
    sample = Sampler([mu, logvar], device)
    rec_real = NetG(sample)
    vutils.save_image(rec_real,
                      '%s/rec_real_epoch_%03d.png' % (opt.outf, epoch),
                      normalize=True)
    if epoch % 10 == 0:
예제 #4
0
for epoch in range(opt.niter):
    for i, data in enumerate(dataloader, 0):
        t0 = time.time()
        sys.stdout.flush()
        content = next(iter(cdataloader))[0]
        content = content.to(device)

        content,templatePatch = randCrop(content,templates,opt.imageSize,targetMosaic)
        templatePatch =templatePatch.to(device)##needed -- I create new float Tensor in randCrop
        if opt.trainOverfit:
            content = content.to(device)

        if epoch==0 and i==0:
            print ("template size",templatePatch.shape)
        # train with real
        netD.zero_grad()
        text, _ = data
        batch_size = content.size(0)##if we use texture and content of diff size may have issue -- just trim
        text=text.to(device) 
        output = netD(text)##used to find correct size for label
        errD_real = criterion(output, output.detach()*0+real_label)
        errD_real.backward()
        D_x = output.mean()

        # train with fake
        noise=setNoise(noise)
        fake, alpha, A, mixedI = famosGeneration(content, noise, templatePatch, True)
        output = netD(fake.detach())#???why detach
        errD_fake = criterion(output, output.detach()*0+fake_label)
        errD_fake.backward()
예제 #5
0
class GAN:
    def __init__(self):
        self.generator = Generator()
        self.generator.to(Params.Device)

        self.discriminator = Discriminator()
        self.discriminator.to(Params.Device)

        self.loss_fn = nn.BCELoss()

        self.optimizer_g = torch.optim.Adam(self.generator.parameters(), Params.LearningRateG, betas=(Params.Beta, 0.999))
        self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(), Params.LearningRateD, betas=(Params.Beta, 0.999))

        self.exemplar_latent_vectors = torch.randn( (64, Params.LatentVectorSize), device=Params.Device )


    def load_image_set(self):
        transforms = torchvision.transforms.Compose( [
            torchvision.transforms.Resize(Params.ImageSize),
            torchvision.transforms.CenterCrop(Params.ImageSize),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize( (0.5,0.5,0.5), (0.5,0.5,0.5) )
        ] )
        self.dataset = torchvision.datasets.ImageFolder(
            Params.ImagePath,
            transforms
        )
        self.loader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size = Params.BatchSize,
            shuffle=True,
            num_workers=Params.NumWorkers
        )


    def train(self, start_epoch=0):
        self.generator.train()
        self.discriminator.train()

        criterion = nn.BCELoss()

        self.loss_record = {
            'D': [],
            'G': []
        }

        for cur_epoch in range(start_epoch, Params.NumEpochs):
            tic = time.perf_counter()
            
            self.train_epoch(cur_epoch)

            toc = time.perf_counter()
            print( f"Last epoch took: {toc - tic:.0f} seconds" )

            if cur_epoch % Params.CheckpointEvery == 0:
                self.save_checkpoint(f"epoch{cur_epoch+1:03d}")

        self.save_checkpoint(f"Final ({Params.NumEpochs} epochs)")


    def train_epoch(self, cur_epoch):
        for (i, real_images) in enumerate(self.loader):
            # Discriminator training step
            self.discriminator.zero_grad()
            real_images = real_images[0].to(Params.Device)
            batch_size = real_images.size(0)
            labels = torch.full(
                (batch_size,),
                1,
                dtype=torch.float, device=Params.Device
            )
            D_real = self.discriminator(real_images).view(-1)
            D_loss_real = self.loss_fn(D_real, labels)
            D_loss_real.backward()

            latent_vectors = torch.randn( (batch_size, Params.LatentVectorSize), device=Params.Device )
            fake_images = self.generator(latent_vectors)
            D_fake = self.discriminator(fake_images.detach()).view(-1)
            labels.fill_(0)
            D_loss_fake = self.loss_fn(D_fake, labels)
            D_loss_fake.backward()

            self.optimizer_d.step()

            # Generator training step
            self.generator.zero_grad()
            labels.fill_(1)
            D_fake = self.discriminator(fake_images).view(-1)
            G_loss = self.loss_fn(D_fake, labels)
            G_loss.backward()

            self.optimizer_g.step()

            D_loss = D_loss_real.item() + D_loss_fake.item()
            G_loss = G_loss.item()

            if (i % Params.ReportEvery == 0) and i>0:
                print( f"Epoch[{cur_epoch}/{Params.NumEpochs}] Batch[{i}/{len(self.loader)}]" )
                print( f"\tD Loss: {D_loss}\tG Loss: {G_loss}\n" )

            self.loss_record['D'].append(D_loss)
            self.loss_record['G'].append(G_loss)


    def save_checkpoint(self, checkpoint_name):
        dir_path = os.path.join(Params.CheckpointPath, checkpoint_name)
        os.makedirs(dir_path, exist_ok=True)
        print( f"Saving snapshot to: {dir_path}" )

        save_state = dict()
        save_state['g_state'] = self.generator.state_dict()
        save_state['g_optimizer_state'] = self.optimizer_g.state_dict()
        save_state['d_state'] = self.discriminator.state_dict()
        save_state['d_optimizer_state'] = self.optimizer_d.state_dict()

        torch.save(
            save_state,
            os.path.join(dir_path, 'model_params.pt')
        )

        with torch.no_grad():
            self.generator.eval()
            gen_images = self.generator(self.exemplar_latent_vectors)
            self.generator.train()
            torchvision.utils.save_image(
                gen_images,
                os.path.join(dir_path, 'gen_images.png'),
                nrow = 8,
                normalize=True
            )
D.train()
""" Loss for GAN """
BCE_loss = nn.BCELoss().cuda()
L1_loss = nn.L1Loss().cuda()
""" tensorboard visualize """
writer = SummaryWriter(log_dir=opt.logdir)
""" start training """
print('training start!')

for epoch in range(opt.train_epoch):
    local_iter = 0

    for x, y in train_loader:
        t0 = time.time()
        """ training Discriminator D"""
        D.zero_grad()

        x, y = Variable(x.cuda()), Variable(y.cuda())

        D_real_result = D(x, y).squeeze()
        D_real_loss = BCE_loss(
            D_real_result,
            Variable(torch.ones(D_real_result.size()).cuda()))  # log(D(x, y))
        D_real_loss *= opt.D_lambda

        D_fake_result = D(x, G(x).detach()).squeeze()
        D_fake_loss = BCE_loss(
            D_fake_result, Variable(torch.zeros(
                D_fake_result.size()).cuda()))  # -log(1-D(x, G(x)))
        D_fake_loss *= opt.D_lambda