예제 #1
0
def main(
    path: str,
    file,
    workers: int,
    resblocks: int,
    n_feats: int,
    scale: int,
    load: str,
    k: int,
    crop: int,
) -> None:
    configs.n_feats = n_feats
    configs.resblocks = resblocks
    configs.K = k
    configs.scale = scale

    print(sys.argv)
    print([item for item in dir(configs) if not item.startswith("__")])

    if load != "/dev/null":
        checkpoint = torch.load(load)
        print(f"Loaded model from {load}.")
        print("Epoch:", checkpoint["epoch"])
    else:
        checkpoint = {}

    compressor = network.Compressor()
    if checkpoint:
        compressor.nets.load_state_dict(checkpoint["nets"])
    compressor = compressor.cuda()

    print(f"Number of parameters: {count_params(compressor.nets)}")
    print(compressor.nets)

    transforms = []  # type: ignore
    if crop > 0:
        transforms.insert(0, T.CenterCrop(crop))

    dataset = lc_data.ImageFolder(path,
                                  [filename.strip() for filename in file],
                                  scale, T.Compose(transforms))
    loader = data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=workers,
        drop_last=False,
    )
    print(f"Loaded dataset with {len(dataset)} images")

    bits, inp_size, individual_bpsps = run_eval(loader, compressor)
    for key in bits.get_keys():
        print(f"{key}:", bits.get_scaled_bpsp(key, inp_size))

    print("Average total bpsp is: ", bits.get_total_bpsp(inp_size))

    # Get the individual images' bpsp
    # save the individual_bpsps
    with open('test.p', 'wb') as f:
        pickle.dump(individual_bpsps, f)
예제 #2
0
파일: decode.py 프로젝트: zeta1999/SReC
def main(
        path: str, file, resblocks: int, n_feats: int, scale: int,
        load: str, k: int, save_path: str,
) -> None:
    ImageFile.LOAD_TRUNCATED_IMAGES = True

    configs.n_feats = n_feats
    configs.resblocks = resblocks
    configs.K = k
    configs.scale = scale
    configs.collect_probs = False

    print(sys.argv)

    checkpoint = torch.load(load)
    print(f"Loaded model from {load}.")
    print("Epoch:", checkpoint["epoch"])

    compressor = network.Compressor()
    compressor.nets.load_state_dict(checkpoint["nets"])
    compressor = compressor.cuda()
    print(compressor.nets)

    filenames = [filename.strip() for filename in file]
    print(f"Loaded directory with {len(filenames)} images")

    os.makedirs(save_path, exist_ok=True)

    coder = bitcoding.Bitcoding(compressor)
    decoder_time_accumulator = timer.TimeAccumulator()
    total_num_bytes = 0
    total_num_subpixels = 0

    for filename in filenames:
        assert filename.endswith(".srec"), (
            f"{filename} is not a .srec file")
        filepath = os.path.join(path, filename)
        with decoder_time_accumulator.execute():
            x = coder.decode(filepath)
            x = x.byte().squeeze(0).cpu()
        img = T.functional.to_pil_image(x)
        img.save(os.path.join(save_path, f"{filename[:-5]}.png"))
        print(
            "Decomp: "
            f"{decoder_time_accumulator.mean_time_spent():.3f};\t"
            "Decomp Time By Scale: ",
            end="")
        decomp_scale_times = coder.decomp_scale_times()
        print(
            ", ".join(f"{scale_time:.3f}" for scale_time in decomp_scale_times),
            end="; ")

        total_num_bytes += os.path.getsize(filepath)
        total_num_subpixels += np.prod(x.size())

        print(
            f"Bpsp: {total_num_bytes*8/total_num_subpixels:.3f}", end="\r")
    print()
예제 #3
0
def main(
    train_path: str,
    eval_path: str,
    train_file,
    eval_file,
    batch: int,
    workers: int,
    plot: str,
    epochs: int,
    resblocks: int,
    n_feats: int,
    scale: int,
    load: str,
    lr: float,
    eval_iters: int,
    lr_epochs: int,
    plot_iters: int,
    k: int,
    clip: float,
    crop: int,
    gd: str,
) -> None:
    ImageFile.LOAD_TRUNCATED_IMAGES = True

    configs.n_feats = n_feats
    configs.scale = scale
    configs.resblocks = resblocks
    configs.K = k
    configs.plot = plot

    print(sys.argv)

    os.makedirs(plot, exist_ok=True)
    model_load = os.path.join(plot, "train.pth")
    if os.path.isfile(model_load):
        load = model_load
    if os.path.isfile(load) and load != "/dev/null":
        checkpoint = torch.load(load)
        print(f"Loaded model from {load}.")
        print("Epoch:", checkpoint["epoch"])
        if checkpoint.get("best_bpsp") is None:
            print("Warning: best_bpsp not found!")
        else:
            configs.best_bpsp = checkpoint["best_bpsp"]
            print("Best bpsp:", configs.best_bpsp)
    else:
        checkpoint = {}

    compressor = network.Compressor()
    if checkpoint:
        compressor.nets.load_state_dict(checkpoint["nets"])
    compressor = compressor.cuda()

    optimizer: optim.Optimizer  # type: ignore
    if gd == "adam":
        optimizer = optim.Adam(compressor.parameters(), lr=lr, weight_decay=0)
    elif gd == "sgd":
        optimizer = optim.SGD(compressor.parameters(),
                              lr=lr,
                              momentum=0.9,
                              nesterov=True)
    elif gd == "rmsprop":
        optimizer = optim.RMSprop(  # type: ignore
            compressor.parameters(), lr=lr)
    else:
        raise NotImplementedError(gd)

    starting_epoch = checkpoint.get("epoch") or 0

    print(compressor)

    train_dataset = lc_data.ImageFolder(
        train_path,
        [filename.strip() for filename in train_file],
        scale,
        T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomCrop(crop),
        ]),
    )
    dataset_index = checkpoint.get("index") or 0
    train_sampler = lc_data.PreemptiveRandomSampler(
        checkpoint.get("sampler_indices")
        or torch.randperm(len(train_dataset)).tolist(),
        dataset_index,
    )
    train_loader = data.DataLoader(
        train_dataset,
        batch_size=batch,
        sampler=train_sampler,
        num_workers=workers,
        drop_last=True,
    )
    print(f"Loaded training dataset with {len(train_loader)} batches "
          f"and {len(train_loader.dataset)} images")
    eval_dataset = lc_data.ImageFolder(
        eval_path,
        [filename.strip() for filename in eval_file],
        scale,
        T.Lambda(lambda x: x),
    )
    eval_loader = data.DataLoader(
        eval_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=workers,
        drop_last=False,
    )
    print(f"Loaded eval dataset with {len(eval_loader)} batches "
          f"and {len(eval_dataset)} images")

    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, lr_epochs, gamma=0.75)

    for _ in range(starting_epoch):
        lr_scheduler.step()  # type: ignore

    train_iter = checkpoint.get("train_iter") or 0
    if eval_iters == 0:
        eval_iters = len(train_loader)

    for epoch in range(starting_epoch, epochs):
        with tensorboard.SummaryWriter(plot) as plotter:
            # input: List[Tensor], downsampled images.
            # sizes: N scale 4
            for _, inputs in train_loader:
                train_iter += 1
                batch_size = inputs[0].shape[0]

                train_loop(inputs, compressor, optimizer, train_iter, plotter,
                           plot_iters, clip)
                # Increment dataset_index before checkpointing because
                # dataset_index is starting index of index of the FIRST
                # unseen piece of data.
                dataset_index += batch_size

                if train_iter % plot_iters == 0:
                    plotter.add_scalar(
                        "train/lr",
                        lr_scheduler.get_lr()[0],  # type: ignore
                        train_iter)
                    save(compressor, train_sampler.indices, dataset_index,
                         epoch, train_iter, plot, "train.pth")

                if train_iter % eval_iters == 0:
                    run_eval(eval_loader, compressor, train_iter, plotter,
                             epoch)

            lr_scheduler.step()  # type: ignore
            dataset_index = 0

    with tensorboard.SummaryWriter(plot) as plotter:
        run_eval(eval_loader, compressor, train_iter, plotter, epochs)
    save(compressor, train_sampler.indices, train_sampler.index, epochs,
         train_iter, plot, "train.pth")
    print("training done")
예제 #4
0
def main(
    path: str,
    file,
    resblocks: int,
    n_feats: int,
    scale: int,
    load: str,
    k: int,
    crop: int,
    log_likelihood: bool,
    decode: bool,
    save_path: str,
) -> None:

    configs.n_feats = n_feats
    configs.resblocks = resblocks
    configs.K = k
    configs.scale = scale
    configs.log_likelihood = log_likelihood
    configs.collect_probs = True

    print(sys.argv)

    checkpoint = torch.load(load)
    print(f"Loaded model from {load}.")
    print("Epoch:", checkpoint["epoch"])

    compressor = network.Compressor()
    compressor.nets.load_state_dict(checkpoint["nets"])
    compressor = compressor.cuda()

    print(compressor.nets)

    transforms = []  # type: ignore
    if crop > 0:
        transforms.insert(0, T.CenterCrop(crop))

    dataset = lc_data.ImageFolder(path,
                                  [filename.strip() for filename in file],
                                  scale, T.Compose(transforms))
    loader = data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        drop_last=False,
    )
    print(f"Loaded directory with {len(dataset)} images")

    os.makedirs(save_path, exist_ok=True)

    coder = bitcoding.Bitcoding(compressor)
    encoder_time_accumulator = timer.TimeAccumulator()
    decoder_time_accumulator = timer.TimeAccumulator()
    total_file_bytes = 0
    total_num_subpixels = 0
    total_entropy_coding_bytes: np.ndarray = 0  # type: ignore
    total_log_likelihood_bits = network.Bits()

    for i, (filenames, x) in enumerate(loader):
        assert len(filenames) == 1, filenames
        filename = filenames[0]
        file_id = filename.split(".")[0]
        filepath = os.path.join(save_path, f"{file_id}.srec")

        with encoder_time_accumulator.execute():
            log_likelihood_bits, entropy_coding_bytes = coder.encode(
                x, filepath)

        total_file_bytes += os.path.getsize(filepath)
        total_entropy_coding_bytes += np.array(entropy_coding_bytes)
        total_num_subpixels += np.prod(x.size())
        if configs.log_likelihood:
            total_log_likelihood_bits.add_bits(log_likelihood_bits)

        if decode:
            with decoder_time_accumulator.execute():
                y = coder.decode(filepath)
                y = y.cpu()
            assert torch.all(x == y), (x[x != y], y[x != y])

        if configs.log_likelihood:
            theoretical_bpsp = total_log_likelihood_bits.get_total_bpsp(
                total_num_subpixels).item()
            print(f"Theoretical Bpsp: {theoretical_bpsp:.3f};\t", end="")
        print(
            f"Bpsp: {total_file_bytes*8/total_num_subpixels:.3f};\t"
            f"Images: {i + 1};\t"
            f"Comp: {encoder_time_accumulator.mean_time_spent():.3f};\t",
            end="")
        if decode:
            print(
                "Decomp: "
                f"{decoder_time_accumulator.mean_time_spent():.3f}",
                end="")
        print(end="\r")
    print()

    if decode:
        print("Decomp Time By Scale: ", end="")
        print(", ".join(f"{scale_time:.3f}"
                        for scale_time in coder.decomp_scale_times()))
    else:
        print("Scale Bpsps: ", end="")
        print(", ".join(f"{scale_bpsp:.3f}"
                        for scale_bpsp in total_entropy_coding_bytes * 8 /
                        total_num_subpixels))