示例#1
0
def test_config_create():
    with pytest.raises(ValueError) as excinfo:
        Config(value="test")
    assert "no strategy" in str(excinfo.value).lower()

    with pytest.raises(ValueError) as excinfo:
        Config(strategy="not_available")
    assert "unknown strategy" in str(excinfo.value).lower()

    conf = Config(**TEST_CONFIG)

    assert isinstance(conf.strategy, UnkReplacement)
    assert conf.batch_size == 32
    assert conf.seed == 1234
示例#2
0
def test_engine():
    config = Config.from_dict(TEST_CONFIG)
    engine = Engine(config, batcher=batcher)

    input_instance = InputInstance(id_=1,
                                   sent1=["a", "b", "c"],
                                   sent2=["d", "e", "f"])

    occluded_instances, instance_probabilities = engine.run([input_instance])

    assert len(occluded_instances) == 7

    relevances = engine.relevances(occluded_instances, instance_probabilities)
示例#3
0
def test_config_create_from_json_file():
    path = os.path.join(FIXTURES_ROOT, "test_config.json")
    conf = Config.from_json_file(path)
    assert isinstance(conf.strategy, UnkReplacement)
    assert conf.batch_size == 32
    assert conf.seed == 1234
示例#4
0
def test_config_create_from_dict():
    conf = Config.from_dict(TEST_CONFIG)
    assert isinstance(conf.strategy, UnkReplacement)
    assert conf.batch_size == 32
    assert conf.seed == 1234
示例#5
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files for the CoLA task.")
    parser.add_argument("--model_name_or_path",
                        default=None,
                        type=str,
                        required=True,
                        help="Path to pre-trained model or shortcut name.")
    parser.add_argument("--strategy",
                        default=None,
                        type=str,
                        required=True,
                        help="The explainability strategy to use.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the results will be written.")

    # Other parameters
    parser.add_argument("--do_run",
                        action='store_true',
                        help="Whether to run the explainability strategy.")
    parser.add_argument(
        "--do_relevances",
        action='store_true',
        help="Whether to compute relevances from the run results.")
    parser.add_argument(
        "--cache_dir",
        default=None,
        type=str,
        help=
        "The cache dir. Should contain the candidate_instances.pkl file of a strategy."
    )

    # Optional parameters
    parser.add_argument("--cuda_device",
                        default=0,
                        type=int,
                        help="The default cuda device.")
    parser.add_argument("--overwrite_output_dir",
                        action="store_true",
                        help="Overwrite the content of the output directory")

    args = parser.parse_args()

    if args.strategy.lower() not in ALL_STRATEGIES:
        raise ValueError("Explainability strategy: '{}' not in {}".format(
            args.strategy, ALL_STRATEGIES))

    # if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and not args.overwrite_output_dir:
    #     raise ValueError("Output directory ({}) already exists and is not empty. "
    #                      "Use --overwrite_output_dir to overcome.".format(args.output_dir))

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    dataset = read_cola_dataset(os.path.join(args.data_dir, "dev.tsv"))
    input_instances = dataset_to_input_instances(dataset)
    labels = get_labels(dataset)

    tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
    model = RobertaForSequenceClassification.from_pretrained(
        args.model_name_or_path).to(args.cuda_device)

    if args.strategy.lower() in GRAD_STRATEGIES:
        config_dict = ROBERTA_GRADIENT_CONFIG
        config = Config.from_dict(config_dict)

        # output_getter extracts the first entry of the return tuple and also applies a softmax to the
        # log probabilities
        explainer = {
            "grad": VanillaGradExplainer,
            "gradxinput": GradxInputExplainer,
            "saliency": SaliencyExplainer,
            "integratedgrad": IntegrateGradExplainer,
        }[args.strategy](model=model,
                         input_key="inputs_embeds",
                         output_getter=lambda x: F.softmax(x[0], dim=-1))

        batcher = partial(batcher_gradient,
                          labels=labels,
                          tokenizer=tokenizer,
                          model=model,
                          explainer=explainer,
                          cuda_device=args.cuda_device)
    else:
        config_dict = {
            "unk": ROBERTA_UNK_CONFIG,
            "delete": ROBERTA_DEL_CONFIG,
            "resampling": ROBERTA_RESAMPLING_CONFIG,
            "resampling_std": ROBERTA_RESAMPLING_STD_CONFIG,
        }[args.strategy.lower()]
        config = Config.from_dict(config_dict)

        batcher = partial(batcher_occlusion,
                          labels=labels,
                          tokenizer=tokenizer,
                          model=model,
                          cuda_device=args.cuda_device)

    engine = Engine(config, batcher)

    candidate_results_file = os.path.join(args.output_dir,
                                          "candidate_instances.pkl")

    with open(os.path.join(args.output_dir, "args.json"), "w") as out_f:
        json.dump(vars(args), out_f)

    with open(os.path.join(args.output_dir, "config.json"), "w") as out_f:
        json.dump(config_dict, out_f)

    if args.do_run:
        candidate_instances, candidate_results = engine.run(input_instances)
        with open(candidate_results_file, "wb") as out_f:
            dill.dump((candidate_instances, candidate_results), out_f)

    if args.do_relevances:
        if args.cache_dir is not None:
            candidate_results_file = os.path.join(args.cache_dir,
                                                  "candidate_instances.pkl")

        with open(candidate_results_file, "rb") as in_f:
            candidate_instances, candidate_results = dill.load(in_f)

        relevances = engine.relevances(candidate_instances, candidate_results)

        with open(os.path.join(args.output_dir, "relevances.pkl"),
                  "wb") as out_f:
            dill.dump(relevances, out_f)
示例#6
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files for the MNLI task.")
    parser.add_argument("--model_name_or_path",
                        default=None,
                        type=str,
                        required=True,
                        help="Path to pre-trained model or shortcut name.")
    parser.add_argument("--strategy",
                        default=None,
                        type=str,
                        required=True,
                        help="The explainability strategy to use.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the results will be written.")

    # Other parameters
    parser.add_argument("--do_run",
                        action='store_true',
                        help="Whether to run the explainability strategy.")
    parser.add_argument(
        "--do_relevances",
        action='store_true',
        help="Whether to compute relevances from the run results.")
    parser.add_argument(
        "--cache_dir",
        default=None,
        type=str,
        help=
        "The cache dir. Should contain the candidate_instances.pkl file of a strategy."
    )

    # Optional parameters
    parser.add_argument("--cuda_device",
                        default=0,
                        type=int,
                        help="The default cuda device.")
    parser.add_argument("--overwrite_output_dir",
                        action="store_true",
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        "--predictor_name",
        default="sst_text_classifier",
        type=str,
        help="The predictor name. Defaults to sst_text_classifier.")

    args = parser.parse_args()

    if args.strategy.lower() not in ALL_STRATEGIES:
        raise ValueError("Explainability strategy: '{}' not in {}".format(
            args.strategy, ALL_STRATEGIES))

    # if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and not args.overwrite_output_dir:
    #     raise ValueError("Output directory ({}) already exists and is not empty. "
    #                      "Use --overwrite_output_dir to overcome.".format(args.output_dir))

    # disable cudnn when running on GPU, because can't do a backward pass when not in train mode
    if args.cuda_device >= 0:
        torch.backends.cudnn.enabled = False

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    archive = load_archive(archive_file=os.path.join(args.model_name_or_path,
                                                     "model.tar.gz"),
                           cuda_device=args.cuda_device)
    predictor = Predictor.from_archive(archive, args.predictor_name)

    dataset = predictor._dataset_reader.read(
        os.path.join(args.data_dir, "dev.tsv"))

    input_instances = dataset_to_input_instances(dataset)
    labels = get_labels(dataset)

    if args.strategy.lower() in GRAD_STRATEGIES:
        config_dict = SST2_GRADIENT_CONFIG
        config = Config.from_dict(config_dict)

        # output_getter extracts the first entry of the return tuple and also applies a softmax to the
        # log probabilities
        explainer = {
            "grad": AllenNLPVanillaGradExplainer,
            "gradxinput": AllenNLPGradxInputExplainer,
            "saliency": AllenNLPSaliencyExplainer,
            "integratedgrad": AllenNLPIntegrateGradExplainer,
        }[args.strategy](
            predictor=predictor,
            output_getter=lambda x: F.softmax(x["logits"], dim=-1))

        batcher = partial(batcher_gradient,
                          labels=labels,
                          predictor=predictor,
                          explainer=explainer,
                          cuda_device=args.cuda_device)
    else:
        config_dict = {
            "unk": SST2_UNK_CONFIG,
            "delete": SST2_DEL_CONFIG,
            "resampling": SST2_RESAMPLING_CONFIG,
            "resampling_std": SST2_RESAMPLING_STD_CONFIG,
        }[args.strategy.lower()]
        config = Config.from_dict(config_dict)

        batcher = partial(batcher_occlusion,
                          labels=labels,
                          predictor=predictor)

    engine = Engine(config, batcher)

    candidate_results_file = os.path.join(args.output_dir,
                                          "candidate_instances.pkl")

    with open(os.path.join(args.output_dir, "args.json"), "w") as out_f:
        json.dump(vars(args), out_f)

    with open(os.path.join(args.output_dir, "config.json"), "w") as out_f:
        json.dump(config_dict, out_f)

    if args.do_run:
        candidate_instances, candidate_results = engine.run(input_instances)
        with open(candidate_results_file, "wb") as out_f:
            dill.dump((candidate_instances, candidate_results), out_f)

    if args.do_relevances:
        if args.cache_dir is not None:
            candidate_results_file = os.path.join(args.cache_dir,
                                                  "candidate_instances.pkl")

        with open(candidate_results_file, "rb") as in_f:
            candidate_instances, candidate_results = dill.load(in_f)

        relevances = engine.relevances(candidate_instances, candidate_results)

        with open(os.path.join(args.output_dir, "relevances.pkl"),
                  "wb") as out_f:
            dill.dump(relevances, out_f)