def plot_confusion_matrix(ax, confusion_matrix, class_names=None, normalize=False, title=None, cmap=plt.cm.Oranges): """ This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. based on : https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html """ if not title: if normalize: title = "Normalized class distribustion" else: title = "class distribustion, without normalization" if confusion_matrix.shape[0] != confusion_matrix.shape[1]: log.warn("foun classes not same as predictions") # Compute confusion matrix cm = confusion_matrix # Only use the labels that appear in the data if class_names is None: classes = [str(i) for i in range(confusion_matrix.shape[0])] else: classes = class_names if normalize: cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] print("Normalized confusion matrix") else: print("Confusion matrix, without normalization") im = ax.imshow(cm, interpolation="nearest", cmap=cmap) ax.figure.colorbar(im, ax=ax) # We want to show all ticks... ax.set( xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]), # ... and label them with the respective list entries xticklabels=classes, yticklabels=classes, title=title, ylabel="True label", xlabel="Predicted label", ) # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. # fmt = '.2f' if normalize else 'd' fmt = ".2f" thresh = cm.max() / 2.0 for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text( j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black" ) return ax
def get_skill_dataloader(dir_vids, num_views, batch_size, use_cuda, img_size, filter_func, label_funcs, num_domain_frames=1, stride=1): # sampler with all rand frames from alle task transformer_train = get_train_transformer(img_size=img_size) # sample differt view transformed_dataset_train_domain = DoubleViewPairDataset( vid_dir=dir_vids, number_views=num_views, # std_similar_frame_margin_distribution=sim_frames, transform_frames=transformer_train, lable_funcs=label_funcs, filter_func=filter_func) sampler = None drop_last = True log.info('transformed_dataset_train_domain len: {}'.format( len(transformed_dataset_train_domain))) if num_domain_frames > 1: assert batch_size % num_domain_frames == 0, 'wrong batch size for multi frames ' sampler = SkillViewPairSequenceSampler( dataset=transformed_dataset_train_domain, stride=stride, allow_same_frames_in_seq=True, sequence_length=num_domain_frames, sequences_per_vid_in_batch=1, batch_size=batch_size) log.info('use multi frame dir {} len sampler: {}'.format( dir_vids, len(sampler))) drop_last = len(sampler) >= batch_size # random smaple vid dataloader_train_domain = DataLoader( transformed_dataset_train_domain, drop_last=drop_last, batch_size=batch_size, shuffle=True if sampler is None else False, num_workers=4, sampler=sampler, pin_memory=use_cuda) if sampler is not None and len(sampler) <= batch_size: log.warn("dataset sampler batch size") return dataloader_train_domain
def tb_write_class_dist(writer, class_predictions, data_cnt, step_writer, task_names, plot_name): if 0 not in class_predictions: log.warn("distri class not starts with label zero, missing classe in dataloader!") cm = [d / data_cnt for c, d in sorted(class_predictions.items())] cm = np.vstack(cm) fig, ax = plt.subplots() if task_names is None: class_names = [str(c) for c in class_predictions.keys()] else: class_names = task_names print("class_names: {}".format(class_names)) plot_confusion_matrix(ax, cm, class_names=class_names) fig.tight_layout() writer.add_figure(plot_name, fig, step_writer)
def _filter_view_pairs(video_paths_pair, frames_length_pair, filter_func): filtered_paths = [] filtered_vid_len = [] all_comm_names = _get_all_comm_view_pair_names(video_paths_pair) for vp, comm_name, len_views in zip(video_paths_pair, all_comm_names, frames_length_pair): if filter_func(str(comm_name), list(len_views)): assert len(vp) < 2 or vp[0] != vp[ 1], "view pair with same video: {}".format(vp) filtered_paths.append(vp) filtered_vid_len.append(len_views) if len(video_paths_pair) != len(filtered_paths): log.warn('dataset filtered videos form {} to {}'.format( len(video_paths_pair), len(filtered_paths))) else: log.warn("no videos filtered, but filter function is not not None") assert len(filtered_paths) > 0 return filtered_paths, filtered_vid_len
def __init__(self, vid_path, resize_shape=None, dtype=np.uint8, to_rgb=True, torch_transformer=None): self._vid_path = os.path.expanduser(vid_path) assert os.path.isfile(self._vid_path), "vid not exists: {}".format( self._vid_path) self.fps = None self.to_rgb = to_rgb self.approximate_frameCount = None self.frameWidth = None self.frameHeight = None self.resize_shape = resize_shape # cv2 resize self.torch_transformer = torch_transformer # torch vision if self.torch_transformer is not None: assert dtype == np.uint8 and self.resize_shape is None if not to_rgb: log.warn("bgr imag and torch transformer") self.dtype = dtype
def get_frame(vid_file, frame, use_image_if_exists): """ load frames form vid file or load from images if exists (set in .csv) """ _, tail = os.path.split(vid_file) vid_file_comm, view_num_i, _, _ = split_view_file_name(tail) csv_file_dir = os.path.dirname(vid_file) csv_file = get_video_csv_file(csv_file_dir, vid_file_comm) if csv_file is not None: # read image file form csv try: key_image = "image_file_view_{}".format(view_num_i) image_file = get_state_labled_frame(csv_file, frame, key=key_image) image_file = os.path.join(csv_file_dir, image_file) image_file = os.path.abspath(image_file) rgb_unit8 = cv2.imread(image_file)[..., ::-1] if rgb_unit8 is None: raise ValueError( "file not found form csv {}".format(image_file)) # load image file return vid_file_comm, rgb_unit8 except KeyError: log.warn("no image key in csv for {}".format(vid_file)) vid = VideoFrameSampler(vid_file) return vid_file_comm, vid.get_frame(frame)
def get_all(self): tmp = self.get_frame(0) # TODO CAP_PROP_FRAME_COUNT is a approximate if not isinstance(tmp, torch.Tensor): all_frames = np.empty((len(self), ) + tmp.shape, self.dtype) for i, rgb in enumerate(self): all_frames[i, :, :, :] = rgb else: all_frames = torch.zeros((len(self), ) + tmp.size(), dtype=tmp.dtype) try: backend = "ffmpeg" if skvideo._HAS_FFMPEG else "libav" if not self.to_rgb: # convert to bgr not supported here TODO raise ValueError() vid = skvideo.io.vread(self._vid_path, backend=backend) except ValueError as e: log.warn("skvideo failed, falling back to cv2") vid = self for i, rgb in enumerate(vid): if self.torch_transformer is not None: rgb = self.torch_transformer(rgb) all_frames[i, :, :, :] = rgb return all_frames
def visualize_embeddings( func_model_forward, data_loader, summary_writer=None, global_step=0, seq_len=None, stride=None, label_func=None, save_dir=None, tag="", emb_size=32, ): """visualize embeddings with tensorboardX Args: summary_writer(tensorboardX.SummaryWriter): data_loader(ViewPairDataset): with shuffle false label_func: function to label a frame: input is (vid_file_comm,frame_idx=None,vid_len=None,csv_file=None,state_label=None) Returns: None :param func_model_forward: :param global_step: :param seq_len: :param stride: :param save_dir: :param tag: :param emb_size: """ assert isinstance( data_loader.dataset, ViewPairDataset), "dataset must be form type ViewPairDataset" data_len = len(data_loader.dataset) vid_dir = data_loader.dataset.vid_dir if seq_len: assert stride is not None # cut off first frames data_len -= seq_len * stride * len(data_loader.dataset.video_paths) embeddings = np.empty((data_len, emb_size)) img_size = 50 # image size to plot frames = torch.empty((data_len, 3, img_size, img_size)) # trans form the image to plot it later trans = transforms.Compose([ transforms.ToPILImage(), # expects rgb, moves channel to front transforms.Resize(img_size), transforms.ToTensor(), # image 0-255 to 0. - 1.0 ]) cnt_data = 0 labels = [] view_pair_name_labels = [] labels_frame_idx = [] vid_len_frame_idx = [] with tqdm(total=len(data_loader), desc="computing embeddings for {} frames".format( len(data_loader))) as pbar: for i, data in enumerate(data_loader): # compute the emb for a batch frames_batch = data["frame"] if seq_len is None: emb = func_model_forward(frames_batch) # add emb to dict and to quue if all frames # for e, name, view, last in zip(emb, data["common name"], data["view"].numpy(), data['is last frame'].numpy()): # transform all frames to a smaller image to plt later for e, frame in zip(emb, frames_batch): embeddings[cnt_data] = e # transform only for on img possible frames[cnt_data] = trans(frame).cpu() cnt_data += 1 if data_len == cnt_data: break state_label = data.get("state lable", None) comm_name = data["common name"] frame_idx = data["frame index"] vid_len = data["video len"] labels_frame_idx.extend(frame_idx.numpy()) vid_len_frame_idx.extend(vid_len.numpy()) if label_func is not None: state_label = len(comm_name) * [ None ] if state_label is None else state_label state_label = [ label_func(c, i, v_len, get_video_csv_file(vid_dir, c), la) for c, la, i, v_len in zip(comm_name, state_label, frame_idx, vid_len) ] else: state_label = comm_name labels.extend(state_label) view_pair_name_labels.extend(comm_name) if data_len == cnt_data: break else: raise NotImplementedError() pbar.update(1) log.info("number of found labels: {}".format(len(labels))) if len(labels) != len(embeddings): # in case of rnn seq cut cuff an the end, in case of drop last log.warn( "number of labels {} smaller than embeddings, changing embeddings size" .format(len(labels))) embeddings = embeddings[:len(labels)] frames = frames[:len(labels)] if len(labels) == 0: log.warn("length of labels is zero!") else: log.info("start TSNE fit") labels = labels[:data_len] imgs = flip_imgs(frames.numpy(), rgb_to_front=False) rnn_tag = "_seq{}_stride{}".format( seq_len, stride) if seq_len is not None else "" X_tsne = TSNE_multi(n_jobs=4, perplexity=40).fit_transform( embeddings) # perplexity = 40, theta=0.5 create_time_vid(X_tsne, labels_frame_idx, vid_len_frame_idx) plot_embedding( X_tsne, labels, title=tag + "multi-t-sne_perplexity40_theta0.5_step" + str(global_step) + rnn_tag, imgs=imgs, save_dir=save_dir, frame_lable=labels_frame_idx, max_frame=vid_len_frame_idx, vid_lable=view_pair_name_labels, )
def main(): args = get_args() args.out_dir = os.path.expanduser(args.out_dir) ports = list(map(int, args.ports.split(","))) log.info("ports: {}".format(ports)) sample_events = [multiprocessing.Event() for _ in ports] num_frames = args.max_frame if args.display: disp_q = multiprocessing.Queue() p = Process(target=display_worker, args=(disp_q, ), daemon=True) p.start() # process to save images as a file im_data_q, im_file_q = multiprocessing.Queue(), multiprocessing.Queue() img_folder = os.path.join(args.out_dir, "images", args.set_name, args.tag) vid_folder = os.path.join(args.out_dir, "videos", args.set_name) img_args = (ports, img_folder, args.tag, im_data_q, im_file_q) p = Process(target=save_img_worker, args=img_args, daemon=True) p.start() log.info("img_folder: {}".format(img_folder)) log.info("vid_folder: {}".format(vid_folder)) log.info("fps: {}".format(args.fps)) try: time_prev = time.time() # loop to sample frames with events for frame_cnt, port_data in enumerate( sample_frames(ports, sample_events, num_frames)): sample_time_dt = time.time() - time_prev if frame_cnt % 10 == 0: log.info("frame {} time_prev: {}".format( frame_cnt, time.time() - time_prev)) time_prev = time.time() # set events to trigger cams for e in sample_events: e.set() if frame_cnt == 0: # skip first frame because not synchronized with event log.info("START: {}".format(frame_cnt)) continue elif (sample_time_dt - 1.0 / args.fps) > 0.1: log.warn("sampling frame taks too long for fps") # check sampel time diff if len(ports) > 1: dt = [ np.abs(p1["time"] - p2["time"]) for p1, p2 in combinations(port_data.values(), 2) ] # log.info('dt: {}'.format(np.mean(dt))) if np.max(dt) > 0.1: log.warn( "camera sample max time dt: {}, check light condition and camera models" .format(np.max(dt))) assert all(frame_cnt == d["num"] for d in port_data.values()), "out of sync" im_data_q.put(port_data) if args.display: disp_q.put(port_data) time.sleep(1.0 / args.fps) except KeyboardInterrupt: # create vids form images save before im_shape = {p: d["frame"].shape for p, d in port_data.items()} img_files = defaultdict(list) for d in get_all_queue_result(im_file_q): for p, f in d.items(): img_files[p].append(f) # TODO start for each a procresss and join for view_i, p in enumerate(port_data.keys()): save_vid_worker(img_files[p], view_i, vid_folder, args.tag, im_shape[p], args.fps) cv2.destroyAllWindows()