コード例 #1
0
    def __init__(self,
                 model_cfg,
                 ckpt_path,
                 device,
                 line_detector_cfg,
                 line_matcher_cfg,
                 multiscale=False,
                 scales=[1., 2.]):
        # Get loss weights if dynamic weighting
        _, loss_weights = get_loss_and_weights(model_cfg, device)
        self.device = device

        # Initialize the cnn backbone
        self.model = get_model(model_cfg, loss_weights)
        checkpoint = torch.load(ckpt_path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.model = self.model.to(self.device)
        self.model = self.model.eval()

        self.grid_size = model_cfg["grid_size"]
        self.junc_detect_thresh = model_cfg["detection_thresh"]
        self.max_num_junctions = model_cfg.get("max_num_junctions", 300)

        # Initialize the line detector
        self.line_detector = LineSegmentDetectionModule(**line_detector_cfg)
        self.multiscale = multiscale
        self.scales = scales

        # Initialize the line matcher
        self.line_matcher = WunschLineMatcher(**line_matcher_cfg)

        # Print some debug messages
        for key, val in line_detector_cfg.items():
            print(f"[Debug] {key}: {val}")
コード例 #2
0
    def __init__(self,
                 model_cfg,
                 ckpt_path,
                 device,
                 line_detector_cfg,
                 junc_detect_thresh=None):
        """ SOLD² line detector taking raw images as input.
        Parameters:
            model_cfg: config for CNN model
            ckpt_path: path to the weights
            line_detector_cfg: config file for the line detection module
        """
        # Get loss weights if dynamic weighting
        _, loss_weights = get_loss_and_weights(model_cfg, device)
        self.device = device

        # Initialize the cnn backbone
        self.model = get_model(model_cfg, loss_weights)
        checkpoint = torch.load(ckpt_path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.model = self.model.to(self.device)
        self.model = self.model.eval()

        self.grid_size = model_cfg["grid_size"]

        if junc_detect_thresh is not None:
            self.junc_detect_thresh = junc_detect_thresh
        else:
            self.junc_detect_thresh = model_cfg["detection_thresh"]
            self.junc_detect_thresh = 1 / 65

        # Initialize the line detector
        self.line_detector_cfg = line_detector_cfg
        self.line_detector = LineSegmentDetectionModule(**line_detector_cfg)
コード例 #3
0
def train(base_path, config):

    # We do not retrieve the color from the config (it should not be specified anyway)
    # because we render our own segmentation images using white color

    image_extension = config.IMAGE_EXTENSION
    object_model_path = config.OBJECT_MODEL_PATH
    ground_truth_path = config.GT_PATH
    data_path = config.DATA_PATH
    weights_path = config.WEIGHTS_PATH
    output_path = os.path.join(base_path, config.OUTPUT_PATH)

    assert os.path.exists(
        object_model_path), "The object model file {} does not exist.".format(
            object_model_path)
    assert os.path.exists(
        ground_truth_path), "The ground-truth file {} does not exist.".format(
            ground_truth_path)

    if weights_path != "":
        assert os.path.exists(
            weights_path), "The weights file {} does not exist.".format(
                weights_path)

    # The paths where we store the generated object coordinate images as well as
    # the cropped original and segmentation images
    images_path = os.path.join(data_path, "images")
    segmentations_path = os.path.join(data_path, "segmentations")
    obj_coords_path = os.path.join(data_path, "obj_coords")

    # Retrieve the rendered images to add it to the datasets
    images = util.get_files_at_path_of_extensions(images_path,
                                                  [image_extension])
    util.sort_list_by_num_in_string_entries(images)
    segmentation_renderings = util.get_files_at_path_of_extensions(
        segmentations_path, ['png'])
    util.sort_list_by_num_in_string_entries(segmentation_renderings)
    obj_coordinate_renderings = util.get_files_at_path_of_extensions(
        obj_coords_path, ['tiff'])
    util.sort_list_by_num_in_string_entries(obj_coordinate_renderings)

    assert len(images) == len(segmentation_renderings) == len(obj_coordinate_renderings), "Number of files in input " \
                                                                                          "folders does not match."

    print("Populating datasets.")

    train_dataset = dataset.Dataset()
    val_dataset = dataset.Dataset()

    # Open the json files that hold the filenames for the respective datasets
    with open(os.path.join(base_path, config.TRAIN_FILE), 'r') as train_file, \
         open(os.path.join(base_path, config.VAL_FILE), 'r') as val_file:
        train_filenames = json.load(train_file)
        val_filenames = json.load(val_file)

        # Fill training dict
        for i, image in enumerate(images):
            segmentation_image = segmentation_renderings[i]
            loaded_segmentation_image = cv2.imread(
                os.path.join(segmentations_path, segmentation_image))
            # We do not want to scale object coordinates in the network because that creates
            # imprecisions. I.e. we can only pad object coordinates to fill the image size
            # but not resize them. This way, the object coordinates would not fit in the
            # batch arrays when width or height exceed the dimensions specified in the config.
            if loaded_segmentation_image.shape[0] > config.IMAGE_DIM or \
               loaded_segmentation_image.shape[1] > config.IMAGE_DIM:
                raise Exception(
                    "Image dimension exceeds image dim {} specified in config."
                                            "File: {} with size {}".\
                               format(config.IMAGE_DIM, image, loaded_segmentation_image.shape))

            # TODO: add check if segmentation image contains color

            object_coordinate_image = obj_coordinate_renderings[i]

            image_path = os.path.join(images_path, image)
            segmentation_path = os.path.join(segmentations_path,
                                             segmentation_image)
            obj_coord_path = os.path.join(obj_coords_path,
                                          object_coordinate_image)

            # Check both cases, it might be that the image is not to be added at all
            if image in train_filenames:
                train_dataset.add_training_example(image_path,
                                                   segmentation_path,
                                                   obj_coord_path)
            elif image in val_filenames:
                val_dataset.add_training_example(image_path, segmentation_path,
                                                 obj_coord_path)

    print("Added {} images for training and {} images for validation.". \
                        format(train_dataset.size(), val_dataset.size()))

    # Here we import the request model
    model = model_util.get_model(config.MODEL)
    network_model = model.FlowerPowerCNN('training', config, output_path)
    if weights_path != "":
        network_model.load_weights(
            weights_path,
            by_name=True,
            exclude=config.LAYERS_TO_EXCLUDE_FROM_WEIGHT_LOADING)

    print("Starting training.")
    network_model.train(train_dataset, val_dataset, config)
コード例 #4
0
def inference(base_path, config):

    images_path = config.IMAGES_PATH
    image_extension = config.IMAGE_EXTENSION 
    segmentation_images_path = config.SEGMENTATION_IMAGES_PATH
    segmentation_image_extension = config.SEGMENTATION_IMAGE_EXTENSION
    segmentation_color = config.SEGMENTATION_COLOR
    object_model_path = config.OBJECT_MODEL_PATH 
    weights_path = config.WEIGHTS_PATH 
    batch_size = config.BATCH_SIZE
    image_list = os.path.join(base_path, config.IMAGE_LIST)
    output_path = os.path.join(base_path, config.OUTPUT_PATH)

    if not os.path.exists(output_path):
        os.makedirs(output_path)

    assert os.path.exists(images_path), \
            "The images path {} does not exist.".format(images_path)
    assert image_extension in ['png', 'jpg', 'jpeg'], \
            "Unkown image extension."
    assert os.path.exists(segmentation_images_path), \
            "The segmentation images path {} does not exist.".format(segmentation_images_path)
    assert segmentation_image_extension in ['png', 'jpg', 'jpeg'], \
            "Unkown segmentation image extension."
    assert os.path.exists(object_model_path), \
            "The object model file {} does not exist.".format(object_model_path)
    assert os.path.exists(weights_path), \
            "The weights file {} does not exist.".format(weights_path)

    image_paths = util.get_files_at_path_of_extensions(images_path, image_extension)
    util.sort_list_by_num_in_string_entries(image_paths)
    segmentation_image_paths = util.get_files_at_path_of_extensions(segmentation_images_path, segmentation_image_extension)
    util.sort_list_by_num_in_string_entries(segmentation_image_paths)

    images = []
    segmentation_images = []
    cropped_segmentation_images = []
    # Bounding boxes
    bbs = []

    print("Preparing data.")

    temp_image_paths = []
    temp_segmentation_image_paths = []

    with open(image_list, "r") as loaded_image_list:
        images_to_process = json.load(loaded_image_list)
        for index in range(len(image_paths)):
            if image_paths[index] in images_to_process:
                temp_image_paths.append(image_paths[index])
                temp_segmentation_image_paths.append(segmentation_image_paths[index])

    image_paths = temp_image_paths
    segmentation_image_paths = temp_segmentation_image_paths

    # Prepare data, i.e. crop images to the segmentation mask
    for index in range(len(image_paths)):
        image_path = image_paths[index]
        image = cv2.imread(os.path.join(images_path, image_path))
        segmentation_image_path = segmentation_image_paths[index]
        segmentation_image =cv2.imread(os.path.join(segmentation_images_path, segmentation_image_path))

        # TODO: add check if segmentation image contains color

        image, frame = util.crop_image_on_segmentation_color(
                        image, segmentation_image, segmentation_color, return_frame=True)
        bbs.append(frame)
        cropped_segmentation_image = util.crop_image_on_segmentation_color(
                                segmentation_image, segmentation_image, segmentation_color)
        images.append(image)
        segmentation_images.append(segmentation_image)
        cropped_segmentation_images.append(cropped_segmentation_image)

    # Otherwise datatype is int64 which is not JSON serializable
    bbs = np.array(bbs).astype(np.int32)

    print("Running network inference.")
    # Here we import the request model
    model = model_util.get_model(config.MODEL)
    network_model = model.FlowerPowerCNN('inference', config, output_path)
    network_model.load_weights(weights_path, by_name=True)

    results = []
    # We only store the filename + extension
    object_model_path = os.path.basename(object_model_path)

    current_batch = 0
    batch_size = min(batch_size, len(images))
    # Set batch size for the network to the current batch size
    network_model.config.BATCH_SIZE = batch_size

    while current_batch < len(images):
        # Account for that the number of images might not be divisible by the batch size
        batch_size = min(batch_size, len(images) - current_batch)
        batch_start = current_batch
        current_batch += batch_size
        batch_end = current_batch
        network_model.config.BATCH_SIZE = batch_size
        predictions = network_model.predict(images[batch_start:batch_end], cropped_segmentation_images[batch_start:batch_end], verbose=1)
        for index in range(len(predictions)):
            results.append({    "prediction" : predictions[index], 
                                "image" : image_paths[batch_start + index], 
                                "segmentation_image" : segmentation_image_paths[batch_start + index],
                                "object_model" : object_model_path,
                                "bb" : bbs[batch_start + index]
                            })

    return results
コード例 #5
0
ファイル: export.py プロジェクト: wx-b/SOLD2
def export_predictions(args, dataset_cfg, model_cfg, output_path,
                       export_dataset_mode):
    """ Export predictions. """
    # Get the test configuration
    test_cfg = model_cfg["test"]

    # Create the dataset and dataloader based on the export_dataset_mode
    print("\t Initializing dataset and dataloader")
    batch_size = 4
    export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg)
    export_loader = DataLoader(export_dataset, batch_size=batch_size,
                               num_workers=test_cfg.get("num_workers", 4),
                               shuffle=False, pin_memory=False,
                               collate_fn=collate_fn)
    print("\t Successfully intialized dataset and dataloader.")

    # Initialize model and load the checkpoint
    model = get_model(model_cfg, mode="test")
    checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name)
    model = restore_weights(model, checkpoint["model_state_dict"])
    model = model.cuda()
    model.eval()
    print("\t Successfully initialized model")

    # Start the export process
    print("[Info] Start exporting predictions")
    output_dataset_path = output_path + ".h5"
    filename_idx = 0
    with h5py.File(output_dataset_path, "w", libver="latest", swmr=True) as f:
        # Iterate through all the data in dataloader
        for data in tqdm(export_loader, ascii=True):
            # Fetch the data
            junc_map = data["junction_map"]
            heatmap = data["heatmap"]
            valid_mask = data["valid_mask"]
            input_images = data["image"].cuda()

            # Run the forward pass
            with torch.no_grad():
                outputs = model(input_images)

            # Convert predictions
            junc_np = convert_junc_predictions(
                outputs["junctions"], model_cfg["grid_size"],
                model_cfg["detection_thresh"], 300)
            junc_map_np = junc_map.numpy().transpose(0, 2, 3, 1)
            heatmap_np = softmax(outputs["heatmap"].detach(),
                                 dim=1).cpu().numpy().transpose(0, 2, 3, 1)
            heatmap_gt_np = heatmap.numpy().transpose(0, 2, 3, 1)
            valid_mask_np = valid_mask.numpy().transpose(0, 2, 3, 1)

            # Data entries to save
            current_batch_size = input_images.shape[0]
            for batch_idx in range(current_batch_size):
                output_data = {
                    "image": input_images.cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
                    "junc_gt": junc_map_np[batch_idx],
                    "junc_pred": junc_np["junc_pred"][batch_idx],
                    "junc_pred_nms": junc_np["junc_pred_nms"][batch_idx].astype(np.float32),
                    "heatmap_gt": heatmap_gt_np[batch_idx],
                    "heatmap_pred": heatmap_np[batch_idx],
                    "valid_mask": valid_mask_np[batch_idx],
                    "junc_points": data["junctions"][batch_idx].numpy()[0].round().astype(np.int32),
                    "line_map": data["line_map"][batch_idx].numpy()[0].astype(np.int32)
                }

                # Save data to h5 dataset
                num_pad = math.ceil(math.log10(len(export_loader))) + 1
                output_key = get_padded_filename(num_pad, filename_idx)
                f_group = f.create_group(output_key)

                # Store data
                for key, output_data in output_data.items():
                    f_group.create_dataset(key, data=output_data,
                                           compression="gzip")
                filename_idx += 1
コード例 #6
0
ファイル: export.py プロジェクト: wx-b/SOLD2
def export_homograpy_adaptation(args, dataset_cfg, model_cfg, output_path,
                                export_dataset_mode, device):
    """ Export homography adaptation results. """
    # Check if the export_dataset_mode is supported
    supported_modes = ["train", "test"]
    if not export_dataset_mode in supported_modes:
        raise ValueError(
            "[Error] The specified export_dataset_mode is not supported.")

    # Get the test configuration
    test_cfg = model_cfg["test"]

    # Get the homography adaptation configurations
    homography_cfg = dataset_cfg.get("homography_adaptation", None)
    if homography_cfg is None:
        raise ValueError(
            "[Error] Empty homography_adaptation entry in config.")

    # Create the dataset and dataloader based on the export_dataset_mode
    print("\t Initializing dataset and dataloader")
    batch_size = args.export_batch_size

    export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg)
    export_loader = DataLoader(export_dataset, batch_size=batch_size,
                               num_workers=test_cfg.get("num_workers", 4),
                               shuffle=False, pin_memory=False,
                               collate_fn=collate_fn)
    print("\t Successfully intialized dataset and dataloader.")

    # Initialize model and load the checkpoint
    model = get_model(model_cfg, mode="test")
    checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name,
                                       device)
    model = restore_weights(model, checkpoint["model_state_dict"])
    model = model.to(device).eval()
    print("\t Successfully initialized model")

    # Start the export process
    print("[Info] Start exporting predictions")    
    output_dataset_path = output_path + ".h5"
    with h5py.File(output_dataset_path, "w", libver="latest") as f:
        f.swmr_mode=True
        for _, data in enumerate(tqdm(export_loader, ascii=True)):
            input_images = data["image"].to(device)
            file_keys = data["file_key"]
            batch_size = input_images.shape[0]
            
            # Run the homograpy adaptation
            outputs = homography_adaptation(input_images, model,
                                            model_cfg["grid_size"],
                                            homography_cfg)

            # Save the entries
            for batch_idx in range(batch_size):
                # Get the save key
                save_key = file_keys[batch_idx]
                output_data = {
                    "image": input_images.cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
                    "junc_prob_mean": outputs["junc_probs_mean"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
                    "junc_prob_max": outputs["junc_probs_max"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
                    "junc_count": outputs["junc_counts"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
                    "heatmap_prob_mean": outputs["heatmap_probs_mean"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
                    "heatmap_prob_max": outputs["heatmap_probs_max"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
                    "heatmap_cout": outputs["heatmap_counts"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx]
                }

                # Create group and write data
                f_group = f.create_group(save_key)
                for key, output_data in output_data.items():
                    f_group.create_dataset(key, data=output_data,
                                           compression="gzip")
コード例 #7
0
ファイル: train.py プロジェクト: wx-b/SOLD2
def train_net(args, dataset_cfg, model_cfg, output_path):
    """ Main training function. """
    # Add some version compatibility check
    if model_cfg.get("weighting_policy") is None:
        # Default to static
        model_cfg["weighting_policy"] = "static"

    # Get the train, val, test config
    train_cfg = model_cfg["train"]
    test_cfg = model_cfg["test"]

    # Create train and test dataset
    print("\t Initializing dataset...")
    train_dataset, train_collate_fn = get_dataset("train", dataset_cfg)
    test_dataset, test_collate_fn = get_dataset("test", dataset_cfg)

    # Create the dataloader
    train_loader = DataLoader(train_dataset,
                              batch_size=train_cfg["batch_size"],
                              num_workers=8,
                              shuffle=True,
                              pin_memory=True,
                              collate_fn=train_collate_fn)
    test_loader = DataLoader(test_dataset,
                             batch_size=test_cfg.get("batch_size", 1),
                             num_workers=test_cfg.get("num_workers", 1),
                             shuffle=False,
                             pin_memory=False,
                             collate_fn=test_collate_fn)
    print("\t Successfully intialized dataloaders.")

    # Get the loss function and weight first
    loss_funcs, loss_weights = get_loss_and_weights(model_cfg)

    # If resume.
    if args.resume:
        # Create model and load the state dict
        checkpoint = get_latest_checkpoint(args.resume_path,
                                           args.checkpoint_name)
        model = get_model(model_cfg, loss_weights)
        model = restore_weights(model, checkpoint["model_state_dict"])
        model = model.cuda()
        optimizer = torch.optim.Adam([{
            "params": model.parameters(),
            "initial_lr": model_cfg["learning_rate"]
        }],
                                     model_cfg["learning_rate"],
                                     amsgrad=True)
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        # Optionally get the learning rate scheduler
        scheduler = get_lr_scheduler(lr_decay=model_cfg.get("lr_decay", False),
                                     lr_decay_cfg=model_cfg.get(
                                         "lr_decay_cfg", None),
                                     optimizer=optimizer)
        # If we start to use learning rate scheduler from the middle
        if ((scheduler is not None) and
            (checkpoint.get("scheduler_state_dict", None) is not None)):
            scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
    # Initialize all the components.
    else:
        # Create model and optimizer
        model = get_model(model_cfg, loss_weights)
        # Optionally get the pretrained wieghts
        if args.pretrained:
            print("\t [Debug] Loading pretrained weights...")
            checkpoint = get_latest_checkpoint(args.pretrained_path,
                                               args.checkpoint_name)
            # If auto weighting restore from non-auto weighting
            model = restore_weights(model,
                                    checkpoint["model_state_dict"],
                                    strict=False)
            print("\t [Debug] Finished loading pretrained weights!")

        model = model.cuda()
        optimizer = torch.optim.Adam([{
            "params": model.parameters(),
            "initial_lr": model_cfg["learning_rate"]
        }],
                                     model_cfg["learning_rate"],
                                     amsgrad=True)
        # Optionally get the learning rate scheduler
        scheduler = get_lr_scheduler(lr_decay=model_cfg.get("lr_decay", False),
                                     lr_decay_cfg=model_cfg.get(
                                         "lr_decay_cfg", None),
                                     optimizer=optimizer)
        start_epoch = 0

    print("\t Successfully initialized model")

    # Define the total loss
    policy = model_cfg.get("weighting_policy", "static")
    loss_func = TotalLoss(loss_funcs, loss_weights, policy).cuda()
    if "descriptor_decoder" in model_cfg:
        metric_func = Metrics(model_cfg["detection_thresh"],
                              model_cfg["prob_thresh"],
                              model_cfg["descriptor_loss_cfg"]["grid_size"],
                              desc_metric_lst='all')
    else:
        metric_func = Metrics(model_cfg["detection_thresh"],
                              model_cfg["prob_thresh"], model_cfg["grid_size"])

    # Define the summary writer
    logdir = os.path.join(output_path, "log")
    writer = SummaryWriter(logdir=logdir)

    # Start the training loop
    for epoch in range(start_epoch, model_cfg["epochs"]):
        # Record the learning rate
        current_lr = optimizer.state_dict()["param_groups"][0]["lr"]
        writer.add_scalar("LR/lr", current_lr, epoch)

        # Train for one epochs
        print("\n\n================== Training ====================")
        train_single_epoch(model=model,
                           model_cfg=model_cfg,
                           optimizer=optimizer,
                           loss_func=loss_func,
                           metric_func=metric_func,
                           train_loader=train_loader,
                           writer=writer,
                           epoch=epoch)

        # Do the validation
        print("\n\n================== Validation ==================")
        validate(model=model,
                 model_cfg=model_cfg,
                 loss_func=loss_func,
                 metric_func=metric_func,
                 val_loader=test_loader,
                 writer=writer,
                 epoch=epoch)

        # Update the scheduler
        if scheduler is not None:
            scheduler.step()

        # Save checkpoints
        file_name = os.path.join(output_path,
                                 "checkpoint-epoch%03d-end.tar" % (epoch))
        print("[Info] Saving checkpoint %s ..." % file_name)
        save_dict = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "model_cfg": model_cfg
        }
        if scheduler is not None:
            save_dict.update({"scheduler_state_dict": scheduler.state_dict()})
        torch.save(save_dict, file_name)

        # Remove the outdated checkpoints
        remove_old_checkpoints(output_path, model_cfg.get("max_ckpt", 15))