Exemplo n.º 1
0
    def __getitem__(self, index):
        image_id = self.text_source[index]
        image_path = f'{self.data_dir}/imgs/{image_id}.jpg'
        label_path = f'{self.data_dir}/labels/{image_id}.json'
        info = load_json_file(label_path)
        text = info['text']
        boxes = info['bboxes']
        image = load_image(image_path)
        w, h = image.size
        if w < h:
            # 垂直文本放平
            boxes = box_util.bbox_rotate90(boxes, image.size)
            image = image.rotate(90, expand=True)

        image, boxes = self.resize_and_padding_image(image,
                                                     boxes=boxes,
                                                     with_box=True)
        image = self.transform(image)

        cls_map = self.get_cls_map(text, boxes)
        attn_map2 = self.get_attn_map(boxes, 1 / 4)
        attn_map3 = self.get_attn_map(boxes, 1 / 8)
        attn_map4 = self.get_attn_map(boxes, 1 / 16)
        attn_map5 = self.get_attn_map(boxes, 1 / 32)
        targets = {
            "image": image,
            "hm": torch.LongTensor(cls_map),
            "a2": torch.FloatTensor(attn_map2),
            "a3": torch.FloatTensor(attn_map3),
            'a4': torch.FloatTensor(attn_map4),
            'a5': torch.FloatTensor(attn_map5),
        }
        if not self.train:
            targets['labels'] = [self.tokenizer.encode(x) for x in text]
        return image, targets
Exemplo n.º 2
0
 def __init__(self, opt):
     self.opt = opt
     if opt.with_cuda:
         self.device = torch.device(
             'cuda' if torch.cuda.is_available() else 'cpu')
     else:
         self.device = torch.device('cpu')
     self.logger = get_logger(opt.log)
     self.arch_config = opt.arch_config
     self._build_model(self.opt, self.arch_config)
     self._build_optimizer(self.opt)
     self._build_criterion(self.opt, self.arch_config)
     self._build_converter(self.opt, self.arch_config)
     self._build_dataloader(self.opt)
     self._build_summary_writer(self.opt)
     self.client_state = {'epoch': 0, 'step': 0}
     self.scheduler = get_lr_schedule(
         self.optimizer, load_json_file(opt.lr_scheduler_config))
Exemplo n.º 3
0
                        default="data/generated")
    args = parser.parse_args()
    BASE_DIR = args.base_dir
    IMG_DIR = f'{BASE_DIR}/imgs'
    LABEL_DIR = f'{BASE_DIR}/labels'
    label_files = glob(f'{LABEL_DIR}\\*.*')

    full_id_writer = open(f'{BASE_DIR}/dataset_ids.txt', 'w', encoding='utf-8')
    train_id_writer = open(f'{BASE_DIR}/train_ids.txt', 'w', encoding='utf-8')
    val_id_writer = open(f'{BASE_DIR}/val_ids.txt', 'w', encoding='utf-8')
    for label_file in tqdm(label_files):
        img_id = label_file.split('\\')[-1].split('.')[0]
        img_path = f'{IMG_DIR}/{img_id}.jpg'
        if not os.path.exists(img_path):
            continue
        label = load_json_file(label_file)
        if label is None:
            continue
        assert True
        text = label['text']
        bboxes = label['bboxes']
        txt_line = f'{img_id}[SEP]{text}[SEP]{bboxes}\n'
        id_line = f'{img_id}\n'
        full_id_writer.write(id_line)
        if random() < 0.01:
            val_id_writer.write(id_line)
        else:

            train_id_writer.write(id_line)

    full_id_writer.close()
Exemplo n.º 4
0
        default=1e-3,
        help="weight decay (L2 penalty).")

    parser.add_argument('--find_lr', action='store_true', help='whether apply accumulation gradients')
    parser.add_argument('--lr_finder_conf', type=str, default='config/lr_find_config.json', help='learning rate finder config file')
    parser.add_argument('--num_iters', type=int, default=4000, help='number of iterations for lr finder')

    opt = parser.parse_args()

    exp_dir = os.path.join('exps', opt.exp_name)
    opt.exp_dir = exp_dir
    os.makedirs(opt.exp_dir, exist_ok=True)
    opt.log = os.path.join(exp_dir, 'log.txt')
    opt.vocab = load_cafcn_vocab(opt.vocab)
    if opt.use_accum:
        opt.batch_size = int(opt.batch_size / opt.accum_steps)

    apply_seed(opt.seed)
    cudnn.benchmark = True
    cudnn.deterministic = True

    trainer = CAFCNTrainer(opt)
    if opt.find_lr:
        lr_finder_conf = load_json_file(opt.lr_finder_conf)
        trainer.find_lr(lr_finder_conf, opt.num_iters)
    elif opt.val:
        loss, score = trainer.validate()
        print(f'Val loss:{loss}, accuracy:{score}')
    else:
        trainer.train()