Beispiel #1
0
def decoder_test_each(codes,
                      output_dir,
                      model,
                      rnn_type,
                      use_cuda,
                      iterations,
                      network,
                      code_size=32):
    iters, batch_size, channels, height, width = codes.size()

    height = height * 16
    width = width * 16

    codes = Variable(codes, volatile=True)

    if network == 'Big':
        decoder = compression_net.CompressionDecoder(rnn_type, code_size)
    else:
        decoder = compression_net_smaller.CompressionDecoder(rnn_type)
    decoder.eval()

    decoder.load_state_dict(torch.load(model))

    decoder_h_1 = (Variable(torch.zeros(batch_size, 512, height // 16,
                                        width // 16),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 512, height // 16,
                                        width // 16),
                            volatile=True))
    decoder_h_2 = (Variable(torch.zeros(batch_size, 512, height // 8,
                                        width // 8),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 512, height // 8,
                                        width // 8),
                            volatile=True))
    decoder_h_3 = (Variable(torch.zeros(batch_size, 256, height // 4,
                                        width // 4),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 256, height // 4,
                                        width // 4),
                            volatile=True))
    decoder_h_4 = (Variable(torch.zeros(batch_size, 128, height // 2,
                                        width // 2),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 128, height // 2,
                                        width // 2),
                            volatile=True))

    if use_cuda:
        decoder = decoder.cuda()

        codes = codes.cuda()

        decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda())
        decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda())
        decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda())
        decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda())

    image = torch.zeros(1, 3, height, width) + 0.5
    images = []
    for iters in range(min(iterations, codes.size(0))):
        output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(
            codes[iters], decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4)
        image = image + output.data.cpu()
        images.append(image.numpy())
    return images
def encoder_test_each(image,
                      model,
                      iterations,
                      rnn_type,
                      use_cuda=True,
                      network='Big',
                      code_size=32):
    batch_size, channel, height, width = image.shape
    assert height % 32 == 0 and width % 32 == 0

    image = Variable(image, volatile=True)

    if network == 'Big':
        encoder = compression_net.CompressionEncoder(rnn_type=rnn_type)
        binarizer = compression_net.CompressionBinarizer(code_size)
        decoder = compression_net.CompressionDecoder(rnn_type=rnn_type,
                                                     code_size=code_size)
    else:
        encoder = compression_net_smaller.CompressionEncoder(rnn_type=rnn_type)
        binarizer = compression_net_smaller.CompressionBinarizer()
        decoder = compression_net_smaller.CompressionDecoder(rnn_type=rnn_type)

    encoder.eval()
    binarizer.eval()
    decoder.eval()

    encoder.load_state_dict(torch.load(model))
    binarizer.load_state_dict(torch.load(model.replace('encoder',
                                                       'binarizer')))
    decoder.load_state_dict(torch.load(model.replace('encoder', 'decoder')))

    encoder_h_1 = (Variable(torch.zeros(batch_size, 256, height // 4,
                                        width // 4),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 256, height // 4,
                                        width // 4),
                            volatile=True))
    encoder_h_2 = (Variable(torch.zeros(batch_size, 512, height // 8,
                                        width // 8),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 512, height // 8,
                                        width // 8),
                            volatile=True))
    encoder_h_3 = (Variable(torch.zeros(batch_size, 512, height // 16,
                                        width // 16),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 512, height // 16,
                                        width // 16),
                            volatile=True))

    decoder_h_1 = (Variable(torch.zeros(batch_size, 512, height // 16,
                                        width // 16),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 512, height // 16,
                                        width // 16),
                            volatile=True))
    decoder_h_2 = (Variable(torch.zeros(batch_size, 512, height // 8,
                                        width // 8),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 512, height // 8,
                                        width // 8),
                            volatile=True))
    decoder_h_3 = (Variable(torch.zeros(batch_size, 256, height // 4,
                                        width // 4),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 256, height // 4,
                                        width // 4),
                            volatile=True))
    decoder_h_4 = (Variable(torch.zeros(batch_size, 128, height // 2,
                                        width // 2),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 128, height // 2,
                                        width // 2),
                            volatile=True))

    if use_cuda:
        encoder = encoder.cuda()
        binarizer = binarizer.cuda()
        decoder = decoder.cuda()

        image = image.cuda()

        encoder_h_1 = (encoder_h_1[0].cuda(), encoder_h_1[1].cuda())
        encoder_h_2 = (encoder_h_2[0].cuda(), encoder_h_2[1].cuda())
        encoder_h_3 = (encoder_h_3[0].cuda(), encoder_h_3[1].cuda())

        decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda())
        decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda())
        decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda())
        decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda())

    codes = []
    res = image - 0.5
    ori_res = image - 0.5

    for iters in range(iterations):
        encoded, encoder_h_1, encoder_h_2, encoder_h_3 = encoder(
            res, encoder_h_1, encoder_h_2, encoder_h_3)

        code = binarizer(encoded)

        output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(
            code, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4)

        res = res - output
        #res = ori_res - output
        codes.append(code.data.cpu().numpy())

        #print('Iter: {:02d}; Loss: {:.06f}'.format(iters, res.data.abs().mean()))
    codes = (np.stack(codes).astype(np.int8) + 1) // 2
    #export = np.packbits(codes.reshape(-1))
    #np.savez_compressed(output_path, shape=codes.shape, codes=export)
    return codes
Beispiel #3
0
def encoder_test(input,
                 output_path,
                 model,
                 iterations,
                 rnn_type,
                 use_cuda=True,
                 network='Big',
                 code_size=32):
    raw_image = imread(input, mode='RGB')
    h, w, c = raw_image.shape
    new_h = (h // 32 + 1) * 32
    new_w = (w // 32 + 1) * 32
    image = np.zeros((new_h, new_w, c), dtype=np.float32)
    image[:h, :w, :] = raw_image
    image = torch.from_numpy(
        np.expand_dims(
            np.transpose(image.astype(np.float32) / 255.0, (2, 0, 1)), 0))
    batch_size, input_channels, height, width = image.size()
    assert height % 32 == 0 and width % 32 == 0

    image = Variable(image, volatile=True)
    if network == 'Big':
        encoder = compression_net.CompressionEncoder(rnn_type=rnn_type)
        binarizer = compression_net.CompressionBinarizer(code_size=code_size)
        decoder = compression_net.CompressionDecoder(rnn_type=rnn_type,
                                                     code_size=code_size)
    else:
        encoder = compression_net_smaller.CompressionEncoder(rnn_type=rnn_type)
        binarizer = compression_net_smaller.CompressionBinarizer()
        decoder = compression_net_smaller.CompressionDecoder(rnn_type=rnn_type)

    encoder.eval()
    binarizer.eval()
    decoder.eval()

    if use_cuda:
        encoder.load_state_dict(torch.load(model))
        binarizer.load_state_dict(
            torch.load(model.replace('encoder', 'binarizer')))
        decoder.load_state_dict(torch.load(model.replace('encoder',
                                                         'decoder')))
    else:
        encoder.load_state_dict(
            torch.load(model, map_location=lambda storage, loc: storage))
        binarizer.load_state_dict(
            torch.load(model.replace('encoder', 'binarizer'),
                       map_location=lambda storage, loc: storage))
        decoder.load_state_dict(
            torch.load(model.replace('encoder', 'decoder'),
                       map_location=lambda storage, loc: storage))

    encoder_h_1 = (Variable(torch.zeros(batch_size, 256, height // 4,
                                        width // 4),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 256, height // 4,
                                        width // 4),
                            volatile=True))
    encoder_h_2 = (Variable(torch.zeros(batch_size, 512, height // 8,
                                        width // 8),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 512, height // 8,
                                        width // 8),
                            volatile=True))
    encoder_h_3 = (Variable(torch.zeros(batch_size, 512, height // 16,
                                        width // 16),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 512, height // 16,
                                        width // 16),
                            volatile=True))
    encoder_h_4 = (Variable(torch.zeros(batch_size, 1024, height // 32,
                                        width // 32),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 1024, height // 32,
                                        width // 32),
                            volatile=True))

    decoder_h_0 = (Variable(torch.zeros(batch_size, 512, height // 32,
                                        width // 32),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 512, height // 32,
                                        width // 32),
                            volatile=True))
    decoder_h_1 = (Variable(torch.zeros(batch_size, 512, height // 16,
                                        width // 16),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 512, height // 16,
                                        width // 16),
                            volatile=True))
    decoder_h_2 = (Variable(torch.zeros(batch_size, 512, height // 8,
                                        width // 8),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 512, height // 8,
                                        width // 8),
                            volatile=True))
    decoder_h_3 = (Variable(torch.zeros(batch_size, 256, height // 4,
                                        width // 4),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 256, height // 4,
                                        width // 4),
                            volatile=True))
    decoder_h_4 = (Variable(torch.zeros(batch_size, 128, height // 2,
                                        width // 2),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 128, height // 2,
                                        width // 2),
                            volatile=True))

    if use_cuda:
        encoder = encoder.cuda()
        binarizer = binarizer.cuda()
        decoder = decoder.cuda()

        image = image.cuda()

        encoder_h_1 = (encoder_h_1[0].cuda(), encoder_h_1[1].cuda())
        encoder_h_2 = (encoder_h_2[0].cuda(), encoder_h_2[1].cuda())
        encoder_h_3 = (encoder_h_3[0].cuda(), encoder_h_3[1].cuda())

        decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda())
        decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda())
        decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda())
        decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda())

    codes = []
    res = image - 0.5
    ori_res = image - 0.5

    if network == 'Big':
        for iters in range(iterations):
            encoded, encoder_h_1, encoder_h_2, encoder_h_3 = encoder(
                res, encoder_h_1, encoder_h_2, encoder_h_3)

            code = binarizer(encoded)

            output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(
                code, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4)

            res = res - output
            #res = ori_res - output
            codes.append(code.data.cpu().numpy())
    else:
        for iters in range(iterations):
            encoded, encoder_h_1, encoder_h_2, encoder_h_3, encoder_h_4 = encoder(
                res, encoder_h_1, encoder_h_2, encoder_h_3, encoder_h_4)

            code = binarizer(encoded)

            output, decoder_h_0, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(
                code, decoder_h_0, decoder_h_1, decoder_h_2, decoder_h_3,
                decoder_h_4)

            res = res - output
            #res = ori_res - output
            codes.append(code.data.cpu().numpy())

        #print('Iter: {:02d}; Loss: {:.06f}'.format(iters, res.data.abs().mean()))

    codes = (np.stack(codes).astype(np.int8) + 1) // 2
    export = np.packbits(codes.reshape(-1))
    delta_h = new_h - h
    delta_w = new_w - w
    delta = np.array([delta_h, delta_w], dtype=np.uint8)
    #export = np.append(delta, export)
    #binary_delta = np.binary_repr(delta_h, width=5) + np.binary_repr(delta_w, width=5)
    #binary_delta_numpy = [eachbit for eachbit in binary_delta]
    np.savez_compressed(output_path,
                        shape=codes.shape,
                        codes=export,
                        delta=delta)
Beispiel #4
0
def decoder_test(input,
                 output_path,
                 model,
                 iterations,
                 rnn_type,
                 use_cuda=True,
                 code_size=32):
    content = np.load(input)
    codes = np.unpackbits(content['codes'])
    codes = np.reshape(codes, content['shape']).astype(np.float32) * 2 - 1

    codes = torch.from_numpy(codes)
    iters, batch_size, channels, height, width = codes.size()
    height = height * 16
    width = width * 16

    codes = Variable(codes, volatile=True)

    decoder = compression_net.CompressionDecoder(rnn_type, code_size)
    decoder.eval()

    if use_cuda:
        decoder.load_state_dict(torch.load(model))
    else:
        decoder.load_state_dict(
            torch.load(model, map_location=lambda storage, loc: storage))

    decoder_h_1 = (Variable(torch.zeros(batch_size, 512, height // 16,
                                        width // 16),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 512, height // 16,
                                        width // 16),
                            volatile=True))
    decoder_h_2 = (Variable(torch.zeros(batch_size, 512, height // 8,
                                        width // 8),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 512, height // 8,
                                        width // 8),
                            volatile=True))
    decoder_h_3 = (Variable(torch.zeros(batch_size, 256, height // 4,
                                        width // 4),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 256, height // 4,
                                        width // 4),
                            volatile=True))
    decoder_h_4 = (Variable(torch.zeros(batch_size, 128, height // 2,
                                        width // 2),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 128, height // 2,
                                        width // 2),
                            volatile=True))

    if use_cuda:
        decoder = decoder.cuda()

        codes = codes.cuda()

        decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda())
        decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda())
        decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda())
        decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda())

    image = torch.zeros(1, 3, height, width) + 0.5
    for iters in range(min(iterations, codes.size(0))):
        output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(
            codes[iters], decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4)
        image = image + output.data.cpu()

    imsave(
        output_path,
        np.squeeze(image.numpy().clip(0, 1) * 255.0).astype(
            np.uint8).transpose(1, 2, 0))
Beispiel #5
0
def main():
    global args
    args = parser.parse_args()
    train_transform = transforms.Compose([
        transforms.RandomCrop((32, 32)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

    train_dataset = dataset.ImageFolder(roots=args.train_paths,
                                        transform=train_transform)
    train_loader = data.DataLoader(dataset = train_dataset,\
            batch_size = args.batch_size,\
            shuffle = True,\
            num_workers = args.num_workers
            )

    encoder = compression_net.CompressionEncoder(rnn_type=args.rnn_type).cuda()
    binarizer = compression_net.CompressionBinarizer(
        code_size=args.code_size).cuda()
    decoder = compression_net.CompressionDecoder(
        rnn_type=args.rnn_type, code_size=args.code_size).cuda()

    optimizer = optim.Adam([
        {
            'params': encoder.parameters()
        },
        {
            'params': binarizer.parameters()
        },
        {
            'params': decoder.parameters()
        },
    ],
                           lr=args.lr)

    start_epoch = args.start_epoch
    if args.resume:
        if os.path.isdir('checkpoint'):
            print("=> loading checkpoint '{}'".format(args.resume))
            encoder.load_state_dict(
                torch.load('checkpoint/{}_{}_{}/encoder_{:08d}.pth'.format(
                    args.rnn_type, args.loss_type, args.code_size,
                    args.resume)))
            binarizer.load_state_dict(
                torch.load('checkpoint/{}_{}_{}/binarizer_{:08d}.pth'.format(
                    args.rnn_type, args.loss_type, args.code_size,
                    args.resume)))
            decoder.load_state_dict(
                torch.load('checkpoint/{}_{}_{}/decoder_{:08d}.pth'.format(
                    args.rnn_type, args.loss_type, args.code_size,
                    args.resume)))
            #args.start_epoch = checkpoint['epoch']
            #model.load_state_dict(checkpoint['state_dict'])
            #optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}'".format(args.resume))
            start_epoch = args.resume + 1
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    torch.manual_seed(23)
    scheduler = LS.MultiStepLR(optimizer,
                               milestones=[3, 10, 20, 50, 100],
                               gamma=0.5)

    for epoch in range(start_epoch, start_epoch + args.epochs):
        scheduler.step()
        train(train_loader, encoder, binarizer, decoder, epoch, optimizer)
        if epoch % args.save_freq == 0 or epoch == args.epochs - 1:
            if not os.path.exists('checkpoint'):
                os.mkdir('checkpoint')
            if not os.path.exists('checkpoint/{}_{}_{}'.format(
                    args.rnn_type, args.loss_type, args.code_size)):
                os.mkdir('checkpoint/{}_{}_{}'.format(args.rnn_type,
                                                      args.loss_type,
                                                      args.code_size))
            torch.save(
                encoder.state_dict(),
                'checkpoint/{}_{}_{}/encoder_{:08d}.pth'.format(
                    args.rnn_type, args.loss_type, args.code_size, epoch))
            torch.save(
                binarizer.state_dict(),
                'checkpoint/{}_{}_{}/binarizer_{:08d}.pth'.format(
                    args.rnn_type, args.loss_type, args.code_size, epoch))
            torch.save(
                decoder.state_dict(),
                'checkpoint/{}_{}_{}/decoder_{:08d}.pth'.format(
                    args.rnn_type, args.loss_type, args.code_size, epoch))
def decoder_test(input,
                 output_dir,
                 model,
                 iterations,
                 rnn_type,
                 use_cuda=True,
                 network='Big',
                 code_size=32):
    content = np.load(input)
    codes = np.unpackbits(content['codes'])
    codes = np.reshape(codes, content['shape']).astype(np.float32) * 2 - 1
    delta = content['delta']
    delta_h = int(delta[0])
    delta_w = int(delta[1])

    codes = torch.from_numpy(codes)
    iters, batch_size, channels, height, width = codes.size()
    height = height * 16
    width = width * 16

    codes = Variable(codes, volatile=True)

    if network == 'Big':
        decoder = compression_net.CompressionDecoder(rnn_type, code_size)
    else:
        decoder = compression_net_smaller.CompressionDecoder(rnn_type)
    decoder.eval()

    if use_cuda:
        decoder.load_state_dict(torch.load(model))
    else:
        decoder.load_state_dict(
            torch.load(model, map_location=lambda storage, loc: storage))

    decoder_h_0 = (Variable(torch.zeros(batch_size, 512, height // 32,
                                        width // 32),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 512, height // 32,
                                        width // 32),
                            volatile=True))
    decoder_h_1 = (Variable(torch.zeros(batch_size, 512, height // 16,
                                        width // 16),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 512, height // 16,
                                        width // 16),
                            volatile=True))
    decoder_h_2 = (Variable(torch.zeros(batch_size, 512, height // 8,
                                        width // 8),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 512, height // 8,
                                        width // 8),
                            volatile=True))
    decoder_h_3 = (Variable(torch.zeros(batch_size, 256, height // 4,
                                        width // 4),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 256, height // 4,
                                        width // 4),
                            volatile=True))
    decoder_h_4 = (Variable(torch.zeros(batch_size, 128, height // 2,
                                        width // 2),
                            volatile=True),
                   Variable(torch.zeros(batch_size, 128, height // 2,
                                        width // 2),
                            volatile=True))

    if use_cuda:
        decoder = decoder.cuda()

        codes = codes.cuda()

        decoder_h_0 = (decoder_h_0[0].cuda(), decoder_h_0[1].cuda())
        decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda())
        decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda())
        decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda())
        decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda())

    image = torch.zeros(1, 3, height, width) + 0.5
    if network == 'Big':
        for iters in range(min(iterations, codes.size(0))):
            output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(
                codes[iters], decoder_h_1, decoder_h_2, decoder_h_3,
                decoder_h_4)
            image = image + output.data.cpu()
            image_save = image[:, :, :-delta_h, :-delta_w]

            imsave(
                os.path.join(output_dir, '{:02d}.png'.format(iters)),
                np.squeeze(image_save.numpy().clip(0, 1) * 255.0).astype(
                    np.uint8).transpose(1, 2, 0))
    else:
        for iters in range(min(iterations, codes.size(0))):
            output, decoder_h_0, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(
                codes[iters], decoder_h_0, decoder_h_1, decoder_h_2,
                decoder_h_3, decoder_h_4)
            image = image + output.data.cpu()
            image_save = image[:, :, :-delta_h, :-delta_w]
            imsave(
                os.path.join(output_dir, '{:02d}.png'.format(iters)),
                np.squeeze(image_save.numpy().clip(0, 1) * 255.0).astype(
                    np.uint8).transpose(1, 2, 0))

    return min(iterations, codes.size(0))