def test_JesterDataset_with_padding(self): number_of_frames = 15 self.assertGreater(number_of_frames, self.min_number_of_frames) dataset = JesterDataset(self.train_csv_file, self.video_dir, number_of_frames=number_of_frames, video_transform=None) self.assertEqual(len(dataset), self.train_test_validation_size[0]) for frames, label in dataset: self.assertEqual(len(frames), number_of_frames) self.assertIn(label, self.testDataset.labels)
def test_JesterDataset_with_validation_set(self): number_of_frames = 10 self.assertLessEqual(number_of_frames, self.min_number_of_frames) dataset = JesterDataset(self.validation_csv_file, self.video_dir, number_of_frames=number_of_frames, video_transform=None) self.assertEqual(len(dataset), self.train_test_validation_size[2]) for frames, label in dataset: self.assertEqual(len(frames), number_of_frames) self.assertIsNone(label)
def test_JesterDataset_with_frame_select_from_end(self): number_of_frames = 5 self.assertLess(number_of_frames, self.min_number_of_frames) dataset = JesterDataset( self.train_csv_file, self.video_dir, number_of_frames=number_of_frames, frame_select_strategy=JesterDataset.FrameSelectStrategy.FROM_END, video_transform=None) self.assertEqual(len(dataset), self.train_test_validation_size[0]) for frames, label in dataset: self.assertEqual(len(frames), number_of_frames) self.assertIn(label, self.testDataset.labels)
def data_processing( batch_size=4, channel_nb=3, frame_nb=16, frame_size=(112, 112), csv_file='./annotation_jester/annotation_train.csv', video_dir='./jester_data/20bn-jester-v1', frame_select_strategy=JesterDataset.FrameSelectStrategy.RANDOM, frame_padding=JesterDataset.FramePadding.REPEAT_END, shuffle=True, num_workers=4): """ Process the video data. 1. transform data and convert them to tensor; 2. read the csv file, such as 'train', 'label' :param batch_size: the size of batch :param channel_nb: the number of channels :param frame_nb: the number of frames :param frame_size: the size of frame in video, e.g., (100, 120) :param csv_file: read the csv file, such as train, validation, test, and labels :param video_dir: the directory of the video :param frame_select_strategy: FROM_BEGINNING = 0, FROM_END = 1, RANDOM = 2 :param frame_padding: REPEAT_END = 0, REPEAT_BEGINNING = 2 :param shuffle: whether to shuffle the data :param num_workers: the number of workers :return: data loader, the batch in it is (batch_size, channel_nb, frame_nb, height, width), e.g., (4,3,16,100,120) """ video_transform_list = [ Resize(frame_size), ClipToTensor(channel_nb=channel_nb), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] video_transform = Compose(video_transform_list) dataset = JesterDataset(csv_file=csv_file, video_dir=video_dir, number_of_frames=frame_nb, video_transform=video_transform, frame_select_strategy=frame_select_strategy, frame_padding=frame_padding) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) return dataloader
def test_JesterDataset_with_dataloader(self): dataset = JesterDataset(self.train_csv_file, self.video_dir, video_transform=ClipToTensor()) preferred_batch_size = 4 dataloader = DataLoader(dataset, batch_size=preferred_batch_size, shuffle=True, num_workers=1) batch_sizes = [] for i_batch, sample_batched in enumerate(dataloader): batch, label = sample_batched actual_batch_size = len(batch) batch_sizes.append(actual_batch_size) self.assertLessEqual(actual_batch_size, preferred_batch_size) number_of_batches_with_batch_size = batch_sizes.count( preferred_batch_size) self.assertGreaterEqual(number_of_batches_with_batch_size, 1)
from torch.utils.data.dataloader import DataLoader from jesterdataset import JesterDataset from torchvideotransforms.volume_transforms import ClipToTensor dataset = JesterDataset("./jester_data/jester-v1-train.csv", "./jester_data/20bn-jester-v1", video_transform=ClipToTensor()) dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4) for i_batch, sample_batched in enumerate(dataloader): batch, label = sample_batched assert (len(batch) <= 4) print(f"Batch number {i_batch} has a batch size of {len(batch)}")