예제 #1
0
    def forward_eval_layer_with_inputs_helper(self, model, inputs_to_test):
        # hard coding for simplicity
        # 0 if using args, 1 if using kwargs
        #   => no 0s after first 1 (left to right)
        #
        # used to test utilization of args/kwargs
        use_args_or_kwargs = [
            [[0], [1]],
            [
                [0, 0],
                [0, 1],
                [1, 1],
            ],
        ]

        model = ModelInputWrapper(model)

        def forward_func(*args, args_or_kwargs=None):
            # convert to args or kwargs to test *args and **kwargs wrapping behavior
            new_args = []
            new_kwargs = {}
            for args_or_kwarg, name, inp in zip(args_or_kwargs,
                                                inputs_to_test.keys(), args):
                if args_or_kwarg:
                    new_kwargs[name] = inp
                else:
                    new_args.append(inp)
            return model(*new_args, **new_kwargs)

        for args_or_kwargs in use_args_or_kwargs[len(inputs_to_test) - 1]:
            with self.subTest(args_or_kwargs=args_or_kwargs):
                inputs = _forward_layer_eval(
                    functools.partial(forward_func,
                                      args_or_kwargs=args_or_kwargs),
                    inputs=tuple(inputs_to_test.values()),
                    layer=[
                        model.input_maps[name]
                        for name in inputs_to_test.keys()
                    ],
                )

                inputs_with_attrib_to_inp = _forward_layer_eval(
                    functools.partial(forward_func,
                                      args_or_kwargs=args_or_kwargs),
                    inputs=tuple(inputs_to_test.values()),
                    layer=[
                        model.input_maps[name]
                        for name in inputs_to_test.keys()
                    ],
                    attribute_to_layer_input=True,
                )

                for i1, i2, i3 in zip(inputs, inputs_with_attrib_to_inp,
                                      inputs_to_test.values()):
                    self.assertTrue((i1[0] == i2[0]).all())
                    self.assertTrue((i1[0] == i3).all())
예제 #2
0
def main(cfg):
    # Initialize the dataset
    blastchar_dataset = BlastcharDataset(cfg.dataset.path)
    NUM_CATEGORICAL_COLS = blastchar_dataset.num_categorical_cols
    NUM_CONTINIOUS_COLS = blastchar_dataset.num_continious_cols
    EMBED_DIM = 32

    # initialize the model with its arguments
    mlp = nn.Sequential(
        nn.Linear(NUM_CATEGORICAL_COLS * EMBED_DIM + NUM_CONTINIOUS_COLS, 50),
        nn.ReLU(), nn.BatchNorm1d(50), nn.Dropout(cfg.params.dropout),
        nn.Linear(50, 20), nn.ReLU(), nn.BatchNorm1d(20),
        nn.Dropout(cfg.params.dropout),
        nn.Linear(20, blastchar_dataset.num_classes))

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = TabTransformer(blastchar_dataset.num_categories,
                           mlp,
                           embed_dim=EMBED_DIM,
                           num_cont_cols=NUM_CONTINIOUS_COLS)
    model.load_state_dict(torch.load(cfg.params.weights), strict=False)
    model = model.to(device)
    model.eval()

    model = ModelInputWrapper(model)

    cat, cont, _ = blastchar_dataset[0]
    cat, cont = cat.unsqueeze(0).long(), cont.unsqueeze(0).float()
    cat = torch.cat((cat, cat), dim=0)
    cont = torch.cat((cont, cont), dim=0)
    input = (cat, cont)

    outs = model(*input)
    preds = outs.argmax(-1)

    attr = LayerIntegratedGradients(
        model, [model.module.embed, model.module.layer_norm])

    attributions, _ = attr.attribute(
        inputs=(cat, cont),
        baselines=(torch.zeros_like(cat, dtype=torch.long),
                   torch.zeros_like(cont, dtype=torch.float32)),
        target=preds.detach(),
        n_steps=30,
        return_convergence_delta=True)

    print(f'attributions: {attributions[0].shape, attributions[1].shape}')
    pprint(torch.cat((attributions[0].sum(dim=2), attributions[1]), dim=1))
예제 #3
0
    def layer_method_with_input_layer_patches(
        self,
        layer_method_class: Callable,
        equiv_method_class: Callable,
        multi_layer: bool,
    ) -> None:
        model = BasicModel_MultiLayer_TrueMultiInput(
        ) if multi_layer else BasicModel()

        input_names = ["x1", "x2", "x3", "x4"] if multi_layer else ["input"]
        model = ModelInputWrapper(model)

        layers = [model.input_maps[inp] for inp in input_names]
        layer_method = layer_method_class(
            model, layer=layers if multi_layer else layers[0])
        equivalent_method = equiv_method_class(model)

        inputs = tuple(torch.rand(5, 3) for _ in input_names)
        baseline = tuple(torch.zeros(5, 3) for _ in input_names)

        args = inspect.getfullargspec(
            equivalent_method.attribute.__wrapped__).args

        args_to_use = [inputs]
        if "baselines" in args:
            args_to_use += [baseline]

        a1 = layer_method.attribute(*args_to_use, target=0)
        a2 = layer_method.attribute(*args_to_use,
                                    target=0,
                                    attribute_to_layer_input=True)

        real_attributions = equivalent_method.attribute(*args_to_use, target=0)

        if not isinstance(a1, tuple):
            a1 = (a1, )
            a2 = (a2, )

        if not isinstance(real_attributions, tuple):
            real_attributions = (real_attributions, )

        assertTensorTuplesAlmostEqual(self, a1, a2)
        assertTensorTuplesAlmostEqual(self, a1, real_attributions)
예제 #4
0
파일: test_config.py 프로젝트: xvdp/captum
        "name": "basic_layer_ig_multi_layer_multi_output",
        "algorithms": [LayerIntegratedGradients],
        "model": BasicModel_MultiLayer_TrueMultiInput(),
        "layer": ["m1", "m234"],
        "attribute_args": {
            "inputs": (
                torch.randn(5, 3),
                torch.randn(5, 3),
                torch.randn(5, 3),
                torch.randn(5, 3),
            ),
            "target": 0,
        },
    },
    {
        "name": "basic_layer_ig_multi_layer_multi_output_with_input_wrapper",
        "algorithms": [LayerIntegratedGradients],
        "model": ModelInputWrapper(BasicModel_MultiLayer_TrueMultiInput()),
        "layer": ["module.m1", "module.m234"],
        "attribute_args": {
            "inputs": (
                torch.randn(5, 3),
                torch.randn(5, 3),
                torch.randn(5, 3),
                torch.randn(5, 3),
            ),
            "target": 0,
        },
    },
]