예제 #1
0
def main():
    #####################################################################################################
    # Options
    #####################################################################################################

    # Parse command line arguments.
    process_arguments()

    split = Parameters.generate_split.value.value  # the second value gets the lc string representation

    # Model checkpoint to use
    saved_model = get_saved_model()
    # Dataset dir
    dataset_base_dir = Parameters.path.dataset_base_directory.value

    # Image dimensions to which we crop the input images, such that they are divisible by 64
    image_height = Parameters.alignment.image_height.value
    image_width = Parameters.alignment.image_width.value

    if Parameters.deform_net.gn_max_matches_eval.value != 100000:
        raise ValueError(
            f"For whatever sunny reason (based on legacy code), "
            f"{Parameters.deform_net.gn_max_matches_eval.value} must be exactly 100000"
        )

    if Parameters.deform_net.threshold_mask_predictions.value:
        raise ValueError(
            f"For whatever sunny reason (based on legacy code), "
            f"{Parameters.deform_net.threshold_mask_predictions.value} must be set to False for generate.py"
        )

    #####################################################################################################
    # Read labels and assert existence of output dir
    #####################################################################################################

    labels_json = os.path.join(dataset_base_dir, f"{split}_graphs.json")

    assert os.path.isfile(
        labels_json
    ), f"{labels_json} does not exist! Make sure you specified the correct 'data_root_dir'."

    with open(labels_json, 'r') as f:
        labels = json.loads(f.read())

    # Output dir
    output_dir = os.path.join(Parameters.path.nn_data_directory.value,
                              "models", Parameters.model.model_name.value)
    output_dir = f"{output_dir}/evaluation/{split}"
    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)
        print("Created output dir", output_dir)
        print()

    #####################################################################################################
    # Model
    #####################################################################################################

    assert os.path.isfile(saved_model), f"Model {saved_model} does not exist."

    # Construct alignment
    model = load_default_nnrt_network(o3c.Device.DeviceType.CUDA)

    #####################################################################################################
    # Go over dataset
    #####################################################################################################

    for label in tqdm(labels):
        src_color_image_path = os.path.join(dataset_base_dir,
                                            label["source_color"])
        src_depth_image_path = os.path.join(dataset_base_dir,
                                            label["source_depth"])
        tgt_color_image_path = os.path.join(dataset_base_dir,
                                            label["target_color"])
        tgt_depth_image_path = os.path.join(dataset_base_dir,
                                            label["target_depth"])
        graph_nodes_path = os.path.join(dataset_base_dir, label["graph_nodes"])
        graph_edges_path = os.path.join(dataset_base_dir, label["graph_edges"])
        graph_edges_weights_path = os.path.join(dataset_base_dir,
                                                label["graph_edges_weights"])
        graph_clusters_path = os.path.join(dataset_base_dir,
                                           label["graph_clusters"])
        pixel_anchors_path = os.path.join(dataset_base_dir,
                                          label["pixel_anchors"])
        pixel_weights_path = os.path.join(dataset_base_dir,
                                          label["pixel_weights"])

        intrinsics = label["intrinsics"]

        print(src_color_image_path)

        # Source color and depth
        source, _, cropper = DeformDataset.load_image(src_color_image_path,
                                                      src_depth_image_path,
                                                      intrinsics, image_height,
                                                      image_width)

        source_points = np.copy(source[3:, :, :])  # 3, h, w

        # Target color and depth (and boundary mask)
        target, _, _ = DeformDataset.load_image(tgt_color_image_path,
                                                tgt_depth_image_path,
                                                intrinsics,
                                                image_height,
                                                image_width,
                                                cropper=cropper,
                                                max_boundary_distance=None,
                                                compute_boundary_mask=False)

        # Graph
        graph_nodes, graph_edges, graph_edges_weights, _, graph_clusters = \
            DeformDataset.load_graph_data(
                graph_nodes_path, graph_edges_path, graph_edges_weights_path, None, graph_clusters_path
            )
        pixel_anchors, pixel_weights = DeformDataset.load_anchors_and_weights(
            pixel_anchors_path, pixel_weights_path, cropper)

        num_nodes = np.array(graph_nodes.shape[0], dtype=np.int64)

        # Update intrinsics to reflect the crops
        fx, fy, cx, cy = image_processing.modify_intrinsics_due_to_cropping(
            intrinsics['fx'],
            intrinsics['fy'],
            intrinsics['cx'],
            intrinsics['cy'],
            image_height,
            image_width,
            original_h=cropper.h,
            original_w=cropper.w)

        intrinsics = np.zeros((4), dtype=np.float32)
        intrinsics[0] = fx
        intrinsics[1] = fy
        intrinsics[2] = cx
        intrinsics[3] = cy

        #####################################################################################################
        # Predict deformation
        #####################################################################################################

        # Move to device and unsqueeze in the batch dimension (to have batch size 1)
        source_cuda = torch.from_numpy(source).cuda().unsqueeze(0)
        target_cuda = torch.from_numpy(target).cuda().unsqueeze(0)
        graph_nodes_cuda = torch.from_numpy(graph_nodes).cuda().unsqueeze(0)
        graph_edges_cuda = torch.from_numpy(graph_edges).cuda().unsqueeze(0)
        graph_edges_weights_cuda = torch.from_numpy(
            graph_edges_weights).cuda().unsqueeze(0)
        graph_clusters_cuda = torch.from_numpy(
            graph_clusters).cuda().unsqueeze(0)
        pixel_anchors_cuda = torch.from_numpy(pixel_anchors).cuda().unsqueeze(
            0)
        pixel_weights_cuda = torch.from_numpy(pixel_weights).cuda().unsqueeze(
            0)
        intrinsics_cuda = torch.from_numpy(intrinsics).cuda().unsqueeze(0)

        num_nodes_cuda = torch.from_numpy(num_nodes).cuda().unsqueeze(0)

        with torch.no_grad():
            model_data = model(source_cuda,
                               target_cuda,
                               graph_nodes_cuda,
                               graph_edges_cuda,
                               graph_edges_weights_cuda,
                               graph_clusters_cuda,
                               pixel_anchors_cuda,
                               pixel_weights_cuda,
                               num_nodes_cuda,
                               intrinsics_cuda,
                               evaluate=True,
                               split="test")

        # Get predicted graph deformation
        node_rotations_pred = model_data["node_rotations"].view(
            num_nodes, 3, 3).cpu().numpy()
        node_translations_pred = model_data["node_translations"].view(
            num_nodes, 3).cpu().numpy()

        # Warp source points with predicted graph deformation
        warped_source_points = image_processing.warp_deform_3d(
            source, pixel_anchors, pixel_weights, graph_nodes,
            node_rotations_pred, node_translations_pred)

        # Compute dense 3d flow
        scene_flow_pred = warped_source_points - source_points

        # Save predictions
        seq_id = label["seq_id"]
        object_id = label["object_id"]
        source_id = label["source_id"]
        target_id = label["target_id"]

        sample_name = f"{seq_id}_{object_id}_{source_id}_{target_id}"

        node_translations_pred_file = os.path.join(
            output_dir, f"{sample_name}_node_translations.bin")
        scene_flow_pred_file = os.path.join(output_dir,
                                            f"{sample_name}_sceneflow.sflow")

        io.save_graph_node_deformations(node_translations_pred_file,
                                        node_translations_pred)

        io.save_flow(scene_flow_pred_file, scene_flow_pred)
예제 #2
0
avg_distance = None

GOAL_INCHES = 18

# a = 122520
# b = -3203.4
# c = 23.766

x = -620090
y = -98.894
c = 620650
d = 603600
g = -719350
h = -0.22849
""" PROCESS COMMAND LINE FLAGS """
settings.process_arguments(sys.argv)
settings.print_settings()
""" INITIALIZE MODULES """
# serialoutput.init_serial(settings.port, settings.baudrate)
videoinput.open_stream(settings.device)

if settings.save_video:
    t0 = Thread(target=videooutput.start_recording, args=(settings.codec, ))
    t0.daemon = True
    t0.start()
    # videooutput.start_recording(settings.codec)

t1 = Thread(target=networking.try_connection_streaming)
t2 = Thread(target=networking.try_connection_start_stop_writing)

t1.daemon = True  # thread automatically closes when main thread closes
예제 #3
0
def main():
    #####################################################################################################
    # Options
    #####################################################################################################

    # Parse command line arguments.
    process_arguments()

    split = Parameters.evaluate_split.value.value

    dataset_base_dir = Parameters.path.dataset_base_directory.value
    experiments_dir = Parameters.path.nn_data_directory.value
    model_name = Parameters.model.model_name.value
    gn_max_depth = Parameters.deform_net.gn_max_depth.value

    # Image dimensions to which we crop the input images, such that they are divisible by 64
    alignment_image_height = Parameters.alignment.image_height.value
    alignment_image_width = Parameters.alignment.image_width.value

    #####################################################################################################
    # Read labels and check existence of output dir
    #####################################################################################################

    labels_json = os.path.join(dataset_base_dir, f"{split}_graphs.json")

    assert os.path.isfile(
        labels_json
    ), f"{labels_json} does not exist! Make sure you specified the correct 'dataset_base_dir' in options.py."

    with open(labels_json, 'r') as f:
        labels = json.loads(f.read())

    # Output dir
    model_base_dir = os.path.join(experiments_dir, "models", model_name)
    predictions_dir = f"{model_base_dir}/evaluation/{split}"
    if not os.path.isdir(predictions_dir):
        raise Exception(
            f"Predictions directory {predictions_dir} does not exist. Please generate predictions with 'run_generate.sh' first."
        )

    #####################################################################################################
    # Go over dataset
    #####################################################################################################

    # Graph error (EPE on graph canonical_node_positions)
    graph_error_3d_sum = 0.0
    total_num_nodes = 0

    # Dense EPE 3D
    epe3d_sum = 0.0
    total_num_points = 0

    for label in tqdm(labels):
        assert "graph_node_deformations" in label, "It's highly probable that you're running this script with 'split' set to 'test'. " \
                                                   "but the public dataset does not provide gt for the test set. Plase choose 'val' if " \
                                                   "you want to compute metrics."

        ##############################################################################################
        # Load gt
        ##############################################################################################
        src_color_image_path = os.path.join(dataset_base_dir,
                                            label["source_color"])
        src_depth_image_path = os.path.join(dataset_base_dir,
                                            label["source_depth"])
        graph_nodes_path = os.path.join(dataset_base_dir, label["graph_nodes"])
        graph_edges_path = os.path.join(dataset_base_dir, label["graph_edges"])
        graph_edges_weights_path = os.path.join(dataset_base_dir,
                                                label["graph_edges_weights"])
        graph_node_deformations_path = os.path.join(
            dataset_base_dir, label["graph_node_deformations"])
        graph_clusters_path = os.path.join(dataset_base_dir,
                                           label["graph_clusters"])
        pixel_anchors_path = os.path.join(dataset_base_dir,
                                          label["pixel_anchors"])
        pixel_weights_path = os.path.join(dataset_base_dir,
                                          label["pixel_weights"])
        optical_flow_image_path = os.path.join(dataset_base_dir,
                                               label["optical_flow"])
        scene_flow_image_path = os.path.join(dataset_base_dir,
                                             label["scene_flow"])

        intrinsics = label["intrinsics"]

        # Source color and depth
        source, _, cropper = DeformDataset.load_image(src_color_image_path,
                                                      src_depth_image_path,
                                                      intrinsics,
                                                      alignment_image_height,
                                                      alignment_image_width)
        source_points = source[3:, :, :]

        # Graph
        graph_nodes, graph_edges, graph_edges_weights, graph_node_deformations, graph_clusters = \
            DeformDataset.load_graph_data(
                graph_nodes_path, graph_edges_path, graph_edges_weights_path, graph_node_deformations_path, graph_clusters_path
            )

        pixel_anchors, pixel_weights = DeformDataset.load_anchors_and_weights(
            pixel_anchors_path, pixel_weights_path, cropper)

        optical_flow_gt, optical_flow_mask, scene_flow_gt, scene_flow_mask = DeformDataset.load_flow(
            optical_flow_image_path, scene_flow_image_path, cropper)

        # mask is duplicated across feature dimension, so we can safely take the first channel
        scene_flow_mask = scene_flow_mask[0].astype(bool)
        optical_flow_mask = optical_flow_mask[0].astype(bool)

        # All points that have valid optical flow should also have valid scene flow
        assert np.array_equal(scene_flow_mask, optical_flow_mask)

        num_source_points = np.sum(scene_flow_mask)

        # if num_source_points > 100000:
        #     print(label["source_color"], num_source_points)

        ##############################################################################################
        # Load predictions
        ##############################################################################################
        seq_id = label["seq_id"]
        object_id = label["object_id"]
        source_id = label["source_id"]
        target_id = label["target_id"]

        sample_name = f"{seq_id}_{object_id}_{source_id}_{target_id}"

        node_translations_pred_file = os.path.join(
            predictions_dir, f"{sample_name}_node_translations.bin")
        scene_flow_pred_file = os.path.join(predictions_dir,
                                            f"{sample_name}_sceneflow.sflow")

        assert os.path.isfile(
            node_translations_pred_file
        ), f"{node_translations_pred_file} does not exist. Make sure you are not missing any prediction."
        assert os.path.isfile(
            scene_flow_pred_file
        ), f"{scene_flow_pred_file} does not exist. Make sure you are not missing any prediction."

        node_translations_pred = io.load_graph_node_deformations(
            node_translations_pred_file)

        scene_flow_pred = io.load_flow(scene_flow_pred_file)

        ##############################################################################################
        # Compute metrics
        ##############################################################################################

        ######################
        # Node translations (graph_node_deformations are the groundtruth graph canonical_node_positions translations)
        ######################
        graph_error_3d_dict = EPE_3D_eval(graph_node_deformations,
                                          node_translations_pred)
        graph_error_3d_sum += graph_error_3d_dict["sum"]
        total_num_nodes += graph_error_3d_dict["num"]

        ######################
        # Scene flow
        ######################

        # First, get valid source points
        source_anchor_validity = np.all(pixel_anchors >= 0.0, axis=2)

        valid_source_points = np.logical_and.reduce([
            source_points[2, :, :] > 0.0,
            source_points[2, :, :] <= gn_max_depth, source_anchor_validity,
            scene_flow_mask, optical_flow_mask
        ])

        scene_flow_gt = np.moveaxis(scene_flow_gt, 0, -1)
        scene_flow_pred = np.moveaxis(scene_flow_pred, 0, -1)

        deformed_points_gt = scene_flow_gt[valid_source_points]
        deformed_points_pred = scene_flow_pred[valid_source_points]

        epe_3d_dict = EPE_3D_eval(deformed_points_gt, deformed_points_pred)
        epe3d_sum += epe_3d_dict["sum"]
        total_num_points += epe_3d_dict["num"]

    # Compute average errors
    graph_error_3d_avg = graph_error_3d_sum / total_num_nodes
    epe3d_avg = epe3d_sum / total_num_points

    print(f"Graph Error 3D (mm): {graph_error_3d_avg * 1000.0}")
    print(f"EPE 3D (mm):         {epe3d_avg * 1000.0}")

    # Write to file
    with open(f"{model_base_dir}/{model_name}__ON__{split}.txt", "w") as f:
        f.write("\n")
        f.write("Evaluation results:\n\n")
        f.write("\n")
        f.write("Model: {0}\n".format(model_name))
        f.write("Split: {0}\n".format(split))
        f.write("\n")
        f.write("{:<40} {}\n".format("Graph Error 3D (mm)",
                                     graph_error_3d_avg * 1000.0))
        f.write("{:<40} {}\n".format("EPE 3D (mm)", epe3d_avg * 1000.0))
예제 #4
0
def main():
    process_arguments()
    #####################################################################################################
    # Options
    #####################################################################################################

    # Source-target example
    frame_pair_preset: FramePairPreset = FramePairPreset.RED_SHORTS_200_400
    frame_pair_name = frame_pair_preset.name.lower()
    frame_pair_dataset: FramePairDataset = frame_pair_preset.value
    frame_pair_dataset.load()

    save_node_transformations = False

    source_frame_index = frame_pair_dataset.source_frame_index
    target_frame_index = frame_pair_dataset.target_frame_index
    segment_name = frame_pair_dataset.segment_name

    #####################################################################################################
    # Load alignment
    #####################################################################################################

    saved_model = get_saved_model()

    assert os.path.isfile(saved_model), f"Model {saved_model} does not exist."
    pretrained_dict = torch.load(saved_model)

    # Construct alignment
    model = load_default_nnrt_network(o3c.Device.CUDA)

    #####################################################################################################
    # Load example dataset
    #####################################################################################################
    intrinsics = load_intrinsic_matrix_entries_as_dict_from_text_4x4_matrix(
        frame_pair_dataset.get_intrinsics_path())

    alignment_image_height = Parameters.alignment.image_height.value
    alignment_image_width = Parameters.alignment.image_width.value
    max_boundary_distance = Parameters.alignment.max_boundary_distance.value

    src_color_image_path = frame_pair_dataset.get_source_color_image_path()
    src_depth_image_path = frame_pair_dataset.get_source_depth_image_path()
    tgt_color_image_path = frame_pair_dataset.get_target_color_image_path()
    tgt_depth_image_path = frame_pair_dataset.get_target_depth_image_path()

    # Source color and depth
    source_rgbxyz, _, cropper = DeformDataset.load_image(
        src_color_image_path, src_depth_image_path, intrinsics,
        alignment_image_height, alignment_image_width)

    # Target color and depth (and boundary mask)
    target_rgbxyz, target_boundary_mask, _ = DeformDataset.load_image(
        tgt_color_image_path,
        tgt_depth_image_path,
        intrinsics,
        alignment_image_height,
        alignment_image_width,
        cropper=cropper,
        max_boundary_distance=max_boundary_distance,
        compute_boundary_mask=True)

    # Graph
    graph_nodes, graph_edges, graph_edges_weights, _, graph_clusters = \
        DeformDataset.load_graph_data(frame_pair_dataset.get_sequence_directory(),
                                      frame_pair_dataset.graph_filename, False)

    pixel_anchors, pixel_weights = DeformDataset.load_anchors_and_weights_from_sequence_directory_and_graph_filename(
        frame_pair_dataset.get_sequence_directory(),
        frame_pair_dataset.graph_filename, cropper)

    num_nodes = np.array(graph_nodes.shape[0], dtype=np.int64)

    # Update intrinsics to reflect the crops
    fx, fy, cx, cy = image_processing.modify_intrinsics_due_to_cropping(
        intrinsics['fx'],
        intrinsics['fy'],
        intrinsics['cx'],
        intrinsics['cy'],
        alignment_image_height,
        alignment_image_width,
        original_h=cropper.h,
        original_w=cropper.w)

    intrinsics = np.zeros((4), dtype=np.float32)
    intrinsics[0] = fx
    intrinsics[1] = fy
    intrinsics[2] = cx
    intrinsics[3] = cy

    #####################################################################################################
    # region ======= Predict deformation ================================================================
    #####################################################################################################

    # Move to device and unsqueeze in the batch dimension (to have batch size 1)
    source_rgbxyz_cuda = torch.from_numpy(source_rgbxyz).cuda().unsqueeze(0)
    target_rgbxyz_cuda = torch.from_numpy(target_rgbxyz).cuda().unsqueeze(0)
    target_boundary_mask_cuda = torch.from_numpy(
        target_boundary_mask).cuda().unsqueeze(0)
    graph_nodes_cuda = torch.from_numpy(graph_nodes).cuda().unsqueeze(0)
    graph_edges_cuda = torch.from_numpy(graph_edges).cuda().unsqueeze(0)
    graph_edges_weights_cuda = torch.from_numpy(
        graph_edges_weights).cuda().unsqueeze(0)
    graph_clusters_cuda = torch.from_numpy(graph_clusters).cuda().unsqueeze(0)
    pixel_anchors_cuda = torch.from_numpy(pixel_anchors).cuda().unsqueeze(0)
    pixel_weights_cuda = torch.from_numpy(pixel_weights).cuda().unsqueeze(0)
    intrinsics_cuda = torch.from_numpy(intrinsics).cuda().unsqueeze(0)

    num_nodes_cuda = torch.from_numpy(num_nodes).cuda().unsqueeze(0)

    with torch.no_grad():
        model_data = model(source_rgbxyz_cuda,
                           target_rgbxyz_cuda,
                           graph_nodes_cuda,
                           graph_edges_cuda,
                           graph_edges_weights_cuda,
                           graph_clusters_cuda,
                           pixel_anchors_cuda,
                           pixel_weights_cuda,
                           num_nodes_cuda,
                           intrinsics_cuda,
                           evaluate=True,
                           split="test")

    # Get some of the results
    rotations_pred = model_data["node_rotations"].view(num_nodes, 3,
                                                       3).cpu().numpy()
    translations_pred = model_data["node_translations"].view(num_nodes,
                                                             3).cpu().numpy()
    if save_node_transformations:
        # Save rotations & translations
        with open(
                'output/{:s}_{:s}_{:06d}_{:06d}_rotations.np'.format(
                    frame_pair_name, segment_name, source_frame_index,
                    target_frame_index), 'wb') as file:
            np.save(file, rotations_pred)
        with open(
                'output/{:s}_{:s}_{:06d}_{:06d}_translations.np'.format(
                    frame_pair_name, segment_name, source_frame_index,
                    target_frame_index), 'wb') as file:
            np.save(file, translations_pred)

    mask_pred = model_data["mask_pred"]
    assert mask_pred is not None, "Make sure deform_net.use_mask is set to true in the configuration"
    mask_pred = mask_pred.view(-1, alignment_image_height,
                               alignment_image_width).cpu().numpy()

    # Compute mask gt for mask baseline
    _, source_points, valid_source_points, target_matches, valid_target_matches, valid_correspondences, _, _ \
        = model_data["correspondence_info"]

    target_matches = target_matches.view(-1, alignment_image_height,
                                         alignment_image_width).cpu().numpy()
    valid_source_points = valid_source_points.view(
        -1, alignment_image_height, alignment_image_width).cpu().numpy()
    valid_target_matches = valid_target_matches.view(
        -1, alignment_image_height, alignment_image_width).cpu().numpy()
    valid_correspondences = valid_correspondences.view(
        -1, alignment_image_height, alignment_image_width).cpu().numpy()

    # Delete tensors to free up memory
    del source_rgbxyz_cuda
    del target_rgbxyz_cuda
    del target_boundary_mask_cuda
    del graph_nodes_cuda
    del graph_edges_cuda
    del graph_edges_weights_cuda
    del graph_clusters_cuda
    del pixel_anchors_cuda
    del pixel_weights_cuda
    del intrinsics_cuda

    del model

    # endregion

    tracking_viz.visualize_tracking(source_rgbxyz, target_rgbxyz,
                                    pixel_anchors, pixel_weights, graph_nodes,
                                    graph_edges, rotations_pred,
                                    translations_pred, mask_pred,
                                    valid_source_points, valid_correspondences,
                                    target_matches)
예제 #5
0
def evaluate(model, criterion, dataloader, batch_num, split):
    process_arguments()
    dataset_obj = dataloader.dataset
    dataset_batch_size = dataloader.batch_size
    total_size = len(dataset_obj)

    # Losses
    loss_sum = 0.0
    loss_flow_sum = 0.0
    loss_graph_sum = 0.0
    loss_warp_sum = 0.0
    loss_mask_sum = 0.0

    max_num_batches = int(math.ceil(total_size / dataset_batch_size))
    total_num_batches = batch_num if batch_num != -1 else max_num_batches
    total_num_batches = min(max_num_batches, total_num_batches)

    # Metrics
    epe2d_sum_0 = 0.0
    epe2d_sum_2 = 0.0
    epe3d_sum = 0.0
    epe_warp_sum = 0.0

    total_num_pixels_0 = 0
    total_num_pixels_2 = 0
    total_num_nodes = 0
    total_num_points = 0

    num_valid_solves = 0
    num_total_solves = 0

    total_corres_weight_sum = 0.0
    total_corres_valid_num = 0

    print()

    for i, data in enumerate(dataloader):
        if i >= total_num_batches:
            break

        sys.stdout.write("\r############# Eval iteration: {0} / {1}".format(
            i + 1, total_num_batches))
        sys.stdout.flush()

        source, target, target_boundary_mask, \
            optical_flow_gt, optical_flow_mask, scene_flow_gt, scene_flow_mask, \
                    graph_nodes, graph_edges, graph_edges_weights, translations_gt, graph_clusters, \
                        pixel_anchors, pixel_weights, num_nodes, intrinsics, sample_idx = data

        source = source.cuda()
        target = target.cuda()
        target_boundary_mask = target_boundary_mask.cuda()
        optical_flow_gt = optical_flow_gt.cuda()
        optical_flow_mask = optical_flow_mask.cuda()
        scene_flow_gt = scene_flow_gt.cuda()
        scene_flow_mask = scene_flow_mask.cuda()
        graph_nodes = graph_nodes.cuda()
        graph_edges = graph_edges.cuda()
        graph_edges_weights = graph_edges_weights.cuda()
        translations_gt = translations_gt.cuda()
        graph_clusters = graph_clusters.cuda()
        pixel_anchors = pixel_anchors.cuda()
        pixel_weights = pixel_weights.cuda()
        intrinsics = intrinsics.cuda()

        batch_size = source.shape[0]

        alignment_image_width = Parameters.alignment.image_width.value
        alignment_image_height = Parameters.alignment.image_height.value

        # Build data for coarser level
        assert alignment_image_height % 64 == 0 and alignment_image_width % 64 == 0
        optical_flow_gt2 = torch.nn.functional.interpolate(
            input=optical_flow_gt.clone() / 20.0,
            size=(alignment_image_height // 4, alignment_image_width // 4),
            mode='nearest')
        optical_flow_mask2 = torch.nn.functional.interpolate(
            input=optical_flow_mask.clone().float(),
            size=(alignment_image_height // 4, alignment_image_width // 4),
            mode='nearest').bool()
        assert (torch.isfinite(optical_flow_gt2).all().item())
        assert (torch.isfinite(optical_flow_mask2).all().item())

        with torch.no_grad():
            # Predictions.
            model_data = model(source,
                               target,
                               graph_nodes,
                               graph_edges,
                               graph_edges_weights,
                               graph_clusters,
                               pixel_anchors,
                               pixel_weights,
                               num_nodes,
                               intrinsics,
                               evaluate=True,
                               split=split)
            optical_flow_pred2 = model_data["flow_data"][0]
            optical_flow_pred = 20.0 * torch.nn.functional.interpolate(
                input=optical_flow_pred2,
                size=(alignment_image_height, alignment_image_width),
                mode='bilinear',
                align_corners=False)

            translations_pred = model_data["node_translations"]

            total_corres_weight_sum += model_data["weight_info"][
                "total_corres_weight"]
            total_corres_valid_num += model_data["weight_info"][
                "total_corres_num"]

            # Compute mask gt for mask baseline
            xy_coords_warped, source_points, valid_source_points, target_matches, \
                valid_target_matches, valid_correspondences, deformed_points_idxs, \
                    deformed_points_subsampled = model_data["correspondence_info"]

            mask_gt, valid_mask_pixels = nn_utilities.compute_baseline_mask_gt(
                xy_coords_warped, target_matches, valid_target_matches,
                source_points, valid_source_points, scene_flow_gt,
                scene_flow_mask, target_boundary_mask, Parameters.training.
                baseline.max_pos_flowed_source_to_target_dist.value,
                Parameters.training.baseline.
                min_neg_flowed_source_to_target_dist.value)

            # Compute deformed point gt
            deformed_points_gt, deformed_points_mask = nn_utilities.compute_deformed_points_gt(
                source_points, scene_flow_gt, model_data["valid_solve"],
                valid_correspondences, deformed_points_idxs,
                deformed_points_subsampled)

            # Loss.
            loss, loss_flow, loss_graph, loss_warp, loss_mask = criterion(
                [optical_flow_gt], [optical_flow_pred], [optical_flow_mask],
                translations_gt,
                model_data["node_translations"],
                model_data["deformations_validity"],
                deformed_points_gt,
                model_data["deformed_points_pred"],
                deformed_points_mask,
                model_data["valid_solve"],
                num_nodes,
                model_data["mask_pred"],
                mask_gt,
                valid_mask_pixels,
                evaluate=True)

            loss_sum += loss.item()
            loss_flow_sum += loss_flow.item(
            ) if Parameters.training.loss.use_flow_loss.value else -1
            loss_graph_sum += loss_graph.item(
            ) if Parameters.training.loss.use_graph_loss.value else -1
            loss_warp_sum += loss_warp.item(
            ) if Parameters.training.loss.use_warp_loss.value else -1
            loss_mask_sum += loss_mask.item(
            ) if Parameters.training.loss.use_mask_loss.value else -1

            # Metrics.
            # A.1) End Point Error in Optical Flow for FINEST level
            epe2d_dict = criterion.epe_2d(optical_flow_gt, optical_flow_pred,
                                          optical_flow_mask)
            epe2d_sum_0 += epe2d_dict["sum"]
            total_num_pixels_0 += epe2d_dict["num"]

            # A.2) End Point Error in Optical Flow for PYRAMID level 2 (4 times lower rez than finest level)
            epe2d_dict = criterion.epe_2d(optical_flow_gt2, optical_flow_pred2,
                                          optical_flow_mask2)
            epe2d_sum_2 += epe2d_dict["sum"]
            total_num_pixels_2 += epe2d_dict["num"]

            # B) Deformation translation/angle difference.
            # Important: We also evaluate canonical_node_positions that were filtered at optimization (and were assigned
            # identity deformation).
            for k in range(batch_size):
                # We validate node deformation of both valid and invalid solves (for invalid
                # solves, the prediction should be identity transformation).
                num_total_solves += 1

                t_gt = translations_gt[k].view(1, -1, 3)
                t_pred = translations_pred[k].view(1, -1, 3)
                deformations_validity = model_data["deformations_validity"][
                    k].view(1, -1)

                # For evaluation of all canonical_node_positions (valid or invalid), we take all canonical_node_positions into account.
                deformations_validity_all = torch.zeros_like(
                    deformations_validity)
                deformations_validity_all[:, :int(num_nodes[k])] = 1

                epe3d_dict = criterion.epe_3d(t_gt, t_pred,
                                              deformations_validity_all)
                epe3d_sum += epe3d_dict["sum"]
                total_num_nodes += epe3d_dict["num"]

                # If valid solve, add to valid solves.
                if not model_data["valid_solve"][k]: continue
                num_valid_solves += 1

            # C) End Point Error in Warped 3D Points
            epe_warp_dict = criterion.epe_warp(
                deformed_points_gt, model_data["deformed_points_pred"],
                deformed_points_mask)
            if epe_warp_dict is not None:
                epe_warp_sum += epe_warp_dict["sum"]
                total_num_points += epe_warp_dict["num"]

    # Losses
    loss_avg = loss_sum / total_num_batches
    loss_flow_avg = loss_flow_sum / total_num_batches
    loss_graph_avg = loss_graph_sum / total_num_batches
    loss_warp_avg = loss_warp_sum / total_num_batches
    loss_mask_avg = loss_mask_sum / total_num_batches

    losses = {
        "total": loss_avg,
        "flow": loss_flow_avg,
        "graph": loss_graph_avg,
        "warp": loss_warp_avg,
        "mask": loss_mask_avg
    }

    # Metrics.
    epe2d_avg_0 = epe2d_sum_0 / total_num_pixels_0 if total_num_pixels_0 > 0 else -1.0
    epe2d_avg_2 = epe2d_sum_2 / total_num_pixels_2 if total_num_pixels_2 > 0 else -1.0
    epe3d_avg = epe3d_sum / total_num_nodes if total_num_nodes > 0 else -1.0
    epe_warp_avg = epe_warp_sum / total_num_points if total_num_points > 0 else -1.0
    valid_ratio = num_valid_solves / num_total_solves if num_total_solves > 0 else -1

    if total_corres_valid_num > 0:
        print(" Average correspondence weight: {0:.3f}".format(
            total_corres_weight_sum / total_corres_valid_num))

    metrics = {
        "epe2d_0": epe2d_avg_0,
        "epe2d_2": epe2d_avg_2,
        "epe3d": epe3d_avg,
        "epe_warp": epe_warp_avg,
        "num_valid_solves": num_valid_solves,
        "num_total_solves": num_total_solves,
        "valid_ratio": valid_ratio,
    }

    return losses, metrics
예제 #6
0
from alignment import evaluate, nn_utilities

import torch
import open3d.core as o3c
from tensorboardX import SummaryWriter

from alignment.default import load_default_nnrt_network
from data import DeformDataset
from alignment import DeformLoss, SnapshotManager, TimeStatistics
from settings import process_arguments, Parameters
import ext_argparse

from settings.model import get_saved_model

if __name__ == "__main__":
    args = process_arguments()
    torch.set_num_threads(Parameters.training.num_threads.value)
    torch.backends.cudnn.benchmark = False

    # Training set
    train_labels_name = Parameters.training.train_labels_name.value

    # Validation set
    validation_labels_name = Parameters.training.validation_labels_name.value

    timestamp = Parameters.training.timestamp.value

    experiment_name = Parameters.training.experiment.value

    #####################################################################################
    # Ask user input regarding the use of data augmentation