def decoder_test_each(codes, output_dir, model, rnn_type, use_cuda, iterations, network, code_size=32): iters, batch_size, channels, height, width = codes.size() height = height * 16 width = width * 16 codes = Variable(codes, volatile=True) if network == 'Big': decoder = compression_net.CompressionDecoder(rnn_type, code_size) else: decoder = compression_net_smaller.CompressionDecoder(rnn_type) decoder.eval() decoder.load_state_dict(torch.load(model)) decoder_h_1 = (Variable(torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True), Variable(torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True)) decoder_h_2 = (Variable(torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True), Variable(torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True)) decoder_h_3 = (Variable(torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True), Variable(torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True)) decoder_h_4 = (Variable(torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True), Variable(torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True)) if use_cuda: decoder = decoder.cuda() codes = codes.cuda() decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda()) decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda()) decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda()) decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda()) image = torch.zeros(1, 3, height, width) + 0.5 images = [] for iters in range(min(iterations, codes.size(0))): output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder( codes[iters], decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4) image = image + output.data.cpu() images.append(image.numpy()) return images
def encoder_test_each(image, model, iterations, rnn_type, use_cuda=True, network='Big', code_size=32): batch_size, channel, height, width = image.shape assert height % 32 == 0 and width % 32 == 0 image = Variable(image, volatile=True) if network == 'Big': encoder = compression_net.CompressionEncoder(rnn_type=rnn_type) binarizer = compression_net.CompressionBinarizer(code_size) decoder = compression_net.CompressionDecoder(rnn_type=rnn_type, code_size=code_size) else: encoder = compression_net_smaller.CompressionEncoder(rnn_type=rnn_type) binarizer = compression_net_smaller.CompressionBinarizer() decoder = compression_net_smaller.CompressionDecoder(rnn_type=rnn_type) encoder.eval() binarizer.eval() decoder.eval() encoder.load_state_dict(torch.load(model)) binarizer.load_state_dict(torch.load(model.replace('encoder', 'binarizer'))) decoder.load_state_dict(torch.load(model.replace('encoder', 'decoder'))) encoder_h_1 = (Variable(torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True), Variable(torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True)) encoder_h_2 = (Variable(torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True), Variable(torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True)) encoder_h_3 = (Variable(torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True), Variable(torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True)) decoder_h_1 = (Variable(torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True), Variable(torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True)) decoder_h_2 = (Variable(torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True), Variable(torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True)) decoder_h_3 = (Variable(torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True), Variable(torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True)) decoder_h_4 = (Variable(torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True), Variable(torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True)) if use_cuda: encoder = encoder.cuda() binarizer = binarizer.cuda() decoder = decoder.cuda() image = image.cuda() encoder_h_1 = (encoder_h_1[0].cuda(), encoder_h_1[1].cuda()) encoder_h_2 = (encoder_h_2[0].cuda(), encoder_h_2[1].cuda()) encoder_h_3 = (encoder_h_3[0].cuda(), encoder_h_3[1].cuda()) decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda()) decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda()) decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda()) decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda()) codes = [] res = image - 0.5 ori_res = image - 0.5 for iters in range(iterations): encoded, encoder_h_1, encoder_h_2, encoder_h_3 = encoder( res, encoder_h_1, encoder_h_2, encoder_h_3) code = binarizer(encoded) output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder( code, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4) res = res - output #res = ori_res - output codes.append(code.data.cpu().numpy()) #print('Iter: {:02d}; Loss: {:.06f}'.format(iters, res.data.abs().mean())) codes = (np.stack(codes).astype(np.int8) + 1) // 2 #export = np.packbits(codes.reshape(-1)) #np.savez_compressed(output_path, shape=codes.shape, codes=export) return codes
def encoder_test(input, output_path, model, iterations, rnn_type, use_cuda=True, network='Big', code_size=32): raw_image = imread(input, mode='RGB') h, w, c = raw_image.shape new_h = (h // 32 + 1) * 32 new_w = (w // 32 + 1) * 32 image = np.zeros((new_h, new_w, c), dtype=np.float32) image[:h, :w, :] = raw_image image = torch.from_numpy( np.expand_dims( np.transpose(image.astype(np.float32) / 255.0, (2, 0, 1)), 0)) batch_size, input_channels, height, width = image.size() assert height % 32 == 0 and width % 32 == 0 image = Variable(image, volatile=True) if network == 'Big': encoder = compression_net.CompressionEncoder(rnn_type=rnn_type) binarizer = compression_net.CompressionBinarizer(code_size=code_size) decoder = compression_net.CompressionDecoder(rnn_type=rnn_type, code_size=code_size) else: encoder = compression_net_smaller.CompressionEncoder(rnn_type=rnn_type) binarizer = compression_net_smaller.CompressionBinarizer() decoder = compression_net_smaller.CompressionDecoder(rnn_type=rnn_type) encoder.eval() binarizer.eval() decoder.eval() if use_cuda: encoder.load_state_dict(torch.load(model)) binarizer.load_state_dict( torch.load(model.replace('encoder', 'binarizer'))) decoder.load_state_dict(torch.load(model.replace('encoder', 'decoder'))) else: encoder.load_state_dict( torch.load(model, map_location=lambda storage, loc: storage)) binarizer.load_state_dict( torch.load(model.replace('encoder', 'binarizer'), map_location=lambda storage, loc: storage)) decoder.load_state_dict( torch.load(model.replace('encoder', 'decoder'), map_location=lambda storage, loc: storage)) encoder_h_1 = (Variable(torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True), Variable(torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True)) encoder_h_2 = (Variable(torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True), Variable(torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True)) encoder_h_3 = (Variable(torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True), Variable(torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True)) encoder_h_4 = (Variable(torch.zeros(batch_size, 1024, height // 32, width // 32), volatile=True), Variable(torch.zeros(batch_size, 1024, height // 32, width // 32), volatile=True)) decoder_h_0 = (Variable(torch.zeros(batch_size, 512, height // 32, width // 32), volatile=True), Variable(torch.zeros(batch_size, 512, height // 32, width // 32), volatile=True)) decoder_h_1 = (Variable(torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True), Variable(torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True)) decoder_h_2 = (Variable(torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True), Variable(torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True)) decoder_h_3 = (Variable(torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True), Variable(torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True)) decoder_h_4 = (Variable(torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True), Variable(torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True)) if use_cuda: encoder = encoder.cuda() binarizer = binarizer.cuda() decoder = decoder.cuda() image = image.cuda() encoder_h_1 = (encoder_h_1[0].cuda(), encoder_h_1[1].cuda()) encoder_h_2 = (encoder_h_2[0].cuda(), encoder_h_2[1].cuda()) encoder_h_3 = (encoder_h_3[0].cuda(), encoder_h_3[1].cuda()) decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda()) decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda()) decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda()) decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda()) codes = [] res = image - 0.5 ori_res = image - 0.5 if network == 'Big': for iters in range(iterations): encoded, encoder_h_1, encoder_h_2, encoder_h_3 = encoder( res, encoder_h_1, encoder_h_2, encoder_h_3) code = binarizer(encoded) output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder( code, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4) res = res - output #res = ori_res - output codes.append(code.data.cpu().numpy()) else: for iters in range(iterations): encoded, encoder_h_1, encoder_h_2, encoder_h_3, encoder_h_4 = encoder( res, encoder_h_1, encoder_h_2, encoder_h_3, encoder_h_4) code = binarizer(encoded) output, decoder_h_0, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder( code, decoder_h_0, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4) res = res - output #res = ori_res - output codes.append(code.data.cpu().numpy()) #print('Iter: {:02d}; Loss: {:.06f}'.format(iters, res.data.abs().mean())) codes = (np.stack(codes).astype(np.int8) + 1) // 2 export = np.packbits(codes.reshape(-1)) delta_h = new_h - h delta_w = new_w - w delta = np.array([delta_h, delta_w], dtype=np.uint8) #export = np.append(delta, export) #binary_delta = np.binary_repr(delta_h, width=5) + np.binary_repr(delta_w, width=5) #binary_delta_numpy = [eachbit for eachbit in binary_delta] np.savez_compressed(output_path, shape=codes.shape, codes=export, delta=delta)
def decoder_test(input, output_path, model, iterations, rnn_type, use_cuda=True, code_size=32): content = np.load(input) codes = np.unpackbits(content['codes']) codes = np.reshape(codes, content['shape']).astype(np.float32) * 2 - 1 codes = torch.from_numpy(codes) iters, batch_size, channels, height, width = codes.size() height = height * 16 width = width * 16 codes = Variable(codes, volatile=True) decoder = compression_net.CompressionDecoder(rnn_type, code_size) decoder.eval() if use_cuda: decoder.load_state_dict(torch.load(model)) else: decoder.load_state_dict( torch.load(model, map_location=lambda storage, loc: storage)) decoder_h_1 = (Variable(torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True), Variable(torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True)) decoder_h_2 = (Variable(torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True), Variable(torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True)) decoder_h_3 = (Variable(torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True), Variable(torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True)) decoder_h_4 = (Variable(torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True), Variable(torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True)) if use_cuda: decoder = decoder.cuda() codes = codes.cuda() decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda()) decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda()) decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda()) decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda()) image = torch.zeros(1, 3, height, width) + 0.5 for iters in range(min(iterations, codes.size(0))): output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder( codes[iters], decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4) image = image + output.data.cpu() imsave( output_path, np.squeeze(image.numpy().clip(0, 1) * 255.0).astype( np.uint8).transpose(1, 2, 0))
def main(): global args args = parser.parse_args() train_transform = transforms.Compose([ transforms.RandomCrop((32, 32)), transforms.RandomHorizontalFlip(), transforms.ToTensor() ]) train_dataset = dataset.ImageFolder(roots=args.train_paths, transform=train_transform) train_loader = data.DataLoader(dataset = train_dataset,\ batch_size = args.batch_size,\ shuffle = True,\ num_workers = args.num_workers ) encoder = compression_net.CompressionEncoder(rnn_type=args.rnn_type).cuda() binarizer = compression_net.CompressionBinarizer( code_size=args.code_size).cuda() decoder = compression_net.CompressionDecoder( rnn_type=args.rnn_type, code_size=args.code_size).cuda() optimizer = optim.Adam([ { 'params': encoder.parameters() }, { 'params': binarizer.parameters() }, { 'params': decoder.parameters() }, ], lr=args.lr) start_epoch = args.start_epoch if args.resume: if os.path.isdir('checkpoint'): print("=> loading checkpoint '{}'".format(args.resume)) encoder.load_state_dict( torch.load('checkpoint/{}_{}_{}/encoder_{:08d}.pth'.format( args.rnn_type, args.loss_type, args.code_size, args.resume))) binarizer.load_state_dict( torch.load('checkpoint/{}_{}_{}/binarizer_{:08d}.pth'.format( args.rnn_type, args.loss_type, args.code_size, args.resume))) decoder.load_state_dict( torch.load('checkpoint/{}_{}_{}/decoder_{:08d}.pth'.format( args.rnn_type, args.loss_type, args.code_size, args.resume))) #args.start_epoch = checkpoint['epoch'] #model.load_state_dict(checkpoint['state_dict']) #optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}'".format(args.resume)) start_epoch = args.resume + 1 else: print("=> no checkpoint found at '{}'".format(args.resume)) torch.manual_seed(23) scheduler = LS.MultiStepLR(optimizer, milestones=[3, 10, 20, 50, 100], gamma=0.5) for epoch in range(start_epoch, start_epoch + args.epochs): scheduler.step() train(train_loader, encoder, binarizer, decoder, epoch, optimizer) if epoch % args.save_freq == 0 or epoch == args.epochs - 1: if not os.path.exists('checkpoint'): os.mkdir('checkpoint') if not os.path.exists('checkpoint/{}_{}_{}'.format( args.rnn_type, args.loss_type, args.code_size)): os.mkdir('checkpoint/{}_{}_{}'.format(args.rnn_type, args.loss_type, args.code_size)) torch.save( encoder.state_dict(), 'checkpoint/{}_{}_{}/encoder_{:08d}.pth'.format( args.rnn_type, args.loss_type, args.code_size, epoch)) torch.save( binarizer.state_dict(), 'checkpoint/{}_{}_{}/binarizer_{:08d}.pth'.format( args.rnn_type, args.loss_type, args.code_size, epoch)) torch.save( decoder.state_dict(), 'checkpoint/{}_{}_{}/decoder_{:08d}.pth'.format( args.rnn_type, args.loss_type, args.code_size, epoch))
def decoder_test(input, output_dir, model, iterations, rnn_type, use_cuda=True, network='Big', code_size=32): content = np.load(input) codes = np.unpackbits(content['codes']) codes = np.reshape(codes, content['shape']).astype(np.float32) * 2 - 1 delta = content['delta'] delta_h = int(delta[0]) delta_w = int(delta[1]) codes = torch.from_numpy(codes) iters, batch_size, channels, height, width = codes.size() height = height * 16 width = width * 16 codes = Variable(codes, volatile=True) if network == 'Big': decoder = compression_net.CompressionDecoder(rnn_type, code_size) else: decoder = compression_net_smaller.CompressionDecoder(rnn_type) decoder.eval() if use_cuda: decoder.load_state_dict(torch.load(model)) else: decoder.load_state_dict( torch.load(model, map_location=lambda storage, loc: storage)) decoder_h_0 = (Variable(torch.zeros(batch_size, 512, height // 32, width // 32), volatile=True), Variable(torch.zeros(batch_size, 512, height // 32, width // 32), volatile=True)) decoder_h_1 = (Variable(torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True), Variable(torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True)) decoder_h_2 = (Variable(torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True), Variable(torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True)) decoder_h_3 = (Variable(torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True), Variable(torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True)) decoder_h_4 = (Variable(torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True), Variable(torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True)) if use_cuda: decoder = decoder.cuda() codes = codes.cuda() decoder_h_0 = (decoder_h_0[0].cuda(), decoder_h_0[1].cuda()) decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda()) decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda()) decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda()) decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda()) image = torch.zeros(1, 3, height, width) + 0.5 if network == 'Big': for iters in range(min(iterations, codes.size(0))): output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder( codes[iters], decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4) image = image + output.data.cpu() image_save = image[:, :, :-delta_h, :-delta_w] imsave( os.path.join(output_dir, '{:02d}.png'.format(iters)), np.squeeze(image_save.numpy().clip(0, 1) * 255.0).astype( np.uint8).transpose(1, 2, 0)) else: for iters in range(min(iterations, codes.size(0))): output, decoder_h_0, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder( codes[iters], decoder_h_0, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4) image = image + output.data.cpu() image_save = image[:, :, :-delta_h, :-delta_w] imsave( os.path.join(output_dir, '{:02d}.png'.format(iters)), np.squeeze(image_save.numpy().clip(0, 1) * 255.0).astype( np.uint8).transpose(1, 2, 0)) return min(iterations, codes.size(0))