示例#1
0
文件: utils.py 项目: SlipknotTN/MUNIT
def load_vgg16(model_dir):
    """ Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)
    if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')):
        if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')):
            os.system('wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_dir, 'vgg16.t7'))
        vgglua = torchfile.load(os.path.join(model_dir, 'vgg16.t7'))
        vgg = Vgg16()
        for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
            dst.data[:] = src
        torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight'))
    vgg = Vgg16()
    vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight')))
    return vgg
示例#2
0
def init_vgg16(model_dir):
    if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')):
        if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')):
            os.system(
                'wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_dir, 'vgg16.t7'))
        vgglua = torchfile.load(os.path.join(model_dir, 'vgg16.t7'))
        vgg = Vgg16()
        for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
            dst.data[:] = src
        torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight'))
示例#3
0
def load_vgg16(model_dir):
    """ Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)
    if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')):
        if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')):
            os.system(
                'wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O '
                + os.path.join(model_dir, 'vgg16.t7'))
        # vgglua = load_lua(os.path.join(model_dir, 'vgg16.t7')) /media/HDD1/convert_torch_to_pytorch/vgg16.pth
        vgglua = torch.load('/media/HDD1/convert_torch_to_pytorch/vgg16.pth')
        # vgglua = torchfile.load(os.path.join(model_dir, 'vgg16.t7'))
        vgg = Vgg16()
        value_list = []
        for name, v in vgglua.items():
            value_list.append(v)
        for (src, dst) in zip(value_list, vgg.parameters()):
            dst.data[:] = src
        torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight'))
    vgg = Vgg16()
    vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight')))
    return vgg
示例#4
0
def inference(config):
    hair_model, skin_model = load_model(config)

    #train_loader, val_loader = get_loaders(hair_model, skin_model, config)

    try:
        your_pic = Image.open(config.your_pic)
        celeb_pic = Image.open(config.celeb_pic)

    except:
        return

    your_pic, your_pic_mask, celeb_pic, celeb_pic_mask = DataLoader(
        transform(your_pic, celeb_pic, config.image_size, hair_model,
                  skin_model, config.device))

    # Initialize
    vgg = Vgg16().to(config.device)
    resnet = ResNet18(requires_grad=True, pretrained=True).to(config.device)
    generator = GeneratorUNet().to(config.device)
    # discriminator = Discriminator().to(config.device)

    try:
        resnet.load_state_dict(
            torch.load(
                os.path.join(config.checkpoints,
                             'epoch_%d_%s.pth' % (20, 'resnet'))))
        generator.load_state_dict(
            torch.load(
                os.path.join(config.checkpoints,
                             'epoch_%d_%s.pth' % (20, 'generator'))))
    except OSError:
        print('Check if your pretrained weight is in the right place.')

    z1 = resnet(your_pic * your_pic_mask)  #skin
    z2 = resnet(celeb_pic * celeb_pic_mask)  #hair
    fake_im = generator(your_pic, z1, z2)  # z1 is skin, z2 is hair

    images = [your_pic[0], celeb_pic[0], fake_im[0]]
    titles = ['Your picture', 'Celebrity picture', 'Synthesized picture']

    fig, axes = plt.subplots(1, len(titles))
    for i in range(len(images)):
        im = images[i]
        im = im.data.cpu().numpy().transpose(1, 2, 0)
        im = (im + 1) / 2
        axes[i].imshow(im)
        axes[i].axis('off')
        axes[i].set_title(titles[i])

    plt.show()
示例#5
0
def load_vgg16(model_dir):
    """ Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """
    #from torch.utils.serialization import load_lua
    import torchvision.models as models
    vgg16 = models.vgg16(pretrained=True)
    vgg = vgg16.features[:-1]
    return vgg

    import torchfile
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)
    if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')):
        if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')):
            os.system(
                'wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O '
                + os.path.join(model_dir, 'vgg16.t7'))
        vgglua = torchfile.load(os.path.join(model_dir, 'vgg16.t7'))
        vgg = Vgg16()
        for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
            dst.data[:] = src
        torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight'))
    vgg = Vgg16()
    vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight')))
    return vgg
示例#6
0
import torch
import os
from torch.utils.serialization import load_lua
from torchvision.models import resnet18
from networks import Vgg16, ResNet
import torch.utils.model_zoo as model_zoo

model_dir = 'models/'

with open('lua.txt', 'w') as luaf:
    vgglua = load_lua(os.path.join(model_dir, 'vgg16.t7'))
    vgg = Vgg16()
    for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
        print(src, file=luaf)
        input()
        dst.data[:] = src

resnet = ResNet()
resnet.load_state_dict(
    model_zoo.load_url(
        'https://download.pytorch.org/models/resnet18-5c106cde.pth',
        model_dir))
def main(trainfile,
         validfile,
         batch_size,
         epochs,
         valid_frequency=2,
         write_frequency=1000,
         checkpoint_path=None,
         logdir='logs/',
         model_root='checkpoints'):
    # Datasets
    trainset = DatasetFromImageLabelList(trainfile, transform['train'],
                                         target_transform)
    validset = DatasetFromImageLabelList(validfile, transform['valid'],
                                         target_transform)

    # Dataloader
    trainloader = DataLoader(trainset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=4)
    validloader = DataLoader(validset,
                             batch_size=batch_size * 2,
                             shuffle=False,
                             num_workers=4)

    # # Test trainloader
    # dataiter = iter(trainloader)
    # images, targets = dataiter.next()
    # samples_img = transforms.ToPILImage()(torchvision.utils.make_grid(images))
    # samples_img.save('samples.png')

    # SummaryWriter
    writer = SummaryWriter(log_dir=logdir, flush_secs=5)

    # Model
    model = Vgg16(num_classes=10, pretrained=True)
    # Writer model graph
    writer.add_graph(model, torch.rand((1, 3, 224, 224), dtype=torch.float32))

    # Loss function
    criterion = nn.CrossEntropyLoss()

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-2)

    # learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

    # Available device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Model to cuda
    model.to(device)

    # Global step counter
    steps = 1

    # accracy_list and loss_list
    loss_record = []
    accuracy_record = []

    # Minist loss and Best accuracy
    minist_loss = 1
    best_accuracy = 0

    # Start epoch
    start_epoch = 0

    # Loaing model
    if checkpoint_path:
        start_epoch = load_model_state_dict(checkpoint_path, model, optimizer)

    # Train epochs
    for epoch in range(start_epoch, epochs):
        # Train mode
        model.train()

        # Clear record
        loss_record.clear()
        accuracy_record.clear()

        # Train one epoch
        for inputs, targets in trainloader:
            # Data to device
            inputs, targets = inputs.to(device), targets.to(device)

            # Model outputs
            outputs = model(inputs)

            # Caculate batch loss
            batch_loss = criterion(outputs, targets)

            # Zero grad before backward
            optimizer.zero_grad()

            # Backward
            batch_loss.backward()

            # Optimizer weights
            optimizer.step()

            # Adjust learning rate by loss
            scheduler.step(batch_loss)

            # Caculate accuracy
            batch_accuracy = caculate_accuracy(outputs, targets)

            # Record loss and accuracy
            loss_record.append(batch_loss.item())
            accuracy_record.append(batch_accuracy)

            # print('epoch: {}, steps: {}, loss: {:.5f}, accuracy: {:.5f}'.format(
            #     epoch, steps, np.array(loss_record).mean(), np.array(accuracy_record).mean()))

            if steps % write_frequency == 0:
                loss = np.array(loss_record).mean()
                accuracy = np.array(accuracy_record).mean()
                # writer.add_images('train sample', inputs, global_step=steps)
                writer.add_scalar('train loss', loss, global_step=steps)
                writer.add_scalar('train accuracy',
                                  accuracy,
                                  global_step=steps)
            # Add steps
            steps += 1

        if epoch % valid_frequency == 0:
            # Inference mode
            model.eval()
            with torch.no_grad():
                for inputs, targets in validloader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    outputs = model(inputs)
                    batch_loss = criterion(outputs, targets)
                    batch_accuracy = caculate_accuracy(outputs, targets)
                    loss_record.append(batch_loss.item())
                    accuracy_record.append(batch_accuracy)
                loss = np.array(loss_record).mean()
                accuracy = np.array(accuracy_record).mean()
                writer.add_scalar('valid loss', loss, global_step=steps)
                writer.add_scalar('valid accuracy',
                                  accuracy,
                                  global_step=steps)
                print('valid:', epoch, steps, loss, accuracy)

                # 检查点保存条件
                if loss <= minist_loss and accuracy >= best_accuracy:
                    model.to('cpu')
                    save_model_state_dict(model_root, model, optimizer, epoch,
                                          loss, accuracy, 'vgg16')
                    model.to(device)
示例#8
0
def main(config):
    model = load_model(config)
    train_loader, val_loader = get_loaders(model, config)

    # Make dirs
    if not os.path.exists(config.checkpoints):
        os.makedirs(config.checkpoints, exist_ok=True)
    if not os.path.exists(config.save_path):
        os.makedirs(config.save_path, exist_ok=True)

    # Loss Functions
    criterion_GAN = mse_loss

    # Calculate output of image discriminator (PatchGAN)
    patch = (1, config.image_size // 2**4, config.image_size // 2**4)

    # Initialize
    vgg = Vgg16().to(config.device)
    resnet = ResNet18(requires_grad=True, pretrained=True).to(config.device)
    generator = GeneratorUNet().to(config.device)
    discriminator = Discriminator().to(config.device)

    if config.epoch != 0:
        # Load pretrained models
        resnet.load_state_dict(
            torch.load(
                os.path.join(config.checkpoints, 'epoch_%d_%s.pth' %
                             (config.epoch - 1, 'resnet'))))
        generator.load_state_dict(
            torch.load(
                os.path.join(
                    config.checkpoints,
                    'epoch_%d_%s.pth' % (config.epoch - 1, 'generator'))))
        discriminator.load_state_dict(
            torch.load(
                os.path.join(
                    config.checkpoints,
                    'epoch_%d_%s.pth' % (config.epoch - 1, 'discriminator'))))
    else:
        # Initialize weights
        # resnet.apply(weights_init_normal)
        generator.apply(weights_init_normal)
        discriminator.apply(weights_init_normal)

    # Optimizers
    optimizer_resnet = torch.optim.Adam(resnet.parameters(),
                                        lr=config.lr,
                                        betas=(config.b1, config.b2))
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=config.lr,
                                   betas=(config.b1, config.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=config.lr,
                                   betas=(config.b1, config.b2))

    # ----------
    #  Training
    # ----------

    resnet.train()
    generator.train()
    discriminator.train()
    for epoch in range(config.epoch, config.n_epochs):
        for i, (im1, m1, im2, m2) in enumerate(train_loader):
            assert im1.size(0) == im2.size(0)
            valid = Variable(torch.Tensor(np.ones(
                (im1.size(0), *patch))).to(config.device),
                             requires_grad=False)
            fake = Variable(torch.Tensor(np.ones(
                (im1.size(0), *patch))).to(config.device),
                            requires_grad=False)

            # ------------------
            #  Train Generators
            # ------------------

            optimizer_resnet.zero_grad()
            optimizer_G.zero_grad()

            # GAN loss
            z = resnet(im2 * m2)
            if epoch < config.gan_epochs:
                fake_im = generator(im1 * (1 - m1), im2 * m2, z)
            else:
                fake_im = generator(im1, im2, z)
            if epoch < config.gan_epochs:
                pred_fake = discriminator(fake_im, im2)
                gan_loss = config.lambda_gan * criterion_GAN(pred_fake, valid)
            else:
                gan_loss = torch.Tensor([0]).to(config.device)

            # Hair, Face loss
            fake_m2 = torch.argmax(model(fake_im),
                                   1).unsqueeze(1).type(torch.uint8).repeat(
                                       1, 3, 1, 1).to(config.device)
            if 0.5 * torch.sum(m1) <= torch.sum(
                    fake_m2) <= 1.5 * torch.sum(m1):
                hair_loss = config.lambda_style * calc_style_loss(
                    fake_im * fake_m2, im2 * m2, vgg) + calc_content_loss(
                        fake_im * fake_m2, im2 * m2, vgg)
                face_loss = calc_content_loss(fake_im, im1, vgg)
            else:
                hair_loss = config.lambda_style * calc_style_loss(
                    fake_im * m1, im2 * m2, vgg) + calc_content_loss(
                        fake_im * m1, im2 * m2, vgg)
                face_loss = calc_content_loss(fake_im, im1, vgg)
            hair_loss *= config.lambda_hair
            face_loss *= config.lambda_face

            # Total loss
            loss = gan_loss + hair_loss + face_loss

            loss.backward()
            optimizer_resnet.step()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            if epoch < config.gan_epochs:
                optimizer_D.zero_grad()

                # Real loss
                pred_real = discriminator(im1 * (1 - m1) + im2 * m2, im2)
                loss_real = criterion_GAN(pred_real, valid)
                # Fake loss
                pred_fake = discriminator(fake_im.detach(), im2)
                loss_fake = criterion_GAN(pred_fake, fake)
                # Total loss
                loss_D = 0.5 * (loss_real + loss_fake)

                loss_D.backward()
                optimizer_D.step()

            if i % config.sample_interval == 0:
                msg = "Train || Gan loss: %.6f, hair loss: %.6f, face loss: %.6f, loss: %.6f\n" % \
                    (gan_loss.item(), hair_loss.item(), face_loss.item(), loss.item())
                sys.stdout.write("Epoch: %d || Batch: %d\n" % (epoch, i))
                sys.stdout.write(msg)
                fname = os.path.join(
                    config.save_path,
                    "Train_Epoch:%d_Batch:%d.png" % (epoch, i))
                sample_images([im1[0], im2[0], fake_im[0]],
                              ["img1", "img2", "img1+img2"], fname)
                for j, (im1, m1, im2, m2) in enumerate(val_loader):
                    with torch.no_grad():
                        valid = Variable(torch.Tensor(
                            np.ones((im1.size(0), *patch))).to(config.device),
                                         requires_grad=False)
                        fake = Variable(torch.Tensor(
                            np.ones((im1.size(0), *patch))).to(config.device),
                                        requires_grad=False)

                        # GAN loss
                        z = resnet(im2 * m2)
                        if epoch < config.gan_epochs:
                            fake_im = generator(im1 * (1 - m1), im2 * m2, z)
                        else:
                            fake_im = generator(im1, im2, z)

                        if epoch < config.gan_epochs:
                            pred_fake = discriminator(fake_im, im2)
                            gan_loss = config.lambda_gan * criterion_GAN(
                                pred_fake, valid)
                        else:
                            gan_loss = torch.Tensor([0]).to(config.device)

                        # Hair, Face loss
                        fake_m2 = torch.argmax(
                            model(fake_im),
                            1).unsqueeze(1).type(torch.uint8).repeat(
                                1, 3, 1, 1).to(config.device)
                        if 0.5 * torch.sum(m1) <= torch.sum(
                                fake_m2) <= 1.5 * torch.sum(m1):
                            hair_loss = config.lambda_style * calc_style_loss(
                                fake_im * fake_m2, im2 * m2,
                                vgg) + calc_content_loss(
                                    fake_im * fake_m2, im2 * m2, vgg)
                            face_loss = calc_content_loss(fake_im, im1, vgg)
                        else:
                            hair_loss = config.lambda_style * calc_style_loss(
                                fake_im * m1, im2 * m2,
                                vgg) + calc_content_loss(
                                    fake_im * m1, im2 * m2, vgg)
                            face_loss = calc_content_loss(fake_im, im1, vgg)
                        hair_loss *= config.lambda_hair
                        face_loss *= config.lambda_face

                        # Total loss
                        loss = gan_loss + hair_loss + face_loss

                        msg = "Validation || Gan loss: %.6f, hair loss: %.6f, face loss: %.6f, loss: %.6f\n" % \
                                (gan_loss.item(), hair_loss.item(), face_loss.item(), loss.item())
                        sys.stdout.write(msg)
                        fname = os.path.join(
                            config.save_path,
                            "Validation_Epoch:%d_Batch:%d.png" % (epoch, i))
                        sample_images([im1[0], im2[0], fake_im[0]],
                                      ["img1", "img2", "img1+img2"], fname)
                        break

        if epoch % config.checkpoint_interval == 0:
            if epoch < config.gan_epochs:
                models = [resnet, generator, discriminator]
                fnames = ['resnet', 'generator', 'discriminator']
            else:
                models = [resnet, generator]
                fnames = ['resnet', 'generator']
            fnames = [
                os.path.join(config.checkpoints,
                             'epoch_%d_%s.pth' % (epoch, s)) for s in fnames
            ]
            save_weights(models, fnames)
示例#9
0
def load_vgg16(model_path):
    vgg = Vgg16()
    vgg.load_state_dict(torch.load(model_path))
    return vgg