예제 #1
0
 def __init__(self,
              name,
              src_task=get_task("rgb"),
              dest_task=get_task("normal")):
     self.name = name
     self.src_task, self.dest_task = src_task, dest_task
     self.load_dataset()
예제 #2
0
    def __init__(self, buildings, tasks=[get_task("rgb"), get_task("normal")], data_dirs=DATA_DIRS,
            building_files=None, convert_path=None, use_raid=USE_RAID, resize=None, unpaired=False, shuffle=True):

        super().__init__()
        self.buildings, self.tasks, self.data_dirs = buildings, tasks, data_dirs
        self.building_files = building_files or self.building_files
        self.convert_path = convert_path or self.convert_path
        self.resize = resize
        if use_raid:
            self.convert_path = self.convert_path_raid
            self.building_files = self.building_files_raid

        self.file_map = {}
        for data_dir in self.data_dirs:
            for file in glob.glob(f'{data_dir}/*'):
                res = parse.parse("{building}_{task}", file[len(data_dir)+1:])
                if res is None: continue
                self.file_map[file[len(data_dir)+1:]] = data_dir

        filtered_files = None
        task = tasks[0]
        task_files = []
        for building in buildings:
            task_files += self.building_files(task, building)
        print(f"    {task.name} file len: {len(task_files)}")
        self.idx_files = task_files
        if not shuffle: self.idx_files = sorted(task_files)

        print ("    Intersection files len: ", len(self.idx_files))
예제 #3
0
def load_train_val(train_tasks, val_tasks=None, fast=False,
        train_buildings=None, val_buildings=None, split_file="config/split.txt",
        dataset_cls=None, batch_size=32, batch_transforms=cycle,
        subset=None, subset_size=None, dataaug=False,
    ):

    dataset_cls = dataset_cls or TaskDataset
    train_cls = TrainTaskDataset if dataaug else dataset_cls
    train_tasks = [get_task(t) if isinstance(t, str) else t for t in train_tasks]
    if val_tasks is None: val_tasks = train_tasks
    val_tasks = [get_task(t) if isinstance(t, str) else t for t in val_tasks]  
    data = yaml.load(open(split_file))
    train_buildings = train_buildings or (["almena"] if fast else data["train_buildings"])
    val_buildings = val_buildings or (["almena"] if fast else data["val_buildings"])
    print("number of train images:")
    train_loader = train_cls(buildings=train_buildings, tasks=train_tasks)
    print("number of val images:")
    val_loader = dataset_cls(buildings=val_buildings, tasks=val_tasks)

    if subset_size is not None or subset is not None:
        train_loader = torch.utils.data.Subset(train_loader,
            random.sample(range(len(train_loader)), subset_size or int(len(train_loader)*subset)),
        )

    train_step = int(len(train_loader) // (400 * batch_size))
    val_step = int(len(val_loader) // (400 * batch_size))
    print("Train step: ", train_step)
    print("Val step: ", val_step)
    if fast: train_step, val_step = 8, 8

    return train_loader, val_loader, train_step, val_step
예제 #4
0
def load_sintel_train_val_test(source_task, dest_task, 
        batch_size=64, batch_transforms=cycle
    ):

    if isinstance(source_task, str) and isinstance(dest_task, str):
        source_task, dest_task = get_task(source_task), get_task(dest_task)
    
    buildings = sorted([x.split('/')[-1] for x in glob.glob("mount/sintel/training/depth/*")])
    train_buildings, val_buildings = train_test_split(buildings, test_size=0.2)
    print (len(train_buildings))
    print (len(val_buildings))

    train_loader = torch.utils.data.DataLoader(
        SintelDataset(buildings=train_buildings, tasks=[source_task, dest_task]),
        batch_size=batch_size,
        num_workers=64, shuffle=True, pin_memory=True
    )
    val_loader = torch.utils.data.DataLoader(
        SintelDataset(buildings=val_buildings, tasks=[source_task, dest_task]),
        batch_size=batch_size,
        num_workers=64, shuffle=True, pin_memory=True
    )

    train_step = int(2248616 // (100 * batch_size))
    val_step = int(245592 // (100 * batch_size))
    print("Train step: ", train_step)
    print("Val step: ", val_step)

    test_set = list(itertools.islice(val_loader, 1))
    test_images = torch.cat([x for x, y in test_set], dim=0)
    
    return train_loader, val_loader, train_step, val_step, test_set, test_images
예제 #5
0
def load_train_val(train_tasks, val_tasks=None, fast=False, 
        train_buildings=None, val_buildings=None, split_file="data/split.txt", 
        dataset_cls=None, batch_size=64, batch_transforms=cycle,
        subset=None, subset_size=None,
    ):
    
    dataset_cls = dataset_cls or TaskDataset
    train_tasks = [get_task(t) if isinstance(t, str) else t for t in train_tasks]
    if val_tasks is None: val_tasks = train_tasks
    val_tasks = [get_task(t) if isinstance(t, str) else t for t in val_tasks]

    data = yaml.load(open(split_file))
    train_buildings = train_buildings or (["almena"] if fast else data["train_buildings"])
    val_buildings = val_buildings or (["almena"] if fast else data["val_buildings"])
    train_loader = dataset_cls(buildings=train_buildings, tasks=train_tasks)
    val_loader = dataset_cls(buildings=val_buildings, tasks=val_tasks)

    if subset_size is not None or subset is not None:
        train_loader = torch.utils.data.Subset(train_loader, 
            random.sample(range(len(train_loader)), subset_size or int(len(train_loader)*subset)),
        )
        # val_loader = torch.utils.data.Subset(val_loader, 
        #     random.sample(range(len(val_loader)), subset_size or int(len(val_loader)*subset)),
        # )

    train_step = int(2248616 // (100 * batch_size))
    val_step = int(245592 // (100 * batch_size))
    print("Train step: ", train_step)
    print("Val step: ", val_step)
    if fast: train_step, val_step = 8, 8

    return train_loader, val_loader, train_step, val_step
예제 #6
0
    def __init__(self,
                 src_task,
                 dest_task,
                 checkpoint=True,
                 name=None,
                 model_type=None,
                 path=None,
                 pretrained=True,
                 finetuned=False):
        super().__init__()
        if isinstance(src_task, str) and isinstance(dest_task, str):
            src_task, dest_task = get_task(src_task), get_task(dest_task)

        self.src_task, self.dest_task, self.checkpoint = src_task, dest_task, checkpoint
        self.name = name or f"{src_task.name}2{dest_task.name}"
        saved_type, saved_path = None, None
        if model_type is None and path is None:
            saved_type, saved_path = pretrained_transfers.get(
                (src_task.name, dest_task.name), (None, None))

        self.model_type, self.path = model_type or saved_type, path or saved_path
        self.model = None

        if finetuned:
            path = f"{MODELS_DIR}/ft_perceptual/{src_task.name}2{dest_task.name}.pth"
            if os.path.exists(path):
                self.model_type, self.path = saved_type or (
                    lambda: get_model(src_task, dest_task)), path
                return

        if self.model_type is None:

            if src_task.kind == dest_task.kind and src_task.resize != dest_task.resize:

                class Module(TrainableModel):
                    def __init__(self):
                        super().__init__()

                    def forward(self, x):
                        return resize(x, val=dest_task.resize)

                self.model_type = lambda: Module()
                self.path = None

            path = f"{MODELS_DIR}/{src_task.name}2{dest_task.name}.pth"
            if src_task.name == "keypoints2d" or dest_task.name == "keypoints2d":
                path = f"{MODELS_DIR}/{src_task.name}2{dest_task.name}_new.pth"
            if os.path.exists(path):
                self.model_type, self.path = lambda: get_model(
                    src_task, dest_task), path
            # else:
            #     self.model_type = lambda: get_model(src_task, dest_task)
            #     print ("Not using pretrained b/c no model file avail")

        if not pretrained:
            print("Not using pretrained [heavily discouraged]")
            self.path = None
예제 #7
0
def main():

    logger = VisdomLogger("train", env=JOB)

    test_loader, test_images = load_doom()
    test_images = torch.cat(test_images, dim=0)
    src_task, dest_task = get_task("rgb"), get_task("normal")

    print(test_images.shape)
    src_task.plot_func(test_images, f"images", logger, resize=128)

    paths = ["F(RC(x))", "F(EC(a(x)))", "n(x)", "npstep(x)"]

    for path_str in paths:

        path_list = path_str.replace(')', '').split('(')[::-1][1:]
        path = [TRANSFER_MAP[name] for name in path_list]

        class PathModel(TrainableModel):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                with torch.no_grad():
                    for f in path:
                        x = f(x)
                return x

            def loss(self, pred, target):
                loss = torch.tensor(0.0, device=pred.device)
                return loss, (loss.detach(), )

        model = PathModel()

        preds = model.predict(test_loader)
        dest_task.plot_func(preds, f"preds_{path_str}", logger, resize=128)
        transform = transforms.ToPILImage()
        os.makedirs(f"{BASE_DIR}/doom_processed/{path_str}/video2",
                    exist_ok=True)

        for image, file in zip(preds, test_loader.dataset.files):
            image = transform(image.cpu())
            filename = file.split("/")[-1]
            print(filename)
            image.save(
                f"{BASE_DIR}/doom_processed/{path_str}/video2/{filename}")

    print(preds.shape)
예제 #8
0
    def __init__(self, buildings, tasks=[get_task("rgb"), get_task("normal")], data_dirs=DATA_DIRS, 
            building_files=None, convert_path=None, use_raid=USE_RAID, resize=None, unpaired=False, shuffle=True):

        super().__init__()
        self.buildings, self.tasks, self.data_dirs = buildings, tasks, data_dirs
        self.building_files = building_files or self.building_files
        self.convert_path = convert_path or self.convert_path
        self.resize = resize
        if use_raid:
            self.convert_path = self.convert_path_raid
            self.building_files = self.building_files_raid
        
        # Build a map from buildings to directories
        # self.file_map = {}
        # for data_dir in self.data_dirs:
        #     for file in glob.glob(f'{data_dir}/*'):
        #         res = parse.parse("{building}_{task}", file[len(data_dir)+1:])
        #         if res is None: continue
        #         self.file_map[file[len(data_dir)+1:]] = data_dir
        # filtered_files = set()
        # for i, task in enumerate(tasks):
        #     task_files = []
        #     for building in buildings:
        #         task_files += sorted(self.building_files(task, building))
        #     print(f"{task.name} file len: {len(task_files)}")
        #     task_set = {self.convert_path(x, tasks[0]) for x in task_files}
        #     filtered_files = filtered_files.intersection(task_set) if i != 0 else task_set
        # self.idx_files = sorted(list(filtered_files))
        # print ("Intersection files len: ", len(self.idx_files))

        self.file_map = {}
        for data_dir in self.data_dirs:
            for file in glob.glob(f'{data_dir}/*'):
                res = parse.parse("{building}_{task}", file[len(data_dir)+1:])
                if res is None: continue
                self.file_map[file[len(data_dir)+1:]] = data_dir
            print(f'{data_dir}/*')
        #print(self.file_map)
        filtered_files = None
        task = tasks[0]
        task_files = []
        for building in buildings:
            task_files += self.building_files(task, building)
        print(f"{task.name} file len: {len(task_files)}")
        self.idx_files = task_files
        if not shuffle: self.idx_files = sorted(task_files)

        print ("Intersection files len: ", len(self.idx_files))
예제 #9
0
def load_test(all_tasks, buildings=["almena", "albertville", "espanola"], sample=4):

    all_tasks = [get_task(t) if isinstance(t, str) else t for t in all_tasks]
    print(f"number of images in {buildings[0]}:")
    test_loader1 = torch.utils.data.DataLoader(
        TaskDataset(buildings=[buildings[0]], tasks=all_tasks, shuffle=False),
        batch_size=sample,
        num_workers=0, shuffle=False, pin_memory=True,
    )
    print(f"number of images in {buildings[1]}:")
    test_loader2 = torch.utils.data.DataLoader(
        TaskDataset(buildings=[buildings[1]], tasks=all_tasks, shuffle=False),
        batch_size=sample,
        num_workers=0, shuffle=False, pin_memory=True,
    )
    print(f"number of images in {buildings[2]}:")
    test_loader3 = torch.utils.data.DataLoader(
        TaskDataset(buildings=[buildings[2]], tasks=all_tasks, shuffle=False),
        batch_size=sample,
        num_workers=0, shuffle=False, pin_memory=True,
    )
    set1 = list(itertools.islice(test_loader1, 1))[0]
    set2 = list(itertools.islice(test_loader2, 1))[0]
    set3 = list(itertools.islice(test_loader3, 1))[0]
    test_set = tuple(torch.cat([x, y, z], dim=0) for x, y, z in zip(set1, set2, set3))
    return test_set
예제 #10
0
def get_model(src_task, dest_task):

    if isinstance(src_task, str) and isinstance(dest_task, str):
        src_task, dest_task = get_task(src_task), get_task(dest_task)

    if (src_task.name, dest_task.name) in model_types:
        return model_types[(src_task.name, dest_task.name)]()

    elif isinstance(src_task, ImageTask) and isinstance(dest_task, ImageTask):
        return UNet(downsample=3,
                    in_channels=src_task.shape[0],
                    out_channels=dest_task.shape[0])

    elif isinstance(src_task, ImageTask) and isinstance(dest_task, ClassTask):
        return ResNet(in_channels=src_task.shape[0],
                      out_channels=dest_task.classes)

    elif isinstance(src_task, ImageTask) and isinstance(
            dest_task, PointInfoTask):
        return ResNet(out_channels=dest_task.out_channels)

    return None
예제 #11
0
def load_test(all_tasks, buildings=["almena", "albertville"], sample=4):

    all_tasks = [get_task(t) if isinstance(t, str) else t for t in all_tasks]
    test_loader1 = torch.utils.data.DataLoader(
        TaskDataset(buildings=[buildings[0]], tasks=all_tasks, shuffle=False),
        batch_size=sample,
        num_workers=sample, shuffle=False, pin_memory=True,
    )
    test_loader2 = torch.utils.data.DataLoader(
        TaskDataset(buildings=[buildings[1]], tasks=all_tasks, shuffle=False),
        batch_size=sample,
        num_workers=sample, shuffle=False, pin_memory=True,
    )
    set1 = list(itertools.islice(test_loader1, 1))[0]
    set2 = list(itertools.islice(test_loader2, 1))[0]
    test_set = tuple(torch.cat([x, y], dim=0) for x, y in zip(set1, set2))
    return test_set
예제 #12
0
def plot_images(
    model,
    logger,
    test_set,
    dest_task="normal",
    ood_images=None,
    show_masks=False,
    loss_models={},
    preds_name=None,
    target_name=None,
    ood_name=None,
):

    from task_configs import get_task, ImageTask

    test_images, preds, targets, losses, _ = model.predict_with_data(test_set)

    if isinstance(dest_task, str):
        dest_task = get_task(dest_task)

    if show_masks and isinstance(dest_task, ImageTask):
        test_masks = ImageTask.build_mask(targets,
                                          dest_task.mask_val,
                                          tol=1e-3)
        logger.images(test_masks.float(), f"{dest_task}_masks", resize=64)

    dest_task.plot_func(preds, preds_name or f"{dest_task.name}_preds", logger)
    dest_task.plot_func(targets, target_name or f"{dest_task.name}_target",
                        logger)

    if ood_images is not None:
        ood_preds = model.predict(ood_images)
        dest_task.plot_func(ood_preds, ood_name
                            or f"{dest_task.name}_ood_preds", logger)

    for name, loss_model in loss_models.items():
        with torch.no_grad():
            output = loss_model(preds, targets, test_images)
            if hasattr(output, "task"):
                output.task.plot_func(output, name, logger, resize=128)
            else:
                logger.images(output.clamp(min=0, max=1), name, resize=128)
예제 #13
0
def main():

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda x: logger.step(), feature="loss", freq=25)

    resize = 256
    ood_images = load_ood()[0]
    tasks = [
        get_task(name) for name in [
            'rgb', 'normal', 'principal_curvature', 'depth_zbuffer',
            'sobel_edges', 'reshading', 'keypoints3d', 'keypoints2d'
        ]
    ]

    test_loader = torch.utils.data.DataLoader(TaskDataset(['almena'], tasks),
                                              batch_size=64,
                                              num_workers=12,
                                              shuffle=False,
                                              pin_memory=True)
    imgs = list(itertools.islice(test_loader, 1))[0]
    gt = {tasks[i].name: batch.cuda() for i, batch in enumerate(imgs)}
    num_plot = 4

    logger.images(ood_images, f"x", nrow=2, resize=resize)
    edges = finetuned_transfers

    def get_nbrs(task, edges):
        res = []
        for e in edges:
            if task == e.src_task:
                res.append(e)
        return res

    max_depth = 10
    mse_dict = defaultdict(list)

    def search_small(x, task, prefix, visited, depth, endpoint):

        if task.name == 'normal':
            interleave = torch.stack([
                val for pair in zip(x[:num_plot], gt[task.name][:num_plot])
                for val in pair
            ])
            logger.images(interleave.clamp(max=1, min=0),
                          prefix,
                          nrow=2,
                          resize=resize)
            mse, _ = task.loss_func(x, gt[task.name])
            mse_dict[task.name].append(
                (mse.detach().data.cpu().numpy(), prefix))

        for transfer in get_nbrs(task, edges):
            preds = transfer(x)
            next_prefix = f'{transfer.name}({prefix})'
            print(f"{transfer.src_task.name}2{transfer.dest_task.name}",
                  next_prefix)

            if transfer.dest_task.name not in visited:
                visited.add(transfer.dest_task.name)
                res = search_small(preds, transfer.dest_task, next_prefix,
                                   visited, depth + 1, endpoint)
                visited.remove(transfer.dest_task.name)

        return endpoint == task

    def search_full(x, task, prefix, visited, depth, endpoint):
        for transfer in get_nbrs(task, edges):
            preds = transfer(x)
            next_prefix = f'{transfer.name}({prefix})'
            print(f"{transfer.src_task.name}2{transfer.dest_task.name}",
                  next_prefix)
            if transfer.dest_task.name == 'normal':
                interleave = torch.stack([
                    val for pair in zip(preds[:num_plot], gt[
                        transfer.dest_task.name][:num_plot]) for val in pair
                ])
                logger.images(interleave.clamp(max=1, min=0),
                              next_prefix,
                              nrow=2,
                              resize=resize)
                mse, _ = task.loss_func(preds, gt[transfer.dest_task.name])
                mse_dict[transfer.dest_task.name].append(
                    (mse.detach().data.cpu().numpy(), next_prefix))
            if transfer.dest_task.name not in visited:
                visited.add(transfer.dest_task.name)
                res = search_full(preds, transfer.dest_task, next_prefix,
                                  visited, depth + 1, endpoint)
                visited.remove(transfer.dest_task.name)

        return endpoint == task

    def search(x, task, prefix, visited, depth):
        for transfer in get_nbrs(task, edges):
            preds = transfer(x)
            next_prefix = f'{transfer.name}({prefix})'
            print(f"{transfer.src_task.name}2{transfer.dest_task.name}",
                  next_prefix)
            if transfer.dest_task.name == 'normal':
                logger.images(preds.clamp(max=1, min=0),
                              next_prefix,
                              nrow=2,
                              resize=resize)
            if transfer.dest_task.name not in visited:
                visited.add(transfer.dest_task.name)
                res = search(preds, transfer.dest_task, next_prefix, visited,
                             depth + 1)
                visited.remove(transfer.dest_task.name)

    with torch.no_grad():
        # search_full(gt['rgb'], TASK_MAP['rgb'], 'x', set('rgb'), 1, TASK_MAP['normal'])
        search(ood_images, get_task('rgb'), 'x', set('rgb'), 1)

    for name, mse_list in mse_dict.items():
        mse_list.sort()
        print(name)
        print(mse_list)
        if len(mse_list) == 1: mse_list.append((0, '-'))
        rownames = [pair[1] for pair in mse_list]
        data = [pair[0] for pair in mse_list]
        print(data, rownames)
        logger.bar(data, f'{name}_path_mse', opts={'rownames': rownames})