Ejemplo n.º 1
0
 def from_args(cls, args_path: str):
     args = json.load(open(args_path))
     hierarchy = Hierarchy.from_tree_file(args["ontology"], with_other=args["with_other"])
     return cls(
         hierarchy=hierarchy,
         input_dim=args["input_dim"],
         type_dim=args["type_dim"],
         bottleneck_dim=args["bottleneck_dim"],
         mention_pooling=args["mention_pooling"],
         with_context=True,
         dropout_rate=args["dropout_rate"],
         emb_dropout_rate=args["emb_dropout_rate"],
         margins_per_level=args["margins"],
         num_negative_samples=args["num_negative_samples"],
         threshold_ratio=args["threshold_ratio"],
         relation_constraint_coef=args["relation_constraint_coef"],
         lift_other=args["lift_other"],
         compute_metric_when_training=True,
         decoder=BeamDecoder(
             hierarchy=hierarchy,
             strategies=args["strategies"],
             max_branching_factors=args["max_branching_factors"],
             delta=args["delta"]
         )
     )
Ejemplo n.º 2
0
    def __init__(self,
                 hierarchy: Hierarchy,
                 input_dim: int,
                 type_dim: int,
                 bottleneck_dim: int,
                 mention_pooling: str,
                 with_context: bool,
                 dropout_rate: float,
                 emb_dropout_rate: float,
                 margins_per_level: List[float],
                 num_negative_samples: int,
                 threshold_ratio: float,
                 relation_constraint_coef: float,
                 lift_other: bool,
                 compute_metric_when_training: bool,
                 decoder: HierarchyDecoder
                 ):
        super(HierarchicalTyper, self).__init__(vocab=None)

        self.hierarchy = hierarchy
        self.threshold_ratio = threshold_ratio
        self.relation_constraint_coef = relation_constraint_coef
        self.num_negative_samples = num_negative_samples
        self.lift_other = lift_other

        self.mention_feature_extractor = MentionFeatureExtractor(
            hierarchy=hierarchy,
            dim=input_dim,
            dropout_rate=emb_dropout_rate,
            mention_pooling=mention_pooling,
            with_context=with_context
        )
        self.type_scorer = TypeScorer(
            type_embeddings=torch.nn.Embedding(hierarchy.size(), type_dim),
            input_dim=(input_dim * 2) if with_context else input_dim,
            type_dim=type_dim,
            bottleneck_dim=bottleneck_dim,
            dropout_rate=dropout_rate
        )
        self.loss = IndexedHingeLoss(torch.tensor([0.0] + margins_per_level, dtype=torch.float32))
        self.rel_loss = RelationConstraintLoss(
            self.type_scorer.type_embeddings,
            ComplEx(self.type_scorer.type_embeddings.embedding_dim)
        )

        self.decoder = decoder
        self.compute_metric_when_training = compute_metric_when_training
        self.metric = HierarchicalMetric(hierarchy)

        self.trainer: MyTrainer = None
        self.current_epoch = 0
Ejemplo n.º 3
0
def main(*,
         model: str,
         model_file: str = "best.th",
         test: str,
         out: str,
         max_branching_factors: List[int],
         delta: List[float],
         strategies: List[str],
         other_delta: float = 0.0,
         seed: int = 0xDEADBEEF,
         batch_size: int = 256,
         gpuid: int = 0):
    TEST_ARGS = argparse.Namespace(**locals().copy())
    ARGS = argparse.Namespace(
        **json.load(open(f"{TEST_ARGS.model}/args.json", mode='r')))

    for key, val in ARGS.__dict__.items():
        print(f"ARG {key}: {val}", file=sys.stderr)
    for key, val in TEST_ARGS.__dict__.items():
        print(f"TEST_ARG {key}: {val}", file=sys.stderr)

    torch.cuda.set_device(gpuid)
    torch.manual_seed(seed)

    if TEST_ARGS.max_branching_factors is None:
        TEST_ARGS.max_branching_factors = ARGS.max_branching_factors
    if TEST_ARGS.delta is None:
        TEST_ARGS.delta = ARGS.delta
    if TEST_ARGS.strategies is None:
        TEST_ARGS.strategies = ARGS.strategies

    hierarchy: Hierarchy = Hierarchy.from_tree_file(filename=ARGS.ontology,
                                                    with_other=ARGS.with_other)

    model = HierarchicalTyper(
        hierarchy=hierarchy,
        input_dim=ARGS.input_dim,
        type_dim=ARGS.type_dim,
        bottleneck_dim=ARGS.bottleneck_dim,
        mention_pooling=ARGS.mention_pooling,
        with_context=True,
        dropout_rate=ARGS.dropout_rate,
        emb_dropout_rate=ARGS.emb_dropout_rate,
        margins_per_level=ARGS.margins,
        num_negative_samples=ARGS.num_negative_samples,
        threshold_ratio=ARGS.threshold_ratio,
        lift_other=ARGS.lift_other,
        relation_constraint_coef=ARGS.relation_constraint_coef,
        compute_metric_when_training=True,
        decoder=BeamDecoder(
            hierarchy=hierarchy,
            strategies=TEST_ARGS.strategies,
            max_branching_factors=TEST_ARGS.max_branching_factors,
            delta=TEST_ARGS.delta,
            top_other_delta=TEST_ARGS.other_delta))

    model_state = torch.load(f"{ARGS.out}/{TEST_ARGS.model_file}",
                             map_location=lambda storage, loc: storage)
    model.load_state_dict(model_state)
    model.cuda()
    model.eval()

    model.metric.set_serialization_dir(TEST_ARGS.out)
    print("Model loaded.", file=sys.stderr)

    test_reader = CachedMentionReader(hierarchy=hierarchy,
                                      model=ARGS.contextualizer)
    iterator = BasicIterator(batch_size=TEST_ARGS.batch_size)

    with torch.no_grad():
        for batch in tqdm.tqdm(
                iterator(instances=test_reader.read(TEST_ARGS.test),
                         num_epochs=1,
                         shuffle=False)):
            for k, v in batch.items():
                if hasattr(v, 'cuda'):
                    batch[k] = v.cuda()
            model(**batch)

    for m, v in model.metric.get_metric(reset=False).items():
        print(f"METRIC {m}: {v}")
Ejemplo n.º 4
0
def main(*,
         ontology: str,
         train: str,
         dev: str,
         out: str,
         contextualizer: str = "elmo-original",
         input_dim: int = 3072,
         type_dim: int = 1024,
         bottleneck_dim: int = 0,
         with_other: bool = False,
         lift_other: bool = False,
         mention_pooling: str = "max",
         emb_dropout_rate: float = 0.3,
         dropout_rate: float = 0.3,
         margins: List[float] = [],
         threshold_ratio: float = 0.1,
         relation_constraint_coef: float = 0.1,
         num_negative_samples: int = 0,
         max_branching_factors: List[int] = [],
         delta: List[float] = [],
         strategies: List[str] = [],
         seed: int = 0xDEADBEEF,
         batch_size: int = 256,
         dev_batch_size: int = 256,
         num_epochs: int = 5,
         dev_metric: str = "+O_MiF",
         patience: int = 4,
         lr: float = 1e-5,
         regularizer: float = 0.1,
         gpuid: int = 0):

    args = locals().copy()

    with open(f"{out}/args.json", mode='w') as args_out:
        for k, v in reversed(list(args.items(
        ))):  # seems that `locals()` stores the args in reverse order
            print(f"{blue('--' + k)} \"{v}\"", file=sys.stderr)
        print(json.dumps(args, indent=2), file=args_out)

    torch.cuda.set_device(gpuid)

    # Ensure deterministic behavior
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    logging.basicConfig(level=logging.INFO)

    hierarchy: Hierarchy = Hierarchy.from_tree_file(ontology,
                                                    with_other=with_other)
    print(hierarchy, file=sys.stderr)

    reader = CachedMentionReader(hierarchy, model=contextualizer)

    model = HierarchicalTyper(
        hierarchy=hierarchy,
        input_dim=input_dim,
        type_dim=type_dim,
        bottleneck_dim=bottleneck_dim,
        mention_pooling=mention_pooling,
        with_context=True,
        dropout_rate=dropout_rate,
        emb_dropout_rate=emb_dropout_rate,
        margins_per_level=margins,
        num_negative_samples=num_negative_samples,
        threshold_ratio=threshold_ratio,
        relation_constraint_coef=relation_constraint_coef,
        lift_other=lift_other,
        compute_metric_when_training=True,
        decoder=BeamDecoder(hierarchy=hierarchy,
                            strategies=strategies,
                            max_branching_factors=max_branching_factors,
                            delta=delta))
    model.cuda()

    optimizer: Optimizer = AdamW(params=model.parameters(),
                                 lr=lr,
                                 weight_decay=regularizer)

    trainer = MyTrainer(
        model=model,
        optimizer=optimizer,
        iterator=BasicIterator(batch_size=batch_size),
        validation_iterator=BasicIterator(batch_size=dev_batch_size),
        train_dataset=reader.read(train),
        validation_dataset=reader.read(dev),
        validation_metric=dev_metric,
        patience=patience,
        num_epochs=num_epochs,
        grad_norm=1.0,
        serialization_dir=out,
        num_serialized_models_to_keep=1,
        cuda_device=gpuid)

    model.set_trainer(trainer)
    model.metric.set_serialization_dir(trainer._serialization_dir)
    # hook into the trainer to set the metric serialization path
    trainer.train()
Ejemplo n.º 5
0
def main(*,
         model: str,
         model_file: str = "best.th",
         test: str,
         out: str,
         max_branching_factors: List[int],
         layers: List[int],
         delta: List[float],
         strategies: List[str],
         other_delta: float = 0.0,
         seed: int = 0xDEADBEEF,
         batch_size: int = 256,
         gpuid: int = 0,
         out_file_name: str,
         cached: bool = True):
    TEST_ARGS = argparse.Namespace(**locals().copy())
    ARGS = argparse.Namespace(
        **json.load(open(f"{TEST_ARGS.model}/args.json", mode='r')))

    for key, val in ARGS.__dict__.items():
        print(f"ARG {key}: {val}", file=sys.stderr)
    for key, val in TEST_ARGS.__dict__.items():
        print(f"TEST_ARG {key}: {val}", file=sys.stderr)

    if cached:
        torch.cuda.set_device(gpuid)
    torch.manual_seed(seed)

    if TEST_ARGS.max_branching_factors is None:
        TEST_ARGS.max_branching_factors = ARGS.max_branching_factors
    if TEST_ARGS.delta is None:
        TEST_ARGS.delta = ARGS.delta
    if TEST_ARGS.strategies is None:
        TEST_ARGS.strategies = ARGS.strategies

    hierarchy: Hierarchy = Hierarchy.from_tree_file(filename=ARGS.ontology,
                                                    with_other=ARGS.with_other)

    model = HierarchicalTyper(
        hierarchy=hierarchy,
        input_dim=ARGS.input_dim,
        type_dim=ARGS.type_dim,
        bottleneck_dim=ARGS.bottleneck_dim,
        mention_pooling=ARGS.mention_pooling,
        with_context=True,
        dropout_rate=ARGS.dropout_rate,
        emb_dropout_rate=ARGS.emb_dropout_rate,
        margins_per_level=ARGS.margins,
        num_negative_samples=ARGS.num_negative_samples,
        threshold_ratio=ARGS.threshold_ratio,
        lift_other=ARGS.lift_other,
        relation_constraint_coef=ARGS.relation_constraint_coef,
        compute_metric_when_training=True,
        decoder=BeamDecoder(
            hierarchy=hierarchy,
            strategies=TEST_ARGS.strategies,
            max_branching_factors=TEST_ARGS.max_branching_factors,
            delta=TEST_ARGS.delta,
            top_other_delta=TEST_ARGS.other_delta))

    model_state = torch.load(f"{ARGS.out}/{TEST_ARGS.model_file}",
                             map_location=lambda storage, loc: storage)
    model.load_state_dict(model_state)
    model.cuda()
    model.eval()

    model.metric.set_serialization_dir(TEST_ARGS.out)
    print("Model loaded.", file=sys.stderr)

    if cached:
        test_reader = CachedMentionReader(hierarchy=hierarchy,
                                          model=ARGS.contextualizer)
    else:
        test_reader = UncachedMentionReader(hierarchy=hierarchy,
                                            model=ARGS.contextualizer,
                                            layers=layers,
                                            gpuid=1)
    iterator = BasicIterator(batch_size=TEST_ARGS.batch_size)

    print("Batch size: ", TEST_ARGS.batch_size)

    print(f"Writing prediction results to: {out_file_name}")
    result_fp = open(out_file_name, 'w', encoding='utf8')

    debug = True

    with torch.no_grad():
        for batch in tqdm.tqdm(
                iterator(instances=test_reader.read(TEST_ARGS.test),
                         num_epochs=1,
                         shuffle=False)):
            for k, v in batch.items():
                if hasattr(v, 'cuda'):
                    batch[k] = v.cuda()
            batch_predicted_types = model.predict(
                **batch
            )  # List[Set(int)], length of list: Batch, set of positive type_ids
            sents = batch['sentence_text']
            span_texts = batch['span_text']
            span_lefts = batch['span_left'].tolist()
            span_rights = batch['span_right'].tolist()
            span_lens = batch['span_length'].tolist()

            if debug:
                for i in range(len(sents)):
                    print("sents: ")
                    print(sents[i])
                    print("spans: ")
                    print(f"{span_lefts[i]}:{span_rights[i]}")
                    print("span texts: ")
                    print(span_texts[i])
                    print("span lengths: ")
                    print(span_lens[i])
                    print("predicted types: ")
                    print(batch_predicted_types[i])
                debug = False

            for ins_id, ins in enumerate(batch_predicted_types):
                # below in ins_span the '-1' is needed because all span idxs are previously moved right 1 slot to
                # accomodate the [CLS] token at front.
                res_ins = {
                    'ins_types': [],
                    'ins_sent': sents[ins_id],
                    'ins_span':
                    [span_lefts[ins_id] - 1, span_rights[ins_id] - 1],
                    'ins_span_text': span_texts[ins_id]
                }
                for type_id in ins:
                    res_ins['ins_types'].append(hierarchy.type_str(type_id))
                res_ins_line = json.dumps(res_ins, ensure_ascii=False)
                result_fp.write(res_ins_line + '\n')

    print("Finished!")
    result_fp.close()