def validate(model, ids, bptt=2000): """ Return the validation loss and perplexity of a model :param model: model to test :param ids: data on which to evaluate the model :param bptt: bptt for this evaluation (doesn't change the result, only the speed) From https://github.com/sgugger/Adam-experiments/blob/master/lm_val_fns.py#L34 """ data = TextReader(np.concatenate(ids), bptt) model.eval() model.reset() total_loss, num_examples = 0., 0 for inputs, targets in tqdm(data): outputs, raws, outs = model(to_device(inputs, None)) p_vocab = F.softmax(outputs, 1) for i, pv in enumerate(p_vocab): targ_pred = pv[targets[i]] total_loss -= torch.log(targ_pred.detach()) num_examples += len(inputs) mean = total_loss / num_examples # divide by total number of tokens return mean, np.exp(mean)
def _calc_test_hvps_log_loss(criterion: Callable, trained_model: nn.Module, train_dataloader: DeviceDataLoader, test_dataloader: DataLoader, run_params: Mapping, diff_wrt: Optional[torch.Tensor] = None, show_progress: bool = False): device = train_dataloader.device diff_wrt = maybe( diff_wrt, [p for p in trained_model.parameters() if p.requires_grad]) if run_params['use_gauss_newton']: matmul_class: Union[Type[GNP], Type[HVP]] = GNP damping = 0.001 else: matmul_class = HVP damping = 0.01 matmul = matmul_class( calc_loss=lambda xs, target: criterion(trained_model(*xs), target), parameters=diff_wrt, data=train_dataloader, data_len=len(train_dataloader), damping=damping, cache_batch=True) cg = CG(matmul=matmul, result_len=sum(p.numel() for p in diff_wrt), max_iters=maybe(run_params['max_cg_iters'], sum(p.numel() for p in diff_wrt))) test_hvps: List[torch.Tensor] = [] iterator = progressbar( test_dataloader) if show_progress else test_dataloader for batch in iterator: x_test, target = to_device(batch, device) loss_at_x_test = criterion(trained_model(*x_test), target.squeeze()) grads = autograd.grad(loss_at_x_test, diff_wrt) grad_at_z_test = collect(grads) if len(test_hvps) != 0: init = test_hvps[-1].detach().clone() else: init = torch.zeros_like(grad_at_z_test) def _min(x, grad_at_z_test=grad_at_z_test): x_tens = torch.tensor( x, device=device) if not isinstance(x, torch.Tensor) else x grad_tens = torch.tensor( grad_at_z_test, device=device) if not isinstance( grad_at_z_test, torch.Tensor) else grad_at_z_test return np.array(0.5 * matmul(x_tens).dot(x_tens) - grad_tens.dot(x_tens)) def _grad(x, grad_at_z_test=grad_at_z_test): x_tens = torch.tensor( x, device=device) if not isinstance(x, torch.Tensor) else x return np.array(matmul(x_tens) - grad_at_z_test) def _hess(x, p): grad_tens = torch.tensor( p, device=device) if not isinstance(p, torch.Tensor) else p return np.array(matmul(grad_tens)) if getattr(run_params, 'use_scipy', False): test_hvps.append( torch.tensor(fmin_ncg(f=_min, x0=init, fprime=_grad, fhess_p=_hess, avextol=1e-8, maxiter=100, disp=False), device=device)) else: test_hvps.append( cg.solve(grad_at_z_test, test_hvps[-1] if len(test_hvps) != 0 else None)) matmul.clear_batch() return torch.stack(test_hvps)
def main(): global model_to_save global experiment global rabbit rabbit = MyRabbit(args) if rabbit.model_params.dont_limit_num_uniq_tokens: raise NotImplementedError() if rabbit.model_params.frame_as_qa: raise NotImplementedError if rabbit.run_params.drop_val_loss_calc: raise NotImplementedError if rabbit.run_params.use_softrank_influence and not rabbit.run_params.freeze_all_but_last_for_influence: raise NotImplementedError if rabbit.train_params.weight_influence: raise NotImplementedError experiment = Experiment(rabbit.train_params + rabbit.model_params + rabbit.run_params) print('Model name:', experiment.model_name) use_pretrained_doc_encoder = rabbit.model_params.use_pretrained_doc_encoder use_pointwise_loss = rabbit.train_params.use_pointwise_loss query_token_embed_len = rabbit.model_params.query_token_embed_len document_token_embed_len = rabbit.model_params.document_token_embed_len _names = [] if not rabbit.model_params.dont_include_titles: _names.append('with_titles') if rabbit.train_params.num_doc_tokens_to_consider != -1: _names.append('num_doc_toks_' + str(rabbit.train_params.num_doc_tokens_to_consider)) if not rabbit.run_params.just_caches: if rabbit.model_params.dont_include_titles: document_lookup = read_cache(name('./doc_lookup.json', _names), get_robust_documents) else: document_lookup = read_cache(name('./doc_lookup.json', _names), get_robust_documents_with_titles) num_doc_tokens_to_consider = rabbit.train_params.num_doc_tokens_to_consider document_title_to_id = read_cache( './document_title_to_id.json', lambda: create_id_lookup(document_lookup.keys())) with open('./caches/106756_most_common_doc.json', 'r') as fh: doc_token_set = set(json.load(fh)) tokenizer = Tokenizer() tokenized = set( sum( tokenizer.process_all(list( get_robust_eval_queries().values())), [])) doc_token_set = doc_token_set.union(tokenized) use_bow_model = not any([ rabbit.model_params[attr] for attr in ['use_doc_out', 'use_cnn', 'use_lstm', 'use_pretrained_doc_encoder'] ]) use_bow_model = use_bow_model and not rabbit.model_params.dont_use_bow if use_bow_model: documents, document_token_lookup = read_cache( name(f'./docs_fs_tokens_limit_uniq_toks_qrels_and_106756.pkl', _names), lambda: prepare_fs(document_lookup, document_title_to_id, num_tokens=num_doc_tokens_to_consider, token_set=doc_token_set)) if rabbit.model_params.keep_top_uniq_terms is not None: documents = [ dict( nlargest(rabbit.model_params.keep_top_uniq_terms, _.to_pairs(doc), itemgetter(1))) for doc in documents ] else: documents, document_token_lookup = read_cache( name( f'./parsed_docs_{num_doc_tokens_to_consider}_tokens_limit_uniq_toks_qrels_and_106756.json', _names), lambda: prepare(document_lookup, document_title_to_id, num_tokens=num_doc_tokens_to_consider, token_set=doc_token_set)) if not rabbit.run_params.just_caches: train_query_lookup = read_cache('./robust_train_queries.json', get_robust_train_queries) train_query_name_to_id = read_cache( './train_query_name_to_id.json', lambda: create_id_lookup(train_query_lookup.keys())) train_queries, query_token_lookup = read_cache( './parsed_robust_queries_dict.json', lambda: prepare(train_query_lookup, train_query_name_to_id, token_lookup=document_token_lookup, token_set=doc_token_set, drop_if_any_unk=True)) query_tok_to_doc_tok = { idx: document_token_lookup.get(query_token) or document_token_lookup['<unk>'] for query_token, idx in query_token_lookup.items() } names = [RANKER_NAME_TO_SUFFIX[rabbit.train_params.ranking_set]] if rabbit.train_params.use_pointwise_loss or not rabbit.run_params.just_caches: train_data = read_cache( name('./robust_train_query_results_tokens_qrels_and_106756.json', names), lambda: read_query_result( train_query_name_to_id, document_title_to_id, train_queries, path='./indri/query_result' + RANKER_NAME_TO_SUFFIX[ rabbit.train_params.ranking_set])) else: train_data = [] q_embed_len = rabbit.model_params.query_token_embed_len doc_embed_len = rabbit.model_params.document_token_embed_len if rabbit.model_params.append_difference or rabbit.model_params.append_hadamard: assert q_embed_len == doc_embed_len, 'Must use same size doc and query embeds when appending diff or hadamard' if q_embed_len == doc_embed_len: glove_lookup = get_glove_lookup( embedding_dim=q_embed_len, use_large_embed=rabbit.model_params.use_large_embed, use_word2vec=rabbit.model_params.use_word2vec) q_glove_lookup = glove_lookup doc_glove_lookup = glove_lookup else: q_glove_lookup = get_glove_lookup( embedding_dim=q_embed_len, use_large_embed=rabbit.model_params.use_large_embed, use_word2vec=rabbit.model_params.use_word2vec) doc_glove_lookup = get_glove_lookup( embedding_dim=doc_embed_len, use_large_embed=rabbit.model_params.use_large_embed, use_word2vec=rabbit.model_params.use_word2vec) num_query_tokens = len(query_token_lookup) num_doc_tokens = len(document_token_lookup) doc_encoder = None if use_pretrained_doc_encoder or rabbit.model_params.use_doc_out: doc_encoder, document_token_embeds = get_doc_encoder_and_embeddings( document_token_lookup, rabbit.model_params.only_use_last_out) if rabbit.model_params.use_glove: query_token_embeds_init = init_embedding(q_glove_lookup, query_token_lookup, num_query_tokens, query_token_embed_len) else: query_token_embeds_init = from_doc_to_query_embeds( document_token_embeds, document_token_lookup, query_token_lookup) if not rabbit.train_params.dont_freeze_pretrained_doc_encoder: dont_update(doc_encoder) if rabbit.model_params.use_doc_out: doc_encoder = None else: document_token_embeds = init_embedding(doc_glove_lookup, document_token_lookup, num_doc_tokens, document_token_embed_len) if rabbit.model_params.use_single_word_embed_set: query_token_embeds_init = document_token_embeds else: query_token_embeds_init = init_embedding(q_glove_lookup, query_token_lookup, num_query_tokens, query_token_embed_len) if not rabbit.train_params.dont_freeze_word_embeds: dont_update(document_token_embeds) dont_update(query_token_embeds_init) else: do_update(document_token_embeds) do_update(query_token_embeds_init) if rabbit.train_params.add_rel_score: query_token_embeds, additive = get_additive_regularized_embeds( query_token_embeds_init) rel_score = RelScore(query_token_embeds, document_token_embeds, rabbit.model_params, rabbit.train_params) else: query_token_embeds = query_token_embeds_init additive = None rel_score = None eval_query_lookup = get_robust_eval_queries() eval_query_name_document_title_rels = get_robust_rels() test_query_names = [] val_query_names = [] for query_name in eval_query_lookup: if len(val_query_names) >= 50: test_query_names.append(query_name) else: val_query_names.append(query_name) test_query_name_document_title_rels = _.pick( eval_query_name_document_title_rels, test_query_names) test_query_lookup = _.pick(eval_query_lookup, test_query_names) test_query_name_to_id = create_id_lookup(test_query_lookup.keys()) test_queries, __ = prepare(test_query_lookup, test_query_name_to_id, token_lookup=query_token_lookup) eval_ranking_candidates = read_query_test_rankings( './indri/query_result_test' + RANKER_NAME_TO_SUFFIX[rabbit.train_params.ranking_set]) test_candidates_data = read_query_result( test_query_name_to_id, document_title_to_id, dict(zip(range(len(test_queries)), test_queries)), path='./indri/query_result_test' + RANKER_NAME_TO_SUFFIX[rabbit.train_params.ranking_set]) test_ranking_candidates = process_raw_candidates(test_query_name_to_id, test_queries, document_title_to_id, test_query_names, eval_ranking_candidates) test_data = process_rels(test_query_name_document_title_rels, document_title_to_id, test_query_name_to_id, test_queries) val_query_name_document_title_rels = _.pick( eval_query_name_document_title_rels, val_query_names) val_query_lookup = _.pick(eval_query_lookup, val_query_names) val_query_name_to_id = create_id_lookup(val_query_lookup.keys()) val_queries, __ = prepare(val_query_lookup, val_query_name_to_id, token_lookup=query_token_lookup) val_candidates_data = read_query_result( val_query_name_to_id, document_title_to_id, dict(zip(range(len(val_queries)), val_queries)), path='./indri/query_result_test' + RANKER_NAME_TO_SUFFIX[rabbit.train_params.ranking_set]) val_ranking_candidates = process_raw_candidates(val_query_name_to_id, val_queries, document_title_to_id, val_query_names, eval_ranking_candidates) val_data = process_rels(val_query_name_document_title_rels, document_title_to_id, val_query_name_to_id, val_queries) train_normalized_score_lookup = read_cache( name('./train_normalized_score_lookup.pkl', names), lambda: get_normalized_score_lookup(train_data)) test_normalized_score_lookup = get_normalized_score_lookup( test_candidates_data) val_normalized_score_lookup = get_normalized_score_lookup( val_candidates_data) if use_pointwise_loss: normalized_train_data = read_cache( name('./normalized_train_query_data_qrels_and_106756.json', names), lambda: normalize_scores_query_wise(train_data)) collate_fn = lambda samples: collate_query_samples( samples, use_bow_model=use_bow_model, use_dense=rabbit.model_params.use_dense) train_dl = build_query_dataloader( documents, normalized_train_data, rabbit.train_params, rabbit.model_params, cache=name('train_ranking_qrels_and_106756.json', names), limit=10, query_tok_to_doc_tok=query_tok_to_doc_tok, normalized_score_lookup=train_normalized_score_lookup, use_bow_model=use_bow_model, collate_fn=collate_fn, is_test=False) test_dl = build_query_dataloader( documents, test_data, rabbit.train_params, rabbit.model_params, query_tok_to_doc_tok=query_tok_to_doc_tok, normalized_score_lookup=test_normalized_score_lookup, use_bow_model=use_bow_model, collate_fn=collate_fn, is_test=True) val_dl = build_query_dataloader( documents, val_data, rabbit.train_params, rabbit.model_params, query_tok_to_doc_tok=query_tok_to_doc_tok, normalized_score_lookup=val_normalized_score_lookup, use_bow_model=use_bow_model, collate_fn=collate_fn, is_test=True) model = PointwiseScorer(query_token_embeds, document_token_embeds, doc_encoder, rabbit.model_params, rabbit.train_params) else: if rabbit.train_params.use_noise_aware_loss: ranker_query_str_to_rankings = get_ranker_query_str_to_rankings( train_query_name_to_id, document_title_to_id, train_queries, limit=rabbit.train_params.num_snorkel_train_queries) query_names = reduce( lambda acc, query_to_ranking: acc.intersection( set(query_to_ranking.keys())) if len(acc) != 0 else set(query_to_ranking.keys()), ranker_query_str_to_rankings.values(), set()) all_ranked_lists_by_ranker = _.map_values( ranker_query_str_to_rankings, lambda query_to_ranking: [query_to_ranking[query] for query in query_names]) ranker_query_str_to_pairwise_bins = get_ranker_query_str_to_pairwise_bins( train_query_name_to_id, document_title_to_id, train_queries, limit=rabbit.train_params.num_train_queries) snorkeller = Snorkeller(ranker_query_str_to_pairwise_bins) snorkeller.train(all_ranked_lists_by_ranker) calc_marginals = snorkeller.calc_marginals else: calc_marginals = None collate_fn = lambda samples: collate_query_pairwise_samples( samples, use_bow_model=use_bow_model, calc_marginals=calc_marginals, use_dense=rabbit.model_params.use_dense) if rabbit.run_params.load_influences: try: with open(rabbit.run_params.influences_path) as fh: pairs_to_flip = defaultdict(set) for pair, influence in json.load(fh): if rabbit.train_params.use_pointwise_loss: condition = True else: condition = influence < rabbit.train_params.influence_thresh if condition: query = tuple(pair[1]) pairs_to_flip[query].add(tuple(pair[0])) except FileNotFoundError: pairs_to_flip = None else: pairs_to_flip = None train_dl = build_query_pairwise_dataloader( documents, train_data, rabbit.train_params, rabbit.model_params, pairs_to_flip=pairs_to_flip, cache=name('train_ranking_qrels_and_106756.json', names), limit=10, query_tok_to_doc_tok=query_tok_to_doc_tok, normalized_score_lookup=train_normalized_score_lookup, use_bow_model=use_bow_model, collate_fn=collate_fn, is_test=False) test_dl = build_query_pairwise_dataloader( documents, test_data, rabbit.train_params, rabbit.model_params, query_tok_to_doc_tok=query_tok_to_doc_tok, normalized_score_lookup=test_normalized_score_lookup, use_bow_model=use_bow_model, collate_fn=collate_fn, is_test=True) val_dl = build_query_pairwise_dataloader( documents, val_data, rabbit.train_params, rabbit.model_params, query_tok_to_doc_tok=query_tok_to_doc_tok, normalized_score_lookup=val_normalized_score_lookup, use_bow_model=use_bow_model, collate_fn=collate_fn, is_test=True) val_rel_dl = build_query_pairwise_dataloader( documents, val_data, rabbit.train_params, rabbit.model_params, query_tok_to_doc_tok=query_tok_to_doc_tok, normalized_score_lookup=val_normalized_score_lookup, use_bow_model=use_bow_model, collate_fn=collate_fn, is_test=True, rel_vs_irrel=True, candidates=val_ranking_candidates, num_to_rank=rabbit.run_params.num_to_rank) model = PairwiseScorer(query_token_embeds, document_token_embeds, doc_encoder, rabbit.model_params, rabbit.train_params, use_bow_model=use_bow_model) train_ranking_dataset = RankingDataset( documents, train_dl.dataset.rankings, rabbit.train_params, rabbit.model_params, rabbit.run_params, query_tok_to_doc_tok=query_tok_to_doc_tok, normalized_score_lookup=train_normalized_score_lookup, use_bow_model=use_bow_model, use_dense=rabbit.model_params.use_dense) test_ranking_dataset = RankingDataset( documents, test_ranking_candidates, rabbit.train_params, rabbit.model_params, rabbit.run_params, relevant=test_dl.dataset.rankings, query_tok_to_doc_tok=query_tok_to_doc_tok, cheat=rabbit.run_params.cheat, normalized_score_lookup=test_normalized_score_lookup, use_bow_model=use_bow_model, use_dense=rabbit.model_params.use_dense) val_ranking_dataset = RankingDataset( documents, val_ranking_candidates, rabbit.train_params, rabbit.model_params, rabbit.run_params, relevant=val_dl.dataset.rankings, query_tok_to_doc_tok=query_tok_to_doc_tok, cheat=rabbit.run_params.cheat, normalized_score_lookup=val_normalized_score_lookup, use_bow_model=use_bow_model, use_dense=rabbit.model_params.use_dense) if rabbit.train_params.memorize_test: train_dl = test_dl train_ranking_dataset = test_ranking_dataset model_data = DataBunch(train_dl, val_rel_dl, test_dl, collate_fn=collate_fn, device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) multi_objective_model = MultiObjective(model, rabbit.train_params, rel_score, additive) model_to_save = multi_objective_model if rabbit.train_params.memorize_test: try: del train_data except: pass if not rabbit.run_params.just_caches: del document_lookup del train_query_lookup del query_token_lookup del document_token_lookup del train_queries try: del glove_lookup except UnboundLocalError: del q_glove_lookup del doc_glove_lookup if rabbit.run_params.load_model: try: multi_objective_model.load_state_dict( torch.load(rabbit.run_params.load_path)) except RuntimeError: dp = nn.DataParallel(multi_objective_model) dp.load_state_dict(torch.load(rabbit.run_params.load_path)) multi_objective_model = dp.module else: train_model(multi_objective_model, model_data, train_ranking_dataset, val_ranking_dataset, test_ranking_dataset, rabbit.train_params, rabbit.model_params, rabbit.run_params, experiment) if rabbit.train_params.fine_tune_on_val: fine_tune_model_data = DataBunch( val_rel_dl, val_rel_dl, test_dl, collate_fn=collate_fn, device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) train_model(multi_objective_model, fine_tune_model_data, val_ranking_dataset, val_ranking_dataset, test_ranking_dataset, rabbit.train_params, rabbit.model_params, rabbit.run_params, experiment, load_path=rabbit.run_params.load_path) multi_objective_model.eval() device = model_data.device gpu_multi_objective_model = multi_objective_model.to(device) if rabbit.run_params.calc_influence: if rabbit.run_params.freeze_all_but_last_for_influence: last_layer_idx = _.find_last_index( multi_objective_model.model.pointwise_scorer.layers, lambda layer: isinstance(layer, nn.Linear)) to_last_layer = lambda x: gpu_multi_objective_model( *x, to_idx=last_layer_idx) last_layer = gpu_multi_objective_model.model.pointwise_scorer.layers[ last_layer_idx] diff_wrt = [p for p in last_layer.parameters() if p.requires_grad] else: diff_wrt = None test_hvps = calc_test_hvps( multi_objective_model.loss, gpu_multi_objective_model, DeviceDataLoader(train_dl, device, collate_fn=collate_fn), val_rel_dl, rabbit.run_params, diff_wrt=diff_wrt, show_progress=True, use_softrank_influence=rabbit.run_params.use_softrank_influence) influences = [] if rabbit.train_params.use_pointwise_loss: num_real_samples = len(train_dl.dataset) else: num_real_samples = train_dl.dataset._num_pos_pairs if rabbit.run_params.freeze_all_but_last_for_influence: _sampler = SequentialSamplerWithLimit(train_dl.dataset, num_real_samples) _batch_sampler = BatchSampler(_sampler, rabbit.train_params.batch_size, False) _dl = DataLoader(train_dl.dataset, batch_sampler=_batch_sampler, collate_fn=collate_fn) sequential_train_dl = DeviceDataLoader(_dl, device, collate_fn=collate_fn) influences = calc_dataset_influence(gpu_multi_objective_model, to_last_layer, sequential_train_dl, test_hvps, sum_p=True).tolist() else: for i in progressbar(range(num_real_samples)): train_sample = train_dl.dataset[i] x, labels = to_device(collate_fn([train_sample]), device) device_train_sample = (x, labels.squeeze()) influences.append( calc_influence(multi_objective_model.loss, gpu_multi_objective_model, device_train_sample, test_hvps, diff_wrt=diff_wrt).sum().tolist()) with open(rabbit.run_params.influences_path, 'w+') as fh: json.dump([[train_dl.dataset[idx][1], influence] for idx, influence in enumerate(influences)], fh)