Ejemplo n.º 1
0
    def run(self):

        self.model.train()

        for epoch in range(self.EPOCH_START, self.EPOCHS):

            self.epoch = epoch
            self.epoch_ts = time.time()
            self.running_loss = 0.0

            utils.iterate_loader(self.DEVICE, self.loaders["train"],
                                 self.__train_step_fn)

            if self.SHOULD_CHECKPOINT:
                checkpoint_file = f"epoch_{epoch+1}.pt"
                path = os.path.join(self.checkpoint_dir, checkpoint_file)
                print(f"Saving checkpoint: {path}")
                torch.save(
                    {
                        "epoch": epoch,
                        "model": self.model.state_dict(),
                        "optimizer": self.optimizer.state_dict()
                    }, path)

        if self.SHOULD_WRITE:
            self.writer.close()
Ejemplo n.º 2
0
    def __validate(self):
        N = len(self.loaders["val"])
        val_metrics = {}
        train_metrics = {}

        utils.iterate_loader(self.DEVICE,
                             self.loaders["val"],
                             self.__val_step_fn,
                             args=(val_metrics, ))
        utils.iterate_loader(self.DEVICE,
                             self.loaders["train"],
                             self.__val_step_fn,
                             args=(train_metrics, ),
                             end=N)

        val_metrics = utils.map_dict(val_metrics, lambda v: v / N)
        train_metrics = utils.map_dict(train_metrics, lambda v: v / N)

        return val_metrics, train_metrics
Ejemplo n.º 3
0
    def run(self):

        self.model.eval()

        utils.iterate_loader(self.DEVICE,
                             self.loader,
                             self._step_fn,
                             end=len(self.loader))

        mean_ates = np.mean(self.ates)
        std_ates = np.std(self.ates)
        mean_metrics = utils.dict_mean(self.metrics)
        std_metrics = utils.dict_std(self.metrics)

        print("\n" + "=" * 20)
        print(f"Trajectory error: {mean_ates:0.3f}, std: {std_ates:0.3f}")
        print("Depth metrics:")
        for key in self.metrics:
            print(
                f"{key} -> mean: {mean_metrics[key]:0.3f}, std: {std_metrics[key]:0.3f}"
            )
        print("=" * 20)
Ejemplo n.º 4
0
def main():
    # Parse arguments
    args = options.get_args(description="Debug a network",
                            options=[
                                "batch",
                                "workers",
                                "device",
                                "load",
                                "net",
                                "loss",
                            ])

    if args.load == "":
        print("No model file specified to load!")
        exit()

    # Construct datasets
    random.seed(1337)
    _, _, test_loader = utils.get_kitti_split(args.batch, args.workers)

    # The model
    model = networks.architectures.get_net(args)
    loss_fn = sfm_loss.get_loss_fn(args)

    # Load
    checkpoint = torch.load(args.load, map_location=torch.device(args.device))
    model.load_state_dict(checkpoint["model"])

    # Window
    pango.CreateWindowAndBind('Main', int(640 * (3 / 2)), int(480 * (3 / 2)))
    gl.glEnable(gl.GL_DEPTH_TEST)

    # Define Projection and initial ModelView matrix
    scam = pango.OpenGlRenderState(
        pango.ProjectionMatrix(640, 480, 420, 420, 320, 240, 0.2, 200),
        pango.ModelViewLookAt(0, -0.5, -0.5, 0, 0, 1, 0, -1, 0))
    handler = pango.Handler3D(scam)

    # Create Interactive View in window
    dcam = pango.CreateDisplay()
    dcam.SetBounds(pango.Attach(0), pango.Attach(1), pango.Attach(0),
                   pango.Attach(1), -640.0 / 480.0)
    dcam.SetHandler(handler)

    # Test
    model.train()
    model = model.to(args.device)

    def step_fn(step, inputs):

        # Forward pass and loss
        with torch.no_grad():
            loss, data = utils.forward_pass(model, loss_fn, inputs)

        print("loss %f" % loss.item())

        print(data.keys())

        print(data["pose"].shape)
        for i in range(args.batch):
            print(list(data["pose"][i, 0, :].cpu().detach().numpy()))
            print(list(data["pose"][i, 1, :].cpu().detach().numpy()))
            print("--")

        depth_img = viz.tensor2depthimg(
            torch.cat((*data["depth"][0][:, 0], ), dim=0))
        tgt_img = viz.tensor2img(torch.cat((*data["tgt"], ), dim=1))
        img = np.concatenate((tgt_img, depth_img), axis=1)

        warp_imgs = []
        #diff_imgs = []
        for warp, diff in zip(data["warp"], data["diff"]):
            warp = restack(restack(warp, 1, -1), 0, -2)
            diff = restack(restack(diff, 1, -1), 0, -2)
            warp_imgs.append(viz.tensor2img(warp))
            #diff_imgs.append(viz.tensor2diffimg(diff))

        world = reconstruction.depth_to_3d_points(data["depth"][0], data["K"])
        points = world[0, :].view(3, -1).transpose(
            1, 0).cpu().detach().numpy().astype(np.float64)
        colors = (data["tgt"][0, :].view(3, -1).transpose(
            1, 0).cpu().detach().numpy().astype(np.float64) + 1) / 2

        loop = True
        while loop:
            key = cv2.waitKey(10)
            if key == 27 or pango.ShouldQuit():
                exit()
            elif key != -1:
                loop = False
            cv2.imshow("target and depth", img)
            #for i, (warp, diff) in enumerate(zip(warp_imgs, diff_imgs)):
            for i, warp in enumerate(warp_imgs):
                cv2.imshow("warp scale: %d" % i, warp)
                #cv2.imshow("diff scale: %d" % i, diff)

            gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
            gl.glClearColor(1.0, 1.0, 1.0, 1.0)
            dcam.Activate(scam)
            gl.glPointSize(5)
            pango.DrawPoints(points, colors)
            pose = np.identity(4)
            pose[:3, 3] = 0
            gl.glLineWidth(1)
            gl.glColor3f(0.0, 0.0, 1.0)
            pango.DrawCamera(pose, 0.5, 0.75, 0.8)
            pango.FinishFrame()

    utils.iterate_loader(args, test_loader, step_fn)
Ejemplo n.º 5
0
def main():
    # Parse arguments
    args = options.get_args(description="Debug a network",
                            options=[
                                "batch",
                                "workers",
                                "device",
                                "load",
                                "smooth-weight",
                                "explain-weight",
                                "net",
                            ])

    if args.load == "":
        print("No model file specified to load!")
        exit()

    # Construct datasets
    random.seed(1337)
    _, _, test_loader = utils.get_kitti_split(args.batch, args.workers)

    # The model
    model = networks.architectures.get_net(args)
    loss_fn = sfm_loss.get_loss_fn(args)

    # Load
    checkpoint = torch.load(args.load, map_location=torch.device(args.device))
    model.load_state_dict(checkpoint["model"])

    fig = mlab.figure(figure=None,
                      bgcolor=(0, 0, 0),
                      fgcolor=None,
                      engine=None,
                      size=(1000, 500))

    # Test
    model.train()
    model = model.to(args.device)

    def step_fn(step, inputs):

        # Forward pass and loss
        with torch.no_grad():
            loss, data = utils.forward_pass(model, loss_fn, inputs)

        print("loss %f" % loss.item())

        print(data.keys())

        print(data["pose"].shape)
        for i in range(4):
            print(list(data["pose"][i, 0, :].cpu().detach().numpy()))
            print(list(data["pose"][i, 1, :].cpu().detach().numpy()))
            print("--")

        depth_img = viz.tensor2depthimg(
            torch.cat((*data["depth"][0][:, 0], ), dim=0))
        tgt_img = viz.tensor2img(torch.cat((*data["tgt"], ), dim=1))
        img = np.concatenate((tgt_img, depth_img), axis=1)

        warp_imgs = []
        diff_imgs = []
        for warp, diff in zip(data["warp"], data["diff"]):
            warp = restack(restack(warp, 1, -1), 0, -2)
            diff = restack(restack(diff, 1, -1), 0, -2)
            warp_imgs.append(viz.tensor2img(warp))
            diff_imgs.append(viz.tensor2diffimg(diff))

        world = inverse_warp.depth_to_3d_points(data["depth"][0], data["K"])
        points = world[0, :].view(3, -1).transpose(
            1, 0).cpu().detach().numpy().astype(np.float64)
        colors = (data["tgt"][0, :].view(3, -1).transpose(
            1, 0).cpu().detach().numpy().astype(np.float64) + 1) / 2

        test_mayavi.draw_rgb_points(fig, points, colors)

        loop = True
        while loop:
            key = cv2.waitKey(10)
            if key == 27:
                exit()
            elif key != -1:
                loop = False
            cv2.imshow("target and depth", img)
            for i, (warp, diff) in enumerate(zip(warp_imgs, diff_imgs)):
                cv2.imshow("warp scale: %d" % i, warp)
                cv2.imshow("diff scale: %d" % i, diff)
            mlab.show(10)

    utils.iterate_loader(args, test_loader, step_fn)
Ejemplo n.º 6
0
    def run(self):

        self.model.eval()

        utils.iterate_loader(self.DEVICE, self.loader, self._step_fn)