def generate_pred(model_path): checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_args = Namespace(**checkpoint['args']) if args.model == 'lstm': model = LSTM_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) model.load_state_dict(checkpoint['model']) elif args.model == 'trans': transformer_config = BertConfig.from_pretrained('bert-base-uncased', num_labels=model_args.labels) model_cp = BertForSequenceClassification.from_pretrained( 'bert-base-uncased', config=transformer_config).to( device) checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_cp.load_state_dict(checkpoint['model']) model = BertModelWrapper(model_cp) else: model = CNN_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) model.load_state_dict(checkpoint['model']) model.train() pad_to_max = False coll_call = get_collate_fn(dataset=args.dataset, model=args.model) return_attention_masks = args.model == 'trans' collate_fn = partial(coll_call, tokenizer=tokenizer, device=device, return_attention_masks=return_attention_masks, pad_to_max_length=pad_to_max) test = get_dataset(path=args.dataset_dir, mode=args.split, dataset=args.dataset) batch_size = args.batch_size if args.batch_size is not None else \ model_args.batch_size test_dl = DataLoader(batch_size=batch_size, dataset=test, shuffle=False, collate_fn=collate_fn) # PREDICTIONS predictions_path = model_path + '.predictions' predictions = defaultdict(lambda: []) for batch in tqdm(test_dl, desc='Running test prediction... '): if args.model == 'trans': logits = model(batch[0], attention_mask=batch[1], labels=batch[2].long()) else: logits = model(batch[0]) logits = logits.detach().cpu().numpy().tolist() predicted = np.argmax(np.array(logits), axis=-1) predictions['class'] += predicted.tolist() predictions['logits'] += logits with open(predictions_path, 'w') as out: json.dump(predictions, out)
def get_model(model_path): checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_args = Namespace(**checkpoint['args']) if args.model == 'lstm': model_cp = LSTM_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) elif args.model == 'trans': transformer_config = BertConfig.from_pretrained( 'bert-base-uncased', num_labels=model_args.labels) model_cp = BertForSequenceClassification.from_pretrained( 'bert-base-uncased', config=transformer_config).to(device) else: model_cp = CNN_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) model_cp.load_state_dict(checkpoint['model']) return model_cp, model_args
model = LSTM_MODEL(tokenizer, args, n_labels=args.labels).to(device) else: model = CNN_MODEL(tokenizer, args, n_labels=args.labels).to(device) if args.mode == 'test': test = NLIDataset(args.dataset_dir, type='test') test_dl = BucketBatchSampler(batch_size=args.batch_size, sort_key=sort_key, dataset=test, collate_fn=collate_fn) optimizer = AdamW(model.parameters(), lr=args.lr) scores = [] for model_path in args.model_path: checkpoint = torch.load(model_path) model.load_state_dict(checkpoint['model']) p, r, f1, loss, _, _ = eval_model(model, test_dl) scores.append((p, r, f1, loss)) for i, name in zip(range(len(scores[0])), ['p', 'r', 'f1', 'loss']): l = [model_scores[i] for model_scores in scores] print(name, np.average(l), np.std(l)) else: print("Loading datasets...") train = NLIDataset(args.dataset_dir, type='train') dev = NLIDataset(args.dataset_dir, type='dev') train_dl = BucketBatchSampler(batch_size=args.batch_size, sort_key=sort_key, dataset=train,
def generate_saliency(model_path, saliency_path): test = get_dataset(path=args.dataset_dir, mode=args.split, dataset=args.dataset) checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_args = argparse.Namespace(**checkpoint['args']) if args.model == 'trans': model_args.batch_size = 7 transformer_config = BertConfig.from_pretrained('bert-base-uncased', num_labels=model_args.labels) model = BertForSequenceClassification.from_pretrained( 'bert-base-uncased', config=transformer_config).to(device) model.load_state_dict(checkpoint['model']) modelw = BertModelWrapper(model, device, tokenizer, model_args) else: if args.model == 'lstm': model_args.batch_size = 200 model = LSTM_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels'], device=device).to(device) else: model_args.batch_size = 300 model = CNN_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) model.load_state_dict(checkpoint['model']) modelw = ModelWrapper(model, device, tokenizer, model_args) modelw.eval() explainer = LimeTextExplainer() saliency_flops = [] with open(saliency_path, 'w') as out: for instance in tqdm(test): # SALIENCY if not args.no_time: high.start_counters([events.PAPI_FP_OPS, ]) saliencies = [] if args.dataset in ['imdb', 'tweet']: token_ids = tokenizer.encode(instance[0]) else: token_ids = tokenizer.encode(instance[0], instance[1]) if len(token_ids) < 6: token_ids = token_ids + [tokenizer.pad_token_id] * ( 6 - len(token_ids)) try: exp = explainer.explain_instance( " ".join([str(i) for i in token_ids]), modelw, num_features=len(token_ids), top_labels=args.labels) except Exception as e: print(e) if not args.no_time: x = high.stop_counters()[0] saliency_flops.append(x) for token_id in token_ids: token_id = int(token_id) token_saliency = { 'token': tokenizer.ids_to_tokens[token_id] } for cls_ in range(args.labels): token_saliency[int(cls_)] = 0 saliencies.append(token_saliency) out.write(json.dumps({'tokens': saliencies}) + '\n') out.flush() continue if not args.no_time: x = high.stop_counters()[0] saliency_flops.append(x) # SERIALIZE explanation = {} for cls_ in range(args.labels): cls_expl = {} for (w, s) in exp.as_list(label=cls_): cls_expl[int(w)] = s explanation[cls_] = cls_expl for token_id in token_ids: token_id = int(token_id) token_saliency = {'token': tokenizer.ids_to_tokens[token_id]} for cls_ in range(args.labels): token_saliency[int(cls_)] = explanation[cls_].get(token_id, None) saliencies.append(token_saliency) out.write(json.dumps({'tokens': saliencies}) + '\n') out.flush() return saliency_flops
def generate_saliency(model_path, saliency_path, saliency, aggregation): checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_args = Namespace(**checkpoint['args']) if args.model == 'lstm': model = LSTM_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) model.load_state_dict(checkpoint['model']) elif args.model == 'trans': transformer_config = BertConfig.from_pretrained( 'bert-base-uncased', num_labels=model_args.labels) model_cp = BertForSequenceClassification.from_pretrained( 'bert-base-uncased', config=transformer_config).to(device) checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_cp.load_state_dict(checkpoint['model']) model = BertModelWrapper(model_cp) else: model = CNN_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) model.load_state_dict(checkpoint['model']) model.train() pad_to_max = False if saliency == 'deeplift': ablator = DeepLift(model) elif saliency == 'guided': ablator = GuidedBackprop(model) elif saliency == 'sal': ablator = Saliency(model) elif saliency == 'inputx': ablator = InputXGradient(model) elif saliency == 'occlusion': ablator = Occlusion(model) coll_call = get_collate_fn(dataset=args.dataset, model=args.model) return_attention_masks = args.model == 'trans' collate_fn = partial(coll_call, tokenizer=tokenizer, device=device, return_attention_masks=return_attention_masks, pad_to_max_length=pad_to_max) test = get_dataset(path=args.dataset_dir, mode=args.split, dataset=args.dataset) batch_size = args.batch_size if args.batch_size != None else \ model_args.batch_size test_dl = DataLoader(batch_size=batch_size, dataset=test, shuffle=False, collate_fn=collate_fn) # PREDICTIONS predictions_path = model_path + '.predictions' if not os.path.exists(predictions_path): predictions = defaultdict(lambda: []) for batch in tqdm(test_dl, desc='Running test prediction... '): if args.model == 'trans': logits = model(batch[0], attention_mask=batch[1], labels=batch[2].long()) else: logits = model(batch[0]) logits = logits.detach().cpu().numpy().tolist() predicted = np.argmax(np.array(logits), axis=-1) predictions['class'] += predicted.tolist() predictions['logits'] += logits with open(predictions_path, 'w') as out: json.dump(predictions, out) # COMPUTE SALIENCY if saliency != 'occlusion': embedding_layer_name = 'model.bert.embeddings' if args.model == \ 'trans' else \ 'embedding' interpretable_embedding = configure_interpretable_embedding_layer( model, embedding_layer_name) class_attr_list = defaultdict(lambda: []) token_ids = [] saliency_flops = [] for batch in tqdm(test_dl, desc='Running Saliency Generation...'): if args.model == 'cnn': additional = None elif args.model == 'trans': additional = (batch[1], batch[2]) else: additional = batch[-1] token_ids += batch[0].detach().cpu().numpy().tolist() if saliency != 'occlusion': input_embeddings = interpretable_embedding.indices_to_embeddings( batch[0]) if not args.no_time: high.start_counters([ events.PAPI_FP_OPS, ]) for cls_ in range(checkpoint['args']['labels']): if saliency == 'occlusion': attributions = ablator.attribute( batch[0], sliding_window_shapes=(args.sw, ), target=cls_, additional_forward_args=additional) else: attributions = ablator.attribute( input_embeddings, target=cls_, additional_forward_args=additional) attributions = summarize_attributions( attributions, type=aggregation, model=model, tokens=batch[0]).detach().cpu().numpy().tolist() class_attr_list[cls_] += [[_li for _li in _l] for _l in attributions] if not args.no_time: saliency_flops.append( sum(high.stop_counters()) / batch[0].shape[0]) if saliency != 'occlusion': remove_interpretable_embedding_layer(model, interpretable_embedding) # SERIALIZE print('Serializing...', flush=True) with open(saliency_path, 'w') as out: for instance_i, _ in enumerate(test): saliencies = [] for token_i, token_id in enumerate(token_ids[instance_i]): token_sal = {'token': tokenizer.ids_to_tokens[token_id]} for cls_ in range(checkpoint['args']['labels']): token_sal[int( cls_)] = class_attr_list[cls_][instance_i][token_i] saliencies.append(token_sal) out.write(json.dumps({'tokens': saliencies}) + '\n') out.flush() return saliency_flops