Ejemplo n.º 1
0
def plot_confusion_matrix(ax, confusion_matrix, class_names=None, normalize=False, title=None, cmap=plt.cm.Oranges):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    based on : https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
    """
    if not title:
        if normalize:
            title = "Normalized class distribustion"
        else:
            title = "class distribustion, without normalization"
    if confusion_matrix.shape[0] != confusion_matrix.shape[1]:
        log.warn("foun classes not same as predictions")
    # Compute confusion matrix
    cm = confusion_matrix
    # Only use the labels that appear in the data
    if class_names is None:
        classes = [str(i) for i in range(confusion_matrix.shape[0])]
    else:
        classes = class_names
    if normalize:
        cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print("Confusion matrix, without normalization")

    im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(
        xticks=np.arange(cm.shape[1]),
        yticks=np.arange(cm.shape[0]),
        # ... and label them with the respective list entries
        xticklabels=classes,
        yticklabels=classes,
        title=title,
        ylabel="True label",
        xlabel="Predicted label",
    )

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    # fmt = '.2f' if normalize else 'd'
    fmt = ".2f"
    thresh = cm.max() / 2.0
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(
                j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black"
            )
    return ax
def get_skill_dataloader(dir_vids,
                         num_views,
                         batch_size,
                         use_cuda,
                         img_size,
                         filter_func,
                         label_funcs,
                         num_domain_frames=1,
                         stride=1):
    # sampler with all rand frames from alle task
    transformer_train = get_train_transformer(img_size=img_size)
    # sample differt view
    transformed_dataset_train_domain = DoubleViewPairDataset(
        vid_dir=dir_vids,
        number_views=num_views,
        # std_similar_frame_margin_distribution=sim_frames,
        transform_frames=transformer_train,
        lable_funcs=label_funcs,
        filter_func=filter_func)

    sampler = None
    drop_last = True
    log.info('transformed_dataset_train_domain len: {}'.format(
        len(transformed_dataset_train_domain)))
    if num_domain_frames > 1:
        assert batch_size % num_domain_frames == 0, 'wrong batch size for multi frames '
        sampler = SkillViewPairSequenceSampler(
            dataset=transformed_dataset_train_domain,
            stride=stride,
            allow_same_frames_in_seq=True,
            sequence_length=num_domain_frames,
            sequences_per_vid_in_batch=1,
            batch_size=batch_size)
        log.info('use multi frame dir {} len sampler: {}'.format(
            dir_vids, len(sampler)))
        drop_last = len(sampler) >= batch_size

    # random smaple vid
    dataloader_train_domain = DataLoader(
        transformed_dataset_train_domain,
        drop_last=drop_last,
        batch_size=batch_size,
        shuffle=True if sampler is None else False,
        num_workers=4,
        sampler=sampler,
        pin_memory=use_cuda)

    if sampler is not None and len(sampler) <= batch_size:
        log.warn("dataset sampler batch size")
    return dataloader_train_domain
Ejemplo n.º 3
0
def tb_write_class_dist(writer, class_predictions, data_cnt, step_writer, task_names, plot_name):
    if 0 not in class_predictions:
        log.warn("distri class not starts with label zero, missing classe in dataloader!")
    cm = [d / data_cnt for c, d in sorted(class_predictions.items())]
    cm = np.vstack(cm)
    fig, ax = plt.subplots()
    if task_names is None:
        class_names = [str(c) for c in class_predictions.keys()]
    else:
        class_names = task_names
        print("class_names: {}".format(class_names))

    plot_confusion_matrix(ax, cm, class_names=class_names)
    fig.tight_layout()
    writer.add_figure(plot_name, fig, step_writer)
def _filter_view_pairs(video_paths_pair, frames_length_pair, filter_func):
    filtered_paths = []
    filtered_vid_len = []
    all_comm_names = _get_all_comm_view_pair_names(video_paths_pair)
    for vp, comm_name, len_views in zip(video_paths_pair, all_comm_names,
                                        frames_length_pair):
        if filter_func(str(comm_name), list(len_views)):
            assert len(vp) < 2 or vp[0] != vp[
                1], "view pair with same video: {}".format(vp)
            filtered_paths.append(vp)
            filtered_vid_len.append(len_views)
    if len(video_paths_pair) != len(filtered_paths):
        log.warn('dataset filtered videos form {} to {}'.format(
            len(video_paths_pair), len(filtered_paths)))
    else:
        log.warn("no videos filtered, but filter function is not not None")
    assert len(filtered_paths) > 0
    return filtered_paths, filtered_vid_len
Ejemplo n.º 5
0
 def __init__(self,
              vid_path,
              resize_shape=None,
              dtype=np.uint8,
              to_rgb=True,
              torch_transformer=None):
     self._vid_path = os.path.expanduser(vid_path)
     assert os.path.isfile(self._vid_path), "vid not exists: {}".format(
         self._vid_path)
     self.fps = None
     self.to_rgb = to_rgb
     self.approximate_frameCount = None
     self.frameWidth = None
     self.frameHeight = None
     self.resize_shape = resize_shape  # cv2 resize
     self.torch_transformer = torch_transformer  # torch vision
     if self.torch_transformer is not None:
         assert dtype == np.uint8 and self.resize_shape is None
         if not to_rgb:
             log.warn("bgr imag and torch transformer")
     self.dtype = dtype
def get_frame(vid_file, frame, use_image_if_exists):
    """ load frames form vid file or load from images if exists (set in .csv) """
    _, tail = os.path.split(vid_file)
    vid_file_comm, view_num_i, _, _ = split_view_file_name(tail)
    csv_file_dir = os.path.dirname(vid_file)
    csv_file = get_video_csv_file(csv_file_dir, vid_file_comm)
    if csv_file is not None:
        # read image file form csv
        try:
            key_image = "image_file_view_{}".format(view_num_i)
            image_file = get_state_labled_frame(csv_file, frame, key=key_image)
            image_file = os.path.join(csv_file_dir, image_file)
            image_file = os.path.abspath(image_file)
            rgb_unit8 = cv2.imread(image_file)[..., ::-1]
            if rgb_unit8 is None:
                raise ValueError(
                    "file not found form csv {}".format(image_file))
            # load image file
            return vid_file_comm, rgb_unit8
        except KeyError:
            log.warn("no image key in csv for {}".format(vid_file))
    vid = VideoFrameSampler(vid_file)
    return vid_file_comm, vid.get_frame(frame)
Ejemplo n.º 7
0
    def get_all(self):
        tmp = self.get_frame(0)
        # TODO CAP_PROP_FRAME_COUNT is a approximate
        if not isinstance(tmp, torch.Tensor):
            all_frames = np.empty((len(self), ) + tmp.shape, self.dtype)
            for i, rgb in enumerate(self):
                all_frames[i, :, :, :] = rgb
        else:
            all_frames = torch.zeros((len(self), ) + tmp.size(),
                                     dtype=tmp.dtype)
            try:
                backend = "ffmpeg" if skvideo._HAS_FFMPEG else "libav"
                if not self.to_rgb:  # convert to bgr not supported here TODO
                    raise ValueError()
                vid = skvideo.io.vread(self._vid_path, backend=backend)
            except ValueError as e:
                log.warn("skvideo failed, falling back to cv2")
                vid = self
            for i, rgb in enumerate(vid):
                if self.torch_transformer is not None:
                    rgb = self.torch_transformer(rgb)
                all_frames[i, :, :, :] = rgb

        return all_frames
def visualize_embeddings(
    func_model_forward,
    data_loader,
    summary_writer=None,
    global_step=0,
    seq_len=None,
    stride=None,
    label_func=None,
    save_dir=None,
    tag="",
    emb_size=32,
):
    """visualize embeddings with tensorboardX

    Args:
        summary_writer(tensorboardX.SummaryWriter):
        data_loader(ViewPairDataset): with shuffle false
        label_func: function to label a frame: input is (vid_file_comm,frame_idx=None,vid_len=None,csv_file=None,state_label=None)
    Returns:
        None
        :param func_model_forward:
        :param global_step:
        :param seq_len:
        :param stride:
        :param save_dir:
        :param tag:
        :param emb_size:

    """
    assert isinstance(
        data_loader.dataset,
        ViewPairDataset), "dataset must be form type ViewPairDataset"
    data_len = len(data_loader.dataset)
    vid_dir = data_loader.dataset.vid_dir

    if seq_len:
        assert stride is not None
        # cut off first frames
        data_len -= seq_len * stride * len(data_loader.dataset.video_paths)
    embeddings = np.empty((data_len, emb_size))
    img_size = 50  # image size to plot
    frames = torch.empty((data_len, 3, img_size, img_size))
    # trans form the image to plot it later
    trans = transforms.Compose([
        transforms.ToPILImage(),  # expects rgb, moves channel to front
        transforms.Resize(img_size),
        transforms.ToTensor(),  # image 0-255 to 0. - 1.0
    ])
    cnt_data = 0
    labels = []
    view_pair_name_labels = []
    labels_frame_idx = []
    vid_len_frame_idx = []
    with tqdm(total=len(data_loader),
              desc="computing embeddings for {} frames".format(
                  len(data_loader))) as pbar:
        for i, data in enumerate(data_loader):
            # compute the emb for a batch
            frames_batch = data["frame"]
            if seq_len is None:
                emb = func_model_forward(frames_batch)
                # add emb to dict and to quue if all frames
                # for e, name, view, last in zip(emb, data["common name"], data["view"].numpy(), data['is last frame'].numpy()):
                # transform all frames to a smaller image to plt later
                for e, frame in zip(emb, frames_batch):
                    embeddings[cnt_data] = e
                    # transform only for on img possible
                    frames[cnt_data] = trans(frame).cpu()
                    cnt_data += 1
                    if data_len == cnt_data:
                        break
                state_label = data.get("state lable", None)
                comm_name = data["common name"]
                frame_idx = data["frame index"]
                vid_len = data["video len"]
                labels_frame_idx.extend(frame_idx.numpy())
                vid_len_frame_idx.extend(vid_len.numpy())
                if label_func is not None:
                    state_label = len(comm_name) * [
                        None
                    ] if state_label is None else state_label
                    state_label = [
                        label_func(c, i, v_len, get_video_csv_file(vid_dir, c),
                                   la) for c, la, i, v_len in
                        zip(comm_name, state_label, frame_idx, vid_len)
                    ]
                else:
                    state_label = comm_name
                labels.extend(state_label)
                view_pair_name_labels.extend(comm_name)
                if data_len == cnt_data:
                    break
            else:
                raise NotImplementedError()

            pbar.update(1)

    log.info("number of found labels: {}".format(len(labels)))
    if len(labels) != len(embeddings):
        # in case of rnn seq cut cuff an the end, in case of drop last
        log.warn(
            "number of labels {} smaller than embeddings, changing embeddings size"
            .format(len(labels)))
        embeddings = embeddings[:len(labels)]
        frames = frames[:len(labels)]
    if len(labels) == 0:
        log.warn("length of labels is zero!")
    else:
        log.info("start TSNE fit")
        labels = labels[:data_len]
        imgs = flip_imgs(frames.numpy(), rgb_to_front=False)
        rnn_tag = "_seq{}_stride{}".format(
            seq_len, stride) if seq_len is not None else ""
        X_tsne = TSNE_multi(n_jobs=4, perplexity=40).fit_transform(
            embeddings)  # perplexity = 40, theta=0.5
        create_time_vid(X_tsne, labels_frame_idx, vid_len_frame_idx)
        plot_embedding(
            X_tsne,
            labels,
            title=tag + "multi-t-sne_perplexity40_theta0.5_step" +
            str(global_step) + rnn_tag,
            imgs=imgs,
            save_dir=save_dir,
            frame_lable=labels_frame_idx,
            max_frame=vid_len_frame_idx,
            vid_lable=view_pair_name_labels,
        )
def main():
    args = get_args()
    args.out_dir = os.path.expanduser(args.out_dir)
    ports = list(map(int, args.ports.split(",")))
    log.info("ports: {}".format(ports))
    sample_events = [multiprocessing.Event() for _ in ports]
    num_frames = args.max_frame
    if args.display:
        disp_q = multiprocessing.Queue()
        p = Process(target=display_worker, args=(disp_q, ), daemon=True)
        p.start()
    # process to save images as a file
    im_data_q, im_file_q = multiprocessing.Queue(), multiprocessing.Queue()
    img_folder = os.path.join(args.out_dir, "images", args.set_name, args.tag)
    vid_folder = os.path.join(args.out_dir, "videos", args.set_name)
    img_args = (ports, img_folder, args.tag, im_data_q, im_file_q)
    p = Process(target=save_img_worker, args=img_args, daemon=True)
    p.start()

    log.info("img_folder: {}".format(img_folder))
    log.info("vid_folder: {}".format(vid_folder))
    log.info("fps: {}".format(args.fps))

    try:
        time_prev = time.time()
        # loop to sample frames with events
        for frame_cnt, port_data in enumerate(
                sample_frames(ports, sample_events, num_frames)):
            sample_time_dt = time.time() - time_prev
            if frame_cnt % 10 == 0:
                log.info("frame {} time_prev: {}".format(
                    frame_cnt,
                    time.time() - time_prev))

            time_prev = time.time()
            # set events to trigger cams
            for e in sample_events:
                e.set()

            if frame_cnt == 0:
                # skip first frame because not  synchronized with event
                log.info("START: {}".format(frame_cnt))
                continue
            elif (sample_time_dt - 1.0 / args.fps) > 0.1:
                log.warn("sampling frame taks too long for fps")
            # check sampel time diff
            if len(ports) > 1:
                dt = [
                    np.abs(p1["time"] - p2["time"])
                    for p1, p2 in combinations(port_data.values(), 2)
                ]
                # log.info('dt: {}'.format(np.mean(dt)))
                if np.max(dt) > 0.1:
                    log.warn(
                        "camera sample max time dt: {}, check light condition and camera models"
                        .format(np.max(dt)))
            assert all(frame_cnt == d["num"]
                       for d in port_data.values()), "out of sync"

            im_data_q.put(port_data)
            if args.display:
                disp_q.put(port_data)

            time.sleep(1.0 / args.fps)
    except KeyboardInterrupt:
        # create vids form images save before
        im_shape = {p: d["frame"].shape for p, d in port_data.items()}
        img_files = defaultdict(list)
        for d in get_all_queue_result(im_file_q):
            for p, f in d.items():
                img_files[p].append(f)
        # TODO start for each a procresss and join
        for view_i, p in enumerate(port_data.keys()):
            save_vid_worker(img_files[p], view_i, vid_folder, args.tag,
                            im_shape[p], args.fps)

    cv2.destroyAllWindows()