예제 #1
0
    def visualize(self, use_model=False):
        if self.vis_index == len(self.dataset):
            raise Exception("No more examples to visualize in dataset")
        b = self.dataset[self.vis_index]
        if "schematic" in b:
            self.sp.drawGeoscorerPlotly(b["schematic"])
        c_sl = b["context"].size()[0]
        self.vis_index += 1
        self.sp.drawGeoscorerPlotly(b["context"])
        self.sp.drawGeoscorerPlotly(b["seg"])
        target_coord = su.index_to_coord(b["target"].item(), c_sl)
        combined_voxel = su.combine_seg_context(b["seg"],
                                                b["context"],
                                                target_coord,
                                                seg_mult=3)
        self.sp.drawGeoscorerPlotly(combined_voxel)

        if use_model:
            b = {k: t.unsqueeze(0) for k, t in b.items()}
            targets, scores = tu.get_scores_from_datapoint(
                self.model, b, self.opts)
            max_ind = torch.argmax(scores, dim=1)
            pred_coord = su.index_to_coord(max_ind, c_sl)
            b = {k: t.squeeze(0) for k, t in b.items()}
            predicted_voxel = su.combine_seg_context(b["seg"],
                                                     b["context"],
                                                     pred_coord,
                                                     seg_mult=3)
            self.sp.drawGeoscorerPlotly(predicted_voxel)
예제 #2
0
    def visualize(self, use_model=False, verbose=False):
        if self.vis_index == len(self.dataset):
            raise Exception("No more examples to visualize in dataset")
        b = self.dataset[self.vis_index]
        c_sl = b["context"].size()[0]
        if verbose:
            print("viewer_pos", b["viewer_pos"])
            print("viewer_look", b["viewer_look"])
            print("target_coord", su.index_to_coord(b["target"].item(), c_sl))
            print("dir_vec", b["dir_vec"])
            print("-----------\n")
        if "schematic" in b:
            self.sp.drawGeoscorerPlotly(b["schematic"])
        self.vis_index += 1
        self.sp.drawGeoscorerPlotly(b["context"])
        self.sp.drawGeoscorerPlotly(b["seg"])
        target_coord = su.index_to_coord(b["target"].item(), c_sl)
        combined_voxel = su.combine_seg_context(b["seg"],
                                                b["context"],
                                                target_coord,
                                                seg_mult=3)
        # Add in the viewer pos and look
        vp = b["viewer_pos"].long()
        vl = b["viewer_look"].long()
        combined_voxel[vl[0], vl[1], vl[2]] = 5
        combined_voxel[vp[0], :, vp[2]] = 10
        self.sp.drawGeoscorerPlotly(combined_voxel)

        if use_model:
            b = {k: t.unsqueeze(0) for k, t in b.items()}
            targets, scores = tu.get_scores_and_target_from_datapoint(
                self.model, b, self.opts)
            max_ind = torch.argmax(scores, dim=1)
            pred_coord = su.index_to_coord(max_ind, c_sl)
            b = {k: t.squeeze(0) for k, t in b.items()}
            predicted_voxel = su.combine_seg_context(b["seg"],
                                                     b["context"],
                                                     pred_coord,
                                                     seg_mult=3)
            self.sp.drawGeoscorerPlotly(predicted_voxel)
예제 #3
0
 def segment_context_to_pos(self, segment, context, dir_vec, viewer_pos):
     batch = {
         "context": context.unsqueeze(0),
         "seg": segment.unsqueeze(0),
         "dir_vec": dir_vec.unsqueeze(0),
         "viewer_pos": viewer_pos.unsqueeze(0),
         "viewer_look": torch.tensor([16.0, 16.0, 16.0]).unsqueeze(0),
     }
     scores = tu.get_scores_from_datapoint(self.tms, batch, self.tms["opts"])
     index = scores[0].flatten().max(0)[1]
     target_coord = su.index_to_coord(index.item(), self.context_sl)
     seg_origin = su.get_seg_origin_from_target_coord(segment, target_coord)
     return target_coord, seg_origin
예제 #4
0
    def segment_context_to_pos(self, segment, context):
        # Coords are in Z, X, Y, so put segment into same coords
        segment = segment.permute(1, 2, 0).contiguous()

        c_embed = self.context_net(context.unsqueeze(0))
        s_embed = self.seg_net(segment.unsqueeze(0))
        scores = self.score_module([c_embed, s_embed])
        index = scores[0].flatten().max(0)[1]
        target_coord = index_to_coord(index.item(), self.context_sl)

        # Then take final coord back into X, Y, Z coords
        final_target_coord = (target_coord[2], target_coord[0],
                              target_coord[1])
        return final_target_coord
예제 #5
0
 def _get_example(self):
     if not self.use_direction:
         return get_inst_seg_example(self.seg_data, self.drop_perc,
                                     self.c_sl, self.s_sl, self.useid)
     else:
         example = get_inst_seg_example(self.seg_data, self.drop_perc,
                                        self.c_sl, self.s_sl, self.useid)
         viewer_pos, viewer_look = du.get_random_viewer_info(self.c_sl)
         target_coord = torch.tensor(
             su.index_to_coord(example["target"], self.c_sl))
         example["dir_vec"] = du.get_sampled_direction_vec(
             viewer_pos, viewer_look, target_coord)
         example["viewer_pos"] = viewer_pos
         return example
예제 #6
0
 def _get_example(self):
     # note that seg_sparse is not in target location
     context_sparse, c_sizes, seg_sparse, s_sizes = get_two_shape_sparse(self.c_sl, self.s_sl)
     viewer_pos, viewer_look = du.get_random_viewer_info(self.c_sl)
     dir_vec = du.random_dir_vec_tensor()
     target = get_shape_dir_target(
         viewer_pos, dir_vec, c_sizes, s_sizes, self.c_sl, self.max_shift
     )
     if self.ground_type is not None:
         target_coord = su.index_to_coord(target, self.c_sl)
         su.add_ground_to_context(
             context_sparse, target_coord, flat=(self.ground_type == "flat")
         )
     context = su.get_dense_array_from_sl(context_sparse, self.c_sl, self.useid)
     seg = su.get_dense_array_from_sl(seg_sparse, self.s_sl, self.useid)
     return {
         "context": torch.from_numpy(context),
         "seg": torch.from_numpy(seg),
         "target": torch.tensor([target]),
         "viewer_pos": viewer_pos,
         "dir_vec": dir_vec,
     }
예제 #7
0
    def _get_example(self):
        # Get the raw context and seg
        context_sparse, seg_sparse = get_context_seg_sparse(
            self.seg_data, self.drop_perc)
        # Convert into an example
        example = su.sparse_context_seg_in_space_to_example(
            context_sparse,
            seg_sparse,
            self.c_sl,
            self.s_sl,
            self.useid,
            self.ground_type,
            self.random_ground_height,
        )

        # Add the direction info
        target_coord = torch.tensor(
            su.index_to_coord(example["target"], self.c_sl))
        example["viewer_pos"], example[
            "dir_vec"] = du.get_random_vp_and_max_dir_vec(
                example["viewer_look"], target_coord, self.c_sl)
        return example
예제 #8
0
    def _get_example(self):
        # Get the raw context and seg
        uncentered_context_sparse, uncentered_seg_sparse = get_shape_segment(
            max_chunk=self.s_sl - 1, side_length=self.c_sl
        )
        # Convert into an example, without direction info
        example = su.sparse_context_seg_in_space_to_example(
            uncentered_context_sparse,
            uncentered_seg_sparse,
            self.c_sl,
            self.s_sl,
            self.useid,
            self.ground_type,
            self.random_ground_height,
        )

        # Add the direction info
        target_coord = torch.tensor(su.index_to_coord(example["target"], self.c_sl))
        example["viewer_pos"], example["dir_vec"] = du.get_random_vp_and_max_dir_vec(
            example["viewer_look"], target_coord, self.c_sl
        )
        return example
예제 #9
0
def eval_loop(tms, DL, opts, vis=True, wrong_counts=False):
    tu.set_modules(tms, train=False)
    dlit = iter(DL)
    c_sl = opts["context_sidelength"]
    max_allowed_key = "min_dist_above_{}".format(opts["max_allowed_dist"])

    wrong_keys = []
    if wrong_counts:
        wrong_keys = (["dir_vec", "vpvl", "vpvl_dir_vec"], )

    metrics = EvalMetrics(
        [
            "overlap_count",
            "out_of_bounds_count",
            max_allowed_key,
            "dir_wrong_count",
            "dist_from_target",
        ],
        wrong_keys=wrong_keys,
    )
    n = len(dlit)
    if vis:
        viz = GeoscorerVisualizer()

    for j in range(n):
        metrics.reset_batch()
        batch = dlit.next()
        targets, scores = tu.get_scores_and_target_from_datapoint(
            tms, batch, opts)
        max_ind = torch.argmax(scores, dim=1)

        if vis:
            it = range(3)
        else:
            it = range(batch["context"].size()[0])

        if opts["tqdm"]:
            it = tqdm(it)

        for i in it:
            context = batch["context"][i]
            seg = batch["seg"][i]
            vp = batch["viewer_pos"][i]
            vl = batch["viewer_look"][i]
            dir_vec = batch["dir_vec"][i]

            predicted_ind = max_ind[i]
            target_ind = targets[i]
            predicted_coord = su.index_to_coord(predicted_ind.cpu().item(),
                                                c_sl)
            target_coord = su.index_to_coord(target_ind.cpu().item(), c_sl)

            if vis:
                viz.visualize(context)
                viz.visualize(seg)
                viz.visualize_combined(context, seg, target_coord, vp, vl)
                viz.visualize_combined(context, seg, predicted_coord)

            # TODO: some of these calcs can be dramatically sped up using tensors
            c_tuple = set([s[0] for s in su.sparsify_voxel(context)])
            s_tuple = set([s[0] for s in su.sparsify_voxel(seg)])
            predicted_seg_voxel = su.shift_sparse_voxel(
                [(s, (0, 0)) for s in s_tuple],
                predicted_coord,
                min_b=[0, 0, 0],
                max_b=[c_sl, c_sl, c_sl],
            )
            predicted_seg_tuple = set([v[0] for v in predicted_seg_voxel])

            res = {
                "overlap_count": len(predicted_seg_tuple & c_tuple),
                "out_of_bounds_count": len(s_tuple) - len(predicted_seg_tuple),
                "min_dist_above_8": 0,
                "dir_wrong_count": 0,
                "dist_from_target": su.euclid_dist(predicted_coord,
                                                   target_coord),
            }

            min_dist = su.get_min_block_pair_dist(predicted_seg_tuple, c_tuple)
            if min_dist > opts["max_allowed_dist"]:
                res[max_allowed_key] = 1
            if vis:
                print("batch info:")
                print("  viewer_pos", vp)
                print("  viewer_look", vl)
                print("  target_coord", target_coord)
                print("  predicted_coord", predicted_coord)
                print("  dir_vec", dir_vec)
                print("-------------\n")
            if not dir_correct(vp, vl, dir_vec, predicted_coord):
                res["dir_wrong_count"] = 1
                if wrong_counts:
                    dvl = dir_vec.tolist()
                    vpl = vp[[0, 2]].tolist() + vl[[0, 2]].tolist()
                    metrics.add_to_wrong_values("dir_vec", tuple(dvl))
                    metrics.add_to_wrong_values("vpvl", tuple(vpl))
                    metrics.add_to_wrong_values("vpvl_dir_vec",
                                                tuple(vpl + dvl))
            metrics.update_batch_elem(res)

        metrics.print_metrics(j)
        metrics.print_wrong_values()
        if vis:
            input("Press enter to continue.")
    # To include the last batch in the overall aggregates
    metrics.reset_batch()
    metrics.print_metrics(0, overall=True)
    metrics.print_wrong_values()