Esempio n. 1
0
def main(config: ConfigParser, local_master: bool, logger=None):
    # setup dataset and data_loader instances
    train_dataset = config.init_obj('train_dataset', pick_dataset_module)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) \
        if config['distributed'] else None

    is_shuffle = False if config['distributed'] else True
    train_data_loader = config.init_obj('train_data_loader',
                                        torch.utils.data.dataloader,
                                        dataset=train_dataset,
                                        sampler=train_sampler,
                                        batch_size=8,
                                        shuffle=is_shuffle,
                                        collate_fn=BatchCollateFn())

    val_dataset = config.init_obj('validation_dataset', pick_dataset_module)
    val_data_loader = config.init_obj('val_data_loader',
                                      torch.utils.data.dataloader,
                                      dataset=val_dataset,
                                      collate_fn=BatchCollateFn())
    logger.info(
        f'Dataloader instances created. Batch size: {train_data_loader.batch_size} '
        f'Batch size: {val_data_loader.batch_size}.') if local_master else None
    logger.info(f'Train datasets: {len(train_dataset)} samples '
                f'Validation datasets: {len(val_dataset)} samples.'
                ) if local_master else None

    # build model architecture
    pick_model = config.init_obj('model_arch', pick_arch_module)
    logger.info(
        f'Model created, trainable parameters: {pick_model.model_parameters()}.'
    ) if local_master else None

    # build optimizer, learning rate scheduler.
    optimizer = config.init_obj('optimizer', torch.optim,
                                pick_model.parameters())
    lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler,
                                   optimizer)
    logger.info(
        'Optimizer and lr_scheduler created.') if local_master else None

    # print training related information
    logger.info(
        'Max_epochs: {} Log_per_step: {} Validation_per_step: {}.'.format(
            config['trainer']['epochs'],
            config['trainer']['log_step_interval'],
            config['trainer']['val_step_interval'])) if local_master else None

    logger.info('Training start...') if local_master else None
    trainer = Trainer(pick_model,
                      optimizer,
                      config=config,
                      data_loader=train_data_loader,
                      valid_data_loader=val_data_loader,
                      lr_scheduler=lr_scheduler)

    trainer.train()
    logger.info('Training end...') if local_master else None
Esempio n. 2
0
def extractKeys(test_box_path,test_img_path,pick_model,args,device):
  # setup dataset and data_loader instances
  test_dataset = PICKDataset(boxes_and_transcripts_folder=test_box_path,
                              images_folder=test_img_path,
                              resized_image_size=(480, 960),
                              ignore_error=False,
                              training=False)
  test_data_loader = DataLoader(test_dataset, batch_size=args.bs, shuffle=False,
                                num_workers=2, collate_fn=BatchCollateFn(training=False))

  # setup output path
  output_path = Path(args.output_folder)
  output_path.mkdir(parents=True, exist_ok=True)

  # predict and save to file
  with torch.no_grad():
      for step_idx, input_data_item in tqdm(enumerate(test_data_loader)):
          for key, input_value in input_data_item.items():
              if input_value is not None:
                  input_data_item[key] = input_value.to(device)

          output = pick_model(**input_data_item)
          logits = output['logits']
          new_mask = output['new_mask']
          image_indexs = input_data_item['image_indexs']  # (B,)
          text_segments = input_data_item['text_segments']  # (B, num_boxes, T)
          mask = input_data_item['mask']
          # List[(List[int], torch.Tensor)]
          best_paths = pick_model.decoder.crf_layer.viterbi_tags(logits, mask=new_mask, logits_batch_first=True)
          predicted_tags = []
          for path, score in best_paths:
              predicted_tags.append(path)

          # convert iob index to iob string
          decoded_tags_list = iob_index_to_str(predicted_tags)
          # union text as a sequence and convert index to string
          decoded_texts_list = text_index_to_str(text_segments, mask)

          for decoded_tags, decoded_texts, image_index in zip(decoded_tags_list, decoded_texts_list, image_indexs):
              # List[ Tuple[str, Tuple[int, int]] ]
              spans = bio_tags_to_spans(decoded_tags, [])
              spans = sorted(spans, key=lambda x: x[1][0])

              entities = []  # exists one to many case
              for entity_name, range_tuple in spans:
                  entity = dict(entity_name=entity_name,
                                text=''.join(decoded_texts[range_tuple[0]:range_tuple[1] + 1]))
                  entities.append(entity)

              result_file = output_path.joinpath(Path(test_dataset.files_list[image_index]).stem + '.txt')
              ent={}
              for item in entities:
                ent.setdefault(item['entity_name'],[]).append((item['text']))
              for k,v in ent.items():
                ent[k]=' '.join(i for i in ent[k])
              with open(result_file,'w') as fw:
                for k,v in ent.items():
                  fw.write(k+":   "+v)
                  fw.write("\n")
                fw.close()
Esempio n. 3
0
def test_model_forward():
    # torch.backends.cudnn.benchmark = False
    args = argparse.ArgumentParser(description='PICK parameters')
    args.add_argument('-c', '--config', default='../config.json', type=str,
                      help='config file path (default: None)')
    args.add_argument('-r', '--resume', default=None, type=str,
                      help='path to latest checkpoint (default: None)')
    args.add_argument('-d', '--device', default='0', type=str,
                      help='indices of GPUs to enable (default: all)')
    CustomArgs = collections.namedtuple('CustomArgs', 'flags default type target help')
    options = [
        CustomArgs(['--local_world_size'], default=1, type=int, target='local_world_size',
                   help='this is passed in explicitly'),
        CustomArgs(['--local_rank'], default=0, type=int, target='local_rank',
                   help='this is automatically passed in via launch.py')

    ]
    config = ConfigParser.from_args(args, options)

    if torch.cuda.is_available() and config['local_rank'] != -1:
        torch.cuda.set_device(config['local_rank'])
        device = 'cuda'
    else:
        device = 'cpu'
    device = torch.device(device)

    pick_model = config.init_obj('model_arch', pick_arch)
    pick_model.to(device)

    dataset = config.init_obj('train_dataset', pick_dataset)

    # filename = Path(__file__).parent.parent.joinpath('data/data_examples_root/train_samples_list.csv').as_posix()
    # dataset = PICKDataset(files_name=filename,
    #                       iob_tagging_type = 'box_level',
    #                       resized_image_size = (560,784))

    data_loader = DataLoader(dataset, batch_size=2, collate_fn=BatchCollateFn(), num_workers=2)
    for idx, data_item in tqdm(enumerate(data_loader)):
        for key, tensor in data_item.items():
            data_item[key] = tensor.to(device)
        output = pick_model(**data_item)

        logits = output['logits']
        new_mask = output['new_mask']
        adj = output['adj']
        gl_loss = output['gl_loss']
        crf_loss = output['crf_loss']
        print(gl_loss.shape, crf_loss.shape)
        # predicted_tags = output['predicted_tags']
        print(logits.shape)
Esempio n. 4
0
def test_datasets():
    filename = Path(__file__).parent.parent.joinpath('data/data_examples_root/train_samples_list.csv').as_posix()
    dataset = PICKDataset(files_name=filename,
                          iob_tagging_type='box_level',
                          resized_image_size=(560,784))

    data_loader = DataLoader(dataset, batch_size=3, collate_fn=BatchCollateFn(), num_workers=2)
    for idx, data_item in tqdm(enumerate(data_loader)):
        whole_image = data_item['whole_image']
        relation_features = data_item['relation_features']
        text_segments = data_item['text_segments']
        text_length = data_item['text_length']
        iob_tags_label = data_item['iob_tags_label']
        # entity_types = data_item['entity_types'] # (B, num_boxes)
        boxes_coordinate = data_item['boxes_coordinate']
        mask = data_item['mask']
        print(whole_image.shape)
Esempio n. 5
0
def main(args):
    device = torch.device(f'cuda:{gpu}' if gpu != None else 'cpu')
    checkpoint = torch.load(args.checkpoint, map_location=device)

    config = checkpoint['config']
    state_dict = checkpoint['state_dict']
    monitor_best = checkpoint['monitor_best']
    print('Loading checkpoint: {} \nwith saved mEF {:.4f} ...'.format(
        args.checkpoint, monitor_best))

    # prepare model for testing
    pick_model = config.init_obj('model_arch', pick_arch_module)
    pick_model = pick_model.to(device)
    pick_model.load_state_dict(state_dict)
    pick_model.eval()

    # setup dataset and data_loader instances
    test_dataset = PICKDataset(boxes_and_transcripts_folder=args.bt,
                               images_folder=args.impt,
                               resized_image_size=(560, 784),
                               ignore_error=False,
                               training=False,
                               max_boxes_num=130,
                               max_transcript_len=70)
    test_data_loader = DataLoader(test_dataset,
                                  batch_size=args.bs,
                                  shuffle=False,
                                  num_workers=2,
                                  collate_fn=BatchCollateFn(training=False))

    # setup output path
    output_path = Path(args.output_folder)
    output_path.mkdir(parents=True, exist_ok=True)

    # predict and save to file
    now_start = time.time()
    with torch.no_grad():
        for step_idx, input_data_item in enumerate(test_data_loader):
            # if step_idx!=355:
            #     continue
            now = time.time()
            for key, input_value in input_data_item.items():
                if input_value is not None:
                    input_data_item[key] = input_value.to(device)
            output = pick_model(**input_data_item)
            logits = output['logits']
            new_mask = output['new_mask']
            image_indexs = input_data_item['image_indexs']  # (B,)
            text_segments = input_data_item[
                'text_segments']  # (B, num_boxes, T)
            mask = input_data_item['mask']
            text_length = input_data_item['text_length']
            boxes_coors = input_data_item['boxes_coordinate'].cpu().numpy()[0]
            # List[(List[int], torch.Tensor)]
            best_paths = pick_model.decoder.crf_layer.viterbi_tags(
                logits, mask=new_mask, logits_batch_first=True)
            predicted_tags = []
            for path, score in best_paths:
                predicted_tags.append(path)

            # convert iob index to iob string
            decoded_tags_list = iob_index_to_str(predicted_tags)
            # union text as a sequence and convert index to string
            decoded_texts_list = text_index_to_str(text_segments, mask)
            for decoded_tags, decoded_texts, image_index in zip(
                    decoded_tags_list, decoded_texts_list, image_indexs):
                # List[ Tuple[str, Tuple[int, int]] ]
                # spans = bio_tags_to_spans(decoded_tags, [])
                spans, line_pos_from_bottom = bio_tags_to_spans2(
                    decoded_tags,
                    text_length.cpu().numpy())
                # spans = sorted(spans, key=lambda x: x[1][0])

                entities = []  # exists one to many case
                for entity_name, range_tuple in spans:
                    entity = dict(
                        entity_name=entity_name,
                        text=''.join(
                            decoded_texts[range_tuple[0]:range_tuple[1] + 1]))
                    entities.append(entity)

                result_file = output_path.joinpath(
                    Path(test_dataset.files_list[image_index]).stem + '.txt')
                base_filename = os.path.basename(result_file)
                list_coors = get_list_coors_from_line_pos_from_bottom(
                    args.impt, base_filename.replace('.txt', '.jpg'),
                    boxes_coors, line_pos_from_bottom)
                with result_file.open(mode='w', encoding='utf8') as f:
                    for jdx, item in enumerate(entities):
                        f.write('{}\t{}\t{}\n'.format(list_coors[jdx],
                                                      item['entity_name'],
                                                      item['text']))
            print(step_idx, base_filename, ", inference time:",
                  time.time() - now)
    print('time run program', time.time() - now_start)
    if kie_visualize:
        viz_output_of_pick(img_dir=rot_out_img_dir,
                           output_txt_dir=kie_out_txt_dir,
                           output_viz_dir=kie_out_viz_dir)
Esempio n. 6
0
def main(args):
    device = torch.device(f'cuda:{args.gpu}' if args.gpu != -1 else 'cpu')
    checkpoint = torch.load(args.checkpoint, map_location=device)

    config = checkpoint['config']
    state_dict = checkpoint['state_dict']
    monitor_best = checkpoint['monitor_best']
    print('Loading checkpoint: {} \nwith saved mEF {:.4f} ...'.format(
        args.checkpoint, monitor_best))

    # prepare model for testing
    pick_model = config.init_obj('model_arch', pick_arch_module)
    pick_model = pick_model.to(device)
    pick_model.load_state_dict(state_dict)
    pick_model.eval()

    # setup dataset and data_loader instances
    test_dataset = PICKDataset(
        files_name="/mnt/dick/PICK_dataset/test/test_samples_list.csv",
        boxes_and_transcripts_folder=args.bt,
        iob_tagging_type="box_level",
        images_folder=args.impt,
        resized_image_size=(480, 960),
        ignore_error=False,
        training=False)
    test_data_loader = DataLoader(test_dataset,
                                  batch_size=args.bs,
                                  shuffle=False,
                                  num_workers=2,
                                  collate_fn=BatchCollateFn(training=False))

    # setup output path
    output_path = Path(args.output_folder)
    output_path.mkdir(parents=True, exist_ok=True)

    # predict and save to file
    with torch.no_grad():
        for step_idx, input_data_item in tqdm(enumerate(test_data_loader)):
            # if step_idx > 1:
            #     break
            for key, input_value in input_data_item.items():
                if input_value is not None and isinstance(
                        input_value, torch.Tensor):
                    input_data_item[key] = input_value.to(device)

            # For easier debug.
            image_names = input_data_item["filenames"]

            output = pick_model(**input_data_item)
            logits = output['logits']  # (B, N*T, out_dim)
            new_mask = output['new_mask']
            image_indexs = input_data_item['image_indexs']  # (B,)
            text_segments = input_data_item[
                'text_segments']  # (B, num_boxes, T)
            mask = input_data_item['mask']
            # List[(List[int], torch.Tensor)]
            best_paths = pick_model.decoder.crf_layer.viterbi_tags(
                logits, mask=new_mask, logits_batch_first=True)
            predicted_tags = []
            for path, score in best_paths:
                predicted_tags.append(path)

            # convert iob index to iob string
            decoded_tags_list = iob_index_to_str(predicted_tags)
            # union text as a sequence and convert index to string
            decoded_texts_list = text_index_to_str(text_segments, mask)

            # for decoded_tags, decoded_texts, image_index in zip(decoded_tags_list, decoded_texts_list, image_indexs):
            for decoded_tags, decoded_texts, image_name in zip(
                    decoded_tags_list, decoded_texts_list, image_names):
                # List[ Tuple[str, Tuple[int, int]] ]
                spans = bio_tags_to_spans(decoded_tags, [])
                spans = sorted(spans, key=lambda x: x[1][0])

                entities = []  # exists one to many case
                for entity_name, range_tuple in spans:
                    entity = dict(
                        entity_name=entity_name,
                        text=''.join(
                            decoded_texts[range_tuple[0]:range_tuple[1] + 1]))
                    entities.append(entity)
                filename = os.path.basename(image_name)[:-4]

                result_file = output_path.joinpath(filename + '.txt')
                with result_file.open(mode='w') as f:
                    for item in entities:
                        f.write('{}\t{}\n'.format(item['entity_name'],
                                                  item['text']))
Esempio n. 7
0
def eval(args):
    device = torch.device(f'cuda:{args.gpu}' if args.gpu != -1 else 'cpu')
    checkpoint = torch.load(args.checkpoint, map_location=device)

    config = checkpoint['config']
    state_dict = checkpoint['state_dict']
    monitor_best = checkpoint['monitor_best']
    print('Loading checkpoint: {} \nwith saved mEF {:.4f} ...'.format(
        args.checkpoint, monitor_best))

    # prepare model for eval
    pick_model = config.init_obj('model_arch', pick_arch_module)
    pick_model = pick_model.to(device)
    pick_model.load_state_dict(state_dict)
    pick_model.eval()

    # setup dataset and data_loader instances
    test_dataset = PICKDataset(
        files_name=args.fn,
        boxes_and_transcripts_folder='boxes_and_transcripts',
        images_folder='images',
        entities_folder='entities',
        iob_tagging_type='box_and_within_box_level',
        ignore_error=False,
        training=True)
    test_data_loader = DataLoader(test_dataset,
                                  batch_size=args.bs,
                                  shuffle=False,
                                  num_workers=2,
                                  collate_fn=BatchCollateFn(training=True))

    #setup out path
    output_path = Path(args.output_folder)
    output_path.mkdir(parents=True, exist_ok=True)

    num_classes = 3
    confusion_matrix = torch.zeros([num_classes + 1, num_classes + 1])

    # caculate evaluation meansure
    with torch.no_grad():
        for step_idx, input_data_item in tqdm(enumerate(test_data_loader)):
            for key, input_value in input_data_item.items():
                if input_value is not None and isinstance(
                        input_value, torch.Tensor):
                    input_data_item[key] = input_value.to(device)

            # For easier debug.
            image_names = input_data_item["filenames"]

            output = pick_model(**input_data_item)
            logits = output['logits']  # (B, N*T, out_dim)
            new_mask = output['new_mask']

            #image_indexs = input_data_item['image_indexs']  # (B,)
            text_segments = input_data_item[
                'text_segments']  # (B, num_boxes, T)
            gt_masks = input_data_item['mask']
            gt_tags = input_data_item['iob_tags_label']
            gt_text_len = input_data_item['text_length']

            best_paths = pick_model.decoder.crf_layer.viterbi_tags(
                logits, mask=new_mask, logits_batch_first=True)
            predicted_tags = []
            for path, score in best_paths:
                predicted_tags.append(torch.Tensor(path))

            doc_seq_len = gt_text_len.sum(dim=-1)
            max_doc_seq_len = doc_seq_len.max()

            B, N, T = gt_tags.shape
            gt_tags = gt_tags.reshape(B, N * T)
            gt_masks = gt_masks.reshape(B, N * T)
            new_gt_tags = torch.zeros_like(gt_tags, dtype=torch.int64)
            new_gt_masks = torch.zeros_like(gt_masks)

            for i in range(B):
                doc_x = gt_tags[i]
                doc_mask_x = gt_masks[i]
                valid_doc_x = doc_x[doc_mask_x == 1]
                num_valid = valid_doc_x.size(0)
                new_gt_tags[i, :num_valid] = valid_doc_x
                new_gt_masks[i, :doc_seq_len[i]] = 1

            new_gt_tags[new_gt_tags < num_classes] += num_classes
            for i in range(B):
                prev = 0
                predicted_tags[i][
                    predicted_tags[i] < num_classes] += num_classes
                for box_len in gt_text_len[i]:
                    if torch.sum(
                            new_gt_tags[i][prev:prev + box_len]) == torch.sum(
                                predicted_tags[i][prev:prev + box_len]):
                        if new_gt_tags[i][prev] == 2 * num_classes:
                            confusion_matrix[num_classes][
                                num_classes] += 1  # 'other' entities
                        else:
                            confusion_matrix[
                                new_gt_tags[i][prev] - num_classes][
                                    new_gt_tags[i][prev] -
                                    num_classes] += 1  # labeled entities
                    else:
                        gt_class = torch.argmax(
                            torch.bincount(new_gt_tags[i][prev:prev +
                                                          box_len].int()))
                        pred_class = torch.argmax(
                            torch.bincount(predicted_tags[i][prev:prev +
                                                             box_len].int()))

                        if gt_class == 2 * num_classes:
                            gt_class = num_classes
                        else:
                            gt_class -= num_classes

                        if pred_class == 2 * num_classes:
                            pred_class = num_classes
                        else:
                            pred_class -= num_classes

                        confusion_matrix[gt_class][pred_class] += 1
                    prev += box_len
    confusion_matrix = torch.flip(confusion_matrix, [1])
    tag = [
        iob_labels_vocab_cls.itos[x].splt('-')[1] for x in range(num_classes)
    ]
    tag.append('other')
    df_cm = pd.DataFrame(confusion_matrix,
                         index=[i for i in tag],
                         columns=[i for i in reversed(tag)])
    plt.figure(figsize=(10, 7))
    sn.heatmap(df_cm, annot=True, fmt='g')
    plt.savefig(
        os.path.join(args.output_folder,
                     args.fn.split('.')[0] + '.png'))