Пример #1
0
def run_eval(
    loader: data.DataLoader,
    compressor: nn.Module,
) -> Tuple[network.Bits, int]:
    """ Runs entire eval epoch. """
    time_accumulator = timer.TimeAccumulator()
    compressor.eval()
    cur_agg_size = 0

    with torch.no_grad():
        # BitsKeeper is used to aggregates bits from all eval iterations.
        bits_keeper = network.Bits()
        for i, (_, x) in enumerate(loader):
            cur_agg_size += np.prod(x.size())
            with time_accumulator.execute():
                x = x.cuda()
                bits = compressor(x)
            bits_keeper.add_bits(bits)

            bpsp = bits_keeper.get_total_bpsp(cur_agg_size)
            print(
                f"Bpsp: {bpsp.item():.3f}; Number of Images: {i + 1}; "
                f"Batch Time: {time_accumulator.mean_time_spent()}",
                end="\r")

        print()
    return bits_keeper, cur_agg_size
 def __init__(
     self,
     compressor: network.Compressor,
 ) -> None:
     self.compressor = compressor
     self.total_num_bytes = 0
     self.total_num_subpixels = 0
     self.log_likelihood_bits = network.Bits()
     self.file_sizes: np.ndarray = 0.  # type: ignore
     self.scale_timers = [
         timer.TimeAccumulator() for _ in range(configs.scale + 1)
     ]
Пример #3
0
def run_eval(
    loader: data.DataLoader,
    compressor: nn.Module,
) -> Tuple[network.Bits, int]:
    """ Runs entire eval epoch. """
    time_accumulator = timer.TimeAccumulator()
    compressor.eval()
    cur_agg_size = 0

    #  Get the individual images' bpsp
    individual_bpsps = [{} for _ in range(1000)]  # list of dictionary

    with torch.no_grad():
        # BitsKeeper is used to aggregates bits from all eval iterations.
        bits_keeper = network.Bits()
        for i, (_, x) in enumerate(loader):
            cur_agg_size += np.prod(x.size())
            with time_accumulator.execute():
                x = x.cuda()
                bits = compressor(x)
            bits_keeper.add_bits(bits)

            bpsp = bits_keeper.get_total_bpsp(cur_agg_size)

            # Get the individual images' bpsp

            individual_bpsps[i]["rounding"] = copy.deepcopy(bits.get_bits("eval/0_rounding")) \
                                          + copy.deepcopy(bits.get_bits("eval/1_rounding")) \
                                          + copy.deepcopy(bits.get_bits("eval/2_rounding"))
            individual_bpsps[i]["image_3"] = copy.deepcopy(
                bits.get_bits("eval/codes_0"))
            individual_bpsps[i]["image_2"] = copy.deepcopy(bits.key_to_bits["eval/0_0"].item()) + \
                                         copy.deepcopy(bits.key_to_bits["eval/0_1"].item()) + \
                                         copy.deepcopy(bits.key_to_bits["eval/0_2"].item())
            individual_bpsps[i]["image_1"] = copy.deepcopy(bits.key_to_bits["eval/1_0"].item()) + \
                                         copy.deepcopy(bits.key_to_bits["eval/1_1"].item()) + \
                                         copy.deepcopy(bits.key_to_bits["eval/1_2"].item())
            individual_bpsps[i]["image_0"] = copy.deepcopy(bits.key_to_bits["eval/2_0"].item()) + \
                                         copy.deepcopy(bits.key_to_bits["eval/2_1"].item()) + \
                                         copy.deepcopy(bits.key_to_bits["eval/2_2"].item())
            print(
                f"Bpsp: {bpsp.item():.3f}; Number of Images: {i + 1}; "
                f"Batch Time: {time_accumulator.mean_time_spent()}",
                end="\r")
            if (i == 999):
                break

        print()
    return bits_keeper, cur_agg_size, individual_bpsps
Пример #4
0
def run_eval(
    eval_loader: data.DataLoader,
    compressor: nn.Module,
    train_iter: int,
    plotter: tensorboard.SummaryWriter,
    epoch: int,
) -> None:
    """ Runs entire eval epoch. """
    time_accumulator = timer.TimeAccumulator()
    compressor.eval()
    inp_size = 0

    with torch.no_grad():
        # BitsKeeper is used to aggregates bits from all eval iterations.
        bits_keeper = network.Bits()
        for _, x in eval_loader:
            inp_size += np.prod(x.size())
            with time_accumulator.execute():
                x = x.cuda()
                bits = compressor(x)
            bits_keeper.add_bits(bits)

        total_bpsp = bits_keeper.get_total_bpsp(inp_size)

        eval_bpsp = total_bpsp.item()
        print(f"Iteration {train_iter} bpsp: {total_bpsp}")
        plotter.add_scalar("eval/bpsp", eval_bpsp, train_iter)
        plotter.add_scalar("eval/batch_time",
                           time_accumulator.mean_time_spent(), train_iter)
        plot_bpsp(plotter, bits_keeper, inp_size, train_iter)

        if configs.best_bpsp > eval_bpsp:
            configs.best_bpsp = eval_bpsp
            torch.save(
                {
                    "nets": compressor.nets.state_dict(),  # type: ignore
                    "best_bpsp": configs.best_bpsp,
                    "epoch": epoch
                },
                os.path.join(configs.plot, "best.pth"))
Пример #5
0
def main(
    path: str,
    file,
    resblocks: int,
    n_feats: int,
    scale: int,
    load: str,
    k: int,
    crop: int,
    log_likelihood: bool,
    decode: bool,
    save_path: str,
) -> None:

    configs.n_feats = n_feats
    configs.resblocks = resblocks
    configs.K = k
    configs.scale = scale
    configs.log_likelihood = log_likelihood
    configs.collect_probs = True

    print(sys.argv)

    checkpoint = torch.load(load)
    print(f"Loaded model from {load}.")
    print("Epoch:", checkpoint["epoch"])

    compressor = network.Compressor()
    compressor.nets.load_state_dict(checkpoint["nets"])
    compressor = compressor.cuda()

    print(compressor.nets)

    transforms = []  # type: ignore
    if crop > 0:
        transforms.insert(0, T.CenterCrop(crop))

    dataset = lc_data.ImageFolder(path,
                                  [filename.strip() for filename in file],
                                  scale, T.Compose(transforms))
    loader = data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        drop_last=False,
    )
    print(f"Loaded directory with {len(dataset)} images")

    os.makedirs(save_path, exist_ok=True)

    coder = bitcoding.Bitcoding(compressor)
    encoder_time_accumulator = timer.TimeAccumulator()
    decoder_time_accumulator = timer.TimeAccumulator()
    total_file_bytes = 0
    total_num_subpixels = 0
    total_entropy_coding_bytes: np.ndarray = 0  # type: ignore
    total_log_likelihood_bits = network.Bits()

    for i, (filenames, x) in enumerate(loader):
        assert len(filenames) == 1, filenames
        filename = filenames[0]
        file_id = filename.split(".")[0]
        filepath = os.path.join(save_path, f"{file_id}.srec")

        with encoder_time_accumulator.execute():
            log_likelihood_bits, entropy_coding_bytes = coder.encode(
                x, filepath)

        total_file_bytes += os.path.getsize(filepath)
        total_entropy_coding_bytes += np.array(entropy_coding_bytes)
        total_num_subpixels += np.prod(x.size())
        if configs.log_likelihood:
            total_log_likelihood_bits.add_bits(log_likelihood_bits)

        if decode:
            with decoder_time_accumulator.execute():
                y = coder.decode(filepath)
                y = y.cpu()
            assert torch.all(x == y), (x[x != y], y[x != y])

        if configs.log_likelihood:
            theoretical_bpsp = total_log_likelihood_bits.get_total_bpsp(
                total_num_subpixels).item()
            print(f"Theoretical Bpsp: {theoretical_bpsp:.3f};\t", end="")
        print(
            f"Bpsp: {total_file_bytes*8/total_num_subpixels:.3f};\t"
            f"Images: {i + 1};\t"
            f"Comp: {encoder_time_accumulator.mean_time_spent():.3f};\t",
            end="")
        if decode:
            print(
                "Decomp: "
                f"{decoder_time_accumulator.mean_time_spent():.3f}",
                end="")
        print(end="\r")
    print()

    if decode:
        print("Decomp Time By Scale: ", end="")
        print(", ".join(f"{scale_time:.3f}"
                        for scale_time in coder.decomp_scale_times()))
    else:
        print("Scale Bpsps: ", end="")
        print(", ".join(f"{scale_bpsp:.3f}"
                        for scale_bpsp in total_entropy_coding_bytes * 8 /
                        total_num_subpixels))