Ejemplo n.º 1
0
 def __init__(self, path_to_model):
     self.transform = Transform().val_transform
     self.model = NIMA(pretrained_base_model=False)
     state_dict = torch.load(path_to_model,
                             map_location=lambda storage, loc: storage)
     self.model.load_state_dict(state_dict)
     self.model.eval()
Ejemplo n.º 2
0
def get_dataloaders(
        path_to_save_csv: Path, path_to_images: Path, batch_size: int,
        num_workers: int) -> Tuple[DataLoader, DataLoader, DataLoader]:
    transform = Transform()

    train_ds = AVADataset(path_to_save_csv / "train.csv", path_to_images,
                          transform.train_transform)
    val_ds = AVADataset(path_to_save_csv / "val.csv", path_to_images,
                        transform.val_transform)
    test_ds = AVADataset(path_to_save_csv / "test.csv", path_to_images,
                         transform.val_transform)

    train_loader = DataLoader(train_ds,
                              batch_size=batch_size,
                              num_workers=num_workers,
                              shuffle=True)
    val_loader = DataLoader(val_ds,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            shuffle=False)
    test_ds = DataLoader(test_ds,
                         batch_size=batch_size,
                         num_workers=num_workers,
                         shuffle=False)
    return train_loader, val_loader, test_ds
Ejemplo n.º 3
0
 def __init__(self, path_to_model_state: Path):
     self.transform = Transform().val_transform
     model_state = torch.load(path_to_model_state,
                              map_location=lambda storage, loc: storage)
     self.model = create_model(model_type=model_state["model_type"],
                               drop_out=0)
     self.model.load_state_dict(model_state["state_dict"])
     self.model = self.model.to(device)
     self.model.eval()
Ejemplo n.º 4
0
def _create_val_data_part(params: TrainParams):
    val_csv_path = os.path.join(params.path_to_save_csv, 'val.csv')
    test_csv_path = os.path.join(params.path_to_save_csv, 'test.csv')

    transform = Transform()
    val_ds = AVADataset(val_csv_path, params.path_to_images,
                        transform.val_transform)
    test_ds = AVADataset(test_csv_path, params.path_to_images,
                         transform.val_transform)

    val_loader = DataLoader(val_ds,
                            batch_size=params.batch_size,
                            num_workers=params.num_workers,
                            shuffle=False)
    test_loader = DataLoader(test_ds,
                             batch_size=params.batch_size,
                             num_workers=params.num_workers,
                             shuffle=False)

    return val_loader, test_loader
 def __init__(self,device):
     self.transform = Transform().eval_transform
     self.model = NIMA(pretrained_base_model=True)
     self.model = self.model.to(device)
     self.model.eval()