Exemple #1
0
    def _download_url_to_file(self, url, dst, progress):
        file_size = None
        u = urlopen(url)
        meta = u.info()
        if hasattr(meta, 'getheaders'):
            content_length = meta.getheaders("Content-Length")
        else:
            content_length = meta.get_all("Content-Length")
        if content_length is not None and len(content_length) > 0:
            file_size = int(content_length[0])

        f = tempfile.NamedTemporaryFile(delete=False)
        try:
            with tqdm(total=file_size, disable=not progress) as pbar:
                while True:
                    buffer = u.read(8192)
                    if len(buffer) == 0:
                        break
                    f.write(buffer)
                    pbar.update(len(buffer))

            f.close()
            shutil.move(f.name, dst)
        finally:
            f.close()
            if os.path.exists(f.name):
                os.remove(f.name)
def get_image_level_data(study_type):
    """
    Returns a dict, with keys 'train' and 'valid' and respective values as study level dataframes,
    these dataframes contain three columns 'Path', 'Count', 'Label'
    Args:
        study_type (string): one of the seven study type folder names in 'train/valid/test' dataset
    """
    image_data = {}
    study_label = {"positive": 1, "negative": 0}
    for phase in categories:
        BASE_DIR = "{}/{}/{}".format(MURA_BASE, phase, study_type)
        patients = list(
            os.walk(BASE_DIR))[0][1]  # list of patient folder names
        image_data[phase] = pd.DataFrame(
            columns=["Path", "Label", "Study_Type", "Study_Type_OH"])
        i = 0
        for patient in tqdm(patients):  # for each patient folder
            for study in os.listdir(os.path.join(
                    BASE_DIR,
                    patient)):  # for each study in that patient folder
                label = study_label[study.split('_')[1]]  # get label 0 or 1
                study_path = os.path.join(BASE_DIR, patient,
                                          study) + "/"  # path to this study
                for image in os.listdir(study_path):
                    path = os.path.join(study_path, image)
                    image_data[phase].loc[i] = [
                        path, label, study_type,
                        body_part_to_one_hot(study_type)
                    ]  # add new row
                    i += 1
    return image_data
def _compute_aspect_ratios_slow(dataset, indices=None):
    print("Your dataset doesn't support the fast path for "
          "computing the aspect ratios, so will iterate over "
          "the full dataset and load every image instead. "
          "This might take some time...")
    if indices is None:
        indices = range(len(dataset))

    class SubsetSampler(Sampler):
        def __init__(self, indices):
            self.indices = indices

        def __iter__(self):
            return iter(self.indices)

        def __len__(self):
            return len(self.indices)

    sampler = SubsetSampler(indices)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        sampler=sampler,
        num_workers=14,  # you might want to increase it for faster processing
        collate_fn=lambda x: x[0])
    aspect_ratios = []
    with tqdm(total=len(dataset)) as pbar:
        for _i, (img, _) in enumerate(data_loader):
            pbar.update(1)
            height, width = img.shape[-2:]
            aspect_ratio = float(width) / float(height)
            aspect_ratios.append(aspect_ratio)
    return aspect_ratios
Exemple #4
0
    def _compute_frame_pts(self):
        self.video_pts = []
        self.video_fps = []

        # strategy: use a DataLoader to parallelize read_video_timestamps
        # so need to create a dummy dataset first
        class DS(object):
            def __init__(self, x):
                self.x = x

            def __len__(self):
                return len(self.x)

            def __getitem__(self, idx):
                # print(self.x[idx])
                try:
                    testoutput = read_video_timestamps(self.x[idx])
                    return testoutput
                except:
                    # import pdb; pdb.set_trace()
                    print('Got a problem at:', self.x[idx])

        import torch.utils.data
        dl = torch.utils.data.DataLoader(DS(self.video_paths),
                                         batch_size=16,
                                         num_workers=self.num_workers,
                                         collate_fn=lambda x: x)
        # print(len(dl))
        with tqdm(total=len(dl)) as pbar:
            for batch in dl:
                pbar.update(1)
                clips, fps = list(zip(*batch))
                clips = [torch.as_tensor(c) for c in clips]
                self.video_pts.extend(clips)
                self.video_fps.extend(fps)
Exemple #5
0
def download_google_drive_url(url: str, output_path: str, output_file_name: str):
    """
    Download a file from google drive

    Downloading an URL from google drive requires confirmation when
    the file of the size is too big (google drive notifies that
    anti-viral checks cannot be performed on such files)
    """
    import requests

    with requests.Session() as session:

        # First get the confirmation token and append it to the URL
        with session.get(url, stream=True, allow_redirects=True) as response:
            for k, v in response.cookies.items():
                if k.startswith("download_warning"):
                    url = url + "&confirm=" + v

        # Then download the content of the file
        with session.get(url, stream=True, verify=True) as response:
            makedir(output_path)
            path = os.path.join(output_path, output_file_name)
            total_size = int(response.headers.get("Content-length", 0))
            with open(path, "wb") as file:
                from tqdm import tqdm

                with tqdm(total=total_size) as progress_bar:
                    for block in response.iter_content(
                        chunk_size=io.DEFAULT_BUFFER_SIZE
                    ):
                        file.write(block)
                        progress_bar.update(len(block))
Exemple #6
0
def evaluate(
    org_seq_path: Path,
    dec_seq_path: Path,
    bitstream_path: Path,
    cuda: bool = False,
) -> Dict[str, Any]:
    # load original and decoded sequences
    org_seq = RawVideoSequence.from_file(str(org_seq_path))
    dec_seq = RawVideoSequence.new_like(org_seq, str(dec_seq_path))

    max_val = 2**org_seq.bitdepth - 1
    num_frames = len(org_seq)

    if len(dec_seq) != num_frames:
        raise RuntimeError(
            "Invalid number of frames in decoded sequence "
            f"({num_frames}!={len(dec_seq)})"
        )

    if org_seq.format != VideoFormat.YUV420:
        raise NotImplementedError(f"Unsupported video format: {org_seq.format}")

    # compute metrics for each frame
    results = defaultdict(list)
    device = "cuda" if cuda else "cpu"
    with tqdm(total=num_frames) as pbar:
        for i in range(num_frames):
            org_frame = to_tensors(org_seq[i], device=device)
            dec_frame = to_tensors(dec_seq[i], device=device)
            metrics = compute_metrics_for_frame(org_frame, dec_frame, org_seq.bitdepth)
            for k, v in metrics.items():
                results[k].append(v)
            pbar.update(1)

    # compute average metrics for sequence
    seq_results: Dict[str, Any] = {
        k: torch.mean(torch.stack(v)) for k, v in results.items()
    }
    filesize = get_filesize(bitstream_path)
    seq_results["bitrate"] = float(
        filesize * 8 * org_seq.framerate / (num_frames * 1000)
    )

    seq_results["psnr-rgb"] = (
        20 * np.log10(max_val) - 10 * torch.log10(seq_results.pop("mse-rgb")).item()
    )
    for component in "yuv":
        seq_results[f"psnr-{component}"] = (
            20 * np.log10(max_val)
            - 10 * torch.log10(seq_results.pop(f"mse-{component}")).item()
        )
    seq_results["psnr-yuv"] = (
        4 * seq_results["psnr-y"] + seq_results["psnr-u"] + seq_results["psnr-v"]
    ) / 6
    for k, v in seq_results.items():
        if isinstance(v, torch.Tensor):
            seq_results[k] = v.item()
    return seq_results
Exemple #7
0
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
    with open(filename, "wb") as fh:
        with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
            with tqdm(total=response.length) as pbar:
                for chunk in iter(lambda: response.read(chunk_size), ""):
                    if not chunk:
                        break
                    pbar.update(chunk_size)
                    fh.write(chunk)
Exemple #8
0
def calc_normalization(train_dl: torch.utils.data.DataLoader):
    "Calculate the mean and std of each channel on images from `train_dl`"
    mean = torch.zeros(3)
    m2 = torch.zeros(3)
    n = len(train_dl)
    for images, labels in tqdm(train_dl, "Compute normalization"):
        mean += images.mean([0, 2, 3]) / n
        m2 += (images ** 2).mean([0, 2, 3]) / n
    var = m2 - mean ** 2
    return mean, var.sqrt()
Exemple #9
0
def _save_response_content(response, destination, chunk_size=32768):
    with open(destination, "wb") as f:
        pbar = tqdm(total=None)
        progress = 0
        for chunk in response.iter_content(chunk_size):
            if chunk:  # filter out keep-alive new chunks
                f.write(chunk)
                progress += len(chunk)
                pbar.update(progress - pbar.n)
        pbar.close()
def gen_bar_updater(total) -> Callable[[int, int, int], None]:
    pbar = tqdm(total=total, unit='Byte')

    def bar_update(count, block_size, total_size):
        if pbar.total is None and total_size:
            pbar.total = total_size
        progress_bytes = count * block_size
        pbar.update(progress_bytes - pbar.n)

    return bar_update
Exemple #11
0
def gen_bar_updater():
    pbar = tqdm(total=None)

    def bar_update(count, block_size, total_size):
        if pbar.total is None and total_size:
            pbar.total = total_size
        progress_bytes = count * block_size
        pbar.update(progress_bytes - pbar.n)

    return bar_update
Exemple #12
0
    def __gen_bar_updater(self):  # pylint: disable=no-self-use
        pbar = tqdm(total=None)

        def bar_update(count, block_size, total_size):
            if pbar.total is None and total_size:
                pbar.total = total_size
            progress_bytes = count * block_size
            pbar.update(progress_bytes - pbar.n)

        return bar_update
Exemple #13
0
 def on_loader_start(self, runner: "IRunner"):
     """Init tqdm progress bar."""
     self.step = 0
     self.tqdm = tqdm(
         total=runner.loader_batch_len,
         desc=f"{runner.stage_epoch_step}/{runner.stage_epoch_len}"
         f" * Epoch ({runner.loader_key})",
         # leave=True,
         # ncols=0,
         # file=sys.stdout,
     )
Exemple #14
0
def gen_bar_updater():
    """@TODO: Docs. Contribution is welcome."""
    pbar = tqdm(total=None)

    def bar_update(count, block_size, total_size):
        if pbar.total is None and total_size:
            pbar.total = total_size
        progress_bytes = count * block_size
        pbar.update(progress_bytes - pbar.n)

    return bar_update
Exemple #15
0
def copy_file_list(file_list, src_dir, dest_dir):
    with tqdm(total=len(file_list)) as pbar:
        for i, filename in enumerate(file_list):
            filename = filename.strip()
            if filename:
                # convert / to os-specific dir separator
                filename_parts = (filename + '.jpg').split('/')
                target = os.path.join(dest_dir, *filename_parts)
                if not os.path.isfile(target):
                    utils.copy_file(os.path.join(src_dir, *filename_parts),
                                    target)
            pbar.update(1)
Exemple #16
0
def _save_response_content(
    response: "requests.models.Response", destination: str, chunk_size: int = 32768,  # type: ignore[name-defined]
) -> None:
    with open(destination, "wb") as f:
        pbar = tqdm(total=None)
        progress = 0
        for chunk in response.iter_content(chunk_size):
            if chunk:  # filter out keep-alive new chunks
                f.write(chunk)
                progress += len(chunk)
                pbar.update(progress - pbar.n)
        pbar.close()
Exemple #17
0
def _save_response_content(
    content: Iterator[bytes],
    destination: str,
    length: Optional[int] = None,
) -> None:
    with open(destination, "wb") as fh, tqdm(total=length) as pbar:
        for chunk in content:
            # filter out keep-alive new chunks
            if not chunk:
                continue

            fh.write(chunk)
            pbar.update(len(chunk))
Exemple #18
0
def gen_bar_updater() -> Callable[[int, int, int], None]:
    warnings.warn(
        "The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15."
    )
    pbar = tqdm(total=None)

    def bar_update(count, block_size, total_size):
        if pbar.total is None and total_size:
            pbar.total = total_size
        progress_bytes = count * block_size
        pbar.update(progress_bytes - pbar.n)

    return bar_update
Exemple #19
0
def convert_to_coco_api(ds):
    from torch.utils.model_zoo import tqdm
    coco_ds = COCO()
    # annotation IDs need to start at 1, not 0, see torchvision issue #1530
    ann_id = 1
    dataset = {'images': [], 'categories': [], 'annotations': []}
    categories = set()
    for img_idx in tqdm(range(len(ds))):
        # find better way to get target
        # targets = ds.get_annotations(img_idx)
        img, targets = ds[img_idx]['images'], ds[img_idx]['targets']
        image_id = targets["image_id"].item()
        img_dict = {}
        img_dict['id'] = image_id
        img_dict['height'] = img.shape[-2]
        img_dict['width'] = img.shape[-1]
        dataset['images'].append(img_dict)
        bboxes = targets["boxes"]
        bboxes[:, 2:] -= bboxes[:, :2]
        bboxes = bboxes.tolist()
        labels = targets['labels'].tolist()
        areas = targets['area'].tolist()
        iscrowd = targets['iscrowd'].tolist()
        if 'masks' in targets:
            masks = targets['masks']
            # make masks Fortran contiguous for coco_mask
            masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
        if 'keypoints' in targets:
            keypoints = targets['keypoints']
            keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
        num_objs = len(bboxes)
        for i in range(num_objs):
            ann = {}
            ann['image_id'] = image_id
            ann['bbox'] = bboxes[i]
            ann['category_id'] = labels[i]
            categories.add(labels[i])
            ann['area'] = areas[i]
            ann['iscrowd'] = iscrowd[i]
            ann['id'] = ann_id
            if 'masks' in targets:
                ann["segmentation"] = coco_mask.encode(masks[i].numpy())
            if 'keypoints' in targets:
                ann['keypoints'] = keypoints[i]
                ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3])
            dataset['annotations'].append(ann)
            ann_id += 1
    dataset['categories'] = [{'id': i} for i in sorted(categories)]
    coco_ds.dataset = dataset
    coco_ds.createIndex()
    return coco_ds
Exemple #20
0
def _save_response_content(
        response_gen: Iterator[bytes],
        destination: str,  # type: ignore[name-defined]
) -> None:
    with open(destination, "wb") as f:
        pbar = tqdm(total=None)
        progress = 0

        for chunk in response_gen:
            if chunk:  # filter out keep-alive new chunks
                f.write(chunk)
                progress += len(chunk)
                pbar.update(progress - pbar.n)
        pbar.close()
Exemple #21
0
def decode_video(f, codec: CodecInfo, output):
    # read number of coded frames
    num_frames = read_uints(f, 1)[0]

    avg_frame_dec_time = []

    with torch.no_grad():
        x_ref = None
        with tqdm(total=num_frames) as pbar:
            for i in range(num_frames):
                frm_dec_start = time.time()

                if i == 0:
                    strings, shape = read_body(f)
                    x_out = codec.net.decode_keyframe(strings, shape)
                else:
                    mstrings, mshape = read_body(f)
                    rstrings, rshape = read_body(f)
                    inter_strings = {"motion": mstrings, "residual": rstrings}
                    inter_shapes = {"motion": mshape, "residual": rshape}

                    x_out = codec.net.decode_inter(x_ref, inter_strings,
                                                   inter_shapes)

                x_ref = x_out.clamp(0, 1)

                avg_frame_dec_time.append((time.time() - frm_dec_start))

                x_hat = crop(x_out, codec.original_size)
                img = torch2img(x_hat)

                if output is not None:
                    if Path(output).suffix == ".yuv":
                        rec = convert_rgb_yuv420(x_hat)
                        wopt = "wb" if i == 0 else "ab"
                        with Path(output).open(wopt) as fout:
                            write_frame(fout, rec, codec.original_bitdepth)
                    else:
                        img.save(output)

                pbar.update(1)

    return {"img": img, "avg_frm_dec_time": np.mean(avg_frame_dec_time)}
def eval_model_entropy_estimation(net: nn.Module, sequence: Path) -> Dict[str, Any]:
    org_seq = RawVideoSequence.from_file(str(sequence))

    if org_seq.format != VideoFormat.YUV420:
        raise NotImplementedError(f"Unsupported video format: {org_seq.format}")

    device = next(net.parameters()).device
    num_frames = len(org_seq)
    max_val = 2**org_seq.bitdepth - 1
    results = defaultdict(list)
    print(f" encoding {sequence.stem}", file=sys.stderr)
    with tqdm(total=num_frames) as pbar:
        for i in range(num_frames):
            x_cur = convert_yuv420_to_rgb(org_seq[i], device, max_val)
            x_cur, padding = pad(x_cur)

            if i == 0:
                x_rec, likelihoods = net.forward_keyframe(x_cur)  # type:ignore
            else:
                x_rec, likelihoods = net.forward_inter(x_cur, x_rec)  # type:ignore

            x_rec = x_rec.clamp(0, 1)

            metrics = compute_metrics_for_frame(
                org_seq[i],
                crop(x_rec, padding),
                device,
                max_val,
            )
            metrics["bitrate"] = estimate_bits_frame(likelihoods)

            for k, v in metrics.items():
                results[k].append(v)
            pbar.update(1)

    seq_results: Dict[str, Any] = {
        k: torch.mean(torch.stack(v)) for k, v in results.items()
    }
    seq_results["bitrate"] = float(seq_results["bitrate"]) * org_seq.framerate / 1000
    for k, v in seq_results.items():
        if isinstance(v, torch.Tensor):
            seq_results[k] = v.item()
    return seq_results
Exemple #23
0
def test_anomaly_detection(opt,
                           generator,
                           discriminator,
                           encoder,
                           dataloader,
                           device,
                           kappa=1.0):
    generator.load_state_dict(torch.load("results/generator"))
    discriminator.load_state_dict(torch.load("results/discriminator"))
    encoder.load_state_dict(torch.load("results/encoder"))

    generator.to(device).eval()
    discriminator.to(device).eval()
    encoder.to(device).eval()

    criterion = nn.MSELoss()

    with open("results/score.csv", "w") as f:
        f.write("label,img_distance,anomaly_score,z_distance\n")

    for (img, label) in tqdm(dataloader):

        real_img = img.to(device)

        real_z = encoder(real_img)
        fake_img = generator(real_z)
        fake_z = encoder(fake_img)

        real_feature = discriminator.forward_features(real_img)
        fake_feature = discriminator.forward_features(fake_img)

        # Scores for anomaly detection
        img_distance = criterion(fake_img, real_img)
        loss_feature = criterion(fake_feature, real_feature)
        anomaly_score = img_distance + kappa * loss_feature

        z_distance = criterion(fake_z, real_z)

        with open("results/score.csv", "a") as f:
            f.write(f"{label.item()},{img_distance},"
                    f"{anomaly_score},{z_distance}\n")
def train(train_loader, model, criterion, optimizer, scheduler):
    total_images = len(train_loader)

    ## set model to train
    model.train(True)

    ## initialize values
    losses = RunningAverage("train loss")
    top1 = RunningAverage("train top1")
    top5 = RunningAverage("train top5")

    for (inputs, labels) in tqdm(train_loader):
        if use_cuda:
            inputs = inputs.cuda()
            labels = labels.cuda()

        ## zero the parameter gradients
        optimizer.zero_grad()

        ## forward
        outputs = model(inputs)

        loss = criterion(outputs, labels)

        batch_size = inputs.size(0)
        (acc1, acc5) = accuracy(outputs.data.cpu(),
                                labels.data.cpu(),
                                topk=(1, 5))

        ## propagate loss backward
        loss.backward()
        optimizer.step()
        scheduler.step()

        ## statistics
        losses.update(loss.item(), batch_size)
        top1.update(acc1[0], batch_size)
        top5.update(acc5[0], batch_size)

    return (losses.avg, top1.avg, top5.avg)
def evaluate(val_loader, model, criterion):
    model.eval()

    total_images = len(val_loader)

    ## initialize values
    losses = RunningAverage("val loss")
    top1 = RunningAverage("val top1")
    top5 = RunningAverage("val top5")

    ## confusion matrix = during training
    correct = 0
    targets, preds = [], []

    with torch.no_grad():
        for (images, target) in tqdm(val_loader):
            if use_cuda:
                images = images.cuda()
                target = target.cuda()

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.data.cpu(),
                                  target.data.cpu(),
                                  topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            ## build confusion matrix
            pred = output.max(
                1, keepdim=True)[1]  # get the index of the max log-probability
            targets += list(target.cpu().numpy())
            preds += list(pred.cpu().numpy())
            confusion_mtx = sm.confusion_matrix(targets, preds)

    return (losses.avg, top1.avg, top5.avg, confusion_mtx)
Exemple #26
0
def stream_url(url: str,
               start_byte: Optional[int] = None,
               block_size: int = 32 * 1024,
               progress_bar: bool = True) -> Iterable:
    """Stream url by chunk

    Args:
        url (str): Url.
        start_byte (int, optional): Start streaming at that point (Default: ``None``).
        block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``).
        progress_bar (bool, optional): Display a progress bar (Default: ``True``).
    """

    # If we already have the whole file, there is no need to download it again
    req = urllib.request.Request(url, method="HEAD")
    with urllib.request.urlopen(req) as response:
        url_size = int(response.info().get("Content-Length", -1))
    if url_size == start_byte:
        return

    req = urllib.request.Request(url)
    if start_byte:
        req.headers["Range"] = "bytes={}-".format(start_byte)

    with urllib.request.urlopen(req) as upointer, tqdm(
            unit="B",
            unit_scale=True,
            unit_divisor=1024,
            total=url_size,
            disable=not progress_bar,
    ) as pbar:

        num_bytes = 0
        while True:
            chunk = upointer.read(block_size)
            if not chunk:
                break
            yield chunk
            num_bytes += len(chunk)
            pbar.update(len(chunk))
Exemple #27
0
def _compute_aspect_ratios_slow(dataset, indices=None):
    logger.info('Your dataset doesn\'t support the fast path for '
                'computing the aspect ratios, so will iterate over '
                'the full dataset and load every image instead. '
                'This might take some time...')
    if indices is None:
        indices = range(len(dataset))

    sampler = _SubsetSampler(indices)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        sampler=sampler,
        num_workers=14,  # you might want to increase it for faster processing
        collate_fn=lambda x: x[0])
    aspect_ratios = []
    with tqdm(total=len(dataset)) as pbar:
        for _i, tuple_item in enumerate(data_loader):
            img = tuple_item[0]
            pbar.update(1)
            height, width = img.shape[-2:]
            aspect_ratio = float(width) / float(height)
            aspect_ratios.append(aspect_ratio)
    return aspect_ratios
    criterion = torch.nn.CrossEntropyLoss()

model.eval()

total_images = len(dataloader)

## initialize values
losses = RunningAverage("test loss")
top1 = RunningAverage("test top1")
top5 = RunningAverage("test top5")

## confusion matrix = during training
correct = 0
targets, preds = [], []
with torch.no_grad():
    for (images, target) in tqdm(dataloader):
        if use_cuda:
            images = images.cuda()
            target = target.cuda()

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output.data.cpu(),
                              target.data.cpu(),
                              topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))
def eval_model(
    net: nn.Module, sequence: Path, binpath: Path, keep_binaries: bool = False
) -> Dict[str, Any]:
    org_seq = RawVideoSequence.from_file(str(sequence))

    if org_seq.format != VideoFormat.YUV420:
        raise NotImplementedError(f"Unsupported video format: {org_seq.format}")

    device = next(net.parameters()).device
    num_frames = len(org_seq)
    max_val = 2**org_seq.bitdepth - 1
    results = defaultdict(list)

    f = binpath.open("wb")

    print(f" encoding {sequence.stem}", file=sys.stderr)
    # write original image size
    write_uints(f, (org_seq.height, org_seq.width))
    # write original bitdepth
    write_uchars(f, (org_seq.bitdepth,))
    # write number of coded frames
    write_uints(f, (num_frames,))
    with tqdm(total=num_frames) as pbar:
        for i in range(num_frames):
            x_cur = convert_yuv420_to_rgb(org_seq[i], device, max_val)
            x_cur, padding = pad(x_cur)

            if i == 0:
                x_rec, enc_info = net.encode_keyframe(x_cur)
                write_body(f, enc_info["shape"], enc_info["strings"])
                # x_rec = net.decode_keyframe(enc_info["strings"], enc_info["shape"])
            else:
                x_rec, enc_info = net.encode_inter(x_cur, x_rec)
                for shape, out in zip(
                    enc_info["shape"].items(), enc_info["strings"].items()
                ):
                    write_body(f, shape[1], out[1])
                # x_rec = net.decode_inter(x_rec, enc_info["strings"], enc_info["shape"])

            x_rec = x_rec.clamp(0, 1)
            metrics = compute_metrics_for_frame(
                org_seq[i],
                crop(x_rec, padding),
                device,
                max_val,
            )

            for k, v in metrics.items():
                results[k].append(v)
            pbar.update(1)
    f.close()

    seq_results: Dict[str, Any] = {
        k: torch.mean(torch.stack(v)) for k, v in results.items()
    }

    seq_results["bitrate"] = (
        float(filesize(binpath)) * 8 * org_seq.framerate / (num_frames * 1000)
    )

    if not keep_binaries:
        binpath.unlink()

    for k, v in seq_results.items():
        if isinstance(v, torch.Tensor):
            seq_results[k] = v.item()
    return seq_results
Exemple #30
0
def encode_video(input, codec: CodecInfo, output):
    if Path(input).suffix != ".yuv":
        raise NotImplementedError(
            f"Unsupported video file extension: {Path(input).suffix}")

    # encode frames of YUV sequence only
    org_seq = RawVideoSequence.from_file(input)
    bitdepth = org_seq.bitdepth
    max_val = 2**bitdepth - 1
    if org_seq.format != VideoFormat.YUV420:
        raise NotImplementedError(
            f"Unsupported video format: {org_seq.format}")

    num_frames = codec.codec_header[2]
    if num_frames < 0:
        num_frames = org_seq.total_frms

    avg_frame_enc_time = []

    f = Path(output).open("wb")
    with torch.no_grad():
        # Write Video Header
        write_uchars(f, codec.codec_header[0:2])
        # write original image size
        write_uints(f, (org_seq.height, org_seq.width))
        # write original bitdepth
        write_uchars(f, (bitdepth, ))
        # write number of coded frames
        write_uints(f, (num_frames, ))

        x_ref = None
        with tqdm(total=num_frames) as pbar:
            for i in range(num_frames):
                frm_enc_start = time.time()

                x_cur = convert_yuv420_rgb(org_seq[i], codec.device, max_val)
                h, w = x_cur.size(2), x_cur.size(3)
                p = 128  # maximum 7 strides of 2
                x_cur = pad(x_cur, p)

                if i == 0:
                    x_out, out_info = codec.net.encode_keyframe(x_cur)
                    write_body(f, out_info["shape"], out_info["strings"])
                else:
                    x_out, out_info = codec.net.encode_inter(x_cur, x_ref)
                    for shape, out in zip(out_info["shape"].items(),
                                          out_info["strings"].items()):
                        write_body(f, shape[1], out[1])

                x_ref = x_out.clamp(0, 1)

                avg_frame_enc_time.append((time.time() - frm_enc_start))

                pbar.update(1)

        org_seq.close()
    f.close()

    size = filesize(output)
    bpp = float(size) * 8 / (h * w * num_frames)

    return {"bpp": bpp, "avg_frm_enc_time": np.mean(avg_frame_enc_time)}