예제 #1
0
    def setup(self):
        """
        Loads the LCNN model and parameters.
        Must be called before any processing can occur.
        
        Returns:
         - bool resulting initialized status
        """
        if self.initialized:
            print("Wireframe already initialized!")
            return self.initialized

        try:
            C.update(C.from_yaml(filename=self._config_file))
            M.update(C.model)
            random.seed(0)
            np.random.seed(0)
            torch.manual_seed(0)

            device_name = "cpu"
            os.environ["CUDA_VISIBLE_DEVICES"] = self._gpu_devices
            if torch.cuda.is_available():
                device_name = "cuda"
                torch.backends.cudnn.deterministic = True
                torch.cuda.manual_seed(0)
                print("Let's use", torch.cuda.device_count(), "GPU(s)!")
            else:
                print("CUDA is not available")
            self._device = torch.device(device_name)
            self._checkpoint = torch.load(self._model_file,
                                          map_location=self._device)

            # Load model
            self._model = lcnn.models.hg(
                depth=M.depth,
                head=lambda c_in, c_out: MultitaskHead(c_in, c_out),
                num_stacks=M.num_stacks,
                num_blocks=M.num_blocks,
                num_classes=sum(sum(M.head_size, [])),
            )
            self._model = MultitaskLearner(self._model)
            self._model = LineVectorizer(self._model)
            self._model.load_state_dict(self._checkpoint["model_state_dict"])
            self._model = self._model.to(self._device)
            self._model.eval()

            self.initialized = True

        except Exception as e:
            self.error = e
            print("Setup failed. Check self.error for more information")
            self.initialized = False

        return self.initialized
예제 #2
0
def main():
    args = docopt(__doc__)
    config_file = args["<yaml-config>"] or "config/wireframe.yaml"
    C.update(C.from_yaml(filename=config_file))
    M.update(C.model)
    pprint.pprint(C, indent=4)

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    device_name = "cpu"
    os.environ["CUDA_VISIBLE_DEVICES"] = args["--devices"]
    if torch.cuda.is_available():
        device_name = "cuda"
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed(0)
        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
    else:
        print("CUDA is not available")
    device = torch.device(device_name)

    if M.backbone == "stacked_hourglass":
        model = lcnn.models.hg(
            depth=M.depth,
            head=lambda c_in, c_out: MultitaskHead(c_in, c_out),
            num_stacks=M.num_stacks,
            num_blocks=M.num_blocks,
            num_classes=sum(sum(M.head_size, [])),
        )
    else:
        raise NotImplementedError

    checkpoint = torch.load(args["<checkpoint>"])
    model = MultitaskLearner(model)
    model = LineVectorizer(model)
    model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device)
    model.eval()

    loader = torch.utils.data.DataLoader(
        WireframeDataset(args["<image-dir>"], split="valid"),
        shuffle=False,
        batch_size=M.batch_size,
        collate_fn=collate,
        num_workers=C.io.num_workers,
        pin_memory=True,
    )
    os.makedirs(args["<output-dir>"], exist_ok=True)

    for batch_idx, (image, meta, target) in enumerate(loader):
        with torch.no_grad():
            input_dict = {
                "image": recursive_to(image, device),
                "meta": recursive_to(meta, device),
                "target": recursive_to(target, device),
                "do_evaluation": True,
            }
            H = model(input_dict)["heatmaps"]
            for i in range(M.batch_size):
                index = batch_idx * M.batch_size + i
                np.savez(
                    osp.join(args["<output-dir>"], f"{index:06}.npz"),
                    **{k: v[i].cpu().numpy()
                       for k, v in H.items()},
                )
                if not args["--plot"]:
                    continue
                im = image[i].cpu().numpy().transpose(1, 2, 0)
                im = im * M.image.stddev + M.image.mean
                lines = H["lines"][i].cpu().numpy() * 4
                scores = H["score"][i].cpu().numpy()
                if len(lines) > 0 and not (lines[0] == 0).all():
                    for i, ((a, b), s) in enumerate(zip(lines, scores)):
                        if i > 0 and (lines[i] == lines[0]).all():
                            break
                        plt.plot([a[1], b[1]], [a[0], b[0]],
                                 c=c(s),
                                 linewidth=4)
                plt.show()
예제 #3
0
def main():
    args = docopt(__doc__)
    config_file = args["<yaml-config>"] or "config/wireframe.yaml"
    C.update(C.from_yaml(filename=config_file))
    M.update(C.model)
    pprint.pprint(C, indent=4)

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    device_name = "cpu"
    os.environ["CUDA_VISIBLE_DEVICES"] = args["--devices"]
    if torch.cuda.is_available():
        device_name = "cuda"
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed(0)
        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
    else:
        print("CUDA is not available")
    device = torch.device(device_name)
    checkpoint = torch.load(args["<checkpoint>"], map_location=device)

    # Load model
    if os.path.isfile(C.io.vote_index):
        vote_index = sio.loadmat(C.io.vote_index)['vote_index']
    else:
        vote_index = hough_transform(rows=128,
                                     cols=128,
                                     theta_res=3,
                                     rho_res=1)
        sio.savemat(C.io.vote_index, {'vote_index': vote_index})
    vote_index = torch.from_numpy(vote_index).float().contiguous().to(device)
    print('load vote_index', vote_index.shape)

    model = hg(
        depth=M.depth,
        head=lambda c_in, c_out: MultitaskHead(c_in, c_out),
        num_stacks=M.num_stacks,
        num_blocks=M.num_blocks,
        num_classes=sum(sum(M.head_size, [])),
        vote_index=vote_index,
    )
    model = MultitaskLearner(model)
    model = LineVectorizer(model)
    model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device)
    model.eval()

    for imname in args["<images>"]:
        print(f"Processing {imname}")
        im = skimage.io.imread(imname)
        if im.ndim == 2:
            im = np.repeat(im[:, :, None], 3, 2)
        im = im[:, :, :3]
        im_resized = skimage.transform.resize(im, (512, 512)) * 255
        image = (im_resized - M.image.mean) / M.image.stddev
        image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float()
        with torch.no_grad():
            input_dict = {
                "image":
                image.to(device),
                "meta": [{
                    "junc":
                    torch.zeros(1, 2).to(device),
                    "jtyp":
                    torch.zeros(1, dtype=torch.uint8).to(device),
                    "Lpos":
                    torch.zeros(2, 2, dtype=torch.uint8).to(device),
                    "Lneg":
                    torch.zeros(2, 2, dtype=torch.uint8).to(device),
                }],
                "target": {
                    "jmap": torch.zeros([1, 1, 128, 128]).to(device),
                    "joff": torch.zeros([1, 1, 2, 128, 128]).to(device),
                },
                "mode":
                "testing",
            }
            H = model(input_dict)["preds"]

        lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
        scores = H["score"][0].cpu().numpy()
        for i in range(1, len(lines)):
            if (lines[i] == lines[0]).all():
                lines = lines[:i]
                scores = scores[:i]
                break

        # postprocess lines to remove overlapped lines
        diag = (im.shape[0]**2 + im.shape[1]**2)**0.5
        nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)

        for i, t in enumerate([0.94, 0.95, 0.96, 0.97, 0.98, 0.99]):
            plt.gca().set_axis_off()
            plt.subplots_adjust(top=1,
                                bottom=0,
                                right=1,
                                left=0,
                                hspace=0,
                                wspace=0)
            plt.margins(0, 0)
            for (a, b), s in zip(nlines, nscores):
                if s < t:
                    continue
                plt.plot([a[1], b[1]], [a[0], b[0]],
                         c=c(s),
                         linewidth=2,
                         zorder=s)
                plt.scatter(a[1], a[0], **PLTOPTS)
                plt.scatter(b[1], b[0], **PLTOPTS)
            plt.gca().xaxis.set_major_locator(plt.NullLocator())
            plt.gca().yaxis.set_major_locator(plt.NullLocator())
            plt.imshow(im)
            plt.savefig(imname.replace(".png", f"-{t:.02f}.svg"),
                        bbox_inches="tight")
            plt.show()
            plt.close()
예제 #4
0
def main():
    args = docopt(__doc__)
    config_file = args["<yaml-config>"] or "config/wireframe.yaml"
    C.update(C.from_yaml(filename=config_file))
    M.update(C.model)
    pprint.pprint(C, indent=4)
    resume_from = C.io.resume_from

    # WARNING: L-CNN is still not deterministic
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    device_name = "cpu"
    os.environ["CUDA_VISIBLE_DEVICES"] = args["--devices"]
    if torch.cuda.is_available():
        device_name = "cuda"
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed(0)
        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
    else:
        print("CUDA is not available")
    device = torch.device(device_name)

    # 1. dataset

    # uncomment for debug DataLoader
    # wireframe.datasets.WireframeDataset(datadir, split="train")[0]
    # sys.exit(0)

    datadir = C.io.datadir
    kwargs = {
        "collate_fn": collate,
        "num_workers": C.io.num_workers if os.name != "nt" else 0,
        "pin_memory": True,
    }
    train_loader = torch.utils.data.DataLoader(
        WireframeDataset(datadir, split="train"),
        shuffle=True,
        batch_size=M.batch_size,
        **kwargs,
    )
    val_loader = torch.utils.data.DataLoader(WireframeDataset(datadir,
                                                              split="valid"),
                                             shuffle=False,
                                             batch_size=2,
                                             **kwargs)
    epoch_size = len(train_loader)
    # print("epoch_size (train):", epoch_size)
    # print("epoch_size (valid):", len(val_loader))

    if resume_from:
        checkpoint = torch.load(
            osp.join(resume_from, "checkpoint_latest.pth.tar"))

    # 2. model
    if M.backbone == "stacked_hourglass":
        model = lcnn.models.hg(
            depth=M.depth,
            head=MultitaskHead,
            num_stacks=M.num_stacks,
            num_blocks=M.num_blocks,
            num_classes=sum(sum(M.head_size, [])),
        )
    else:
        raise NotImplementedError

    model = MultitaskLearner(model)
    model = LineVectorizer(model)

    if resume_from:
        model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device)

    # 3. optimizer
    if C.optim.name == "Adam":
        optim = torch.optim.Adam(
            model.parameters(),
            lr=C.optim.lr,
            weight_decay=C.optim.weight_decay,
            amsgrad=C.optim.amsgrad,
        )
    elif C.optim.name == "SGD":
        optim = torch.optim.SGD(
            model.parameters(),
            lr=C.optim.lr,
            weight_decay=C.optim.weight_decay,
            momentum=C.optim.momentum,
        )
    else:
        raise NotImplementedError

    if resume_from:
        optim.load_state_dict(checkpoint["optim_state_dict"])
    outdir = resume_from or get_outdir(args["--identifier"])
    print("outdir:", outdir)

    try:
        trainer = lcnn.trainer.Trainer(
            device=device,
            model=model,
            optimizer=optim,
            train_loader=train_loader,
            val_loader=val_loader,
            out=outdir,
        )
        if resume_from:
            trainer.iteration = checkpoint["iteration"]
            if trainer.iteration % epoch_size != 0:
                print(
                    "WARNING: iteration is not a multiple of epoch_size, reset it"
                )
                trainer.iteration -= trainer.iteration % epoch_size
            trainer.best_mean_loss = checkpoint["best_mean_loss"]
            del checkpoint
        trainer.train()
    except BaseException:
        if len(glob.glob(f"{outdir}/viz/*")) <= 1:
            shutil.rmtree(outdir)
        raise
def main():
    device = torch.device('cpu')
    C.update(C.from_yaml(filename="config/wireframe.yaml"))
    M.update(C.model)

    if M.backbone == "stacked_hourglass":
        model = lcnn.models.hg(
            depth=M.depth,
            head=MultitaskHead,
            num_stacks=M.num_stacks,
            num_blocks=M.num_blocks,
            num_classes=sum(sum(M.head_size, [])),
        )
    else:
        raise NotImplementedError

    model = MultitaskLearner(model)

    x = torch.load(('./checkpoint_best.pth'),
                   map_location=torch.device('cpu'))['model_state_dict']
    model.load_state_dict(x)
    model = model.to(torch.device('cpu'))
    model.eval()
    datadir = C.io.datadir
    kwargs = {
        "collate_fn": collate,
        "num_workers": C.io.num_workers if os.name != "nt" else 0,
        "pin_memory": True,
    }

    val_loader = torch.utils.data.DataLoader(
        WireframeDataset(datadir, split="valid"),
        shuffle=False,
        batch_size=M.batch_size_eval,
        **kwargs,
    )
    for batch_idx, (image, target, _) in enumerate(val_loader):
        plt.imshow(image[0].permute(1, 2, 0).detach().int())
        plt.show()

        with torch.no_grad():
            input_dict = {
                "image": image.to(device),
                "target": {
                    "corner":
                    torch.zeros([1, 1, 128, 128]).to(device),
                    "center":
                    torch.zeros([1, 1, 128, 128]).to(device),
                    "corner_offset":
                    torch.zeros([1, 1, 2, 128, 128]).to(device),
                    "corner_bin_offset":
                    torch.zeros([1, 1, 2, 128, 128]).to(device),
                },
                "mode": "testing",
            }
            H = model(input_dict)["preds"]

            plt.imshow(H['corner'][0].squeeze())
            plt.colorbar()
            plt.title('corner')
            plt.show()

            plt.imshow(H['center'][0].squeeze())
            plt.colorbar()
            plt.title('center')
            plt.show()

            plt.imshow(H['corner_offset'][0][0][0].squeeze())
            plt.colorbar()
            plt.title('corner_offset')
            plt.show()
            plt.imshow(H['corner_offset'][0][0][1].squeeze())
            plt.colorbar()
            plt.title('corner_offset')
            plt.show()

            plt.imshow(H['corner_bin_offset'][0][0][0].squeeze())
            plt.colorbar()
            plt.title('corner_bin_offset')
            plt.show()
            plt.imshow(H['corner_bin_offset'][0][0][1].squeeze())
            plt.colorbar()
            plt.title('corner_bin_offset')
            plt.show()

        # print(result)
        input()
def main():
    args = docopt(__doc__)
    config_file = args["<yaml-config>"] or "config/wireframe.yaml"
    C.update(C.from_yaml(filename=config_file))
    M.update(C.model)
    pprint.pprint(C, indent=4)

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    device_name = "cpu"
    os.environ["CUDA_VISIBLE_DEVICES"] = args["--devices"]
    if torch.cuda.is_available():
        device_name = "cuda"
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed(0)
        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
    else:
        print("CUDA is not available")
    device = torch.device(device_name)

    ### load vote_index matrix for Hough transform
    ### defualt settings: (128, 128, 3, 1)
    if os.path.isfile(C.io.vote_index):
        vote_index = sio.loadmat(C.io.vote_index)['vote_index']
    else:
        vote_index = hough_transform(rows=128,
                                     cols=128,
                                     theta_res=3,
                                     rho_res=1)
        sio.savemat(C.io.vote_index, {'vote_index': vote_index})
    vote_index = torch.from_numpy(vote_index).float().contiguous().to(device)
    print('load vote_index', vote_index.shape)

    if M.backbone == "stacked_hourglass":
        model = lcnn.models.hg(
            depth=M.depth,
            head=MultitaskHead,
            num_stacks=M.num_stacks,
            num_blocks=M.num_blocks,
            num_classes=sum(sum(M.head_size, [])),
            vote_index=vote_index,
        )
    else:
        raise NotImplementedError

    checkpoint = torch.load(args["<checkpoint>"])
    model = MultitaskLearner(model)
    model = LineVectorizer(model)
    model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device)
    model.eval()

    loader = torch.utils.data.DataLoader(
        # WireframeDataset(args["<image-dir>"], split="valid"),
        WireframeDataset(rootdir=C.io.datadir, split="test"),
        shuffle=False,
        batch_size=M.batch_size,
        collate_fn=collate,
        num_workers=C.io.num_workers if os.name != "nt" else 0,
        pin_memory=True,
    )

    output_dir = C.io.outdir
    os.makedirs(output_dir, exist_ok=True)

    for batch_idx, (image, meta, target) in enumerate(loader):
        with torch.no_grad():
            input_dict = {
                "image": recursive_to(image, device),
                "meta": recursive_to(meta, device),
                "target": recursive_to(target, device),
                "mode": "validation",
            }
            H = model(input_dict)["preds"]
            for i in range(len(image)):
                index = batch_idx * M.batch_size + i
                print('index', index)
                np.savez(
                    osp.join(output_dir, f"{index:06}.npz"),
                    **{k: v[i].cpu().numpy()
                       for k, v in H.items()},
                )
                if not args["--plot"]:
                    continue
                im = image[i].cpu().numpy().transpose(1, 2, 0)
                im = im * M.image.stddev + M.image.mean
                lines = H["lines"][i].cpu().numpy() * 4
                scores = H["score"][i].cpu().numpy()
                if len(lines) > 0 and not (lines[0] == 0).all():
                    for i, ((a, b), s) in enumerate(zip(lines, scores)):
                        if i > 0 and (lines[i] == lines[0]).all():
                            break
                        plt.plot([a[1], b[1]], [a[0], b[0]],
                                 c=c(s),
                                 linewidth=4)
                plt.show()
예제 #7
0
def main():
    args = docopt(__doc__)
    config_file = args["<yaml-config>"] or "config/wireframe.yaml"
    C.update(C.from_yaml(filename=config_file))
    M.update(C.model)
    pprint.pprint(C, indent=4)

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    if M.backbone == "stacked_hourglass":
        model = lcnn.models.hg(
            depth=M.depth,
            head=lambda c_in, c_out: MultitaskHead(c_in, c_out),
            num_stacks=M.num_stacks,
            num_blocks=M.num_blocks,
            num_classes=sum(sum(M.head_size, [])),
        )
    else:
        raise NotImplementedError

    model = MultitaskLearner(model)

    device_name = "cpu"
    os.environ["CUDA_VISIBLE_DEVICES"] = args["--devices"]
    if torch.cuda.is_available():
        device_name = "cuda"
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed(0)
        checkpoint = torch.load(args["<checkpoint>"])
        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
    else:
        checkpoint = torch.load(args["<checkpoint>"],
                                map_location=torch.device('cpu'))
        print("CUDA is not available")
    device = torch.device(device_name)
    model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device)
    model.eval()

    print(f'evaluation batch size {M.batch_size_eval}')
    loader = torch.utils.data.DataLoader(
        WireframeDataset(args["<image-dir>"], split="valid"),
        shuffle=False,
        batch_size=M.batch_size_eval,
        collate_fn=collate,
        num_workers=C.io.num_workers if os.name != "nt" else 0,
        pin_memory=True,
    )
    if os.path.exists(args["<output-dir>"]):
        shutil.rmtree(args["<output-dir>"])
    os.makedirs(args["<output-dir>"], exist_ok=False)
    outdir = os.path.join(args["<output-dir>"], 'test_result')
    os.mkdir(outdir)

    for batch_idx, (image, target, iname) in enumerate(loader):
        with torch.no_grad():
            # predict given image
            input_target = {
                "center": torch.zeros_like(target['center']),
                "corner": torch.zeros_like(target['corner']),
                "corner_offset": torch.zeros_like(target['corner_offset']),
                "corner_bin_offset":
                torch.zeros_like(target['corner_bin_offset'])
            }
            input_dict = {
                "image": recursive_to(image, device),
                "target": recursive_to(input_target, device),
                "mode": "validation",
            }
            network_start_time = time()
            H = model(input_dict)["preds"]
            network_end_time = time()
            # plot gt & prediction
            for i in range(len(iname)):  #M.batch_size
                if not args["--plot"]:
                    continue
                im = image[i].cpu().numpy().transpose(1, 2, 0)  # [512,512,3]
                # im = im * M.image.stddev + M.image.mean

                # plot&process gt
                gt_im_info = [
                    im,
                    iname[i].split('.')[0] + '_gt.' + iname[i].split('.')[1]
                ]
                gt_center = target["center"][i].cpu().numpy()
                gt_corner = target["corner"][i].cpu().numpy()
                gt_corner_offset = target["corner_offset"][i].cpu().numpy()
                gt_corner_bin_offset = target["corner_bin_offset"][i].cpu(
                ).numpy()
                feature_maps = [
                    gt_center, gt_corner, gt_corner_offset,
                    gt_corner_bin_offset
                ]
                postprocess(gt_im_info,
                            feature_maps,
                            outdir,
                            NMS=False,
                            plot=True)
                # plot&process pd
                pd_im_info = [
                    im,
                    iname[i].split('.')[0] + '_pd.' + iname[i].split('.')[1]
                ]
                pd_center = H["center"][i].cpu().numpy()
                pd_corner = H["corner"][i].cpu().numpy()
                pd_corner_offset = H["corner_offset"][i].cpu().numpy()
                pd_corner_bin_offset = H["corner_bin_offset"][i].cpu().numpy()
                feature_maps = [
                    pd_center, pd_corner, pd_corner_offset,
                    pd_corner_bin_offset
                ]
                postprocess_start_time = time()
                grouped_corners = postprocess(pd_im_info,
                                              feature_maps,
                                              outdir,
                                              NMS=True,
                                              plot=True)
                postprocess_end_time = time()
                print(
                    f'inference time is {postprocess_end_time-postprocess_start_time+network_end_time-network_start_time}, network cost:{network_end_time-network_start_time}, postprocessing cost:{postprocess_end_time-postprocess_start_time}'
                )

            # Evaluation:
            # eval() # TBD
    print('-----finished-----')
    return
예제 #8
0
def main():
    args = docopt(__doc__)
    config_file = args["<yaml-config>"] or "config/wireframe.yaml"
    C.update(C.from_yaml(filename=config_file))
    M.update(C.model)
    pprint.pprint(C, indent=4)

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    device_name = "cpu"
    os.environ["CUDA_VISIBLE_DEVICES"] = args["--devices"]
    if torch.cuda.is_available():
        device_name = "cuda"
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed(0)
        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
    else:
        print("CUDA is not available")
    device = torch.device(device_name)
    checkpoint = torch.load(args["<checkpoint>"], map_location=device)

    # Load model
    model = lcnn.models.hg(
        depth=M.depth,
        head=lambda c_in, c_out: MultitaskHead(c_in, c_out),
        num_stacks=M.num_stacks,
        num_blocks=M.num_blocks,
        num_classes=sum(sum(M.head_size, [])),
    )
    model = MultitaskLearner(model)
    model = LineVectorizer(model)
    model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device)
    model.eval()

    # for imname in args["<images>"]:
    for root, dirs, files in os.walk(r'images'):
        for root_file in files:
            path = os.path.join(root, root_file)
            # print(path)
            for imname in glob.glob(path):
                print(f"Processing {imname}")
                im = skimage.io.imread(imname)
                if im.ndim == 2:
                    im = np.repeat(im[:, :, None], 3, 2)
                im = im[:, :, :3]
                im_resized = skimage.transform.resize(im, (512, 512)) * 255
                image = (im_resized - M.image.mean) / M.image.stddev
                image = torch.from_numpy(np.rollaxis(image,
                                                     2)[None].copy()).float()
                with torch.no_grad():
                    input_dict = {
                        "image":
                        image.to(device),
                        "meta": [{
                            "junc":
                            torch.zeros(1, 2).to(device),
                            "jtyp":
                            torch.zeros(1, dtype=torch.uint8).to(device),
                            "Lpos":
                            torch.zeros(2, 2, dtype=torch.uint8).to(device),
                            "Lneg":
                            torch.zeros(2, 2, dtype=torch.uint8).to(device),
                        }],
                        "target": {
                            "jmap": torch.zeros([1, 1, 128, 128]).to(device),
                            "joff": torch.zeros([1, 1, 2, 128,
                                                 128]).to(device),
                        },
                        "mode":
                        "testing",
                    }
                    H = model(input_dict)["preds"]

                lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
                scores = H["score"][0].cpu().numpy()
                for i in range(1, len(lines)):
                    if (lines[i] == lines[0]).all():
                        lines = lines[:i]
                        scores = scores[:i]
                        break

                # postprocess lines to remove overlapped lines
                diag = (im.shape[0]**2 + im.shape[1]**2)**0.5
                nlines, nscores = postprocess(lines, scores, diag * 0.01, 0,
                                              False)

                partExprotName = imname.split(".")[0]
                exportName = partExprotName + ".txt"
                with open(exportName, "w") as writeFile:
                    for i, t in enumerate([0.94]):
                        plt.gca().set_axis_off()
                        plt.subplots_adjust(top=1,
                                            bottom=0,
                                            right=1,
                                            left=0,
                                            hspace=0,
                                            wspace=0)
                        plt.margins(0, 0)
                        for (a, b), s in zip(nlines, nscores):
                            if s < t:
                                continue
                            print(a[1], a[0], b[1], b[0], file=writeFile)
                        #     plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
                        #     plt.scatter(a[1], a[0], **PLTOPTS)
                        #     plt.scatter(b[1], b[0], **PLTOPTS)
                        # plt.gca().xaxis.set_major_locator(plt.NullLocator())
                        # plt.gca().yaxis.set_major_locator(plt.NullLocator())
                        # plt.imshow(im)
                        # plt.savefig(imname.replace(".png", f"-{t:.02f}.svg"), bbox_inches="tight")
                        # plt.show()
                        # plt.close()

    for new_root, new_dir, new_files in os.walk(r'images'):
        # print(new_root)
        for root_file1 in new_files:
            path1 = os.path.join(new_root, root_file1)
            for all_files in glob.glob(path1):
                txtname1 = os.path.splitext(all_files)[1]
                txtname0 = os.path.splitext(all_files)[0]
                # print(txtname1, txtname0)
                if txtname1 == '.txt':
                    # old_path = os.path.join(new_root, all_files)
                    new_path = 'results/'
                    shutil.move(path1, new_path)
def main():
    args = docopt(__doc__)
    config_file = args["<yaml-config>"] or "config/wireframe.yaml"
    C.update(C.from_yaml(filename=config_file))
    M.update(C.model)
    pprint.pprint(C, indent=4)

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    if M.backbone == "stacked_hourglass":
        model = lcnn.models.hg(
            depth=M.depth,
            head=lambda c_in, c_out: MultitaskHead(c_in, c_out),
            num_stacks=M.num_stacks,
            num_blocks=M.num_blocks,
            num_classes=sum(sum(M.head_size, [])),
        )
    else:
        raise NotImplementedError

    model = MultitaskLearner(model)

    device_name = "cpu"
    os.environ["CUDA_VISIBLE_DEVICES"] = args["--devices"]
    if torch.cuda.is_available():
        device_name = "cuda"
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed(0)
        checkpoint = torch.load(args["<checkpoint>"])
        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
    else:
        checkpoint = torch.load(args["<checkpoint>"],
                                map_location=torch.device('cpu'))
        print("CUDA is not available")
    device = torch.device(device_name)
    model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device)
    model.eval()

    loader = torch.utils.data.DataLoader(
        WireframeDataset(args["<image-dir>"], split="valid"),
        shuffle=False,
        batch_size=M.batch_size_eval,
        collate_fn=collate,
        num_workers=C.io.num_workers if os.name != "nt" else 0,
        pin_memory=True,
    )
    os.path.join(args["<output-dir>"], 'test_result')
    if os.path.exists(args["<output-dir>"]):
        shutil.rmtree(args["<output-dir>"])
    os.makedirs(args["<output-dir>"], exist_ok=False)
    outdir = os.path.join(args["<output-dir>"], 'test_result')
    os.mkdir(outdir)

    # clean previous files in mAP folders
    for mAP_folder in [
            os.path.join(C.io.mAP, 'detection-results'),
            os.path.join(C.io.mAP, 'ground-truth')
    ]:
        if os.path.exists(mAP_folder):
            shutil.rmtree(mAP_folder)
        os.makedirs(mAP_folder, exist_ok=False)

    total_inference_time = 0
    time_cost_by_network = 0
    time_cost_by_post = 0
    for batch_idx, (image, target, iname) in enumerate(loader):
        with torch.no_grad():
            # predict given image
            input_target = {
                "center": torch.zeros_like(target['center']),
                "corner": torch.zeros_like(target['corner']),
                "corner_offset": torch.zeros_like(target['corner_offset']),
                "corner_bin_offset":
                torch.zeros_like(target['corner_bin_offset'])
            }
            input_dict = {
                "image": recursive_to(image, device),
                "target": recursive_to(input_target, device),
                "mode": "validation",
            }
            # time cost by network
            timer_begin = time.time()
            H = model(input_dict)["preds"]
            timer_end = time.time()
            time_cost_by_network += timer_end - timer_begin
            total_inference_time += timer_end - timer_begin

            # plot prediction
            for i in range(len(iname)):  #M.batch_size
                im = image[i].cpu().numpy().transpose(1, 2, 0)  # [512,512,3]

                # move gt files to mAP folder for evaluation
                move_to_mAP(os.path.join(args["<image-dir>"], 'valid'),
                            iname[i], os.path.join(C.io.mAP, 'ground-truth'))

                # plot&process pd
                pd_im_info = [
                    im,
                    iname[i].split('.')[0] + '_pd.' + iname[i].split('.')[1]
                ]
                pd_center = H["center"][i].cpu().numpy()
                pd_corner = H["corner"][i].cpu().numpy()
                pd_corner_offset = H["corner_offset"][i].cpu().numpy()
                pd_corner_bin_offset = H["corner_bin_offset"][i].cpu().numpy()
                feature_maps = [
                    pd_center, pd_corner, pd_corner_offset,
                    pd_corner_bin_offset
                ]
                ## post processing with center prediction
                # grouped_corners=postprocess(pd_im_info, feature_maps, outdir, NMS=True,plot=args['--plot'])
                ## post processing without center prediction
                timer_begin = time.time()
                grouped_corners = postprocess(pd_im_info,
                                              feature_maps,
                                              outdir,
                                              maxDet=10,
                                              NMS=True,
                                              plot=args['--plot'])
                timer_end = time.time()
                time_cost_by_post += timer_end - timer_begin
                total_inference_time += timer_end - timer_begin
                write_pd_to_mAP(grouped_corners, iname[i],
                                os.path.join(C.io.mAP, 'detection-results'))
                # print(f'prediction of {iname[i]} finished')

            # Evaluation:
    evalCOCO()  # TBD
    print("inference time is", total_inference_time / len(loader.dataset),
          "s / img")
    print(
        f"time cost by network is {time_cost_by_network/len(loader.dataset)}, time cost by post-processing is {time_cost_by_post/len(loader.dataset)}"
    )
    print('-----finished-----')
    return