def test_multimodal_cand(caplog): """Test multimodal candidate generation""" caplog.set_level(logging.INFO) PARALLEL = 4 max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/pure_html/radiology.html" logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 35 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction ms_doc = mention_subclass("m_doc") ms_sec = mention_subclass("m_sec") ms_tab = mention_subclass("m_tab") ms_fig = mention_subclass("m_fig") ms_cell = mention_subclass("m_cell") ms_para = mention_subclass("m_para") ms_cap = mention_subclass("m_cap") ms_sent = mention_subclass("m_sent") m_doc = MentionDocuments() m_sec = MentionSections() m_tab = MentionTables() m_fig = MentionFigures() m_cell = MentionCells() m_para = MentionParagraphs() m_cap = MentionCaptions() m_sent = MentionSentences() ms = [ms_doc, ms_cap, ms_sec, ms_tab, ms_fig, ms_para, ms_sent, ms_cell] m = [m_doc, m_cap, m_sec, m_tab, m_fig, m_para, m_sent, m_cell] matchers = [DoNothingMatcher()] * 8 mention_extractor = MentionExtractor(session, ms, m, matchers, parallelism=PARALLEL) mention_extractor.apply(docs) assert session.query(ms_doc).count() == 1 assert session.query(ms_cap).count() == 2 assert session.query(ms_sec).count() == 5 assert session.query(ms_tab).count() == 2 assert session.query(ms_fig).count() == 2 assert session.query(ms_para).count() == 30 assert session.query(ms_sent).count() == 35 assert session.query(ms_cell).count() == 21
def test_multimodal_cand(): """Test multimodal candidate generation""" file_name = "radiology" docs_path = f"tests/data/pure_html/{file_name}.html" doc = parse_doc(docs_path, file_name) assert len(doc.sentences) == 35 # Mention Extraction ms_doc = mention_subclass("m_doc") ms_sec = mention_subclass("m_sec") ms_tab = mention_subclass("m_tab") ms_fig = mention_subclass("m_fig") ms_cell = mention_subclass("m_cell") ms_para = mention_subclass("m_para") ms_cap = mention_subclass("m_cap") ms_sent = mention_subclass("m_sent") m_doc = MentionDocuments() m_sec = MentionSections() m_tab = MentionTables() m_fig = MentionFigures() m_cell = MentionCells() m_para = MentionParagraphs() m_cap = MentionCaptions() m_sent = MentionSentences() ms = [ms_doc, ms_cap, ms_sec, ms_tab, ms_fig, ms_para, ms_sent, ms_cell] m = [m_doc, m_cap, m_sec, m_tab, m_fig, m_para, m_sent, m_cell] matchers = [DoNothingMatcher()] * 8 mention_extractor_udf = MentionExtractorUDF(ms, m, matchers) doc = mention_extractor_udf.apply(doc) assert len(doc.m_docs) == 1 assert len(doc.m_caps) == 2 assert len(doc.m_secs) == 5 assert len(doc.m_tabs) == 2 assert len(doc.m_figs) == 2 assert len(doc.m_paras) == 30 assert len(doc.m_sents) == 35 assert len(doc.m_cells) == 21
def apply(self, doc: Document) -> Iterator[TemporaryContext]: """Generate MentionNgrams from a Document by parsing all of its Sentences. :param doc: The ``Document`` to parse. :type doc: ``Document`` :raises TypeError: If the input doc is not of type ``Document``. """ if not isinstance(doc, Document): raise TypeError( "Input Contexts to MentionNgrams.apply() must be of type Document" ) for ts in MentionNgrams.apply(self, doc): yield ts for ts in MentionDocuments.apply(self, doc): yield ts for ts in MentionFigures.apply(self, doc): yield ts
def test_cand_gen(caplog): """Test extracting candidates from mentions from documents.""" caplog.set_level(logging.INFO) if platform == "darwin": logger.info("Using single core.") PARALLEL = 1 else: logger.info("Using two cores.") PARALLEL = 2 # Travis only gives 2 cores def do_nothing_matcher(fig): return True max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, structural=True, lingual=True, visual=True, pdf_path=pdf_path ) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 799 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) figs = MentionFigures(types="png") Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") Fig = mention_subclass("Fig") fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher) with pytest.raises(ValueError): mention_extractor = MentionExtractor( session, [Part, Temp, Volt], [part_ngrams, volt_ngrams], # Fail, mismatched arity [part_matcher, temp_matcher, volt_matcher], ) with pytest.raises(ValueError): mention_extractor = MentionExtractor( session, [Part, Temp, Volt], [part_ngrams, temp_matcher, volt_ngrams], [part_matcher, temp_matcher], # Fail, mismatched arity ) mention_extractor = MentionExtractor( session, [Part, Temp, Volt, Fig], [part_ngrams, temp_ngrams, volt_ngrams, figs], [part_matcher, temp_matcher, volt_matcher, fig_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 70 assert session.query(Volt).count() == 33 assert session.query(Temp).count() == 23 assert session.query(Fig).count() == 31 part = session.query(Part).order_by(Part.id).all()[0] volt = session.query(Volt).order_by(Volt.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info(f"Part: {part.context}") logger.info(f"Volt: {volt.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) with pytest.raises(ValueError): candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[ temp_throttler, volt_throttler, volt_throttler, ], # Fail, mismatched arity ) with pytest.raises(ValueError): candidate_extractor = CandidateExtractor( session, [PartTemp], # Fail, mismatched arity throttlers=[temp_throttler, volt_throttler], ) # Test that no throttler in candidate extractor candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt] ) # Pass, no throttler candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 1610 assert session.query(PartVolt).count() == 2310 assert session.query(Candidate).count() == 3920 candidate_extractor.clear_all(split=0) assert session.query(Candidate).count() == 0 assert session.query(PartTemp).count() == 0 assert session.query(PartVolt).count() == 0 # Test with None in throttlers in candidate extractor candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, None] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 1432 assert session.query(PartVolt).count() == 2310 assert session.query(Candidate).count() == 3742 candidate_extractor.clear_all(split=0) assert session.query(Candidate).count() == 0 candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 1432 assert session.query(PartVolt).count() == 1993 assert session.query(Candidate).count() == 3425 assert docs[0].name == "112823" assert len(docs[0].parts) == 70 assert len(docs[0].volts) == 33 assert len(docs[0].temps) == 23 # Test that deletion of a Candidate does not delete the Mention session.query(PartTemp).delete(synchronize_session="fetch") assert session.query(PartTemp).count() == 0 assert session.query(Temp).count() == 23 assert session.query(Part).count() == 70 # Test deletion of Candidate if Mention is deleted assert session.query(PartVolt).count() == 1993 assert session.query(Volt).count() == 33 session.query(Volt).delete(synchronize_session="fetch") assert session.query(Volt).count() == 0 assert session.query(PartVolt).count() == 0
def test_too_many_clients_error_should_not_happen(): """Too many clients error should not happens.""" PARALLEL = 32 logger.info("Parallel: {PARALLEL}") def do_nothing_matcher(fig): return True max_docs = 1 session = Meta.init(CONN_STRING).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, structural=True, lingual=True, visual=True, pdf_path=pdf_path ) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) figs = MentionFigures(types="png") Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") Fig = mention_subclass("Fig") fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher) mention_extractor = MentionExtractor( session, [Part, Temp, Volt, Fig], [part_ngrams, temp_ngrams, volt_ngrams, figs], [part_matcher, temp_matcher, volt_matcher, fig_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) # Test that no throttler in candidate extractor candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt] ) # Pass, no throttler candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) candidate_extractor.clear_all(split=0) # Test with None in throttlers in candidate extractor candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, None] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)
def main( conn_string, max_docs=float("inf"), parse=False, first_time=False, gpu=None, parallel=4, log_dir=None, verbose=False, ): if not log_dir: log_dir = "logs" if verbose: level = logging.INFO else: level = logging.WARNING dirname = os.path.dirname(os.path.abspath(__file__)) init_logging(log_dir=os.path.join(dirname, log_dir), level=level) tuner_config = {"max_search": 3} em_config = { # GENERAL "seed": None, "verbose": True, "show_plots": True, # Network # The first value is the output dim of the input module (or the sum of # the output dims of all the input modules if multitask=True and # multiple input modules are provided). The last value is the # output dim of the head layer (i.e., the cardinality of the # classification task). The remaining values are the output dims of # middle layers (if any). The number of middle layers will be inferred # from this list. # "layer_out_dims": [10, 2], # Input layer configs "input_layer_config": { "input_relu": False, "input_batchnorm": False, "input_dropout": 0.0, }, # Middle layer configs "middle_layer_config": { "middle_relu": False, "middle_batchnorm": False, "middle_dropout": 0.0, }, # Can optionally skip the head layer completely, for e.g. running baseline # models... "skip_head": True, # GPU "device": "cpu", # MODEL CLASS "resnet18" # DATA CONFIG "src": "gm", # TRAINING "train_config": { # Display "print_every": 1, # Print after this many epochs "disable_prog_bar": False, # Disable progress bar each epoch # Dataloader "data_loader_config": { "batch_size": 32, "num_workers": 8, "sampler": None }, # Loss weights "loss_weights": [0.5, 0.5], # Train Loop "n_epochs": 20, # 'grad_clip': 0.0, "l2": 0.0, # "lr": 0.01, "validation_metric": "accuracy", "validation_freq": 1, # Evaluate dev for during training every this many epochs # Optimizer "optimizer_config": { "optimizer": "adam", "optimizer_common": { "lr": 0.01 }, # Optimizer - SGD "sgd_config": { "momentum": 0.9 }, # Optimizer - Adam "adam_config": { "betas": (0.9, 0.999) }, }, # Scheduler "scheduler_config": { "scheduler": "reduce_on_plateau", # ['constant', 'exponential', 'reduce_on_plateu'] # Freeze learning rate initially this many epochs "lr_freeze": 0, # Scheduler - exponential "exponential_config": { "gamma": 0.9 }, # decay rate # Scheduler - reduce_on_plateau "plateau_config": { "factor": 0.5, "patience": 1, "threshold": 0.0001, "min_lr": 1e-5, }, }, # Checkpointer "checkpoint": True, "checkpoint_config": { "checkpoint_min": -1, # The initial best score to beat to merit checkpointing "checkpoint_runway": 0, # Don't start taking checkpoints until after this many epochs }, }, } session = Meta.init(conn_string).Session() os.chdir(os.path.dirname(os.path.abspath(__file__))) logger.info(f"CWD: {os.getcwd()}") dirname = "." docs, train_docs, dev_docs, test_docs = parse_dataset( session, dirname, first_time=first_time, parallel=parallel, max_docs=max_docs) logger.info(f"# of train Documents: {len(train_docs)}") logger.info(f"# of dev Documents: {len(dev_docs)}") logger.info(f"# of test Documents: {len(test_docs)}") logger.info(f"Documents: {session.query(Document).count()}") logger.info(f"Sections: {session.query(Section).count()}") logger.info(f"Paragraphs: {session.query(Paragraph).count()}") logger.info(f"Sentences: {session.query(Sentence).count()}") logger.info(f"Figures: {session.query(Figure).count()}") Thumbnails = mention_subclass("Thumbnails") thumbnails_img = MentionFigures() class HasFigures(_Matcher): def _f(self, m): file_path = "" for prefix in [ "data/train/html/", "data/dev/html/", "data/test/html/" ]: if os.path.exists(prefix + m.figure.url): file_path = prefix + m.figure.url if file_path == "": return False img = Image.open(file_path) width, height = img.size min_value = min(width, height) return min_value > 50 mention_extractor = MentionExtractor(session, [Thumbnails], [thumbnails_img], [HasFigures()], parallelism=parallel) if first_time: mention_extractor.apply(docs) logger.info("Total Mentions: {}".format(session.query(Mention).count())) ThumbnailLabel = candidate_subclass("ThumbnailLabel", [Thumbnails]) candidate_extractor = CandidateExtractor(session, [ThumbnailLabel], throttlers=[None], parallelism=parallel) if first_time: candidate_extractor.apply(train_docs, split=0) candidate_extractor.apply(dev_docs, split=1) candidate_extractor.apply(test_docs, split=2) train_cands = candidate_extractor.get_candidates(split=0) dev_cands = candidate_extractor.get_candidates(split=1) test_cands = candidate_extractor.get_candidates(split=2) logger.info("Total train candidate:\t{}".format(len(train_cands[0]))) logger.info("Total dev candidate:\t{}".format(len(dev_cands[0]))) logger.info("Total test candidate:\t{}".format(len(test_cands[0]))) fin = open("data/ground_truth.txt", "r") gt = set() for line in fin: gt.add("::".join(line.lower().split())) fin.close() def LF_gt_label(c): doc_file_id = (f"{c[0].context.figure.document.name.lower()}.pdf::" f"{os.path.basename(c[0].context.figure.url.lower())}") return TRUE if doc_file_id in gt else FALSE ans = {0: 0, 1: 0, 2: 0} gt_dev_pb = [] gt_dev = [] gt_test = [] for cand in dev_cands[0]: if LF_gt_label(cand) == 1: ans[1] += 1 gt_dev_pb.append([1.0, 0.0]) gt_dev.append(1.0) else: ans[2] += 1 gt_dev_pb.append([0.0, 1.0]) gt_dev.append(2.0) ans = {0: 0, 1: 0, 2: 0} for cand in test_cands[0]: gt_test.append(LF_gt_label(cand)) ans[gt_test[-1]] += 1 batch_size = 64 input_size = 224 train_loader = torch.utils.data.DataLoader( ImageList( data=dev_cands[0], label=torch.Tensor(gt_dev_pb), transform=transform(input_size), prefix="data/dev/html/", ), batch_size=batch_size, shuffle=False, ) dev_loader = torch.utils.data.DataLoader( ImageList( data=dev_cands[0], label=gt_dev, transform=transform(input_size), prefix="data/dev/html/", ), batch_size=batch_size, shuffle=False, ) test_loader = torch.utils.data.DataLoader( ImageList( data=test_cands[0], label=gt_test, transform=transform(input_size), prefix="data/test/html/", ), batch_size=100, shuffle=False, ) search_space = { "l2": [0.001, 0.0001, 0.00001], # linear range "lr": { "range": [0.0001, 0.1], "scale": "log" }, # log range } train_config = em_config["train_config"] # Defining network parameters num_classes = 2 # fc_size = 2 # hidden_size = 2 pretrained = True # Set CUDA device if gpu: em_config["device"] = "cuda" torch.cuda.set_device(int(gpu)) # Initializing input module input_module = get_cnn("resnet18", pretrained=pretrained, num_classes=num_classes) # Initializing model object init_args = [[num_classes]] init_kwargs = {"input_module": input_module} init_kwargs.update(em_config) # Searching model log_config = { "log_dir": os.path.join(dirname, log_dir), "run_name": "image" } searcher = RandomSearchTuner(EndModel, **log_config) end_model = searcher.search( search_space, dev_loader, train_args=[train_loader], init_args=init_args, init_kwargs=init_kwargs, train_kwargs=train_config, max_search=tuner_config["max_search"], ) # Evaluating model scores = end_model.score( test_loader, metric=["accuracy", "precision", "recall", "f1"], break_ties="abstain", ) logger.warning("End Model Score:") logger.warning(f"precision: {scores[1]:.3f}") logger.warning(f"recall: {scores[2]:.3f}") logger.warning(f"f1: {scores[3]:.3f}")
def test_cand_gen(): """Test extracting candidates from mentions from documents.""" def do_nothing_matcher(fig): return True docs_path = "tests/data/html/112823.html" pdf_path = "tests/data/pdf/112823.pdf" doc = parse_doc(docs_path, "112823", pdf_path) # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) figs = MentionFigures(types="png") Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") Fig = mention_subclass("Fig") fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher) with pytest.raises(ValueError): MentionExtractor( "dummy", [Part, Temp, Volt], [part_ngrams, volt_ngrams], # Fail, mismatched arity [part_matcher, temp_matcher, volt_matcher], ) with pytest.raises(ValueError): MentionExtractor( "dummy", [Part, Temp, Volt], [part_ngrams, temp_matcher, volt_ngrams], [part_matcher, temp_matcher], # Fail, mismatched arity ) mention_extractor_udf = MentionExtractorUDF( [Part, Temp, Volt, Fig], [part_ngrams, temp_ngrams, volt_ngrams, figs], [part_matcher, temp_matcher, volt_matcher, fig_matcher], ) doc = mention_extractor_udf.apply(doc) assert len(doc.parts) == 70 assert len(doc.volts) == 33 assert len(doc.temps) == 23 assert len(doc.figs) == 31 part = doc.parts[0] volt = doc.volts[0] temp = doc.temps[0] logger.info(f"Part: {part.context}") logger.info(f"Volt: {volt.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) with pytest.raises(ValueError): CandidateExtractor( "dummy", [PartTemp, PartVolt], throttlers=[ temp_throttler, volt_throttler, volt_throttler, ], # Fail, mismatched arity ) with pytest.raises(ValueError): CandidateExtractor( "dummy", [PartTemp], # Fail, mismatched arity throttlers=[temp_throttler, volt_throttler], ) # Test that no throttler in candidate extractor candidate_extractor_udf = CandidateExtractorUDF( [PartTemp, PartVolt], [None, None], False, False, True # Pass, no throttler ) doc = candidate_extractor_udf.apply(doc, split=0) assert len(doc.part_temps) == 1610 assert len(doc.part_volts) == 2310 # Clear doc.part_temps = [] doc.part_volts = [] # Test with None in throttlers in candidate extractor candidate_extractor_udf = CandidateExtractorUDF( [PartTemp, PartVolt], [temp_throttler, None], False, False, True ) doc = candidate_extractor_udf.apply(doc, split=0) assert len(doc.part_temps) == 1432 assert len(doc.part_volts) == 2310 # Clear doc.part_temps = [] doc.part_volts = [] candidate_extractor_udf = CandidateExtractorUDF( [PartTemp, PartVolt], [temp_throttler, volt_throttler], False, False, True ) doc = candidate_extractor_udf.apply(doc, split=0) assert len(doc.part_temps) == 1432 assert len(doc.part_volts) == 1993 assert len(doc.parts) == 70 assert len(doc.volts) == 33 assert len(doc.temps) == 23
def main( conn_string, max_docs=float("inf"), parse=False, first_time=False, gpu=None, parallel=4, log_dir=None, verbose=False, ): if not log_dir: log_dir = "logs" if verbose: level = logging.INFO else: level = logging.WARNING dirname = os.path.dirname(os.path.abspath(__file__)) init_logging(log_dir=os.path.join(dirname, log_dir), level=level) session = Meta.init(conn_string).Session() # Parsing logger.info(f"Starting parsing...") start = timer() docs, train_docs, dev_docs, test_docs = parse_dataset( session, dirname, first_time=first_time, parallel=parallel, max_docs=max_docs ) end = timer() logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}") logger.info(f"# of train Documents: {len(train_docs)}") logger.info(f"# of dev Documents: {len(dev_docs)}") logger.info(f"# of test Documents: {len(test_docs)}") logger.info(f"Documents: {session.query(Document).count()}") logger.info(f"Sections: {session.query(Section).count()}") logger.info(f"Paragraphs: {session.query(Paragraph).count()}") logger.info(f"Sentences: {session.query(Sentence).count()}") logger.info(f"Figures: {session.query(Figure).count()}") start = timer() Thumbnails = mention_subclass("Thumbnails") thumbnails_img = MentionFigures() class HasFigures(_Matcher): def _f(self, m): file_path = "" for prefix in [ f"{dirname}/data/train/html/", f"{dirname}/data/dev/html/", f"{dirname}/data/test/html/", ]: if os.path.exists(prefix + m.figure.url): file_path = prefix + m.figure.url if file_path == "": return False img = Image.open(file_path) width, height = img.size min_value = min(width, height) return min_value > 50 mention_extractor = MentionExtractor( session, [Thumbnails], [thumbnails_img], [HasFigures()], parallelism=parallel ) if first_time: mention_extractor.apply(docs) logger.info("Total Mentions: {}".format(session.query(Mention).count())) ThumbnailLabel = candidate_subclass("ThumbnailLabel", [Thumbnails]) candidate_extractor = CandidateExtractor( session, [ThumbnailLabel], throttlers=[None], parallelism=parallel ) if first_time: candidate_extractor.apply(train_docs, split=0) candidate_extractor.apply(dev_docs, split=1) candidate_extractor.apply(test_docs, split=2) train_cands = candidate_extractor.get_candidates(split=0) # Sort the dev_cands, which are used for training, for deterministic behavior dev_cands = candidate_extractor.get_candidates(split=1, sort=True) test_cands = candidate_extractor.get_candidates(split=2) end = timer() logger.warning(f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}") logger.info("Total train candidate:\t{}".format(len(train_cands[0]))) logger.info("Total dev candidate:\t{}".format(len(dev_cands[0]))) logger.info("Total test candidate:\t{}".format(len(test_cands[0]))) fin = open(f"{dirname}/data/ground_truth.txt", "r") gt = set() for line in fin: gt.add("::".join(line.lower().split())) fin.close() # Labeling start = timer() def LF_gt_label(c): doc_file_id = ( f"{c[0].context.figure.document.name.lower()}.pdf::" f"{os.path.basename(c[0].context.figure.url.lower())}" ) return TRUE if doc_file_id in gt else FALSE gt_dev = [LF_gt_label(cand) for cand in dev_cands[0]] gt_test = [LF_gt_label(cand) for cand in test_cands[0]] end = timer() logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}") batch_size = 64 input_size = 224 K = 2 emmental.init(log_dir=Meta.log_path, config=emmental_config) emmental.Meta.config["learner_config"]["task_scheduler_config"][ "task_scheduler" ] = DauphinScheduler(augment_k=K, enlarge=1) train_dataset = ThumbnailDataset( "Thumbnail", dev_cands[0], gt_dev, "train", prob_label=True, prefix=f"{dirname}/data/dev/html/", input_size=input_size, transform_cls=Augmentation(2), k=K, ) val_dataset = ThumbnailDataset( "Thumbnail", dev_cands[0], gt_dev, "valid", prob_label=False, prefix=f"{dirname}/data/dev/html/", input_size=input_size, k=1, ) test_dataset = ThumbnailDataset( "Thumbnail", test_cands[0], gt_test, "test", prob_label=False, prefix=f"{dirname}/data/test/html/", input_size=input_size, k=1, ) dataloaders = [] dataloaders.append( EmmentalDataLoader( task_to_label_dict={"Thumbnail": "labels"}, dataset=train_dataset, split="train", shuffle=True, batch_size=batch_size, num_workers=1, ) ) dataloaders.append( EmmentalDataLoader( task_to_label_dict={"Thumbnail": "labels"}, dataset=val_dataset, split="valid", shuffle=False, batch_size=batch_size, num_workers=1, ) ) dataloaders.append( EmmentalDataLoader( task_to_label_dict={"Thumbnail": "labels"}, dataset=test_dataset, split="test", shuffle=False, batch_size=batch_size, num_workers=1, ) ) model = EmmentalModel(name=f"Thumbnail") model.add_task( create_task("Thumbnail", n_class=2, model="resnet18", pretrained=True) ) emmental_learner = EmmentalLearner() emmental_learner.learn(model, dataloaders) scores = model.score(dataloaders) logger.warning("Model Score:") logger.warning(f"precision: {scores['Thumbnail/Thumbnail/test/precision']:.3f}") logger.warning(f"recall: {scores['Thumbnail/Thumbnail/test/recall']:.3f}") logger.warning(f"f1: {scores['Thumbnail/Thumbnail/test/f1']:.3f}")