class Dataset(VisionDataset): def __init__(self, datapath, annotations_path, transforms, cached_all_train_data_name='cached_all_train_data.pt', cached_valid_train_data_name='cached_valid_train_data.pt', cached_all_val_data_name='cached_all_val_data.pt', cached_valid_val_data_name='cached_valid_val_data.pt', get_video_wise=False, val=False, fps=None, frames_per_clip=None, step_between_clips=None, start_id=0): self.get_video_wise = get_video_wise self.start_id = start_id self.transforms = transforms #Load annotations = fails_data(in original file) with open(annotations_path) as f: self.annotations = json.load(f) #Load videos if fps is None: fps = 16 if frames_per_clip is None: frames_per_clip = fps if step_between_clips is None: step_between_clips = int(fps * 0.25) # FPS X seconds = frames else: step_between_clips = int(fps * step_between_clips) # FPS X seconds = frames #For train_data if not val: if os.path.exists(os.path.join(datapath,'train',cached_valid_train_data_name)): self.video_clips = torch.load(os.path.join(datapath,'train',cached_valid_train_data_name)) print('\nLoaded Valid train data from cache...') else: #Load all train data all_video_list = glob(os.path.join(datapath, 'train', '**', '*.mp4'), recursive=True) if os.path.exists(os.path.join(datapath,'train',cached_all_train_data_name)): self.all_video_clips = torch.load(os.path.join(datapath,'train',cached_all_train_data_name)) print('\nLoaded all train data from cache...') else: print('\nProcessing all train data...') self.all_video_clips = VideoClips(all_video_list, frames_per_clip, step_between_clips, fps) torch.save(self.all_video_clips, os.path.join(datapath,'train',cached_all_train_data_name)) #Separate out all valid videos print('\nSEPARATING VALID VIDEOS... VAL=',val) valid_video_paths = [] print('Computing all clips...') self.all_video_clips.compute_clips(frames_per_clip, step_between_clips, fps) for video_idx, vid_clips in tqdm(enumerate(self.all_video_clips.clips), total=len(self.all_video_clips.clips)): video_path = self.all_video_clips.video_paths[video_idx] #Ignore if annotation doesnt exist if os.path.splitext(os.path.basename(video_path))[0] not in self.annotations: continue #Ignore if moov atom error try: #Ignore if video attribute doesnt qualify t_unit = av.open(video_path, metadata_errors='ignore').streams[0].time_base t_fail = sorted(self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['t']) t_fail = t_fail[len(t_fail) // 2] if t_fail < 0 or not 0.01 <= statistics.median(self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['rel_t']) <= 0.99 or \ self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['len'] < 3.2 or \ self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['len'] > 30: continue except: continue #If none of the above happens, then save the video path valid_video_paths.append(video_path) self.video_clips = VideoClips(valid_video_paths, frames_per_clip, step_between_clips, fps) torch.save(self.video_clips, os.path.join(datapath,'train',cached_valid_train_data_name)) print('Saved valid train data in cache.') #For test data else: if os.path.exists(os.path.join(datapath,'val',cached_valid_val_data_name)): self.video_clips = torch.load(os.path.join(datapath,'val',cached_valid_val_data_name)) print('\nLoaded Valid Val data from cache...') else: #Load all val data all_video_list = glob(os.path.join(datapath, 'val', '**', '*.mp4'), recursive=True) if os.path.exists(os.path.join(datapath,'val',cached_all_val_data_name)): self.all_video_clips = torch.load(os.path.join(datapath,'val',cached_all_val_data_name)) print('\nLoaded all val data from cache...') else: print('\nProcessing all val data...') self.all_video_clips = VideoClips(all_video_list, frames_per_clip, step_between_clips, fps) torch.save(self.all_video_clips, os.path.join(datapath,'val',cached_all_val_data_name)) #Separate out all valid videos print('\nSEPARATING VALID VIDEOS... VAL=',val) valid_video_paths = [] print('Computing all clips...') self.all_video_clips.compute_clips(frames_per_clip, step_between_clips, fps) for video_idx, vid_clips in tqdm(enumerate(self.all_video_clips.clips), total=len(self.all_video_clips.clips)): video_path = self.all_video_clips.video_paths[video_idx] #Ignore if annotation doesnt exist if os.path.splitext(os.path.basename(video_path))[0] not in self.annotations: continue #Ignore if moov atom error try: #Ignore if video attribute doesnt qualify t_unit = av.open(video_path, metadata_errors='ignore').streams[0].time_base t_fail = sorted(self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['t']) t_fail = t_fail[len(t_fail) // 2] if t_fail < 0 or not 0.01 <= statistics.median(self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['rel_t']) <= 0.99 or \ self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['len'] < 3.2 or \ self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['len'] > 30: continue except: continue #if moov atom exception occurs then ignore clip try: temp = av.open(video_path, metadata_errors='ignore').streams[0].time_base except: continue #Ignore video attributes for test data : Like video_len and median(rel_t) valid_video_paths.append(video_path) self.video_clips = VideoClips(valid_video_paths, frames_per_clip, step_between_clips, fps) torch.save(self.video_clips, os.path.join(datapath,'val',cached_valid_val_data_name)) print('Saved valid val data in cache.') #Load borders.json : LATER #Generate all mini-clips of size frames_per_clip from all video clips print('\nGenerating VALID mini-clips and labels from',len(self.video_clips.clips),'videos... VAL=',val) self.video_clips.compute_clips(frames_per_clip, step_between_clips, fps) self.video_clips.labels = [] for video_idx, vid_clips in tqdm(enumerate(self.video_clips.clips), total=len(self.video_clips.clips)): video_path = self.video_clips.video_paths[video_idx] t_unit = av.open(video_path, metadata_errors='ignore').streams[0].time_base t_fail = sorted(self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['t']) t_fail = t_fail[len(t_fail) // 2] prev_label = 0 first_one_idx = len(vid_clips) first_two_idx = len(vid_clips) for clip_idx, clip in enumerate(vid_clips): #clip == timestamps start_pts = clip[0].item() end_pts = clip[-1].item() t_start = float(t_unit * start_pts) t_end = float(t_unit * end_pts) label = 0 if t_start <= t_fail <= t_end: label = 1 elif t_start > t_fail: label = 2 if label == 1 and prev_label == 0: first_one_idx = clip_idx elif label == 2 and prev_label == 1: first_two_idx = clip_idx break prev_label = label self.video_clips.labels.append( [0 for i in range(first_one_idx)] + [1 for i in range(first_one_idx, first_two_idx)] + [2 for i in range(first_two_idx, len(vid_clips))]) #Leaving the part: balance_fails_only (I dunno what this is!!) print('\nNumber of CLIPS generated:', self.video_clips.num_clips()) def __len__(self): if self.get_video_wise: return len(self.video_clips.labels) - self.start_id else: return self.video_clips.num_clips() def __getitem__(self, idx): idx = self.start_id + idx if self.get_video_wise: #TO return all clips of a single video labels = self.video_clips.labels[idx] #here idx is video_idx num_of_clips = len(labels) num_of_clips_before_this_video = 0 for l in self.video_clips.labels[:idx]: num_of_clips_before_this_video += len(l) start_clip_id = num_of_clips_before_this_video end_clip_id = num_of_clips_before_this_video + num_of_clips video = [] for idx in range(start_clip_id, end_clip_id): clip, _, _, _ = self.video_clips.get_clip(idx) if self.transforms: clip = self.transforms(clip) clip = clip.permute(1,0,2,3) video.append(clip.unsqueeze(0)) video = torch.cat(video, dim=0) #labels = torch.cat(labels) return video, labels else: video_idx, clip_idx = self.video_clips.get_clip_location(idx) video, audio, info, video_idx = self.video_clips.get_clip(idx) video_path = self.video_clips.video_paths[video_idx] label = self.video_clips.labels[video_idx][clip_idx] if self.transforms is not None: video = self.transforms(video) video = video.permute(1,0,2,3) return video, label
class KineticsAndFails(VisionDataset): FLOW_FPS = 8 def __init__(self, fails_path, kinetics_path, frames_per_clip, step_between_clips, fps, transform=None, extensions=('.mp4', ), video_clips=None, fails_only=False, val=False, balance_fails_only=False, get_clip_times=False, fails_video_list=None, fns_to_remove=None, load_flow=False, flow_histogram=False, fails_flow_path=None, all_fail_videos=False, selfsup_loss=None, clip_interval_factor=None, labeled_fails=True, debug_dataset=False, anticipate_label=0, data_proportion=1, **kwargs): self.clip_len = frames_per_clip / fps self.clip_step = step_between_clips / fps self.clip_interval_factor = clip_interval_factor self.fps = fps self.t = transform self.load_flow = load_flow self.flow_histogram = flow_histogram self.video_clips = None self.fails_path = fails_path self.fails_flow_path = fails_flow_path self.selfsup_loss = selfsup_loss self.get_clip_times = get_clip_times self.anticipate_label = anticipate_label data_proportion = 1 if val else data_proportion if video_clips: self.video_clips = video_clips else: assert fails_path is None or fails_video_list is None video_list = fails_video_list or glob( os.path.join(fails_path, '**', '*.mp4'), recursive=True) if not fails_only: kinetics_cls = torch.load("PATH/TO/kinetics_classes.pt") kinetics_dist = torch.load("PATH/TO/dist.pt") s = len(video_list) for i, n in kinetics_dist.items(): n *= s video_list += sorted( glob(os.path.join(kinetics_path, '**', kinetics_cls[i], '*.mp4'), recursive=True))[:round(n)] self.video_clips = VideoClips(video_list, frames_per_clip, step_between_clips, fps) with open("PATH/TO/borders.json") as f: self.fails_borders = json.load(f) with open("PATH/TO/all_mturk_data.json") as f: self.fails_data = json.load(f) self.fails_only = fails_only self.t_from_clip_idx = lambda idx: ( (step_between_clips * idx) / fps, (step_between_clips * idx + frames_per_clip) / fps) if not balance_fails_only: # no support for recompute clips after balance calc yet self.video_clips.compute_clips(frames_per_clip, step_between_clips, fps) if video_clips is None and fails_only and labeled_fails: # if True: if not all_fail_videos: idxs = [] for i, video_path in enumerate(self.video_clips.video_paths): video_path = os.path.splitext( os.path.basename(video_path))[0] if video_path in self.fails_data: idxs.append(i) self.video_clips = self.video_clips.subset(idxs) # if not val and balance_fails_only: # balance dataset # ratios = {0: 0.3764, 1: 0.0989, 2: 0.5247} self.video_clips.labels = [] self.video_clips.compute_clips(frames_per_clip, step_between_clips, fps) for video_idx, vid_clips in tqdm(enumerate(self.video_clips.clips), total=len( self.video_clips.clips)): video_path = self.video_clips.video_paths[video_idx] if all_fail_videos and os.path.splitext( os.path.basename( video_path))[0] not in self.fails_data: self.video_clips.labels.append([-1 for _ in vid_clips]) continue t_unit = av.open(video_path, metadata_errors='ignore').streams[0].time_base t_fail = sorted(self.fails_data[os.path.splitext( os.path.basename(video_path))[0]]['t']) t_fail = t_fail[len(t_fail) // 2] if t_fail < 0 or not 0.01 <= statistics.median( self.fails_data[os.path.splitext(os.path.basename(video_path))[0]]['rel_t']) <= 0.99 or \ self.fails_data[os.path.splitext(os.path.basename(video_path))[0]]['len'] < 3.2 or \ self.fails_data[os.path.splitext(os.path.basename(video_path))[0]]['len'] > 30: self.video_clips.clips[video_idx] = torch.Tensor() self.video_clips.resampling_idxs[video_idx] = torch.Tensor( ) self.video_clips.labels.append([]) continue prev_label = 0 first_one_idx = len(vid_clips) first_two_idx = len(vid_clips) for clip_idx, clip in enumerate(vid_clips): start_pts = clip[0].item() end_pts = clip[-1].item() t_start = float(t_unit * start_pts) t_end = float(t_unit * end_pts) label = 0 if t_start <= t_fail <= t_end: label = 1 elif t_start > t_fail: label = 2 if label == 1 and prev_label == 0: first_one_idx = clip_idx elif label == 2 and prev_label == 1: first_two_idx = clip_idx break prev_label = label self.video_clips.labels.append( [0 for i in range(first_one_idx)] + [1 for i in range(first_one_idx, first_two_idx)] + [2 for i in range(first_two_idx, len(vid_clips))]) if balance_fails_only and not val: balance_idxs = [] counts = (first_one_idx, first_two_idx - first_one_idx, len(vid_clips) - first_two_idx) offsets = torch.LongTensor([0] + list(counts)).cumsum( 0)[:-1].tolist() ratios = (1, 0.93, 1 / 0.93) labels = (0, 1, 2) lbl_mode = max(labels, key=lambda i: counts[i]) for i in labels: if i != lbl_mode and counts[i] > 0: n_to_add = round( counts[i] * ((counts[lbl_mode] * ratios[i] / counts[i]) - 1)) tmp = list( range(offsets[i], counts[i] + offsets[i])) random.shuffle(tmp) tmp_bal_idxs = [] while len(tmp_bal_idxs) < n_to_add: tmp_bal_idxs += tmp tmp_bal_idxs = tmp_bal_idxs[:n_to_add] balance_idxs += tmp_bal_idxs if not balance_idxs: continue t = torch.cat( (vid_clips, torch.stack([vid_clips[i] for i in balance_idxs]))) self.video_clips.clips[video_idx] = t vid_resampling_idxs = self.video_clips.resampling_idxs[ video_idx] try: t = torch.cat( (vid_resampling_idxs, torch.stack([ vid_resampling_idxs[i] for i in balance_idxs ]))) self.video_clips.resampling_idxs[video_idx] = t except IndexError: pass self.video_clips.labels[-1] += [ self.video_clips.labels[-1][i] for i in balance_idxs ] clip_lengths = torch.as_tensor( [len(v) for v in self.video_clips.clips]) self.video_clips.cumulative_sizes = clip_lengths.cumsum(0).tolist() fns_removed = 0 if fns_to_remove and not val: for i, video_path in enumerate(self.video_clips.video_paths): if fns_removed > len(self.video_clips.video_paths) // 4: break video_path = os.path.splitext(os.path.basename(video_path))[0] if video_path in fns_to_remove: fns_removed += 1 self.video_clips.clips[i] = torch.Tensor() self.video_clips.resampling_idxs[i] = torch.Tensor() self.video_clips.labels[i] = [] clip_lengths = torch.as_tensor( [len(v) for v in self.video_clips.clips]) self.video_clips.cumulative_sizes = clip_lengths.cumsum(0).tolist() if kwargs['local_rank'] <= 0: print( f'removed videos from {fns_removed} out of {len(self.video_clips.video_paths)} files' ) # if not fails_path.startswith("PATH/TO/scenes"): for i, p in enumerate(self.video_clips.video_paths): self.video_clips.video_paths[i] = p.replace( "PATH/TO/scenes", os.path.dirname(fails_path)) self.debug_dataset = debug_dataset if debug_dataset: # self.video_clips = self.video_clips.subset([0]) pass if data_proportion < 1: rng = random.Random() rng.seed(23719) lbls = self.video_clips.labels subset_idxs = rng.sample( range(len(self.video_clips.video_paths)), int(len(self.video_clips.video_paths) * data_proportion)) self.video_clips = self.video_clips.subset(subset_idxs) self.video_clips.labels = [lbls[i] for i in subset_idxs] def trim_borders(self, img, fn): l, r = self.fails_borders[os.path.splitext(os.path.basename(fn))[0]] w = img.shape[2] # THWC if l > 0 and r > 0: img = img[:, :, round(w * l):round(w * r)] return img def __len__(self): return self.video_clips.num_clips() def compute_clip_times(self, video_idx, clip_idx): video_path = self.video_clips.video_paths[video_idx] video_path = os.path.join( self.fails_path, os.path.sep.join(video_path.rsplit(os.path.sep, 2)[-2:])) clip_pts = self.video_clips.clips[video_idx][clip_idx] start_pts = clip_pts[0].item() end_pts = clip_pts[-1].item() t_unit = av.open(video_path, metadata_errors='ignore').streams[0].time_base t_start = float(t_unit * start_pts) t_end = float(t_unit * end_pts) return t_start, t_end def __getitem__(self, idx): if self.load_flow: video_idx, clip_idx = self.video_clips.get_clip_location(idx) video_path = self.video_clips.video_paths[video_idx] video_path = os.path.join( self.fails_path, os.path.sep.join(video_path.rsplit(os.path.sep, 2)[-2:])) label = self.video_clips.labels[video_idx][clip_idx] flow_path = os.path.join( self.fails_flow_path, os.path.sep.join( os.path.splitext(video_path)[0].rsplit(os.path.sep, 2)[-2:])) t_start, t_end = self.compute_clip_times(video_idx, clip_idx) frame_start = round(t_start * self.FLOW_FPS) n_frames = round(self.clip_len * self.FLOW_FPS) flow = [] for frame_i in range(frame_start, frame_start + n_frames): frame_fn = os.path.join(flow_path, f'{frame_i:06}.flo') try: flow.append( torch.load(frame_fn, map_location=torch.device('cpu')).permute( 1, 2, 0).data.numpy()) except: pass while len(flow) < n_frames: flow += flow flow = flow[:n_frames] flow = torch.Tensor(flow) flow = self.trim_borders(flow, video_path) if self.t is not None: flow = self.t(flow) return flow, label, (flow_path, t_start, t_end) else: video_idx, clip_idx = self.video_clips.get_clip_location(idx) if self.anticipate_label: assert not self.selfsup_loss, 'no anticipation with self supervision' video_path = self.video_clips.video_paths[video_idx] label = self.video_clips.labels[video_idx][clip_idx] idx -= round(self.anticipate_label / self.clip_step) new_video_idx, new_clip_idx = self.video_clips.get_clip_location( idx) video, *_ = self.video_clips.get_clip(idx) video = self.trim_borders(video, video_path) if self.t is not None: video = self.t(video) new_t_start, new_t_end = self.compute_clip_times( new_video_idx, new_clip_idx) old_t_start, old_t_end = self.compute_clip_times( video_idx, clip_idx) if new_video_idx != video_idx or new_t_start > old_t_start: label = -1 return video, label, (video_path, new_t_start, new_t_end, []) video, audio, info, video_idx = self.video_clips.get_clip(idx) video_path = self.video_clips.video_paths[video_idx] # print(video_path) try: label = self.video_clips.labels[video_idx][clip_idx] # if self.anticipate_label: # video_path = self.video_clips.video_paths[video_idx] # t_fail = statistics.median(self.fails_data[os.path.splitext(os.path.basename(video_path))[0]]['t']) # t_start, t_end = self.compute_clip_times(video_idx, clip_idx) # t_start += self.anticipate_label # t_end += self.anticipate_label # label = 0 # if t_start <= t_fail <= t_end: # label = 1 # elif t_start > t_fail: # label = 2 except: label = -1 if label == 0 or self.fails_only: video = self.trim_borders(video, video_path) if self.debug_dataset: pass # video[:] = 0 # video[..., 0] = 255 if self.t is not None: video = self.t(video) t_start = t_end = -1 if self.get_clip_times: t_start, t_end = self.compute_clip_times(video_idx, clip_idx) other = [] if self.selfsup_loss == 'pred_middle' or self.selfsup_loss == 'sort' or self.selfsup_loss == 'ctc': k = round(self.clip_len / self.clip_step * self.clip_interval_factor) video_l = [video] try: pvideo, paudio, pinfo, pvideo_idx = self.video_clips.get_clip( idx - k) except: pvideo_idx = -1 try: nvideo, naudio, ninfo, nvideo_idx = self.video_clips.get_clip( idx + k) except: nvideo_idx = -1 t_start, _ = self.compute_clip_times( *self.video_clips.get_clip_location(idx)) try: p_t_start, _ = self.compute_clip_times( *self.video_clips.get_clip_location(idx - k)) except: p_t_start = 1000000000 try: n_t_start, _ = self.compute_clip_times( *self.video_clips.get_clip_location(idx + k)) except: n_t_start = -1000000000 # if pvideo_idx == video_idx: # assert p_t_start < t_start, f"{t_start} <= prev video time {p_t_start}" # if nvideo_idx == video_idx: # assert t_start < n_t_start, f"{t_start} >= next video time {n_t_start}" if pvideo_idx == video_idx and p_t_start < t_start: pvideo = self.trim_borders(pvideo, video_path) if self.t is not None: pvideo = self.t(pvideo) video_l.insert(0, pvideo) else: video_l.insert(0, torch.full_like(video, -1)) if nvideo_idx == video_idx and t_start < n_t_start: nvideo = self.trim_borders(nvideo, video_path) if self.t is not None: nvideo = self.t(nvideo) video_l.append(nvideo) else: video_l.append(torch.full_like(video, -1)) video_l = torch.stack(video_l) video = video_l other = [nvideo_idx == video_idx and pvideo_idx == video_idx] if self.selfsup_loss == 'fps': other = [self.fps] other.append(idx) return video, label, (video_path, t_start, t_end, *other)