def __init__(self, config, is_training):
        self.config = config
        self.is_training = is_training
        use_gpu = config.use_gpu
        self.net = Model(self.config, is_training=self.is_training)
        if self.is_training:
            self.net.train(is_training)
        else:
            self.net.eval()

        self.net.init_weights(gpu=use_gpu)

        if self.is_training:
            self.optimizer = self._get_optimizer()

        if len(self.config.parallels) > 0:
            self.net = nn.DataParallel(self.net)
            if use_gpu:
                self.net = self.net.cuda()

        self.yolo_loss = []
        for i in range(3):
            self.yolo_loss.append(
                YOLOLoss(config.anchors[i], config.image_size,
                         config.num_classes))
        #if is_refine:
        #    self.refine_loss = RefineLoss(config.anchors, config.num_classes, (config.image_size, config.image_size))

        if config.pretrained_weights:
            logging.info("Load pretrained weights from {}".format(
                config.pretrained_weights))
            if use_gpu:
                checkpoint = torch.load(config.pretrained_weights)
            else:
                checkpoint = torch.load(config.pretrained_weights,
                                        map_location=torch.device('cpu'))
            state_dict = checkpoint['state_dict']
            self.net.load_state_dict(state_dict)
            self.epoch = checkpoint["epoch"] + 1
            self.global_step = checkpoint['global step'] + 1
        else:
            self.epoch = 0
            self.global_step = 0

        if config.official_weights:
            logging.info("Loading official weights from {}".format(
                config.official_weights))
            self.net.load_state_dict(torch.load(config.official_weights))
            self.global_step = 20000

        #self.pre_prune_weights()
        #self.prune_weights_in_training_perc()
        self.prune_weights_in_training_thresh()
Exemple #2
0
def test(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))
    opt.exp_name = '_'.join(opt.saved_model.split('/')[1:])
    # print(model)
    """ keep evaluation model and result logs """
    os.makedirs(f'./result/{opt.exp_name}', exist_ok=True)
    os.system(f'cp {opt.saved_model} ./result/{opt.exp_name}/')
    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)
    """ evaluation """
    model.eval()
    with torch.no_grad():
        if opt.benchmark_all_eval:
            benchmark_all_eval(model, criterion, converter, opt)
        else:
            log = open(f'./result/{opt.exp_name}/log_evaluation.txt', 'a')
            AlignCollate_evaluation = AlignCollate(imgH=opt.imgH,
                                                   imgW=opt.imgW,
                                                   keep_ratio_with_pad=opt.PAD)
            eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data,
                                                            opt=opt)
            evaluation_loader = torch.utils.data.DataLoader(
                eval_data,
                batch_size=opt.batch_size,
                shuffle=False,
                num_workers=int(opt.workers),
                collate_fn=AlignCollate_evaluation,
                pin_memory=True)
            _, accuracy_by_best_model, _, _, _, _, _, _ = validation(
                model, criterion, evaluation_loader, converter, opt)
            log.write(eval_data_log)
            print(f'{accuracy_by_best_model:0.3f}')
            log.write(f'{accuracy_by_best_model:0.3f}\n')
            log.close()
Exemple #3
0
class Mode():
    def __init__(self, config, is_training):
        self.config = config
        self.is_training = is_training
        self.net = Model(self.config, is_training=self.is_training)
        if self.is_training:
            self.net.train(is_training)
        else:
            self.net.eval()
        self.net.init_weights()
        if self.is_training:
            self.optimizer = self._get_optimizer()

        if len(self.config.parallels) > 0:
            self.net = nn.DataParallel(self.net)
            self.net = self.net.cuda()

        self.yolo_loss = []
        for i in range(3):
            self.yolo_loss.append(
                YOLOLoss(config.anchors[i], config.image_size,
                         config.num_classes))
        #if is_refine:
        #    self.refine_loss = RefineLoss(config.anchors, config.num_classes, (config.image_size, config.image_size))

        if config.pretrained_weights:
            logging.info("Load pretrained weights from {}".format(
                config.pretrained_weights))
            checkpoint = torch.load(config.pretrained_weights)
            state_dict = checkpoint['state_dict']
            self.net.load_state_dict(state_dict)
            self.epoch = checkpoint["epoch"] + 1
            self.global_step = checkpoint['global step'] + 1
        else:
            self.epoch = 0
            self.global_step = 0

        if config.official_weights:
            logging.info("Loading official weights from {}".format(
                config.official_weights))
            self.net.load_state_dict(torch.load(config.official_weights))
            self.global_step = 20000

    def _get_optimizer(self):
        optimizer = None

        # Assign different lr for each layer
        params = None
        base_params = list(map(id, self.net.backbone.parameters()))
        logits_params = filter(lambda p: id(p) not in base_params,
                               self.net.parameters())

        if not self.config.freeze_backbone:
            params = [
                {
                    "params": self.net.parameters(),
                    "lr": self.config.learning_rate
                },
            ]
        else:
            logging.info("freeze backbone's parameters.")
            for p in self.net.backbone.parameters():
                p.requires_grad = False
            params = [
                {
                    "params": logits_params,
                    "lr": self.config.learning_rate
                },
            ]

        # Initialize optimizer class
        if self.config.optimizer == "adam":
            optimizer = optim.Adam(params,
                                   weight_decay=self.config.weight_decay)
        elif self.config.optimizer == "amsgrad":
            optimizer = optim.Adam(params,
                                   weight_decay=self.config.weight_decay,
                                   amsgrad=True)
        elif self.config.optimizer == "rmsprop":
            optimizer = optim.RMSprop(params,
                                      weight_decay=self.config.weight_decay)
        else:
            # Default to sgd
            logging.info("Using SGD optimizer.")
            optimizer = optim.SGD(
                params,
                momentum=self.config.momentum,
                weight_decay=self.config.weight_decay,
                nesterov=(self.config.optimizer == "nesterov"))

        return optimizer

    def train(self, train_dataloader, val_dataloader):

        # Optimizer
        def adjust_learning_rate(optimizer, config, global_step):
            lr = config.learning_rate
            if global_step < config.burn_in:
                lr = lr * (global_step / config.burn_in) * (global_step /
                                                            config.burn_in)
            elif global_step < config.decay_step[0]:
                lr = lr
            elif global_step < config.decay_step[1]:
                lr = config.decay_gamma * lr
            else:
                lr = config.decay_gamma * config.decay_gamma * lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            return lr

        summary = SummaryWriter(self.config.write)

        logging.info("Start training")
        while self.global_step < self.config.max_iter:
            #train step
            train_dataloader.dataset.random_shuffle()
            train_dataloader.dataset.update(self.global_step)
            for step, samples in enumerate(train_dataloader):
                images, labels = samples['image'], samples['label']
                images = images.cuda()
                image_size = images.size(2)
                batch_size = images.size(0)
                start_time = time.time()
                lr = adjust_learning_rate(self.optimizer, self.config,
                                          self.global_step)
                self.optimizer.zero_grad()
                outputs = self.net(images)
                losses_name = [
                    "total_loss", "x", "y", "w", "h", "conf", "cls", "a"
                ]
                losses = []
                for _ in range(len(losses_name)):
                    losses.append([])
                for i in range(3):
                    _loss_item = self.yolo_loss[i](outputs[i], labels,
                                                   self.global_step)
                    for j, l in enumerate(_loss_item):
                        losses[j].append(l)
                losses = [sum(l) for l in losses]
                loss = losses[0]
                loss.backward()
                self.optimizer.step()
                #memory_usage_psutil()
                if step >= 0 and step % 10 == 0:
                    _loss = loss.item()
                    time_per_example = float(time.time() -
                                             start_time) / batch_size
                    logging.info(
                        "epoch [%.3d] step = %d size = %d loss = %.2f time/example = %.3f lr = %.5f loss_x = %.3f loss_y = %.3f loss_w = %.3f loss_h = %.3f loss_conf = %.3f loss_cls = %.3f loss_a = %.3f"
                        %
                        (self.epoch, step, image_size, _loss, time_per_example,
                         lr, losses[1], losses[2], losses[3], losses[4],
                         losses[5], losses[6], losses[7]))

                    summary.add_scalar("lr", lr, self.global_step)
                    for i, name in enumerate(losses_name):
                        v = _loss if i == 0 else losses[i]
                        summary.add_scalar(name, v, self.global_step)

                    if step > 0 and step % 1000 == 0:
                        checkpoint_path = os.path.join(self.config.save_dir,
                                                       "model_backup.pth")
                        checkpoint = {
                            'state_dict': self.net.state_dict(),
                            'epoch': self.epoch,
                            "global step": self.global_step
                        }
                        torch.save(checkpoint, checkpoint_path)
                        logging.info("Model checkpoint saved to {}".format(
                            checkpoint_path))

                self.global_step += 1

            checkpoint_path = os.path.join(self.config.save_dir,
                                           "model_{}.pth".format(self.epoch))
            checkpoint = {
                'state_dict': self.net.state_dict(),
                'epoch': self.epoch,
                "global step": self.global_step
            }
            torch.save(checkpoint, checkpoint_path)
            logging.info(
                "Model checkpoint saved to {}".format(checkpoint_path))

            #val every epoch
            logging.info('Start validating after epoch {}'.format(self.epoch))
            val_losses = []
            val_num = len(val_dataloader)
            for step, samples in enumerate(val_dataloader):
                images, labels = samples['image'], samples['label']
                with torch.no_grad():
                    outputs = self.net(images)
                    losses_name = [
                        "total_loss", "x", "y", "w", "h", "conf", "cls", "a"
                    ]
                    losses = []
                    for _ in range(len(losses_name)):
                        losses.append([])
                    for i in range(3):
                        _loss_item = self.yolo_loss[i](outputs[i], labels)
                        for j, l in enumerate(_loss_item):
                            losses[j].append(l)
                    losses = [sum(l) for l in losses]
                    val_loss = losses[0].item()
                    if step > 0 and step % 10 == 0:
                        logging.info("Having validated [%.3d/%.3d]" %
                                     (step, val_num))
                    val_losses.append(val_loss)
            val_loss = np.mean(np.asarray(val_losses))
            logging.info("val loss = %.2f at epoch [%.3d]" %
                         (val_loss, self.epoch))
            self.epoch += 1

    #def inference(self, inputs):
    #    with torch.no_grad():
    #        outputs = self.net(inputs)
    #        output = self.yolo_loss(outputs)
    #        detections =

    def eval_coco(self, val_dataset):
        index2category = json.load(open("coco_index2category.json"))
        logging.info('Start Evaling')
        coco_result = []
        coco_img_ids = set([])

        for step, samples in enumerate(val_dataset):
            images, labels = samples['image'], samples['label']
            image_size = images.size(2)
            image_paths, origin_sizes = samples['image_path'], samples[
                'origin_size']
            with torch.no_grad():
                outputs = self.net(images)
                #output = self.yolo_loss(outputs)
                output_list = []
                for i in range(3):
                    output_list.append(self.yolo_loss[i](outputs[i]))
                output = torch.cat(output_list, 1)
                batch_detections = non_max_suppression(output,
                                                       self.config.num_classes,
                                                       conf_thres=0.001,
                                                       nms_thres=0.45)
            for idx, detections in enumerate(batch_detections):
                image_id = int(os.path.basename(image_paths[idx])[-16:-4])
                coco_img_ids.add(image_id)
                if detections is not None:
                    origin_size = eval(origin_sizes[idx])
                    detections = detections.cpu().numpy()
                    dim_diff = np.abs(origin_size[0] - origin_size[1])
                    pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
                    pad = ((pad1, pad2), (0, 0),
                           (0, 0)) if origin_size[1] <= origin_size[0] else ((
                               0, 0), (pad1, pad2), (0, 0))
                    scale = origin_size[0] if origin_size[1] <= origin_size[
                        0] else origin_size[1]
                    for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
                        x1 = x1 / self.config.image_size * scale
                        x2 = x2 / self.config.image_size * scale
                        y1 = y1 / self.config.image_size * scale
                        y2 = y2 / self.config.image_size * scale
                        x1 -= pad[1][0]
                        y1 -= pad[0][0]
                        x2 -= pad[1][0]
                        y2 -= pad[0][0]
                        w = x2 - x1
                        h = y2 - y1
                        coco_result.append({
                            "image_id":
                            image_id,
                            "category_id":
                            index2category[str(int(cls_pred.item()))],
                            "bbox": (float(x1), float(y1), float(w), float(h)),
                            "score":
                            float(conf),
                        })
            logging.info("Now have finished [%.3d/%.3d]" %
                         (step, len(val_dataset)))
        save_path = "coco_results.json"
        with open(save_path, "w") as f:
            json.dump(coco_result,
                      f,
                      sort_keys=True,
                      indent=4,
                      separators=(',', ':'))
        logging.info('Save result in {}'.format(save_path))

        logging.info('Using COCO APi to evaluate')
        cocoGt = COCO(self.config.annotation)
        cocoDt = cocoGt.loadRes(save_path)
        cocoEval = COCOeval(cocoGt, cocoDt, "bbox")
        cocoEval.params.imgIds = list(coco_img_ids)
        cocoEval.evaluate()
        cocoEval.accumulate()
        cocoEval.summarize()

    def eval_voc(self, val_dataset, classes, iou_thresh=0.5):
        logging.info('Start Evaling')
        results = {}

        def voc_ap(rec, prec, use_07_metric=False):
            """ ap = voc_ap(rec, prec, [use_07_metric])
            Compute VOC AP given precision and recall.
            If use_07_metric is true, uses the
            VOC 07 11 point method (default:False).
            """
            _rec = np.arange(0., 1.1, 0.1)
            _prec = []
            if use_07_metric:
                # 11 point metric
                ap = 0.
                for t in np.arange(0., 1.1, 0.1):
                    if np.sum(rec >= t) == 0:
                        p = 0
                    else:
                        p = np.max(prec[rec >= t])
                    _prec.append(p)
                    ap = ap + p / 11.
            else:
                # correct AP calculation
                # first append sentinel values at the end
                mrec = np.concatenate(([0.], rec, [1.]))
                mpre = np.concatenate(([0.], prec, [0.]))

                # compute the precision envelope
                for i in range(mpre.size - 1, 0, -1):
                    mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

                # to calculate area under PR curve, look for points
                # where X axis (recall) changes value
                i = np.where(mrec[1:] != mrec[:-1])[0]

                # and sum (\Delta recall) * prec
                ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])

            return ap

        def caculate_ap(correct, conf, pred_cls, total, classes):
            correct, conf, pred_cls = np.array(correct), np.array(
                conf), np.array(pred_cls)
            index = np.argsort(-conf)
            correct, conf, pred_cls = correct[index], conf[index], pred_cls[
                index]

            ap = []
            AP = {}
            for i, c in enumerate(classes):
                k = pred_cls == i
                n_gt = total[c]
                n_p = sum(k)

                if n_gt == 0 and n_p == 0:
                    continue
                elif n_p == 0 or n_gt == 0:
                    ap.append(0)
                    AP[c] = 0
                else:
                    fpc = np.cumsum(1 - correct[k])
                    tpc = np.cumsum(correct[k])

                    rec = tpc / n_gt
                    prec = tpc / (tpc + fpc)

                    _ap = voc_ap(rec, prec)
                    ap.append(_ap)
                    AP[c] = _ap
            mAP = np.array(ap).mean()
            return mAP, AP

        def parse_rec(imagename, classes):
            filename = imagename.replace('jpg', 'xml')
            tree = ET.parse(filename)
            objects = []
            for obj in tree.findall('object'):
                difficult = obj.find('difficult').text
                cls = obj.find('name').text
                if cls not in classes or int(difficult) == 1:
                    continue
                cls_id = classes.index(cls)
                xmlbox = obj.find('bndbox')
                obj = [
                    float(xmlbox.find('xmin').text),
                    float(xmlbox.find('xmax').text),
                    float(xmlbox.find('ymin').text),
                    float(xmlbox.find('ymax').text), cls_id
                ]
                objects.append(obj)
            return np.asarray(objects)

        total = {}
        for cls in classes:
            total[cls] = 0

        correct = []
        conf_list = []
        pred_list = []
        for step, samples in enumerate(val_dataset):
            images, labels = samples['image'], samples['label']
            image_paths, origin_sizes = samples['image_path'], samples[
                'origin_size']

            logging.info("Now have finished [%.3d/%.3d]" %
                         (step, len(val_dataset)))
            with torch.no_grad():
                outputs = self.net(images)
                output_list = []
                for i in range(3):
                    output_list.append(self.yolo_loss[i](outputs[i]))
                output = torch.cat(output_list, 1)
                batch_detections = non_max_suppression(output,
                                                       self.config.num_classes,
                                                       conf_thres=0.001,
                                                       nms_thres=0.4)

            for idx, detections in enumerate(batch_detections):
                image_path = image_paths[idx]
                label = labels[idx]
                for t in range(label.size(0)):
                    if label[t, :].sum() == 0:
                        label = label[:t, :]
                        break
                label_cls = np.array(label[:, 0])
                for cls_id in label_cls:
                    total[classes[int(cls_id)]] += 1
                if detections is None:
                    if label.size(0) != 0:
                        label_cls = np.unique(label_cls)
                        for cls_id in label_cls:
                            correct.append(0)
                            conf_list.append(1)
                            pred_list.append(int(cls_id))
                    continue
                if label.size(0) == 0:
                    for *pred_box, conf, cls_conf, cls_pred in detections:
                        correct.append(0)
                        conf_list.append(conf)
                        pred_list.append(int(cls_pred))
                else:
                    detections = detections[np.argsort(-detections[:, 4])]
                    detected = []

                    for *pred_box, conf, cls_conf, cls_pred in detections:
                        pred_box = torch.FloatTensor(pred_box).view(1, -1)
                        pred_box[:, 2:] = pred_box[:, 2:] - pred_box[:, :2]
                        pred_box[:, :2] = pred_box[:, :2] + pred_box[:, 2:] / 2
                        pred_box = pred_box / self.config.image_size
                        ious = bbox_iou(pred_box, label[:, 1:])
                        best_i = np.argmax(ious)
                        if ious[best_i] > iou_thresh and int(cls_pred) == int(
                                label[best_i, 0]) and best_i not in detected:
                            correct.append(1)
                            detected.append(best_i)
                        else:
                            correct.append(0)
                        pred_list.append(int(cls_pred))
                        conf_list.append(float(conf))

        results['correct'] = correct
        results['conf'] = conf_list
        results['pred_cls'] = pred_list
        results['total'] = total
        with open('results.json', 'w') as f:
            json.dump(results, f)
            logging.info('Having saved to results.json')

        logging.info('Begin calculating....')
        with open('results.json', 'r') as result_file:
            results = json.load(result_file)

        mAP, AP_class = caculate_ap(correct=results['correct'],
                                    conf=results['conf'],
                                    pred_cls=results['pred_cls'],
                                    total=results['total'],
                                    classes=classes)
        logging.info('mAP(IoU=0.5):{:.1f}'.format(mAP * 100))

    def inference(self, image, classes, colors):

        image_origin = image
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image,
                           (self.config.image_size, self.config.image_size),
                           interpolation=cv2.INTER_LINEAR)
        image = np.expand_dims(image, 0)
        image = image.astype(np.float32)
        image /= 255
        image = np.transpose(image, (0, 3, 1, 2))
        image = image.astype(np.float32)
        image = torch.from_numpy(image)

        start_time = time.time()
        if torch.cuda.is_available():
            image = image.cuda()
        with torch.no_grad():
            outputs = self.net(image)
            output_list = []
            for i in range(3):
                output_list.append(self.yolo_loss[i](outputs[i]))
            output = torch.cat(output_list, 1)
            batch_detections = non_max_suppression(output,
                                                   self.config.num_classes,
                                                   conf_thres=0.5,
                                                   nms_thres=0.4)
            spand_time = float(time.time() - start_time)
        detection = batch_detections[0]
        if detection is not None:
            origin_size = image_origin.shape[:2]
            detection = detection.cpu().numpy()
            for x1, y1, x2, y2, conf, cls_conf, cls_pred in detection:
                x1 = int(x1 / self.config.image_size * origin_size[1])
                x2 = int(x2 / self.config.image_size * origin_size[1])
                y1 = int(y1 / self.config.image_size * origin_size[0])
                y2 = int(y2 / self.config.image_size * origin_size[0])
                color = colors[int(cls_pred)]
                image_origin = cv2.rectangle(image_origin, (x1, y1), (x2, y2),
                                             color, 3)
                image_origin = cv2.rectangle(image_origin, (x1, y1),
                                             (x2, y1 + 20),
                                             color,
                                             thickness=-1)
                caption = "{}:{:.2f}".format(classes[int(cls_pred)], cls_conf)
                image_origin = cv2.putText(image_origin, caption,
                                           (x1, y1 + 15),
                                           cv2.FONT_HERSHEY_SIMPLEX, 0.6,
                                           (255, 255, 255), 2)
            return image_origin, spand_time
Exemple #4
0
def train(opt):
    """ Dataset Preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ Model Configuration """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            converter = CTCLabelConverterForBaiduWarpctc(opt.character)
        else:
            converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)
    """ Setup Loss """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            # need to install warpctc. see our guideline.
            from warpctc_pytorch import CTCLoss
            criterion = CTCLoss()
        else:
            criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ Final Options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ Start Training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter

    while True:
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            if opt.baiduCTC:
                preds = preds.permute(1, 0, 2)  # to use CTCLoss format
                cost = criterion(preds, text, preds_size, length) / batch_size
            else:
                preds = preds.log_softmax(2).permute(1, 0, 2)
                cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{iteration + 1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(),
                f'./saved_models/{opt.exp_name}/iter_{iteration + 1}.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
Exemple #5
0
def demo(opt):
    """ Model Configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))

    AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(
        demo_data, batch_size=opt.batch_size,
        shuffle=False,
        num_workers=0,  # In Linux use int(opt.workers), in Windows 0
        collate_fn=AlignCollate_demo, pin_memory=True)

    # predict
    model.eval()
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                # preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index, preds_size)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

            log = open(f'./log_demo_result.txt', 'a')
            dashed_line = '-' * 80
            head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'

            print(f'{dashed_line}\n{head}\n{dashed_line}')
            log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob):
                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_max_prob = pred_max_prob[:pred_EOS]

                # calculate confidence score (= multiply of pred_max_prob)
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]

                print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')
                log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n')

            log.close()