Ejemplo n.º 1
0
    def __init__(self):
        self.model = CRAFT()
        if pr.cuda:
            self.model.load_state_dict(
                copyStateDict(torch.load(pr.trained_model)))
            self.model.cuda()
            self.model = torch.nn.DataParallel(self.model)
            cudnn.benchmark = False
        else:
            self.model.load_state_dict(
                copyStateDict(torch.load(pr.trained_model,
                                         map_location='cpu')))
        self.model.eval()

        self.refine_model = None
        if pr.refine:
            self.refine_model = RefineNet()
            if pr.cuda:
                self.refine_model.load_state_dict(
                    copyStateDict(torch.load(pr.refiner_model)))
                self.refine_model = self.refine_net.cuda()
                self.refine_model = torch.nn.DataParallel(self.refine_model)
            else:
                self.refine_model.load_state_dict(
                    copyStateDict(
                        torch.load(pr.refiner_model, map_location='cpu')))

            self.refine_model.eval()
            pr.poly = True
Ejemplo n.º 2
0
 def __init__(self, data_loader, opt):
     self.opt = opt
     self.dataloader = dataloader
     self.model = CRAFT(opt).to(opt.device)
     self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
Ejemplo n.º 3
0
class Trainer(object):
    def __init__(self, data_loader, opt):
        self.opt = opt
        self.dataloader = dataloader
        self.model = CRAFT(opt).to(opt.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        # self.cirterion = Criterion()

    def train(self):
        self.model.train()
        # iterator = tqdm(dataloader)
        # def change_lr(no_i):
        #     for i in config.lr:
        #         if i == no_i:
        #             print("Learning Rate Changed to ", config.lr[i])
        #             for param_group in optimizer.param_groups:
        #                 param_group["lr"] = config.lr[i]
        criterion = nn.CrossEntropyLoss()

        batches_done = 0
        for epoch in range(opt.n_epochs):
            for i, (image, cur_masks, target_mask,
                    target_class) in enumerate(self.dataloader):
                # change_lr(no)
                x = torch.cat([image, cur_masks], dim=1)
                x = x.to(opt.device)
                target_mask = target_mask.to(opt.device)
                target_class = target_class.to(opt.device)

                y_mask, y_logits = self.model(x)
                loss_mask = F.binary_cross_entropy(torch.sigmoid(y_mask),
                                                   target_mask)
                loss_cat = criterion(y_logits, target_class)
                loss = loss_mask + loss_cat
                # loss = self.cirterion(y, target_mask)

                # loss = (
                #     loss_criterian(output, weight, weight_affinity).mean()
                #     / config.optimizer_iteration
                # )
                # all_loss.append(loss.item() * config.optimizer_iteration)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                if batches_done % opt.sample_interval == 0:
                    print(
                        "[Epoch %d/%d] [Batch %d/%d] [loss_cat: %f] [loss_mask: %f]"
                        % (epoch, opt.n_epochs, i, len(dataloader),
                           loss_cat.item(), loss_mask.item()))
                    save_image(
                        # denormalize(fake_img, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]).data[
                        #     :25
                        # ],
                        y_mask.data[:9],
                        os.path.join(self.opt.sample_path,
                                     "{:06d}_mask.png".format(batches_done)),
                        nrow=3,
                        # normalize=True,
                    )
                    save_image(
                        target_mask.data[:9],
                        os.path.join(
                            self.opt.sample_path,
                            "{:06d}_mask_real.png".format(batches_done)),
                        nrow=3,
                    )
                batches_done += 1
Ejemplo n.º 4
0
class CraftDetection:
    def __init__(self):
        self.model = CRAFT()
        if pr.cuda:
            self.model.load_state_dict(
                copyStateDict(torch.load(pr.trained_model)))
            self.model.cuda()
            self.model = torch.nn.DataParallel(self.model)
            cudnn.benchmark = False
        else:
            self.model.load_state_dict(
                copyStateDict(torch.load(pr.trained_model,
                                         map_location='cpu')))
        self.model.eval()

        self.refine_model = None
        if pr.refine:
            self.refine_model = RefineNet()
            if pr.cuda:
                self.refine_model.load_state_dict(
                    copyStateDict(torch.load(pr.refiner_model)))
                self.refine_model = self.refine_net.cuda()
                self.refine_model = torch.nn.DataParallel(self.refine_model)
            else:
                self.refine_model.load_state_dict(
                    copyStateDict(
                        torch.load(pr.refiner_model, map_location='cpu')))

            self.refine_model.eval()
            pr.poly = True

    def text_detect(self, image):
        # if not os.path.exists(image_path):
        #     print("Not exists path")
        #     return []
        # image = imgproc.loadImage(image_path)       # numpy array img (RGB order)
        # image = cv2.imread()

        time0 = time.time()

        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image,
            pr.canvas_size,
            interpolation=cv2.INTER_LINEAR,
            mag_ratio=pr.mag_ratio)
        print(img_resized.shape)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
        if pr.cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = self.model(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # refine link
        if self.refine_model is not None:
            with torch.no_grad():
                y_refiner = self.refine_model(y, feature)
            score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

        time0 = time.time() - time0
        time1 = time.time()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link,
                                               pr.text_threshold,
                                               pr.link_threshold, pr.low_text,
                                               pr.poly)

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]

        # expand box: poly  = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)

        if pr.horizontal_mode:
            if self.check_horizontal(polys):
                height, width, channel = image.shape
                new_polys = []
                for box in polys:
                    [[l1, t1], [r1, t2], [r2, b1], [l2, b2]] = box
                    if t1 < t2:
                        l, r, t, b = l2, r1, t1, b1
                    elif t1 > t2:
                        l, r, t, b = l1, r2, t2, b2
                    else:
                        l, r, t, b = l1, r1, t1, b1
                    h_box = abs(b - t)
                    t = max(0, t - h_box * pr.expand_ratio)
                    b = min(b + h_box * pr.expand_ratio, height)
                    x_min, y_min, x_max, y_max = l, t, r, b
                    new_box = [x_min, y_min, x_max, y_max]
                    new_polys.append(new_box)

                polys = np.array(new_polys, dtype=np.float32)

        # for box in polys:

        time1 = time.time() - time1
        total_time = round(time0 + time1, 2)

        # render results (optional)
        render_img = score_text.copy()
        render_img = np.hstack((render_img, score_link))
        ret_score_text = imgproc.cvt2HeatmapImg(render_img)

        if pr.show_time:
            print("\ninfer/postproc time : {:.3f}/{:.3f}".format(time0, time1))
        if pr.folder_test:
            return boxes, polys, ret_score_text

        if pr.visualize:
            img_draw = displayResult(img=image[:, :, ::-1], boxes=polys)
            plt.imshow(cv2.cvtColor(img_draw, cv2.COLOR_RGB2BGR))
            plt.show()

        result_boxes = []
        for box in polys:
            result_boxes.append(box.tolist())
        return result_boxes, total_time

    def test_folder(self, folder_path):

        image_list, _, _ = file_utils.get_files(folder_path)
        if not os.path.exists(pr.result_folder):
            os.mkdir(pr.result_folder)
        t = time.time()

        # load data
        for k, image_path in enumerate(image_list):
            print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list),
                                                      image_path),
                  end='\r')

            bboxes, polys, score_text = self.text_detect(image_path)

            # save score text
            filename, file_ext = os.path.splitext(os.path.basename(image_path))
            mask_file = pr.result_folder + "/res_" + filename + '_mask.jpg'
            cv2.imwrite(mask_file, score_text)
            image = imgproc.loadImage(image_path)
            file_utils.saveResult(image_path,
                                  image[:, :, ::-1],
                                  polys,
                                  dirname=pr.result_folder)

        print("elapsed time : {}s".format(time.time() - t))

    def check_horizontal(self, boxes):
        total_box = len(boxes)
        num_box_horizontal = 0
        for box in boxes:
            [[l1, t1], [r1, t2], [r2, b1], [l2, b2]] = box
            if t1 == t2:
                num_box_horizontal += 1

        ratio_box_horizontal = num_box_horizontal / float(total_box)
        print("Ratio box horizontal: ", ratio_box_horizontal)
        if ratio_box_horizontal >= pr.ratio_box_horizontal:
            return True
        else:
            return False
Ejemplo n.º 5
0
    def train(self, buffer_dict):

        torch.cuda.set_device(self.gpu)
        total_gpu_num = torch.cuda.device_count()

        # MODEL -------------------------------------------------------------------------------------------------------#
        # SUPERVISION model
        if self.config.mode == "weak_supervision":
            if self.config.train.backbone == "vgg":
                supervision_model = CRAFT(pretrained=False,
                                          amp=self.config.train.amp)
            else:
                raise Exception("Undefined architecture")

            # NOTE: only work on half GPU assign train / half GPU assign supervision setting
            supervision_device = total_gpu_num // 2 + self.gpu
            if self.config.train.ckpt_path is not None:
                supervision_param = self.get_load_param(supervision_device)
                supervision_model.load_state_dict(
                    copyStateDict(supervision_param["craft"]))
                supervision_model = supervision_model.to(
                    f"cuda:{supervision_device}")
            print(f"Supervision model loading on : gpu {supervision_device}")
        else:
            supervision_model, supervision_device = None, None

        # TRAIN model
        if self.config.train.backbone == "vgg":
            craft = CRAFT(pretrained=False, amp=self.config.train.amp)
        else:
            raise Exception("Undefined architecture")

        if self.config.train.ckpt_path is not None:
            craft.load_state_dict(copyStateDict(self.net_param["craft"]))

        craft = nn.SyncBatchNorm.convert_sync_batchnorm(craft)
        craft = craft.cuda()
        craft = torch.nn.parallel.DistributedDataParallel(
            craft, device_ids=[self.gpu])

        torch.backends.cudnn.benchmark = True

        # DATASET -----------------------------------------------------------------------------------------------------#

        if self.config.train.use_synthtext:
            trn_syn_loader = self.get_synth_loader()
            batch_syn = iter(trn_syn_loader)

        if self.config.train.real_dataset == "custom":
            trn_real_dataset = self.get_custom_dataset()
        else:
            raise Exception("Undefined dataset")

        if self.config.mode == "weak_supervision":
            trn_real_dataset.update_model(supervision_model)
            trn_real_dataset.update_device(supervision_device)

        trn_real_sampler = torch.utils.data.distributed.DistributedSampler(
            trn_real_dataset)
        trn_real_loader = torch.utils.data.DataLoader(
            trn_real_dataset,
            batch_size=self.config.train.batch_size,
            shuffle=False,
            num_workers=self.config.train.num_workers,
            sampler=trn_real_sampler,
            drop_last=False,
            pin_memory=True,
        )

        # OPTIMIZER ---------------------------------------------------------------------------------------------------#
        optimizer = optim.Adam(
            craft.parameters(),
            lr=self.config.train.lr,
            weight_decay=self.config.train.weight_decay,
        )

        if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
            optimizer.load_state_dict(
                copyStateDict(self.net_param["optimizer"]))
            self.config.train.st_iter = self.net_param["optimizer"]["state"][
                0]["step"]
            self.config.train.lr = self.net_param["optimizer"]["param_groups"][
                0]["lr"]

        # LOSS --------------------------------------------------------------------------------------------------------#
        # mixed precision
        if self.config.train.amp:
            scaler = torch.cuda.amp.GradScaler()

            if (self.config.train.ckpt_path is not None
                    and self.config.train.st_iter != 0):
                scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
        else:
            scaler = None

        criterion = self.get_loss()

        # TRAIN -------------------------------------------------------------------------------------------------------#
        train_step = self.config.train.st_iter
        whole_training_step = self.config.train.end_iter
        update_lr_rate_step = 0
        training_lr = self.config.train.lr
        loss_value = 0
        batch_time = 0
        start_time = time.time()

        print(
            "================================ Train start ================================"
        )
        while train_step < whole_training_step:
            trn_real_sampler.set_epoch(train_step)
            for (
                    index,
                (
                    images,
                    region_scores,
                    affinity_scores,
                    confidence_masks,
                ),
            ) in enumerate(trn_real_loader):
                craft.train()
                if train_step > 0 and train_step % self.config.train.lr_decay == 0:
                    update_lr_rate_step += 1
                    training_lr = self.adjust_learning_rate(
                        optimizer,
                        self.config.train.gamma,
                        update_lr_rate_step,
                        self.config.train.lr,
                    )

                images = images.cuda(non_blocking=True)
                region_scores = region_scores.cuda(non_blocking=True)
                affinity_scores = affinity_scores.cuda(non_blocking=True)
                confidence_masks = confidence_masks.cuda(non_blocking=True)

                if self.config.train.use_synthtext:
                    # Synth image load
                    syn_image, syn_region_label, syn_affi_label, syn_confidence_mask = next(
                        batch_syn)
                    syn_image = syn_image.cuda(non_blocking=True)
                    syn_region_label = syn_region_label.cuda(non_blocking=True)
                    syn_affi_label = syn_affi_label.cuda(non_blocking=True)
                    syn_confidence_mask = syn_confidence_mask.cuda(
                        non_blocking=True)

                    # concat syn & custom image
                    images = torch.cat((syn_image, images), 0)
                    region_image_label = torch.cat(
                        (syn_region_label, region_scores), 0)
                    affinity_image_label = torch.cat(
                        (syn_affi_label, affinity_scores), 0)
                    confidence_mask_label = torch.cat(
                        (syn_confidence_mask, confidence_masks), 0)
                else:
                    region_image_label = region_scores
                    affinity_image_label = affinity_scores
                    confidence_mask_label = confidence_masks

                if self.config.train.amp:
                    with torch.cuda.amp.autocast():

                        output, _ = craft(images)
                        out1 = output[:, :, :, 0]
                        out2 = output[:, :, :, 1]

                        loss = criterion(
                            region_image_label,
                            affinity_image_label,
                            out1,
                            out2,
                            confidence_mask_label,
                            self.config.train.neg_rto,
                            self.config.train.n_min_neg,
                        )

                    optimizer.zero_grad()
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                else:
                    output, _ = craft(images)
                    out1 = output[:, :, :, 0]
                    out2 = output[:, :, :, 1]
                    loss = criterion(
                        region_image_label,
                        affinity_image_label,
                        out1,
                        out2,
                        confidence_mask_label,
                        self.config.train.neg_rto,
                    )

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                end_time = time.time()
                loss_value += loss.item()
                batch_time += end_time - start_time

                if train_step > 0 and train_step % 5 == 0 and self.gpu == 0:
                    mean_loss = loss_value / 5
                    loss_value = 0
                    avg_batch_time = batch_time / 5
                    batch_time = 0

                    print(
                        "{}, training_step: {}|{}, learning rate: {:.8f}, "
                        "training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
                            time.strftime("%Y-%m-%d:%H:%M:%S",
                                          time.localtime(time.time())),
                            train_step,
                            whole_training_step,
                            training_lr,
                            mean_loss,
                            avg_batch_time,
                        ))

                    if self.gpu == 0 and self.config.wandb_opt:
                        wandb.log({
                            "train_step": train_step,
                            "mean_loss": mean_loss
                        })

                if (train_step % self.config.train.eval_interval == 0
                        and train_step != 0):

                    craft.eval()
                    # initialize all buffer value with zero
                    if self.gpu == 0:
                        for buffer in buffer_dict.values():
                            for i in range(len(buffer)):
                                buffer[i] = None

                        print("Saving state, index:", train_step)
                        save_param_dic = {
                            "iter": train_step,
                            "craft": craft.state_dict(),
                            "optimizer": optimizer.state_dict(),
                        }
                        save_param_path = (self.config.results_dir +
                                           "/CRAFT_clr_" + repr(train_step) +
                                           ".pth")

                        if self.config.train.amp:
                            save_param_dic["scaler"] = scaler.state_dict()
                            save_param_path = (self.config.results_dir +
                                               "/CRAFT_clr_amp_" +
                                               repr(train_step) + ".pth")

                        torch.save(save_param_dic, save_param_path)

                    # validation
                    self.iou_eval(
                        "custom_data",
                        train_step,
                        buffer_dict["custom_data"],
                        craft,
                    )

                train_step += 1
                if train_step >= whole_training_step:
                    break

            if self.config.mode == "weak_supervision":
                state_dict = craft.module.state_dict()
                supervision_model.load_state_dict(state_dict)
                trn_real_dataset.update_model(supervision_model)

        # save last model
        if self.gpu == 0:
            save_param_dic = {
                "iter": train_step,
                "craft": craft.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            save_param_path = (self.config.results_dir + "/CRAFT_clr_" +
                               repr(train_step) + ".pth")

            if self.config.train.amp:
                save_param_dic["scaler"] = scaler.state_dict()
                save_param_path = (self.config.results_dir +
                                   "/CRAFT_clr_amp_" + repr(train_step) +
                                   ".pth")
            torch.save(save_param_dic, save_param_path)
Ejemplo n.º 6
0
    t1 = time.time() - t1

    # render results (optional)
    render_img = score_text.copy()
    render_img = np.hstack((render_img, score_link))
    ret_score_text = imgproc.cvt2HeatmapImg(render_img)

    if args.show_time:
        print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

    return boxes, polys, ret_score_text


if __name__ == '__main__':
    # load net
    net = CRAFT()  # initialize

    print('Loading weights from checkpoint (' + args.trained_model + ')')
    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
    else:
        net.load_state_dict(
            copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()
Ejemplo n.º 7
0
def main_eval(model_path, backbone, config, evaluator, result_dir, buffer,
              model, mode):

    if not os.path.exists(result_dir):
        os.makedirs(result_dir, exist_ok=True)

    total_imgs_bboxes_gt, total_imgs_path = load_test_dataset_iou(
        "custom_data", config)

    if mode == "weak_supervision" and torch.cuda.device_count() != 1:
        gpu_count = torch.cuda.device_count() // 2
    else:
        gpu_count = torch.cuda.device_count()
    gpu_idx = torch.cuda.current_device()
    torch.cuda.set_device(gpu_idx)

    # Only evaluation time
    if model is None:
        piece_imgs_path = total_imgs_path

        if backbone == "vgg":
            model = CRAFT()
        else:
            raise Exception("Undefined architecture")

        print("Loading weights from checkpoint (" + model_path + ")")
        net_param = torch.load(model_path, map_location=f"cuda:{gpu_idx}")
        model.load_state_dict(copyStateDict(net_param["craft"]))

        if config.cuda:
            model = model.cuda()
            cudnn.benchmark = False

    # Distributed evaluation in the middle of training time
    else:
        if buffer is not None:
            # check all buffer value is None for distributed evaluation
            assert all(
                v is None
                for v in buffer), "Buffer already filled with another value."
        slice_idx = len(total_imgs_bboxes_gt) // gpu_count

        # last gpu
        if gpu_idx == gpu_count - 1:
            piece_imgs_path = total_imgs_path[gpu_idx * slice_idx:]
            # piece_imgs_bboxes_gt = total_imgs_bboxes_gt[gpu_idx * slice_idx:]
        else:
            piece_imgs_path = total_imgs_path[gpu_idx *
                                              slice_idx:(gpu_idx + 1) *
                                              slice_idx]
            # piece_imgs_bboxes_gt = total_imgs_bboxes_gt[gpu_idx * slice_idx: (gpu_idx + 1) * slice_idx]

    model.eval()

    # -----------------------------------------------------------------------------------------------------------------#
    total_imgs_bboxes_pre = []
    for k, img_path in enumerate(tqdm(piece_imgs_path)):
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        single_img_bbox = []
        bboxes, polys, score_text = test_net(
            model,
            image,
            config.text_threshold,
            config.link_threshold,
            config.low_text,
            config.cuda,
            config.poly,
            config.canvas_size,
            config.mag_ratio,
        )

        for box in bboxes:
            box_info = {"points": box, "text": "###", "ignore": False}
            single_img_bbox.append(box_info)
        total_imgs_bboxes_pre.append(single_img_bbox)
        # Distributed evaluation -------------------------------------------------------------------------------------#
        if buffer is not None:
            buffer[gpu_idx * slice_idx + k] = single_img_bbox
        # print(sum([element is not None for element in buffer]))
        # -------------------------------------------------------------------------------------------------------------#

        if config.vis_opt:
            viz_test(
                image,
                score_text,
                pre_box=polys,
                gt_box=total_imgs_bboxes_gt[k],
                img_name=img_path,
                result_dir=result_dir,
                test_folder_name="custom_data",
            )

    # When distributed evaluation mode, wait until buffer is full filled
    if buffer is not None:
        while None in buffer:
            continue
        assert all(v is not None for v in buffer), "Buffer not filled"
        total_imgs_bboxes_pre = buffer

    results = []
    for i, (gt,
            pred) in enumerate(zip(total_imgs_bboxes_gt,
                                   total_imgs_bboxes_pre)):
        perSampleMetrics_dict = evaluator.evaluate_image(gt, pred)
        results.append(perSampleMetrics_dict)

    metrics = evaluator.combine_results(results)
    print(metrics)
    return metrics
Ejemplo n.º 8
0
    def train(self, buffer_dict):
        torch.cuda.set_device(self.gpu)

        # DATASET -----------------------------------------------------------------------------------------------------#
        trn_loader = self.trn_loader

        # MODEL -------------------------------------------------------------------------------------------------------#
        if self.config.train.backbone == "vgg":
            craft = CRAFT(pretrained=True, amp=self.config.train.amp)
        else:
            raise Exception("Undefined architecture")

        if self.config.train.ckpt_path is not None:
            craft.load_state_dict(copyStateDict(self.net_param["craft"]))
        craft = nn.SyncBatchNorm.convert_sync_batchnorm(craft)
        craft = craft.cuda()
        craft = torch.nn.parallel.DistributedDataParallel(craft, device_ids=[self.gpu])

        torch.backends.cudnn.benchmark = True

        # OPTIMIZER----------------------------------------------------------------------------------------------------#

        optimizer = optim.Adam(
            craft.parameters(),
            lr=self.config.train.lr,
            weight_decay=self.config.train.weight_decay,
        )

        if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
            optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"]))
            self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"]
            self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"]

        # LOSS --------------------------------------------------------------------------------------------------------#
        # mixed precision
        if self.config.train.amp:
            scaler = torch.cuda.amp.GradScaler()

            # load model
            if (
                self.config.train.ckpt_path is not None
                and self.config.train.st_iter != 0
            ):
                scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
        else:
            scaler = None

        criterion = self.get_loss()

        # TRAIN -------------------------------------------------------------------------------------------------------#
        train_step = self.config.train.st_iter
        whole_training_step = self.config.train.end_iter
        update_lr_rate_step = 0
        training_lr = self.config.train.lr
        loss_value = 0
        batch_time = 0
        epoch = 0
        start_time = time.time()

        while train_step < whole_training_step:
            self.trn_sampler.set_epoch(train_step)
            for (
                index,
                (image, region_image, affinity_image, confidence_mask,),
            ) in enumerate(trn_loader):
                craft.train()
                if train_step > 0 and train_step % self.config.train.lr_decay == 0:
                    update_lr_rate_step += 1
                    training_lr = self.adjust_learning_rate(
                        optimizer,
                        self.config.train.gamma,
                        update_lr_rate_step,
                        self.config.train.lr,
                    )

                images = image.cuda(non_blocking=True)
                region_image_label = region_image.cuda(non_blocking=True)
                affinity_image_label = affinity_image.cuda(non_blocking=True)
                confidence_mask_label = confidence_mask.cuda(non_blocking=True)

                if self.config.train.amp:
                    with torch.cuda.amp.autocast():

                        output, _ = craft(images)
                        out1 = output[:, :, :, 0]
                        out2 = output[:, :, :, 1]

                        loss = criterion(
                            region_image_label,
                            affinity_image_label,
                            out1,
                            out2,
                            confidence_mask_label,
                            self.config.train.neg_rto,
                            self.config.train.n_min_neg,
                        )

                    optimizer.zero_grad()
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                else:
                    output, _ = craft(images)
                    out1 = output[:, :, :, 0]
                    out2 = output[:, :, :, 1]
                    loss = criterion(
                        region_image_label,
                        affinity_image_label,
                        out1,
                        out2,
                        confidence_mask_label,
                        self.config.train.neg_rto,
                    )

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                end_time = time.time()
                loss_value += loss.item()
                batch_time += end_time - start_time

                if train_step > 0 and train_step % 5 == 0 and self.gpu == 0:
                    mean_loss = loss_value / 5
                    loss_value = 0
                    avg_batch_time = batch_time / 5
                    batch_time = 0

                    print(
                        "{}, training_step: {}|{}, learning rate: {:.8f}, "
                        "training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
                            time.strftime(
                                "%Y-%m-%d:%H:%M:%S", time.localtime(time.time())
                            ),
                            train_step,
                            whole_training_step,
                            training_lr,
                            mean_loss,
                            avg_batch_time,
                        )
                    )
                    if self.gpu == 0 and self.config.wandb_opt:
                        wandb.log({"train_step": train_step, "mean_loss": mean_loss})

                if (
                    train_step % self.config.train.eval_interval == 0
                    and train_step != 0
                ):

                    # initialize all buffer value with zero
                    if self.gpu == 0:
                        for buffer in buffer_dict.values():
                            for i in range(len(buffer)):
                                buffer[i] = None

                    print("Saving state, index:", train_step)
                    save_param_dic = {
                        "iter": train_step,
                        "craft": craft.state_dict(),
                        "optimizer": optimizer.state_dict(),
                    }
                    save_param_path = (
                        self.config.results_dir
                        + "/CRAFT_clr_"
                        + repr(train_step)
                        + ".pth"
                    )

                    if self.config.train.amp:
                        save_param_dic["scaler"] = scaler.state_dict()
                        save_param_path = (
                            self.config.results_dir
                            + "/CRAFT_clr_amp_"
                            + repr(train_step)
                            + ".pth"
                        )

                    torch.save(save_param_dic, save_param_path)

                    # validation
                    self.iou_eval(
                        "icdar2013",
                        train_step,
                        save_param_path,
                        buffer_dict["icdar2013"],
                        craft,
                    )

                train_step += 1
                if train_step >= whole_training_step:
                    break
            epoch += 1

        # save last model
        if self.gpu == 0:
            save_param_dic = {
                "iter": train_step,
                "craft": craft.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            save_param_path = (
                self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth"
            )

            if self.config.train.amp:
                save_param_dic["scaler"] = scaler.state_dict()
                save_param_path = (
                    self.config.results_dir
                    + "/CRAFT_clr_amp_"
                    + repr(train_step)
                    + ".pth"
                )
            torch.save(save_param_dic, save_param_path)
Ejemplo n.º 9
0
class CraftDetection:
    def __init__(self):
        self.model = CRAFT()
        if pr.cuda:
            self.model.load_state_dict(
                copyStateDict(torch.load(pr.trained_model)))
            self.model.cuda()
            self.model = torch.nn.DataParallel(self.model)
            cudnn.benchmark = False
        else:
            self.model.load_state_dict(
                copyStateDict(torch.load(pr.trained_model,
                                         map_location='cpu')))
        self.model.eval()

        self.refine_model = None
        if pr.refine:
            self.refine_model = RefineNet()
            if pr.cuda:
                self.refine_model.load_state_dict(
                    copyStateDict(torch.load(pr.refiner_model)))
                self.refine_model = self.refine_net.cuda()
                self.refine_model = torch.nn.DataParallel(self.refine_model)
            else:
                self.refine_model.load_state_dict(
                    copyStateDict(
                        torch.load(pr.refiner_model, map_location='cpu')))

            self.refine_model.eval()
            pr.poly = True

    def text_detect(self, image, have_cmnd=True):
        time0 = time.time()

        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image,
            pr.canvas_size,
            interpolation=cv2.INTER_LINEAR,
            mag_ratio=pr.mag_ratio)
        print(img_resized.shape)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
        if pr.cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = self.model(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # refine link
        if self.refine_model is not None:
            with torch.no_grad():
                y_refiner = self.refine_model(y, feature)
            score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link,
                                               pr.text_threshold,
                                               pr.link_threshold, pr.low_text,
                                               pr.poly)

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]

        # get box + extend
        list_box = []
        for box in polys:
            [[l1, t1], [r1, t2], [r2, b1], [l2, b2]] = box
            if t1 < t2:
                l, r, t, b = l2, r1, t1, b1
            elif t1 > t2:
                l, r, t, b = l1, r2, t2, b2
            else:
                l, r, t, b = l1, r1, t1, b1

            xmin, ymin, xmax, ymax = l, t, r, b
            xmin, ymin, xmax, ymax = max(0, xmin - int((xmax - xmin) * pr.expand_ratio)),\
                                 max(0, ymin - int((ymax - ymin) * pr.expand_ratio)),\
                                 xmax + int((xmax - xmin) * pr.expand_ratio),\
                                 ymax + int((ymax - ymin) * pr.expand_ratio)
            list_box.append([xmin, ymin, xmax, ymax])

        # sort line
        dict_cum_sorted = self.sort_line_cmnd(list_box)
        list_box_optim = []
        for cum in dict_cum_sorted:
            for box in cum:
                list_box_optim.append(box)

        # draw box on image
        img_res = image.copy()
        img_res = np.ascontiguousarray(img_res)
        for box in list_box_optim:
            xmin, ymin, xmax, ymax = box
            cv2.rectangle(img_res, (int(xmin), int(ymin)),
                          (int(xmax), int(ymax)), (29, 187, 255), 2, 2)

        # crop image

        result_list_img_cum = []
        image_PIL = Image.fromarray(image)
        for cum in dict_cum_sorted:
            list_img = []
            for box in cum:
                xmin, ymin, xmax, ymax = box
                list_img.append(image_PIL.copy().crop(
                    (xmin, ymin, xmax, ymax)))
            result_list_img_cum.append(list_img)
        return result_list_img_cum, img_res, None

    def sort_line_cmnd(self, boxes):

        if len(boxes) == 0:
            return []
        boxes = sorted(boxes, key=lambda x: x[1])  # sort by ymin
        lines = [[]]

        # y_center = (boxes[0][1] + boxes[0][3]) / 2.0
        y_max_base = boxes[0][3]  # y_max
        i = 0
        for box in boxes:
            if box[1] + 0.5 * abs(box[3] -
                                  box[1]) <= y_max_base:  # y_min <= y_max_base
                lines[i].append(box)
            else:
                lines[i] = sorted(lines[i], key=lambda x: x[0])
                # y_center = (box[1] + box[3]) / 2.0
                y_max_base = box[3]
                lines.append([])
                i += 1
                lines[i].append(box)

        temp = []

        for line in lines:
            temp.append(line[0][1])
        index_sort = np.argsort(np.array(temp)).tolist()
        lines_new = [self.remove(lines[i]) for i in index_sort]

        return lines_new
        # return lines

    def remove(self, line):
        line = sorted(line, key=lambda x: x[0])
        result = []
        check_index = -1
        for index in range(len(line)):
            if check_index == index:
                pass
            else:
                result.append(line[index])
                check_index = index
            if index == len(line) - 1:
                break
            if self.compute_iou(line[index], line[index + 1]) > 0.25:
                s1 = (line[index][2] - line[index][0] +
                      1) * (line[index][3] - line[index][1] + 1)
                s2 = (line[index + 1][2] - line[index + 1][0] +
                      1) * (line[index + 1][3] - line[index + 1][1] + 1)
                if s2 > s1:
                    del (result[-1])
                    result.append(line[index + 1])
                check_index = index + 1
        result = sorted(result, key=lambda x: x[0])
        return result

    def compute_iou(self, box1, box2):

        x_min_inter = max(box1[0], box2[0])
        y_min_inter = max(box1[1], box2[1])
        x_max_inter = min(box1[2], box2[2])
        y_max_inter = min(box1[3], box2[3])

        inter_area = max(0, x_max_inter - x_min_inter + 1) * max(
            0, y_max_inter - y_min_inter + 1)

        s1 = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
        s2 = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
        # print(inter_area)
        iou = float(inter_area / (s1 + s2 - inter_area))

        return iou

    def sort_line(self, boxes):
        if len(boxes) == 0:
            return []
        boxes = sorted(boxes, key=lambda x: x[1])
        lines = [[]]

        y_center = (boxes[0][1] + boxes[0][3]) / 2.0
        i = 0
        for box in boxes:
            if box[1] < y_center:
                lines[i].append(box)
            else:
                lines[i] = sorted(lines[i], key=lambda x: x[0])
                y_center = (box[1] + box[3]) / 2.0
                lines.append([])
                i += 1
                lines[i].append(box)

        temp = []

        for line in lines:
            temp.append(line[0][1])
        index_sort = np.argsort(np.array(temp)).tolist()
        lines_new = [self.remove(lines[i]) for i in index_sort]

        return lines_new