예제 #1
0
def generate_pred(model_path):
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, loc: storage)
    model_args = Namespace(**checkpoint['args'])
    if args.model == 'lstm':
        model = LSTM_MODEL(tokenizer, model_args,
                           n_labels=checkpoint['args']['labels']).to(device)
        model.load_state_dict(checkpoint['model'])
    elif args.model == 'trans':
        transformer_config = BertConfig.from_pretrained('bert-base-uncased',
                                                        num_labels=model_args.labels)
        model_cp = BertForSequenceClassification.from_pretrained(
            'bert-base-uncased', config=transformer_config).to(
            device)
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        model_cp.load_state_dict(checkpoint['model'])
        model = BertModelWrapper(model_cp)
    else:
        model = CNN_MODEL(tokenizer, model_args,
                          n_labels=checkpoint['args']['labels']).to(device)
        model.load_state_dict(checkpoint['model'])

    model.train()

    pad_to_max = False

    coll_call = get_collate_fn(dataset=args.dataset, model=args.model)

    return_attention_masks = args.model == 'trans'

    collate_fn = partial(coll_call, tokenizer=tokenizer, device=device,
                         return_attention_masks=return_attention_masks,
                         pad_to_max_length=pad_to_max)
    test = get_dataset(path=args.dataset_dir, mode=args.split,
                       dataset=args.dataset)
    batch_size = args.batch_size if args.batch_size is not None else \
        model_args.batch_size
    test_dl = DataLoader(batch_size=batch_size, dataset=test, shuffle=False,
                         collate_fn=collate_fn)

    # PREDICTIONS
    predictions_path = model_path + '.predictions'
    predictions = defaultdict(lambda: [])
    for batch in tqdm(test_dl, desc='Running test prediction... '):
        if args.model == 'trans':
            logits = model(batch[0], attention_mask=batch[1],
                           labels=batch[2].long())
        else:
            logits = model(batch[0])
        logits = logits.detach().cpu().numpy().tolist()
        predicted = np.argmax(np.array(logits), axis=-1)
        predictions['class'] += predicted.tolist()
        predictions['logits'] += logits

        with open(predictions_path, 'w') as out:
            json.dump(predictions, out)
예제 #2
0
def get_model(model_path):
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, loc: storage)
    model_args = Namespace(**checkpoint['args'])
    if args.model == 'lstm':
        model_cp = LSTM_MODEL(tokenizer,
                              model_args,
                              n_labels=checkpoint['args']['labels']).to(device)
    elif args.model == 'trans':
        transformer_config = BertConfig.from_pretrained(
            'bert-base-uncased', num_labels=model_args.labels)
        model_cp = BertForSequenceClassification.from_pretrained(
            'bert-base-uncased', config=transformer_config).to(device)
    else:
        model_cp = CNN_MODEL(tokenizer,
                             model_args,
                             n_labels=checkpoint['args']['labels']).to(device)

    model_cp.load_state_dict(checkpoint['model'])

    return model_cp, model_args
예제 #3
0
def get_model():
    if args.model == 'trans':
        transformer_config = BertConfig.from_pretrained('bert-base-uncased',
                                                        num_labels=args.labels)
        if args.init_only:
            model = BertForSequenceClassification(
                config=transformer_config).to(device)
        else:
            model = BertForSequenceClassification.from_pretrained(
                'bert-base-uncased', config=transformer_config).to(device)
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr)
        es = EarlyStopping(patience=args.patience,
                           percentage=False,
                           mode='max',
                           min_delta=0.0)
        scheduler = get_constant_schedule_with_warmup(optimizer,
                                                      num_warmup_steps=0.05)
    else:
        if args.model == 'cnn':
            model = CNN_MODEL(tokenizer, args, n_labels=args.labels).to(device)
        elif args.model == 'lstm':
            model = LSTM_MODEL(tokenizer, args,
                               n_labels=args.labels).to(device)

        optimizer = AdamW(model.parameters(), lr=args.lr)
        scheduler = ReduceLROnPlateau(optimizer, verbose=True)
        es = EarlyStopping(patience=args.patience,
                           percentage=False,
                           mode='max',
                           min_delta=0.0)

    return model, optimizer, scheduler, es
예제 #4
0
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True

    device = torch.device("cuda") if args.gpu else torch.device("cpu")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    collate_fn = partial(collate_nli,
                         tokenizer=tokenizer,
                         device=device,
                         return_attention_masks=False,
                         pad_to_max_length=False)
    sort_key = lambda x: len(x[0]) + len(x[1])

    if args.model == 'lstm':
        model = LSTM_MODEL(tokenizer, args, n_labels=args.labels).to(device)
    else:
        model = CNN_MODEL(tokenizer, args, n_labels=args.labels).to(device)

    if args.mode == 'test':
        test = NLIDataset(args.dataset_dir, type='test')
        test_dl = BucketBatchSampler(batch_size=args.batch_size,
                                     sort_key=sort_key,
                                     dataset=test,
                                     collate_fn=collate_fn)
        optimizer = AdamW(model.parameters(), lr=args.lr)

        scores = []
        for model_path in args.model_path:
            checkpoint = torch.load(model_path)
            model.load_state_dict(checkpoint['model'])
예제 #5
0
def generate_saliency(model_path, saliency_path):
    test = get_dataset(path=args.dataset_dir, mode=args.split,
                       dataset=args.dataset)
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, loc: storage)
    model_args = argparse.Namespace(**checkpoint['args'])

    if args.model == 'trans':
        model_args.batch_size = 7
        transformer_config = BertConfig.from_pretrained('bert-base-uncased',
                                                        num_labels=model_args.labels)
        model = BertForSequenceClassification.from_pretrained(
            'bert-base-uncased', config=transformer_config).to(device)
        model.load_state_dict(checkpoint['model'])
        modelw = BertModelWrapper(model, device, tokenizer, model_args)
    else:
        if args.model == 'lstm':
            model_args.batch_size = 200
            model = LSTM_MODEL(tokenizer, model_args,
                               n_labels=checkpoint['args']['labels'],
                               device=device).to(device)
        else:
            model_args.batch_size = 300
            model = CNN_MODEL(tokenizer, model_args,
                              n_labels=checkpoint['args']['labels']).to(device)

        model.load_state_dict(checkpoint['model'])
        modelw = ModelWrapper(model, device, tokenizer, model_args)

    modelw.eval()

    explainer = LimeTextExplainer()
    saliency_flops = []

    with open(saliency_path, 'w') as out:
        for instance in tqdm(test):
            # SALIENCY
            if not args.no_time:
                high.start_counters([events.PAPI_FP_OPS, ])

            saliencies = []
            if args.dataset in ['imdb', 'tweet']:
                token_ids = tokenizer.encode(instance[0])
            else:
                token_ids = tokenizer.encode(instance[0], instance[1])

            if len(token_ids) < 6:
                token_ids = token_ids + [tokenizer.pad_token_id] * (
                            6 - len(token_ids))
            try:
                exp = explainer.explain_instance(
                    " ".join([str(i) for i in token_ids]), modelw,
                    num_features=len(token_ids),
                    top_labels=args.labels)
            except Exception as e:
                print(e)
                if not args.no_time:
                    x = high.stop_counters()[0]
                    saliency_flops.append(x)

                for token_id in token_ids:
                    token_id = int(token_id)
                    token_saliency = {
                        'token': tokenizer.ids_to_tokens[token_id]
                    }
                    for cls_ in range(args.labels):
                        token_saliency[int(cls_)] = 0
                    saliencies.append(token_saliency)

                out.write(json.dumps({'tokens': saliencies}) + '\n')
                out.flush()

                continue

            if not args.no_time:
                x = high.stop_counters()[0]
                saliency_flops.append(x)

            # SERIALIZE
            explanation = {}
            for cls_ in range(args.labels):
                cls_expl = {}
                for (w, s) in exp.as_list(label=cls_):
                    cls_expl[int(w)] = s
                explanation[cls_] = cls_expl

            for token_id in token_ids:
                token_id = int(token_id)
                token_saliency = {'token': tokenizer.ids_to_tokens[token_id]}
                for cls_ in range(args.labels):
                    token_saliency[int(cls_)] = explanation[cls_].get(token_id,
                                                                      None)
                saliencies.append(token_saliency)

            out.write(json.dumps({'tokens': saliencies}) + '\n')
            out.flush()

    return saliency_flops
예제 #6
0
def generate_saliency(model_path, saliency_path, saliency, aggregation):
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, loc: storage)
    model_args = Namespace(**checkpoint['args'])
    if args.model == 'lstm':
        model = LSTM_MODEL(tokenizer,
                           model_args,
                           n_labels=checkpoint['args']['labels']).to(device)
        model.load_state_dict(checkpoint['model'])
    elif args.model == 'trans':
        transformer_config = BertConfig.from_pretrained(
            'bert-base-uncased', num_labels=model_args.labels)
        model_cp = BertForSequenceClassification.from_pretrained(
            'bert-base-uncased', config=transformer_config).to(device)
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        model_cp.load_state_dict(checkpoint['model'])
        model = BertModelWrapper(model_cp)
    else:
        model = CNN_MODEL(tokenizer,
                          model_args,
                          n_labels=checkpoint['args']['labels']).to(device)
        model.load_state_dict(checkpoint['model'])

    model.train()

    pad_to_max = False
    if saliency == 'deeplift':
        ablator = DeepLift(model)
    elif saliency == 'guided':
        ablator = GuidedBackprop(model)
    elif saliency == 'sal':
        ablator = Saliency(model)
    elif saliency == 'inputx':
        ablator = InputXGradient(model)
    elif saliency == 'occlusion':
        ablator = Occlusion(model)

    coll_call = get_collate_fn(dataset=args.dataset, model=args.model)

    return_attention_masks = args.model == 'trans'

    collate_fn = partial(coll_call,
                         tokenizer=tokenizer,
                         device=device,
                         return_attention_masks=return_attention_masks,
                         pad_to_max_length=pad_to_max)
    test = get_dataset(path=args.dataset_dir,
                       mode=args.split,
                       dataset=args.dataset)
    batch_size = args.batch_size if args.batch_size != None else \
        model_args.batch_size
    test_dl = DataLoader(batch_size=batch_size,
                         dataset=test,
                         shuffle=False,
                         collate_fn=collate_fn)

    # PREDICTIONS
    predictions_path = model_path + '.predictions'
    if not os.path.exists(predictions_path):
        predictions = defaultdict(lambda: [])
        for batch in tqdm(test_dl, desc='Running test prediction... '):
            if args.model == 'trans':
                logits = model(batch[0],
                               attention_mask=batch[1],
                               labels=batch[2].long())
            else:
                logits = model(batch[0])
            logits = logits.detach().cpu().numpy().tolist()
            predicted = np.argmax(np.array(logits), axis=-1)
            predictions['class'] += predicted.tolist()
            predictions['logits'] += logits

        with open(predictions_path, 'w') as out:
            json.dump(predictions, out)

    # COMPUTE SALIENCY
    if saliency != 'occlusion':
        embedding_layer_name = 'model.bert.embeddings' if args.model == \
                                                          'trans' else \
            'embedding'
        interpretable_embedding = configure_interpretable_embedding_layer(
            model, embedding_layer_name)

    class_attr_list = defaultdict(lambda: [])
    token_ids = []
    saliency_flops = []

    for batch in tqdm(test_dl, desc='Running Saliency Generation...'):
        if args.model == 'cnn':
            additional = None
        elif args.model == 'trans':
            additional = (batch[1], batch[2])
        else:
            additional = batch[-1]

        token_ids += batch[0].detach().cpu().numpy().tolist()
        if saliency != 'occlusion':
            input_embeddings = interpretable_embedding.indices_to_embeddings(
                batch[0])

        if not args.no_time:
            high.start_counters([
                events.PAPI_FP_OPS,
            ])
        for cls_ in range(checkpoint['args']['labels']):
            if saliency == 'occlusion':
                attributions = ablator.attribute(
                    batch[0],
                    sliding_window_shapes=(args.sw, ),
                    target=cls_,
                    additional_forward_args=additional)
            else:
                attributions = ablator.attribute(
                    input_embeddings,
                    target=cls_,
                    additional_forward_args=additional)

            attributions = summarize_attributions(
                attributions, type=aggregation, model=model,
                tokens=batch[0]).detach().cpu().numpy().tolist()
            class_attr_list[cls_] += [[_li for _li in _l]
                                      for _l in attributions]

        if not args.no_time:
            saliency_flops.append(
                sum(high.stop_counters()) / batch[0].shape[0])

    if saliency != 'occlusion':
        remove_interpretable_embedding_layer(model, interpretable_embedding)

    # SERIALIZE
    print('Serializing...', flush=True)
    with open(saliency_path, 'w') as out:
        for instance_i, _ in enumerate(test):
            saliencies = []
            for token_i, token_id in enumerate(token_ids[instance_i]):
                token_sal = {'token': tokenizer.ids_to_tokens[token_id]}
                for cls_ in range(checkpoint['args']['labels']):
                    token_sal[int(
                        cls_)] = class_attr_list[cls_][instance_i][token_i]
                saliencies.append(token_sal)

            out.write(json.dumps({'tokens': saliencies}) + '\n')
            out.flush()

    return saliency_flops
예제 #7
0
def generate_saliency(model_path, saliency_path):
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, loc: storage)
    model_args = Namespace(**checkpoint['args'])
    model_args.batch_size = args.batch_size if args.batch_size != None else \
        model_args.batch_size

    if args.model == 'transformer':
        transformer_config = BertConfig.from_pretrained(
            'bert-base-uncased', num_labels=model_args.labels)
        modelb = BertForSequenceClassification.from_pretrained(
            'bert-base-uncased', config=transformer_config).to(device)
        modelb.load_state_dict(checkpoint['model'])
        model = BertModelWrapper(modelb)
    elif args.model == 'lstm':
        model = LSTM_MODEL(tokenizer,
                           model_args,
                           n_labels=checkpoint['args']['labels'],
                           device=device).to(device)
        model.load_state_dict(checkpoint['model'])
        model.train()
        model = ModelWrapper(model)
    else:
        # model_args.batch_size = 1000
        model = CNN_MODEL(tokenizer,
                          model_args,
                          n_labels=checkpoint['args']['labels']).to(device)
        model.load_state_dict(checkpoint['model'])
        model.train()
        model = ModelWrapper(model)

    ablator = ShapleyValueSampling(model)

    coll_call = get_collate_fn(dataset=args.dataset, model=args.model)

    collate_fn = partial(coll_call,
                         tokenizer=tokenizer,
                         device=device,
                         return_attention_masks=False,
                         pad_to_max_length=False)

    test = get_dataset(args.dataset_dir, mode=args.split)
    test_dl = DataLoader(batch_size=model_args.batch_size,
                         dataset=test,
                         shuffle=False,
                         collate_fn=collate_fn)

    # PREDICTIONS
    predictions_path = model_path + '.predictions'
    if not os.path.exists(predictions_path):
        predictions = defaultdict(lambda: [])
        for batch in tqdm(test_dl, desc='Running test prediction... '):
            logits = model(batch[0])
            logits = logits.detach().cpu().numpy().tolist()
            predicted = np.argmax(np.array(logits), axis=-1)
            predictions['class'] += predicted.tolist()
            predictions['logits'] += logits

        with open(predictions_path, 'w') as out:
            json.dump(predictions, out)

    # COMPUTE SALIENCY

    saliency_flops = []

    with open(saliency_path, 'w') as out_mean:
        for batch in tqdm(test_dl, desc='Running Saliency Generation...'):
            class_attr_list = defaultdict(lambda: [])

            if args.model == 'rnn':
                additional = batch[-1]
            else:
                additional = None

            if not args.no_time:
                high.start_counters([events.PAPI_FP_OPS])
            token_ids = batch[0].detach().cpu().numpy().tolist()

            for cls_ in range(args.labels):
                attributions = ablator.attribute(
                    batch[0].float(),
                    target=cls_,
                    additional_forward_args=additional)
                attributions = attributions.detach().cpu().numpy().tolist()
                class_attr_list[cls_] += attributions

            if not args.no_time:
                x = sum(high.stop_counters())
                saliency_flops.append(x / batch[0].shape[0])

            for i in range(len(batch[0])):
                saliencies = []
                for token_i, token_id in enumerate(token_ids[i]):
                    if token_id == tokenizer.pad_token_id:
                        continue
                    token_sal = {'token': tokenizer.ids_to_tokens[token_id]}
                    for cls_ in range(args.labels):
                        token_sal[int(
                            cls_)] = class_attr_list[cls_][i][token_i]
                    saliencies.append(token_sal)

                out_mean.write(json.dumps({'tokens': saliencies}) + '\n')
                out_mean.flush()

    return saliency_flops