예제 #1
0
    def __init__(self):
        self.model = CRAFT()
        if pr.cuda:
            self.model.load_state_dict(
                copyStateDict(torch.load(pr.trained_model)))
            self.model.cuda()
            self.model = torch.nn.DataParallel(self.model)
            cudnn.benchmark = False
        else:
            self.model.load_state_dict(
                copyStateDict(torch.load(pr.trained_model,
                                         map_location='cpu')))
        self.model.eval()

        self.refine_model = None
        if pr.refine:
            self.refine_model = RefineNet()
            if pr.cuda:
                self.refine_model.load_state_dict(
                    copyStateDict(torch.load(pr.refiner_model)))
                self.refine_model = self.refine_net.cuda()
                self.refine_model = torch.nn.DataParallel(self.refine_model)
            else:
                self.refine_model.load_state_dict(
                    copyStateDict(
                        torch.load(pr.refiner_model, map_location='cpu')))

            self.refine_model.eval()
            pr.poly = True
예제 #2
0
 def __init__(self, data_loader, opt):
     self.opt = opt
     self.dataloader = dataloader
     self.model = CRAFT(opt).to(opt.device)
     self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
예제 #3
0
    def train(self, buffer_dict):

        torch.cuda.set_device(self.gpu)
        total_gpu_num = torch.cuda.device_count()

        # MODEL -------------------------------------------------------------------------------------------------------#
        # SUPERVISION model
        if self.config.mode == "weak_supervision":
            if self.config.train.backbone == "vgg":
                supervision_model = CRAFT(pretrained=False,
                                          amp=self.config.train.amp)
            else:
                raise Exception("Undefined architecture")

            # NOTE: only work on half GPU assign train / half GPU assign supervision setting
            supervision_device = total_gpu_num // 2 + self.gpu
            if self.config.train.ckpt_path is not None:
                supervision_param = self.get_load_param(supervision_device)
                supervision_model.load_state_dict(
                    copyStateDict(supervision_param["craft"]))
                supervision_model = supervision_model.to(
                    f"cuda:{supervision_device}")
            print(f"Supervision model loading on : gpu {supervision_device}")
        else:
            supervision_model, supervision_device = None, None

        # TRAIN model
        if self.config.train.backbone == "vgg":
            craft = CRAFT(pretrained=False, amp=self.config.train.amp)
        else:
            raise Exception("Undefined architecture")

        if self.config.train.ckpt_path is not None:
            craft.load_state_dict(copyStateDict(self.net_param["craft"]))

        craft = nn.SyncBatchNorm.convert_sync_batchnorm(craft)
        craft = craft.cuda()
        craft = torch.nn.parallel.DistributedDataParallel(
            craft, device_ids=[self.gpu])

        torch.backends.cudnn.benchmark = True

        # DATASET -----------------------------------------------------------------------------------------------------#

        if self.config.train.use_synthtext:
            trn_syn_loader = self.get_synth_loader()
            batch_syn = iter(trn_syn_loader)

        if self.config.train.real_dataset == "custom":
            trn_real_dataset = self.get_custom_dataset()
        else:
            raise Exception("Undefined dataset")

        if self.config.mode == "weak_supervision":
            trn_real_dataset.update_model(supervision_model)
            trn_real_dataset.update_device(supervision_device)

        trn_real_sampler = torch.utils.data.distributed.DistributedSampler(
            trn_real_dataset)
        trn_real_loader = torch.utils.data.DataLoader(
            trn_real_dataset,
            batch_size=self.config.train.batch_size,
            shuffle=False,
            num_workers=self.config.train.num_workers,
            sampler=trn_real_sampler,
            drop_last=False,
            pin_memory=True,
        )

        # OPTIMIZER ---------------------------------------------------------------------------------------------------#
        optimizer = optim.Adam(
            craft.parameters(),
            lr=self.config.train.lr,
            weight_decay=self.config.train.weight_decay,
        )

        if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
            optimizer.load_state_dict(
                copyStateDict(self.net_param["optimizer"]))
            self.config.train.st_iter = self.net_param["optimizer"]["state"][
                0]["step"]
            self.config.train.lr = self.net_param["optimizer"]["param_groups"][
                0]["lr"]

        # LOSS --------------------------------------------------------------------------------------------------------#
        # mixed precision
        if self.config.train.amp:
            scaler = torch.cuda.amp.GradScaler()

            if (self.config.train.ckpt_path is not None
                    and self.config.train.st_iter != 0):
                scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
        else:
            scaler = None

        criterion = self.get_loss()

        # TRAIN -------------------------------------------------------------------------------------------------------#
        train_step = self.config.train.st_iter
        whole_training_step = self.config.train.end_iter
        update_lr_rate_step = 0
        training_lr = self.config.train.lr
        loss_value = 0
        batch_time = 0
        start_time = time.time()

        print(
            "================================ Train start ================================"
        )
        while train_step < whole_training_step:
            trn_real_sampler.set_epoch(train_step)
            for (
                    index,
                (
                    images,
                    region_scores,
                    affinity_scores,
                    confidence_masks,
                ),
            ) in enumerate(trn_real_loader):
                craft.train()
                if train_step > 0 and train_step % self.config.train.lr_decay == 0:
                    update_lr_rate_step += 1
                    training_lr = self.adjust_learning_rate(
                        optimizer,
                        self.config.train.gamma,
                        update_lr_rate_step,
                        self.config.train.lr,
                    )

                images = images.cuda(non_blocking=True)
                region_scores = region_scores.cuda(non_blocking=True)
                affinity_scores = affinity_scores.cuda(non_blocking=True)
                confidence_masks = confidence_masks.cuda(non_blocking=True)

                if self.config.train.use_synthtext:
                    # Synth image load
                    syn_image, syn_region_label, syn_affi_label, syn_confidence_mask = next(
                        batch_syn)
                    syn_image = syn_image.cuda(non_blocking=True)
                    syn_region_label = syn_region_label.cuda(non_blocking=True)
                    syn_affi_label = syn_affi_label.cuda(non_blocking=True)
                    syn_confidence_mask = syn_confidence_mask.cuda(
                        non_blocking=True)

                    # concat syn & custom image
                    images = torch.cat((syn_image, images), 0)
                    region_image_label = torch.cat(
                        (syn_region_label, region_scores), 0)
                    affinity_image_label = torch.cat(
                        (syn_affi_label, affinity_scores), 0)
                    confidence_mask_label = torch.cat(
                        (syn_confidence_mask, confidence_masks), 0)
                else:
                    region_image_label = region_scores
                    affinity_image_label = affinity_scores
                    confidence_mask_label = confidence_masks

                if self.config.train.amp:
                    with torch.cuda.amp.autocast():

                        output, _ = craft(images)
                        out1 = output[:, :, :, 0]
                        out2 = output[:, :, :, 1]

                        loss = criterion(
                            region_image_label,
                            affinity_image_label,
                            out1,
                            out2,
                            confidence_mask_label,
                            self.config.train.neg_rto,
                            self.config.train.n_min_neg,
                        )

                    optimizer.zero_grad()
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                else:
                    output, _ = craft(images)
                    out1 = output[:, :, :, 0]
                    out2 = output[:, :, :, 1]
                    loss = criterion(
                        region_image_label,
                        affinity_image_label,
                        out1,
                        out2,
                        confidence_mask_label,
                        self.config.train.neg_rto,
                    )

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                end_time = time.time()
                loss_value += loss.item()
                batch_time += end_time - start_time

                if train_step > 0 and train_step % 5 == 0 and self.gpu == 0:
                    mean_loss = loss_value / 5
                    loss_value = 0
                    avg_batch_time = batch_time / 5
                    batch_time = 0

                    print(
                        "{}, training_step: {}|{}, learning rate: {:.8f}, "
                        "training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
                            time.strftime("%Y-%m-%d:%H:%M:%S",
                                          time.localtime(time.time())),
                            train_step,
                            whole_training_step,
                            training_lr,
                            mean_loss,
                            avg_batch_time,
                        ))

                    if self.gpu == 0 and self.config.wandb_opt:
                        wandb.log({
                            "train_step": train_step,
                            "mean_loss": mean_loss
                        })

                if (train_step % self.config.train.eval_interval == 0
                        and train_step != 0):

                    craft.eval()
                    # initialize all buffer value with zero
                    if self.gpu == 0:
                        for buffer in buffer_dict.values():
                            for i in range(len(buffer)):
                                buffer[i] = None

                        print("Saving state, index:", train_step)
                        save_param_dic = {
                            "iter": train_step,
                            "craft": craft.state_dict(),
                            "optimizer": optimizer.state_dict(),
                        }
                        save_param_path = (self.config.results_dir +
                                           "/CRAFT_clr_" + repr(train_step) +
                                           ".pth")

                        if self.config.train.amp:
                            save_param_dic["scaler"] = scaler.state_dict()
                            save_param_path = (self.config.results_dir +
                                               "/CRAFT_clr_amp_" +
                                               repr(train_step) + ".pth")

                        torch.save(save_param_dic, save_param_path)

                    # validation
                    self.iou_eval(
                        "custom_data",
                        train_step,
                        buffer_dict["custom_data"],
                        craft,
                    )

                train_step += 1
                if train_step >= whole_training_step:
                    break

            if self.config.mode == "weak_supervision":
                state_dict = craft.module.state_dict()
                supervision_model.load_state_dict(state_dict)
                trn_real_dataset.update_model(supervision_model)

        # save last model
        if self.gpu == 0:
            save_param_dic = {
                "iter": train_step,
                "craft": craft.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            save_param_path = (self.config.results_dir + "/CRAFT_clr_" +
                               repr(train_step) + ".pth")

            if self.config.train.amp:
                save_param_dic["scaler"] = scaler.state_dict()
                save_param_path = (self.config.results_dir +
                                   "/CRAFT_clr_amp_" + repr(train_step) +
                                   ".pth")
            torch.save(save_param_dic, save_param_path)
예제 #4
0
    t1 = time.time() - t1

    # render results (optional)
    render_img = score_text.copy()
    render_img = np.hstack((render_img, score_link))
    ret_score_text = imgproc.cvt2HeatmapImg(render_img)

    if args.show_time:
        print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

    return boxes, polys, ret_score_text


if __name__ == '__main__':
    # load net
    net = CRAFT()  # initialize

    print('Loading weights from checkpoint (' + args.trained_model + ')')
    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
    else:
        net.load_state_dict(
            copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()
예제 #5
0
def main_eval(model_path, backbone, config, evaluator, result_dir, buffer,
              model, mode):

    if not os.path.exists(result_dir):
        os.makedirs(result_dir, exist_ok=True)

    total_imgs_bboxes_gt, total_imgs_path = load_test_dataset_iou(
        "custom_data", config)

    if mode == "weak_supervision" and torch.cuda.device_count() != 1:
        gpu_count = torch.cuda.device_count() // 2
    else:
        gpu_count = torch.cuda.device_count()
    gpu_idx = torch.cuda.current_device()
    torch.cuda.set_device(gpu_idx)

    # Only evaluation time
    if model is None:
        piece_imgs_path = total_imgs_path

        if backbone == "vgg":
            model = CRAFT()
        else:
            raise Exception("Undefined architecture")

        print("Loading weights from checkpoint (" + model_path + ")")
        net_param = torch.load(model_path, map_location=f"cuda:{gpu_idx}")
        model.load_state_dict(copyStateDict(net_param["craft"]))

        if config.cuda:
            model = model.cuda()
            cudnn.benchmark = False

    # Distributed evaluation in the middle of training time
    else:
        if buffer is not None:
            # check all buffer value is None for distributed evaluation
            assert all(
                v is None
                for v in buffer), "Buffer already filled with another value."
        slice_idx = len(total_imgs_bboxes_gt) // gpu_count

        # last gpu
        if gpu_idx == gpu_count - 1:
            piece_imgs_path = total_imgs_path[gpu_idx * slice_idx:]
            # piece_imgs_bboxes_gt = total_imgs_bboxes_gt[gpu_idx * slice_idx:]
        else:
            piece_imgs_path = total_imgs_path[gpu_idx *
                                              slice_idx:(gpu_idx + 1) *
                                              slice_idx]
            # piece_imgs_bboxes_gt = total_imgs_bboxes_gt[gpu_idx * slice_idx: (gpu_idx + 1) * slice_idx]

    model.eval()

    # -----------------------------------------------------------------------------------------------------------------#
    total_imgs_bboxes_pre = []
    for k, img_path in enumerate(tqdm(piece_imgs_path)):
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        single_img_bbox = []
        bboxes, polys, score_text = test_net(
            model,
            image,
            config.text_threshold,
            config.link_threshold,
            config.low_text,
            config.cuda,
            config.poly,
            config.canvas_size,
            config.mag_ratio,
        )

        for box in bboxes:
            box_info = {"points": box, "text": "###", "ignore": False}
            single_img_bbox.append(box_info)
        total_imgs_bboxes_pre.append(single_img_bbox)
        # Distributed evaluation -------------------------------------------------------------------------------------#
        if buffer is not None:
            buffer[gpu_idx * slice_idx + k] = single_img_bbox
        # print(sum([element is not None for element in buffer]))
        # -------------------------------------------------------------------------------------------------------------#

        if config.vis_opt:
            viz_test(
                image,
                score_text,
                pre_box=polys,
                gt_box=total_imgs_bboxes_gt[k],
                img_name=img_path,
                result_dir=result_dir,
                test_folder_name="custom_data",
            )

    # When distributed evaluation mode, wait until buffer is full filled
    if buffer is not None:
        while None in buffer:
            continue
        assert all(v is not None for v in buffer), "Buffer not filled"
        total_imgs_bboxes_pre = buffer

    results = []
    for i, (gt,
            pred) in enumerate(zip(total_imgs_bboxes_gt,
                                   total_imgs_bboxes_pre)):
        perSampleMetrics_dict = evaluator.evaluate_image(gt, pred)
        results.append(perSampleMetrics_dict)

    metrics = evaluator.combine_results(results)
    print(metrics)
    return metrics
예제 #6
0
    def train(self, buffer_dict):
        torch.cuda.set_device(self.gpu)

        # DATASET -----------------------------------------------------------------------------------------------------#
        trn_loader = self.trn_loader

        # MODEL -------------------------------------------------------------------------------------------------------#
        if self.config.train.backbone == "vgg":
            craft = CRAFT(pretrained=True, amp=self.config.train.amp)
        else:
            raise Exception("Undefined architecture")

        if self.config.train.ckpt_path is not None:
            craft.load_state_dict(copyStateDict(self.net_param["craft"]))
        craft = nn.SyncBatchNorm.convert_sync_batchnorm(craft)
        craft = craft.cuda()
        craft = torch.nn.parallel.DistributedDataParallel(craft, device_ids=[self.gpu])

        torch.backends.cudnn.benchmark = True

        # OPTIMIZER----------------------------------------------------------------------------------------------------#

        optimizer = optim.Adam(
            craft.parameters(),
            lr=self.config.train.lr,
            weight_decay=self.config.train.weight_decay,
        )

        if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
            optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"]))
            self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"]
            self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"]

        # LOSS --------------------------------------------------------------------------------------------------------#
        # mixed precision
        if self.config.train.amp:
            scaler = torch.cuda.amp.GradScaler()

            # load model
            if (
                self.config.train.ckpt_path is not None
                and self.config.train.st_iter != 0
            ):
                scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
        else:
            scaler = None

        criterion = self.get_loss()

        # TRAIN -------------------------------------------------------------------------------------------------------#
        train_step = self.config.train.st_iter
        whole_training_step = self.config.train.end_iter
        update_lr_rate_step = 0
        training_lr = self.config.train.lr
        loss_value = 0
        batch_time = 0
        epoch = 0
        start_time = time.time()

        while train_step < whole_training_step:
            self.trn_sampler.set_epoch(train_step)
            for (
                index,
                (image, region_image, affinity_image, confidence_mask,),
            ) in enumerate(trn_loader):
                craft.train()
                if train_step > 0 and train_step % self.config.train.lr_decay == 0:
                    update_lr_rate_step += 1
                    training_lr = self.adjust_learning_rate(
                        optimizer,
                        self.config.train.gamma,
                        update_lr_rate_step,
                        self.config.train.lr,
                    )

                images = image.cuda(non_blocking=True)
                region_image_label = region_image.cuda(non_blocking=True)
                affinity_image_label = affinity_image.cuda(non_blocking=True)
                confidence_mask_label = confidence_mask.cuda(non_blocking=True)

                if self.config.train.amp:
                    with torch.cuda.amp.autocast():

                        output, _ = craft(images)
                        out1 = output[:, :, :, 0]
                        out2 = output[:, :, :, 1]

                        loss = criterion(
                            region_image_label,
                            affinity_image_label,
                            out1,
                            out2,
                            confidence_mask_label,
                            self.config.train.neg_rto,
                            self.config.train.n_min_neg,
                        )

                    optimizer.zero_grad()
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                else:
                    output, _ = craft(images)
                    out1 = output[:, :, :, 0]
                    out2 = output[:, :, :, 1]
                    loss = criterion(
                        region_image_label,
                        affinity_image_label,
                        out1,
                        out2,
                        confidence_mask_label,
                        self.config.train.neg_rto,
                    )

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                end_time = time.time()
                loss_value += loss.item()
                batch_time += end_time - start_time

                if train_step > 0 and train_step % 5 == 0 and self.gpu == 0:
                    mean_loss = loss_value / 5
                    loss_value = 0
                    avg_batch_time = batch_time / 5
                    batch_time = 0

                    print(
                        "{}, training_step: {}|{}, learning rate: {:.8f}, "
                        "training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
                            time.strftime(
                                "%Y-%m-%d:%H:%M:%S", time.localtime(time.time())
                            ),
                            train_step,
                            whole_training_step,
                            training_lr,
                            mean_loss,
                            avg_batch_time,
                        )
                    )
                    if self.gpu == 0 and self.config.wandb_opt:
                        wandb.log({"train_step": train_step, "mean_loss": mean_loss})

                if (
                    train_step % self.config.train.eval_interval == 0
                    and train_step != 0
                ):

                    # initialize all buffer value with zero
                    if self.gpu == 0:
                        for buffer in buffer_dict.values():
                            for i in range(len(buffer)):
                                buffer[i] = None

                    print("Saving state, index:", train_step)
                    save_param_dic = {
                        "iter": train_step,
                        "craft": craft.state_dict(),
                        "optimizer": optimizer.state_dict(),
                    }
                    save_param_path = (
                        self.config.results_dir
                        + "/CRAFT_clr_"
                        + repr(train_step)
                        + ".pth"
                    )

                    if self.config.train.amp:
                        save_param_dic["scaler"] = scaler.state_dict()
                        save_param_path = (
                            self.config.results_dir
                            + "/CRAFT_clr_amp_"
                            + repr(train_step)
                            + ".pth"
                        )

                    torch.save(save_param_dic, save_param_path)

                    # validation
                    self.iou_eval(
                        "icdar2013",
                        train_step,
                        save_param_path,
                        buffer_dict["icdar2013"],
                        craft,
                    )

                train_step += 1
                if train_step >= whole_training_step:
                    break
            epoch += 1

        # save last model
        if self.gpu == 0:
            save_param_dic = {
                "iter": train_step,
                "craft": craft.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            save_param_path = (
                self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth"
            )

            if self.config.train.amp:
                save_param_dic["scaler"] = scaler.state_dict()
                save_param_path = (
                    self.config.results_dir
                    + "/CRAFT_clr_amp_"
                    + repr(train_step)
                    + ".pth"
                )
            torch.save(save_param_dic, save_param_path)