示例#1
0
文件: hidden.py 项目: dlshu/RS-GAN-v1
    def __init__(self, configuration: HiDDenConfiguration,
                 device: torch.device, noiser: Noiser, tb_logger):
        """
        :param configuration: Configuration for the net, such as the size of the input image, number of channels in the intermediate layers, etc.
        :param device: torch.device object, CPU or GPU
        :param noiser: Object representing stacked noise layers.
        :param tb_logger: Optional TensorboardX logger object, if specified -- enables Tensorboard logging
        """
        super(Hidden, self).__init__()

        self.encoder_decoder = EncoderDecoder(configuration, noiser).to(device)
        self.discriminator = Discriminator(configuration).to(device)

        self.optimizer_enc_dec = torch.optim.Adam(
            self.encoder_decoder.parameters())
        self.optimizer_discrim = torch.optim.Adam(
            self.discriminator.parameters())

        if configuration.use_vgg:
            self.vgg_loss = VGGLoss(3, 1, False)
            self.vgg_loss.to(device)
        else:
            self.vgg_loss = None

        self.config = configuration
        self.device = device

        self.bce_with_logits_loss = nn.BCEWithLogitsLoss().to(device)
        self.mse_loss = nn.MSELoss().to(device)
        self.ce_loss = nn.CrossEntropyLoss().to(device)

        # Defined the labels used for training the discriminator/adversarial loss
        self.cover_label = 1
        self.encoded_label = 0

        self.tb_logger = tb_logger
        if tb_logger is not None:
            from tensorboard_logger import TensorBoardLogger
            #print(self.encoder_decoder.encoder._modules['module'].final_layer)
            encoder_final = self.encoder_decoder.encoder._modules[
                'module'].final_layer
            encoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/encoder_out'))
            decoder_final = self.encoder_decoder.decoder._modules[
                'module'].linear
            decoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/decoder_out'))
            #print(self.discriminator._modules)
            discrim_final = self.discriminator._modules['linear']
            discrim_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/discrim_out'))
示例#2
0
 def __init__(self, image_dimension, downsample_factor):
     self.image_dimension = image_dimension
     self.downsample_factor = downsample_factor
     self.latent_dimension = (
         image_dimension[0] // self.downsample_factor,
         image_dimension[1] // self.downsample_factor,
         image_dimension[2],
     )
     self.generator = Generator(self.latent_dimension).model
     self.discriminator = Discriminator(self.image_dimension).model
     self.vgg_loss = VGGLoss(self.image_dimension)
示例#3
0
def main():
    batch_size = 16
    generator = Generator().cuda()
    discriminator = Discriminator().cuda()
    optimizer_G = optim.Adam(generator.parameters(), lr=1e-4)
    optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-4)
    dataset = FaceData()
    # dataset = TrainData()
    data_loader = DataLoader(dataset,
                             batch_size,
                             shuffle=True,
                             num_workers=0,
                             pin_memory=True,
                             drop_last=True)
    MSE = nn.MSELoss()
    BCE = nn.BCELoss()

    vgg_loss_function = VGGLoss()
    vgg_loss_function.eval()

    discriminator.train()
    generator.train()
    optimizer_G.zero_grad()
    optimizer_D.zero_grad()

    print("Start Training")
    current_epoch = 0
    for epoch in range(current_epoch, 100):
        for step, (img_Input, img_GT) in tqdm(enumerate(data_loader)):

            img_GT = img_GT.cuda()
            img_Input = img_Input.cuda()

            # if epoch < 10:
            #     # SRResnet Initialize Generator update
            #     generator.zero_grad()
            #     img_SR = generator(img_Input)
            #     loss_content = MSE(img_SR, img_GT)
            #     loss_content.backward()
            #     optimizer_G.step()

            #     if step%100 == 0:
            #         print()
            #         print("Loss_content : {}".format(loss_content.item()))
            #     continue

            # Discriminator update
            discriminator.zero_grad()
            D_real = discriminator(img_GT)
            loss_Dreal = 0.1 * BCE(D_real, torch.ones(batch_size, 1).cuda())
            loss_Dreal.backward()
            D_x = D_real.mean().item()

            img_SR = generator(img_Input)
            D_fake = discriminator(img_SR.detach())
            loss_Dfake = 0.1 * BCE(D_fake, torch.zeros(batch_size, 1).cuda())
            loss_Dfake.backward()
            DG_z = D_fake.mean().item()

            loss_D = (loss_Dfake + loss_Dreal)
            optimizer_D.step()

            # Generator update
            generator.zero_grad()
            loss_content = MSE(img_SR, img_GT)
            loss_vgg = vgg_loss_function(img_SR, img_GT)

            # img_SR = generator(img_Input)
            G_fake = discriminator(img_SR)
            loss_Gfake = BCE(G_fake, torch.ones(batch_size, 1).cuda())

            loss_G = loss_content + 0.006 * loss_vgg + 0.001 * loss_Gfake
            loss_G.backward()
            # loss_Dfake.backward()
            optimizer_G.step()

            if step % 10 == 0:
                # :.10f
                print()
                print("fake out : {}".format(DG_z))
                print("real out : {}".format(D_x))
                print("Loss_Dfake :   {}".format(loss_Dfake.item()))
                print("Loss_Dreal :   {}".format(loss_Dreal.item()))
                print("Loss_D :       {}".format(loss_D.item()))
                print("Loss_content : {}".format(loss_content.item()))
                print("Loss_vgg :     {}".format(0.006 * loss_vgg.item()))
                print("Loss_Gfake :   {}".format(0.001 * loss_Gfake.item()))
                print("Loss_G :       {}".format(loss_G.item()))
                print("Loss_Total :   {}".format((loss_G + loss_D).item()))
                # print("Loss_D : {:.4f}".format(loss_D.item()))
                # print("Loss : {:.4f}".format(loss_total.item()))

        with torch.no_grad():
            generator.eval()
            save_image(denorm(img_SR[0].cpu()),
                       "./Result/{0}_SR.png".format(epoch))
            save_image(denorm(img_GT[0].cpu()),
                       "./Result/{0}_GT.png".format(epoch))
            generator.train()
示例#4
0
class Hidden:
    def __init__(self, configuration: HiDDenConfiguration,
                 device: torch.device, noiser: Noiser, tb_logger):
        """
        :param configuration: Configuration for the net, such as the size of the input image, number of channels in the intermediate layers, etc.
        :param device: torch.device object, CPU or GPU
        :param noiser: Object representing stacked noise layers.
        :param tb_logger: Optional TensorboardX logger object, if specified -- enables Tensorboard logging
        """
        super(Hidden, self).__init__()

        self.encoder_decoder = EncoderDecoder(configuration, noiser).to(device)
        self.optimizer_enc_dec = torch.optim.Adam(
            self.encoder_decoder.parameters())

        self.discriminator = Discriminator(configuration).to(device)
        self.optimizer_discrim = torch.optim.Adam(
            self.discriminator.parameters())

        if configuration.use_vgg:
            self.vgg_loss = VGGLoss(3, 1, False)
            self.vgg_loss.to(device)
        else:
            self.vgg_loss = None

        self.config = configuration
        self.device = device

        self.bce_with_logits_loss = nn.BCEWithLogitsLoss()
        self.mse_loss = nn.MSELoss()

        # Defined the labels used for training the discriminator/adversarial loss
        self.cover_label = 1
        self.encoded_label = 0

        self.tb_logger = tb_logger
        if tb_logger is not None:
            from tensorboard_logger import TensorBoardLogger
            encoder_final = self.encoder_decoder.encoder._modules[
                'final_layer']
            encoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/encoder_out'))
            decoder_final = self.encoder_decoder.decoder._modules['linear']
            decoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/decoder_out'))
            discrim_final = self.discriminator._modules['linear']
            discrim_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/discrim_out'))

    def train_on_batch(self, batch: list):
        """
        Trains the network on a single batch consisting of images and messages
        :param batch: batch of training data, in the form [images, messages]
        :return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch
        """
        images, messages = batch
        batch_size = images.shape[0]
        with torch.enable_grad():
            # ---------------- Train the discriminator -----------------------------
            self.optimizer_discrim.zero_grad()
            # train on cover
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_on_cover = self.discriminator(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)
            d_loss_on_cover.backward()

            # train on fake
            encoded_images, noised_images, decoded_messages = self.encoder_decoder(
                images, messages)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encoded_label,
                                                device=self.device)
            d_on_encoded = self.discriminator(encoded_images.detach())
            d_loss_on_encoded = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)
            d_loss_on_encoded.backward()
            self.optimizer_discrim.step()

            # --------------Train the generator (encoder-decoder) ---------------------
            self.optimizer_enc_dec.zero_grad()
            # target label for encoded images should be 'cover', because we want to fool the discriminator
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)
            d_on_encoded_for_enc = self.discriminator(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)

            if self.vgg_loss == None:
                g_loss_enc = self.mse_loss(encoded_images, images)
            else:
                vgg_on_cov = self.vgg_loss(images)
                vgg_on_enc = self.vgg_loss(encoded_images)
                g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc)

            g_loss_dec = self.mse_loss(decoded_messages, messages)


            g_loss = self.config.adversarial_loss * g_loss_adv + self.config.encoder_loss * g_loss_enc \
                     + self.config.decoder_loss * g_loss_dec
            g_loss.backward()
            self.optimizer_enc_dec.step()

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_avg_err = np.sum(
            np.abs(decoded_rounded - messages.detach().cpu().numpy())) / (
                batch_size * messages.shape[1])

        losses = {
            'loss           ': g_loss.item(),
            'encoder_mse    ': g_loss_enc.item(),
            'dec_mse        ': g_loss_dec.item(),
            'bitwise-error  ': bitwise_avg_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encoded.item()
        }
        return losses, (encoded_images, noised_images, decoded_messages)

    def validate_on_batch(self, batch: list):
        """
        Runs validation on a single batch of data consisting of images and messages
        :param batch: batch of validation data, in form [images, messages]
        :return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch
        """

        # if TensorboardX logging is enabled, save some of the tensors.
        if self.tb_logger is not None:
            encoder_final = self.encoder_decoder.encoder._modules[
                'final_layer']
            self.tb_logger.add_tensor('weights/encoder_out',
                                      encoder_final.weight)
            decoder_final = self.encoder_decoder.decoder._modules['linear']
            self.tb_logger.add_tensor('weights/decoder_out',
                                      decoder_final.weight)
            discrim_final = self.discriminator._modules['linear']
            self.tb_logger.add_tensor('weights/discrim_out',
                                      discrim_final.weight)

        images, messages = batch
        batch_size = images.shape[0]

        with torch.no_grad():
            d_on_cover = self.discriminator(images)
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_on_cover = self.discriminator(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)

            encoded_images, noised_images, decoded_messages = self.encoder_decoder(
                images, messages)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encoded_label,
                                                device=self.device)
            d_on_encoded = self.discriminator(encoded_images)
            d_loss_on_encoded = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)

            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)
            d_on_encoded_for_enc = self.discriminator(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)

            if self.vgg_loss == None:
                g_loss_enc = self.mse_loss(encoded_images, images)
            else:
                vgg_on_cov = self.vgg_loss(images)
                vgg_on_enc = self.vgg_loss(encoded_images)
                g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc)

            g_loss_dec = self.mse_loss(decoded_messages, messages)
            g_loss = self.config.adversarial_loss * g_loss_adv + self.config.encoder_loss * g_loss_enc \
                     + self.config.decoder_loss * g_loss_dec

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_avg_err = np.sum(
            np.abs(decoded_rounded - messages.detach().cpu().numpy())) / (
                batch_size * messages.shape[1])

        losses = {
            'loss           ': g_loss.item(),
            'encoder_mse    ': g_loss_enc.item(),
            'dec_mse        ': g_loss_dec.item(),
            'bitwise-error  ': bitwise_avg_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encoded.item()
        }
        return losses, (encoded_images, noised_images, decoded_messages)

    def to_stirng(self):
        return '{}\n{}'.format(str(self.encoder_decoder),
                               str(self.discriminator))
示例#5
0
文件: dalle.py 项目: deepglugs/dalle
def train_vae(path, vocab, args):

    vae = get_vae(args)

    tforms = transforms.Compose([
        transforms.Resize((args.size, args.size)),
        transforms.RandomHorizontalFlip(),
        #transforms.RandomVerticalFlip(),
        #transforms.RandomRotation((-180, 180)),
        # transforms.ColorJitter(0.15, 0.15, 0.15, 0.15),
        transforms.ToTensor(),
        transforms.Normalize((0.5, ) * 3, (0.5, ) * 3)
    ])

    #generator = DataGenerator(images,
    #                          txts,
    #                          vocab,
    #                          channels_first=True,
    #                          batch_size=1,
    #                         dim=(args.size, args.size),
    #                          normalize=False,
    #                          transform=tforms)

    generator = ImageFolder(args.source, tforms)

    dl = torch.utils.data.DataLoader(generator,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=8)

    optim = Adam(vae.parameters(), lr=args.lr)

    vgg_loss = None

    if args.vgg_loss:
        vgg_loss = VGGLoss(device=args.device)

    rate = deque([1], maxlen=5)
    disp_size = 4
    step = 0

    if args.tempsched:
        vae.temperature = args.temperature
        dk = 0.7**(1 / len(generator))
        print('Scale Factor:', dk)

    for epoch in range(1, args.epochs):
        for i, data in enumerate(dl):
            step += 1

            t1 = time.monotonic()

            images, labels = data

            images = images.to(args.device)
            labels = labels.to(args.device)

            recons = vae(images)
            loss = loss_fn(images, recons)

            if vgg_loss is not None:
                loss += vgg_loss(images, recons)

            optim.zero_grad()

            loss.backward()

            optim.step()

            t2 = time.monotonic()
            rate.append(round(1.0 / (t2 - t1), 2))

            if step % 100 == 0:
                print("epoch {}/{} step {} loss: {} - {}it/s".format(
                    epoch, args.epochs, step,
                    round(loss.item() / len(images), 6), round(np.mean(rate)),
                    1))

            if step % 1000 == 0:
                with torch.no_grad():
                    codes = vae.get_codebook_indices(images[:disp_size])
                    imgx = vae.decode(codes)

                grid = torch.cat(
                    [images[:disp_size], recons[:disp_size], imgx])
                grid = make_grid(grid,
                                 nrow=disp_size,
                                 normalize=True,
                                 range=(-1, 1))
                VTF.to_pil_image(grid).save(
                    os.path.join(args.samples_out,
                                 f"vae_{epoch}_{int(step / epoch)}.png"))
                print("saving checkpoint...")
                torch.save(vae.cpu().state_dict(), args.vae)
                vae.to(args.device)
                print("saving complete")

        if args.tempsched:
            vae.temperature *= dk
            print("Current temperature: ", vae.temperature)

    torch.save(vae.cpu().state_dict(), args.vae)
示例#6
0
文件: dalle.py 项目: deepglugs/dalle
def train_dalle(vae, args):

    if args.vocab is None:
        args.vocab = args.source
    else:
        assert os.path.isfile(args.vocab)

    if args.tags_source is None:
        args.tags_source = args.source

    imgs = get_images(args.source)
    txts = get_images(args.tags_source, exts=".txt")
    vocab = get_vocab(args.vocab, top=args.vocab_limit)

    tforms = transforms.Compose([
        transforms.Resize((args.size, args.size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, ) * 3, (0.5, ) * 3)
    ])

    def txt_xforms(txt):
        # print(f"txt: {txt}")
        txt = txt.split(", ")
        if args.shuffle_tags:
            np.random.shuffle(txt)
        txt = tokenize(txt, vocab, offset=1)
        # txt = torch.Tensor(txt)

        return txt

    data = ImageLabelDataset(imgs,
                             txts,
                             vocab,
                             dim=(args.size, args.size),
                             transform=tforms,
                             channels_first=True,
                             return_raw_txt=True)

    dl = torch.utils.data.DataLoader(data,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=0)

    dalle = get_dalle(vae, vocab, args)

    optimizer = Adam(dalle.parameters(), lr=args.lr)
    if args.vgg_loss:
        vgg_loss = VGGLoss(device=args.device)

    disp_size = 4 if args.batch_size > 4 else args.batch_size

    amp_scaler = GradScaler(enabled=args.fp16)

    for epoch in range(1, args.epochs + 1):

        batch_idx = 0
        train_loss = 0

        for image, labels in dl:
            i = image
            text_ = []
            for label in labels:
                text_.append(txt_xforms(label))
            # print(text_)
            text = torch.LongTensor(text_).to(args.device)
            image = image.to(args.device)

            mask = torch.ones_like(text).bool().to(args.device)

            with autocast(enabled=args.fp16):
                loss = dalle(text, image, mask=mask, return_loss=True)

            # loss = loss_func(image, gens)
            train_loss += loss.item()

            optimizer.zero_grad()
            amp_scaler.scale(loss).backward()

            amp_scaler.step(optimizer)
            amp_scaler.update()

            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(i), len(data),
                    100. * batch_idx / int(round(len(data) / args.batch_size)),
                    loss.item() / len(image)))

            if batch_idx % 100 == 0:
                oimgs = dalle.generate_images(text, mask=mask)
                grid = oimgs[:disp_size]
                grid = make_grid(grid,
                                 nrow=disp_size,
                                 normalize=True,
                                 range=(-1, 1))
                VTF.to_pil_image(grid).save(
                    os.path.join(args.samples_out,
                                 f"dalle_{epoch}_{int(batch_idx)}.png"))
                torch.save(dalle.cpu().state_dict(), args.dalle)
                dalle.to(args.device)

            batch_idx += 1

        print('====> Epoch: {} Average loss: {:.4f}'.format(
            epoch, train_loss / len(data)))

    torch.save(dalle.cpu().state_dict(), args.dalle)
    dalle.to(args.device)