Пример #1
0
 def test_lrp_ixg_equivalency(self) -> None:
     model, inputs = _get_simple_model()
     lrp = LRP(model)
     attributions_lrp = lrp.attribute(inputs)
     ixg = InputXGradient(model)
     attributions_ixg = ixg.attribute(inputs)
     assertTensorAlmostEqual(
         self, attributions_lrp, attributions_ixg
     )  # Divide by score because LRP relevance is normalized.
Пример #2
0
class Explainer():
    def __init__(self, model):
        self.model = model
        self.explain = InputXGradient(model)

    def get_attribution_map(self, img, target=None):
        if target is None:
            target = torch.argmax(self.model(img), 1)
        attributions = self.explain.attribute(img, target=target)
        return attributions
Пример #3
0
 def __init__(self,model):
     self.model=copy.deepcopy(model)
     self.explain=InputXGradient(model)
Пример #4
0
 Occlusion.get_name():
 ConfigParameters(
     params={
         "sliding_window_shapes": StrConfig(value=""),
         "strides": StrConfig(value=""),
         "perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)),
     },
     post_process={
         "sliding_window_shapes": _str_to_tuple,
         "strides": _str_to_tuple,
         "perturbations_per_eval": int,
     },
 ),
 GuidedBackprop.get_name():
 ConfigParameters(params={}),
 InputXGradient.get_name():
 ConfigParameters(params={}),
 Saliency.get_name():
 ConfigParameters(
     params={"abs": StrEnumConfig(limit=["True", "False"], value="True")},
     post_process={"abs": _str_to_bool}),
 # Won't work as Relu is being used in multiple places (same layer can't be shared)
 # DeepLift.get_name(): ConfigParameters(
 #     params={}
 # ),
 LayerIntegratedGradients.get_name():
 ConfigParameters(
     params={
         "n_steps":
         NumberConfig(value=25, limit=(2, None)),
         "method":
Пример #5
0
def generate_saliency(model_path, saliency_path, saliency, aggregation):
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, loc: storage)
    model_args = Namespace(**checkpoint['args'])
    if args.model == 'lstm':
        model = LSTM_MODEL(tokenizer,
                           model_args,
                           n_labels=checkpoint['args']['labels']).to(device)
        model.load_state_dict(checkpoint['model'])
    elif args.model == 'trans':
        transformer_config = BertConfig.from_pretrained(
            'bert-base-uncased', num_labels=model_args.labels)
        model_cp = BertForSequenceClassification.from_pretrained(
            'bert-base-uncased', config=transformer_config).to(device)
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        model_cp.load_state_dict(checkpoint['model'])
        model = BertModelWrapper(model_cp)
    else:
        model = CNN_MODEL(tokenizer,
                          model_args,
                          n_labels=checkpoint['args']['labels']).to(device)
        model.load_state_dict(checkpoint['model'])

    model.train()

    pad_to_max = False
    if saliency == 'deeplift':
        ablator = DeepLift(model)
    elif saliency == 'guided':
        ablator = GuidedBackprop(model)
    elif saliency == 'sal':
        ablator = Saliency(model)
    elif saliency == 'inputx':
        ablator = InputXGradient(model)
    elif saliency == 'occlusion':
        ablator = Occlusion(model)

    coll_call = get_collate_fn(dataset=args.dataset, model=args.model)

    return_attention_masks = args.model == 'trans'

    collate_fn = partial(coll_call,
                         tokenizer=tokenizer,
                         device=device,
                         return_attention_masks=return_attention_masks,
                         pad_to_max_length=pad_to_max)
    test = get_dataset(path=args.dataset_dir,
                       mode=args.split,
                       dataset=args.dataset)
    batch_size = args.batch_size if args.batch_size != None else \
        model_args.batch_size
    test_dl = DataLoader(batch_size=batch_size,
                         dataset=test,
                         shuffle=False,
                         collate_fn=collate_fn)

    # PREDICTIONS
    predictions_path = model_path + '.predictions'
    if not os.path.exists(predictions_path):
        predictions = defaultdict(lambda: [])
        for batch in tqdm(test_dl, desc='Running test prediction... '):
            if args.model == 'trans':
                logits = model(batch[0],
                               attention_mask=batch[1],
                               labels=batch[2].long())
            else:
                logits = model(batch[0])
            logits = logits.detach().cpu().numpy().tolist()
            predicted = np.argmax(np.array(logits), axis=-1)
            predictions['class'] += predicted.tolist()
            predictions['logits'] += logits

        with open(predictions_path, 'w') as out:
            json.dump(predictions, out)

    # COMPUTE SALIENCY
    if saliency != 'occlusion':
        embedding_layer_name = 'model.bert.embeddings' if args.model == \
                                                          'trans' else \
            'embedding'
        interpretable_embedding = configure_interpretable_embedding_layer(
            model, embedding_layer_name)

    class_attr_list = defaultdict(lambda: [])
    token_ids = []
    saliency_flops = []

    for batch in tqdm(test_dl, desc='Running Saliency Generation...'):
        if args.model == 'cnn':
            additional = None
        elif args.model == 'trans':
            additional = (batch[1], batch[2])
        else:
            additional = batch[-1]

        token_ids += batch[0].detach().cpu().numpy().tolist()
        if saliency != 'occlusion':
            input_embeddings = interpretable_embedding.indices_to_embeddings(
                batch[0])

        if not args.no_time:
            high.start_counters([
                events.PAPI_FP_OPS,
            ])
        for cls_ in range(checkpoint['args']['labels']):
            if saliency == 'occlusion':
                attributions = ablator.attribute(
                    batch[0],
                    sliding_window_shapes=(args.sw, ),
                    target=cls_,
                    additional_forward_args=additional)
            else:
                attributions = ablator.attribute(
                    input_embeddings,
                    target=cls_,
                    additional_forward_args=additional)

            attributions = summarize_attributions(
                attributions, type=aggregation, model=model,
                tokens=batch[0]).detach().cpu().numpy().tolist()
            class_attr_list[cls_] += [[_li for _li in _l]
                                      for _l in attributions]

        if not args.no_time:
            saliency_flops.append(
                sum(high.stop_counters()) / batch[0].shape[0])

    if saliency != 'occlusion':
        remove_interpretable_embedding_layer(model, interpretable_embedding)

    # SERIALIZE
    print('Serializing...', flush=True)
    with open(saliency_path, 'w') as out:
        for instance_i, _ in enumerate(test):
            saliencies = []
            for token_i, token_id in enumerate(token_ids[instance_i]):
                token_sal = {'token': tokenizer.ids_to_tokens[token_id]}
                for cls_ in range(checkpoint['args']['labels']):
                    token_sal[int(
                        cls_)] = class_attr_list[cls_][instance_i][token_i]
                saliencies.append(token_sal)

            out.write(json.dumps({'tokens': saliencies}) + '\n')
            out.flush()

    return saliency_flops
Пример #6
0
 def __init__(self, model):
     self.model = model
     self.explain = InputXGradient(model)
Пример #7
0
 def __init__(self, trainer):
     CaptumDerivative.__init__(self, trainer)
     InputXGradient.__init__(self, self.trainer)
Пример #8
0
def get_attribution(real_img, 
                    fake_img, 
                    real_class, 
                    fake_class, 
                    net_module, 
                    checkpoint_path, 
                    input_shape, 
                    channels,
                    methods=["ig", "grads", "gc", "ggc", "dl", "ingrad", "random", "residual"],
                    output_classes=6,
                    downsample_factors=[(2,2), (2,2), (2,2), (2,2)]):


    imgs = [image_to_tensor(normalize_image(real_img).astype(np.float32)), 
            image_to_tensor(normalize_image(fake_img).astype(np.float32))]

    classes = [real_class, fake_class]
    net = init_network(checkpoint_path, input_shape, net_module, channels, output_classes=output_classes,eval_net=True, require_grad=False,
                       downsample_factors=downsample_factors)

    attrs = []
    attrs_names = []

    if "residual" in methods:
        res = np.abs(real_img - fake_img)
        res = res - np.min(res)
        attrs.append(torch.tensor(res/np.max(res)))
        attrs_names.append("residual")

    if "random" in methods:
        rand = np.abs(np.random.randn(*np.shape(real_img)))
        rand = np.abs(scipy.ndimage.filters.gaussian_filter(rand, 4))
        rand = rand - np.min(rand)
        rand = rand/np.max(np.abs(rand))
        attrs.append(torch.tensor(rand))
        attrs_names.append("random")

    if "gc" in methods:
        net.zero_grad()
        last_conv_layer = [(name,module) for name, module in net.named_modules() if type(module) == torch.nn.Conv2d][-1]
        layer_name = last_conv_layer[0]
        layer = last_conv_layer[1]
        layer_gc = LayerGradCam(net, layer)
        gc_real = layer_gc.attribute(imgs[0], target=classes[0])
        gc_fake = layer_gc.attribute(imgs[1], target=classes[1])

        gc_real = project_layer_activations_to_input_rescale(gc_real.cpu().detach().numpy(), (input_shape[0], input_shape[1]))
        gc_fake = project_layer_activations_to_input_rescale(gc_fake.cpu().detach().numpy(), (input_shape[0], input_shape[1]))

        attrs.append(torch.tensor(gc_real[0,0,:,:]))
        attrs_names.append("gc_real")

        attrs.append(torch.tensor(gc_fake[0,0,:,:]))
        attrs_names.append("gc_fake")

        # SCAM
        gc_diff_0, gc_diff_1 = get_sgc(real_img, fake_img, real_class, 
                                     fake_class, net_module, checkpoint_path, 
                                     input_shape, channels, None, output_classes=output_classes,
                                     downsample_factors=downsample_factors)
        attrs.append(gc_diff_0)
        attrs_names.append("gc_diff_0")

        attrs.append(gc_diff_1)
        attrs_names.append("gc_diff_1")

    if "ggc" in methods:
        net.zero_grad()
        last_conv = [module for module in net.modules() if type(module) == torch.nn.Conv2d][-1]
        guided_gc = GuidedGradCam(net, last_conv)
        ggc_real = guided_gc.attribute(imgs[0], target=classes[0])
        ggc_fake = guided_gc.attribute(imgs[1], target=classes[1])

        attrs.append(ggc_real[0,0,:,:])
        attrs_names.append("ggc_real")

        attrs.append(ggc_fake[0,0,:,:])
        attrs_names.append("ggc_fake")

        net.zero_grad()
        gbp = GuidedBackprop(net)
        gbp_real = gbp.attribute(imgs[0], target=classes[0])
        gbp_fake = gbp.attribute(imgs[1], target=classes[1])
        
        attrs.append(gbp_real[0,0,:,:])
        attrs_names.append("gbp_real")

        attrs.append(gbp_fake[0,0,:,:])
        attrs_names.append("gbp_fake")

        ggc_diff_0 = gbp_real[0,0,:,:] * gc_diff_0
        ggc_diff_1 = gbp_fake[0,0,:,:] * gc_diff_1

        attrs.append(ggc_diff_0)
        attrs_names.append("ggc_diff_0")

        attrs.append(ggc_diff_1)
        attrs_names.append("ggc_diff_1")

    # IG
    if "ig" in methods:
        baseline = image_to_tensor(np.zeros(input_shape, dtype=np.float32))
        net.zero_grad()
        ig = IntegratedGradients(net)
        ig_real, delta_real = ig.attribute(imgs[0], baseline, target=classes[0], return_convergence_delta=True)
        ig_fake, delta_fake = ig.attribute(imgs[1], baseline, target=classes[1], return_convergence_delta=True)
        ig_diff_0, delta_diff = ig.attribute(imgs[0], imgs[1], target=classes[0], return_convergence_delta=True)
        ig_diff_1, delta_diff = ig.attribute(imgs[1], imgs[0], target=classes[1], return_convergence_delta=True)

        attrs.append(ig_real[0,0,:,:])
        attrs_names.append("ig_real")

        attrs.append(ig_fake[0,0,:,:])
        attrs_names.append("ig_fake")

        attrs.append(ig_diff_0[0,0,:,:])
        attrs_names.append("ig_diff_0")

        attrs.append(ig_diff_1[0,0,:,:])
        attrs_names.append("ig_diff_1")

        
    # DL
    if "dl" in methods:
        net.zero_grad()
        dl = DeepLift(net)
        dl_real = dl.attribute(imgs[0], target=classes[0])
        dl_fake = dl.attribute(imgs[1], target=classes[1])
        dl_diff_0 = dl.attribute(imgs[0], baselines=imgs[1], target=classes[0])
        dl_diff_1 = dl.attribute(imgs[1], baselines=imgs[0], target=classes[1])

        attrs.append(dl_real[0,0,:,:])
        attrs_names.append("dl_real")

        attrs.append(dl_fake[0,0,:,:])
        attrs_names.append("dl_fake")

        attrs.append(dl_diff_0[0,0,:,:])
        attrs_names.append("dl_diff_0")

        attrs.append(dl_diff_1[0,0,:,:])
        attrs_names.append("dl_diff_1")

    # INGRAD
    if "ingrad" in methods:
        net.zero_grad()
        saliency = Saliency(net)
        grads_real = saliency.attribute(imgs[0], 
                                        target=classes[0]) 
        grads_fake = saliency.attribute(imgs[1], 
                                        target=classes[1]) 

        attrs.append(grads_real[0,0,:,:])
        attrs_names.append("grads_real")

        attrs.append(grads_fake[0,0,:,:])
        attrs_names.append("grads_fake")

        net.zero_grad()
        input_x_gradient = InputXGradient(net)
        ingrad_real = input_x_gradient.attribute(imgs[0], target=classes[0])
        ingrad_fake = input_x_gradient.attribute(imgs[1], target=classes[1])

        ingrad_diff_0 = grads_fake * (imgs[0] - imgs[1])
        ingrad_diff_1 = grads_real * (imgs[1] - imgs[0])

        attrs.append(torch.abs(ingrad_real[0,0,:,:]))
        attrs_names.append("ingrad_real")

        attrs.append(torch.abs(ingrad_fake[0,0,:,:]))
        attrs_names.append("ingrad_fake")

        attrs.append(torch.abs(ingrad_diff_0[0,0,:,:]))
        attrs_names.append("ingrad_diff_0")

        attrs.append(torch.abs(ingrad_diff_1[0,0,:,:]))
        attrs_names.append("ingrad_diff_1")

    attrs = [a.detach().cpu().numpy() for a in attrs]
    attrs_norm = [a/np.max(np.abs(a)) for a in attrs]

    return attrs_norm, attrs_names
Пример #9
0
def evaluation_ten_classes(initiate_or_load_model,
                           config_data,
                           singleton_scope=False,
                           reshape_size=None,
                           FIND_OPTIM_BRANCH_MODEL=False,
                           realtime_update=False,
                           ALLOW_ADHOC_NOPTIM=False):
    from pipeline.training.training_utils import prepare_save_dirs
    xai_mode = config_data['xai_mode']
    MODEL_DIR, INFO_DIR, CACHE_FOLDER_DIR = prepare_save_dirs(config_data)

    ############################
    VERBOSE = 0
    ############################

    if not FIND_OPTIM_BRANCH_MODEL:
        print(
            'Using the following the model from (only) continuous training for xai evaluation [%s]'
            % (str(xai_mode)))
        net, evaluator = initiate_or_load_model(MODEL_DIR,
                                                INFO_DIR,
                                                config_data,
                                                verbose=VERBOSE)
    else:
        BRANCH_FOLDER_DIR = MODEL_DIR[:MODEL_DIR.find('.model')] + '.%s' % (
            str(config_data['branch_name_label']))
        BRANCH_MODEL_DIR = os.path.join(
            BRANCH_FOLDER_DIR,
            '%s.%s.model' % (str(config_data['model_name']),
                             str(config_data['branch_name_label'])))
        # BRANCH_MODEL_DIR = MODEL_DIR[:MODEL_DIR.find('.model')] + '.%s.model'%(str(config_data['branch_name_label']))

        if ALLOW_ADHOC_NOPTIM:  # this is intended only for debug runs
            print('<< [EXY1] ALLOWING ADHOC NOPTIM >>')
            import shutil
            shutil.copyfile(BRANCH_MODEL_DIR, BRANCH_MODEL_DIR + '.noptim')

        if os.path.exists(BRANCH_MODEL_DIR + '.optim'):
            BRANCH_MODEL_DIR = BRANCH_MODEL_DIR + '.optim'
            print(
                '  Using the OPTIMIZED branch model for [%s] xai evaluation: %s'
                % (str(xai_mode), str(BRANCH_MODEL_DIR)))
        elif os.path.exists(BRANCH_MODEL_DIR + '.noptim'):
            BRANCH_MODEL_DIR = BRANCH_MODEL_DIR + '.noptim'
            print(
                '  Using the partially optimized branch model for [%s] xai evaluation: %s'
                % (str(xai_mode), str(BRANCH_MODEL_DIR)))
        else:
            raise RuntimeError(
                'Attempting to find .optim or .noptim model, but not found.')
        if VERBOSE >= 250:
            print(
                '  """You may see a warning by pytorch for ReLu backward hook. It has been fixed externally, so you can ignore it."""'
            )
        net, evaluator = initiate_or_load_model(BRANCH_MODEL_DIR,
                                                INFO_DIR,
                                                config_data,
                                                verbose=VERBOSE)

    if xai_mode == 'Saliency': attrmodel = Saliency(net)
    elif xai_mode == 'IntegratedGradients':
        attrmodel = IntegratedGradients(net)
    elif xai_mode == 'InputXGradient':
        attrmodel = InputXGradient(net)
    elif xai_mode == 'DeepLift':
        attrmodel = DeepLift(net)
    elif xai_mode == 'GuidedBackprop':
        attrmodel = GuidedBackprop(net)
    elif xai_mode == 'GuidedGradCam':
        attrmodel = GuidedGradCam(net, net.select_first_layer())  # first layer
    elif xai_mode == 'Deconvolution':
        attrmodel = Deconvolution(net)
    elif xai_mode == 'GradientShap':
        attrmodel = GradientShap(net)
    elif xai_mode == 'DeepLiftShap':
        attrmodel = DeepLiftShap(net)
    else:
        raise RuntimeError('No valid attribution selected.')

    if singleton_scope:  # just to observe a single datapoint, mostly for debugging
        singleton_scope_oberservation(net, attrmodel, config_data,
                                      CACHE_FOLDER_DIR)
    else:
        aggregate_evaluation(net,
                             attrmodel,
                             config_data,
                             CACHE_FOLDER_DIR,
                             reshape_size=reshape_size,
                             realtime_update=realtime_update,
                             EVALUATE_BRANCH=FIND_OPTIM_BRANCH_MODEL)