コード例 #1
0
def main():
    # Split the dataset
    train_dataset = sunnerData.ImageDataset(
        root=[['/home/sunner/Music/waiting_for_you_dataset/wait'],
              ['/home/sunner/Music/waiting_for_you_dataset/real_world']],
        transform=None,
        split_ratio=0.1,
        save_file=True)
    del train_dataset
    test_dataset = sunnerData.ImageDataset(
        file_name='.split.pkl',
        transform=transforms.Compose([
            sunnertransforms.Resize((160, 320)),
            sunnertransforms.ToTensor(),
            sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW),
            sunnertransforms.Normalize(),
        ]))

    # Create the data loader
    loader = sunnerData.DataLoader(test_dataset,
                                   batch_size=32,
                                   shuffle=False,
                                   num_workers=2)

    # Use upper wrapper to assign particular iteration
    loader = sunnerData.IterationLoader(loader, max_iter=1)

    # Show!
    for batch_img, _ in loader:
        batch_img = sunnertransforms.asImg(batch_img, size=(160, 320))
        cv2.imshow('show_window', batch_img[0][:, :, ::-1])
        cv2.waitKey(0)
コード例 #2
0
def main():
    # Get the pallete object
    pallete = sunnertransforms.getCategoricalMapping(
        path="ear-pen-pallete.json")[0]

    # Create the dataset
    img_dataset = sunnerData.ImageDataset(
        root=[
            ['/home/sunner/Music/Ear-Pen-master/generate/train/img'],
        ],
        transforms=transforms.Compose([
            sunnertransforms.Resize((260, 195)),
            sunnertransforms.ToTensor(),
            sunnertransforms.ToFloat(),
            sunnertransforms.Normalize(mean=[0.5, 0.5, 0.5],
                                       std=[0.5, 0.5, 0.5]),
        ]))
    tag_dataset = sunnerData.ImageDataset(
        root=[['/home/sunner/Music/Ear-Pen-master/generate/train/tag']],
        transforms=transforms.Compose([
            sunnertransforms.Resize((260, 195)),
            sunnertransforms.ToTensor(),
            sunnertransforms.ToFloat(),
            sunnertransforms.Normalize(mean=[0.5, 0.5, 0.5],
                                       std=[0.5, 0.5, 0.5]),
            sunnertransforms.CategoricalTranspose(
                pallete=pallete,
                direction=sunnertransforms.COLOR2INDEX,
                index_default=0)
        ]))

    # Create the loader
    loader = sunnerData.MultiLoader([img_dataset, tag_dataset],
                                    batch_size=1,
                                    shuffle=False,
                                    num_workers=2)

    # Define the reverse operator
    back_op = sunnertransforms.CategoricalTranspose(
        pallete=pallete,
        direction=sunnertransforms.INDEX2COLOR,
        index_default=0)

    # Show!
    for (_, batch_tag) in loader:
        batch_tag = back_op(batch_tag)
        batch_tag = sunnertransforms.asImg(batch_tag, size=(260, 195))
        cv2.imshow('show_window', batch_tag[0][:, :, ::-1])
        cv2.waitKey(0)
        break
コード例 #3
0
def main():
    # Create the fundemental data loader
    loader = sunnerData.DataLoader(sunnerData.ImageDataset(
        root=[['/home/sunner/Music/waiting_for_you_dataset/wait'],
              ['/home/sunner/Music/waiting_for_you_dataset/real_world']],
        transforms=transforms.Compose([
            sunnertransforms.Resize((160, 320)),
            sunnertransforms.ToTensor(),
            sunnertransforms.ToFloat(),
            sunnertransforms.Normalize(mean=[0.5, 0.5, 0.5],
                                       std=[0.5, 0.5, 0.5]),
        ])),
                                   batch_size=32,
                                   shuffle=False,
                                   num_workers=2)

    # Use upper wrapper to assign particular iteration
    loader = sunnerData.IterationLoader(loader, max_iter=1)

    # Show!
    for batch_tensor, _ in loader:
        batch_img = sunnertransforms.asImg(batch_tensor, size=(160, 320))
        cv2.imshow('show_window', batch_img[0][:, :, ::-1])
        cv2.waitKey(0)

        # Or show multiple image in one line
        sunnertransforms.show(batch_tensor[:10], row=2, column=5)
コード例 #4
0
def demo(args):
    """
        This function define the demo process
        
        Arg:    args    (napmespace) - The arguments
    """
    # Create the data loader
    loader = sunnerData.DataLoader(
        dataset=sunnerData.ImageDataset(
            root=[[args.demo]],
            transforms=transforms.Compose([
                sunnerTransforms.Resize(output_size=(args.H, args.W)),
                sunnerTransforms.ToTensor(),
                sunnerTransforms.ToFloat(),
                # sunnerTransforms.Transpose(),
                sunnerTransforms.Normalize(mean=[0.5, 0.5, 0.5],
                                           std=[0.5, 0.5, 0.5]),
            ])),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=2)

    # Create the model
    model = GANomaly2D(r=args.r, device=args.device)
    model.IO(args.resume, direction='load')

    # Demo!
    bar = tqdm(loader)
    model.eval()
    with torch.no_grad():
        for (img, ) in bar:
            z, z_ = model.forward(img)
            img, img_ = model.getImg()
            visualizeAnomalyImage(img, img_, z, z_)
コード例 #5
0
def main(args):
    """
        Train the 2nd step for LaDo

        Arg:    args    (Namespace)     - The argument object
    """
    # Create the data loader and the pre-trained model
    loader = Data.DataLoader(dataset=sunnerData.ImageDataset(
        root=[[args.src_pair_image_folder], [args.tar_pair_image_folder]],
        transform=transforms.Compose([
            transforms.Resize((args.img_size, args.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ]),
        sample_method=sunnerData.OVER_SAMPLING),
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=8)
    model = LaDo(args.content_dims, args.appearance_dims, args.batch_size)
    if os.path.exists(args.model_path_1st):
        model.load(args.model_path_1st, stage=1)
        model.copy()
        model.fix()
    else:
        raise Exception(
            "Pre-trained model didn't exist, please train for the 1st step first!"
        )

    # Loop
    for ep in range(args.total_epoch + 1):
        bar = tqdm(loader)
        for i, (src_img, tar_img) in enumerate(bar):
            model.setInput_2nd(src_img, tar_img)
            if ep != 0:
                model.backward_2nd()

        # Print average loss
        if ep != 0:
            print("=====" * 20)
            print("<< Epoch {} average >>".format(ep) + "     " * 20)
            string = ""
            for i, (key, loss) in enumerate(
                    model.getLoss(stage=2, normalize=True).items()):
                if loss == 0.0:
                    continue
                loss = round(loss, 10)
                if i % 2 == 1:
                    string = "{:>20}: {:>15} \t".format(key[14:], str(loss))
                else:
                    string += "{:>20}: {:>15} \t".format(key[14:], str(loss))
                    print(string)
            model.finishEpoch()

        # Save
        if not os.path.exists(os.path.join(args.root_folder, 'models_2nd')):
            os.mkdir(os.path.join(args.root_folder, 'models_2nd'))
        model.save(path=os.path.join(args.root_folder, 'models_2nd',
                                     num2Str(ep) + '.pth'))
コード例 #6
0
def main():
    # Define the loader to generate the pallete object
    loader = sunnerData.DataLoader(
        sunnerData.ImageDataset(
            root=[tag_folder],
            transform=transforms.Compose([
                sunnertransforms.ToTensor(),
            ]),
            save_file=False  # Don't save the record file, be careful!
        ),
        batch_size=2,
        shuffle=False,
        num_workers=2)
    pallete = sunnertransforms.getCategoricalMapping(loader,
                                                     path='pallete.json')[0]
    del loader

    # Define the actual loader
    loader = sunnerData.DataLoader(sunnerData.ImageDataset(
        root=[img_folder, tag_folder],
        transform=transforms.Compose([
            sunnertransforms.Resize((512, 1024)),
            sunnertransforms.ToTensor(),
            sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW),
            sunnertransforms.Normalize(),
        ])),
                                   batch_size=32,
                                   shuffle=False,
                                   num_workers=2)

    # Define the reverse operator
    goto_op = sunnertransforms.CategoricalTranspose(
        pallete=pallete, direction=sunnertransforms.COLOR2ONEHOT)
    back_op = sunnertransforms.CategoricalTranspose(
        pallete=pallete, direction=sunnertransforms.ONEHOT2COLOR)

    # Show!
    for _, batch_index in loader:
        batch_img = back_op(goto_op(batch_index))
        batch_img = sunnertransforms.asImg(batch_img, size=(512, 1024))
        cv2.imshow('show_window', batch_img[0][:, :, ::-1])
        cv2.waitKey(0)
        break
コード例 #7
0
def main():
    # Define the loader to generate the pallete object
    loader = sunnerData.DataLoader(sunnerData.ImageDataset(
        root = [
            tag_folder
        ],
        transforms = transforms.Compose([            
            sunnertransforms.ToTensor(),
            sunnertransforms.ToFloat(),
            sunnertransforms.UnNormalize(mean=[0, 0, 0], std=[255, 255, 255]),  # Remember to transfer back to [0~255] before generate pallete 
            sunnertransforms.Transpose(sunnertransforms.BCHW2BHWC)              # Remember to transfer back to BHWC before generate pallete 
        ])
        ), batch_size = 2, shuffle = False, num_workers = 2
    )
    pallete = sunnertransforms.getCategoricalMapping(loader, path = 'pallete.json')[0]
    del loader

    # Define the actual loader
    loader = sunnerData.DataLoader(sunnerData.ImageDataset(
        root = [
            img_folder,
            tag_folder
        ],
        transforms = transforms.Compose([
            sunnertransforms.Resize((512, 1024)),
            sunnertransforms.ToTensor(),
            sunnertransforms.ToFloat(),
            sunnertransforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5]),
        ])), batch_size = 32, shuffle = False, num_workers = 2
    )

    # Define the reverse operator
    goto_op = sunnertransforms.CategoricalTranspose(pallete = pallete, direction = sunnertransforms.COLOR2ONEHOT)
    back_op = sunnertransforms.CategoricalTranspose(pallete = pallete, direction = sunnertransforms.ONEHOT2COLOR)

    # Show!
    for _, batch_index in loader:
        batch_img = back_op(goto_op(batch_index))
        batch_img = sunnertransforms.asImg(batch_img, size = (512, 1024))
        cv2.imshow('show_window', batch_img[0][:, :, ::-1])
        cv2.waitKey(0)
        break
コード例 #8
0
def train(args):
    """
        This function define the training process
        
        Arg:    args    (napmespace) - The arguments
    """
    # Create the data loader
    loader = sunnerData.DataLoader(
        dataset=sunnerData.ImageDataset(
            root=[[args.train]],
            transforms=transforms.Compose([

                #                 transforms.RandomCrop(720,720)
                #                 transforms.RandomRotation(45)
                #                 transforms.RandomHorizontalFlip(),
                #                 transforms.ColorJitter(brightness=0.5, contrast=0.5),
                sunnerTransforms.Resize(output_size=(args.H, args.W)),
                #transforms.RandomCrop(512,512)
                sunnerTransforms.ToTensor(),
                sunnerTransforms.ToFloat(),
                # sunnerTransforms.Transpose(),
                sunnerTransforms.Normalize(mean=[0.5, 0.5, 0.5],
                                           std=[0.5, 0.5, 0.5]),
            ])),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=2)
    loader = sunnerData.IterationLoader(loader, max_iter=args.n_iter)

    # Create the model
    model = GANomaly2D(r=args.r, device=args.device)
    model.IO(args.resume, direction='load')
    model.train()

    # Train!
    bar = tqdm(loader)
    for i, (normal_img, ) in enumerate(bar):
        model.forward(normal_img)
        model.backward()
        loss_G, loss_D = model.getLoss()
        bar.set_description("Loss_G: " + str(loss_G) + " loss_D: " +
                            str(loss_D))
        bar.refresh()
        if i % args.record_iter == 0:
            model.eval()
            with torch.no_grad():
                z, z_ = model.forward(normal_img)
                img, img_ = model.getImg()
                visualizeEncoderDecoder(img, img_, z, z_, i)
            model.train()
            model.IO(args.det, direction='save')
    model.IO(args.det, direction='save')
コード例 #9
0
def main():
    # Create the fundemental data loader
    loader = sunnerData.DataLoader(sunnerData.ImageDataset(
        root=[['/home/sunner/Music/waiting_for_you_dataset/wait'],
              ['/home/sunner/Music/waiting_for_you_dataset/real_world']],
        transform=transforms.Compose([
            sunnertransforms.Resize((160, 320)),
            sunnertransforms.ToTensor(),
            sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW),
            sunnertransforms.Normalize(),
        ])),
                                   batch_size=32,
                                   shuffle=False,
                                   num_workers=2)

    # Use upper wrapper to assign particular iteration
    loader = sunnerData.IterationLoader(loader, max_iter=1)

    # Show!
    for batch_img, _ in loader:
        batch_img = sunnertransforms.asImg(batch_img, size=(160, 320))
        cv2.imshow('show_window', batch_img[0][:, :, ::-1])
        cv2.waitKey(0)
コード例 #10
0
ファイル: eval.py プロジェクト: micklexqg/P-Conv
def evalModel(args, model):
    # Create data loader
    loader = sunnerData.ImageLoader(sunnerData.ImageDataset(
        root_list=[args.folder_path, args.mask_path],
        transform=transforms.Compose([
            sunnertransforms.Rescale((args.size, args.size)),
            sunnertransforms.ToTensor(),
            sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW),
            sunnertransforms.Normalize()
        ]),
        sample_method=sunnerData.OVER_SAMPLING),
                                    batch_size=1,
                                    shuffle=False,
                                    num_workers=2)

    # Compute the PSNR and record
    psnr_list = []
    bar = tqdm(loader)
    for image, mask in bar:
        # Double the tensor to adapt with BN
        image = torch.cat([image, image], 0)
        mask = torch.cat([mask, mask], 0)

        # forward
        mask = (mask + 1) / 2
        model.setInput(target=image, mask=mask)
        model.forward()
        _, recon_img, _ = model.getOutput()
        psnr = compare_psnr(image[0].detach().cpu().numpy(),
                            recon_img[0].detach().cpu().numpy())
        psnr_list.append(psnr)

    # Show the result
    print('\n\n')
    print('-' * 20, 'Complete evaluation', '-' * 20)
    print('Testing average psnr: %.4f' % np.mean(psnr_list))
コード例 #11
0
def main(opts):
    # Create the data loader
    loader = sunnerData.DataLoader(
        sunnerData.ImageDataset(root=[[opts.path]],
                                transform=transforms.Compose([
                                    sunnertransforms.Resize((1024, 1024)),
                                    sunnertransforms.ToTensor(),
                                    sunnertransforms.ToFloat(),
                                    sunnertransforms.Transpose(
                                        sunnertransforms.BHWC2BCHW),
                                    sunnertransforms.Normalize(),
                                ])),
        batch_size=opts.batch_size,
        shuffle=True,
    )

    # Create the model
    start_epoch = 0
    G = StyleGenerator()
    D = StyleDiscriminator()

    # Load the pre-trained weight
    if os.path.exists(opts.resume):
        INFO("Load the pre-trained weight!")
        state = torch.load(opts.resume)
        G.load_state_dict(state['G'])
        D.load_state_dict(state['D'])
        start_epoch = state['start_epoch']
    else:
        INFO(
            "Pre-trained weight cannot load successfully, train from scratch!")

    # Multi-GPU support
    if torch.cuda.device_count() > 1:
        INFO("Multiple GPU:" + str(torch.cuda.device_count()) + "\t GPUs")
        G = nn.DataParallel(G)
        D = nn.DataParallel(D)
    G.to(opts.device)
    D.to(opts.device)

    # Create the criterion, optimizer and scheduler
    optim_D = optim.Adam(D.parameters(), lr=0.00001, betas=(0.5, 0.999))
    optim_G = optim.Adam(G.parameters(), lr=0.00001, betas=(0.5, 0.999))
    scheduler_D = optim.lr_scheduler.ExponentialLR(optim_D, gamma=0.99)
    scheduler_G = optim.lr_scheduler.ExponentialLR(optim_G, gamma=0.99)

    # Train
    fix_z = torch.randn([opts.batch_size, 512]).to(opts.device)
    softplus = nn.Softplus()
    Loss_D_list = [0.0]
    Loss_G_list = [0.0]
    for ep in range(start_epoch, opts.epoch):
        bar = tqdm(loader)
        loss_D_list = []
        loss_G_list = []
        for i, (real_img, ) in enumerate(bar):
            # =======================================================================================================
            #   (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            # =======================================================================================================
            # Compute adversarial loss toward discriminator
            D.zero_grad()
            real_img = real_img.to(opts.device)
            real_logit = D(real_img)
            fake_img = G(torch.randn([real_img.size(0), 512]).to(opts.device))
            fake_logit = D(fake_img.detach())
            d_loss = softplus(fake_logit).mean()
            d_loss = d_loss + softplus(-real_logit).mean()

            if opts.r1_gamma != 0.0:
                r1_penalty = R1Penalty(real_img.detach(), D)
                d_loss = d_loss + r1_penalty * (opts.r1_gamma * 0.5)

            if opts.r2_gamma != 0.0:
                r2_penalty = R2Penalty(fake_img.detach(), D)
                d_loss = d_loss + r2_penalty * (opts.r2_gamma * 0.5)

            loss_D_list.append(d_loss.item())

            # Update discriminator
            d_loss.backward()
            optim_D.step()

            # =======================================================================================================
            #   (2) Update G network: maximize log(D(G(z)))
            # =======================================================================================================
            if i % CRITIC_ITER == 0:
                G.zero_grad()
                fake_logit = D(fake_img)
                g_loss = softplus(-fake_logit).mean()
                loss_G_list.append(g_loss.item())

                # Update generator
                g_loss.backward()
                optim_G.step()

            # Output training stats
            bar.set_description("Epoch {} [{}, {}] [G]: {} [D]: {}".format(
                ep, i + 1, len(loader), loss_G_list[-1], loss_D_list[-1]))

        # Save the result
        Loss_G_list.append(np.mean(loss_G_list))
        Loss_D_list.append(np.mean(loss_D_list))

        # Check how the generator is doing by saving G's output on fixed_noise
        with torch.no_grad():
            fake_img = G(fix_z).detach().cpu()
            save_image(fake_img,
                       os.path.join(opts.det, 'images',
                                    str(ep) + '.png'),
                       nrow=4,
                       normalize=True)

        # Save model
        state = {
            'G': G.state_dict(),
            'D': D.state_dict(),
            'Loss_G': Loss_G_list,
            'Loss_D': Loss_D_list,
            'start_epoch': ep,
        }
        torch.save(state, os.path.join(opts.det, 'models', 'latest.pth'))

        scheduler_D.step()
        scheduler_G.step()

    # Plot the total loss curve
    Loss_D_list = Loss_D_list[1:]
    Loss_G_list = Loss_G_list[1:]
    plotLossCurve(opts, Loss_D_list, Loss_G_list)
コード例 #12
0
def train(opts):
    def log(string, name="stylegan.log"):
        with open(name, 'a') as f:
            f.write(string + '\n')

    writer = SummaryWriter(str(opts.output))
    loader = dataset.DataLoader(dataset=dataset.ImageDataset(
        [[opts.input]],
        transform=transforms.Compose([
            trans.Resize((opts.imsize, opts.imsize)),
            trans.ToTensor(),
            trans.ToFloat(),
            trans.Transpose(trans.BHWC2BCHW),
            trans.Normalize()
        ])),
                                batch_size=opts.batch_size,
                                shuffle=True)
    G = opts.G
    D = opts.D
    step = 0
    start_epoch = opts.start_epoch
    if opts.resume:
        try:
            assert os.path.exists(opts.resume)
            state = torch.load(opts.resume)
            G.load_state_dict(state['G'])
            D.load_state_dict(state['D'])
            start_epoch = state['start_epoch']
            logger.info("Load Pretrained Weight")
        except:
            logger.warn("Resume Files cannot Load")
            logger.info("Train from Scratch")
    else:
        logger.info("Train from Scratch")
    if torch.cuda.device_count() > 1 and opts.device == 'cuda':
        logger.info(f"{torch.cuda.device_count()} GPUs found.")
        G = nn.DataParallel(G)
        D = nn.DataParallel(D)
    if opts.device == 'cuda':
        torch.backends.cudnn.benchmark = True
    G.to(opts.device)
    D.to(opts.device)
    optimG = Adam(G.parameters(), lr=opts.g_lr, betas=opts.betas)
    optimD = Adam(D.parameters(), lr=opts.d_lr, betas=opts.betas)
    schedulerG = lr_scheduler.ExponentialLR(optimG, gamma=opts.g_lrdecay)
    schedulerD = lr_scheduler.ExponentialLR(optimD, gamma=opts.d_lrdecay)
    fixed_z = torch.randn([opts.batch_size, 512]).to(opts.device)
    sp = nn.Softplus()
    g_store = [0.0]
    d_store = [0.0]
    for epoch in range(start_epoch, opts.epochs + 1):
        bar = tqdm(loader)
        glosses = []
        dlosses = []
        for i, (real, ) in enumerate(bar):
            step += 1
            D.zero_grad()
            real = real.to(opts.device)
            Dr = D(real)
            writer.add_graph(D, real)
            z = torch.randn([real.size(0), 512]).to(opts.device)
            fake = G(z)
            writer.add_graph(G, z)
            Df = D(fake.detach())
            Dloss = sp(Df).mean() + sp(-Dr).mean()
            if opts.r1gamma > 0:
                r1 = r1_penalty(real.detach(), D)
                Dloss = Dloss + r1 * (opts.r1gamma * .5)
            if opts.r2gamma > 0:
                r2 = r2_penalty(fake.detach(), D)
                Dloss = Dloss + r2 * (opts.r2gamma * .5)
            dlosses.append(Dloss.item())
            Dloss.backward()
            optimD.step()
            if i % opts.critic_iters == 0:
                G.zero_grad()
                Df = D(fake)
                Gloss = sp(-Df).mean()
                glosses.append(Gloss.item())
                Gloss.backward()
                optimG.step()
            if i % opts.show_interval == 0:
                with torch.no_grad():
                    nr = int(math.ceil(math.sqrt(opts.batch_size)))
                    z = torch.randn([real.size(0), 512]).to(opts.device)
                    img = G(z)
                    save_image(img.detach().cpu(),
                               os.path.join(opts.output, 'images', 'normal',
                                            f'{epoch:04}_{i:06}.png'),
                               nrow=nr,
                               normalize=True)
                    fakes = utils.make_grid(img, nr, padding=0)
                    fakes = fakes.to(torch.float32).cpu().numpy()
                    fakes = np.clip((fakes / 2) + 0.5, 0, 1)
                    writer.add_image(f"EPOCH{epoch}/Random",
                                     torch.from_numpy(fakes), i)
                    img = G(fixed_z)
                    save_image(img.detach().cpu(),
                               os.path.join(opts.output, 'images', 'fixed',
                                            f'{epoch:04}_{i:06}.png'),
                               nrow=nr,
                               normalize=True)
                    fakes = utils.make_grid(img, nr, padding=0)
                    fakes = fakes.to(torch.float32).cpu().numpy()
                    fakes = np.clip((fakes / 2) + 0.5, 0, 1)
                    writer.add_image(f"EPOCH{epoch}/Fixed",
                                     torch.from_numpy(fakes), i)
            writer.add_scalar(f"LOSS/Generator",
                              Gloss.item(),
                              global_step=step)
            writer.add_scalar(f"LOSS/Discriminator",
                              Dloss.item(),
                              global_step=step)
            bar.set_description(
                f"Epoch {epoch}/{opts.epochs} G: {glosses[-1]:.6f} D: {dlosses[-1]:.6f}"
            )
        g_store.append(np.mean(glosses))
        d_store.append(np.mean(dlosses))
        state = {
            'G': G.state_dict(),
            'D': D.state_dict(),
            'Loss_G': g_store,
            'Loss_D': d_store,
            'start_epoch': epoch,
            'opts': opts
        }
        torch.save(state, os.path.join(opts.output, 'models', 'latest.pth'))
        if epoch % 10 == 0:
            torch.save(state,
                       os.path.join(opts.output, 'models', f'{epoch:04}.pth'))
        schedulerD.step()
        schedulerG.step()
コード例 #13
0
def main(opts):
    # Create the data loader
    loader = sunnerData.DataLoader(sunnerData.ImageDataset(
        root=[[opts.path]],
        transform=transforms.Compose([
            sunnertransforms.Resize((128, 128)),
            sunnertransforms.ToTensor(),
            sunnertransforms.ToFloat(),
            sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW),
            sunnertransforms.Normalize(),
        ])),
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=4)

    # Create the model
    if opts.type == 'style':
        G = StyleGenerator().to(opts.device)
    else:
        G = Generator().to(opts.device)
    D = Discriminator().to(opts.device)

    # Load the pre-trained weight
    if os.path.exists(opts.resume):
        INFO("Load the pre-trained weight!")
        state = torch.load(opts.resume)
        G.load_state_dict(state['G'])
        D.load_state_dict(state['D'])
    else:
        INFO(
            "Pre-trained weight cannot load successfully, train from scratch!")

    # Create the criterion, optimizer and scheduler
    optim_D = optim.Adam(D.parameters(), lr=0.0001, betas=(0.5, 0.999))
    optim_G = optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.999))
    scheduler_D = optim.lr_scheduler.ExponentialLR(optim_D, gamma=0.99)
    scheduler_G = optim.lr_scheduler.ExponentialLR(optim_G, gamma=0.99)

    # Train
    fix_z = torch.randn([opts.batch_size, 512]).to(opts.device)
    Loss_D_list = [0.0]
    Loss_G_list = [0.0]
    for ep in range(opts.epoch):
        bar = tqdm(loader)
        loss_D_list = []
        loss_G_list = []
        for i, (real_img, ) in enumerate(bar):
            # =======================================================================================================
            #   Update discriminator
            # =======================================================================================================
            # Compute adversarial loss toward discriminator
            real_img = real_img.to(opts.device)
            real_logit = D(real_img)
            fake_img = G(torch.randn([real_img.size(0), 512]).to(opts.device))
            fake_logit = D(fake_img.detach())
            d_loss = -(real_logit.mean() -
                       fake_logit.mean()) + gradient_penalty(
                           real_img.data, fake_img.data, D) * 10.0
            loss_D_list.append(d_loss.item())

            # Update discriminator
            optim_D.zero_grad()
            d_loss.backward()
            optim_D.step()

            # =======================================================================================================
            #   Update generator
            # =======================================================================================================
            if i % CRITIC_ITER == 0:
                # Compute adversarial loss toward generator
                fake_img = G(
                    torch.randn([opts.batch_size, 512]).to(opts.device))
                fake_logit = D(fake_img)
                g_loss = -fake_logit.mean()
                loss_G_list.append(g_loss.item())

                # Update generator
                D.zero_grad()
                optim_G.zero_grad()
                g_loss.backward()
                optim_G.step()
            bar.set_description(" {} [G]: {} [D]: {}".format(
                ep, loss_G_list[-1], loss_D_list[-1]))

        # Save the result
        Loss_G_list.append(np.mean(loss_G_list))
        Loss_D_list.append(np.mean(loss_D_list))
        fake_img = G(fix_z)
        save_image(fake_img,
                   os.path.join(opts.det, 'images',
                                str(ep) + '.png'),
                   nrow=4,
                   normalize=True)
        state = {
            'G': G.state_dict(),
            'D': D.state_dict(),
            'Loss_G': Loss_G_list,
            'Loss_D': Loss_D_list,
        }
        torch.save(state, os.path.join(opts.det, 'models', 'latest.pth'))

        scheduler_D.step()
        scheduler_G.step()

    # Plot the total loss curve
    Loss_D_list = Loss_D_list[1:]
    Loss_G_list = Loss_G_list[1:]
    plotLossCurve(opts, Loss_D_list, Loss_G_list)
コード例 #14
0
        discriminator.module.load_state_dict(ckpt["discriminator"])
        g_running.load_state_dict(ckpt["g_running"])
        g_optimizer.load_state_dict(ckpt["g_optimizer"])
        d_optimizer.load_state_dict(ckpt["d_optimizer"])

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
    ])

    dataset = sunnerData.ImageDataset(
        root=[[args.path]],
        transform=transforms.Compose([
            sunnertransforms.Resize((128, 128)),
            sunnertransforms.ToTensor(),
            sunnertransforms.ToFloat(),
            sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW),
            sunnertransforms.Normalize()
        ]),
    )

    if args.sched:
        args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
        args.batch = {
            4: 512,
            8: 256,
            16: 128,
            32: 64,
            64: 32,
            128: 32,
            256: 32
コード例 #15
0
ファイル: train.py プロジェクト: ahuirecome/StyleGAN2_PyTorch
def main(opts):
    # Create the data loader
    loader = sunnerData.DataLoader(sunnerData.ImageDataset(
        root=[[opts.path]],
        transform=transforms.Compose([
            sunnertransforms.Resize((opts.resolution, opts.resolution)),
            sunnertransforms.ToTensor(),
            sunnertransforms.ToFloat(),
            sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW),
            sunnertransforms.Normalize(),
        ])),
        batch_size=opts.batch_size,
        shuffle=True,
        drop_last=True
    )

    # Create the model
    start_epoch = 0
    G = G_stylegan2(fmap_base=opts.fmap_base,
                    resolution=opts.resolution,
                    mapping_layers=opts.mapping_layers,
                    opts=opts,
                    return_dlatents=True)
    D = D_stylegan2(fmap_base=opts.fmap_base,
                    resolution=opts.resolution,
                    structure='resnet')

    # Load the pre-trained weight
    if os.path.exists(opts.resume):
        INFO("Load the pre-trained weight!")
        state = torch.load(opts.resume)
        G.load_state_dict(state['G'])
        D.load_state_dict(state['D'])
        start_epoch = state['start_epoch']
    else:
        INFO("Pre-trained weight cannot load successfully, train from scratch!")

    # Multi-GPU support
    if torch.cuda.device_count() > 1:
        INFO("Multiple GPU:" + str(torch.cuda.device_count()) + "\t GPUs")
        G = torch.nn.DataParallel(G)
        D = torch.nn.DataParallel(D)
    G.to(opts.device)
    D.to(opts.device)

    # Create the criterion, optimizer and scheduler
    lr_D = 0.0015
    lr_G = 0.0015
    optim_D = torch.optim.Adam(D.parameters(), lr=lr_D, betas=(0.9, 0.999))
    # g_mapping has 100x lower learning rate
    params_G = [{"params": G.g_synthesis.parameters()},
				{"params": G.g_mapping.parameters(), "lr": lr_G * 0.01}]
    optim_G = torch.optim.Adam(params_G, lr=lr_G, betas=(0.9, 0.999))
    scheduler_D = optim.lr_scheduler.ExponentialLR(optim_D, gamma=0.99)
    scheduler_G = optim.lr_scheduler.ExponentialLR(optim_G, gamma=0.99)

    # Train
    fix_z = torch.randn([opts.batch_size, 512]).to(opts.device)
    softplus = torch.nn.Softplus()
    Loss_D_list = [0.0]
    Loss_G_list = [0.0]
    for ep in range(start_epoch, opts.epoch):
        bar = tqdm(loader)
        loss_D_list = []
        loss_G_list = []
        for i, (real_img,) in enumerate(bar):

            real_img = real_img.to(opts.device)
            latents = torch.randn([real_img.size(0), 512]).to(opts.device)

            # =======================================================================================================
            #   (1) Update D network: D_logistic_r1(default)
            # =======================================================================================================
            # Compute adversarial loss toward discriminator
            real_img = real_img.to(opts.device)
            real_logit = D(real_img)
            fake_img, fake_dlatent = G(latents)
            fake_logit = D(fake_img.detach())

            d_loss = softplus(fake_logit)
            d_loss = d_loss + softplus(-real_logit)

            # original
            r1_penalty = D_logistic_r1(real_img.detach(), D)
            d_loss = (d_loss + r1_penalty).mean()
            # lite
            # d_loss = d_loss.mean()

            loss_D_list.append(d_loss.mean().item())

            # Update discriminator
            optim_D.zero_grad()
            d_loss.backward()
            optim_D.step()

            # =======================================================================================================
            #   (2) Update G network: G_logistic_ns_pathreg(default)
            # =======================================================================================================
            # if i % CRITIC_ITER == 0:
            G.zero_grad()
            fake_scores_out = D(fake_img)
            _g_loss = softplus(-fake_scores_out)

            # Compute |J*y|.
            # pl_noise = (torch.randn(fake_img.shape) / np.sqrt(fake_img.shape[2] * fake_img.shape[3])).to(fake_img.device)
            # pl_grads = grad(torch.sum(fake_img * pl_noise), fake_dlatent, retain_graph=True)[0]
            # pl_lengths = torch.sqrt(torch.sum(torch.sum(torch.mul(pl_grads, pl_grads), dim=2), dim=1))
            # pl_mean = PL_DECAY * torch.sum(pl_lengths)
            #
            # pl_penalty = torch.mul(pl_lengths - pl_mean, pl_lengths - pl_mean)
            # reg = pl_penalty * PL_WEIGHT
            #
            # # original
            # g_loss = (_g_loss + reg).mean()
            # lite
            g_loss = _g_loss.mean()
            loss_G_list.append(g_loss.mean().item())

            # Update generator
            g_loss.backward(retain_graph=True)
            optim_G.step()

            # Output training stats
            bar.set_description(
                "Epoch {} [{}, {}] [G]: {} [D]: {}".format(ep, i + 1, len(loader), loss_G_list[-1], loss_D_list[-1]))

        # Save the result
        Loss_G_list.append(np.mean(loss_G_list))
        Loss_D_list.append(np.mean(loss_D_list))

        # Check how the generator is doing by saving G's output on fixed_noise
        with torch.no_grad():
            fake_img = G(fix_z)[0].detach().cpu()
            save_image(fake_img, os.path.join(opts.det, 'images', str(ep) + '.png'), nrow=4, normalize=True)

        # Save model
        state = {
            'G': G.state_dict(),
            'D': D.state_dict(),
            'Loss_G': Loss_G_list,
            'Loss_D': Loss_D_list,
            'start_epoch': ep,
        }
        torch.save(state, os.path.join(opts.det, 'models', 'latest.pth'))

        scheduler_D.step()
        scheduler_G.step()

    # Plot the total loss curve
    Loss_D_list = Loss_D_list[1:]
    Loss_G_list = Loss_G_list[1:]
    plotLossCurve(opts, Loss_D_list, Loss_G_list)
コード例 #16
0
    return args


if __name__ == '__main__':
    args = parse()

    # Create data loader
    sunnerData.quiet()
    sunnertransforms.quiet()
    loader = sunnerData.ImageLoader(
        sunnerData.ImageDataset(
            root_list=[args.image_folder, args.mask_folder],
            transform=transforms.Compose([
                # sunnertransforms.Rescale((360, 640)),
                sunnertransforms.Rescale((256, 256)),
                sunnertransforms.ToTensor(),

                # BHWC -> BCHW
                sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW),
                sunnertransforms.Normalize()
            ]),
            sample_method=sunnerData.OVER_SAMPLING),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=2)

    # Load model
    if args.model_type == 'pconv':
        model = PartialUNet(style_list=args.style,
                            base=64,
                            style_weight=args.lambda_style,
                            freeze=args.freeze)
コード例 #17
0
def main(args):
    """
        Train the 1st step for LaDo

        Arg:    args    (Namespace)     - The argument object
    """
    # Create the data loader and the model
    loader = Data.DataLoader(dataset=sunnerData.ImageDataset(
        root=[[args.src_image_folder], [args.src_pair_image_folder],
              [args.tar_pair_image_folder]],
        transform=transforms.Compose([
            transforms.Resize((args.img_size, args.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ]),
        sample_method=sunnerData.OVER_SAMPLING),
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=8)
    model = LaDo(args.content_dims, args.appearance_dims, args.batch_size)

    # Loop
    for ep in range(args.total_epoch + 1):
        bar = tqdm_table(loader)
        bar.set_table_setting(150)
        for i, (src_img, src_img_pair, tar_img_pair) in enumerate(bar):
            if len(src_img) == 1:
                continue
            model.setInput_1st(src_img, src_img_pair, tar_img_pair)
            if ep != 0:
                # Print total loss and update
                model.backward_1st()
                bar.set_table_info(
                    {k[14:]: v
                     for k, v in model.getLoss(stage=1).items()})
            else:
                model.forward_1st()
                break

        # Print average loss
        if ep != 0:
            print("=====" * 20)
            print("<< Epoch {} average >>".format(ep) + "     " * 20)
            string = ""
            had_write = False
            for i, (key, loss) in enumerate(
                    model.getLoss(stage=1, normalize=True).items()):
                if loss == 0.0:
                    continue
                loss = round(loss, 10)
                if i % 2 == 1:
                    string += "{:>20}: {:>15} \t".format(key[14:], str(loss))
                    print(string)
                    had_write = True
                else:
                    string = "{:>20}: {:>15} \t".format(key[14:], str(loss))
                    had_write = False
            if had_write == False:
                print(string)
            model.finishEpoch()

        # Save
        if not os.path.exists(args.root_folder):
            os.mkdir(args.root_folder)
        if not os.path.exists(os.path.join(args.root_folder, 'models_1st')):
            os.mkdir(os.path.join(args.root_folder, 'models_1st'))
        model.save(path=os.path.join(args.root_folder, 'models_1st',
                                     str(ep) + '.pth'))
コード例 #18
0
ファイル: train.py プロジェクト: itsss/StyleGAN2
def main(opts):
    # Data load
    loader = sunnerData.DataLoader(sunnerData.ImageDataset(
        root=[[opts.path]],
        transform=transforms.Compose([
            sunnertransforms.Resize((opts.resolution, opts.resolution)),
            sunnertransforms.ToTensor(),
            sunnertransforms.ToFloat(),
            sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW),
            sunnertransforms.Normalize()
        ])),
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   drop_last=True)

    # model generation
    start_epoch = 0
    G = Generator_stylegan2(fmap_base=opts.fmap_base,
                            resol=opts.resolution,
                            mapping_layers=opts.mapping_layers,
                            opts=opts,
                            return_dlatents=True)
    D = Discriminator_stylegan2(fmap_base=opts.fmap_base,
                                resol=opts.resolution,
                                structure='resnet')

    # pre-trained weight loading
    if os.path.exists(opts.resume):
        INFO("Load the pre-trained weight!")
        state = torch.load(opts.resume)
        G.load_state_dict(state['G'])
        D.load_state_dict(state['D'])
        start_epoch = state['start_epoch']
    else:
        INFO("pre-trained weight error")

    # multiple GPU support
    if (torch.cuda.device_count() > 1):
        INFO("multiple GPU detected! Total " + str(torch.cuda.device_count()) +
             '\t GPUs!')
        G = torch.nn.DataParrlel(G)
        D = torch.nn.DataParallel(D)
    G.to(opts.device)
    D.to(opts.device)

    # optimizer, scheduler
    lr_D = 0.0015
    lr_G = 0.0015
    optim_D = torch.optim.Adam(D.parameters(), lr=lr_D, betas=(0.9, 0.999))
    params_G = [{
        "params": G.g_synthesis.parameters()
    }, {
        "params": G.g_mapping.parameters(),
        "lr": lr_G * 0.01
    }]
    optim_G = torch.optim.Adam(params_G, lr=lr_G, betas=(0.9, 0.999))
    scheduler_D = optim.lr_scheduler.ExponentialLR(optim_D, gamma=0.99)
    scheduler_G = optim.lr_scheduler.ExponentialLR(optim_G, gamma=0.99)

    # start training
    fix_z = torch.randn([opts.batch_size, 512]).to(opts.device)
    softplus = torch.nn.Softplus()
    Loss_D_list = [0.0]
    Loss_G_list = [0.0]
    for ep in range(start_epoch, opts.epoch):
        bar = tqdm(loader)
        loss_D_list = []
        loss_G_list = []
        for i, (real_img, ) in enumerate(bar):

            real_img = real_img.to(opts.device)
            latents = torch.randn([real_img.size(0), 512]).to(opts.device)

            # Discriminator Network
            real_img = real_img.to(opts.device)
            real_logit = D(real_img)
            fake_img, fake_dlatent = G(latents)
            fake_logit = D(fake_img.detach())

            d_loss = softplus(fake_logit)
            d_loss = d_loss + softplus(-real_logit)

            r1_penalty = D_logistic_r1(real_img.detach(), D)
            d_loss = (d_loss + r1_penalty).mean()

            loss_D_list.append(d_loss.mean().item())

            optim_D.zero_grad()
            d_loss.backward()
            optim_D.step()

            # Generator Network
            G.zero_grad()
            fake_scores_out = D(fake_img)
            _g_loss = softplus(-fake_scores_out)

            g_loss = _g_loss.mean()
            loss_G_list.append(g_loss.mean().item())

            g_loss.backward()
            optim_G.step()

            bar.set_description("Epoch {} [{}, {}] [G]: {} [D]: {}".format(
                ep, i + 1, len(loader), loss_G_list[-1], loss_D_list[-1]))

        # save result
        Loss_G_list.append(np.mean(loss_G_list))
        Loss_D_list.append(np.mean(loss_D_list))

        with torch.no_grad():
            fake_img = G(fix_z)[0].detach().cpu()
            save_image(fake_img,
                       os.path.join(opts.det, 'images',
                                    str(ep) + '.png'),
                       nrow=4,
                       normalize=True)

        # save model
        state = {
            'G': G.state_dict(),
            'D': D.state_dict(),
            'Loss_G': Loss_G_list,
            'Loss_D': Loss_D_list,
            'start_epoch': ep,
        }
        torch.save(state, os.path.join(opts.det, 'models', 'latest.pth'))

        scheduler_D.step()
        scheduler_G.step()

    Loss_D_list = Loss_D_list[1:]
    Loss_G_list = Loss_G_list[1:]
    plotLossCurve(opts, Loss_D_list, Loss_G_list)