Exemplo n.º 1
0
    def __init__(self):
        self.model = CRAFT()
        if pr.cuda:
            self.model.load_state_dict(
                copyStateDict(torch.load(pr.trained_model)))
            self.model.cuda()
            self.model = torch.nn.DataParallel(self.model)
            cudnn.benchmark = False
        else:
            self.model.load_state_dict(
                copyStateDict(torch.load(pr.trained_model,
                                         map_location='cpu')))
        self.model.eval()

        self.refine_model = None
        if pr.refine:
            self.refine_model = RefineNet()
            if pr.cuda:
                self.refine_model.load_state_dict(
                    copyStateDict(torch.load(pr.refiner_model)))
                self.refine_model = self.refine_net.cuda()
                self.refine_model = torch.nn.DataParallel(self.refine_model)
            else:
                self.refine_model.load_state_dict(
                    copyStateDict(
                        torch.load(pr.refiner_model, map_location='cpu')))

            self.refine_model.eval()
            pr.poly = True
Exemplo n.º 2
0
def main(opt):
    load_epoch = opt.load_epoch
    test_dataset = P2PDataset(dataset_path=cfg.DATA_DIR,
                              root_idx=cfg.DATASET.ROOT_IDX)
    test_loader = DataLoader(test_dataset,
                             batch_size=cfg.TEST.BATCH_SIZE,
                             shuffle=False)
    model = RefineNet()
    model = model.cuda()

    min_root_error = 1000
    min_idx = 0
    while True:
        ckpt_file = os.path.join(cfg.CHECKPOINT_DIR,
                                 "RefineNet_epoch_%03d.pth" % load_epoch)
        if not os.path.exists(ckpt_file):
            print("No ckpt of epoch {}".format(load_epoch))
            print("Best real_error iter is {}, error is {}".format(
                min_idx, min_root_error))
            break
        load_state_dict(model, torch.load(ckpt_file))
        model.eval()

        count = 0
        root_error = 0
        time_total = 0.0
        for i, (inp, gt_t) in enumerate(test_loader):
            inp = inp.cuda()
            gt_t = gt_t
            with torch.no_grad():
                start_time = time.time()
                pred_t = model(inp)
                time_total += time.time() - start_time
                pred_t = pred_t.cpu()
                # loss = criterion(pred, gt)
                for j in range(len(pred_t)):
                    gt = copy.deepcopy(gt_t[j].numpy())
                    gt.resize((15, 3))
                    pred = copy.deepcopy(pred_t[j].numpy())
                    pred.resize((15, 3))
                    count += 1
                    root_error += np.linalg.norm(np.abs(pred - gt), axis=1)

        print_root_error = root_error / count
        mean_root_error = np.mean(print_root_error)
        print("Root error of epoch {} is {}, mean is {}".format(
            load_epoch, print_root_error, mean_root_error))
        if mean_root_error < min_root_error:
            min_root_error = mean_root_error
            min_idx = load_epoch
        load_epoch += cfg.SAVE_FREQ
        print("Time per inference is {}".format(time_total / len(test_loader)))
Exemplo n.º 3
0
    def __init__(self):
        self.model = SMAP(cfg, run_efficient=cfg.RUN_EFFICIENT)
        self.device = torch.device(cfg.MODEL.DEVICE)
        self.model.to(self.device)

        if cfg.REFINE:
            self.refine_model = RefineNet()
            self.refine_model.to(self.device)
            refine_model_file = rospy.get_param(
                '~refinenet_model',
                '/ros/src/smap_ros/resources/RefineNet.pth')
        else:
            self.refine_model = None
            refine_model_file = ""
        smap_model = rospy.get_param(
            '~smap_model', '/ros/src/smap_ros/resources/SMAP_model.pth')

        if os.path.exists(smap_model):
            state_dict = torch.load(smap_model,
                                    map_location=lambda storage, loc: storage)
            state_dict = state_dict['model']
            self.model.load_state_dict(state_dict)
            if os.path.exists(refine_model_file):
                self.refine_model.load_state_dict(
                    torch.load(refine_model_file))
            elif self.refine_model is not None:
                rospy.logerr("No such RefineNet checkpoint of {}".format(
                    refine_model_file))
                return
        else:
            rospy.logerr("No such checkpoint of SMAP {}".format(smap_model))
            return

        rospy.Subscriber('~input', Image, self.callback)

        self.__pub = rospy.Publisher('~image', Image, queue_size=10)
Exemplo n.º 4
0
Arquivo: train.py Projeto: zju3dv/SMAP
def main():
    train_dataset = P2PDataset(dataset_path=cfg.DATA_DIR, root_idx=cfg.DATASET.ROOT_IDX)
    train_loader = DataLoader(train_dataset, batch_size=cfg.SOLVER.BATCH_SIZE, shuffle=True)
    
    model = RefineNet()
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    if len(cfg.MODEL.GPU_IDS) > 1:
        model = nn.parallel.DataParallel(model, device_ids=cfg.MODEL.GPU_IDS)
    
    optimizer = optim.Adam(model.parameters(), lr=cfg.SOLVER.BASE_LR, betas=(0.9, 0.999))
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=cfg.SOLVER.LR_STEP_SIZE, gamma=cfg.SOLVER.GAMMA, last_epoch=-1)
    
    criterion = nn.MSELoss()

    model.train()
    for epoch in range(1, cfg.SOLVER.NUM_EPOCHS+1):
        total_loss = 0
        count = 0
        for i, (inp, gt) in enumerate(train_loader):
            count += 1
            inp = inp.to(device)
            gt = gt.to(device)

            preds = model(inp)
            loss = criterion(preds, gt)
            total_loss += loss.data.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        scheduler.step()
        
        avg_loss = total_loss / count
        if epoch % cfg.PRINT_FREQ == 0:
            print("epoch: {} | loss: {}.".format(epoch, avg_loss))
        if epoch % cfg.SAVE_FREQ == 0 or epoch == cfg.SOLVER.NUM_EPOCHS:
            torch.save(model.state_dict(), osp.join(checkpoint_dir, "RefineNet_epoch_%03d.pth" % epoch))
Exemplo n.º 5
0
def try_load_model(save_dir,
                   step_ckpt=-1,
                   return_new_model=True,
                   verbose=True):
    """
    Tries to load a model from the provided directory, otherwise returns a new initialized model.
    :param save_dir: directory with checkpoints
    :param step_ckpt: step of checkpoint where to resume the model from
    :param verbose: true for printing the model summary
    :return:
    """
    import tensorflow as tf
    tf.compat.v1.enable_v2_behavior()
    if configs.config_values.model == 'baseline':
        configs.config_values.num_L = 1

    # initialize return values
    model_name = configs.config_values.model
    if model_name == 'resnet':
        model = ResNet(filters=configs.config_values.filters,
                       activation=tf.nn.elu)
    elif model_name in ['refinenet', 'baseline']:
        model = RefineNet(filters=configs.config_values.filters,
                          activation=tf.nn.elu)
    elif model_name == 'refinenet_twores':
        model = RefineNetTwoResidual(filters=configs.config_values.filters,
                                     activation=tf.nn.elu)

    optimizer = tf.keras.optimizers.Adam(
        learning_rate=configs.config_values.learning_rate)
    step = 0

    # if resuming training, overwrite model parameters from checkpoint
    if configs.config_values.resume:
        if step_ckpt == -1:
            print("Trying to load latest model from " + save_dir)
            checkpoint = tf.train.latest_checkpoint(save_dir)
        else:
            print("Trying to load checkpoint with step", step_ckpt,
                  " model from " + save_dir)
            onlyfiles = [
                f for f in os.listdir(save_dir)
                if os.path.isfile(os.path.join(save_dir, f))
            ]
            r = re.compile(".*step_{}-.*".format(step_ckpt))
            name_all_checkpoints = sorted(list(filter(r.match, onlyfiles)))
            # Retrieve name of the last checkpoint with that number of steps
            name_ckpt = name_all_checkpoints[-1][:-6]
            checkpoint = save_dir + name_ckpt
        if checkpoint is None:
            print("No model found.")
            if return_new_model:
                print("Using a new model")
            else:
                print("Returning None")
                model = None
                optimizer = None
                step = None
        else:
            step = tf.Variable(0)
            ckpt = tf.train.Checkpoint(step=step,
                                       optimizer=optimizer,
                                       model=model)
            ckpt.restore(checkpoint)
            step = int(step)
            print("Loaded model: " + checkpoint)

    evaluate_print_model_summary(model, verbose)

    return model, optimizer, step
Exemplo n.º 6
0
def try_load_model(save_dir, step_ckpt=-1, return_new_model=True, verbose=True, ocnn=False):
    """
    Tries to load a model from the provided directory, otherwise returns a new initialized model.
    :param save_dir: directory with checkpoints
    :param step_ckpt: step of checkpoint where to resume the model from
    :param verbose: true for printing the model summary
    :return:
    """
    ocnn_model=None
    ocnn_optimizer=None

    import tensorflow as tf
    tf.compat.v1.enable_v2_behavior()
    if configs.config_values.model == 'baseline':
        configs.config_values.num_L = 1

    splits=False
    if configs.config_values.y_cond:
        splits = dict_splits[configs.config_values.dataset]

    # initialize return values
    model_name = configs.config_values.model
    if model_name == 'resnet':
        model = ResNet(filters=configs.config_values.filters, activation=tf.nn.elu)
    elif model_name in ['refinenet', 'baseline']:
        model = RefineNet(filters=configs.config_values.filters, activation=tf.nn.elu,
        y_conditioned=configs.config_values.y_cond, splits=splits)
    elif model_name == 'refinenet_twores':
        model = RefineNetTwoResidual(filters=configs.config_values.filters, activation=tf.nn.elu)
    elif model_name == 'masked_refinenet':
        print("Using Masked RefineNet...")
        # assert configs.config_values.y_cond 
        model = MaskedRefineNet(filters=configs.config_values.filters, activation=tf.nn.elu, 
        splits=dict_splits[configs.config_values.dataset], y_conditioned=configs.config_values.y_cond)

    optimizer = tf.keras.optimizers.Adamax(learning_rate=configs.config_values.learning_rate)
    step = 0
    evaluate_print_model_summary(model, verbose)
    
    if ocnn:
        from tensorflow.keras import Model
        from tensorflow.keras.layers import Input, Flatten, Dense, AvgPool2D
        # Building OCNN on top
        print("Building OCNN...")
        Input = [Input(name="images", shape=(28,28,1)),
                Input(name="idx_sigmas", shape=(), dtype=tf.int32)]

        score_logits = model(Input)
        x = Flatten()(score_logits)
        x = Dense(128, activation="linear", name="embedding")(x)
        dist = Dense(1, activation="linear", name="distance")(x)
        ocnn_model = Model(inputs=Input, outputs=dist, name="OC-NN")
        ocnn_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)
        evaluate_print_model_summary(ocnn_model, verbose=True)

    # if resuming training, overwrite model parameters from checkpoint
    if configs.config_values.resume:
        if step_ckpt == -1:
            print("Trying to load latest model from " + save_dir)
            checkpoint = tf.train.latest_checkpoint(str(save_dir))
        else:
            print("Trying to load checkpoint with step", step_ckpt, " model from " + save_dir)
            onlyfiles = [f for f in os.listdir(save_dir) if os.path.isfile(os.path.join(save_dir, f))]
            # r = re.compile(".*step_{}-.*".format(step_ckpt))
            r = re.compile("ckpt-{}\\..*".format(step_ckpt))

            name_all_checkpoints = sorted(list(filter(r.match, onlyfiles)))
            print(name_all_checkpoints)
            # Retrieve name of the last checkpoint with that number of steps
            name_ckpt = name_all_checkpoints[-1][:-6]
            # print(name_ckpt)
            checkpoint = save_dir + name_ckpt
        if checkpoint is None:
            print("No model found.")
            if return_new_model:
                print("Using a new model")
            else:
                print("Returning None")
                model = None
                optimizer = None
                step = None
        else:
            step = tf.Variable(0)

            if ocnn:
                ckpt = tf.train.Checkpoint(step=step, optimizer=optimizer, model=model,
                ocnn_model=ocnn_model, ocnn_optimizer=ocnn_optimizer)
            else:
                 ckpt = tf.train.Checkpoint(step=step, optimizer=optimizer, model=model)

            ckpt.restore(checkpoint)
            step = int(step)
            print("Loaded model: " + checkpoint)

    return model, optimizer, step, ocnn_model, ocnn_optimizer
Exemplo n.º 7
0
class CraftDetection:
    def __init__(self):
        self.model = CRAFT()
        if pr.cuda:
            self.model.load_state_dict(
                copyStateDict(torch.load(pr.trained_model)))
            self.model.cuda()
            self.model = torch.nn.DataParallel(self.model)
            cudnn.benchmark = False
        else:
            self.model.load_state_dict(
                copyStateDict(torch.load(pr.trained_model,
                                         map_location='cpu')))
        self.model.eval()

        self.refine_model = None
        if pr.refine:
            self.refine_model = RefineNet()
            if pr.cuda:
                self.refine_model.load_state_dict(
                    copyStateDict(torch.load(pr.refiner_model)))
                self.refine_model = self.refine_net.cuda()
                self.refine_model = torch.nn.DataParallel(self.refine_model)
            else:
                self.refine_model.load_state_dict(
                    copyStateDict(
                        torch.load(pr.refiner_model, map_location='cpu')))

            self.refine_model.eval()
            pr.poly = True

    def text_detect(self, image):
        # if not os.path.exists(image_path):
        #     print("Not exists path")
        #     return []
        # image = imgproc.loadImage(image_path)       # numpy array img (RGB order)
        # image = cv2.imread()

        time0 = time.time()

        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image,
            pr.canvas_size,
            interpolation=cv2.INTER_LINEAR,
            mag_ratio=pr.mag_ratio)
        print(img_resized.shape)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
        if pr.cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = self.model(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # refine link
        if self.refine_model is not None:
            with torch.no_grad():
                y_refiner = self.refine_model(y, feature)
            score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

        time0 = time.time() - time0
        time1 = time.time()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link,
                                               pr.text_threshold,
                                               pr.link_threshold, pr.low_text,
                                               pr.poly)

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]

        # expand box: poly  = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)

        if pr.horizontal_mode:
            if self.check_horizontal(polys):
                height, width, channel = image.shape
                new_polys = []
                for box in polys:
                    [[l1, t1], [r1, t2], [r2, b1], [l2, b2]] = box
                    if t1 < t2:
                        l, r, t, b = l2, r1, t1, b1
                    elif t1 > t2:
                        l, r, t, b = l1, r2, t2, b2
                    else:
                        l, r, t, b = l1, r1, t1, b1
                    h_box = abs(b - t)
                    t = max(0, t - h_box * pr.expand_ratio)
                    b = min(b + h_box * pr.expand_ratio, height)
                    x_min, y_min, x_max, y_max = l, t, r, b
                    new_box = [x_min, y_min, x_max, y_max]
                    new_polys.append(new_box)

                polys = np.array(new_polys, dtype=np.float32)

        # for box in polys:

        time1 = time.time() - time1
        total_time = round(time0 + time1, 2)

        # render results (optional)
        render_img = score_text.copy()
        render_img = np.hstack((render_img, score_link))
        ret_score_text = imgproc.cvt2HeatmapImg(render_img)

        if pr.show_time:
            print("\ninfer/postproc time : {:.3f}/{:.3f}".format(time0, time1))
        if pr.folder_test:
            return boxes, polys, ret_score_text

        if pr.visualize:
            img_draw = displayResult(img=image[:, :, ::-1], boxes=polys)
            plt.imshow(cv2.cvtColor(img_draw, cv2.COLOR_RGB2BGR))
            plt.show()

        result_boxes = []
        for box in polys:
            result_boxes.append(box.tolist())
        return result_boxes, total_time

    def test_folder(self, folder_path):

        image_list, _, _ = file_utils.get_files(folder_path)
        if not os.path.exists(pr.result_folder):
            os.mkdir(pr.result_folder)
        t = time.time()

        # load data
        for k, image_path in enumerate(image_list):
            print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list),
                                                      image_path),
                  end='\r')

            bboxes, polys, score_text = self.text_detect(image_path)

            # save score text
            filename, file_ext = os.path.splitext(os.path.basename(image_path))
            mask_file = pr.result_folder + "/res_" + filename + '_mask.jpg'
            cv2.imwrite(mask_file, score_text)
            image = imgproc.loadImage(image_path)
            file_utils.saveResult(image_path,
                                  image[:, :, ::-1],
                                  polys,
                                  dirname=pr.result_folder)

        print("elapsed time : {}s".format(time.time() - t))

    def check_horizontal(self, boxes):
        total_box = len(boxes)
        num_box_horizontal = 0
        for box in boxes:
            [[l1, t1], [r1, t2], [r2, b1], [l2, b2]] = box
            if t1 == t2:
                num_box_horizontal += 1

        ratio_box_horizontal = num_box_horizontal / float(total_box)
        print("Ratio box horizontal: ", ratio_box_horizontal)
        if ratio_box_horizontal >= pr.ratio_box_horizontal:
            return True
        else:
            return False
Exemplo n.º 8
0
class PoseEstimation(object):
    def __init__(self):
        self.model = SMAP(cfg, run_efficient=cfg.RUN_EFFICIENT)
        self.device = torch.device(cfg.MODEL.DEVICE)
        self.model.to(self.device)

        if cfg.REFINE:
            self.refine_model = RefineNet()
            self.refine_model.to(self.device)
            refine_model_file = rospy.get_param(
                '~refinenet_model',
                '/ros/src/smap_ros/resources/RefineNet.pth')
        else:
            self.refine_model = None
            refine_model_file = ""
        smap_model = rospy.get_param(
            '~smap_model', '/ros/src/smap_ros/resources/SMAP_model.pth')

        if os.path.exists(smap_model):
            state_dict = torch.load(smap_model,
                                    map_location=lambda storage, loc: storage)
            state_dict = state_dict['model']
            self.model.load_state_dict(state_dict)
            if os.path.exists(refine_model_file):
                self.refine_model.load_state_dict(
                    torch.load(refine_model_file))
            elif self.refine_model is not None:
                rospy.logerr("No such RefineNet checkpoint of {}".format(
                    refine_model_file))
                return
        else:
            rospy.logerr("No such checkpoint of SMAP {}".format(smap_model))
            return

        rospy.Subscriber('~input', Image, self.callback)

        self.__pub = rospy.Publisher('~image', Image, queue_size=10)

    def callback(self, msg):
        total_now = time.time()
        try:
            image_bgr = self.__bridge.imgmsg_to_cv2(msg, 'bgr8')
        except CvBridgeError as e:
            rospy.logwarn(e)
            return
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        test_dataset = Dataset(image_rgb)
        data_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

        image_debug = generate_3d_point_pairs(self.model,
                                              self.refine_model,
                                              data_loader,
                                              cfg,
                                              self.device,
                                              output_dir=os.path.join(
                                                  cfg.OUTPUT_DIR, "result"))

        total_then = time.time()

        text = "{:03.2f} sec".format(total_then - total_now)
        rospy.loginfo(text)

        self.__pub.publish(self.__bridge.cv2_to_imgmsg(image_debug, 'bgr8'))
Exemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--test_mode",
        "-t",
        type=str,
        default="run_inference",
        choices=['generate_train', 'generate_result', 'run_inference'],
        help=
        'Type of test. One of "generate_train": generate refineNet datasets, '
        '"generate_result": save inference result and groundtruth, '
        '"run_inference": save inference result for input images.')
    parser.add_argument(
        "--data_mode",
        "-d",
        type=str,
        default="test",
        choices=['test', 'generation'],
        help=
        'Only used for "generate_train" test_mode, "generation" for refineNet train dataset,'
        '"test" for refineNet test dataset.')
    parser.add_argument("--SMAP_path",
                        "-p",
                        type=str,
                        default='log/SMAP.pth',
                        help='Path to SMAP model')
    parser.add_argument(
        "--RefineNet_path",
        "-rp",
        type=str,
        default='',
        help='Path to RefineNet model, empty means without RefineNet')
    parser.add_argument("--batch_size",
                        type=int,
                        default=1,
                        help='Batch_size of test')
    parser.add_argument("--do_flip",
                        type=float,
                        default=0,
                        help='Set to 1 if do flip when test')
    parser.add_argument("--dataset_path",
                        type=str,
                        default="",
                        help='Image dir path of "run_inference" test mode')
    parser.add_argument("--json_name",
                        type=str,
                        default="",
                        help='Add a suffix to the result json.')
    args = parser.parse_args()
    cfg.TEST_MODE = args.test_mode
    cfg.DATA_MODE = args.data_mode
    cfg.REFINE = len(args.RefineNet_path) > 0
    cfg.DO_FLIP = args.do_flip
    cfg.JSON_SUFFIX_NAME = args.json_name
    cfg.TEST.IMG_PER_GPU = args.batch_size

    os.makedirs(cfg.TEST_DIR, exist_ok=True)
    logger = get_logger(cfg.DATASET.NAME, cfg.TEST_DIR, 0,
                        'test_log_{}.txt'.format(args.test_mode))

    model = SMAP(cfg, run_efficient=cfg.RUN_EFFICIENT)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    if args.test_mode == "run_inference":
        test_dataset = CustomDataset(cfg, args.dataset_path)
        data_loader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False)
    else:
        data_loader = get_test_loader(cfg,
                                      num_gpu=1,
                                      local_rank=0,
                                      stage=args.data_mode)

    if cfg.REFINE:
        refine_model = RefineNet()
        refine_model.to(device)
        refine_model_file = args.RefineNet_path
    else:
        refine_model = None
        refine_model_file = ""

    model_file = args.SMAP_path
    if os.path.exists(model_file):
        state_dict = torch.load(model_file,
                                map_location=lambda storage, loc: storage)
        state_dict = state_dict['model']
        model.load_state_dict(state_dict)
        if os.path.exists(refine_model_file):
            refine_model.load_state_dict(torch.load(refine_model_file))
        elif refine_model is not None:
            logger.info("No such RefineNet checkpoint of {}".format(
                args.RefineNet_path))
            return
        generate_3d_point_pairs(model,
                                refine_model,
                                data_loader,
                                cfg,
                                logger,
                                device,
                                output_dir=os.path.join(
                                    cfg.OUTPUT_DIR, "result"))
    else:
        logger.info("No such checkpoint of SMAP {}".format(args.SMAP_path))
Exemplo n.º 10
0
import numpy as np
import tensorflow as tf
import datetime
from model.refinenet import RefineNet

refinenet = RefineNet(27)
refinenet.compile(optimizer='adam',
                  loss='mean_absolute_error',
                  metrics=['accuracy'])

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir,
                                                      histogram_freq=1)

x_train = np.random.random((2, 224, 224, 3)).astype(np.float32)
y_train = np.random.random((2, 224, 224, 27)).astype(np.float32)

x_test = np.random.random((1, 224, 224, 3)).astype(np.float32)
y_test = np.random.random((1, 224, 224, 27)).astype(np.float32)

refinenet.fit(x=x_train,
              y=y_train,
              epochs=5,
              validation_data=(x_test, y_test),
              callbacks=[tensorboard_callback])
Exemplo n.º 11
0
    else:
        net.load_state_dict(
            copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    # LinkRefiner
    refine_net = None
    if args.refine:
        from model.refinenet import RefineNet
        refine_net = RefineNet()
        print('Loading weights of refiner from checkpoint (' +
              args.refiner_model + ')')
        if args.cuda:
            refine_net.load_state_dict(
                copyStateDict(torch.load(args.refiner_model)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(
                copyStateDict(
                    torch.load(args.refiner_model, map_location='cpu')))

        refine_net.eval()
        args.poly = True
Exemplo n.º 12
0
class CraftDetection:
    def __init__(self):
        self.model = CRAFT()
        if pr.cuda:
            self.model.load_state_dict(
                copyStateDict(torch.load(pr.trained_model)))
            self.model.cuda()
            self.model = torch.nn.DataParallel(self.model)
            cudnn.benchmark = False
        else:
            self.model.load_state_dict(
                copyStateDict(torch.load(pr.trained_model,
                                         map_location='cpu')))
        self.model.eval()

        self.refine_model = None
        if pr.refine:
            self.refine_model = RefineNet()
            if pr.cuda:
                self.refine_model.load_state_dict(
                    copyStateDict(torch.load(pr.refiner_model)))
                self.refine_model = self.refine_net.cuda()
                self.refine_model = torch.nn.DataParallel(self.refine_model)
            else:
                self.refine_model.load_state_dict(
                    copyStateDict(
                        torch.load(pr.refiner_model, map_location='cpu')))

            self.refine_model.eval()
            pr.poly = True

    def text_detect(self, image, have_cmnd=True):
        time0 = time.time()

        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image,
            pr.canvas_size,
            interpolation=cv2.INTER_LINEAR,
            mag_ratio=pr.mag_ratio)
        print(img_resized.shape)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
        if pr.cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = self.model(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # refine link
        if self.refine_model is not None:
            with torch.no_grad():
                y_refiner = self.refine_model(y, feature)
            score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link,
                                               pr.text_threshold,
                                               pr.link_threshold, pr.low_text,
                                               pr.poly)

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]

        # get box + extend
        list_box = []
        for box in polys:
            [[l1, t1], [r1, t2], [r2, b1], [l2, b2]] = box
            if t1 < t2:
                l, r, t, b = l2, r1, t1, b1
            elif t1 > t2:
                l, r, t, b = l1, r2, t2, b2
            else:
                l, r, t, b = l1, r1, t1, b1

            xmin, ymin, xmax, ymax = l, t, r, b
            xmin, ymin, xmax, ymax = max(0, xmin - int((xmax - xmin) * pr.expand_ratio)),\
                                 max(0, ymin - int((ymax - ymin) * pr.expand_ratio)),\
                                 xmax + int((xmax - xmin) * pr.expand_ratio),\
                                 ymax + int((ymax - ymin) * pr.expand_ratio)
            list_box.append([xmin, ymin, xmax, ymax])

        # sort line
        dict_cum_sorted = self.sort_line_cmnd(list_box)
        list_box_optim = []
        for cum in dict_cum_sorted:
            for box in cum:
                list_box_optim.append(box)

        # draw box on image
        img_res = image.copy()
        img_res = np.ascontiguousarray(img_res)
        for box in list_box_optim:
            xmin, ymin, xmax, ymax = box
            cv2.rectangle(img_res, (int(xmin), int(ymin)),
                          (int(xmax), int(ymax)), (29, 187, 255), 2, 2)

        # crop image

        result_list_img_cum = []
        image_PIL = Image.fromarray(image)
        for cum in dict_cum_sorted:
            list_img = []
            for box in cum:
                xmin, ymin, xmax, ymax = box
                list_img.append(image_PIL.copy().crop(
                    (xmin, ymin, xmax, ymax)))
            result_list_img_cum.append(list_img)
        return result_list_img_cum, img_res, None

    def sort_line_cmnd(self, boxes):

        if len(boxes) == 0:
            return []
        boxes = sorted(boxes, key=lambda x: x[1])  # sort by ymin
        lines = [[]]

        # y_center = (boxes[0][1] + boxes[0][3]) / 2.0
        y_max_base = boxes[0][3]  # y_max
        i = 0
        for box in boxes:
            if box[1] + 0.5 * abs(box[3] -
                                  box[1]) <= y_max_base:  # y_min <= y_max_base
                lines[i].append(box)
            else:
                lines[i] = sorted(lines[i], key=lambda x: x[0])
                # y_center = (box[1] + box[3]) / 2.0
                y_max_base = box[3]
                lines.append([])
                i += 1
                lines[i].append(box)

        temp = []

        for line in lines:
            temp.append(line[0][1])
        index_sort = np.argsort(np.array(temp)).tolist()
        lines_new = [self.remove(lines[i]) for i in index_sort]

        return lines_new
        # return lines

    def remove(self, line):
        line = sorted(line, key=lambda x: x[0])
        result = []
        check_index = -1
        for index in range(len(line)):
            if check_index == index:
                pass
            else:
                result.append(line[index])
                check_index = index
            if index == len(line) - 1:
                break
            if self.compute_iou(line[index], line[index + 1]) > 0.25:
                s1 = (line[index][2] - line[index][0] +
                      1) * (line[index][3] - line[index][1] + 1)
                s2 = (line[index + 1][2] - line[index + 1][0] +
                      1) * (line[index + 1][3] - line[index + 1][1] + 1)
                if s2 > s1:
                    del (result[-1])
                    result.append(line[index + 1])
                check_index = index + 1
        result = sorted(result, key=lambda x: x[0])
        return result

    def compute_iou(self, box1, box2):

        x_min_inter = max(box1[0], box2[0])
        y_min_inter = max(box1[1], box2[1])
        x_max_inter = min(box1[2], box2[2])
        y_max_inter = min(box1[3], box2[3])

        inter_area = max(0, x_max_inter - x_min_inter + 1) * max(
            0, y_max_inter - y_min_inter + 1)

        s1 = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
        s2 = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
        # print(inter_area)
        iou = float(inter_area / (s1 + s2 - inter_area))

        return iou

    def sort_line(self, boxes):
        if len(boxes) == 0:
            return []
        boxes = sorted(boxes, key=lambda x: x[1])
        lines = [[]]

        y_center = (boxes[0][1] + boxes[0][3]) / 2.0
        i = 0
        for box in boxes:
            if box[1] < y_center:
                lines[i].append(box)
            else:
                lines[i] = sorted(lines[i], key=lambda x: x[0])
                y_center = (box[1] + box[3]) / 2.0
                lines.append([])
                i += 1
                lines[i].append(box)

        temp = []

        for line in lines:
            temp.append(line[0][1])
        index_sort = np.argsort(np.array(temp)).tolist()
        lines_new = [self.remove(lines[i]) for i in index_sort]

        return lines_new