def generate_pred(model_path): checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_args = Namespace(**checkpoint['args']) if args.model == 'lstm': model = LSTM_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) model.load_state_dict(checkpoint['model']) elif args.model == 'trans': transformer_config = BertConfig.from_pretrained('bert-base-uncased', num_labels=model_args.labels) model_cp = BertForSequenceClassification.from_pretrained( 'bert-base-uncased', config=transformer_config).to( device) checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_cp.load_state_dict(checkpoint['model']) model = BertModelWrapper(model_cp) else: model = CNN_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) model.load_state_dict(checkpoint['model']) model.train() pad_to_max = False coll_call = get_collate_fn(dataset=args.dataset, model=args.model) return_attention_masks = args.model == 'trans' collate_fn = partial(coll_call, tokenizer=tokenizer, device=device, return_attention_masks=return_attention_masks, pad_to_max_length=pad_to_max) test = get_dataset(path=args.dataset_dir, mode=args.split, dataset=args.dataset) batch_size = args.batch_size if args.batch_size is not None else \ model_args.batch_size test_dl = DataLoader(batch_size=batch_size, dataset=test, shuffle=False, collate_fn=collate_fn) # PREDICTIONS predictions_path = model_path + '.predictions' predictions = defaultdict(lambda: []) for batch in tqdm(test_dl, desc='Running test prediction... '): if args.model == 'trans': logits = model(batch[0], attention_mask=batch[1], labels=batch[2].long()) else: logits = model(batch[0]) logits = logits.detach().cpu().numpy().tolist() predicted = np.argmax(np.array(logits), axis=-1) predictions['class'] += predicted.tolist() predictions['logits'] += logits with open(predictions_path, 'w') as out: json.dump(predictions, out)
def get_model(): if args.model == 'trans': transformer_config = BertConfig.from_pretrained('bert-base-uncased', num_labels=args.labels) if args.init_only: model = BertForSequenceClassification( config=transformer_config).to(device) else: model = BertForSequenceClassification.from_pretrained( 'bert-base-uncased', config=transformer_config).to(device) param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr) es = EarlyStopping(patience=args.patience, percentage=False, mode='max', min_delta=0.0) scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=0.05) else: if args.model == 'cnn': model = CNN_MODEL(tokenizer, args, n_labels=args.labels).to(device) elif args.model == 'lstm': model = LSTM_MODEL(tokenizer, args, n_labels=args.labels).to(device) optimizer = AdamW(model.parameters(), lr=args.lr) scheduler = ReduceLROnPlateau(optimizer, verbose=True) es = EarlyStopping(patience=args.patience, percentage=False, mode='max', min_delta=0.0) return model, optimizer, scheduler, es
def get_model(model_path): checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_args = Namespace(**checkpoint['args']) if args.model == 'lstm': model_cp = LSTM_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) elif args.model == 'trans': transformer_config = BertConfig.from_pretrained( 'bert-base-uncased', num_labels=model_args.labels) model_cp = BertForSequenceClassification.from_pretrained( 'bert-base-uncased', config=transformer_config).to(device) else: model_cp = CNN_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) model_cp.load_state_dict(checkpoint['model']) return model_cp, model_args
torch.cuda.manual_seed_all(args.seed) torch.backends.cudnn.deterministic = True device = torch.device("cuda") if args.gpu else torch.device("cpu") tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') collate_fn = partial(collate_nli, tokenizer=tokenizer, device=device, return_attention_masks=False, pad_to_max_length=False) sort_key = lambda x: len(x[0]) + len(x[1]) if args.model == 'lstm': model = LSTM_MODEL(tokenizer, args, n_labels=args.labels).to(device) else: model = CNN_MODEL(tokenizer, args, n_labels=args.labels).to(device) if args.mode == 'test': test = NLIDataset(args.dataset_dir, type='test') test_dl = BucketBatchSampler(batch_size=args.batch_size, sort_key=sort_key, dataset=test, collate_fn=collate_fn) optimizer = AdamW(model.parameters(), lr=args.lr) scores = [] for model_path in args.model_path: checkpoint = torch.load(model_path) model.load_state_dict(checkpoint['model']) p, r, f1, loss, _, _ = eval_model(model, test_dl)
def generate_saliency(model_path, saliency_path): test = get_dataset(path=args.dataset_dir, mode=args.split, dataset=args.dataset) checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_args = argparse.Namespace(**checkpoint['args']) if args.model == 'trans': model_args.batch_size = 7 transformer_config = BertConfig.from_pretrained('bert-base-uncased', num_labels=model_args.labels) model = BertForSequenceClassification.from_pretrained( 'bert-base-uncased', config=transformer_config).to(device) model.load_state_dict(checkpoint['model']) modelw = BertModelWrapper(model, device, tokenizer, model_args) else: if args.model == 'lstm': model_args.batch_size = 200 model = LSTM_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels'], device=device).to(device) else: model_args.batch_size = 300 model = CNN_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) model.load_state_dict(checkpoint['model']) modelw = ModelWrapper(model, device, tokenizer, model_args) modelw.eval() explainer = LimeTextExplainer() saliency_flops = [] with open(saliency_path, 'w') as out: for instance in tqdm(test): # SALIENCY if not args.no_time: high.start_counters([events.PAPI_FP_OPS, ]) saliencies = [] if args.dataset in ['imdb', 'tweet']: token_ids = tokenizer.encode(instance[0]) else: token_ids = tokenizer.encode(instance[0], instance[1]) if len(token_ids) < 6: token_ids = token_ids + [tokenizer.pad_token_id] * ( 6 - len(token_ids)) try: exp = explainer.explain_instance( " ".join([str(i) for i in token_ids]), modelw, num_features=len(token_ids), top_labels=args.labels) except Exception as e: print(e) if not args.no_time: x = high.stop_counters()[0] saliency_flops.append(x) for token_id in token_ids: token_id = int(token_id) token_saliency = { 'token': tokenizer.ids_to_tokens[token_id] } for cls_ in range(args.labels): token_saliency[int(cls_)] = 0 saliencies.append(token_saliency) out.write(json.dumps({'tokens': saliencies}) + '\n') out.flush() continue if not args.no_time: x = high.stop_counters()[0] saliency_flops.append(x) # SERIALIZE explanation = {} for cls_ in range(args.labels): cls_expl = {} for (w, s) in exp.as_list(label=cls_): cls_expl[int(w)] = s explanation[cls_] = cls_expl for token_id in token_ids: token_id = int(token_id) token_saliency = {'token': tokenizer.ids_to_tokens[token_id]} for cls_ in range(args.labels): token_saliency[int(cls_)] = explanation[cls_].get(token_id, None) saliencies.append(token_saliency) out.write(json.dumps({'tokens': saliencies}) + '\n') out.flush() return saliency_flops
def generate_saliency(model_path, saliency_path, saliency, aggregation): checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_args = Namespace(**checkpoint['args']) if args.model == 'lstm': model = LSTM_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) model.load_state_dict(checkpoint['model']) elif args.model == 'trans': transformer_config = BertConfig.from_pretrained( 'bert-base-uncased', num_labels=model_args.labels) model_cp = BertForSequenceClassification.from_pretrained( 'bert-base-uncased', config=transformer_config).to(device) checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_cp.load_state_dict(checkpoint['model']) model = BertModelWrapper(model_cp) else: model = CNN_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) model.load_state_dict(checkpoint['model']) model.train() pad_to_max = False if saliency == 'deeplift': ablator = DeepLift(model) elif saliency == 'guided': ablator = GuidedBackprop(model) elif saliency == 'sal': ablator = Saliency(model) elif saliency == 'inputx': ablator = InputXGradient(model) elif saliency == 'occlusion': ablator = Occlusion(model) coll_call = get_collate_fn(dataset=args.dataset, model=args.model) return_attention_masks = args.model == 'trans' collate_fn = partial(coll_call, tokenizer=tokenizer, device=device, return_attention_masks=return_attention_masks, pad_to_max_length=pad_to_max) test = get_dataset(path=args.dataset_dir, mode=args.split, dataset=args.dataset) batch_size = args.batch_size if args.batch_size != None else \ model_args.batch_size test_dl = DataLoader(batch_size=batch_size, dataset=test, shuffle=False, collate_fn=collate_fn) # PREDICTIONS predictions_path = model_path + '.predictions' if not os.path.exists(predictions_path): predictions = defaultdict(lambda: []) for batch in tqdm(test_dl, desc='Running test prediction... '): if args.model == 'trans': logits = model(batch[0], attention_mask=batch[1], labels=batch[2].long()) else: logits = model(batch[0]) logits = logits.detach().cpu().numpy().tolist() predicted = np.argmax(np.array(logits), axis=-1) predictions['class'] += predicted.tolist() predictions['logits'] += logits with open(predictions_path, 'w') as out: json.dump(predictions, out) # COMPUTE SALIENCY if saliency != 'occlusion': embedding_layer_name = 'model.bert.embeddings' if args.model == \ 'trans' else \ 'embedding' interpretable_embedding = configure_interpretable_embedding_layer( model, embedding_layer_name) class_attr_list = defaultdict(lambda: []) token_ids = [] saliency_flops = [] for batch in tqdm(test_dl, desc='Running Saliency Generation...'): if args.model == 'cnn': additional = None elif args.model == 'trans': additional = (batch[1], batch[2]) else: additional = batch[-1] token_ids += batch[0].detach().cpu().numpy().tolist() if saliency != 'occlusion': input_embeddings = interpretable_embedding.indices_to_embeddings( batch[0]) if not args.no_time: high.start_counters([ events.PAPI_FP_OPS, ]) for cls_ in range(checkpoint['args']['labels']): if saliency == 'occlusion': attributions = ablator.attribute( batch[0], sliding_window_shapes=(args.sw, ), target=cls_, additional_forward_args=additional) else: attributions = ablator.attribute( input_embeddings, target=cls_, additional_forward_args=additional) attributions = summarize_attributions( attributions, type=aggregation, model=model, tokens=batch[0]).detach().cpu().numpy().tolist() class_attr_list[cls_] += [[_li for _li in _l] for _l in attributions] if not args.no_time: saliency_flops.append( sum(high.stop_counters()) / batch[0].shape[0]) if saliency != 'occlusion': remove_interpretable_embedding_layer(model, interpretable_embedding) # SERIALIZE print('Serializing...', flush=True) with open(saliency_path, 'w') as out: for instance_i, _ in enumerate(test): saliencies = [] for token_i, token_id in enumerate(token_ids[instance_i]): token_sal = {'token': tokenizer.ids_to_tokens[token_id]} for cls_ in range(checkpoint['args']['labels']): token_sal[int( cls_)] = class_attr_list[cls_][instance_i][token_i] saliencies.append(token_sal) out.write(json.dumps({'tokens': saliencies}) + '\n') out.flush() return saliency_flops
def generate_saliency(model_path, saliency_path): checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_args = Namespace(**checkpoint['args']) model_args.batch_size = args.batch_size if args.batch_size != None else \ model_args.batch_size if args.model == 'transformer': transformer_config = BertConfig.from_pretrained( 'bert-base-uncased', num_labels=model_args.labels) modelb = BertForSequenceClassification.from_pretrained( 'bert-base-uncased', config=transformer_config).to(device) modelb.load_state_dict(checkpoint['model']) model = BertModelWrapper(modelb) elif args.model == 'lstm': model = LSTM_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels'], device=device).to(device) model.load_state_dict(checkpoint['model']) model.train() model = ModelWrapper(model) else: # model_args.batch_size = 1000 model = CNN_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) model.load_state_dict(checkpoint['model']) model.train() model = ModelWrapper(model) ablator = ShapleyValueSampling(model) coll_call = get_collate_fn(dataset=args.dataset, model=args.model) collate_fn = partial(coll_call, tokenizer=tokenizer, device=device, return_attention_masks=False, pad_to_max_length=False) test = get_dataset(args.dataset_dir, mode=args.split) test_dl = DataLoader(batch_size=model_args.batch_size, dataset=test, shuffle=False, collate_fn=collate_fn) # PREDICTIONS predictions_path = model_path + '.predictions' if not os.path.exists(predictions_path): predictions = defaultdict(lambda: []) for batch in tqdm(test_dl, desc='Running test prediction... '): logits = model(batch[0]) logits = logits.detach().cpu().numpy().tolist() predicted = np.argmax(np.array(logits), axis=-1) predictions['class'] += predicted.tolist() predictions['logits'] += logits with open(predictions_path, 'w') as out: json.dump(predictions, out) # COMPUTE SALIENCY saliency_flops = [] with open(saliency_path, 'w') as out_mean: for batch in tqdm(test_dl, desc='Running Saliency Generation...'): class_attr_list = defaultdict(lambda: []) if args.model == 'rnn': additional = batch[-1] else: additional = None if not args.no_time: high.start_counters([events.PAPI_FP_OPS]) token_ids = batch[0].detach().cpu().numpy().tolist() for cls_ in range(args.labels): attributions = ablator.attribute( batch[0].float(), target=cls_, additional_forward_args=additional) attributions = attributions.detach().cpu().numpy().tolist() class_attr_list[cls_] += attributions if not args.no_time: x = sum(high.stop_counters()) saliency_flops.append(x / batch[0].shape[0]) for i in range(len(batch[0])): saliencies = [] for token_i, token_id in enumerate(token_ids[i]): if token_id == tokenizer.pad_token_id: continue token_sal = {'token': tokenizer.ids_to_tokens[token_id]} for cls_ in range(args.labels): token_sal[int( cls_)] = class_attr_list[cls_][i][token_i] saliencies.append(token_sal) out_mean.write(json.dumps({'tokens': saliencies}) + '\n') out_mean.flush() return saliency_flops