Ejemplo n.º 1
0
def preprocess(path, output, num_words, augmentation=False, workers=None):
    workers = psutil.cpu_count() if workers == None else workers
    if os.path.exists(output) == False:
        os.makedirs(output)

    if augmentation:
        output_path = "%s/lrw_aug_%d.h5" % (output, num_words)
    else:
        output_path = "%s/lrw_%d.h5" % (output, num_words)
    if os.path.exists(output_path):
        os.remove(output_path)

    words = None
    for mode in ['train', 'val', 'test']:
        print("Generating %s data" % mode)
        dataset = LRWDataset(path=path,
                             num_words=num_words,
                             mode=mode,
                             augmentations=augmentation,
                             estimate_pose=True)
        if words != None:
            assert words == dataset.words
        words = dataset.words
        preprocess_hdf5(
            dataset=dataset,
            output_path=output_path,
            table=mode,
            workers=workers,
        )
    print("Saved preprocessed file: %s" % output_path)
Ejemplo n.º 2
0
def extract_angles(path, output_path, num_workers, seed):
    from src.preprocess.head_pose.hopenet import HeadPose
    head_pose = HeadPose()

    words = None
    for mode in ['train', 'val', 'test']:
        dataset = LRWDataset(path=path,
                             num_words=500,
                             mode=mode,
                             estimate_pose=True,
                             seed=seed)
        if words != None:
            assert words == dataset.words
        words = dataset.words
        data_loader = DataLoader(dataset,
                                 batch_size=256,
                                 shuffle=False,
                                 num_workers=num_workers)
        lines = ""
        with tqdm(total=len(dataset)) as progress:
            for batch in data_loader:
                frames = batch['angle_frame']
                files = batch['file']
                yaws = head_pose.predict(frames)['yaw']
                for i in range(len(batch['frames'])):
                    line = f"{files[i]},{yaws[i].item():.2f}\n"
                    lines += line
                    progress.update(1)
        file = open(f"{output_path}/{mode}.txt", "w")
        file.write(lines)
        file.close()
Ejemplo n.º 3
0
 def test_dataloader(self):
     test_data = LRWDataset(path=self.hparams.data,
                            num_words=self.hparams.words,
                            mode='test',
                            seed=self.hparams.seed)
     test_loader = DataLoader(test_data,
                              shuffle=False,
                              batch_size=self.hparams.batch_size * 2,
                              num_workers=self.hparams.workers)
     return test_loader
Ejemplo n.º 4
0
 def val_dataloader(self):
     val_data = LRWDataset(path=self.hparams.data,
                           num_words=self.hparams.words,
                           mode='val',
                           seed=self.hparams.seed)
     val_loader = DataLoader(val_data,
                             shuffle=False,
                             batch_size=self.hparams.batch_size * 2,
                             num_workers=self.hparams.workers)
     return val_loader
Ejemplo n.º 5
0
 def train_dataloader(self):
     train_data = LRWDataset(path=self.hparams.data,
                             num_words=self.hparams.words,
                             seed=self.hparams.seed)
     train_loader = DataLoader(train_data,
                               shuffle=True,
                               batch_size=self.hparams.batch_size,
                               num_workers=self.hparams.workers,
                               pin_memory=True)
     return train_loader
Ejemplo n.º 6
0
 def train_dataloader(self):
     train_data = LRWDataset(
         path=self.hparams.data,
         num_words=self.hparams.words,
         in_channels=self.in_channels,
         augmentations=self.augmentations,
         query=self.query,
         seed=self.hparams.seed
     )
     train_loader = DataLoader(train_data, shuffle=True, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers, pin_memory=True)
     return train_loader