Пример #1
0
def create_pn_dataloaders(x: Tensor, y: Tensor, bs: int) \
        -> Tuple[DeviceDataLoader, DeviceDataLoader]:
    r"""
    Creates training and validation PN \p DataLoader objects from the specified tensors
    :param x: Feature values tensor
    :param y: Labels tensor
    :param bs: \p DataLoader's batch size
    :return: Tuple of the training and validation \p DataLoader objects respectively
    """
    num_ele = x.shape[0]
    assert num_ele == y.shape[0], "Mismatch in number of elements"

    tr_size = int(round(num_ele * (1. - config.VALIDATION_SPLIT_RATIO)))
    x, y = shuffle_tensors(x, y)
    tk = sigma = -torch.ones(y.shape).float().cpu()

    tensor_ds = TensorDataset(x[:tr_size], y[:tr_size], sigma[:tr_size],
                              tk[:tr_size])
    train_dl = DeviceDataLoader.create(dataset=tensor_ds,
                                       shuffle=True,
                                       drop_last=True,
                                       bs=bs,
                                       num_workers=NUM_WORKERS,
                                       device=TORCH_DEVICE)

    tensor_ds = TensorDataset(x[tr_size:], y[tr_size:], sigma[tr_size:],
                              tk[tr_size:])
    valid_dl = DeviceDataLoader.create(dataset=tensor_ds,
                                       shuffle=False,
                                       drop_last=False,
                                       bs=bs,
                                       num_workers=NUM_WORKERS,
                                       device=TORCH_DEVICE)

    return train_dl, valid_dl
Пример #2
0
def _flatten_cifar(tensor_path: Path, dest_dir: Path, device: torch.device):
    r""" Flattens CIFAR into preprocessed vectors """
    # `body` is the base layers of the specified model
    body = fastai.vision.create_body(MODEL)
    body.add_module("Flatten", ViewTo1D())
    body.eval()
    body.to(device)

    # Path to write the processed tensor
    tensor_ds = TensorDataset(*torch.load(tensor_path))
    with torch.no_grad():
        dl = DeviceDataLoader.create(tensor_ds,
                                     bs=config.BATCH_SIZE,
                                     num_workers=0,
                                     shuffle=False,
                                     drop_last=False,
                                     device=device)
        flat_x, flat_y = [], []
        for xs, ys in dl:
            flat_x.append(body.forward(xs).cpu())
            flat_y.append(ys.cpu())

    # Concatenate all objects
    dest_dir.mkdir(exist_ok=True, parents=True)
    dest_path = dest_dir / tensor_path.name
    flat_x, flat_y = torch.cat(flat_x, dim=0).cpu(), torch.cat(flat_y,
                                                               dim=0).cpu()
    torch.save((flat_x, flat_y), dest_path)
def calculate_results(tg: TensorGroup, our_learner, puc_learners: List[PUcLearner],
                      dest_dir: Optional[Union[Path, str]] = None,
                      exclude_puc: bool = False) -> dict:
    r"""
    Calculates and writes to disk the model's results

    :param tg: Tensor group containing the test conditions
    :param our_learner: PURR, PU2aPNU, or PU2wUU
    :param puc_learners: Learner(s) implementing the PUc algorithm
    :param dest_dir: Location to write the results
    :param exclude_puc: DEBUG ONLY. Exclude the PUc results.
    :return: Dictionary containing results of all experiments
    """
    if dest_dir is None: dest_dir = RES_DIR
    dest_dir = Path(dest_dir)

    our_learner.eval()

    all_res = dict()
    ds_flds = (("unlabel_train", TensorDataset(tg.u_tr_x, tg.u_tr_y)),
               ("tr_test", TensorDataset(tg.test_x_tr, tg.test_y_tr)),
               ("unlabel_test", TensorDataset(tg.u_te_x, tg.u_te_y)),
               ("test", TensorDataset(tg.test_x, tg.test_y)))

    for block_name, block in our_learner.blocks():
        res = LearnerResults()
        res.loss_name = block.loss.name()
        res.valid_loss = block.best_loss

        for ds_name, ds in ds_flds:
            # noinspection PyTypeChecker
            dl = DeviceDataLoader.create(ds, shuffle=False, drop_last=False, bs=config.BATCH_SIZE,
                                         num_workers=0, device=TORCH_DEVICE)
            all_y, dec_scores = [], []
            with torch.no_grad():
                for xs, ys in dl:
                    all_y.append(ys)
                    dec_scores.append(block.forward(xs))

            # Iterator transforms label so transform it back
            y = torch.cat(all_y, dim=0).squeeze().cpu().numpy()
            dec_scores = torch.cat(dec_scores, dim=0).squeeze().cpu()
            y_hat, dec_scores = dec_scores.sign().cpu().numpy(), dec_scores.cpu().numpy()
            # Store for name "unlabel" or "test"
            res.__setattr__(ds_name, _single_ds_results(block, ds_name, y, y_hat, dec_scores))

        if config.DATASET.is_synthetic():
            log_decision_boundary(block.module, name=block_name)
        all_res[block_name] = res

    if not exclude_puc:
        for puc in puc_learners:
            all_res[puc.name()] = _build_puc_results(puc, ds_flds)
            if config.DATASET.is_synthetic():
                log_decision_boundary(puc, name=puc.name())

    config.print_configuration()
    _write_results_to_disk(dest_dir, our_learner.train_start_time(), all_res)

    return all_res
Пример #4
0
 def temporary_validation_set(self, x_dataset, pause_train=True):
     """Override the validation set for the duration of this context.  This is usually done to perform evaluation of a specific set of data.
     If pause_train is true, the model is also set to eval for the duration (and restored to its previous setting at the end)."""
     stashed_train_state = self.model.training
     stashed_validation_loader = self.data.valid_dl
     self.data.valid_dl = DeviceDataLoader(x_dataset.as_loader(),
                                           self.data.device)
     if pause_train: self.model.eval()
     yield True
     self.data.valid_dl = stashed_validation_loader
     if pause_train: self.model.train(stashed_train_state)
Пример #5
0
    def set_sigmas(self, sigma_module: nn.Module, config,
                   device: torch.device) -> None:
        r"""
        Sets all weights in the \p TensorGroup

        :param sigma_module: Module that represents :math:`\sigma(x) = \Pr[Y = +1 | x]` in the
                             aPU paper.
        :param config: Configuration settings for the learner
        :param device: Device where to run the weight calculations
        """
        sigma_module.eval()

        # Iterate through each X vector and build the weights
        all_priors = (1., config.TRAIN_PRIOR, config.TEST_PRIOR)
        for ds_name, prior in zip(("p", "u_tr", "u_te"), all_priors):
            x = self.__getattribute__(f"{ds_name}_x")
            assert x is not None, f"{ds_name}_x cannot be None"

            dl = DeviceDataLoader.create(dataset=TensorDataset(x),
                                         shuffle=False,
                                         drop_last=False,
                                         bs=config.BATCH_SIZE,
                                         num_workers=0,
                                         device=device)
            all_sigma = []
            with torch.no_grad():
                for xs, in dl:
                    sig_vals = sigma_module.calc_cal_weights(xs)
                    all_sigma.append(sig_vals)

            w = torch.cat(all_sigma).detach().cpu()
            tk = self._calc_topk(sigma=w, prior=prior)

            for _w, suffix in [(w, "sigma"), (tk, "tk")]:
                if len(_w.shape) > 1:
                    _w = _w.squeeze(dim=1)
                attr_name = f"{ds_name}_{suffix}"

                # Sanity check the _weights information
                assert float(_w.min().item(
                )) >= 0, "Min sigma must be greater than or equal to 0"
                assert float(_w.max().item(
                )) <= 1, "Max sigma must be less than or equal to 1"
                assert _w.numel() == x.shape[
                    0], "Num. of weights does not match num. of elements"
                assert len(
                    _w.shape) == 1, f"Strange size for {attr_name} vector"

                assert self.__getattribute__(
                    attr_name) is None, f"{attr_name} is not None"
                self.__setattr__(attr_name, _w)
Пример #6
0
def create_apu_dataloaders(ts_grp: TensorGroup, bs: int, inc_cal: bool = False) \
        -> Tuple[DeviceDataLoader, DeviceDataLoader, Optional[DeviceDataLoader]]:
    r"""
    Creates the training and validation dataloaders
    :param ts_grp: Stores the raw tensor information
    :param bs: \p DataLoader's batch size
    :param inc_cal: If \p True, generate a calibration \p DataLoader
    :return: Training, validation and optionally calibration \p DataLoader objects respectively
    """
    assert not inc_cal or config.CALIBRATION_SPLIT_RATIO is not None, "Calibration mismatch"

    all_tr, all_val, all_cal = [], [], []
    # Split the tensors into train/validation
    flds = (("p", Labels.Training.POS), ("u_tr", Labels.Training.U_TRAIN),
            ("u_te", Labels.Training.U_TEST))
    for ds_name, lbl in flds:
        x = ts_grp.__getattribute__(f"{ds_name}_x")
        sigma = ts_grp.__getattribute__(f"{ds_name}_sigma")
        if sigma is None:
            sigma = -torch.ones(
                x.shape[:1], dtype=torch.float, device=TORCH_DEVICE)

        spl_tr, spl_val, spl_cal = _split_tensor(x, lbl.value, sigma, inc_cal)
        all_tr.append(spl_tr)
        all_val.append(spl_val)

        assert (inc_cal and spl_cal is not None) or (
            not inc_cal and spl_cal is None), "Cal invalid"
        if inc_cal: all_cal.append(spl_cal)

    # Construct the individual dataloaders
    dls = []
    flds = ((all_tr, True), (all_val, False), (all_cal, True))
    for spl_info, shuffle in flds:
        if not spl_info:
            dls.append(None)
            continue

        x = torch.cat([info.x for info in spl_info], dim=0).cpu()
        y = torch.cat([info.y for info in spl_info], dim=0).cpu()
        sigma = torch.cat([info.sigma for info in spl_info], dim=0).cpu()
        dl = DeviceDataLoader.create(dataset=TensorDataset(x, y, sigma),
                                     shuffle=shuffle,
                                     drop_last=shuffle,
                                     bs=bs,
                                     num_workers=NUM_WORKERS,
                                     device=TORCH_DEVICE)
        dls.append(dl)
    # train, validation, and calibration dataloaders respectively
    # noinspection PyTypeChecker
    return tuple(dls)
Пример #7
0
def create_apu_dataloaders(
        tg: TensorGroup, bs: int) -> Tuple[DeviceDataLoader, DeviceDataLoader]:
    r"""
    Creates the training and validation dataloaders
    :param tg: Stores the raw tensor information
    :param bs: \p DataLoader's batch size
    :return: Training, validation and optionally calibration \p DataLoader objects respectively
    """
    all_tr, all_val = [], []
    # Split the tensors into train/validation
    flds = (("p", Labels.Training.POS), ("u_tr", Labels.Training.U_TRAIN),
            ("u_te", Labels.Training.U_TEST))
    for ds_name, lbl in flds:
        x = tg.__getattribute__(f"{ds_name}_x")
        sigma = tg.__getattribute__(f"{ds_name}_sigma")
        if sigma is None:
            sigma = -torch.ones(
                x.shape[:1], dtype=torch.float, device=TORCH_DEVICE)
        tk = tg.__getattribute__(f"{ds_name}_tk")
        if tk is None:
            tk = -torch.ones(
                x.shape[:1], dtype=torch.float, device=TORCH_DEVICE)

        spl_tr, spl_val = _split_tensor(x, lbl.value, sigma, tk)
        all_tr.append(spl_tr)
        all_val.append(spl_val)

    # Construct the individual dataloaders
    dls = []
    flds = ((all_tr, True), (all_val, False))
    for spl_info, shuffle in flds:
        if not spl_info:
            dls.append(None)
            continue

        x = torch.cat([info.x for info in spl_info], dim=0).cpu()
        y = torch.cat([info.y for info in spl_info], dim=0).cpu()
        sigma = torch.cat([info.sigma for info in spl_info], dim=0).cpu()
        tk = torch.cat([info.tk for info in spl_info], dim=0).cpu()

        dl = DeviceDataLoader.create(dataset=TensorDataset(x, y, sigma, tk),
                                     shuffle=shuffle,
                                     drop_last=shuffle,
                                     bs=bs,
                                     num_workers=NUM_WORKERS,
                                     device=TORCH_DEVICE)
        dls.append(dl)
    # train, validation, and calibration dataloaders respectively
    # noinspection PyTypeChecker
    return tuple(dls)
Пример #8
0
def create_pu_dataloaders(p_x: Tensor, u_x: Tensor, bs: int) \
        -> Tuple[DeviceDataLoader, DeviceDataLoader]:
    r"""
    Simple method that splits the positive and unlabeled sets into stratified training and
    validation \p DataLoader objects

    :param p_x: Feature vectors for the positive (labeled) examples
    :param u_x: Feature vectors for the unlabeled examples
    :param bs: \p DataLoader's batch size
    :return: Training and validation \p DataLoader objects respectively
    """
    tr_x, tr_y, val_x, val_y = [], [], [], []
    for x, lbl in ((p_x, Labels.Training.POS), (u_x, Labels.Training.U_TRAIN)):
        num_ex = x.shape[0]
        tr_size = int((1 - config.VALIDATION_SPLIT_RATIO) * num_ex)
        x = shuffle_tensors(x)

        tr_x.append(x[:tr_size])
        tr_y.append(torch.full([tr_size], lbl.value, dtype=torch.int))

        val_x.append(x[tr_size:])
        val_y.append(torch.full([num_ex - tr_size], lbl.value,
                                dtype=torch.int))

    def _cat_tensors(lst_tensors: List[Tensor]) -> Tensor:
        return torch.cat(lst_tensors, dim=0).cpu()

    tr_x, tr_y = shuffle_tensors(_cat_tensors(tr_x), _cat_tensors(tr_y))
    val_x, val_y = shuffle_tensors(_cat_tensors(val_x), _cat_tensors(val_y))

    # Create the training and validation dataloaders respectively
    dls = []
    for x, y, shuffle in ((tr_x, tr_y, True), (val_x, val_y, False)):
        tk = w = -torch.ones_like(y).float().cpu()
        dls.append(
            DeviceDataLoader.create(TensorDataset(x.cpu(), y.cpu(), w.cpu(),
                                                  tk.cpu()),
                                    shuffle=shuffle,
                                    drop_last=shuffle
                                    and not config.DATASET.is_synthetic(),
                                    bs=bs,
                                    num_workers=NUM_WORKERS,
                                    device=TORCH_DEVICE))
    # noinspection PyTypeChecker
    return tuple(dls)
Пример #9
0
def construct_loader(
        ds: Union[TensorDataset, TextDataset],
        bs: int,
        shuffle: bool = True,
        drop_last: bool = False) -> Union[DeviceDataLoader, Iterator]:
    r""" Construct \p Iterator which emulates a \p DataLoader """
    if isinstance(ds, TextDataset):
        return Iterator(dataset=ds,
                        batch_size=bs,
                        shuffle=shuffle,
                        device=TORCH_DEVICE)

    dl = DataLoader(dataset=ds,
                    batch_size=bs,
                    shuffle=shuffle,
                    drop_last=drop_last,
                    num_workers=NUM_WORKERS,
                    pin_memory=False)
    # noinspection PyArgumentList
    return DeviceDataLoader(dl=dl, device=TORCH_DEVICE)
Пример #10
0
set_seed(100)
train_ds = MoleculeDataset(train_mol_ids, gb_mol_sc, gb_mol_atom, gb_mol_bond,
                           gb_mol_struct, gb_mol_angle_in, gb_mol_angle_out,
                           gb_mol_graph_dist)
val_ds = MoleculeDataset(val_mol_ids, gb_mol_sc, gb_mol_atom, gb_mol_bond,
                         gb_mol_struct, gb_mol_angle_in, gb_mol_angle_out,
                         gb_mol_graph_dist)
test_ds = MoleculeDataset(test_mol_ids, test_gb_mol_sc, gb_mol_atom,
                          gb_mol_bond, gb_mol_struct, gb_mol_angle_in,
                          gb_mol_angle_out, gb_mol_graph_dist)

train_dl = DataLoader(train_ds, args.batch_size, shuffle=True, num_workers=8)
val_dl = DataLoader(val_ds, args.batch_size, num_workers=8)
test_dl = DeviceDataLoader.create(test_ds,
                                  args.batch_size,
                                  num_workers=8,
                                  collate_fn=partial(collate_parallel_fn,
                                                     test=True))

db = DataBunch(train_dl, val_dl, collate_fn=collate_parallel_fn)
db.test_dl = test_dl

# set up model
set_seed(100)
d_model = args.d_model
enn_args = dict(layers=3 * [d_model], dropout=3 * [0.0], layer_norm=True)
ann_args = dict(layers=1 * [d_model],
                dropout=1 * [0.0],
                layer_norm=True,
                out_act=nn.Tanh())
model = Transformer(C.N_ATOM_FEATURES,
Пример #11
0
def main():
    global model_to_save
    global experiment
    global rabbit
    rabbit = MyRabbit(args)
    if rabbit.model_params.dont_limit_num_uniq_tokens:
        raise NotImplementedError()
    if rabbit.model_params.frame_as_qa: raise NotImplementedError
    if rabbit.run_params.drop_val_loss_calc: raise NotImplementedError
    if rabbit.run_params.use_softrank_influence and not rabbit.run_params.freeze_all_but_last_for_influence:
        raise NotImplementedError
    if rabbit.train_params.weight_influence: raise NotImplementedError
    experiment = Experiment(rabbit.train_params + rabbit.model_params +
                            rabbit.run_params)
    print('Model name:', experiment.model_name)
    use_pretrained_doc_encoder = rabbit.model_params.use_pretrained_doc_encoder
    use_pointwise_loss = rabbit.train_params.use_pointwise_loss
    query_token_embed_len = rabbit.model_params.query_token_embed_len
    document_token_embed_len = rabbit.model_params.document_token_embed_len
    _names = []
    if not rabbit.model_params.dont_include_titles:
        _names.append('with_titles')
    if rabbit.train_params.num_doc_tokens_to_consider != -1:
        _names.append('num_doc_toks_' +
                      str(rabbit.train_params.num_doc_tokens_to_consider))
    if not rabbit.run_params.just_caches:
        if rabbit.model_params.dont_include_titles:
            document_lookup = read_cache(name('./doc_lookup.json', _names),
                                         get_robust_documents)
        else:
            document_lookup = read_cache(name('./doc_lookup.json', _names),
                                         get_robust_documents_with_titles)
    num_doc_tokens_to_consider = rabbit.train_params.num_doc_tokens_to_consider
    document_title_to_id = read_cache(
        './document_title_to_id.json',
        lambda: create_id_lookup(document_lookup.keys()))
    with open('./caches/106756_most_common_doc.json', 'r') as fh:
        doc_token_set = set(json.load(fh))
        tokenizer = Tokenizer()
        tokenized = set(
            sum(
                tokenizer.process_all(list(
                    get_robust_eval_queries().values())), []))
        doc_token_set = doc_token_set.union(tokenized)
    use_bow_model = not any([
        rabbit.model_params[attr] for attr in
        ['use_doc_out', 'use_cnn', 'use_lstm', 'use_pretrained_doc_encoder']
    ])
    use_bow_model = use_bow_model and not rabbit.model_params.dont_use_bow
    if use_bow_model:
        documents, document_token_lookup = read_cache(
            name(f'./docs_fs_tokens_limit_uniq_toks_qrels_and_106756.pkl',
                 _names),
            lambda: prepare_fs(document_lookup,
                               document_title_to_id,
                               num_tokens=num_doc_tokens_to_consider,
                               token_set=doc_token_set))
        if rabbit.model_params.keep_top_uniq_terms is not None:
            documents = [
                dict(
                    nlargest(rabbit.model_params.keep_top_uniq_terms,
                             _.to_pairs(doc), itemgetter(1)))
                for doc in documents
            ]
    else:
        documents, document_token_lookup = read_cache(
            name(
                f'./parsed_docs_{num_doc_tokens_to_consider}_tokens_limit_uniq_toks_qrels_and_106756.json',
                _names), lambda: prepare(document_lookup,
                                         document_title_to_id,
                                         num_tokens=num_doc_tokens_to_consider,
                                         token_set=doc_token_set))
    if not rabbit.run_params.just_caches:
        train_query_lookup = read_cache('./robust_train_queries.json',
                                        get_robust_train_queries)
        train_query_name_to_id = read_cache(
            './train_query_name_to_id.json',
            lambda: create_id_lookup(train_query_lookup.keys()))
    train_queries, query_token_lookup = read_cache(
        './parsed_robust_queries_dict.json',
        lambda: prepare(train_query_lookup,
                        train_query_name_to_id,
                        token_lookup=document_token_lookup,
                        token_set=doc_token_set,
                        drop_if_any_unk=True))
    query_tok_to_doc_tok = {
        idx: document_token_lookup.get(query_token)
        or document_token_lookup['<unk>']
        for query_token, idx in query_token_lookup.items()
    }
    names = [RANKER_NAME_TO_SUFFIX[rabbit.train_params.ranking_set]]
    if rabbit.train_params.use_pointwise_loss or not rabbit.run_params.just_caches:
        train_data = read_cache(
            name('./robust_train_query_results_tokens_qrels_and_106756.json',
                 names), lambda: read_query_result(
                     train_query_name_to_id,
                     document_title_to_id,
                     train_queries,
                     path='./indri/query_result' + RANKER_NAME_TO_SUFFIX[
                         rabbit.train_params.ranking_set]))
    else:
        train_data = []
    q_embed_len = rabbit.model_params.query_token_embed_len
    doc_embed_len = rabbit.model_params.document_token_embed_len
    if rabbit.model_params.append_difference or rabbit.model_params.append_hadamard:
        assert q_embed_len == doc_embed_len, 'Must use same size doc and query embeds when appending diff or hadamard'
    if q_embed_len == doc_embed_len:
        glove_lookup = get_glove_lookup(
            embedding_dim=q_embed_len,
            use_large_embed=rabbit.model_params.use_large_embed,
            use_word2vec=rabbit.model_params.use_word2vec)
        q_glove_lookup = glove_lookup
        doc_glove_lookup = glove_lookup
    else:
        q_glove_lookup = get_glove_lookup(
            embedding_dim=q_embed_len,
            use_large_embed=rabbit.model_params.use_large_embed,
            use_word2vec=rabbit.model_params.use_word2vec)
        doc_glove_lookup = get_glove_lookup(
            embedding_dim=doc_embed_len,
            use_large_embed=rabbit.model_params.use_large_embed,
            use_word2vec=rabbit.model_params.use_word2vec)
    num_query_tokens = len(query_token_lookup)
    num_doc_tokens = len(document_token_lookup)
    doc_encoder = None
    if use_pretrained_doc_encoder or rabbit.model_params.use_doc_out:
        doc_encoder, document_token_embeds = get_doc_encoder_and_embeddings(
            document_token_lookup, rabbit.model_params.only_use_last_out)
        if rabbit.model_params.use_glove:
            query_token_embeds_init = init_embedding(q_glove_lookup,
                                                     query_token_lookup,
                                                     num_query_tokens,
                                                     query_token_embed_len)
        else:
            query_token_embeds_init = from_doc_to_query_embeds(
                document_token_embeds, document_token_lookup,
                query_token_lookup)
        if not rabbit.train_params.dont_freeze_pretrained_doc_encoder:
            dont_update(doc_encoder)
        if rabbit.model_params.use_doc_out:
            doc_encoder = None
    else:
        document_token_embeds = init_embedding(doc_glove_lookup,
                                               document_token_lookup,
                                               num_doc_tokens,
                                               document_token_embed_len)
        if rabbit.model_params.use_single_word_embed_set:
            query_token_embeds_init = document_token_embeds
        else:
            query_token_embeds_init = init_embedding(q_glove_lookup,
                                                     query_token_lookup,
                                                     num_query_tokens,
                                                     query_token_embed_len)
    if not rabbit.train_params.dont_freeze_word_embeds:
        dont_update(document_token_embeds)
        dont_update(query_token_embeds_init)
    else:
        do_update(document_token_embeds)
        do_update(query_token_embeds_init)
    if rabbit.train_params.add_rel_score:
        query_token_embeds, additive = get_additive_regularized_embeds(
            query_token_embeds_init)
        rel_score = RelScore(query_token_embeds, document_token_embeds,
                             rabbit.model_params, rabbit.train_params)
    else:
        query_token_embeds = query_token_embeds_init
        additive = None
        rel_score = None
    eval_query_lookup = get_robust_eval_queries()
    eval_query_name_document_title_rels = get_robust_rels()
    test_query_names = []
    val_query_names = []
    for query_name in eval_query_lookup:
        if len(val_query_names) >= 50: test_query_names.append(query_name)
        else: val_query_names.append(query_name)
    test_query_name_document_title_rels = _.pick(
        eval_query_name_document_title_rels, test_query_names)
    test_query_lookup = _.pick(eval_query_lookup, test_query_names)
    test_query_name_to_id = create_id_lookup(test_query_lookup.keys())
    test_queries, __ = prepare(test_query_lookup,
                               test_query_name_to_id,
                               token_lookup=query_token_lookup)
    eval_ranking_candidates = read_query_test_rankings(
        './indri/query_result_test' +
        RANKER_NAME_TO_SUFFIX[rabbit.train_params.ranking_set])
    test_candidates_data = read_query_result(
        test_query_name_to_id,
        document_title_to_id,
        dict(zip(range(len(test_queries)), test_queries)),
        path='./indri/query_result_test' +
        RANKER_NAME_TO_SUFFIX[rabbit.train_params.ranking_set])
    test_ranking_candidates = process_raw_candidates(test_query_name_to_id,
                                                     test_queries,
                                                     document_title_to_id,
                                                     test_query_names,
                                                     eval_ranking_candidates)
    test_data = process_rels(test_query_name_document_title_rels,
                             document_title_to_id, test_query_name_to_id,
                             test_queries)
    val_query_name_document_title_rels = _.pick(
        eval_query_name_document_title_rels, val_query_names)
    val_query_lookup = _.pick(eval_query_lookup, val_query_names)
    val_query_name_to_id = create_id_lookup(val_query_lookup.keys())
    val_queries, __ = prepare(val_query_lookup,
                              val_query_name_to_id,
                              token_lookup=query_token_lookup)
    val_candidates_data = read_query_result(
        val_query_name_to_id,
        document_title_to_id,
        dict(zip(range(len(val_queries)), val_queries)),
        path='./indri/query_result_test' +
        RANKER_NAME_TO_SUFFIX[rabbit.train_params.ranking_set])
    val_ranking_candidates = process_raw_candidates(val_query_name_to_id,
                                                    val_queries,
                                                    document_title_to_id,
                                                    val_query_names,
                                                    eval_ranking_candidates)
    val_data = process_rels(val_query_name_document_title_rels,
                            document_title_to_id, val_query_name_to_id,
                            val_queries)
    train_normalized_score_lookup = read_cache(
        name('./train_normalized_score_lookup.pkl', names),
        lambda: get_normalized_score_lookup(train_data))
    test_normalized_score_lookup = get_normalized_score_lookup(
        test_candidates_data)
    val_normalized_score_lookup = get_normalized_score_lookup(
        val_candidates_data)
    if use_pointwise_loss:
        normalized_train_data = read_cache(
            name('./normalized_train_query_data_qrels_and_106756.json', names),
            lambda: normalize_scores_query_wise(train_data))
        collate_fn = lambda samples: collate_query_samples(
            samples,
            use_bow_model=use_bow_model,
            use_dense=rabbit.model_params.use_dense)
        train_dl = build_query_dataloader(
            documents,
            normalized_train_data,
            rabbit.train_params,
            rabbit.model_params,
            cache=name('train_ranking_qrels_and_106756.json', names),
            limit=10,
            query_tok_to_doc_tok=query_tok_to_doc_tok,
            normalized_score_lookup=train_normalized_score_lookup,
            use_bow_model=use_bow_model,
            collate_fn=collate_fn,
            is_test=False)
        test_dl = build_query_dataloader(
            documents,
            test_data,
            rabbit.train_params,
            rabbit.model_params,
            query_tok_to_doc_tok=query_tok_to_doc_tok,
            normalized_score_lookup=test_normalized_score_lookup,
            use_bow_model=use_bow_model,
            collate_fn=collate_fn,
            is_test=True)
        val_dl = build_query_dataloader(
            documents,
            val_data,
            rabbit.train_params,
            rabbit.model_params,
            query_tok_to_doc_tok=query_tok_to_doc_tok,
            normalized_score_lookup=val_normalized_score_lookup,
            use_bow_model=use_bow_model,
            collate_fn=collate_fn,
            is_test=True)
        model = PointwiseScorer(query_token_embeds, document_token_embeds,
                                doc_encoder, rabbit.model_params,
                                rabbit.train_params)
    else:
        if rabbit.train_params.use_noise_aware_loss:
            ranker_query_str_to_rankings = get_ranker_query_str_to_rankings(
                train_query_name_to_id,
                document_title_to_id,
                train_queries,
                limit=rabbit.train_params.num_snorkel_train_queries)
            query_names = reduce(
                lambda acc, query_to_ranking: acc.intersection(
                    set(query_to_ranking.keys()))
                if len(acc) != 0 else set(query_to_ranking.keys()),
                ranker_query_str_to_rankings.values(), set())
            all_ranked_lists_by_ranker = _.map_values(
                ranker_query_str_to_rankings, lambda query_to_ranking:
                [query_to_ranking[query] for query in query_names])
            ranker_query_str_to_pairwise_bins = get_ranker_query_str_to_pairwise_bins(
                train_query_name_to_id,
                document_title_to_id,
                train_queries,
                limit=rabbit.train_params.num_train_queries)
            snorkeller = Snorkeller(ranker_query_str_to_pairwise_bins)
            snorkeller.train(all_ranked_lists_by_ranker)
            calc_marginals = snorkeller.calc_marginals
        else:
            calc_marginals = None
        collate_fn = lambda samples: collate_query_pairwise_samples(
            samples,
            use_bow_model=use_bow_model,
            calc_marginals=calc_marginals,
            use_dense=rabbit.model_params.use_dense)
        if rabbit.run_params.load_influences:
            try:
                with open(rabbit.run_params.influences_path) as fh:
                    pairs_to_flip = defaultdict(set)
                    for pair, influence in json.load(fh):
                        if rabbit.train_params.use_pointwise_loss:
                            condition = True
                        else:
                            condition = influence < rabbit.train_params.influence_thresh
                        if condition:
                            query = tuple(pair[1])
                            pairs_to_flip[query].add(tuple(pair[0]))
            except FileNotFoundError:
                pairs_to_flip = None
        else:
            pairs_to_flip = None
        train_dl = build_query_pairwise_dataloader(
            documents,
            train_data,
            rabbit.train_params,
            rabbit.model_params,
            pairs_to_flip=pairs_to_flip,
            cache=name('train_ranking_qrels_and_106756.json', names),
            limit=10,
            query_tok_to_doc_tok=query_tok_to_doc_tok,
            normalized_score_lookup=train_normalized_score_lookup,
            use_bow_model=use_bow_model,
            collate_fn=collate_fn,
            is_test=False)
        test_dl = build_query_pairwise_dataloader(
            documents,
            test_data,
            rabbit.train_params,
            rabbit.model_params,
            query_tok_to_doc_tok=query_tok_to_doc_tok,
            normalized_score_lookup=test_normalized_score_lookup,
            use_bow_model=use_bow_model,
            collate_fn=collate_fn,
            is_test=True)
        val_dl = build_query_pairwise_dataloader(
            documents,
            val_data,
            rabbit.train_params,
            rabbit.model_params,
            query_tok_to_doc_tok=query_tok_to_doc_tok,
            normalized_score_lookup=val_normalized_score_lookup,
            use_bow_model=use_bow_model,
            collate_fn=collate_fn,
            is_test=True)
        val_rel_dl = build_query_pairwise_dataloader(
            documents,
            val_data,
            rabbit.train_params,
            rabbit.model_params,
            query_tok_to_doc_tok=query_tok_to_doc_tok,
            normalized_score_lookup=val_normalized_score_lookup,
            use_bow_model=use_bow_model,
            collate_fn=collate_fn,
            is_test=True,
            rel_vs_irrel=True,
            candidates=val_ranking_candidates,
            num_to_rank=rabbit.run_params.num_to_rank)
        model = PairwiseScorer(query_token_embeds,
                               document_token_embeds,
                               doc_encoder,
                               rabbit.model_params,
                               rabbit.train_params,
                               use_bow_model=use_bow_model)
    train_ranking_dataset = RankingDataset(
        documents,
        train_dl.dataset.rankings,
        rabbit.train_params,
        rabbit.model_params,
        rabbit.run_params,
        query_tok_to_doc_tok=query_tok_to_doc_tok,
        normalized_score_lookup=train_normalized_score_lookup,
        use_bow_model=use_bow_model,
        use_dense=rabbit.model_params.use_dense)
    test_ranking_dataset = RankingDataset(
        documents,
        test_ranking_candidates,
        rabbit.train_params,
        rabbit.model_params,
        rabbit.run_params,
        relevant=test_dl.dataset.rankings,
        query_tok_to_doc_tok=query_tok_to_doc_tok,
        cheat=rabbit.run_params.cheat,
        normalized_score_lookup=test_normalized_score_lookup,
        use_bow_model=use_bow_model,
        use_dense=rabbit.model_params.use_dense)
    val_ranking_dataset = RankingDataset(
        documents,
        val_ranking_candidates,
        rabbit.train_params,
        rabbit.model_params,
        rabbit.run_params,
        relevant=val_dl.dataset.rankings,
        query_tok_to_doc_tok=query_tok_to_doc_tok,
        cheat=rabbit.run_params.cheat,
        normalized_score_lookup=val_normalized_score_lookup,
        use_bow_model=use_bow_model,
        use_dense=rabbit.model_params.use_dense)
    if rabbit.train_params.memorize_test:
        train_dl = test_dl
        train_ranking_dataset = test_ranking_dataset
    model_data = DataBunch(train_dl,
                           val_rel_dl,
                           test_dl,
                           collate_fn=collate_fn,
                           device=torch.device('cuda') if
                           torch.cuda.is_available() else torch.device('cpu'))
    multi_objective_model = MultiObjective(model, rabbit.train_params,
                                           rel_score, additive)
    model_to_save = multi_objective_model
    if rabbit.train_params.memorize_test:
        try:
            del train_data
        except:
            pass
    if not rabbit.run_params.just_caches:
        del document_lookup
        del train_query_lookup
    del query_token_lookup
    del document_token_lookup
    del train_queries
    try:
        del glove_lookup
    except UnboundLocalError:
        del q_glove_lookup
        del doc_glove_lookup
    if rabbit.run_params.load_model:
        try:
            multi_objective_model.load_state_dict(
                torch.load(rabbit.run_params.load_path))
        except RuntimeError:
            dp = nn.DataParallel(multi_objective_model)
            dp.load_state_dict(torch.load(rabbit.run_params.load_path))
            multi_objective_model = dp.module
    else:
        train_model(multi_objective_model, model_data, train_ranking_dataset,
                    val_ranking_dataset, test_ranking_dataset,
                    rabbit.train_params, rabbit.model_params,
                    rabbit.run_params, experiment)
    if rabbit.train_params.fine_tune_on_val:
        fine_tune_model_data = DataBunch(
            val_rel_dl,
            val_rel_dl,
            test_dl,
            collate_fn=collate_fn,
            device=torch.device('cuda')
            if torch.cuda.is_available() else torch.device('cpu'))
        train_model(multi_objective_model,
                    fine_tune_model_data,
                    val_ranking_dataset,
                    val_ranking_dataset,
                    test_ranking_dataset,
                    rabbit.train_params,
                    rabbit.model_params,
                    rabbit.run_params,
                    experiment,
                    load_path=rabbit.run_params.load_path)
    multi_objective_model.eval()
    device = model_data.device
    gpu_multi_objective_model = multi_objective_model.to(device)
    if rabbit.run_params.calc_influence:
        if rabbit.run_params.freeze_all_but_last_for_influence:
            last_layer_idx = _.find_last_index(
                multi_objective_model.model.pointwise_scorer.layers,
                lambda layer: isinstance(layer, nn.Linear))
            to_last_layer = lambda x: gpu_multi_objective_model(
                *x, to_idx=last_layer_idx)
            last_layer = gpu_multi_objective_model.model.pointwise_scorer.layers[
                last_layer_idx]
            diff_wrt = [p for p in last_layer.parameters() if p.requires_grad]
        else:
            diff_wrt = None
        test_hvps = calc_test_hvps(
            multi_objective_model.loss,
            gpu_multi_objective_model,
            DeviceDataLoader(train_dl, device, collate_fn=collate_fn),
            val_rel_dl,
            rabbit.run_params,
            diff_wrt=diff_wrt,
            show_progress=True,
            use_softrank_influence=rabbit.run_params.use_softrank_influence)
        influences = []
        if rabbit.train_params.use_pointwise_loss:
            num_real_samples = len(train_dl.dataset)
        else:
            num_real_samples = train_dl.dataset._num_pos_pairs
        if rabbit.run_params.freeze_all_but_last_for_influence:
            _sampler = SequentialSamplerWithLimit(train_dl.dataset,
                                                  num_real_samples)
            _batch_sampler = BatchSampler(_sampler,
                                          rabbit.train_params.batch_size,
                                          False)
            _dl = DataLoader(train_dl.dataset,
                             batch_sampler=_batch_sampler,
                             collate_fn=collate_fn)
            sequential_train_dl = DeviceDataLoader(_dl,
                                                   device,
                                                   collate_fn=collate_fn)
            influences = calc_dataset_influence(gpu_multi_objective_model,
                                                to_last_layer,
                                                sequential_train_dl,
                                                test_hvps,
                                                sum_p=True).tolist()
        else:
            for i in progressbar(range(num_real_samples)):
                train_sample = train_dl.dataset[i]
                x, labels = to_device(collate_fn([train_sample]), device)
                device_train_sample = (x, labels.squeeze())
                influences.append(
                    calc_influence(multi_objective_model.loss,
                                   gpu_multi_objective_model,
                                   device_train_sample,
                                   test_hvps,
                                   diff_wrt=diff_wrt).sum().tolist())
        with open(rabbit.run_params.influences_path, 'w+') as fh:
            json.dump([[train_dl.dataset[idx][1], influence]
                       for idx, influence in enumerate(influences)], fh)