def __init__(self, model_cfg, ckpt_path, device, line_detector_cfg, line_matcher_cfg, multiscale=False, scales=[1., 2.]): # Get loss weights if dynamic weighting _, loss_weights = get_loss_and_weights(model_cfg, device) self.device = device # Initialize the cnn backbone self.model = get_model(model_cfg, loss_weights) checkpoint = torch.load(ckpt_path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.model = self.model.to(self.device) self.model = self.model.eval() self.grid_size = model_cfg["grid_size"] self.junc_detect_thresh = model_cfg["detection_thresh"] self.max_num_junctions = model_cfg.get("max_num_junctions", 300) # Initialize the line detector self.line_detector = LineSegmentDetectionModule(**line_detector_cfg) self.multiscale = multiscale self.scales = scales # Initialize the line matcher self.line_matcher = WunschLineMatcher(**line_matcher_cfg) # Print some debug messages for key, val in line_detector_cfg.items(): print(f"[Debug] {key}: {val}")
def __init__(self, model_cfg, ckpt_path, device, line_detector_cfg, junc_detect_thresh=None): """ SOLD² line detector taking raw images as input. Parameters: model_cfg: config for CNN model ckpt_path: path to the weights line_detector_cfg: config file for the line detection module """ # Get loss weights if dynamic weighting _, loss_weights = get_loss_and_weights(model_cfg, device) self.device = device # Initialize the cnn backbone self.model = get_model(model_cfg, loss_weights) checkpoint = torch.load(ckpt_path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.model = self.model.to(self.device) self.model = self.model.eval() self.grid_size = model_cfg["grid_size"] if junc_detect_thresh is not None: self.junc_detect_thresh = junc_detect_thresh else: self.junc_detect_thresh = model_cfg["detection_thresh"] self.junc_detect_thresh = 1 / 65 # Initialize the line detector self.line_detector_cfg = line_detector_cfg self.line_detector = LineSegmentDetectionModule(**line_detector_cfg)
def train(base_path, config): # We do not retrieve the color from the config (it should not be specified anyway) # because we render our own segmentation images using white color image_extension = config.IMAGE_EXTENSION object_model_path = config.OBJECT_MODEL_PATH ground_truth_path = config.GT_PATH data_path = config.DATA_PATH weights_path = config.WEIGHTS_PATH output_path = os.path.join(base_path, config.OUTPUT_PATH) assert os.path.exists( object_model_path), "The object model file {} does not exist.".format( object_model_path) assert os.path.exists( ground_truth_path), "The ground-truth file {} does not exist.".format( ground_truth_path) if weights_path != "": assert os.path.exists( weights_path), "The weights file {} does not exist.".format( weights_path) # The paths where we store the generated object coordinate images as well as # the cropped original and segmentation images images_path = os.path.join(data_path, "images") segmentations_path = os.path.join(data_path, "segmentations") obj_coords_path = os.path.join(data_path, "obj_coords") # Retrieve the rendered images to add it to the datasets images = util.get_files_at_path_of_extensions(images_path, [image_extension]) util.sort_list_by_num_in_string_entries(images) segmentation_renderings = util.get_files_at_path_of_extensions( segmentations_path, ['png']) util.sort_list_by_num_in_string_entries(segmentation_renderings) obj_coordinate_renderings = util.get_files_at_path_of_extensions( obj_coords_path, ['tiff']) util.sort_list_by_num_in_string_entries(obj_coordinate_renderings) assert len(images) == len(segmentation_renderings) == len(obj_coordinate_renderings), "Number of files in input " \ "folders does not match." print("Populating datasets.") train_dataset = dataset.Dataset() val_dataset = dataset.Dataset() # Open the json files that hold the filenames for the respective datasets with open(os.path.join(base_path, config.TRAIN_FILE), 'r') as train_file, \ open(os.path.join(base_path, config.VAL_FILE), 'r') as val_file: train_filenames = json.load(train_file) val_filenames = json.load(val_file) # Fill training dict for i, image in enumerate(images): segmentation_image = segmentation_renderings[i] loaded_segmentation_image = cv2.imread( os.path.join(segmentations_path, segmentation_image)) # We do not want to scale object coordinates in the network because that creates # imprecisions. I.e. we can only pad object coordinates to fill the image size # but not resize them. This way, the object coordinates would not fit in the # batch arrays when width or height exceed the dimensions specified in the config. if loaded_segmentation_image.shape[0] > config.IMAGE_DIM or \ loaded_segmentation_image.shape[1] > config.IMAGE_DIM: raise Exception( "Image dimension exceeds image dim {} specified in config." "File: {} with size {}".\ format(config.IMAGE_DIM, image, loaded_segmentation_image.shape)) # TODO: add check if segmentation image contains color object_coordinate_image = obj_coordinate_renderings[i] image_path = os.path.join(images_path, image) segmentation_path = os.path.join(segmentations_path, segmentation_image) obj_coord_path = os.path.join(obj_coords_path, object_coordinate_image) # Check both cases, it might be that the image is not to be added at all if image in train_filenames: train_dataset.add_training_example(image_path, segmentation_path, obj_coord_path) elif image in val_filenames: val_dataset.add_training_example(image_path, segmentation_path, obj_coord_path) print("Added {} images for training and {} images for validation.". \ format(train_dataset.size(), val_dataset.size())) # Here we import the request model model = model_util.get_model(config.MODEL) network_model = model.FlowerPowerCNN('training', config, output_path) if weights_path != "": network_model.load_weights( weights_path, by_name=True, exclude=config.LAYERS_TO_EXCLUDE_FROM_WEIGHT_LOADING) print("Starting training.") network_model.train(train_dataset, val_dataset, config)
def inference(base_path, config): images_path = config.IMAGES_PATH image_extension = config.IMAGE_EXTENSION segmentation_images_path = config.SEGMENTATION_IMAGES_PATH segmentation_image_extension = config.SEGMENTATION_IMAGE_EXTENSION segmentation_color = config.SEGMENTATION_COLOR object_model_path = config.OBJECT_MODEL_PATH weights_path = config.WEIGHTS_PATH batch_size = config.BATCH_SIZE image_list = os.path.join(base_path, config.IMAGE_LIST) output_path = os.path.join(base_path, config.OUTPUT_PATH) if not os.path.exists(output_path): os.makedirs(output_path) assert os.path.exists(images_path), \ "The images path {} does not exist.".format(images_path) assert image_extension in ['png', 'jpg', 'jpeg'], \ "Unkown image extension." assert os.path.exists(segmentation_images_path), \ "The segmentation images path {} does not exist.".format(segmentation_images_path) assert segmentation_image_extension in ['png', 'jpg', 'jpeg'], \ "Unkown segmentation image extension." assert os.path.exists(object_model_path), \ "The object model file {} does not exist.".format(object_model_path) assert os.path.exists(weights_path), \ "The weights file {} does not exist.".format(weights_path) image_paths = util.get_files_at_path_of_extensions(images_path, image_extension) util.sort_list_by_num_in_string_entries(image_paths) segmentation_image_paths = util.get_files_at_path_of_extensions(segmentation_images_path, segmentation_image_extension) util.sort_list_by_num_in_string_entries(segmentation_image_paths) images = [] segmentation_images = [] cropped_segmentation_images = [] # Bounding boxes bbs = [] print("Preparing data.") temp_image_paths = [] temp_segmentation_image_paths = [] with open(image_list, "r") as loaded_image_list: images_to_process = json.load(loaded_image_list) for index in range(len(image_paths)): if image_paths[index] in images_to_process: temp_image_paths.append(image_paths[index]) temp_segmentation_image_paths.append(segmentation_image_paths[index]) image_paths = temp_image_paths segmentation_image_paths = temp_segmentation_image_paths # Prepare data, i.e. crop images to the segmentation mask for index in range(len(image_paths)): image_path = image_paths[index] image = cv2.imread(os.path.join(images_path, image_path)) segmentation_image_path = segmentation_image_paths[index] segmentation_image =cv2.imread(os.path.join(segmentation_images_path, segmentation_image_path)) # TODO: add check if segmentation image contains color image, frame = util.crop_image_on_segmentation_color( image, segmentation_image, segmentation_color, return_frame=True) bbs.append(frame) cropped_segmentation_image = util.crop_image_on_segmentation_color( segmentation_image, segmentation_image, segmentation_color) images.append(image) segmentation_images.append(segmentation_image) cropped_segmentation_images.append(cropped_segmentation_image) # Otherwise datatype is int64 which is not JSON serializable bbs = np.array(bbs).astype(np.int32) print("Running network inference.") # Here we import the request model model = model_util.get_model(config.MODEL) network_model = model.FlowerPowerCNN('inference', config, output_path) network_model.load_weights(weights_path, by_name=True) results = [] # We only store the filename + extension object_model_path = os.path.basename(object_model_path) current_batch = 0 batch_size = min(batch_size, len(images)) # Set batch size for the network to the current batch size network_model.config.BATCH_SIZE = batch_size while current_batch < len(images): # Account for that the number of images might not be divisible by the batch size batch_size = min(batch_size, len(images) - current_batch) batch_start = current_batch current_batch += batch_size batch_end = current_batch network_model.config.BATCH_SIZE = batch_size predictions = network_model.predict(images[batch_start:batch_end], cropped_segmentation_images[batch_start:batch_end], verbose=1) for index in range(len(predictions)): results.append({ "prediction" : predictions[index], "image" : image_paths[batch_start + index], "segmentation_image" : segmentation_image_paths[batch_start + index], "object_model" : object_model_path, "bb" : bbs[batch_start + index] }) return results
def export_predictions(args, dataset_cfg, model_cfg, output_path, export_dataset_mode): """ Export predictions. """ # Get the test configuration test_cfg = model_cfg["test"] # Create the dataset and dataloader based on the export_dataset_mode print("\t Initializing dataset and dataloader") batch_size = 4 export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg) export_loader = DataLoader(export_dataset, batch_size=batch_size, num_workers=test_cfg.get("num_workers", 4), shuffle=False, pin_memory=False, collate_fn=collate_fn) print("\t Successfully intialized dataset and dataloader.") # Initialize model and load the checkpoint model = get_model(model_cfg, mode="test") checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name) model = restore_weights(model, checkpoint["model_state_dict"]) model = model.cuda() model.eval() print("\t Successfully initialized model") # Start the export process print("[Info] Start exporting predictions") output_dataset_path = output_path + ".h5" filename_idx = 0 with h5py.File(output_dataset_path, "w", libver="latest", swmr=True) as f: # Iterate through all the data in dataloader for data in tqdm(export_loader, ascii=True): # Fetch the data junc_map = data["junction_map"] heatmap = data["heatmap"] valid_mask = data["valid_mask"] input_images = data["image"].cuda() # Run the forward pass with torch.no_grad(): outputs = model(input_images) # Convert predictions junc_np = convert_junc_predictions( outputs["junctions"], model_cfg["grid_size"], model_cfg["detection_thresh"], 300) junc_map_np = junc_map.numpy().transpose(0, 2, 3, 1) heatmap_np = softmax(outputs["heatmap"].detach(), dim=1).cpu().numpy().transpose(0, 2, 3, 1) heatmap_gt_np = heatmap.numpy().transpose(0, 2, 3, 1) valid_mask_np = valid_mask.numpy().transpose(0, 2, 3, 1) # Data entries to save current_batch_size = input_images.shape[0] for batch_idx in range(current_batch_size): output_data = { "image": input_images.cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], "junc_gt": junc_map_np[batch_idx], "junc_pred": junc_np["junc_pred"][batch_idx], "junc_pred_nms": junc_np["junc_pred_nms"][batch_idx].astype(np.float32), "heatmap_gt": heatmap_gt_np[batch_idx], "heatmap_pred": heatmap_np[batch_idx], "valid_mask": valid_mask_np[batch_idx], "junc_points": data["junctions"][batch_idx].numpy()[0].round().astype(np.int32), "line_map": data["line_map"][batch_idx].numpy()[0].astype(np.int32) } # Save data to h5 dataset num_pad = math.ceil(math.log10(len(export_loader))) + 1 output_key = get_padded_filename(num_pad, filename_idx) f_group = f.create_group(output_key) # Store data for key, output_data in output_data.items(): f_group.create_dataset(key, data=output_data, compression="gzip") filename_idx += 1
def export_homograpy_adaptation(args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device): """ Export homography adaptation results. """ # Check if the export_dataset_mode is supported supported_modes = ["train", "test"] if not export_dataset_mode in supported_modes: raise ValueError( "[Error] The specified export_dataset_mode is not supported.") # Get the test configuration test_cfg = model_cfg["test"] # Get the homography adaptation configurations homography_cfg = dataset_cfg.get("homography_adaptation", None) if homography_cfg is None: raise ValueError( "[Error] Empty homography_adaptation entry in config.") # Create the dataset and dataloader based on the export_dataset_mode print("\t Initializing dataset and dataloader") batch_size = args.export_batch_size export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg) export_loader = DataLoader(export_dataset, batch_size=batch_size, num_workers=test_cfg.get("num_workers", 4), shuffle=False, pin_memory=False, collate_fn=collate_fn) print("\t Successfully intialized dataset and dataloader.") # Initialize model and load the checkpoint model = get_model(model_cfg, mode="test") checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name, device) model = restore_weights(model, checkpoint["model_state_dict"]) model = model.to(device).eval() print("\t Successfully initialized model") # Start the export process print("[Info] Start exporting predictions") output_dataset_path = output_path + ".h5" with h5py.File(output_dataset_path, "w", libver="latest") as f: f.swmr_mode=True for _, data in enumerate(tqdm(export_loader, ascii=True)): input_images = data["image"].to(device) file_keys = data["file_key"] batch_size = input_images.shape[0] # Run the homograpy adaptation outputs = homography_adaptation(input_images, model, model_cfg["grid_size"], homography_cfg) # Save the entries for batch_idx in range(batch_size): # Get the save key save_key = file_keys[batch_idx] output_data = { "image": input_images.cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], "junc_prob_mean": outputs["junc_probs_mean"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], "junc_prob_max": outputs["junc_probs_max"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], "junc_count": outputs["junc_counts"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], "heatmap_prob_mean": outputs["heatmap_probs_mean"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], "heatmap_prob_max": outputs["heatmap_probs_max"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], "heatmap_cout": outputs["heatmap_counts"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx] } # Create group and write data f_group = f.create_group(save_key) for key, output_data in output_data.items(): f_group.create_dataset(key, data=output_data, compression="gzip")
def train_net(args, dataset_cfg, model_cfg, output_path): """ Main training function. """ # Add some version compatibility check if model_cfg.get("weighting_policy") is None: # Default to static model_cfg["weighting_policy"] = "static" # Get the train, val, test config train_cfg = model_cfg["train"] test_cfg = model_cfg["test"] # Create train and test dataset print("\t Initializing dataset...") train_dataset, train_collate_fn = get_dataset("train", dataset_cfg) test_dataset, test_collate_fn = get_dataset("test", dataset_cfg) # Create the dataloader train_loader = DataLoader(train_dataset, batch_size=train_cfg["batch_size"], num_workers=8, shuffle=True, pin_memory=True, collate_fn=train_collate_fn) test_loader = DataLoader(test_dataset, batch_size=test_cfg.get("batch_size", 1), num_workers=test_cfg.get("num_workers", 1), shuffle=False, pin_memory=False, collate_fn=test_collate_fn) print("\t Successfully intialized dataloaders.") # Get the loss function and weight first loss_funcs, loss_weights = get_loss_and_weights(model_cfg) # If resume. if args.resume: # Create model and load the state dict checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name) model = get_model(model_cfg, loss_weights) model = restore_weights(model, checkpoint["model_state_dict"]) model = model.cuda() optimizer = torch.optim.Adam([{ "params": model.parameters(), "initial_lr": model_cfg["learning_rate"] }], model_cfg["learning_rate"], amsgrad=True) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) # Optionally get the learning rate scheduler scheduler = get_lr_scheduler(lr_decay=model_cfg.get("lr_decay", False), lr_decay_cfg=model_cfg.get( "lr_decay_cfg", None), optimizer=optimizer) # If we start to use learning rate scheduler from the middle if ((scheduler is not None) and (checkpoint.get("scheduler_state_dict", None) is not None)): scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) start_epoch = checkpoint["epoch"] + 1 # Initialize all the components. else: # Create model and optimizer model = get_model(model_cfg, loss_weights) # Optionally get the pretrained wieghts if args.pretrained: print("\t [Debug] Loading pretrained weights...") checkpoint = get_latest_checkpoint(args.pretrained_path, args.checkpoint_name) # If auto weighting restore from non-auto weighting model = restore_weights(model, checkpoint["model_state_dict"], strict=False) print("\t [Debug] Finished loading pretrained weights!") model = model.cuda() optimizer = torch.optim.Adam([{ "params": model.parameters(), "initial_lr": model_cfg["learning_rate"] }], model_cfg["learning_rate"], amsgrad=True) # Optionally get the learning rate scheduler scheduler = get_lr_scheduler(lr_decay=model_cfg.get("lr_decay", False), lr_decay_cfg=model_cfg.get( "lr_decay_cfg", None), optimizer=optimizer) start_epoch = 0 print("\t Successfully initialized model") # Define the total loss policy = model_cfg.get("weighting_policy", "static") loss_func = TotalLoss(loss_funcs, loss_weights, policy).cuda() if "descriptor_decoder" in model_cfg: metric_func = Metrics(model_cfg["detection_thresh"], model_cfg["prob_thresh"], model_cfg["descriptor_loss_cfg"]["grid_size"], desc_metric_lst='all') else: metric_func = Metrics(model_cfg["detection_thresh"], model_cfg["prob_thresh"], model_cfg["grid_size"]) # Define the summary writer logdir = os.path.join(output_path, "log") writer = SummaryWriter(logdir=logdir) # Start the training loop for epoch in range(start_epoch, model_cfg["epochs"]): # Record the learning rate current_lr = optimizer.state_dict()["param_groups"][0]["lr"] writer.add_scalar("LR/lr", current_lr, epoch) # Train for one epochs print("\n\n================== Training ====================") train_single_epoch(model=model, model_cfg=model_cfg, optimizer=optimizer, loss_func=loss_func, metric_func=metric_func, train_loader=train_loader, writer=writer, epoch=epoch) # Do the validation print("\n\n================== Validation ==================") validate(model=model, model_cfg=model_cfg, loss_func=loss_func, metric_func=metric_func, val_loader=test_loader, writer=writer, epoch=epoch) # Update the scheduler if scheduler is not None: scheduler.step() # Save checkpoints file_name = os.path.join(output_path, "checkpoint-epoch%03d-end.tar" % (epoch)) print("[Info] Saving checkpoint %s ..." % file_name) save_dict = { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "model_cfg": model_cfg } if scheduler is not None: save_dict.update({"scheduler_state_dict": scheduler.state_dict()}) torch.save(save_dict, file_name) # Remove the outdated checkpoints remove_old_checkpoints(output_path, model_cfg.get("max_ckpt", 15))