def main( path: str, file, workers: int, resblocks: int, n_feats: int, scale: int, load: str, k: int, crop: int, ) -> None: configs.n_feats = n_feats configs.resblocks = resblocks configs.K = k configs.scale = scale print(sys.argv) print([item for item in dir(configs) if not item.startswith("__")]) if load != "/dev/null": checkpoint = torch.load(load) print(f"Loaded model from {load}.") print("Epoch:", checkpoint["epoch"]) else: checkpoint = {} compressor = network.Compressor() if checkpoint: compressor.nets.load_state_dict(checkpoint["nets"]) compressor = compressor.cuda() print(f"Number of parameters: {count_params(compressor.nets)}") print(compressor.nets) transforms = [] # type: ignore if crop > 0: transforms.insert(0, T.CenterCrop(crop)) dataset = lc_data.ImageFolder(path, [filename.strip() for filename in file], scale, T.Compose(transforms)) loader = data.DataLoader( dataset, batch_size=1, shuffle=False, num_workers=workers, drop_last=False, ) print(f"Loaded dataset with {len(dataset)} images") bits, inp_size, individual_bpsps = run_eval(loader, compressor) for key in bits.get_keys(): print(f"{key}:", bits.get_scaled_bpsp(key, inp_size)) print("Average total bpsp is: ", bits.get_total_bpsp(inp_size)) # Get the individual images' bpsp # save the individual_bpsps with open('test.p', 'wb') as f: pickle.dump(individual_bpsps, f)
def main( path: str, file, resblocks: int, n_feats: int, scale: int, load: str, k: int, save_path: str, ) -> None: ImageFile.LOAD_TRUNCATED_IMAGES = True configs.n_feats = n_feats configs.resblocks = resblocks configs.K = k configs.scale = scale configs.collect_probs = False print(sys.argv) checkpoint = torch.load(load) print(f"Loaded model from {load}.") print("Epoch:", checkpoint["epoch"]) compressor = network.Compressor() compressor.nets.load_state_dict(checkpoint["nets"]) compressor = compressor.cuda() print(compressor.nets) filenames = [filename.strip() for filename in file] print(f"Loaded directory with {len(filenames)} images") os.makedirs(save_path, exist_ok=True) coder = bitcoding.Bitcoding(compressor) decoder_time_accumulator = timer.TimeAccumulator() total_num_bytes = 0 total_num_subpixels = 0 for filename in filenames: assert filename.endswith(".srec"), ( f"{filename} is not a .srec file") filepath = os.path.join(path, filename) with decoder_time_accumulator.execute(): x = coder.decode(filepath) x = x.byte().squeeze(0).cpu() img = T.functional.to_pil_image(x) img.save(os.path.join(save_path, f"{filename[:-5]}.png")) print( "Decomp: " f"{decoder_time_accumulator.mean_time_spent():.3f};\t" "Decomp Time By Scale: ", end="") decomp_scale_times = coder.decomp_scale_times() print( ", ".join(f"{scale_time:.3f}" for scale_time in decomp_scale_times), end="; ") total_num_bytes += os.path.getsize(filepath) total_num_subpixels += np.prod(x.size()) print( f"Bpsp: {total_num_bytes*8/total_num_subpixels:.3f}", end="\r") print()
def main( train_path: str, eval_path: str, train_file, eval_file, batch: int, workers: int, plot: str, epochs: int, resblocks: int, n_feats: int, scale: int, load: str, lr: float, eval_iters: int, lr_epochs: int, plot_iters: int, k: int, clip: float, crop: int, gd: str, ) -> None: ImageFile.LOAD_TRUNCATED_IMAGES = True configs.n_feats = n_feats configs.scale = scale configs.resblocks = resblocks configs.K = k configs.plot = plot print(sys.argv) os.makedirs(plot, exist_ok=True) model_load = os.path.join(plot, "train.pth") if os.path.isfile(model_load): load = model_load if os.path.isfile(load) and load != "/dev/null": checkpoint = torch.load(load) print(f"Loaded model from {load}.") print("Epoch:", checkpoint["epoch"]) if checkpoint.get("best_bpsp") is None: print("Warning: best_bpsp not found!") else: configs.best_bpsp = checkpoint["best_bpsp"] print("Best bpsp:", configs.best_bpsp) else: checkpoint = {} compressor = network.Compressor() if checkpoint: compressor.nets.load_state_dict(checkpoint["nets"]) compressor = compressor.cuda() optimizer: optim.Optimizer # type: ignore if gd == "adam": optimizer = optim.Adam(compressor.parameters(), lr=lr, weight_decay=0) elif gd == "sgd": optimizer = optim.SGD(compressor.parameters(), lr=lr, momentum=0.9, nesterov=True) elif gd == "rmsprop": optimizer = optim.RMSprop( # type: ignore compressor.parameters(), lr=lr) else: raise NotImplementedError(gd) starting_epoch = checkpoint.get("epoch") or 0 print(compressor) train_dataset = lc_data.ImageFolder( train_path, [filename.strip() for filename in train_file], scale, T.Compose([ T.RandomHorizontalFlip(), T.RandomCrop(crop), ]), ) dataset_index = checkpoint.get("index") or 0 train_sampler = lc_data.PreemptiveRandomSampler( checkpoint.get("sampler_indices") or torch.randperm(len(train_dataset)).tolist(), dataset_index, ) train_loader = data.DataLoader( train_dataset, batch_size=batch, sampler=train_sampler, num_workers=workers, drop_last=True, ) print(f"Loaded training dataset with {len(train_loader)} batches " f"and {len(train_loader.dataset)} images") eval_dataset = lc_data.ImageFolder( eval_path, [filename.strip() for filename in eval_file], scale, T.Lambda(lambda x: x), ) eval_loader = data.DataLoader( eval_dataset, batch_size=1, shuffle=False, num_workers=workers, drop_last=False, ) print(f"Loaded eval dataset with {len(eval_loader)} batches " f"and {len(eval_dataset)} images") lr_scheduler = optim.lr_scheduler.StepLR(optimizer, lr_epochs, gamma=0.75) for _ in range(starting_epoch): lr_scheduler.step() # type: ignore train_iter = checkpoint.get("train_iter") or 0 if eval_iters == 0: eval_iters = len(train_loader) for epoch in range(starting_epoch, epochs): with tensorboard.SummaryWriter(plot) as plotter: # input: List[Tensor], downsampled images. # sizes: N scale 4 for _, inputs in train_loader: train_iter += 1 batch_size = inputs[0].shape[0] train_loop(inputs, compressor, optimizer, train_iter, plotter, plot_iters, clip) # Increment dataset_index before checkpointing because # dataset_index is starting index of index of the FIRST # unseen piece of data. dataset_index += batch_size if train_iter % plot_iters == 0: plotter.add_scalar( "train/lr", lr_scheduler.get_lr()[0], # type: ignore train_iter) save(compressor, train_sampler.indices, dataset_index, epoch, train_iter, plot, "train.pth") if train_iter % eval_iters == 0: run_eval(eval_loader, compressor, train_iter, plotter, epoch) lr_scheduler.step() # type: ignore dataset_index = 0 with tensorboard.SummaryWriter(plot) as plotter: run_eval(eval_loader, compressor, train_iter, plotter, epochs) save(compressor, train_sampler.indices, train_sampler.index, epochs, train_iter, plot, "train.pth") print("training done")
def main( path: str, file, resblocks: int, n_feats: int, scale: int, load: str, k: int, crop: int, log_likelihood: bool, decode: bool, save_path: str, ) -> None: configs.n_feats = n_feats configs.resblocks = resblocks configs.K = k configs.scale = scale configs.log_likelihood = log_likelihood configs.collect_probs = True print(sys.argv) checkpoint = torch.load(load) print(f"Loaded model from {load}.") print("Epoch:", checkpoint["epoch"]) compressor = network.Compressor() compressor.nets.load_state_dict(checkpoint["nets"]) compressor = compressor.cuda() print(compressor.nets) transforms = [] # type: ignore if crop > 0: transforms.insert(0, T.CenterCrop(crop)) dataset = lc_data.ImageFolder(path, [filename.strip() for filename in file], scale, T.Compose(transforms)) loader = data.DataLoader( dataset, batch_size=1, shuffle=False, num_workers=0, drop_last=False, ) print(f"Loaded directory with {len(dataset)} images") os.makedirs(save_path, exist_ok=True) coder = bitcoding.Bitcoding(compressor) encoder_time_accumulator = timer.TimeAccumulator() decoder_time_accumulator = timer.TimeAccumulator() total_file_bytes = 0 total_num_subpixels = 0 total_entropy_coding_bytes: np.ndarray = 0 # type: ignore total_log_likelihood_bits = network.Bits() for i, (filenames, x) in enumerate(loader): assert len(filenames) == 1, filenames filename = filenames[0] file_id = filename.split(".")[0] filepath = os.path.join(save_path, f"{file_id}.srec") with encoder_time_accumulator.execute(): log_likelihood_bits, entropy_coding_bytes = coder.encode( x, filepath) total_file_bytes += os.path.getsize(filepath) total_entropy_coding_bytes += np.array(entropy_coding_bytes) total_num_subpixels += np.prod(x.size()) if configs.log_likelihood: total_log_likelihood_bits.add_bits(log_likelihood_bits) if decode: with decoder_time_accumulator.execute(): y = coder.decode(filepath) y = y.cpu() assert torch.all(x == y), (x[x != y], y[x != y]) if configs.log_likelihood: theoretical_bpsp = total_log_likelihood_bits.get_total_bpsp( total_num_subpixels).item() print(f"Theoretical Bpsp: {theoretical_bpsp:.3f};\t", end="") print( f"Bpsp: {total_file_bytes*8/total_num_subpixels:.3f};\t" f"Images: {i + 1};\t" f"Comp: {encoder_time_accumulator.mean_time_spent():.3f};\t", end="") if decode: print( "Decomp: " f"{decoder_time_accumulator.mean_time_spent():.3f}", end="") print(end="\r") print() if decode: print("Decomp Time By Scale: ", end="") print(", ".join(f"{scale_time:.3f}" for scale_time in coder.decomp_scale_times())) else: print("Scale Bpsps: ", end="") print(", ".join(f"{scale_bpsp:.3f}" for scale_bpsp in total_entropy_coding_bytes * 8 / total_num_subpixels))