Example #1
0
def main(args):
    # Set up main device and scale batch size
    device = 'cuda' if torch.cuda.is_available() and args.gpu_ids else 'cpu'
    args.batch_size *= max(1, len(args.gpu_ids))

    # Set random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    trainset = ImgDatasets(root_dir='data/celeba_sample',
                           files='train_files.txt',
                           mode=args.mode)
    trainloader = data.DataLoader(trainset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)

    testset = ImgDatasets(root_dir='data/celeba_sample',
                          files='test_files.txt',
                          mode=args.mode)
    testloader = data.DataLoader(testset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers)

    # Model
    print('Building model..')
    net = Glow(num_channels=args.num_channels,
               num_levels=args.num_levels,
               num_steps=args.num_steps,
               mode=args.mode)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    start_epoch = 0
    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at ckpts/best.pth.tar...')
        assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('ckpts/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        global best_loss
        global global_step
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']
        global_step = start_epoch * len(trainset)

    loss_fn = util.NLLLoss().to(device)
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    scheduler = sched.LambdaLR(optimizer, lambda s: min(1., s / args.warm_up))

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn,
              args.max_grad_norm)
        test(epoch, net, testloader, device, loss_fn, args.mode)
Example #2
0
def main(args):
    # Set up main device and scale batch size
    device = 'cuda' if torch.cuda.is_available() and args.gpu_ids else 'cpu'
    args.batch_size *= max(1, len(args.gpu_ids))

    # Set random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # No normalization applied, since Glow expects inputs in (0, 1)
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor()
    ])

    trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform_train)
    trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
    testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    # Model
    print('Building model..')
    net = Glow(num_channels=args.num_channels,
               num_levels=args.num_levels,
               num_steps=args.num_steps)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    start_epoch = 0
    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at ckpts/best.pth.tar...')
        assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('ckpts/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        global best_loss
        global global_step
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']
        global_step = start_epoch * len(trainset)

    loss_fn = util.NLLLoss().to(device)
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    scheduler = sched.LambdaLR(optimizer, lambda s: min(1., s / args.warm_up))

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, scheduler,
              loss_fn, args.max_grad_norm)
        test(epoch, net, testloader, device, loss_fn, args.num_samples)
Example #3
0
def main(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    ckpt = torch.load(args.ckpt_path)
    ckpt_args = ckpt["args"]
    net = Glow(num_channels=ckpt_args.num_channels,
               num_levels=ckpt_args.num_levels,
               num_steps=ckpt_args.num_steps,
               img_size=ckpt_args.img_size,
               dec_size=ckpt_args.dec_size).to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, ckpt_args.gpu_ids)
    net.load_state_dict(ckpt['net'])

    cond_data = torch.load(args.cond_data)
    original, cond_img = cond_data["original"], cond_data["cond_img"].to(
        device)

    # style transfer
    synth_img, target = style_transfer(net,
                                       original,
                                       cond_img,
                                       target_index=args.index)

    ######3#
    os.makedirs('inference_data', exist_ok=True)
    origin_concat = torchvision.utils.make_grid(original,
                                                nrow=4,
                                                padding=2,
                                                pad_value=255)
    img_concat = torchvision.utils.make_grid(synth_img,
                                             nrow=4,
                                             padding=2,
                                             pad_value=255)
    torchvision.utils.save_image(origin_concat,
                                 args.output_dir + 'original.png')
    torchvision.utils.save_image(img_concat,
                                 args.output_dir + '/synthesized.png')
    torchvision.utils.save_image(target, args.output_dir + 'cond_img.png')
Example #4
0
        #print(x.size())
        return x, 0


transform = transforms.Compose(
    [transforms.Scale((32, 32)),
     transforms.ToTensor()])

for i in range(3):
    net = Glow(num_channels=512, num_levels=3, num_steps=16)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = 'cpu'
    #net.to(device)
    if i == 0:
        net.load_state_dict({
            k.replace('module.', ''): v
            for k, v in torch.load("ckpts/-2.pth.tar")['net'].items()
        })
    if i == 1:
        net.load_state_dict({
            k.replace('module.', ''): v
            for k, v in torch.load("ckpts/-1.pth.tar")['net'].items()
        })
    net.eval()
    #testset = dataset(-2, transform, test=True,rotation_data=True)
    testset = torchvision.datasets.CIFAR10(root='dataset/cifar10-torchvision',
                                           train=False,
                                           download=True,
                                           transform=transform)
    #testset = imagenet_val(transform)
    testloader = data.DataLoader(testset,
                                 batch_size=64,
Example #5
0
                                  shuffle=False,
                                  num_workers=8)

    testset = dataset(num_class % 10, transform, test=True)
    testloader = data.DataLoader(testset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=8)

    net = Glow(num_channels=512, num_levels=3, num_steps=16)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = 'cpu'
    net.to(device)
    net.load_state_dict({
        k.replace('module.', ''): v
        for k, v in torch.load("ckpts/" + str(num_class) +
                               ".pth.tar")['net'].items()
    })
    net.eval()
    q = np.load('rotation.npy')

    n = 0
    for i, (image, label) in enumerate(trainloader):
        print('num_class:{},train:{},'.format(num_class, n))

        z, _ = net(image, reverse=False)
        z = z.view(-1, 3 * 32 * 32)
        z = z.detach().numpy()
        for i in range(z.shape[0]):
            z[i] = q.dot(z[i])
        z = torch.from_numpy(z).view(-1, 3, 32, 32)