def forward_eval_layer_with_inputs_helper(self, model, inputs_to_test): # hard coding for simplicity # 0 if using args, 1 if using kwargs # => no 0s after first 1 (left to right) # # used to test utilization of args/kwargs use_args_or_kwargs = [ [[0], [1]], [ [0, 0], [0, 1], [1, 1], ], ] model = ModelInputWrapper(model) def forward_func(*args, args_or_kwargs=None): # convert to args or kwargs to test *args and **kwargs wrapping behavior new_args = [] new_kwargs = {} for args_or_kwarg, name, inp in zip(args_or_kwargs, inputs_to_test.keys(), args): if args_or_kwarg: new_kwargs[name] = inp else: new_args.append(inp) return model(*new_args, **new_kwargs) for args_or_kwargs in use_args_or_kwargs[len(inputs_to_test) - 1]: with self.subTest(args_or_kwargs=args_or_kwargs): inputs = _forward_layer_eval( functools.partial(forward_func, args_or_kwargs=args_or_kwargs), inputs=tuple(inputs_to_test.values()), layer=[ model.input_maps[name] for name in inputs_to_test.keys() ], ) inputs_with_attrib_to_inp = _forward_layer_eval( functools.partial(forward_func, args_or_kwargs=args_or_kwargs), inputs=tuple(inputs_to_test.values()), layer=[ model.input_maps[name] for name in inputs_to_test.keys() ], attribute_to_layer_input=True, ) for i1, i2, i3 in zip(inputs, inputs_with_attrib_to_inp, inputs_to_test.values()): self.assertTrue((i1[0] == i2[0]).all()) self.assertTrue((i1[0] == i3).all())
def main(cfg): # Initialize the dataset blastchar_dataset = BlastcharDataset(cfg.dataset.path) NUM_CATEGORICAL_COLS = blastchar_dataset.num_categorical_cols NUM_CONTINIOUS_COLS = blastchar_dataset.num_continious_cols EMBED_DIM = 32 # initialize the model with its arguments mlp = nn.Sequential( nn.Linear(NUM_CATEGORICAL_COLS * EMBED_DIM + NUM_CONTINIOUS_COLS, 50), nn.ReLU(), nn.BatchNorm1d(50), nn.Dropout(cfg.params.dropout), nn.Linear(50, 20), nn.ReLU(), nn.BatchNorm1d(20), nn.Dropout(cfg.params.dropout), nn.Linear(20, blastchar_dataset.num_classes)) device = 'cuda' if torch.cuda.is_available() else 'cpu' model = TabTransformer(blastchar_dataset.num_categories, mlp, embed_dim=EMBED_DIM, num_cont_cols=NUM_CONTINIOUS_COLS) model.load_state_dict(torch.load(cfg.params.weights), strict=False) model = model.to(device) model.eval() model = ModelInputWrapper(model) cat, cont, _ = blastchar_dataset[0] cat, cont = cat.unsqueeze(0).long(), cont.unsqueeze(0).float() cat = torch.cat((cat, cat), dim=0) cont = torch.cat((cont, cont), dim=0) input = (cat, cont) outs = model(*input) preds = outs.argmax(-1) attr = LayerIntegratedGradients( model, [model.module.embed, model.module.layer_norm]) attributions, _ = attr.attribute( inputs=(cat, cont), baselines=(torch.zeros_like(cat, dtype=torch.long), torch.zeros_like(cont, dtype=torch.float32)), target=preds.detach(), n_steps=30, return_convergence_delta=True) print(f'attributions: {attributions[0].shape, attributions[1].shape}') pprint(torch.cat((attributions[0].sum(dim=2), attributions[1]), dim=1))
def layer_method_with_input_layer_patches( self, layer_method_class: Callable, equiv_method_class: Callable, multi_layer: bool, ) -> None: model = BasicModel_MultiLayer_TrueMultiInput( ) if multi_layer else BasicModel() input_names = ["x1", "x2", "x3", "x4"] if multi_layer else ["input"] model = ModelInputWrapper(model) layers = [model.input_maps[inp] for inp in input_names] layer_method = layer_method_class( model, layer=layers if multi_layer else layers[0]) equivalent_method = equiv_method_class(model) inputs = tuple(torch.rand(5, 3) for _ in input_names) baseline = tuple(torch.zeros(5, 3) for _ in input_names) args = inspect.getfullargspec( equivalent_method.attribute.__wrapped__).args args_to_use = [inputs] if "baselines" in args: args_to_use += [baseline] a1 = layer_method.attribute(*args_to_use, target=0) a2 = layer_method.attribute(*args_to_use, target=0, attribute_to_layer_input=True) real_attributions = equivalent_method.attribute(*args_to_use, target=0) if not isinstance(a1, tuple): a1 = (a1, ) a2 = (a2, ) if not isinstance(real_attributions, tuple): real_attributions = (real_attributions, ) assertTensorTuplesAlmostEqual(self, a1, a2) assertTensorTuplesAlmostEqual(self, a1, real_attributions)
"name": "basic_layer_ig_multi_layer_multi_output", "algorithms": [LayerIntegratedGradients], "model": BasicModel_MultiLayer_TrueMultiInput(), "layer": ["m1", "m234"], "attribute_args": { "inputs": ( torch.randn(5, 3), torch.randn(5, 3), torch.randn(5, 3), torch.randn(5, 3), ), "target": 0, }, }, { "name": "basic_layer_ig_multi_layer_multi_output_with_input_wrapper", "algorithms": [LayerIntegratedGradients], "model": ModelInputWrapper(BasicModel_MultiLayer_TrueMultiInput()), "layer": ["module.m1", "module.m234"], "attribute_args": { "inputs": ( torch.randn(5, 3), torch.randn(5, 3), torch.randn(5, 3), torch.randn(5, 3), ), "target": 0, }, }, ]