コード例 #1
0
def dataset_to_input_instances(dataset: List[Tuple[List[str], List[str], str]]) -> List[InputInstance]:
    input_instances = []
    for idx, (sent1, sent2, _) in enumerate(dataset):
        instance = InputInstance(id_=idx, sent1=web_tokenizer(sent1), sent2=web_tokenizer(sent2))
        input_instances.append(instance)

    return input_instances
コード例 #2
0
ファイル: unk_replace_test.py プロジェクト: DFKI-NLP/OLM
def test_unk_replacement():
    input_instance = InputInstance(id_=1,
                                   sent1=["a", "b", "c"],
                                   sent2=["d", "e", "f"])
    strategy = UnkReplacement(unk_token=UNK_TOKEN)

    candidate_instances = strategy.get_candidate_instances(input_instance)

    assert len(candidate_instances) == 7

    instance = candidate_instances[1]
    assert instance.sent1.tokens == [UNK_TOKEN, "b", "c"]
    assert instance.sent2.tokens == input_instance.sent2.tokens
コード例 #3
0
def test_engine():
    config = Config.from_dict(TEST_CONFIG)
    engine = Engine(config, batcher=batcher)

    input_instance = InputInstance(id_=1,
                                   sent1=["a", "b", "c"],
                                   sent2=["d", "e", "f"])

    occluded_instances, instance_probabilities = engine.run([input_instance])

    assert len(occluded_instances) == 7

    relevances = engine.relevances(occluded_instances, instance_probabilities)
コード例 #4
0
def test_occluded_instance():
    instance = InputInstance(id_=1, sent1=SENT_1, sent2=SENT_2)

    with pytest.raises(ValueError) as excinfo:
        OccludedInstance.from_input_instance(instance,
                                             occlude_token="occluded",
                                             weight=5)
    assert "'occlude_token' requires setting 'occlude_field_index'" in str(
        excinfo.value)

    occluded_inst = OccludedInstance.from_input_instance(
        instance,
        occlude_token="occluded",
        occlude_field_index=("sent1", 1),
        weight=5)

    assert occluded_inst.sent1.occluded_index == 1

    assert occluded_inst.occluded_indices == ("sent1", 1)
コード例 #5
0
def test_input_instance():
    instance = InputInstance(id_=1, sent1=SENT_1, sent2=SENT_2)

    assert instance.id == 1
    assert instance.sent1.tokens == SENT_1
    assert instance.token_fields["sent1"].tokens == SENT_1
    assert instance.token_fields["sent2"].tokens == SENT_2
    assert instance.sent2.tokens == SENT_2
    assert isinstance(instance.sent1, TokenField)
    assert isinstance(instance.sent2, TokenField)

    occluded_inst = OccludedInstance.from_input_instance(
        instance,
        occlude_token="occluded",
        occlude_field_index=("sent1", 1),
        weight=5)

    assert occluded_inst.id == 1
    assert occluded_inst.sent1.tokens == OCCLUDED_SENT_1
    assert occluded_inst.token_fields["sent1"].tokens == OCCLUDED_SENT_1
    assert occluded_inst.token_fields["sent2"].tokens == SENT_2
    assert occluded_inst.sent2.tokens == SENT_2
    assert isinstance(occluded_inst.sent1, OccludedTokenField)
    assert isinstance(occluded_inst.sent2, TokenField)
コード例 #6
0
def dataset_to_input_instances(dataset: List[Instance]) -> List[InputInstance]:
    return [
        InputInstance(id_=idx,
                      text=[t.text for t in instance.fields["tokens"].tokens])
        for idx, instance in enumerate(dataset)
    ]