def test_video_clips(self):
        with get_list_of_videos(num_videos=3) as video_list:
            video_clips = VideoClips(video_list, 5, 5)
            self.assertEqual(video_clips.num_clips(), 1 + 2 + 3)
            for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0),
                                                (2, 1), (2, 2)]):
                video_idx, clip_idx = video_clips.get_clip_location(i)
                self.assertEqual(video_idx, v_idx)
                self.assertEqual(clip_idx, c_idx)

            video_clips = VideoClips(video_list, 6, 6)
            self.assertEqual(video_clips.num_clips(), 0 + 1 + 2)
            for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]):
                video_idx, clip_idx = video_clips.get_clip_location(i)
                self.assertEqual(video_idx, v_idx)
                self.assertEqual(clip_idx, c_idx)

            video_clips = VideoClips(video_list, 6, 1)
            self.assertEqual(video_clips.num_clips(),
                             0 + (10 - 6 + 1) + (15 - 6 + 1))
            for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0),
                                    (6, 2, 1)]:
                video_idx, clip_idx = video_clips.get_clip_location(i)
                self.assertEqual(video_idx, v_idx)
                self.assertEqual(clip_idx, c_idx)
示例#2
0
class Mice(VisionDataset):
    def __init__(self,
                 root,
                 frames_per_clip,
                 step_between_clips=1,
                 frame_rate=None,
                 extensions=("mp4", ),
                 transform=None,
                 _precomputed_metadata=None,
                 num_workers=1,
                 _video_width=0,
                 _video_height=0,
                 _video_min_dimension=0,
                 _audio_samples=0,
                 _audio_channels=0):
        super(Mice, self).__init__(root)
        classes = list(sorted(list_dir(root)))
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        self.samples = make_dataset(self.root,
                                    class_to_idx,
                                    extensions,
                                    is_valid_file=None)
        self.classes = classes
        video_list = [x[0] for x in self.samples]

        self.video_clips = VideoClips(
            video_list,
            frames_per_clip,
            step_between_clips,
            frame_rate,
            _precomputed_metadata,
            num_workers=num_workers,
            _video_width=_video_width,
            _video_height=_video_height,
            _video_min_dimension=_video_min_dimension,
            _audio_samples=_audio_samples,
            _audio_channels=_audio_channels,
        )
        self.transform = transform

    @property
    def metadata(self):
        return self.video_clips.metadata

    def __len__(self):
        return self.video_clips.num_clips()

    def __getitem__(self, idx):
        video, _, _, video_idx = self.video_clips.get_clip(idx)
        video_idx, clip_idx = self.video_clips.get_clip_location(idx)
        label = self.samples[video_idx][1]

        if self.transform is not None:
            video = self.transform(video)

        return video, label, video_idx, clip_idx
示例#3
0
def DownsampleClipSampler(video_clips: VideoClips, labels: List[int]):
    vc_labels = [
        labels[video_clips.get_clip_location(idx)[0]]
        for idx in range(video_clips.num_clips())
    ]
    cnt = min(vc_labels.count(a) for a in set(labels))
    indices = []
    for a in set(labels):
        indices += random.sample(
            [i for i, c in enumerate(vc_labels) if c == a], cnt)
    return SubsetRandomSampler(indices)
示例#4
0
def BalancedClipSampler(video_clips: VideoClips,
                        clip_labels: List[int],
                        num_samples=None,
                        log_weight=False):
    assert len(video_clips.clips) == len(clip_labels)
    vc_labels = [
        clip_labels[video_clips.get_clip_location(idx)[0]]
        for idx in range(video_clips.num_clips())
    ]
    if num_samples is None:
        num_samples = len(video_clips.video_paths)
    return BalancedSampler(vc_labels, num_samples, log_weight)
示例#5
0
def BalancedPathSampler(video_clips: VideoClips,
                        clip_labels: List[int],
                        num_samples=None,
                        log_weight=False):
    assert len(video_clips.clips) == len(clip_labels)
    vc_labels = []
    for idx in range(video_clips.num_clips()):
        vidx, _ = video_clips.get_clip_location(idx)
        vc_labels.append((clip_labels[vidx], video_clips.video_paths[vidx]))

    if num_samples is None:
        num_samples = len(video_clips.video_paths)
    return BalancedSampler(vc_labels, num_samples, log_weight)
    def test_video_clips(self, tmpdir):
        video_list = get_list_of_videos(tmpdir, num_videos=3)
        video_clips = VideoClips(video_list, 5, 5, num_workers=2)
        assert video_clips.num_clips() == 1 + 2 + 3
        for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]):
            video_idx, clip_idx = video_clips.get_clip_location(i)
            assert video_idx == v_idx
            assert clip_idx == c_idx

        video_clips = VideoClips(video_list, 6, 6)
        assert video_clips.num_clips() == 0 + 1 + 2
        for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]):
            video_idx, clip_idx = video_clips.get_clip_location(i)
            assert video_idx == v_idx
            assert clip_idx == c_idx

        video_clips = VideoClips(video_list, 6, 1)
        assert video_clips.num_clips() == 0 + (10 - 6 + 1) + (15 - 6 + 1)
        for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]:
            video_idx, clip_idx = video_clips.get_clip_location(i)
            assert video_idx == v_idx
            assert clip_idx == c_idx
class VideoIter(data.Dataset):
    def __init__(self,
                 clip_length,
                 frame_stride,
                 dataset_path=None,
                 annotation_path=None,
                 video_transform=None,
                 name="<NO_NAME>",
                 shuffle_list_seed=None,
                 single_load=False):
        super(VideoIter, self).__init__()
        self.dataset_path = dataset_path
        self.frames_stride = frame_stride
        self.video_transform = video_transform
        self.rng = np.random.RandomState(
            shuffle_list_seed if shuffle_list_seed else 0)

        # load video list
        if dataset_path is not None:
            self.video_list = self._get_video_list(
                dataset_path=self.dataset_path)

        elif type(annotation_path) == list():
            self.video_list = annotation_path
        else:
            self.video_list = [annotation_path]

        self.total_clip_length_in_frames = clip_length * frame_stride

        if single_load:
            print("loading each file at a time")
            self.video_clips = VideoClips(
                video_paths=[self.video_list[0]],
                clip_length_in_frames=self.total_clip_length_in_frames,
                frames_between_clips=self.total_clip_length_in_frames)
            with tqdm(total=len(self.video_list[1:]) + 1,
                      desc=' total % of videos loaded') as pbar1:
                for video_list_used in self.video_list[1:]:
                    print(video_list_used)
                    pbar1.update(1)
                    video_clips_out = VideoClips(
                        video_paths=[video_list_used],
                        clip_length_in_frames=self.total_clip_length_in_frames,
                        frames_between_clips=self.total_clip_length_in_frames)
                    self.video_clips.clips.append(video_clips_out.clips[0])
                    self.video_clips.cumulative_sizes.append(
                        self.video_clips.cumulative_sizes[-1] +
                        video_clips_out.cumulative_sizes[0])
                    self.video_clips.resampling_idxs.append(
                        video_clips_out.resampling_idxs[0])
                    self.video_clips.video_fps.append(
                        video_clips_out.video_fps[0])
                    self.video_clips.video_paths.append(
                        video_clips_out.video_paths[0])
                    self.video_clips.video_pts.append(
                        video_clips_out.video_pts[0])
        else:
            print("single loader used")
            self.video_clips = VideoClips(
                video_paths=self.video_list,
                clip_length_in_frames=self.total_clip_length_in_frames,
                frames_between_clips=self.total_clip_length_in_frames)

        logging.info(
            "VideoIter:: iterator initialized (phase: '{:s}', num: {:d})".
            format(name, len(self.video_list)))

    def getitem_from_raw_video(self, idx):
        # get current video info
        video, _, _, _ = self.video_clips.get_clip(idx)
        video_idx, clip_idx = self.video_clips.get_clip_location(idx)
        video_path = self.video_clips.video_paths[video_idx]
        in_clip_frames = list(
            range(0, self.total_clip_length_in_frames, self.frames_stride))
        video = video[in_clip_frames]
        if self.video_transform is not None:
            video = self.video_transform(video)

        label = 0 if "Normal" in video_path else 1

        dir, file = video_path.split(os.sep)[-2:]
        file = file.split('.')[0]

        return video, label, clip_idx, dir, file

    def __len__(self):
        return len(self.video_clips)

    def __getitem__(self, index):
        succ = False
        while not succ:
            try:
                clip_input, label, sampled_idx, dir, file = self.getitem_from_raw_video(
                    index)
                succ = True
            except Exception as e:
                index = self.rng.choice(range(0, self.__len__()))
                logging.warning(
                    "VideoIter:: ERROR!! (Force using another index:\n{})\n{}".
                    format(index, e))

        return clip_input, label, sampled_idx, dir, file

    @staticmethod
    def _get_video_list(dataset_path):
        assert os.path.exists(
            dataset_path), "VideoIter:: failed to locate: `{}'".format(
                dataset_path)
        vid_list = []
        for path, subdirs, files in os.walk(dataset_path):
            for name in files:
                vid_list.append(os.path.join(path, name))

        return vid_list
class Dataset(VisionDataset):
    def __init__(self, datapath, annotations_path, transforms, 
        cached_all_train_data_name='cached_all_train_data.pt', cached_valid_train_data_name='cached_valid_train_data.pt', 
        cached_all_val_data_name='cached_all_val_data.pt', cached_valid_val_data_name='cached_valid_val_data.pt',
        get_video_wise=False, val=False, fps=None, frames_per_clip=None, step_between_clips=None, start_id=0):

        self.get_video_wise = get_video_wise
        self.start_id = start_id
        self.transforms = transforms

        #Load annotations = fails_data(in original file)
        with open(annotations_path) as f:
            self.annotations = json.load(f)

        #Load videos
        if fps is None:
            fps = 16
        if frames_per_clip is None:
            frames_per_clip = fps
        if step_between_clips is None:  
            step_between_clips = int(fps * 0.25)    # FPS X seconds = frames
        else:          
            step_between_clips = int(fps * step_between_clips)    # FPS X seconds = frames
        
        #For train_data
        if not val:

            if os.path.exists(os.path.join(datapath,'train',cached_valid_train_data_name)):
                self.video_clips = torch.load(os.path.join(datapath,'train',cached_valid_train_data_name)) 
                print('\nLoaded Valid train data from cache...')
            else:
                #Load all train data
                all_video_list = glob(os.path.join(datapath, 'train', '**', '*.mp4'), recursive=True)

                if os.path.exists(os.path.join(datapath,'train',cached_all_train_data_name)):
                    self.all_video_clips = torch.load(os.path.join(datapath,'train',cached_all_train_data_name))
                    print('\nLoaded all train data from cache...')
                else:
                    print('\nProcessing all train data...')
                    self.all_video_clips = VideoClips(all_video_list, frames_per_clip, step_between_clips, fps)
                    torch.save(self.all_video_clips, os.path.join(datapath,'train',cached_all_train_data_name))

                #Separate out all valid videos  
                print('\nSEPARATING VALID VIDEOS... VAL=',val)
                valid_video_paths = []
                print('Computing all clips...')
                self.all_video_clips.compute_clips(frames_per_clip, step_between_clips, fps)
                for video_idx, vid_clips in tqdm(enumerate(self.all_video_clips.clips), total=len(self.all_video_clips.clips)):
                    video_path = self.all_video_clips.video_paths[video_idx]
                    
                    #Ignore if annotation doesnt exist
                    if os.path.splitext(os.path.basename(video_path))[0] not in self.annotations:
                        continue
                    #Ignore if moov atom error
                    try:
                        #Ignore if video attribute doesnt qualify
                        t_unit = av.open(video_path, metadata_errors='ignore').streams[0].time_base
                        t_fail = sorted(self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['t'])
                        t_fail = t_fail[len(t_fail) // 2]
                        if t_fail < 0 or not 0.01 <= statistics.median(self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['rel_t']) <= 0.99 or \
                                                    self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['len'] < 3.2 or \
                                                    self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['len'] > 30:
                            continue
                    except:
                        continue            
                    #If none of the above happens, then save the video path
                    valid_video_paths.append(video_path)

                self.video_clips = VideoClips(valid_video_paths, frames_per_clip, step_between_clips, fps)
                torch.save(self.video_clips, os.path.join(datapath,'train',cached_valid_train_data_name))
                print('Saved valid train data in cache.')

        #For test data
        else:        
            if os.path.exists(os.path.join(datapath,'val',cached_valid_val_data_name)):
                self.video_clips = torch.load(os.path.join(datapath,'val',cached_valid_val_data_name)) 
                print('\nLoaded Valid Val data from cache...')
            else:
                #Load all val data
                all_video_list = glob(os.path.join(datapath, 'val', '**', '*.mp4'), recursive=True)

                if os.path.exists(os.path.join(datapath,'val',cached_all_val_data_name)):
                    self.all_video_clips = torch.load(os.path.join(datapath,'val',cached_all_val_data_name))
                    print('\nLoaded all val data from cache...')
                else:
                    print('\nProcessing all val data...')
                    self.all_video_clips = VideoClips(all_video_list, frames_per_clip, step_between_clips, fps)
                    torch.save(self.all_video_clips, os.path.join(datapath,'val',cached_all_val_data_name))

                #Separate out all valid videos  
                print('\nSEPARATING VALID VIDEOS... VAL=',val)
                valid_video_paths = []
                print('Computing all clips...')
                self.all_video_clips.compute_clips(frames_per_clip, step_between_clips, fps)
                for video_idx, vid_clips in tqdm(enumerate(self.all_video_clips.clips), total=len(self.all_video_clips.clips)):
                    video_path = self.all_video_clips.video_paths[video_idx]
                    
                    #Ignore if annotation doesnt exist
                    if os.path.splitext(os.path.basename(video_path))[0] not in self.annotations:
                        continue
                    
                    #Ignore if moov atom error
                    try:
                        #Ignore if video attribute doesnt qualify
                        t_unit = av.open(video_path, metadata_errors='ignore').streams[0].time_base
                        t_fail = sorted(self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['t'])
                        t_fail = t_fail[len(t_fail) // 2]
                        if t_fail < 0 or not 0.01 <= statistics.median(self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['rel_t']) <= 0.99 or \
                                                    self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['len'] < 3.2 or \
                                                    self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['len'] > 30:
                            continue
                    except:
                        continue

                    #if moov atom exception occurs then ignore clip
                    try:
                        temp = av.open(video_path, metadata_errors='ignore').streams[0].time_base
                    except:
                        continue

                    #Ignore video attributes for test data : Like video_len and median(rel_t)  
                             
                    valid_video_paths.append(video_path)

                self.video_clips = VideoClips(valid_video_paths, frames_per_clip, step_between_clips, fps)
                torch.save(self.video_clips, os.path.join(datapath,'val',cached_valid_val_data_name))
                print('Saved valid val data in cache.')

        #Load borders.json : LATER

        #Generate all mini-clips of size frames_per_clip from all video clips
        print('\nGenerating VALID mini-clips and labels from',len(self.video_clips.clips),'videos... VAL=',val)
        self.video_clips.compute_clips(frames_per_clip, step_between_clips, fps)
        self.video_clips.labels = []
        for video_idx, vid_clips in tqdm(enumerate(self.video_clips.clips), total=len(self.video_clips.clips)):

            video_path = self.video_clips.video_paths[video_idx]
           
            t_unit = av.open(video_path, metadata_errors='ignore').streams[0].time_base
            t_fail = sorted(self.annotations[os.path.splitext(os.path.basename(video_path))[0]]['t'])
            t_fail = t_fail[len(t_fail) // 2]                
            prev_label = 0
            first_one_idx = len(vid_clips)
            first_two_idx = len(vid_clips)
            for clip_idx, clip in enumerate(vid_clips): #clip == timestamps
                start_pts = clip[0].item()
                end_pts = clip[-1].item()
                t_start = float(t_unit * start_pts)
                t_end = float(t_unit * end_pts)
                label = 0
                if t_start <= t_fail <= t_end:
                    label = 1
                elif t_start > t_fail:
                    label = 2
                if label == 1 and prev_label == 0:
                    first_one_idx = clip_idx
                elif label == 2 and prev_label == 1:
                    first_two_idx = clip_idx
                    break
                prev_label = label

            self.video_clips.labels.append(
                [0 for i in range(first_one_idx)] + [1 for i in range(first_one_idx, first_two_idx)] +
                [2 for i in range(first_two_idx, len(vid_clips))])

            #Leaving the part: balance_fails_only (I dunno what this is!!)

        print('\nNumber of CLIPS generated:', self.video_clips.num_clips())


    def __len__(self):
        if self.get_video_wise:
            return len(self.video_clips.labels) - self.start_id
        else:
            return self.video_clips.num_clips()

    def __getitem__(self, idx):
        idx = self.start_id + idx

        if self.get_video_wise:             #TO return all clips of a single video 

            labels = self.video_clips.labels[idx]   #here idx is video_idx
            num_of_clips = len(labels)
            
            num_of_clips_before_this_video = 0
            for l in self.video_clips.labels[:idx]:
                num_of_clips_before_this_video += len(l)

            start_clip_id = num_of_clips_before_this_video
            end_clip_id = num_of_clips_before_this_video + num_of_clips 

            video = []
            for idx in range(start_clip_id, end_clip_id):
                clip, _, _, _  = self.video_clips.get_clip(idx)
                if self.transforms:
                    clip = self.transforms(clip)
                    clip = clip.permute(1,0,2,3)
                video.append(clip.unsqueeze(0))
            video = torch.cat(video, dim=0)
            #labels = torch.cat(labels)

            return video, labels

        else:
            video_idx, clip_idx = self.video_clips.get_clip_location(idx)
            video, audio, info, video_idx = self.video_clips.get_clip(idx)
            video_path = self.video_clips.video_paths[video_idx]
            label = self.video_clips.labels[video_idx][clip_idx]

            if self.transforms is not None:
                video = self.transforms(video)

            video = video.permute(1,0,2,3)

            return video, label
示例#9
0
class VideoDataset(data.Dataset):
    def __init__(self, opt, transforms, subset, fraction=1.):
        """file_list is a list of [/path/to/mp4 key-to-df]"""
        self.subset = subset
        self.video_info_path = opt["video_info"]
        self.mode = opt["mode"]
        self.boundary_ratio = opt['boundary_ratio']
        self.skip_videoframes = opt['skip_videoframes']
        self.num_videoframes = opt['num_videoframes']
        self.dist_videoframes = opt['dist_videoframes']
        self.fraction = fraction

        subset_translate = {'train': 'training', 'val': 'validation'}
        self.anno_df = pd.read_csv(self.video_info_path)
        print(self.anno_df)
        print(subset, subset_translate.get(subset))
        if subset != 'full':
            self.anno_df = self.anno_df[self.anno_df.subset ==
                                        subset_translate[subset]]
            print(self.anno_df)

        file_loc = opt['%s_video_file_list' % subset]
        with open(file_loc, 'r') as f:
            lines = [k.strip() for k in f.readlines()]

        file_list = [k.split(' ')[0] for k in lines]
        keys_list = [k.split(' ')[1][:-4] for k in lines]
        print(keys_list[:5])
        valid_key_indices = [num for num, k in enumerate(keys_list) \
                             if k in set(self.anno_df.video.unique())]
        self.keys_list = [keys_list[num] for num in valid_key_indices]
        self.file_list = [file_list[num] for num in valid_key_indices]
        print('Number of indices: ', len(valid_key_indices), subset)

        video_info_dir = '/'.join(self.video_info_path.split('/')[:-1])
        clip_length_in_frames = self.num_videoframes * self.skip_videoframes
        frames_between_clips = self.dist_videoframes
        saved_video_clips = os.path.join(
            video_info_dir, 'video_clips.%s.%df.%ds.pkl' %
            (subset, clip_length_in_frames, frames_between_clips))
        if os.path.exists(saved_video_clips):
            print('Path Exists for video_clips: ', saved_video_clips)
            self.video_clips = pickle.load(open(saved_video_clips, 'rb'))
        else:
            print('Path does NOT exist for video_clips: ', saved_video_clips)
            self.video_clips = VideoClips(
                self.file_list,
                clip_length_in_frames=clip_length_in_frames,
                frames_between_clips=frames_between_clips,
                frame_rate=opt['fps'])
            pickle.dump(self.video_clips, open(saved_video_clips, 'wb'))
        print('Length of vid clips: ', self.video_clips.num_clips(),
              self.subset)

        if self.mode == "train":
            self.datums = self._retrieve_valid_datums()
            self.datum_indices = list(range(len(self.datums)))
            if fraction < 1:
                print('DOING the subset dataset on %s ...' % subset)
                self._subset_dataset(fraction)
            print('Len of %s datums: ' % subset, len(self.datum_indices))

        self.transforms = transforms

    def _subset_dataset(self, fraction):
        num_datums = int(len(self.datums) * fraction)
        self.datum_indices = list(range(len(self.datums)))
        random.shuffle(self.datum_indices)
        self.datum_indices = self.datum_indices[:num_datums]
        print('These indices: ', len(self.datum_indices), num_datums,
              len(self.datums))
        print(sorted(self.datum_indices)[:10])
        print(sorted(self.datum_indices)[-10:])

    def __len__(self):
        if self.mode == 'train':
            return len(self.datum_indices)
        else:
            return self.video_clips.num_clips()

    def _retrieve_valid_datums(self):
        video_info_dir = '/'.join(self.video_info_path.split('/')[:-1])
        num_clips = self.video_clips.num_clips()
        saved_data_path = os.path.join(
            video_info_dir, 'saved.%s.nf%d.sf%d.df%d.vid%d.pkl' %
            (self.subset, self.num_videoframes, self.skip_videoframes,
             self.dist_videoframes, num_clips))
        print(saved_data_path)
        if os.path.exists(saved_data_path):
            print('Got saved data.')
            with open(saved_data_path, 'rb') as f:
                return pickle.load(f)

        ret = []
        for flat_index in range(num_clips):
            video_idx, clip_idx = self.video_clips.get_clip_location(
                flat_index)
            start_frame = clip_idx * self.dist_videoframes
            snippets = [
                start_frame + self.skip_videoframes * i
                for i in range(self.num_videoframes)
            ]
            key = self.keys_list[video_idx]
            training_anchors = self._get_training_anchors(snippets, key)
            if not training_anchors:
                continue

            anchor_xmins, anchor_xmaxs, gt_bbox = training_anchors
            ret.append((flat_index, anchor_xmins, anchor_xmaxs, gt_bbox))

        print('Size of data: ', len(ret), flush=True)
        with open(saved_data_path, 'wb') as f:
            pickle.dump(ret, f)
        print('Dumped data...')
        return ret

    def __getitem__(self, index):
        # The video_data retrieved has shape [nf * sf, w, h, c].
        # We want to pick every sf'th frame out of that.
        if self.mode == "train":
            datum_index = self.datum_indices[index]
            flat_index, anchor_xmin, anchor_xmax, gt_bbox = self.datums[
                datum_index]
        else:
            flat_index = index

        video, _, _, video_idx = self.video_clips.get_clip(flat_index)

        video_data = video[0::self.skip_videoframes]
        print('Bef transform: ', video_data, type(video_data))
        video_data = self.transforms(video_data)
        print('AFt transform: ', video_data, type(video_data))
        video_data = torch.transpose(video_data, 0, 1)

        _, clip_idx = self.video_clips.get_clip_location(index)
        start_frame = clip_idx * self.dist_videoframes
        snippets = [
            start_frame + self.skip_videoframes * i
            for i in range(self.num_videoframes)
        ]
        if self.mode == "train":
            match_score_action, match_score_start, match_score_end = self._get_train_label(
                gt_bbox, anchor_xmin, anchor_xmax)
            return video_data, match_score_action, match_score_start, match_score_end
        else:
            try:
                video_name = self.keys_list[video_idx]
            except Exception as e:
                print('Whoops: VideoReader ...', video_idx,
                      len(self.keys_list), index, flat_index)
            return flat_index, video_data, video_name, snippets

    def _get_training_anchors(self, snippets, key):
        tmp_anchor_xmins = np.array(snippets) - self.skip_videoframes / 2.
        tmp_anchor_xmaxs = np.array(snippets) + self.skip_videoframes / 2.
        tmp_gt_bbox = []
        tmp_ioa_list = []
        anno_df_video = self.anno_df[self.anno_df.video == key]
        gt_xmins = anno_df_video.startFrame.values[:]
        gt_xmaxs = anno_df_video.endFrame.values[:]
        if len(gt_xmins) == 0:
            print('Yo wat gt_xmins: ', key)
            raise

        for idx in range(len(gt_xmins)):
            tmp_ioa = ioa_with_anchors(gt_xmins[idx], gt_xmaxs[idx],
                                       tmp_anchor_xmins[0],
                                       tmp_anchor_xmaxs[-1])
            tmp_ioa_list.append(tmp_ioa)
            if tmp_ioa > 0:
                tmp_gt_bbox.append([gt_xmins[idx], gt_xmaxs[idx]])

        # print(len(tmp_gt_bbox), max(tmp_ioa_list), tmp_ioa_list)
        if len(tmp_gt_bbox) > 0:
            # NOTE: Removed the threshold of 0.9... ruh roh.
            return tmp_anchor_xmins, tmp_anchor_xmaxs, tmp_gt_bbox
        return None

    def _get_train_label(self, gt_bbox, anchor_xmin, anchor_xmax):
        gt_bbox = np.array(gt_bbox)
        gt_xmins = gt_bbox[:, 0]
        gt_xmaxs = gt_bbox[:, 1]
        # same as gt_len but using the thumos code repo :/.
        gt_duration = gt_xmaxs - gt_xmins
        gt_duration_boundary = np.maximum(self.skip_videoframes,
                                          gt_duration * self.boundary_ratio)
        gt_start_bboxs = np.stack((gt_xmins - gt_duration_boundary / 2,
                                   gt_xmins + gt_duration_boundary / 2),
                                  axis=1)
        gt_end_bboxs = np.stack((gt_xmaxs - gt_duration_boundary / 2,
                                 gt_xmaxs + gt_duration_boundary / 2),
                                axis=1)

        match_score_action = [
            np.max(
                ioa_with_anchors(anchor_xmin[jdx], anchor_xmax[jdx], gt_xmins,
                                 gt_xmaxs)) for jdx in range(len(anchor_xmin))
        ]

        match_score_start = [
            np.max(
                ioa_with_anchors(anchor_xmin[jdx], anchor_xmax[jdx],
                                 gt_start_bboxs[:, 0], gt_start_bboxs[:, 1]))
            for jdx in range(len(anchor_xmin))
        ]

        match_score_end = [
            np.max(
                ioa_with_anchors(anchor_xmin[jdx], anchor_xmax[jdx],
                                 gt_end_bboxs[:, 0], gt_end_bboxs[:, 1]))
            for jdx in range(len(anchor_xmin))
        ]

        return torch.Tensor(match_score_action), torch.Tensor(
            match_score_start), torch.Tensor(match_score_end)
class VideoDataset(data.Dataset):
    """
    Process raw videos to get videoclips
    """
    def __init__(self,
                 clip_length,
                 frame_stride,
                 frame_rate=None,
                 dataset_path=None,
                 spatial_transform=None,
                 temporal_transform=None,
                 return_label=False,
                 video_formats=["avi", "mp4"]):
        super(VideoDataset, self).__init__()
        # video clip properties
        self.frames_stride = frame_stride
        self.total_clip_length_in_frames = clip_length * frame_stride
        self.spatial_transform = spatial_transform
        self.temporal_transform = temporal_transform
        self.video_formats = video_formats
        # IO
        self.dataset_path = dataset_path
        self.video_list = self._get_video_list(dataset_path=self.dataset_path)
        # print("video_list:", self.video_list, len(self.video_list))
        self.return_label = return_label

        # data loading
        self.video_clips = VideoClips(video_paths=self.video_list,
                                      clip_length_in_frames=self.total_clip_length_in_frames,
                                      frames_between_clips=self.total_clip_length_in_frames,
                                      frame_rate=frame_rate)

    @property
    def video_count(self):
        return len(self.video_list)

    def getitem_from_raw_video(self, idx):
        video, _, _, _ = self.video_clips.get_clip(idx)
        video_idx, clip_idx = self.video_clips.get_clip_location(idx)

        video_path = self.video_clips.video_paths[video_idx]

        in_clip_frames = list(range(0, self.total_clip_length_in_frames, self.frames_stride))

        # print("idx: {}, video_path: {}, video_idx: {}, clip_idx: {}, in_clip_frames: {}".format(idx, video_path, video_idx, clip_idx, in_clip_frames))

        video = video[in_clip_frames]
        # print('video: ', video.size(), video.dtype)
        if self.temporal_transform:
            video = self.temporal_transform(video)
        
        if self.spatial_transform:
            video = self.spatial_transform(video)

        dir, file = video_path.split(os.sep)[-2:]
        file = file.split('.')[0]

        # if self.return_label:
        #     label = 0 if "Normal" in video_path else 1
        #     return video, label, clip_idx, dir, file
        label = 0 if "Normal" in video_path else 1

        return video, label, (clip_idx, dir, file)

    def __len__(self):
        return len(self.video_clips)

    def __getitem__(self, index):
        succ = False
        while not succ:
            try:
                batch = self.getitem_from_raw_video(index)
                succ = True
            except Exception as e:
                index = np.random.choice(range(0, self.__len__()))
                trace_back = sys.exc_info()[2]
                line = trace_back.tb_lineno
                logging.warning(f"VideoIter:: ERROR (line number {line}) !! (Force using another index:\n{index})\n{e}")

        return batch

    def _get_video_list(self, dataset_path):
        assert os.path.exists(dataset_path), "VideoIter:: failed to locate: `{}'".format(dataset_path)
        vid_list = []
        for path, subdirs, files in os.walk(dataset_path):
            for name in files:
                if not any([format in name and name[0]!= '.' for format in self.video_formats]):
                    continue
                vid_list.append(os.path.join(path, name))
        return vid_list
示例#11
0
class VideoIter(data.Dataset):
    def __init__(self,
                 clip_length,
                 frame_stride,
                 dataset_path=None,
                 video_transform=None,
                 return_label=False):
        super(VideoIter, self).__init__()
        # video clip properties
        self.frames_stride = frame_stride
        self.total_clip_length_in_frames = clip_length * frame_stride
        self.video_transform = video_transform

        # IO
        self.dataset_path = dataset_path
        self.video_list = self._get_video_list(dataset_path=self.dataset_path)
        self.return_label = return_label

        # data loading
        self.video_clips = VideoClips(
            video_paths=self.video_list,
            clip_length_in_frames=self.total_clip_length_in_frames,
            frames_between_clips=self.total_clip_length_in_frames,
        )
        #
        # if os.path.exists('video_clips.file'):
        #     with open('video_clips.file', 'rb') as fp:
        #         self.video_clips = pickle.load(fp)
        # else:
        #     self.video_clips = VideoClips(video_paths=self.video_list,
        #                                   clip_length_in_frames=self.total_clip_length_in_frames,
        #                                   frames_between_clips=self.total_clip_length_in_frames,)
        #
        # if not os.path.exists('video_clips.file'):
        #     with open('video_clips.file', 'wb') as fp:
        #         pickle.dump(self.video_clips, fp, protocol=pickle.HIGHEST_PROTOCOL)

    @property
    def video_count(self):
        return len(self.video_list)

    def getitem_from_raw_video(self, idx):
        video, _, _, _ = self.video_clips.get_clip(idx)
        video_idx, clip_idx = self.video_clips.get_clip_location(idx)
        video_path = self.video_clips.video_paths[video_idx]
        in_clip_frames = list(
            range(0, self.total_clip_length_in_frames, self.frames_stride))
        video = video[in_clip_frames]
        if self.video_transform is not None:
            video = self.video_transform(video)

        dir, file = video_path.split(os.sep)[-2:]
        file = file.split('.')[0]

        if self.return_label:
            label = 0 if "Normal" in video_path else 1
            return video, label, clip_idx, dir, file

        return video, clip_idx, dir, file

    def __len__(self):
        return len(self.video_clips)

    def __getitem__(self, index):
        succ = False
        while not succ:
            try:
                batch = self.getitem_from_raw_video(index)
                succ = True
            except Exception as e:
                index = np.random.choice(range(0, self.__len__()))
                trace_back = sys.exc_info()[2]
                line = trace_back.tb_lineno
                logging.warning(
                    f"VideoIter:: ERROR (line number {line}) !! (Force using another index:\n{index})\n{e}"
                )

        return batch

    def _get_video_list(self, dataset_path):
        # features_path = r'/Users/eitankosman/PycharmProjects/anomaly_features'
        # existing_features = np.concatenate(
        #     [[file.split('.')[0] for file in files] for path, subdirs, files in os.walk(features_path)])
        # print(len(existing_features))
        assert os.path.exists(
            dataset_path), "VideoIter:: failed to locate: `{}'".format(
                dataset_path)
        vid_list = []
        # skp = 0
        for path, subdirs, files in os.walk(dataset_path):
            for name in files:
                if 'mp4' not in name:
                    continue
                # if name.split('.')[0] in existing_features:
                # print(f"Skipping {name}")
                # skp += 1
                # continue
                vid_list.append(os.path.join(path, name))

        # print(f"Skipped {skp}")
        return vid_list
示例#12
0
class VideoIterVal(data.Dataset):
    def __init__(self,
                 dataset_path,
                 annotation_path,
                 clip_length,
                 frame_stride,
                 video_transform=None,
                 name="<NO_NAME>",
                 return_item_subpath=False,
                 shuffle_list_seed=None):
        super(VideoIterVal, self).__init__()
        # load params
        self.frames_stride = frame_stride
        self.dataset_path = dataset_path
        self.video_transform = video_transform
        self.return_item_subpath = return_item_subpath
        self.rng = np.random.RandomState(
            shuffle_list_seed if shuffle_list_seed else 0)
        # load video list
        self.video_list = self._get_video_list(dataset_path=self.dataset_path,
                                               annotation_path=annotation_path)
        self.total_clip_length_in_frames = clip_length * frame_stride
        self.video_clips = VideoClips(
            video_paths=self.video_list,
            clip_length_in_frames=self.total_clip_length_in_frames,
            frames_between_clips=self.total_clip_length_in_frames)
        logging.info(
            "VideoIter:: iterator initialized (phase: '{:s}', num: {:d})".
            format(name, len(self.video_list)))

    def getitem_from_raw_video(self, idx):
        # get current video info
        video, _, _, _ = self.video_clips.get_clip(idx)
        video_idx, clip_idx = self.video_clips.get_clip_location(idx)
        video_path = self.video_clips.video_paths[video_idx]
        if self.video_transform is not None:
            video = self.video_transform(video)

        if "Normal" not in video_path:
            label = 1
        else:
            label = 0

        dir, file = video_path.split(os.sep)[-2:]
        file = file.split('.')[0]
        in_clip_frames = list(
            range(0, self.total_clip_length_in_frames, self.frames_stride))
        return video[in_clip_frames], label, clip_idx, dir, file

    def __getitem__(self, index):
        succ = False
        while not succ:
            try:
                clip_input, label, sampled_idx, dir, file = self.getitem_from_raw_video(
                    index)
                succ = True
            except Exception as e:
                index = self.rng.choice(range(0, self.__len__()))
                logging.warning(
                    "VideoIter:: ERROR!! (Force using another index:\n{})\n{}".
                    format(index, e))

        return clip_input, label, sampled_idx, dir, file

    def __len__(self):
        return len(self.video_list)

    def _get_video_list(self, dataset_path, annotation_path):
        assert os.path.exists(
            dataset_path
        )  # , "VideoIter:: failed to locate: `{}'".format(dataset_path)
        assert os.path.exists(
            annotation_path
        )  # , "VideoIter:: failed to locate: `{}'".format(annotation_path)
        v_id = 0
        vid_list = []
        with open(annotation_path, 'r') as f:
            for line in f:
                items = line.split()
                path = os.path.join(dataset_path, items[0])
                vid_list.append(path.strip('\n'))
        return vid_list
class VideoIter(data.Dataset):
    def __init__(self,
                 clip_length,
                 frame_stride,
                 dataset_path=None,
                 video_transform=None,
                 return_label=False):
        super(VideoIter, self).__init__()
        # video clip properties
        self.frames_stride = frame_stride
        self.total_clip_length_in_frames = clip_length * frame_stride
        self.video_transform = video_transform

        # IO
        self.dataset_path = dataset_path
        self.video_list = self._get_video_list(dataset_path=self.dataset_path)
        self.return_label = return_label

        # data loading
        self.video_clips = VideoClips(
            video_paths=self.video_list,
            clip_length_in_frames=self.total_clip_length_in_frames,
            frames_between_clips=self.total_clip_length_in_frames,
        )

    @property
    def video_count(self):
        return len(self.video_list)

    def getitem_from_raw_video(self, idx):
        video, _, _, _ = self.video_clips.get_clip(idx)
        video_idx, clip_idx = self.video_clips.get_clip_location(idx)
        video_path = self.video_clips.video_paths[video_idx]
        in_clip_frames = list(
            range(0, self.total_clip_length_in_frames, self.frames_stride))
        video = video[in_clip_frames]
        if self.video_transform is not None:
            video = self.video_transform(video)

        dir, file = video_path.split(os.sep)[-2:]
        file = file.split('.')[0]

        if self.return_label:
            label = 0 if "Normal" in video_path else 1
            return video, label, clip_idx, dir, file

        return video, clip_idx, dir, file

    def __len__(self):
        return len(self.video_clips)

    def __getitem__(self, index):
        succ = False
        while not succ:
            try:
                batch = self.getitem_from_raw_video(index)
                succ = True
            except Exception as e:
                index = np.random.choice(range(0, self.__len__()))
                logging.warning(
                    "VideoIter:: ERROR!! (Force using another index:\n{})\n{}".
                    format(index, e))

        return batch

    def _get_video_list(self, dataset_path):
        features_path = r'/Users/eitankosman/PycharmProjects/anomaly_features'
        existing_features = np.concatenate(
            [[file.split('.')[0] for file in files]
             for path, subdirs, files in os.walk(features_path)])
        print(len(existing_features))
        assert os.path.exists(
            dataset_path), "VideoIter:: failed to locate: `{}'".format(
                dataset_path)
        vid_list = []
        skp = 0
        for path, subdirs, files in os.walk(dataset_path):
            for name in files:
                if 'mp4' not in name:
                    continue
                if name.split('.')[0] in existing_features:
                    print(f"Skipping {name}")
                    skp += 1
                    continue
                vid_list.append(os.path.join(path, name))

        print(f"Skipped {skp}")
        return vid_list
示例#14
0
class KineticsAndFails(VisionDataset):
    FLOW_FPS = 8

    def __init__(self,
                 fails_path,
                 kinetics_path,
                 frames_per_clip,
                 step_between_clips,
                 fps,
                 transform=None,
                 extensions=('.mp4', ),
                 video_clips=None,
                 fails_only=False,
                 val=False,
                 balance_fails_only=False,
                 get_clip_times=False,
                 fails_video_list=None,
                 fns_to_remove=None,
                 load_flow=False,
                 flow_histogram=False,
                 fails_flow_path=None,
                 all_fail_videos=False,
                 selfsup_loss=None,
                 clip_interval_factor=None,
                 labeled_fails=True,
                 debug_dataset=False,
                 anticipate_label=0,
                 data_proportion=1,
                 **kwargs):
        self.clip_len = frames_per_clip / fps
        self.clip_step = step_between_clips / fps
        self.clip_interval_factor = clip_interval_factor
        self.fps = fps
        self.t = transform
        self.load_flow = load_flow
        self.flow_histogram = flow_histogram
        self.video_clips = None
        self.fails_path = fails_path
        self.fails_flow_path = fails_flow_path
        self.selfsup_loss = selfsup_loss
        self.get_clip_times = get_clip_times
        self.anticipate_label = anticipate_label
        data_proportion = 1 if val else data_proportion
        if video_clips:
            self.video_clips = video_clips
        else:
            assert fails_path is None or fails_video_list is None
            video_list = fails_video_list or glob(
                os.path.join(fails_path, '**', '*.mp4'), recursive=True)
            if not fails_only:
                kinetics_cls = torch.load("PATH/TO/kinetics_classes.pt")
                kinetics_dist = torch.load("PATH/TO/dist.pt")
                s = len(video_list)
                for i, n in kinetics_dist.items():
                    n *= s
                    video_list += sorted(
                        glob(os.path.join(kinetics_path, '**', kinetics_cls[i],
                                          '*.mp4'),
                             recursive=True))[:round(n)]
            self.video_clips = VideoClips(video_list, frames_per_clip,
                                          step_between_clips, fps)
        with open("PATH/TO/borders.json") as f:
            self.fails_borders = json.load(f)
        with open("PATH/TO/all_mturk_data.json") as f:
            self.fails_data = json.load(f)
        self.fails_only = fails_only
        self.t_from_clip_idx = lambda idx: (
            (step_between_clips * idx) / fps,
            (step_between_clips * idx + frames_per_clip) / fps)
        if not balance_fails_only:  # no support for recompute clips after balance calc yet
            self.video_clips.compute_clips(frames_per_clip, step_between_clips,
                                           fps)
        if video_clips is None and fails_only and labeled_fails:
            # if True:
            if not all_fail_videos:
                idxs = []
                for i, video_path in enumerate(self.video_clips.video_paths):
                    video_path = os.path.splitext(
                        os.path.basename(video_path))[0]
                    if video_path in self.fails_data:
                        idxs.append(i)
                self.video_clips = self.video_clips.subset(idxs)
            # if not val and balance_fails_only:  # balance dataset
            # ratios = {0: 0.3764, 1: 0.0989, 2: 0.5247}
            self.video_clips.labels = []
            self.video_clips.compute_clips(frames_per_clip, step_between_clips,
                                           fps)
            for video_idx, vid_clips in tqdm(enumerate(self.video_clips.clips),
                                             total=len(
                                                 self.video_clips.clips)):
                video_path = self.video_clips.video_paths[video_idx]
                if all_fail_videos and os.path.splitext(
                        os.path.basename(
                            video_path))[0] not in self.fails_data:
                    self.video_clips.labels.append([-1 for _ in vid_clips])
                    continue
                t_unit = av.open(video_path,
                                 metadata_errors='ignore').streams[0].time_base
                t_fail = sorted(self.fails_data[os.path.splitext(
                    os.path.basename(video_path))[0]]['t'])
                t_fail = t_fail[len(t_fail) // 2]
                if t_fail < 0 or not 0.01 <= statistics.median(
                        self.fails_data[os.path.splitext(os.path.basename(video_path))[0]]['rel_t']) <= 0.99 or \
                        self.fails_data[os.path.splitext(os.path.basename(video_path))[0]]['len'] < 3.2 or \
                        self.fails_data[os.path.splitext(os.path.basename(video_path))[0]]['len'] > 30:
                    self.video_clips.clips[video_idx] = torch.Tensor()
                    self.video_clips.resampling_idxs[video_idx] = torch.Tensor(
                    )
                    self.video_clips.labels.append([])
                    continue
                prev_label = 0
                first_one_idx = len(vid_clips)
                first_two_idx = len(vid_clips)
                for clip_idx, clip in enumerate(vid_clips):
                    start_pts = clip[0].item()
                    end_pts = clip[-1].item()
                    t_start = float(t_unit * start_pts)
                    t_end = float(t_unit * end_pts)
                    label = 0
                    if t_start <= t_fail <= t_end:
                        label = 1
                    elif t_start > t_fail:
                        label = 2
                    if label == 1 and prev_label == 0:
                        first_one_idx = clip_idx
                    elif label == 2 and prev_label == 1:
                        first_two_idx = clip_idx
                        break
                    prev_label = label
                self.video_clips.labels.append(
                    [0 for i in range(first_one_idx)] +
                    [1 for i in range(first_one_idx, first_two_idx)] +
                    [2 for i in range(first_two_idx, len(vid_clips))])
                if balance_fails_only and not val:
                    balance_idxs = []
                    counts = (first_one_idx, first_two_idx - first_one_idx,
                              len(vid_clips) - first_two_idx)
                    offsets = torch.LongTensor([0] + list(counts)).cumsum(
                        0)[:-1].tolist()
                    ratios = (1, 0.93, 1 / 0.93)
                    labels = (0, 1, 2)
                    lbl_mode = max(labels, key=lambda i: counts[i])
                    for i in labels:
                        if i != lbl_mode and counts[i] > 0:
                            n_to_add = round(
                                counts[i] *
                                ((counts[lbl_mode] * ratios[i] / counts[i]) -
                                 1))
                            tmp = list(
                                range(offsets[i], counts[i] + offsets[i]))
                            random.shuffle(tmp)
                            tmp_bal_idxs = []
                            while len(tmp_bal_idxs) < n_to_add:
                                tmp_bal_idxs += tmp
                            tmp_bal_idxs = tmp_bal_idxs[:n_to_add]
                            balance_idxs += tmp_bal_idxs
                    if not balance_idxs:
                        continue
                    t = torch.cat(
                        (vid_clips,
                         torch.stack([vid_clips[i] for i in balance_idxs])))
                    self.video_clips.clips[video_idx] = t
                    vid_resampling_idxs = self.video_clips.resampling_idxs[
                        video_idx]
                    try:
                        t = torch.cat(
                            (vid_resampling_idxs,
                             torch.stack([
                                 vid_resampling_idxs[i] for i in balance_idxs
                             ])))
                        self.video_clips.resampling_idxs[video_idx] = t
                    except IndexError:
                        pass
                    self.video_clips.labels[-1] += [
                        self.video_clips.labels[-1][i] for i in balance_idxs
                    ]
            clip_lengths = torch.as_tensor(
                [len(v) for v in self.video_clips.clips])
            self.video_clips.cumulative_sizes = clip_lengths.cumsum(0).tolist()
        fns_removed = 0
        if fns_to_remove and not val:
            for i, video_path in enumerate(self.video_clips.video_paths):
                if fns_removed > len(self.video_clips.video_paths) // 4:
                    break
                video_path = os.path.splitext(os.path.basename(video_path))[0]
                if video_path in fns_to_remove:
                    fns_removed += 1
                    self.video_clips.clips[i] = torch.Tensor()
                    self.video_clips.resampling_idxs[i] = torch.Tensor()
                    self.video_clips.labels[i] = []
            clip_lengths = torch.as_tensor(
                [len(v) for v in self.video_clips.clips])
            self.video_clips.cumulative_sizes = clip_lengths.cumsum(0).tolist()
            if kwargs['local_rank'] <= 0:
                print(
                    f'removed videos from {fns_removed} out of {len(self.video_clips.video_paths)} files'
                )
        # if not fails_path.startswith("PATH/TO/scenes"):
        for i, p in enumerate(self.video_clips.video_paths):
            self.video_clips.video_paths[i] = p.replace(
                "PATH/TO/scenes", os.path.dirname(fails_path))
        self.debug_dataset = debug_dataset
        if debug_dataset:
            # self.video_clips = self.video_clips.subset([0])
            pass
        if data_proportion < 1:
            rng = random.Random()
            rng.seed(23719)
            lbls = self.video_clips.labels
            subset_idxs = rng.sample(
                range(len(self.video_clips.video_paths)),
                int(len(self.video_clips.video_paths) * data_proportion))
            self.video_clips = self.video_clips.subset(subset_idxs)
            self.video_clips.labels = [lbls[i] for i in subset_idxs]

    def trim_borders(self, img, fn):
        l, r = self.fails_borders[os.path.splitext(os.path.basename(fn))[0]]
        w = img.shape[2]  # THWC
        if l > 0 and r > 0:
            img = img[:, :, round(w * l):round(w * r)]
        return img

    def __len__(self):
        return self.video_clips.num_clips()

    def compute_clip_times(self, video_idx, clip_idx):
        video_path = self.video_clips.video_paths[video_idx]
        video_path = os.path.join(
            self.fails_path,
            os.path.sep.join(video_path.rsplit(os.path.sep, 2)[-2:]))
        clip_pts = self.video_clips.clips[video_idx][clip_idx]
        start_pts = clip_pts[0].item()
        end_pts = clip_pts[-1].item()
        t_unit = av.open(video_path,
                         metadata_errors='ignore').streams[0].time_base
        t_start = float(t_unit * start_pts)
        t_end = float(t_unit * end_pts)
        return t_start, t_end

    def __getitem__(self, idx):
        if self.load_flow:
            video_idx, clip_idx = self.video_clips.get_clip_location(idx)
            video_path = self.video_clips.video_paths[video_idx]
            video_path = os.path.join(
                self.fails_path,
                os.path.sep.join(video_path.rsplit(os.path.sep, 2)[-2:]))
            label = self.video_clips.labels[video_idx][clip_idx]
            flow_path = os.path.join(
                self.fails_flow_path,
                os.path.sep.join(
                    os.path.splitext(video_path)[0].rsplit(os.path.sep,
                                                           2)[-2:]))
            t_start, t_end = self.compute_clip_times(video_idx, clip_idx)
            frame_start = round(t_start * self.FLOW_FPS)
            n_frames = round(self.clip_len * self.FLOW_FPS)
            flow = []
            for frame_i in range(frame_start, frame_start + n_frames):
                frame_fn = os.path.join(flow_path, f'{frame_i:06}.flo')
                try:
                    flow.append(
                        torch.load(frame_fn,
                                   map_location=torch.device('cpu')).permute(
                                       1, 2, 0).data.numpy())
                except:
                    pass
            while len(flow) < n_frames:
                flow += flow
            flow = flow[:n_frames]
            flow = torch.Tensor(flow)
            flow = self.trim_borders(flow, video_path)
            if self.t is not None:
                flow = self.t(flow)
            return flow, label, (flow_path, t_start, t_end)
        else:
            video_idx, clip_idx = self.video_clips.get_clip_location(idx)
            if self.anticipate_label:
                assert not self.selfsup_loss, 'no anticipation with self supervision'
                video_path = self.video_clips.video_paths[video_idx]
                label = self.video_clips.labels[video_idx][clip_idx]
                idx -= round(self.anticipate_label / self.clip_step)
                new_video_idx, new_clip_idx = self.video_clips.get_clip_location(
                    idx)
                video, *_ = self.video_clips.get_clip(idx)
                video = self.trim_borders(video, video_path)
                if self.t is not None:
                    video = self.t(video)
                new_t_start, new_t_end = self.compute_clip_times(
                    new_video_idx, new_clip_idx)
                old_t_start, old_t_end = self.compute_clip_times(
                    video_idx, clip_idx)
                if new_video_idx != video_idx or new_t_start > old_t_start:
                    label = -1
                return video, label, (video_path, new_t_start, new_t_end, [])

            video, audio, info, video_idx = self.video_clips.get_clip(idx)
            video_path = self.video_clips.video_paths[video_idx]
            # print(video_path)
            try:
                label = self.video_clips.labels[video_idx][clip_idx]
                # if self.anticipate_label:
                #     video_path = self.video_clips.video_paths[video_idx]
                #     t_fail = statistics.median(self.fails_data[os.path.splitext(os.path.basename(video_path))[0]]['t'])
                #     t_start, t_end = self.compute_clip_times(video_idx, clip_idx)
                #     t_start += self.anticipate_label
                #     t_end += self.anticipate_label
                #     label = 0
                #     if t_start <= t_fail <= t_end:
                #         label = 1
                #     elif t_start > t_fail:
                #         label = 2
            except:
                label = -1

            if label == 0 or self.fails_only:
                video = self.trim_borders(video, video_path)
            if self.debug_dataset:
                pass
                # video[:] = 0
                # video[..., 0] = 255
            if self.t is not None:
                video = self.t(video)

            t_start = t_end = -1
            if self.get_clip_times:
                t_start, t_end = self.compute_clip_times(video_idx, clip_idx)

            other = []

            if self.selfsup_loss == 'pred_middle' or self.selfsup_loss == 'sort' or self.selfsup_loss == 'ctc':
                k = round(self.clip_len / self.clip_step *
                          self.clip_interval_factor)
                video_l = [video]
                try:
                    pvideo, paudio, pinfo, pvideo_idx = self.video_clips.get_clip(
                        idx - k)
                except:
                    pvideo_idx = -1
                try:
                    nvideo, naudio, ninfo, nvideo_idx = self.video_clips.get_clip(
                        idx + k)
                except:
                    nvideo_idx = -1
                t_start, _ = self.compute_clip_times(
                    *self.video_clips.get_clip_location(idx))
                try:
                    p_t_start, _ = self.compute_clip_times(
                        *self.video_clips.get_clip_location(idx - k))
                except:
                    p_t_start = 1000000000
                try:
                    n_t_start, _ = self.compute_clip_times(
                        *self.video_clips.get_clip_location(idx + k))
                except:
                    n_t_start = -1000000000
                # if pvideo_idx == video_idx:
                #     assert p_t_start < t_start, f"{t_start} <= prev video time {p_t_start}"
                # if nvideo_idx == video_idx:
                #     assert t_start < n_t_start, f"{t_start} >= next video time {n_t_start}"
                if pvideo_idx == video_idx and p_t_start < t_start:
                    pvideo = self.trim_borders(pvideo, video_path)
                    if self.t is not None:
                        pvideo = self.t(pvideo)
                    video_l.insert(0, pvideo)
                else:
                    video_l.insert(0, torch.full_like(video, -1))
                if nvideo_idx == video_idx and t_start < n_t_start:
                    nvideo = self.trim_borders(nvideo, video_path)
                    if self.t is not None:
                        nvideo = self.t(nvideo)
                    video_l.append(nvideo)
                else:
                    video_l.append(torch.full_like(video, -1))
                video_l = torch.stack(video_l)
                video = video_l
                other = [nvideo_idx == video_idx and pvideo_idx == video_idx]

            if self.selfsup_loss == 'fps':
                other = [self.fps]

            other.append(idx)

            return video, label, (video_path, t_start, t_end, *other)
示例#15
0
class VideoDataset(VisionDataset):

    def __init__(self, root, train, frames_per_clip=16, step_between_clips=1, frame_rate=16, transform=None,
                 extensions=('mp4',), label_fn=lambda x, *_: x, local_rank=-1, get_label_only=False):
        train_or_val = 'train' if train else 'val'
        root = os.path.join(root, train_or_val)
        self.root = root

        super().__init__(root)

        self.transform = transform
        # Function that takes in __getitem__ idx and returns auxiliary label information in the form of a tensor
        self.label_fn = MethodType(label_fn, self)
        self.get_label_only = get_label_only

        clips_fn = os.path.join(root, f'clips_{train_or_val}_{frames_per_clip}_{step_between_clips}_{frame_rate}.pt')

        try:
            self.video_clips = torch.load(clips_fn)
        except FileNotFoundError:
            video_list = list(
                map(str, itertools.chain.from_iterable(Path(root).rglob(f'*.{ext}') for ext in extensions)))
            random.shuffle(video_list)
            if local_rank <= 0:
                print('Generating video clips file: ' + clips_fn)
            self.video_clips = VideoClips(
                video_list,
                frames_per_clip,
                step_between_clips,
                frame_rate,
                num_workers=32
            )
            torch.save(self.video_clips, clips_fn)

        clip_lengths = torch.as_tensor([len(v) for v in self.video_clips.clips])
        self.video_clips.clip_sizes = clip_lengths

    def __len__(self):
        return self.video_clips.num_clips()

    def __getitem__(self, idx):
        if self.get_label_only:
            return torch.Tensor([0]), torch.Tensor([0]), self.label_fn(idx)

        try:
            video, audio, info, video_idx = self.video_clips.get_clip(idx)  # Takes in index w.r.t orig clip sizes
        except IndexError as e:
            # Off by one bug in VideoClips object
            vi, ci = self.video_clips.get_clip_location(idx)
            self.video_clips.resampling_idxs[vi][ci][-1] -= 1
            video, audio, info, video_idx = self.video_clips.get_clip(idx)

        if self.transform is not None:
            video = self.transform(video)

        return video, torch.Tensor([0]), self.label_fn(idx)

    def update_subset(self, paths, path_transform=None):
        paths = set(paths)
        for i, path in enumerate(self.video_clips.video_paths):
            if path_transform:
                path = path_transform(path)
            if path not in paths:
                self.video_clips.clip_sizes[i] = 0
        self.video_clips.cumulative_sizes = self.video_clips.clip_sizes.cumsum(0).tolist()

    def use_partial_data(self, fraction):
        self.update_subset(self.video_clips.video_paths[:round(fraction * len(self.video_clips.video_paths))])
class VideoIterTrain(data.Dataset):
    def __init__(self,
                 dataset_path,
                 annotation_path,
                 clip_length,
                 frame_stride,
                 video_transform=None,
                 name="<NO_NAME>",
                 return_item_subpath=False,
                 shuffle_list_seed=None,
                 single_load=False):
        super(VideoIterTrain, self).__init__()

        self.force_color = True
        if dataset_path != None:
            self.dataset_path = dataset_path
        self.frames_stride = frame_stride
        self.video_transform = video_transform
        self.return_item_subpath = return_item_subpath
        self.rng = np.random.RandomState(
            shuffle_list_seed if shuffle_list_seed else 0)
        # load video list
        if dataset_path != None:
            self.video_list = self._get_video_list(
                dataset_path=self.dataset_path,
                annotation_path=annotation_path)

        elif type(annotation_path) == list():
            self.video_list = annotation_path
        else:
            self.video_list = [annotation_path]

        self.total_clip_length_in_frames = clip_length * frame_stride

        #size_list=[]
        if single_load == True:
            print("loading each file at a time")
            self.video_clips = VideoClips(
                video_paths=[self.video_list[0]],
                clip_length_in_frames=self.total_clip_length_in_frames,
                frames_between_clips=self.total_clip_length_in_frames)
            with tqdm(total=len(self.video_list[1:]) + 1,
                      desc=' total % of videos loaded') as pbar1:
                for video_list_used in self.video_list[1:]:  #length of load?)
                    #blockPrint()
                    print(video_list_used)
                    import os
                    #print("size "+str(os.path.getsize(video_list_used)))
                    #size_list.append(os.path.getsize(video_list_used))
                    #print(max(size_list))
                    pbar1.update(1)
                    video_clips_out = VideoClips(
                        video_paths=[video_list_used],
                        clip_length_in_frames=self.total_clip_length_in_frames,
                        frames_between_clips=self.total_clip_length_in_frames)
                    # if video_list_used =="/media/peter/Maxtor/AD-pytorch/UCF_Crimes/Videos/Training_Normal_Videos_Anomaly/Normal_Videos547_x264.mp4":
                    #     continue
                    # #enablePrint()
                    self.video_clips.clips.append(video_clips_out.clips[0])
                    #print(self.video_clips.cumulative_sizes)
                    self.video_clips.cumulative_sizes.append(
                        self.video_clips.cumulative_sizes[-1] +
                        video_clips_out.cumulative_sizes[0])
                    self.video_clips.resampling_idxs.append(
                        video_clips_out.resampling_idxs[0])
                    self.video_clips.video_fps.append(
                        video_clips_out.video_fps[0])
                    self.video_clips.video_paths.append(
                        video_clips_out.video_paths[0])
                    self.video_clips.video_pts.append(
                        video_clips_out.video_pts[0])
        else:
            print("single loader used")
            self.video_clips = VideoClips(
                video_paths=self.video_list,
                clip_length_in_frames=self.total_clip_length_in_frames,
                frames_between_clips=self.total_clip_length_in_frames)

        logging.info(
            "VideoIter:: iterator initialized (phase: '{:s}', num: {:d})".
            format(name, len(self.video_list)))

    def getitem_from_raw_video(self, idx):
        # get current video info
        video, _, _, _ = self.video_clips.get_clip(idx)
        video_idx, clip_idx = self.video_clips.get_clip_location(idx)
        in_clip_frames = list(
            range(0, self.total_clip_length_in_frames, self.frames_stride))
        video_path = self.video_clips.video_paths[video_idx]
        print(idx)
        print(video_idx)
        print(video_path)
        in_clip_frames = list(
            range(0, self.total_clip_length_in_frames, self.frames_stride))
        video = video[in_clip_frames]
        if self.video_transform is not None:
            video = self.video_transform(video)

        if "Normal" not in video_path:
            label = 1
        else:
            label = 0

        dir, file = video_path.split(os.sep)[-2:]
        file = file.split('.')[0]

        #video=video.numpy()
        #test=video.shape
        #t=video[:][0]
        #video[in_clip_frames]
        return video, label, clip_idx, dir, file  #video[:, in_clip_frames, :, :], label, clip_idx, dir, file

    def __getitem__(self, index):
        succ = False
        while not succ:
            try:
                clip_input, label, sampled_idx, dir, file = self.getitem_from_raw_video(
                    index)
                succ = True
            except Exception as e:
                index = self.rng.choice(range(0, self.__len__()))
                logging.warning(
                    "VideoIter:: ERROR!! (Force using another index:\n{})\n{}".
                    format(index, e))

        return clip_input, label, sampled_idx, dir, file

    def __len__(self):
        return len(self.video_list)

    def _get_video_list(self, dataset_path, annotation_path):

        assert os.path.exists(
            dataset_path
        )  # , "VideoIter:: failed to locate: `{}'".format(dataset_path)
        assert os.path.exists(
            annotation_path
        )  # , "VideoIter:: failed to locate: `{}'".format(annotation_path)
        vid_list = []
        with open(annotation_path, 'r') as f:
            for line in f:
                items = line.split()

                path = os.path.join(dataset_path, items[0])
                vid_list.append(path.strip('\n'))
        return vid_list  #set(vid_list)
class GymnasticsVideo(data.Dataset):
    def __init__(self,
                 transforms=None,
                 train=True,
                 test=False,
                 count_videos=-1,
                 count_clips=-1,
                 skip_videoframes=5,
                 num_videoframes=100,
                 dist_videoframes=50,
                 video_directory=None,
                 fps=5):
        # If count_videos <= 0, use all the videos. If count_clips <= 0, use
        # all the clips from all the videos.
        self.train = train
        self.transforms = transforms
        self.video_directory = video_directory
        self.skip_videoframes = skip_videoframes
        self.num_videoframes = num_videoframes
        self.dist_videoframes = dist_videoframes

        self.video_files = sorted([
            os.path.join(video_directory, f) for f in os.listdir(video_directory) \
            if f.endswith('mp4')
        ])
        if count_videos > 0:
            self.video_files = self.video_files[:count_videos]

        clip_length_in_frames = self.num_videoframes * self.skip_videoframes
        frames_between_clips = self.dist_videoframes
        self.saved_video_clips = os.path.join(
            video_directory, 'video_clips.%dnf.%df.%ds.pkl' %
            (count_videos, clip_length_in_frames, frames_between_clips))
        if os.path.exists(self.saved_video_clips):
            print('Path Exists for video_clips: ', self.saved_video_clips)
            self.video_clips = pickle.load(open(self.saved_video_clips, 'rb'))
        else:
            print('Path does NOT exist for video_clips: ',
                  self.saved_video_clips)
            self.video_clips = VideoClips(
                self.video_files,
                clip_length_in_frames=clip_length_in_frames,
                frames_between_clips=frames_between_clips,
                frame_rate=fps)
            pickle.dump(self.video_clips, open(self.saved_video_clips, 'wb'))
        self.datums = self._retrieve_valid_datums(count_videos, count_clips)
        print(self.datums)

    def __len__(self):
        return len(self.datums)

    def _retrieve_valid_datums(self, count_videos, count_clips):
        num_clips = self.video_clips.num_clips()
        ret = []
        for flat_index in range(num_clips):
            video_idx, clip_idx = self.video_clips.get_clip_location(
                flat_index)
            if count_videos > 0 and video_idx >= count_videos:
                # We reached the max number of videos we want.
                break
            if count_clips > 0 and clip_idx >= count_clips:
                # We reached the max number of clips for this video.
                continue
            ret.append((flat_index, video_idx, clip_idx))

        return ret

    def __getitem__(self, index):
        # The video_data retrieved has shape [nf * sf, w, h, c].
        # We want to pick every sf'th frame out of that.
        flat_idx, video_idx, clip_idx = self.datums[index]
        video, _, _, _ = self.video_clips.get_clip(flat_idx)
        # video_data is [100, 360, 640, 3] --> num_videoframes, w, h, ch.
        video_data = video[0::self.skip_videoframes]
        # now video_transforms is [ch, num_videoframes, 64, 64]
        video_data = self.transforms(video_data)
        # now it's [num_videoframes, ch, 64, 64]
        video_data = torch.transpose(video_data, 0, 1)
        # path = '/misc/kcgscratch1/ChoGroup/resnick/v%d.c%d.npy' % (video_idx, clip_idx)
        # if not os.path.exists(path):
        #     np.save(path, video_data.numpy())
        return video_data, index