Exemple #1
0
    def __init__(
            self,
            # exp params
            exp_name="u50_block",
            # arch params
            backbone="resnet50",
            backbone_kwargs={},
            dim_embedding=256,
            feature_spatial_scale=0.25,
            max_junctions=512,
            junction_pooling_threshold=0.2,
            junc_pooling_size=15,
            attention_sigma=1.,
            junction_heatmap_criterion="binary_cross_entropy",
            block_inference_size=64,
            adjacency_matrix_criterion="binary_cross_entropy",
            # data params
            data_root=r"/home/ziheng/indoorDist_new2",
            img_size=416,
            junc_sigma=3.,
            batch_size=2,
            # train params
            gpus=[
                0,
            ],
            num_workers=5,
            resume_epoch="latest",
            is_train_junc=True,
            is_train_adj=True,
            # vis params
            vis_junc_th=0.3,
            vis_line_th=0.3):
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(c) for c in gpus)

        self.is_cuda = bool(gpus)

        self.model = LSDModule(
            backbone=backbone,
            dim_embedding=dim_embedding,
            backbone_kwargs=backbone_kwargs,
            junction_pooling_threshold=junction_pooling_threshold,
            max_junctions=max_junctions,
            feature_spatial_scale=feature_spatial_scale,
            junction_heatmap_criterion=junction_heatmap_criterion,
            junction_pooling_size=junc_pooling_size,
            attention_sigma=attention_sigma,
            block_inference_size=block_inference_size,
            adjacency_matrix_criterion=adjacency_matrix_criterion,
            weight_fn=weight_fn,
            is_train_adj=is_train_adj,
            is_train_junc=is_train_junc)

        self.exp_name = exp_name
        os.makedirs(os.path.join("log", exp_name), exist_ok=True)
        os.makedirs(os.path.join("ckpt", exp_name), exist_ok=True)
        self.writer = SummaryWriter(log_dir=os.path.join("log", exp_name))

        # checkpoints
        self.states = dict(last_epoch=-1, elapsed_time=0, state_dict=None)

        if resume_epoch and os.path.isfile(
                os.path.join("ckpt", exp_name,
                             f"train_states_{resume_epoch}.pth")):
            states = torch.load(
                os.path.join("ckpt", exp_name,
                             f"train_states_{resume_epoch}.pth"))
            print(f"resume traning from epoch {states['last_epoch']}")
            self.model.load_state_dict(states["state_dict"])
            self.states.update(states)

        self.train_data = SISTLine(data_root=data_root,
                                   transforms=tf.Compose(
                                       tf.Resize((img_size, img_size)),
                                       tf.RandomHorizontalFlip(),
                                       tf.RandomColorAug()),
                                   phase="train",
                                   sigma_junction=junc_sigma,
                                   max_junctions=max_junctions)

        self.train_loader = DataLoader(self.train_data,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=num_workers,
                                       pin_memory=True)

        self.eval_data = SISTLine(data_root=data_root,
                                  transforms=tf.Compose(
                                      tf.Resize((img_size, img_size)), ),
                                  phase="val",
                                  sigma_junction=junc_sigma,
                                  max_junctions=max_junctions)

        self.eval_loader = DataLoader(self.eval_data,
                                      batch_size=batch_size,
                                      shuffle=False,
                                      num_workers=num_workers,
                                      pin_memory=True)

        self.vis_junc_th = vis_junc_th
        self.vis_line_th = vis_line_th
        self.block_size = block_inference_size
        self.max_junctions = max_junctions
        self.is_train_junc = is_train_junc
        self.is_train_adj = is_train_adj
Exemple #2
0
class LSDTrainer(object):
    def __init__(
            self,
            # exp params
            exp_name="u50_block",
            # arch params
            backbone="resnet50",
            backbone_kwargs={},
            dim_embedding=256,
            feature_spatial_scale=0.25,
            max_junctions=512,
            junction_pooling_threshold=0.2,
            junc_pooling_size=15,
            attention_sigma=1.,
            junction_heatmap_criterion="binary_cross_entropy",
            block_inference_size=64,
            adjacency_matrix_criterion="binary_cross_entropy",
            # data params
            data_root=r"/home/ziheng/indoorDist_new2",
            img_size=416,
            junc_sigma=3.,
            batch_size=2,
            # train params
            gpus=[
                0,
            ],
            num_workers=5,
            resume_epoch="latest",
            is_train_junc=True,
            is_train_adj=True,
            # vis params
            vis_junc_th=0.3,
            vis_line_th=0.3):
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(c) for c in gpus)

        self.is_cuda = bool(gpus)

        self.model = LSDModule(
            backbone=backbone,
            dim_embedding=dim_embedding,
            backbone_kwargs=backbone_kwargs,
            junction_pooling_threshold=junction_pooling_threshold,
            max_junctions=max_junctions,
            feature_spatial_scale=feature_spatial_scale,
            junction_heatmap_criterion=junction_heatmap_criterion,
            junction_pooling_size=junc_pooling_size,
            attention_sigma=attention_sigma,
            block_inference_size=block_inference_size,
            adjacency_matrix_criterion=adjacency_matrix_criterion,
            weight_fn=weight_fn,
            is_train_adj=is_train_adj,
            is_train_junc=is_train_junc)

        self.exp_name = exp_name
        os.makedirs(os.path.join("log", exp_name), exist_ok=True)
        os.makedirs(os.path.join("ckpt", exp_name), exist_ok=True)
        self.writer = SummaryWriter(log_dir=os.path.join("log", exp_name))

        # checkpoints
        self.states = dict(last_epoch=-1, elapsed_time=0, state_dict=None)

        if resume_epoch and os.path.isfile(
                os.path.join("ckpt", exp_name,
                             f"train_states_{resume_epoch}.pth")):
            states = torch.load(
                os.path.join("ckpt", exp_name,
                             f"train_states_{resume_epoch}.pth"))
            print(f"resume traning from epoch {states['last_epoch']}")
            self.model.load_state_dict(states["state_dict"])
            self.states.update(states)

        self.train_data = SISTLine(data_root=data_root,
                                   transforms=tf.Compose(
                                       tf.Resize((img_size, img_size)),
                                       tf.RandomHorizontalFlip(),
                                       tf.RandomColorAug()),
                                   phase="train",
                                   sigma_junction=junc_sigma,
                                   max_junctions=max_junctions)

        self.train_loader = DataLoader(self.train_data,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=num_workers,
                                       pin_memory=True)

        self.eval_data = SISTLine(data_root=data_root,
                                  transforms=tf.Compose(
                                      tf.Resize((img_size, img_size)), ),
                                  phase="val",
                                  sigma_junction=junc_sigma,
                                  max_junctions=max_junctions)

        self.eval_loader = DataLoader(self.eval_data,
                                      batch_size=batch_size,
                                      shuffle=False,
                                      num_workers=num_workers,
                                      pin_memory=True)

        self.vis_junc_th = vis_junc_th
        self.vis_line_th = vis_line_th
        self.block_size = block_inference_size
        self.max_junctions = max_junctions
        self.is_train_junc = is_train_junc
        self.is_train_adj = is_train_adj

    @staticmethod
    def _group_weight(module, lr):
        group_decay = []
        group_no_decay = []
        for m in module.modules():
            if isinstance(m, nn.Linear):
                group_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)
            elif isinstance(m, nn.modules.conv._ConvNd):
                group_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)
            elif isinstance(m, nn.modules.batchnorm._BatchNorm) or isinstance(
                    m, nn.GroupNorm):
                if m.weight is not None:
                    group_no_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)

        assert len(list(
            module.parameters())) == len(group_decay) + len(group_no_decay)
        groups = [
            dict(params=group_decay, lr=lr),
            dict(params=group_no_decay, lr=lr, weight_decay=.0)
        ]
        return groups

    def end(self):
        self.writer.close()
        return "command queue finished."

    def _train_epoch(self):
        net_time = AverageMeter()
        data_time = AverageMeter()
        vis_time = AverageMeter()

        epoch = self.states["last_epoch"]
        data_loader = self.train_loader
        if self.is_cuda:
            self.model = self.model.cuda()
        params = self._group_weight(self.model.backbone, self.lr)
        if self.is_train_junc:
            params += self._group_weight(self.model.junc_infer, self.lr)
        if self.is_train_adj:
            params += self._group_weight(self.model.adj_infer, self.lr)
            params += self._group_weight(self.model.adj_embed, self.lr)
        if self.solver == "Adadelta":
            solver = optim.__dict__[self.solver](
                params, weight_decay=self.weight_decay)
        else:
            solver = optim.__dict__[self.solver](
                params, weight_decay=self.weight_decay, momentum=self.momentum)

        # main loop
        torch.set_grad_enabled(True)
        tic = time.time()
        print(f"start training epoch: {epoch}", flush=True)

        if self.is_cuda:
            model = nn.DataParallel(self.model).train()
        else:
            model = self.model.train()

        for i, batch in enumerate(data_loader):
            if self.is_cuda:
                img = batch["image"].cuda()
                heatmap_gt = batch["heatmap"].cuda()
                adj_mtx_gt = batch["adj_mtx"].cuda()
                junctions_gt = batch["junctions"].cuda()
            else:
                img = batch["image"]
                heatmap_gt = batch["heatmap"]
                adj_mtx_gt = batch["adj_mtx"]
                junctions_gt = batch["junctions"]

            # measure elapsed time
            data_time.update(time.time() - tic)
            tic = time.time()

            junc_pred, heatmap_pred, adj_mtx_pred, loss_hm, loss_adj = model(
                img, heatmap_gt, adj_mtx_gt, self.lambda_heatmap,
                self.lambda_adj, junctions_gt)

            model.zero_grad()
            loss_adj = loss_adj.mean()
            loss_hm = loss_hm.mean()
            loss = (loss_hm if self.is_train_junc else
                    0) + (loss_adj if self.is_train_adj else 0)
            loss.backward()
            solver.step()

            # measure elapsed time
            net_time.update(time.time() - tic)
            tic = time.time()

            # visualize result
            if i % self.vis_line_interval == 0:
                img = img.cpu().numpy()
                heatmap_pred = heatmap_pred.detach().cpu()
                adj_mtx_pred = adj_mtx_pred.detach().cpu().numpy()
                junctions_gt = junctions_gt.cpu().numpy()
                adj_mtx_gt = adj_mtx_gt.cpu().numpy()
                self._vis_train(epoch, i, len(data_loader), img, heatmap_pred,
                                adj_mtx_pred, junctions_gt, adj_mtx_gt)

            vis_heatmap_gt = vutils.make_grid(
                heatmap_gt.view(heatmap_gt.size(0), 1, heatmap_gt.size(1),
                                heatmap_gt.size(2)))
            vis_heatmap_pred = vutils.make_grid(
                heatmap_pred.view(heatmap_gt.size(0), 1, heatmap_gt.size(1),
                                  heatmap_gt.size(2)))

            self.writer.add_scalar(self.exp_name + "/" + "train/loss_total",
                                   loss.item(),
                                   epoch * len(data_loader) + i)
            self.writer.add_scalar(
                self.exp_name + "/" + "train/loss_heatmap",
                loss_hm.item() /
                self.lambda_heatmap if self.lambda_heatmap else 0,
                epoch * len(data_loader) + i)
            self.writer.add_scalar(
                self.exp_name + "/" + "train/loss_adj_mtx",
                loss_adj.item() / self.lambda_adj if self.lambda_adj else 0,
                epoch * len(data_loader) + i)
            self.writer.add_image(self.exp_name + "/" + "train/heatmap_gt",
                                  vis_heatmap_gt,
                                  epoch * len(data_loader) + i)
            self.writer.add_image(self.exp_name + "/" + "train/heatmap_pred",
                                  vis_heatmap_pred,
                                  epoch * len(data_loader) + i)

            vis_time.update(time.time() - tic)
            info = f"epoch: [{epoch}][{i}/{len(data_loader)}], lr: {self.lr}, " \
                   f"time_total: {net_time.average() + data_time.average() + vis_time.average():.2f}, " \
                   f"time_data: {data_time.average():.2f}, time_net: {net_time.average():.2f}, " \
                   f"time_vis: {vis_time.average():.2f}, " \
                   f"loss: {loss.item():.4f}, " \
                   f"loss_heatmap: {loss_hm.item() / self.lambda_heatmap if self.lambda_heatmap else 0:.4f}, " \
                   f"loss_adj_mtx: {loss_adj.item() / self.lambda_adj if self.lambda_adj else 0:.4f}"
            self.writer.add_text(self.exp_name + "/" + "train/info", info,
                                 epoch * len(data_loader) + i)
            print(info, flush=True)
            # measure elapsed time
            tic = time.time()

    def _vis_train(self, epoch, i, len_loader, img, heatmap, adj_mtx,
                   junctions_gt, adj_mtx_gt):
        junctions_gt = np.int32(junctions_gt)
        lines_gt, scores_gt = graph2line(junctions_gt, adj_mtx_gt)
        vis_line_gt = vutils.make_grid(draw_lines(img, lines_gt, scores_gt))
        lines_pred, score_pred = graph2line(junctions_gt,
                                            adj_mtx,
                                            threshold=self.vis_line_th)
        vis_line_pred = vutils.make_grid(
            draw_lines(img, lines_pred, score_pred))
        junc_score = []
        line_score = []
        for m, juncs in zip(heatmap, junctions_gt):
            juncs = juncs[juncs.sum(axis=1) > 0]
            junc_score += m[juncs[:, 1], juncs[:, 0]].tolist()
        for s in score_pred:
            line_score += s.tolist()

        self.writer.add_image(self.exp_name + "/" + "train/lines_gt",
                              vis_line_gt, epoch * len_loader + i)
        self.writer.add_image(self.exp_name + "/" + "train/lines_pred",
                              vis_line_pred, epoch * len_loader + i)
        self.writer.add_scalar(self.exp_name + "/" + "train/mean_junc_score",
                               np.mean(junc_score), epoch * len_loader + i)
        self.writer.add_scalar(self.exp_name + "/" + "train/mean_line_score",
                               np.mean(line_score), epoch * len_loader + i)

    def _checkpoint(self):
        print('Saving checkpoints...')

        train_states = self.states

        train_states["state_dict"] = self.model.cpu().state_dict()

        torch.save(
            train_states,
            os.path.join("ckpt", self.exp_name, "train_states_latest.pth"))
        torch.save(
            train_states,
            os.path.join("ckpt", self.exp_name,
                         f"train_states_{self.states['last_epoch']}.pth"))

        state = torch.load(
            os.path.join("ckpt", self.exp_name, "train_states_latest.pth"))
        self.model.load_state_dict(state["state_dict"])

    def train(
        self,
        end_epoch=20,
        solver="SGD",
        lr=1.,
        weight_decay=5e-4,
        momentum=0.9,
        lambda_heatmap=1.,
        lambda_adj=1.,
        vis_line_interval=20,
    ):
        self.vis_line_interval = vis_line_interval
        self.end_epoch = end_epoch
        self.lr = lr
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.lambda_heatmap = lambda_heatmap
        self.lambda_adj = lambda_adj
        self.solver = solver

        start_epoch = self.states["last_epoch"] + 1

        for epoch in range(start_epoch, end_epoch):
            self.states["last_epoch"] = epoch
            self._train_epoch()
            self._checkpoint()

        return self

    def _vis_eval(self, epoch, i, len_loader, img, heatmap, adj_mtx,
                  junctions_pred, junctions_gt, adj_mtx_gt):
        junctions_gt = np.int32(junctions_gt)
        lines_gt, scores_gt = graph2line(junctions_gt,
                                         adj_mtx_gt,
                                         threshold=self.vis_junc_th)
        vis_line_gt = vutils.make_grid(draw_lines(img, lines_gt, scores_gt))
        img_with_junc = draw_jucntions(img, junctions_pred)
        img_with_junc = torch.stack(img_with_junc,
                                    dim=0).numpy()[:, ::-1, :, :]
        lines_pred, score_pred = graph2line(junctions_pred, adj_mtx)
        vis_line_pred = vutils.make_grid(
            draw_lines(img_with_junc, lines_pred, score_pred))
        junc_score = []
        line_score = []
        for m, juncs in zip(heatmap, junctions_gt):
            juncs = juncs[juncs.sum(axis=1) > 0]
            junc_score += m[juncs[:, 1], juncs[:, 0]].tolist()
        for s in score_pred:
            line_score += s.tolist()

        junc_pooling = vutils.make_grid(draw_jucntions(heatmap,
                                                       junctions_pred))

        self.writer.add_image(self.exp_name + "/" + "eval/junction_pooling",
                              junc_pooling, epoch * len_loader + i)

        self.writer.add_image(self.exp_name + "/" + "eval/lines_gt",
                              vis_line_gt, epoch * len_loader + i)
        self.writer.add_image(self.exp_name + "/" + "eval/lines_pred",
                              vis_line_pred, epoch * len_loader + i)
        self.writer.add_scalar(self.exp_name + "/" + "eval/mean_junc_score",
                               np.mean(junc_score), epoch * len_loader + i)
        self.writer.add_scalar(self.exp_name + "/" + "eval/mean_line_score",
                               np.mean(line_score), epoch * len_loader + i)

    def eval(self,
             lambda_heatmap=1.,
             lambda_adj=1.,
             off_line=False,
             epoch=None):

        if not off_line:
            if not (self.states["last_epoch"] == epoch - 1):
                return self
        else:
            self.lambda_heatmap = lambda_heatmap
            self.lambda_adj = lambda_adj

        net_time = AverageMeter()
        data_time = AverageMeter()
        vis_time = AverageMeter()
        ave_loss = AverageMeter()
        ave_loss_heatmap = AverageMeter()
        ave_loss_adj_mtx = AverageMeter()

        epoch = self.states["last_epoch"]
        data_loader = self.eval_loader

        # main loop
        torch.set_grad_enabled(False)
        tic = time.time()
        print(f"start evaluating epoch: {epoch}", flush=True)

        if self.is_cuda:
            model = nn.DataParallel(self.model.cuda()).train()
        else:
            model = self.model.train()

        for i, batch in enumerate(data_loader):
            if self.is_cuda:
                img = batch["image"].cuda()
                heatmap_gt = batch["heatmap"].cuda()
                adj_mtx_gt = batch["adj_mtx"].cuda()
                junctions_gt = batch["junctions"].cuda()
            else:
                img = batch["image"]
                heatmap_gt = batch["heatmap"]
                adj_mtx_gt = batch["adj_mtx"]
                junctions_gt = batch["junctions"]

            # measure elapsed time
            data_time.update(time.time() - tic)
            tic = time.time()

            junc_pred, heatmap_pred, adj_mtx_pred, loss_hm, loss_adj = model(
                img, heatmap_gt, adj_mtx_gt, self.lambda_heatmap,
                self.lambda_adj, junctions_gt)

            loss_adj = loss_adj.mean()
            loss_hm = loss_hm.mean()
            loss = loss_adj + loss_hm
            ave_loss_adj_mtx.update(loss_adj.item() /
                                    self.lambda_adj if self.lambda_adj else 0)
            ave_loss_heatmap.update(
                loss_hm.item() /
                self.lambda_heatmap if self.lambda_heatmap else 0)
            ave_loss.update(loss.item())

            # measure elapsed time
            net_time.update(time.time() - tic)
            tic = time.time()

            # visualize eval
            img = img.cpu().numpy()
            heatmap = heatmap_pred.detach().cpu().numpy()
            junctions_pred = junc_pred.detach().cpu().numpy()
            adj_mtx = adj_mtx_pred.detach().cpu().numpy()
            junctions_gt = junctions_gt.cpu().numpy()
            adj_mtx_gt = adj_mtx_gt.cpu().numpy()
            self._vis_eval(epoch, i, len(data_loader), img, heatmap, adj_mtx,
                           junctions_pred, junctions_gt, adj_mtx_gt)

            vis_heatmap_gt = vutils.make_grid(
                heatmap_gt.view(heatmap_gt.size(0), 1, heatmap_gt.size(1),
                                heatmap_gt.size(2)))
            vis_heatmap_pred = vutils.make_grid(
                heatmap.view(heatmap_gt.size(0), 1, heatmap_gt.size(1),
                             heatmap_gt.size(2)))

            self.writer.add_scalar(self.exp_name + "/" + "eval/loss_total",
                                   loss.item(),
                                   epoch * len(data_loader) + i)
            self.writer.add_scalar(
                self.exp_name + "/" + "eval/loss_heatmap",
                loss_hm.item() /
                self.lambda_heatmap if self.lambda_heatmap else 0,
                epoch * len(data_loader) + i)
            self.writer.add_scalar(
                self.exp_name + "/" + "eval/loss_adj_mtx",
                loss_adj.item() / self.lambda_adj if self.lambda_adj else 0,
                epoch * len(data_loader) + i)
            self.writer.add_image(self.exp_name + "/" + "eval/heatmap_gt",
                                  vis_heatmap_gt,
                                  epoch * len(data_loader) + i)
            self.writer.add_image(self.exp_name + "/" + "eval/heatmap_pred",
                                  vis_heatmap_pred,
                                  epoch * len(data_loader) + i)

            vis_time.update(time.time() - tic)
            info = f"epoch: [{epoch}][{i}/{len(data_loader)}], " \
                   f"time_total: {net_time.average() + data_time.average() + vis_time.average():.2f}, " \
                   f"time_data: {data_time.average():.2f}, time_net: {net_time.average():.2f}, " \
                   f"time_vis: {vis_time.average():.2f}, " \
                   f"loss: {loss.item():.4f}, " \
                   f"loss_heatmap: {loss_hm.item() / self.lambda_heatmap if self.lambda_heatmap else 0:.4f}, " \
                   f"loss_adj_mtx: {loss_adj.item() / self.lambda_adj if self.lambda_adj else 0:.4f}"
            if i == len(data_loader) - 1:
                info += f"\n*[{epoch}] " \
                        f"ave_loss: {ave_loss.average():.4f}, " \
                        f"ave_loss_heatmap: {ave_loss_heatmap.average():.4f}, " \
                        f"ave_loss_adj_mtx: {ave_loss_adj_mtx.average():.4f}"

            self.writer.add_text(self.exp_name + "/" + "eval/info", info,
                                 epoch * len(data_loader) + i)
            print(info, flush=True)
            # measure elapsed time
            tic = time.time()

        return self