Пример #1
0
    def __init__(self, root_dir, augmentation_params, image_shape=(64, 64, 3), is_train=True,
                 random_seed=0, pairs_list=None, transform=None):
        self.root_dir = root_dir
        self.images = os.listdir(root_dir)
        self.image_shape = tuple(image_shape)
        self.pairs_list = pairs_list

        if os.path.exists(os.path.join(root_dir, 'train')):
            assert os.path.exists(os.path.join(root_dir, 'test'))
            print("Use predefined train-test split.")
            train_images = os.listdir(os.path.join(root_dir, 'train'))
            test_images = os.listdir(os.path.join(root_dir, 'test'))
            self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
        else:
            print("Use random train-test split.")
            train_images, test_images = train_test_split(self.images, random_state=random_seed, test_size=0.2)

        if is_train:
            self.images = train_images
        else:
            self.images = test_images

        if transform is None:
            if is_train:
                self.transform = AllAugmentationTransform(**augmentation_params)
            else:
                self.transform = VideoToTensor()
        else:
            self.transform = transform
Пример #2
0
    def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
                 random_seed=0, pairs_list=None, augmentation_params=None):
        self.root_dir = root_dir
        self.videos = os.listdir(root_dir)
        self.frame_shape = tuple(frame_shape)
        self.pairs_list = pairs_list
        self.id_sampling = id_sampling
        if os.path.exists(os.path.join(root_dir, 'train')):
            assert os.path.exists(os.path.join(root_dir, 'test'))
            print("Use predefined train-test split.")
            if id_sampling:
                train_videos = {os.path.basename(video).split('#')[0] for video in
                                os.listdir(os.path.join(root_dir, 'train'))}
                train_videos = list(train_videos)
            else:
                train_videos = os.listdir(os.path.join(root_dir, 'train'))
            test_videos = os.listdir(os.path.join(root_dir, 'test'))
            self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
        else:
            print("Use random train-test split.")
            train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)

        if is_train:
            self.videos = train_videos
        else:
            self.videos = test_videos

        self.is_train = is_train

        if self.is_train:
            self.transform = AllAugmentationTransform(**augmentation_params)
        else:
            self.transform = None
Пример #3
0
class FramesDataset(Dataset):
    """
    Dataset of videos, each video can be represented as:
      - an image of concatenated frames
      - '.mp4' or '.gif'
      - folder with all frames
    """
    def __init__(self,
                 root_dir,
                 frame_shape=(256, 256, 3),
                 id_sampling=False,
                 is_train=True,
                 random_seed=0,
                 pairs_list=None,
                 augmentation_params=None):
        self.root_dir = root_dir
        self.videos = os.listdir(root_dir)
        self.frame_shape = tuple(frame_shape)
        self.pairs_list = pairs_list
        self.id_sampling = id_sampling
        if os.path.exists(os.path.join(root_dir, 'train')):
            assert os.path.exists(os.path.join(root_dir, 'test'))
            print("Use predefined train-test split.")
            if id_sampling:
                train_videos = {
                    os.path.basename(video).split('#')[0]
                    for video in os.listdir(os.path.join(root_dir, 'train'))
                }
                train_videos = list(train_videos)
            else:
                train_videos = os.listdir(os.path.join(root_dir, 'train'))
            test_videos = os.listdir(os.path.join(root_dir, 'test'))
            self.root_dir = os.path.join(self.root_dir,
                                         'train' if is_train else 'test')
        else:
            print("Use random train-test split.")
            train_videos, test_videos = train_test_split(
                self.videos, random_state=random_seed, test_size=0.2)

        if is_train:
            self.videos = train_videos
        else:
            self.videos = test_videos

        self.is_train = is_train
        self.transform = AllAugmentationTransform(**augmentation_params)

    def __len__(self):
        return len(self.videos)

    def __getitem__(self, idx):
        if self.is_train and self.id_sampling:
            name = self.videos[idx]
            path = np.random.choice(
                glob.glob(os.path.join(self.root_dir, name + '*.mp4')))

            idx2 = idx
            while idx2 == idx:
                idx2 = random.randint(0, len(self.videos) - 1)
            name2 = self.videos[idx2]
            path2 = np.random.choice(
                glob.glob(os.path.join(self.root_dir, name2 + '*.mp4')))
            video_name2 = os.path.basename(path2)
        else:
            name = self.videos[idx]
            path = os.path.join(self.root_dir, name)

        video_name = os.path.basename(path)

        if self.is_train and os.path.isdir(path):
            frames = os.listdir(path)
            num_frames = len(frames)

            while True:
                try:
                    frame_idx = np.sort(
                        np.random.choice(num_frames, replace=True, size=2))
                    video_array = [
                        img_as_float32(
                            io.imread(os.path.join(path, frames[idx])))
                        for idx in frame_idx
                    ]
                    break
                except:
                    print('error loading, trying again')

            frames2 = os.listdir(path2)
            num_frames2 = len(frames2)
            while True:
                try:
                    frame_idx2 = np.sort(
                        np.random.choice(num_frames2, replace=True, size=1))
                    video_array2 = [
                        img_as_float32(
                            io.imread(os.path.join(path2, frames2[tidx])))
                        for tidx in frame_idx2
                    ]
                    break
                except:
                    print('error loading, trying again')
        else:
            video_array = read_video(path, frame_shape=self.frame_shape)
            num_frames = len(video_array)
            frame_idx = np.sort(
                np.random.choice(
                    num_frames, replace=True,
                    size=2)) if self.is_train else range(num_frames)
            video_array = video_array[frame_idx]

        out = {}
        if self.is_train:
            video_array = self.transform.flip(video_array)
            video_array2 = self.transform.flip(video_array2)

            source = np.array(video_array[0], dtype='float32')
            driving = np.array(video_array[1], dtype='float32')
            driving2 = np.array(video_array2[0], dtype='float32')

            target = driving.copy()

            source_and_targets = np.concatenate((np.expand_dims(
                source, axis=0), np.expand_dims(target, axis=0)),
                                                axis=0)
            source_and_targets = self.transform.transform_source_and_targets(
                source_and_targets)

            source = source_and_targets[0]
            target = source_and_targets[1]

            driving = self.transform.transform_driving(
                np.expand_dims(driving, axis=0))[0]
            driving2 = self.transform.transform_driving2(
                np.expand_dims(driving2, axis=0))[0]

            out['driving'] = driving.transpose((2, 0, 1))
            out['driving2'] = driving2.transpose((2, 0, 1))
            out['source'] = source.transpose((2, 0, 1))
            out['target'] = target.transpose((2, 0, 1))
            out['name2'] = video_name2
        else:
            video = np.array(video_array, dtype='float32')
            out['video'] = video.transpose((3, 0, 1, 2))

        out['name'] = video_name

        return out
    def __init__(self,
                 root_dir,
                 frame_shape=(256, 256, 3),
                 id_sampling=False,
                 is_train=True,
                 random_seed=0,
                 pairs_list=None,
                 augmentation_params=None):
        data_dir = os.environ.get("DATA_DIR")
        if data_dir is not None:
            root_dir = data_dir

        print(f'Dataset root dir {root_dir}.')

        self.root_dir = os.path.join(root_dir)
        self.videos = os.listdir(root_dir)
        self.frame_shape = tuple(frame_shape)
        if pairs_list:
            self.pairs_list = os.path.join(self.root_dir, pairs_list)
        else:
            pairs_list = None
        self.id_sampling = id_sampling
        if os.path.exists(os.path.join(root_dir, 'train')):
            assert os.path.exists(os.path.join(root_dir, 'test'))
            print("Use predefined train-test split.")
            if id_sampling:
                train_videos = {
                    os.path.basename(video).split('#')[0]
                    for video in os.listdir(os.path.join(root_dir, 'train'))
                }
                train_videos = list(train_videos)
            else:
                train_videos = os.listdir(os.path.join(root_dir, 'train'))
            test_videos = os.listdir(os.path.join(root_dir, 'test'))
            self.root_dir = os.path.join(self.root_dir,
                                         'train' if is_train else 'test')
        else:
            print("Use random train-test split.")
            train_videos, test_videos = train_test_split(
                self.videos, random_state=random_seed, test_size=0.2)

        if is_train:
            self.videos = train_videos
        else:
            self.videos = test_videos

        name = self.videos[0]
        path = glob.glob(os.path.join(self.root_dir, name, '**/*.jpg'),
                         recursive=True)
        if len(path) > 0:
            print(f'Detected frame dataset by {path[0]}.')
            self.video_dataset = False
        else:
            print(f'Detected video dataset.')
            self.video_dataset = True

        self.is_train = is_train

        if self.is_train:
            self.transform = AllAugmentationTransform(**augmentation_params)
        else:
            self.transform = None