def __init__(self, name, src_task=get_task("rgb"), dest_task=get_task("normal")): self.name = name self.src_task, self.dest_task = src_task, dest_task self.load_dataset()
def __init__(self, buildings, tasks=[get_task("rgb"), get_task("normal")], data_dirs=DATA_DIRS, building_files=None, convert_path=None, use_raid=USE_RAID, resize=None, unpaired=False, shuffle=True): super().__init__() self.buildings, self.tasks, self.data_dirs = buildings, tasks, data_dirs self.building_files = building_files or self.building_files self.convert_path = convert_path or self.convert_path self.resize = resize if use_raid: self.convert_path = self.convert_path_raid self.building_files = self.building_files_raid self.file_map = {} for data_dir in self.data_dirs: for file in glob.glob(f'{data_dir}/*'): res = parse.parse("{building}_{task}", file[len(data_dir)+1:]) if res is None: continue self.file_map[file[len(data_dir)+1:]] = data_dir filtered_files = None task = tasks[0] task_files = [] for building in buildings: task_files += self.building_files(task, building) print(f" {task.name} file len: {len(task_files)}") self.idx_files = task_files if not shuffle: self.idx_files = sorted(task_files) print (" Intersection files len: ", len(self.idx_files))
def load_train_val(train_tasks, val_tasks=None, fast=False, train_buildings=None, val_buildings=None, split_file="config/split.txt", dataset_cls=None, batch_size=32, batch_transforms=cycle, subset=None, subset_size=None, dataaug=False, ): dataset_cls = dataset_cls or TaskDataset train_cls = TrainTaskDataset if dataaug else dataset_cls train_tasks = [get_task(t) if isinstance(t, str) else t for t in train_tasks] if val_tasks is None: val_tasks = train_tasks val_tasks = [get_task(t) if isinstance(t, str) else t for t in val_tasks] data = yaml.load(open(split_file)) train_buildings = train_buildings or (["almena"] if fast else data["train_buildings"]) val_buildings = val_buildings or (["almena"] if fast else data["val_buildings"]) print("number of train images:") train_loader = train_cls(buildings=train_buildings, tasks=train_tasks) print("number of val images:") val_loader = dataset_cls(buildings=val_buildings, tasks=val_tasks) if subset_size is not None or subset is not None: train_loader = torch.utils.data.Subset(train_loader, random.sample(range(len(train_loader)), subset_size or int(len(train_loader)*subset)), ) train_step = int(len(train_loader) // (400 * batch_size)) val_step = int(len(val_loader) // (400 * batch_size)) print("Train step: ", train_step) print("Val step: ", val_step) if fast: train_step, val_step = 8, 8 return train_loader, val_loader, train_step, val_step
def load_sintel_train_val_test(source_task, dest_task, batch_size=64, batch_transforms=cycle ): if isinstance(source_task, str) and isinstance(dest_task, str): source_task, dest_task = get_task(source_task), get_task(dest_task) buildings = sorted([x.split('/')[-1] for x in glob.glob("mount/sintel/training/depth/*")]) train_buildings, val_buildings = train_test_split(buildings, test_size=0.2) print (len(train_buildings)) print (len(val_buildings)) train_loader = torch.utils.data.DataLoader( SintelDataset(buildings=train_buildings, tasks=[source_task, dest_task]), batch_size=batch_size, num_workers=64, shuffle=True, pin_memory=True ) val_loader = torch.utils.data.DataLoader( SintelDataset(buildings=val_buildings, tasks=[source_task, dest_task]), batch_size=batch_size, num_workers=64, shuffle=True, pin_memory=True ) train_step = int(2248616 // (100 * batch_size)) val_step = int(245592 // (100 * batch_size)) print("Train step: ", train_step) print("Val step: ", val_step) test_set = list(itertools.islice(val_loader, 1)) test_images = torch.cat([x for x, y in test_set], dim=0) return train_loader, val_loader, train_step, val_step, test_set, test_images
def load_train_val(train_tasks, val_tasks=None, fast=False, train_buildings=None, val_buildings=None, split_file="data/split.txt", dataset_cls=None, batch_size=64, batch_transforms=cycle, subset=None, subset_size=None, ): dataset_cls = dataset_cls or TaskDataset train_tasks = [get_task(t) if isinstance(t, str) else t for t in train_tasks] if val_tasks is None: val_tasks = train_tasks val_tasks = [get_task(t) if isinstance(t, str) else t for t in val_tasks] data = yaml.load(open(split_file)) train_buildings = train_buildings or (["almena"] if fast else data["train_buildings"]) val_buildings = val_buildings or (["almena"] if fast else data["val_buildings"]) train_loader = dataset_cls(buildings=train_buildings, tasks=train_tasks) val_loader = dataset_cls(buildings=val_buildings, tasks=val_tasks) if subset_size is not None or subset is not None: train_loader = torch.utils.data.Subset(train_loader, random.sample(range(len(train_loader)), subset_size or int(len(train_loader)*subset)), ) # val_loader = torch.utils.data.Subset(val_loader, # random.sample(range(len(val_loader)), subset_size or int(len(val_loader)*subset)), # ) train_step = int(2248616 // (100 * batch_size)) val_step = int(245592 // (100 * batch_size)) print("Train step: ", train_step) print("Val step: ", val_step) if fast: train_step, val_step = 8, 8 return train_loader, val_loader, train_step, val_step
def __init__(self, src_task, dest_task, checkpoint=True, name=None, model_type=None, path=None, pretrained=True, finetuned=False): super().__init__() if isinstance(src_task, str) and isinstance(dest_task, str): src_task, dest_task = get_task(src_task), get_task(dest_task) self.src_task, self.dest_task, self.checkpoint = src_task, dest_task, checkpoint self.name = name or f"{src_task.name}2{dest_task.name}" saved_type, saved_path = None, None if model_type is None and path is None: saved_type, saved_path = pretrained_transfers.get( (src_task.name, dest_task.name), (None, None)) self.model_type, self.path = model_type or saved_type, path or saved_path self.model = None if finetuned: path = f"{MODELS_DIR}/ft_perceptual/{src_task.name}2{dest_task.name}.pth" if os.path.exists(path): self.model_type, self.path = saved_type or ( lambda: get_model(src_task, dest_task)), path return if self.model_type is None: if src_task.kind == dest_task.kind and src_task.resize != dest_task.resize: class Module(TrainableModel): def __init__(self): super().__init__() def forward(self, x): return resize(x, val=dest_task.resize) self.model_type = lambda: Module() self.path = None path = f"{MODELS_DIR}/{src_task.name}2{dest_task.name}.pth" if src_task.name == "keypoints2d" or dest_task.name == "keypoints2d": path = f"{MODELS_DIR}/{src_task.name}2{dest_task.name}_new.pth" if os.path.exists(path): self.model_type, self.path = lambda: get_model( src_task, dest_task), path # else: # self.model_type = lambda: get_model(src_task, dest_task) # print ("Not using pretrained b/c no model file avail") if not pretrained: print("Not using pretrained [heavily discouraged]") self.path = None
def main(): logger = VisdomLogger("train", env=JOB) test_loader, test_images = load_doom() test_images = torch.cat(test_images, dim=0) src_task, dest_task = get_task("rgb"), get_task("normal") print(test_images.shape) src_task.plot_func(test_images, f"images", logger, resize=128) paths = ["F(RC(x))", "F(EC(a(x)))", "n(x)", "npstep(x)"] for path_str in paths: path_list = path_str.replace(')', '').split('(')[::-1][1:] path = [TRANSFER_MAP[name] for name in path_list] class PathModel(TrainableModel): def __init__(self): super().__init__() def forward(self, x): with torch.no_grad(): for f in path: x = f(x) return x def loss(self, pred, target): loss = torch.tensor(0.0, device=pred.device) return loss, (loss.detach(), ) model = PathModel() preds = model.predict(test_loader) dest_task.plot_func(preds, f"preds_{path_str}", logger, resize=128) transform = transforms.ToPILImage() os.makedirs(f"{BASE_DIR}/doom_processed/{path_str}/video2", exist_ok=True) for image, file in zip(preds, test_loader.dataset.files): image = transform(image.cpu()) filename = file.split("/")[-1] print(filename) image.save( f"{BASE_DIR}/doom_processed/{path_str}/video2/{filename}") print(preds.shape)
def __init__(self, buildings, tasks=[get_task("rgb"), get_task("normal")], data_dirs=DATA_DIRS, building_files=None, convert_path=None, use_raid=USE_RAID, resize=None, unpaired=False, shuffle=True): super().__init__() self.buildings, self.tasks, self.data_dirs = buildings, tasks, data_dirs self.building_files = building_files or self.building_files self.convert_path = convert_path or self.convert_path self.resize = resize if use_raid: self.convert_path = self.convert_path_raid self.building_files = self.building_files_raid # Build a map from buildings to directories # self.file_map = {} # for data_dir in self.data_dirs: # for file in glob.glob(f'{data_dir}/*'): # res = parse.parse("{building}_{task}", file[len(data_dir)+1:]) # if res is None: continue # self.file_map[file[len(data_dir)+1:]] = data_dir # filtered_files = set() # for i, task in enumerate(tasks): # task_files = [] # for building in buildings: # task_files += sorted(self.building_files(task, building)) # print(f"{task.name} file len: {len(task_files)}") # task_set = {self.convert_path(x, tasks[0]) for x in task_files} # filtered_files = filtered_files.intersection(task_set) if i != 0 else task_set # self.idx_files = sorted(list(filtered_files)) # print ("Intersection files len: ", len(self.idx_files)) self.file_map = {} for data_dir in self.data_dirs: for file in glob.glob(f'{data_dir}/*'): res = parse.parse("{building}_{task}", file[len(data_dir)+1:]) if res is None: continue self.file_map[file[len(data_dir)+1:]] = data_dir print(f'{data_dir}/*') #print(self.file_map) filtered_files = None task = tasks[0] task_files = [] for building in buildings: task_files += self.building_files(task, building) print(f"{task.name} file len: {len(task_files)}") self.idx_files = task_files if not shuffle: self.idx_files = sorted(task_files) print ("Intersection files len: ", len(self.idx_files))
def load_test(all_tasks, buildings=["almena", "albertville", "espanola"], sample=4): all_tasks = [get_task(t) if isinstance(t, str) else t for t in all_tasks] print(f"number of images in {buildings[0]}:") test_loader1 = torch.utils.data.DataLoader( TaskDataset(buildings=[buildings[0]], tasks=all_tasks, shuffle=False), batch_size=sample, num_workers=0, shuffle=False, pin_memory=True, ) print(f"number of images in {buildings[1]}:") test_loader2 = torch.utils.data.DataLoader( TaskDataset(buildings=[buildings[1]], tasks=all_tasks, shuffle=False), batch_size=sample, num_workers=0, shuffle=False, pin_memory=True, ) print(f"number of images in {buildings[2]}:") test_loader3 = torch.utils.data.DataLoader( TaskDataset(buildings=[buildings[2]], tasks=all_tasks, shuffle=False), batch_size=sample, num_workers=0, shuffle=False, pin_memory=True, ) set1 = list(itertools.islice(test_loader1, 1))[0] set2 = list(itertools.islice(test_loader2, 1))[0] set3 = list(itertools.islice(test_loader3, 1))[0] test_set = tuple(torch.cat([x, y, z], dim=0) for x, y, z in zip(set1, set2, set3)) return test_set
def get_model(src_task, dest_task): if isinstance(src_task, str) and isinstance(dest_task, str): src_task, dest_task = get_task(src_task), get_task(dest_task) if (src_task.name, dest_task.name) in model_types: return model_types[(src_task.name, dest_task.name)]() elif isinstance(src_task, ImageTask) and isinstance(dest_task, ImageTask): return UNet(downsample=3, in_channels=src_task.shape[0], out_channels=dest_task.shape[0]) elif isinstance(src_task, ImageTask) and isinstance(dest_task, ClassTask): return ResNet(in_channels=src_task.shape[0], out_channels=dest_task.classes) elif isinstance(src_task, ImageTask) and isinstance( dest_task, PointInfoTask): return ResNet(out_channels=dest_task.out_channels) return None
def load_test(all_tasks, buildings=["almena", "albertville"], sample=4): all_tasks = [get_task(t) if isinstance(t, str) else t for t in all_tasks] test_loader1 = torch.utils.data.DataLoader( TaskDataset(buildings=[buildings[0]], tasks=all_tasks, shuffle=False), batch_size=sample, num_workers=sample, shuffle=False, pin_memory=True, ) test_loader2 = torch.utils.data.DataLoader( TaskDataset(buildings=[buildings[1]], tasks=all_tasks, shuffle=False), batch_size=sample, num_workers=sample, shuffle=False, pin_memory=True, ) set1 = list(itertools.islice(test_loader1, 1))[0] set2 = list(itertools.islice(test_loader2, 1))[0] test_set = tuple(torch.cat([x, y], dim=0) for x, y in zip(set1, set2)) return test_set
def plot_images( model, logger, test_set, dest_task="normal", ood_images=None, show_masks=False, loss_models={}, preds_name=None, target_name=None, ood_name=None, ): from task_configs import get_task, ImageTask test_images, preds, targets, losses, _ = model.predict_with_data(test_set) if isinstance(dest_task, str): dest_task = get_task(dest_task) if show_masks and isinstance(dest_task, ImageTask): test_masks = ImageTask.build_mask(targets, dest_task.mask_val, tol=1e-3) logger.images(test_masks.float(), f"{dest_task}_masks", resize=64) dest_task.plot_func(preds, preds_name or f"{dest_task.name}_preds", logger) dest_task.plot_func(targets, target_name or f"{dest_task.name}_target", logger) if ood_images is not None: ood_preds = model.predict(ood_images) dest_task.plot_func(ood_preds, ood_name or f"{dest_task.name}_ood_preds", logger) for name, loss_model in loss_models.items(): with torch.no_grad(): output = loss_model(preds, targets, test_images) if hasattr(output, "task"): output.task.plot_func(output, name, logger, resize=128) else: logger.images(output.clamp(min=0, max=1), name, resize=128)
def main(): # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda x: logger.step(), feature="loss", freq=25) resize = 256 ood_images = load_ood()[0] tasks = [ get_task(name) for name in [ 'rgb', 'normal', 'principal_curvature', 'depth_zbuffer', 'sobel_edges', 'reshading', 'keypoints3d', 'keypoints2d' ] ] test_loader = torch.utils.data.DataLoader(TaskDataset(['almena'], tasks), batch_size=64, num_workers=12, shuffle=False, pin_memory=True) imgs = list(itertools.islice(test_loader, 1))[0] gt = {tasks[i].name: batch.cuda() for i, batch in enumerate(imgs)} num_plot = 4 logger.images(ood_images, f"x", nrow=2, resize=resize) edges = finetuned_transfers def get_nbrs(task, edges): res = [] for e in edges: if task == e.src_task: res.append(e) return res max_depth = 10 mse_dict = defaultdict(list) def search_small(x, task, prefix, visited, depth, endpoint): if task.name == 'normal': interleave = torch.stack([ val for pair in zip(x[:num_plot], gt[task.name][:num_plot]) for val in pair ]) logger.images(interleave.clamp(max=1, min=0), prefix, nrow=2, resize=resize) mse, _ = task.loss_func(x, gt[task.name]) mse_dict[task.name].append( (mse.detach().data.cpu().numpy(), prefix)) for transfer in get_nbrs(task, edges): preds = transfer(x) next_prefix = f'{transfer.name}({prefix})' print(f"{transfer.src_task.name}2{transfer.dest_task.name}", next_prefix) if transfer.dest_task.name not in visited: visited.add(transfer.dest_task.name) res = search_small(preds, transfer.dest_task, next_prefix, visited, depth + 1, endpoint) visited.remove(transfer.dest_task.name) return endpoint == task def search_full(x, task, prefix, visited, depth, endpoint): for transfer in get_nbrs(task, edges): preds = transfer(x) next_prefix = f'{transfer.name}({prefix})' print(f"{transfer.src_task.name}2{transfer.dest_task.name}", next_prefix) if transfer.dest_task.name == 'normal': interleave = torch.stack([ val for pair in zip(preds[:num_plot], gt[ transfer.dest_task.name][:num_plot]) for val in pair ]) logger.images(interleave.clamp(max=1, min=0), next_prefix, nrow=2, resize=resize) mse, _ = task.loss_func(preds, gt[transfer.dest_task.name]) mse_dict[transfer.dest_task.name].append( (mse.detach().data.cpu().numpy(), next_prefix)) if transfer.dest_task.name not in visited: visited.add(transfer.dest_task.name) res = search_full(preds, transfer.dest_task, next_prefix, visited, depth + 1, endpoint) visited.remove(transfer.dest_task.name) return endpoint == task def search(x, task, prefix, visited, depth): for transfer in get_nbrs(task, edges): preds = transfer(x) next_prefix = f'{transfer.name}({prefix})' print(f"{transfer.src_task.name}2{transfer.dest_task.name}", next_prefix) if transfer.dest_task.name == 'normal': logger.images(preds.clamp(max=1, min=0), next_prefix, nrow=2, resize=resize) if transfer.dest_task.name not in visited: visited.add(transfer.dest_task.name) res = search(preds, transfer.dest_task, next_prefix, visited, depth + 1) visited.remove(transfer.dest_task.name) with torch.no_grad(): # search_full(gt['rgb'], TASK_MAP['rgb'], 'x', set('rgb'), 1, TASK_MAP['normal']) search(ood_images, get_task('rgb'), 'x', set('rgb'), 1) for name, mse_list in mse_dict.items(): mse_list.sort() print(name) print(mse_list) if len(mse_list) == 1: mse_list.append((0, '-')) rownames = [pair[1] for pair in mse_list] data = [pair[0] for pair in mse_list] print(data, rownames) logger.bar(data, f'{name}_path_mse', opts={'rownames': rownames})