def create_pn_dataloaders(x: Tensor, y: Tensor, bs: int) \ -> Tuple[DeviceDataLoader, DeviceDataLoader]: r""" Creates training and validation PN \p DataLoader objects from the specified tensors :param x: Feature values tensor :param y: Labels tensor :param bs: \p DataLoader's batch size :return: Tuple of the training and validation \p DataLoader objects respectively """ num_ele = x.shape[0] assert num_ele == y.shape[0], "Mismatch in number of elements" tr_size = int(round(num_ele * (1. - config.VALIDATION_SPLIT_RATIO))) x, y = shuffle_tensors(x, y) tk = sigma = -torch.ones(y.shape).float().cpu() tensor_ds = TensorDataset(x[:tr_size], y[:tr_size], sigma[:tr_size], tk[:tr_size]) train_dl = DeviceDataLoader.create(dataset=tensor_ds, shuffle=True, drop_last=True, bs=bs, num_workers=NUM_WORKERS, device=TORCH_DEVICE) tensor_ds = TensorDataset(x[tr_size:], y[tr_size:], sigma[tr_size:], tk[tr_size:]) valid_dl = DeviceDataLoader.create(dataset=tensor_ds, shuffle=False, drop_last=False, bs=bs, num_workers=NUM_WORKERS, device=TORCH_DEVICE) return train_dl, valid_dl
def _flatten_cifar(tensor_path: Path, dest_dir: Path, device: torch.device): r""" Flattens CIFAR into preprocessed vectors """ # `body` is the base layers of the specified model body = fastai.vision.create_body(MODEL) body.add_module("Flatten", ViewTo1D()) body.eval() body.to(device) # Path to write the processed tensor tensor_ds = TensorDataset(*torch.load(tensor_path)) with torch.no_grad(): dl = DeviceDataLoader.create(tensor_ds, bs=config.BATCH_SIZE, num_workers=0, shuffle=False, drop_last=False, device=device) flat_x, flat_y = [], [] for xs, ys in dl: flat_x.append(body.forward(xs).cpu()) flat_y.append(ys.cpu()) # Concatenate all objects dest_dir.mkdir(exist_ok=True, parents=True) dest_path = dest_dir / tensor_path.name flat_x, flat_y = torch.cat(flat_x, dim=0).cpu(), torch.cat(flat_y, dim=0).cpu() torch.save((flat_x, flat_y), dest_path)
def calculate_results(tg: TensorGroup, our_learner, puc_learners: List[PUcLearner], dest_dir: Optional[Union[Path, str]] = None, exclude_puc: bool = False) -> dict: r""" Calculates and writes to disk the model's results :param tg: Tensor group containing the test conditions :param our_learner: PURR, PU2aPNU, or PU2wUU :param puc_learners: Learner(s) implementing the PUc algorithm :param dest_dir: Location to write the results :param exclude_puc: DEBUG ONLY. Exclude the PUc results. :return: Dictionary containing results of all experiments """ if dest_dir is None: dest_dir = RES_DIR dest_dir = Path(dest_dir) our_learner.eval() all_res = dict() ds_flds = (("unlabel_train", TensorDataset(tg.u_tr_x, tg.u_tr_y)), ("tr_test", TensorDataset(tg.test_x_tr, tg.test_y_tr)), ("unlabel_test", TensorDataset(tg.u_te_x, tg.u_te_y)), ("test", TensorDataset(tg.test_x, tg.test_y))) for block_name, block in our_learner.blocks(): res = LearnerResults() res.loss_name = block.loss.name() res.valid_loss = block.best_loss for ds_name, ds in ds_flds: # noinspection PyTypeChecker dl = DeviceDataLoader.create(ds, shuffle=False, drop_last=False, bs=config.BATCH_SIZE, num_workers=0, device=TORCH_DEVICE) all_y, dec_scores = [], [] with torch.no_grad(): for xs, ys in dl: all_y.append(ys) dec_scores.append(block.forward(xs)) # Iterator transforms label so transform it back y = torch.cat(all_y, dim=0).squeeze().cpu().numpy() dec_scores = torch.cat(dec_scores, dim=0).squeeze().cpu() y_hat, dec_scores = dec_scores.sign().cpu().numpy(), dec_scores.cpu().numpy() # Store for name "unlabel" or "test" res.__setattr__(ds_name, _single_ds_results(block, ds_name, y, y_hat, dec_scores)) if config.DATASET.is_synthetic(): log_decision_boundary(block.module, name=block_name) all_res[block_name] = res if not exclude_puc: for puc in puc_learners: all_res[puc.name()] = _build_puc_results(puc, ds_flds) if config.DATASET.is_synthetic(): log_decision_boundary(puc, name=puc.name()) config.print_configuration() _write_results_to_disk(dest_dir, our_learner.train_start_time(), all_res) return all_res
def temporary_validation_set(self, x_dataset, pause_train=True): """Override the validation set for the duration of this context. This is usually done to perform evaluation of a specific set of data. If pause_train is true, the model is also set to eval for the duration (and restored to its previous setting at the end).""" stashed_train_state = self.model.training stashed_validation_loader = self.data.valid_dl self.data.valid_dl = DeviceDataLoader(x_dataset.as_loader(), self.data.device) if pause_train: self.model.eval() yield True self.data.valid_dl = stashed_validation_loader if pause_train: self.model.train(stashed_train_state)
def set_sigmas(self, sigma_module: nn.Module, config, device: torch.device) -> None: r""" Sets all weights in the \p TensorGroup :param sigma_module: Module that represents :math:`\sigma(x) = \Pr[Y = +1 | x]` in the aPU paper. :param config: Configuration settings for the learner :param device: Device where to run the weight calculations """ sigma_module.eval() # Iterate through each X vector and build the weights all_priors = (1., config.TRAIN_PRIOR, config.TEST_PRIOR) for ds_name, prior in zip(("p", "u_tr", "u_te"), all_priors): x = self.__getattribute__(f"{ds_name}_x") assert x is not None, f"{ds_name}_x cannot be None" dl = DeviceDataLoader.create(dataset=TensorDataset(x), shuffle=False, drop_last=False, bs=config.BATCH_SIZE, num_workers=0, device=device) all_sigma = [] with torch.no_grad(): for xs, in dl: sig_vals = sigma_module.calc_cal_weights(xs) all_sigma.append(sig_vals) w = torch.cat(all_sigma).detach().cpu() tk = self._calc_topk(sigma=w, prior=prior) for _w, suffix in [(w, "sigma"), (tk, "tk")]: if len(_w.shape) > 1: _w = _w.squeeze(dim=1) attr_name = f"{ds_name}_{suffix}" # Sanity check the _weights information assert float(_w.min().item( )) >= 0, "Min sigma must be greater than or equal to 0" assert float(_w.max().item( )) <= 1, "Max sigma must be less than or equal to 1" assert _w.numel() == x.shape[ 0], "Num. of weights does not match num. of elements" assert len( _w.shape) == 1, f"Strange size for {attr_name} vector" assert self.__getattribute__( attr_name) is None, f"{attr_name} is not None" self.__setattr__(attr_name, _w)
def create_apu_dataloaders(ts_grp: TensorGroup, bs: int, inc_cal: bool = False) \ -> Tuple[DeviceDataLoader, DeviceDataLoader, Optional[DeviceDataLoader]]: r""" Creates the training and validation dataloaders :param ts_grp: Stores the raw tensor information :param bs: \p DataLoader's batch size :param inc_cal: If \p True, generate a calibration \p DataLoader :return: Training, validation and optionally calibration \p DataLoader objects respectively """ assert not inc_cal or config.CALIBRATION_SPLIT_RATIO is not None, "Calibration mismatch" all_tr, all_val, all_cal = [], [], [] # Split the tensors into train/validation flds = (("p", Labels.Training.POS), ("u_tr", Labels.Training.U_TRAIN), ("u_te", Labels.Training.U_TEST)) for ds_name, lbl in flds: x = ts_grp.__getattribute__(f"{ds_name}_x") sigma = ts_grp.__getattribute__(f"{ds_name}_sigma") if sigma is None: sigma = -torch.ones( x.shape[:1], dtype=torch.float, device=TORCH_DEVICE) spl_tr, spl_val, spl_cal = _split_tensor(x, lbl.value, sigma, inc_cal) all_tr.append(spl_tr) all_val.append(spl_val) assert (inc_cal and spl_cal is not None) or ( not inc_cal and spl_cal is None), "Cal invalid" if inc_cal: all_cal.append(spl_cal) # Construct the individual dataloaders dls = [] flds = ((all_tr, True), (all_val, False), (all_cal, True)) for spl_info, shuffle in flds: if not spl_info: dls.append(None) continue x = torch.cat([info.x for info in spl_info], dim=0).cpu() y = torch.cat([info.y for info in spl_info], dim=0).cpu() sigma = torch.cat([info.sigma for info in spl_info], dim=0).cpu() dl = DeviceDataLoader.create(dataset=TensorDataset(x, y, sigma), shuffle=shuffle, drop_last=shuffle, bs=bs, num_workers=NUM_WORKERS, device=TORCH_DEVICE) dls.append(dl) # train, validation, and calibration dataloaders respectively # noinspection PyTypeChecker return tuple(dls)
def create_apu_dataloaders( tg: TensorGroup, bs: int) -> Tuple[DeviceDataLoader, DeviceDataLoader]: r""" Creates the training and validation dataloaders :param tg: Stores the raw tensor information :param bs: \p DataLoader's batch size :return: Training, validation and optionally calibration \p DataLoader objects respectively """ all_tr, all_val = [], [] # Split the tensors into train/validation flds = (("p", Labels.Training.POS), ("u_tr", Labels.Training.U_TRAIN), ("u_te", Labels.Training.U_TEST)) for ds_name, lbl in flds: x = tg.__getattribute__(f"{ds_name}_x") sigma = tg.__getattribute__(f"{ds_name}_sigma") if sigma is None: sigma = -torch.ones( x.shape[:1], dtype=torch.float, device=TORCH_DEVICE) tk = tg.__getattribute__(f"{ds_name}_tk") if tk is None: tk = -torch.ones( x.shape[:1], dtype=torch.float, device=TORCH_DEVICE) spl_tr, spl_val = _split_tensor(x, lbl.value, sigma, tk) all_tr.append(spl_tr) all_val.append(spl_val) # Construct the individual dataloaders dls = [] flds = ((all_tr, True), (all_val, False)) for spl_info, shuffle in flds: if not spl_info: dls.append(None) continue x = torch.cat([info.x for info in spl_info], dim=0).cpu() y = torch.cat([info.y for info in spl_info], dim=0).cpu() sigma = torch.cat([info.sigma for info in spl_info], dim=0).cpu() tk = torch.cat([info.tk for info in spl_info], dim=0).cpu() dl = DeviceDataLoader.create(dataset=TensorDataset(x, y, sigma, tk), shuffle=shuffle, drop_last=shuffle, bs=bs, num_workers=NUM_WORKERS, device=TORCH_DEVICE) dls.append(dl) # train, validation, and calibration dataloaders respectively # noinspection PyTypeChecker return tuple(dls)
def create_pu_dataloaders(p_x: Tensor, u_x: Tensor, bs: int) \ -> Tuple[DeviceDataLoader, DeviceDataLoader]: r""" Simple method that splits the positive and unlabeled sets into stratified training and validation \p DataLoader objects :param p_x: Feature vectors for the positive (labeled) examples :param u_x: Feature vectors for the unlabeled examples :param bs: \p DataLoader's batch size :return: Training and validation \p DataLoader objects respectively """ tr_x, tr_y, val_x, val_y = [], [], [], [] for x, lbl in ((p_x, Labels.Training.POS), (u_x, Labels.Training.U_TRAIN)): num_ex = x.shape[0] tr_size = int((1 - config.VALIDATION_SPLIT_RATIO) * num_ex) x = shuffle_tensors(x) tr_x.append(x[:tr_size]) tr_y.append(torch.full([tr_size], lbl.value, dtype=torch.int)) val_x.append(x[tr_size:]) val_y.append(torch.full([num_ex - tr_size], lbl.value, dtype=torch.int)) def _cat_tensors(lst_tensors: List[Tensor]) -> Tensor: return torch.cat(lst_tensors, dim=0).cpu() tr_x, tr_y = shuffle_tensors(_cat_tensors(tr_x), _cat_tensors(tr_y)) val_x, val_y = shuffle_tensors(_cat_tensors(val_x), _cat_tensors(val_y)) # Create the training and validation dataloaders respectively dls = [] for x, y, shuffle in ((tr_x, tr_y, True), (val_x, val_y, False)): tk = w = -torch.ones_like(y).float().cpu() dls.append( DeviceDataLoader.create(TensorDataset(x.cpu(), y.cpu(), w.cpu(), tk.cpu()), shuffle=shuffle, drop_last=shuffle and not config.DATASET.is_synthetic(), bs=bs, num_workers=NUM_WORKERS, device=TORCH_DEVICE)) # noinspection PyTypeChecker return tuple(dls)
def construct_loader( ds: Union[TensorDataset, TextDataset], bs: int, shuffle: bool = True, drop_last: bool = False) -> Union[DeviceDataLoader, Iterator]: r""" Construct \p Iterator which emulates a \p DataLoader """ if isinstance(ds, TextDataset): return Iterator(dataset=ds, batch_size=bs, shuffle=shuffle, device=TORCH_DEVICE) dl = DataLoader(dataset=ds, batch_size=bs, shuffle=shuffle, drop_last=drop_last, num_workers=NUM_WORKERS, pin_memory=False) # noinspection PyArgumentList return DeviceDataLoader(dl=dl, device=TORCH_DEVICE)
set_seed(100) train_ds = MoleculeDataset(train_mol_ids, gb_mol_sc, gb_mol_atom, gb_mol_bond, gb_mol_struct, gb_mol_angle_in, gb_mol_angle_out, gb_mol_graph_dist) val_ds = MoleculeDataset(val_mol_ids, gb_mol_sc, gb_mol_atom, gb_mol_bond, gb_mol_struct, gb_mol_angle_in, gb_mol_angle_out, gb_mol_graph_dist) test_ds = MoleculeDataset(test_mol_ids, test_gb_mol_sc, gb_mol_atom, gb_mol_bond, gb_mol_struct, gb_mol_angle_in, gb_mol_angle_out, gb_mol_graph_dist) train_dl = DataLoader(train_ds, args.batch_size, shuffle=True, num_workers=8) val_dl = DataLoader(val_ds, args.batch_size, num_workers=8) test_dl = DeviceDataLoader.create(test_ds, args.batch_size, num_workers=8, collate_fn=partial(collate_parallel_fn, test=True)) db = DataBunch(train_dl, val_dl, collate_fn=collate_parallel_fn) db.test_dl = test_dl # set up model set_seed(100) d_model = args.d_model enn_args = dict(layers=3 * [d_model], dropout=3 * [0.0], layer_norm=True) ann_args = dict(layers=1 * [d_model], dropout=1 * [0.0], layer_norm=True, out_act=nn.Tanh()) model = Transformer(C.N_ATOM_FEATURES,
def main(): global model_to_save global experiment global rabbit rabbit = MyRabbit(args) if rabbit.model_params.dont_limit_num_uniq_tokens: raise NotImplementedError() if rabbit.model_params.frame_as_qa: raise NotImplementedError if rabbit.run_params.drop_val_loss_calc: raise NotImplementedError if rabbit.run_params.use_softrank_influence and not rabbit.run_params.freeze_all_but_last_for_influence: raise NotImplementedError if rabbit.train_params.weight_influence: raise NotImplementedError experiment = Experiment(rabbit.train_params + rabbit.model_params + rabbit.run_params) print('Model name:', experiment.model_name) use_pretrained_doc_encoder = rabbit.model_params.use_pretrained_doc_encoder use_pointwise_loss = rabbit.train_params.use_pointwise_loss query_token_embed_len = rabbit.model_params.query_token_embed_len document_token_embed_len = rabbit.model_params.document_token_embed_len _names = [] if not rabbit.model_params.dont_include_titles: _names.append('with_titles') if rabbit.train_params.num_doc_tokens_to_consider != -1: _names.append('num_doc_toks_' + str(rabbit.train_params.num_doc_tokens_to_consider)) if not rabbit.run_params.just_caches: if rabbit.model_params.dont_include_titles: document_lookup = read_cache(name('./doc_lookup.json', _names), get_robust_documents) else: document_lookup = read_cache(name('./doc_lookup.json', _names), get_robust_documents_with_titles) num_doc_tokens_to_consider = rabbit.train_params.num_doc_tokens_to_consider document_title_to_id = read_cache( './document_title_to_id.json', lambda: create_id_lookup(document_lookup.keys())) with open('./caches/106756_most_common_doc.json', 'r') as fh: doc_token_set = set(json.load(fh)) tokenizer = Tokenizer() tokenized = set( sum( tokenizer.process_all(list( get_robust_eval_queries().values())), [])) doc_token_set = doc_token_set.union(tokenized) use_bow_model = not any([ rabbit.model_params[attr] for attr in ['use_doc_out', 'use_cnn', 'use_lstm', 'use_pretrained_doc_encoder'] ]) use_bow_model = use_bow_model and not rabbit.model_params.dont_use_bow if use_bow_model: documents, document_token_lookup = read_cache( name(f'./docs_fs_tokens_limit_uniq_toks_qrels_and_106756.pkl', _names), lambda: prepare_fs(document_lookup, document_title_to_id, num_tokens=num_doc_tokens_to_consider, token_set=doc_token_set)) if rabbit.model_params.keep_top_uniq_terms is not None: documents = [ dict( nlargest(rabbit.model_params.keep_top_uniq_terms, _.to_pairs(doc), itemgetter(1))) for doc in documents ] else: documents, document_token_lookup = read_cache( name( f'./parsed_docs_{num_doc_tokens_to_consider}_tokens_limit_uniq_toks_qrels_and_106756.json', _names), lambda: prepare(document_lookup, document_title_to_id, num_tokens=num_doc_tokens_to_consider, token_set=doc_token_set)) if not rabbit.run_params.just_caches: train_query_lookup = read_cache('./robust_train_queries.json', get_robust_train_queries) train_query_name_to_id = read_cache( './train_query_name_to_id.json', lambda: create_id_lookup(train_query_lookup.keys())) train_queries, query_token_lookup = read_cache( './parsed_robust_queries_dict.json', lambda: prepare(train_query_lookup, train_query_name_to_id, token_lookup=document_token_lookup, token_set=doc_token_set, drop_if_any_unk=True)) query_tok_to_doc_tok = { idx: document_token_lookup.get(query_token) or document_token_lookup['<unk>'] for query_token, idx in query_token_lookup.items() } names = [RANKER_NAME_TO_SUFFIX[rabbit.train_params.ranking_set]] if rabbit.train_params.use_pointwise_loss or not rabbit.run_params.just_caches: train_data = read_cache( name('./robust_train_query_results_tokens_qrels_and_106756.json', names), lambda: read_query_result( train_query_name_to_id, document_title_to_id, train_queries, path='./indri/query_result' + RANKER_NAME_TO_SUFFIX[ rabbit.train_params.ranking_set])) else: train_data = [] q_embed_len = rabbit.model_params.query_token_embed_len doc_embed_len = rabbit.model_params.document_token_embed_len if rabbit.model_params.append_difference or rabbit.model_params.append_hadamard: assert q_embed_len == doc_embed_len, 'Must use same size doc and query embeds when appending diff or hadamard' if q_embed_len == doc_embed_len: glove_lookup = get_glove_lookup( embedding_dim=q_embed_len, use_large_embed=rabbit.model_params.use_large_embed, use_word2vec=rabbit.model_params.use_word2vec) q_glove_lookup = glove_lookup doc_glove_lookup = glove_lookup else: q_glove_lookup = get_glove_lookup( embedding_dim=q_embed_len, use_large_embed=rabbit.model_params.use_large_embed, use_word2vec=rabbit.model_params.use_word2vec) doc_glove_lookup = get_glove_lookup( embedding_dim=doc_embed_len, use_large_embed=rabbit.model_params.use_large_embed, use_word2vec=rabbit.model_params.use_word2vec) num_query_tokens = len(query_token_lookup) num_doc_tokens = len(document_token_lookup) doc_encoder = None if use_pretrained_doc_encoder or rabbit.model_params.use_doc_out: doc_encoder, document_token_embeds = get_doc_encoder_and_embeddings( document_token_lookup, rabbit.model_params.only_use_last_out) if rabbit.model_params.use_glove: query_token_embeds_init = init_embedding(q_glove_lookup, query_token_lookup, num_query_tokens, query_token_embed_len) else: query_token_embeds_init = from_doc_to_query_embeds( document_token_embeds, document_token_lookup, query_token_lookup) if not rabbit.train_params.dont_freeze_pretrained_doc_encoder: dont_update(doc_encoder) if rabbit.model_params.use_doc_out: doc_encoder = None else: document_token_embeds = init_embedding(doc_glove_lookup, document_token_lookup, num_doc_tokens, document_token_embed_len) if rabbit.model_params.use_single_word_embed_set: query_token_embeds_init = document_token_embeds else: query_token_embeds_init = init_embedding(q_glove_lookup, query_token_lookup, num_query_tokens, query_token_embed_len) if not rabbit.train_params.dont_freeze_word_embeds: dont_update(document_token_embeds) dont_update(query_token_embeds_init) else: do_update(document_token_embeds) do_update(query_token_embeds_init) if rabbit.train_params.add_rel_score: query_token_embeds, additive = get_additive_regularized_embeds( query_token_embeds_init) rel_score = RelScore(query_token_embeds, document_token_embeds, rabbit.model_params, rabbit.train_params) else: query_token_embeds = query_token_embeds_init additive = None rel_score = None eval_query_lookup = get_robust_eval_queries() eval_query_name_document_title_rels = get_robust_rels() test_query_names = [] val_query_names = [] for query_name in eval_query_lookup: if len(val_query_names) >= 50: test_query_names.append(query_name) else: val_query_names.append(query_name) test_query_name_document_title_rels = _.pick( eval_query_name_document_title_rels, test_query_names) test_query_lookup = _.pick(eval_query_lookup, test_query_names) test_query_name_to_id = create_id_lookup(test_query_lookup.keys()) test_queries, __ = prepare(test_query_lookup, test_query_name_to_id, token_lookup=query_token_lookup) eval_ranking_candidates = read_query_test_rankings( './indri/query_result_test' + RANKER_NAME_TO_SUFFIX[rabbit.train_params.ranking_set]) test_candidates_data = read_query_result( test_query_name_to_id, document_title_to_id, dict(zip(range(len(test_queries)), test_queries)), path='./indri/query_result_test' + RANKER_NAME_TO_SUFFIX[rabbit.train_params.ranking_set]) test_ranking_candidates = process_raw_candidates(test_query_name_to_id, test_queries, document_title_to_id, test_query_names, eval_ranking_candidates) test_data = process_rels(test_query_name_document_title_rels, document_title_to_id, test_query_name_to_id, test_queries) val_query_name_document_title_rels = _.pick( eval_query_name_document_title_rels, val_query_names) val_query_lookup = _.pick(eval_query_lookup, val_query_names) val_query_name_to_id = create_id_lookup(val_query_lookup.keys()) val_queries, __ = prepare(val_query_lookup, val_query_name_to_id, token_lookup=query_token_lookup) val_candidates_data = read_query_result( val_query_name_to_id, document_title_to_id, dict(zip(range(len(val_queries)), val_queries)), path='./indri/query_result_test' + RANKER_NAME_TO_SUFFIX[rabbit.train_params.ranking_set]) val_ranking_candidates = process_raw_candidates(val_query_name_to_id, val_queries, document_title_to_id, val_query_names, eval_ranking_candidates) val_data = process_rels(val_query_name_document_title_rels, document_title_to_id, val_query_name_to_id, val_queries) train_normalized_score_lookup = read_cache( name('./train_normalized_score_lookup.pkl', names), lambda: get_normalized_score_lookup(train_data)) test_normalized_score_lookup = get_normalized_score_lookup( test_candidates_data) val_normalized_score_lookup = get_normalized_score_lookup( val_candidates_data) if use_pointwise_loss: normalized_train_data = read_cache( name('./normalized_train_query_data_qrels_and_106756.json', names), lambda: normalize_scores_query_wise(train_data)) collate_fn = lambda samples: collate_query_samples( samples, use_bow_model=use_bow_model, use_dense=rabbit.model_params.use_dense) train_dl = build_query_dataloader( documents, normalized_train_data, rabbit.train_params, rabbit.model_params, cache=name('train_ranking_qrels_and_106756.json', names), limit=10, query_tok_to_doc_tok=query_tok_to_doc_tok, normalized_score_lookup=train_normalized_score_lookup, use_bow_model=use_bow_model, collate_fn=collate_fn, is_test=False) test_dl = build_query_dataloader( documents, test_data, rabbit.train_params, rabbit.model_params, query_tok_to_doc_tok=query_tok_to_doc_tok, normalized_score_lookup=test_normalized_score_lookup, use_bow_model=use_bow_model, collate_fn=collate_fn, is_test=True) val_dl = build_query_dataloader( documents, val_data, rabbit.train_params, rabbit.model_params, query_tok_to_doc_tok=query_tok_to_doc_tok, normalized_score_lookup=val_normalized_score_lookup, use_bow_model=use_bow_model, collate_fn=collate_fn, is_test=True) model = PointwiseScorer(query_token_embeds, document_token_embeds, doc_encoder, rabbit.model_params, rabbit.train_params) else: if rabbit.train_params.use_noise_aware_loss: ranker_query_str_to_rankings = get_ranker_query_str_to_rankings( train_query_name_to_id, document_title_to_id, train_queries, limit=rabbit.train_params.num_snorkel_train_queries) query_names = reduce( lambda acc, query_to_ranking: acc.intersection( set(query_to_ranking.keys())) if len(acc) != 0 else set(query_to_ranking.keys()), ranker_query_str_to_rankings.values(), set()) all_ranked_lists_by_ranker = _.map_values( ranker_query_str_to_rankings, lambda query_to_ranking: [query_to_ranking[query] for query in query_names]) ranker_query_str_to_pairwise_bins = get_ranker_query_str_to_pairwise_bins( train_query_name_to_id, document_title_to_id, train_queries, limit=rabbit.train_params.num_train_queries) snorkeller = Snorkeller(ranker_query_str_to_pairwise_bins) snorkeller.train(all_ranked_lists_by_ranker) calc_marginals = snorkeller.calc_marginals else: calc_marginals = None collate_fn = lambda samples: collate_query_pairwise_samples( samples, use_bow_model=use_bow_model, calc_marginals=calc_marginals, use_dense=rabbit.model_params.use_dense) if rabbit.run_params.load_influences: try: with open(rabbit.run_params.influences_path) as fh: pairs_to_flip = defaultdict(set) for pair, influence in json.load(fh): if rabbit.train_params.use_pointwise_loss: condition = True else: condition = influence < rabbit.train_params.influence_thresh if condition: query = tuple(pair[1]) pairs_to_flip[query].add(tuple(pair[0])) except FileNotFoundError: pairs_to_flip = None else: pairs_to_flip = None train_dl = build_query_pairwise_dataloader( documents, train_data, rabbit.train_params, rabbit.model_params, pairs_to_flip=pairs_to_flip, cache=name('train_ranking_qrels_and_106756.json', names), limit=10, query_tok_to_doc_tok=query_tok_to_doc_tok, normalized_score_lookup=train_normalized_score_lookup, use_bow_model=use_bow_model, collate_fn=collate_fn, is_test=False) test_dl = build_query_pairwise_dataloader( documents, test_data, rabbit.train_params, rabbit.model_params, query_tok_to_doc_tok=query_tok_to_doc_tok, normalized_score_lookup=test_normalized_score_lookup, use_bow_model=use_bow_model, collate_fn=collate_fn, is_test=True) val_dl = build_query_pairwise_dataloader( documents, val_data, rabbit.train_params, rabbit.model_params, query_tok_to_doc_tok=query_tok_to_doc_tok, normalized_score_lookup=val_normalized_score_lookup, use_bow_model=use_bow_model, collate_fn=collate_fn, is_test=True) val_rel_dl = build_query_pairwise_dataloader( documents, val_data, rabbit.train_params, rabbit.model_params, query_tok_to_doc_tok=query_tok_to_doc_tok, normalized_score_lookup=val_normalized_score_lookup, use_bow_model=use_bow_model, collate_fn=collate_fn, is_test=True, rel_vs_irrel=True, candidates=val_ranking_candidates, num_to_rank=rabbit.run_params.num_to_rank) model = PairwiseScorer(query_token_embeds, document_token_embeds, doc_encoder, rabbit.model_params, rabbit.train_params, use_bow_model=use_bow_model) train_ranking_dataset = RankingDataset( documents, train_dl.dataset.rankings, rabbit.train_params, rabbit.model_params, rabbit.run_params, query_tok_to_doc_tok=query_tok_to_doc_tok, normalized_score_lookup=train_normalized_score_lookup, use_bow_model=use_bow_model, use_dense=rabbit.model_params.use_dense) test_ranking_dataset = RankingDataset( documents, test_ranking_candidates, rabbit.train_params, rabbit.model_params, rabbit.run_params, relevant=test_dl.dataset.rankings, query_tok_to_doc_tok=query_tok_to_doc_tok, cheat=rabbit.run_params.cheat, normalized_score_lookup=test_normalized_score_lookup, use_bow_model=use_bow_model, use_dense=rabbit.model_params.use_dense) val_ranking_dataset = RankingDataset( documents, val_ranking_candidates, rabbit.train_params, rabbit.model_params, rabbit.run_params, relevant=val_dl.dataset.rankings, query_tok_to_doc_tok=query_tok_to_doc_tok, cheat=rabbit.run_params.cheat, normalized_score_lookup=val_normalized_score_lookup, use_bow_model=use_bow_model, use_dense=rabbit.model_params.use_dense) if rabbit.train_params.memorize_test: train_dl = test_dl train_ranking_dataset = test_ranking_dataset model_data = DataBunch(train_dl, val_rel_dl, test_dl, collate_fn=collate_fn, device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) multi_objective_model = MultiObjective(model, rabbit.train_params, rel_score, additive) model_to_save = multi_objective_model if rabbit.train_params.memorize_test: try: del train_data except: pass if not rabbit.run_params.just_caches: del document_lookup del train_query_lookup del query_token_lookup del document_token_lookup del train_queries try: del glove_lookup except UnboundLocalError: del q_glove_lookup del doc_glove_lookup if rabbit.run_params.load_model: try: multi_objective_model.load_state_dict( torch.load(rabbit.run_params.load_path)) except RuntimeError: dp = nn.DataParallel(multi_objective_model) dp.load_state_dict(torch.load(rabbit.run_params.load_path)) multi_objective_model = dp.module else: train_model(multi_objective_model, model_data, train_ranking_dataset, val_ranking_dataset, test_ranking_dataset, rabbit.train_params, rabbit.model_params, rabbit.run_params, experiment) if rabbit.train_params.fine_tune_on_val: fine_tune_model_data = DataBunch( val_rel_dl, val_rel_dl, test_dl, collate_fn=collate_fn, device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) train_model(multi_objective_model, fine_tune_model_data, val_ranking_dataset, val_ranking_dataset, test_ranking_dataset, rabbit.train_params, rabbit.model_params, rabbit.run_params, experiment, load_path=rabbit.run_params.load_path) multi_objective_model.eval() device = model_data.device gpu_multi_objective_model = multi_objective_model.to(device) if rabbit.run_params.calc_influence: if rabbit.run_params.freeze_all_but_last_for_influence: last_layer_idx = _.find_last_index( multi_objective_model.model.pointwise_scorer.layers, lambda layer: isinstance(layer, nn.Linear)) to_last_layer = lambda x: gpu_multi_objective_model( *x, to_idx=last_layer_idx) last_layer = gpu_multi_objective_model.model.pointwise_scorer.layers[ last_layer_idx] diff_wrt = [p for p in last_layer.parameters() if p.requires_grad] else: diff_wrt = None test_hvps = calc_test_hvps( multi_objective_model.loss, gpu_multi_objective_model, DeviceDataLoader(train_dl, device, collate_fn=collate_fn), val_rel_dl, rabbit.run_params, diff_wrt=diff_wrt, show_progress=True, use_softrank_influence=rabbit.run_params.use_softrank_influence) influences = [] if rabbit.train_params.use_pointwise_loss: num_real_samples = len(train_dl.dataset) else: num_real_samples = train_dl.dataset._num_pos_pairs if rabbit.run_params.freeze_all_but_last_for_influence: _sampler = SequentialSamplerWithLimit(train_dl.dataset, num_real_samples) _batch_sampler = BatchSampler(_sampler, rabbit.train_params.batch_size, False) _dl = DataLoader(train_dl.dataset, batch_sampler=_batch_sampler, collate_fn=collate_fn) sequential_train_dl = DeviceDataLoader(_dl, device, collate_fn=collate_fn) influences = calc_dataset_influence(gpu_multi_objective_model, to_last_layer, sequential_train_dl, test_hvps, sum_p=True).tolist() else: for i in progressbar(range(num_real_samples)): train_sample = train_dl.dataset[i] x, labels = to_device(collate_fn([train_sample]), device) device_train_sample = (x, labels.squeeze()) influences.append( calc_influence(multi_objective_model.loss, gpu_multi_objective_model, device_train_sample, test_hvps, diff_wrt=diff_wrt).sum().tolist()) with open(rabbit.run_params.influences_path, 'w+') as fh: json.dump([[train_dl.dataset[idx][1], influence] for idx, influence in enumerate(influences)], fh)