예제 #1
0
def test_setup_device():
    import torch
    from pecos.utils import torch_util

    if torch.cuda.is_available():  # GPU machine
        device, n_active_gpu = torch_util.setup_device(
            use_gpu_if_available=True)
        assert device == torch.device("cuda")
        assert n_active_gpu == torch.cuda.device_count()
        device, n_active_gpu = torch_util.setup_device(
            use_gpu_if_available=False)
        assert device == torch.device("cpu")
        assert n_active_gpu == 0
    else:
        device, n_active_gpu = torch_util.setup_device(
            use_gpu_if_available=True)
        assert device == torch.device("cpu")
        assert n_active_gpu == 0
        device, n_active_gpu = torch_util.setup_device(
            use_gpu_if_available=False)
        assert device == torch.device("cpu")
        assert n_active_gpu == 0
예제 #2
0
def test_bert(tmpdir):
    from pecos.utils import torch_util

    _, n_gpu = torch_util.setup_device()
    # test on CPU
    xtransformer_cli(tmpdir.join("sparse_cpu"), bert_model_path,
                     train_feat_file, 0)
    xtransformer_cli(tmpdir.join("dense_cpu"), bert_model_path,
                     train_dense_feat_file, 0)

    if n_gpu > 0:
        # test on all GPUs
        xtransformer_cli(tmpdir.join("sparse_gpu"), bert_model_path,
                         train_feat_file, n_gpu)
        xtransformer_cli(tmpdir.join("dense_gpu"), bert_model_path,
                         train_dense_feat_file, n_gpu)

    if n_gpu > 1:
        # test on single GPU when multi-GPU available
        xtransformer_cli(tmpdir.join("sparse_single_gpu"), bert_model_path,
                         train_feat_file, 1)
        xtransformer_cli(tmpdir.join("dense_single_gpu"), bert_model_path,
                         train_dense_feat_file, 1)
예제 #3
0
    def predict(
        self,
        corpus,
        batch_size=8,
        truncate_length=300,
        use_gpu_if_available=True,
        **kwargs,
    ):
        """Vectorizer a corpus.

        Args:
            corpus (list): List of strings to vectorize.
            batch_size (int, optional): Default is 8.
            truncate_length (int, optional): Default is 300.
            use_gpu_if_available (bool, optional): Default is True.

        Returns:
            numpy.ndarray: Matrix of features.
        """

        if self.model.config.max_position_embeddings > 0:
            truncate_length = min(truncate_length,
                                  self.model.config.max_position_embeddings)

        # generate feature batches
        feature_tensors = self.tokenizer.batch_encode_plus(
            batch_text_or_text_pairs=corpus,
            return_tensors="pt",
            return_attention_mask=True,
            return_token_type_ids=True,
            add_special_tokens=True,
            max_length=truncate_length,
            truncation=True,
            padding="longest",
        )
        # setup device
        device, n_active_gpu = torch_util.setup_device(
            use_gpu_if_available=use_gpu_if_available)
        # start eval
        transformer_type = self.transformer_options["transformer_type"]
        norm = self.transformer_options["norm"]
        pooling = self.transformer_options["pooling"]

        batch_size = batch_size * max(1, n_active_gpu)
        data = TensorDataset(
            feature_tensors["input_ids"],
            feature_tensors["attention_mask"],
            feature_tensors["token_type_ids"],
        )

        sampler = SequentialSampler(data)
        dataloader = DataLoader(data,
                                sampler=sampler,
                                batch_size=batch_size,
                                num_workers=4)

        # multi-gpu eval
        if n_active_gpu > 1 and not isinstance(self.model,
                                               torch.nn.parallel.DataParallel):
            model = torch.nn.parallel.DataParallel(self.model)
        else:
            model = self.model

        model.eval()
        model.to(device)
        embeddings = []
        for batch in dataloader:
            batch = tuple(t.to(device) for t in batch)

            with torch.no_grad():
                inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "token_type_ids": batch[2],
                }
                if transformer_type == "distilbert":
                    outputs = model(
                        input_ids=inputs["input_ids"],
                        attention_mask=inputs["attention_mask"],
                    )
                elif transformer_type in [
                        "bert",
                        "roberta",
                        "xlm-roberta",
                        "albert",
                        "xlm",
                        "xlnet",
                ]:
                    outputs = model(
                        input_ids=inputs["input_ids"],
                        attention_mask=inputs["attention_mask"],
                        token_type_ids=inputs["token_type_ids"],
                    )
                else:
                    raise NotImplementedError(
                        "Unsupported transformer_type {}".format(
                            transformer_type))

                # get the embeddings from model output
                # REF: https://huggingface.co/transformers/v2.3.0/model_doc/bert.html#bertmodel
                # For bert,roberta,xlm-roberta,albert:  outputs = last_hidden_states, pooled_output, (hidden_states)
                # For xlm,xlnet,distilbert: outputs = last_hidden_states, (hidden_states), (attentions)
                if pooling == "mean":
                    pooled_output = outputs[0].mean(dim=1)
                elif pooling == "mask-mean":
                    last_hidden_states = torch_util.apply_mask(
                        outputs[0], inputs["attention_mask"])
                    pooled_output = last_hidden_states.sum(dim=1)
                    masked_length = inputs["attention_mask"].sum(dim=1)
                    pooled_output = pooled_output / masked_length.unsqueeze(
                        1).float()
                elif pooling == "first":
                    pooled_output = outputs[0][:, 0, :]
                elif pooling == "last":
                    pooled_output = outputs[0][:, -1, :]
                elif pooling == "cls":
                    assert transformer_type in [
                        "bert",
                        "roberta",
                        "xlm-roberta",
                        "albert",
                    ], "Only {} models have [CLS] token.".format(
                        ["bert", "roberta", "xlm-roberta", "albert"])
                    # get the [CLS] embedding for the document
                    pooled_output = outputs[1]
                else:
                    raise NotImplementedError(
                        "Unsupported pooling method {}".format(pooling))

                embeddings.append(pooled_output.cpu().numpy())

        # construct dense output
        embeddings = np.concatenate(embeddings, axis=0)
        if norm is not None:
            embeddings = normalize(embeddings, norm=norm, axis=1, copy=False)
        return embeddings
예제 #4
0
파일: model.py 프로젝트: OctoberChang/pecos
    def encode(
        self,
        X_text,
        pred_params=None,
        **kwargs,
    ):
        """Use the Transformer text encoder to generate embeddings for input data.

        Args:
            X_text (iterable over str): instance text input to predict on
            pred_kwargs (XTransformer.PredParams, optional): instance of
                XTransformer.PredParams. Default None to use pred_params stored
                during model training.
            kwargs:
                saved_pt (str, optional): if given, will try to load encoded tensors and skip text encoding
                batch_size (int, optional): per device batch size for transformer evaluation. Default 8
                batch_gen_workers (int, optional): number of CPUs to use for batch generation. Default 4
                use_gpu (bool, optional): use GPU if available. Default True
                max_pred_chunk (int, optional): max number of instances to predict at once.
                    Set to None to ignore. Default 10^7

        Returns:
            embeddings (ndarray): instance embedding on training data, shape = (nr_inst, hidden_dim).
        """
        saved_pt = kwargs.get("saved_pt", None)
        batch_size = kwargs.get("batch_size", 8)
        batch_gen_workers = kwargs.get("batch_gen_workers", 4)
        use_gpu = kwargs.get("use_gpu", True)
        max_pred_chunk = kwargs.get("max_pred_chunk", 10**7)
        device, n_gpu = torch_util.setup_device(use_gpu)

        # get the override pred_params
        if pred_params is None:
            pred_params = self.get_pred_params()
        else:
            pred_params = self.PredParams.from_dict(pred_params)
        pred_params.override_with_kwargs(kwargs)

        LOGGER.debug(f"Encode with pred_params: {json.dumps(pred_params.to_dict(), indent=True)}")
        if isinstance(pred_params.matcher_params_chain, list):
            encoder_pred_params = pred_params.matcher_params_chain[-1]
        else:
            encoder_pred_params = pred_params.matcher_params_chain

        # generate instance-to-cluster prediction
        if saved_pt and os.path.isfile(saved_pt):
            text_tensors = torch.load(saved_pt)
            LOGGER.info("Text tensors loaded_from {}".format(saved_pt))
        else:
            text_tensors = self.text_encoder.text_to_tensor(
                X_text,
                num_workers=batch_gen_workers,
                max_length=encoder_pred_params.truncate_length,
            )

        self.text_encoder.to_device(device, n_gpu=n_gpu)
        _, embeddings = self.text_encoder.predict(
            text_tensors,
            pred_params=encoder_pred_params,
            batch_size=batch_size * max(1, n_gpu),
            batch_gen_workers=batch_gen_workers,
            max_pred_chunk=max_pred_chunk,
            only_embeddings=True,
        )
        return embeddings
예제 #5
0
파일: model.py 프로젝트: OctoberChang/pecos
    def predict(
        self,
        X_text,
        X_feat=None,
        pred_params=None,
        **kwargs,
    ):
        """Use the XR-Transformer model to predict on given data.

        Args:
            X_text (iterable over str): instance text input to predict on
            X_feat (csr_matrix or ndarray): instance feature matrix (nr_insts, feature_dim)
            pred_kwargs (XTransformer.PredParams, optional): instance of
                XTransformer.PredParams. Default None to use pred_params stored
                during model training.
            kwargs:
                beam_size (int, optional): override the beam size specified in the model.
                    Default None to disable overriding
                only_topk (int, optional): override the only topk specified in the model
                    Default None to disable overriding
                post_processor (str, optional):  override the post_processor specified in the model
                    Default None to disable overriding
                saved_pt (str, optional): if given, will try to load encoded tensors and skip text encoding
                batch_size (int, optional): per device batch size for transformer evaluation. Default 8
                batch_gen_workers (int, optional): number of CPUs to use for batch generation. Default 4
                use_gpu (bool, optional): use GPU if available. Default True
                max_pred_chunk (int, optional): max number of instances to predict at once.
                    Set to None to ignore. Default 10^7
                threads (int, optional): the number of threads to use for linear model prediction.

        Returns:
            pred_csr (csr_matrix): instance to label prediction (csr_matrix, nr_insts * nr_labels)
        """
        if not isinstance(self.concat_model, XLinearModel):
            raise TypeError("concat_model is not present in current XTransformer model!")

        saved_pt = kwargs.get("saved_pt", None)
        batch_size = kwargs.get("batch_size", 8)
        batch_gen_workers = kwargs.get("batch_gen_workers", 4)
        use_gpu = kwargs.get("use_gpu", True)
        max_pred_chunk = kwargs.get("max_pred_chunk", 10**7)
        device, n_gpu = torch_util.setup_device(use_gpu)

        # get the override pred_params
        if pred_params is None:
            pred_params = self.get_pred_params()
        else:
            pred_params = self.PredParams.from_dict(pred_params)
        pred_params.override_with_kwargs(kwargs)

        LOGGER.debug(
            f"Prediction with pred_params: {json.dumps(pred_params.to_dict(), indent=True)}"
        )
        if isinstance(pred_params.matcher_params_chain, list):
            encoder_pred_params = pred_params.matcher_params_chain[-1]
        else:
            encoder_pred_params = pred_params.matcher_params_chain

        # generate instance-to-cluster prediction
        if saved_pt and os.path.isfile(saved_pt):
            text_tensors = torch.load(saved_pt)
            LOGGER.info("Text tensors loaded_from {}".format(saved_pt))
        else:
            text_tensors = self.text_encoder.text_to_tensor(
                X_text,
                num_workers=batch_gen_workers,
                max_length=encoder_pred_params.truncate_length,
            )

        pred_csr = None
        self.text_encoder.to_device(device, n_gpu=n_gpu)
        _, embeddings = self.text_encoder.predict(
            text_tensors,
            pred_params=encoder_pred_params,
            batch_size=batch_size * max(1, n_gpu),
            batch_gen_workers=batch_gen_workers,
            max_pred_chunk=max_pred_chunk,
            only_embeddings=True,
        )

        cat_embeddings = TransformerMatcher.concat_features(
            X_feat,
            embeddings,
            normalize_emb=True,
        )
        LOGGER.debug(
            "Constructed instance feature matrix with shape={}".format(cat_embeddings.shape)
        )
        pred_csr = self.concat_model.predict(
            cat_embeddings,
            pred_params=None if pred_params is None else pred_params.ranker_params,
            max_pred_chunk=max_pred_chunk,
            threads=kwargs.get("threads", -1),
        )
        return pred_csr
예제 #6
0
파일: model.py 프로젝트: OctoberChang/pecos
    def train(
        cls,
        prob,
        clustering=None,
        val_prob=None,
        train_params=None,
        pred_params=None,
        **kwargs,
    ):
        """Train the XR-Transformer model with the given input data.

        Args:
            prob (MLProblemWithText): ML problem to solve.
            clustering (ClusterChain, optional): preliminary hierarchical label tree,
                where transformer is fine-tuned on.
            val_prob (MLProblemWithText, optional): ML problem for validation.
            train_params (XTransformer.TrainParams): training parameters for XTransformer
            pred_params (XTransformer.pred_params): pred parameters for XTransformer
            kwargs:
                label_feat (ndarray or csr_matrix, optional): label features on which to generate preliminary HLT
                saved_trn_pt (str, optional): path to save the tokenized trn text. Use a tempdir if not given
                saved_val_pt (str, optional): path to save the tokenized val text. Use a tempdir if not given
                matmul_threads (int, optional): number of threads to use for
                    constructing label tree. Default to use at most 32 threads
                beam_size (int, optional): overrides only_topk for models except
                    bottom layer one

        Returns:
            XTransformer
        """
        # tempdir to save tokenized text
        temp_dir = tempfile.TemporaryDirectory()
        saved_trn_pt = kwargs.get("saved_trn_pt", "")
        if not saved_trn_pt:
            saved_trn_pt = f"{temp_dir.name}/X_trn.pt"

        saved_val_pt = kwargs.get("saved_val_pt", "")
        if not saved_val_pt:
            saved_val_pt = f"{temp_dir.name}/X_val.pt"

        # construct train_params
        if train_params is None:
            # fill all BaseParams class with their default value
            train_params = cls.TrainParams.from_dict(dict(), recursive=True)
        else:
            train_params = cls.TrainParams.from_dict(train_params)
        # construct pred_params
        if pred_params is None:
            # fill all BaseParams with their default value
            pred_params = cls.PredParams.from_dict(dict(), recursive=True)
        else:
            pred_params = cls.PredParams.from_dict(pred_params)

        if not train_params.do_fine_tune:
            if isinstance(train_params.matcher_params_chain, list):
                matcher_train_params = train_params.matcher_params_chain[-1]
            else:
                matcher_train_params = train_params.matcher_params_chain

            if isinstance(train_params.matcher_params_chain, list):
                matcher_pred_params = pred_params.matcher_params_chain[-1]
            else:
                matcher_pred_params = pred_params.matcher_params_chain

            device, n_gpu = torch_util.setup_device(matcher_train_params.use_gpu)

            if matcher_train_params.init_model_dir:
                parent_model = cls.load(train_params.init_model_dir)
                LOGGER.info("Loaded encoder from {}.".format(matcher_train_params.init_model_dir))
            else:
                parent_model = TransformerMatcher.download_model(
                    matcher_train_params.model_shortcut,
                )
                LOGGER.info(
                    "Downloaded encoder from {}.".format(matcher_train_params.model_shortcut)
                )

            parent_model.to_device(device, n_gpu=n_gpu)
            _, inst_embeddings = parent_model.predict(
                prob.X_text,
                pred_params=matcher_pred_params,
                batch_size=matcher_train_params.batch_size * max(1, n_gpu),
                batch_gen_workers=matcher_train_params.batch_gen_workers,
                only_embeddings=True,
            )
            if val_prob:
                _, val_inst_embeddings = parent_model.predict(
                    val_prob.X_text,
                    pred_params=matcher_pred_params,
                    batch_size=matcher_train_params.batch_size * max(1, n_gpu),
                    batch_gen_workers=matcher_train_params.batch_gen_workers,
                    only_embeddings=True,
                )
        else:
            # 1. Constructing primary Hierarchial Label Tree
            if clustering is None:
                label_feat = kwargs.get("label_feat", None)
                if label_feat is None:
                    if prob.X_feat is None:
                        raise ValueError(
                            "Instance features are required to generate label features!"
                        )
                    label_feat = LabelEmbeddingFactory.pifa(prob.Y, prob.X_feat)

                clustering = Indexer.gen(
                    label_feat,
                    train_params=train_params.preliminary_indexer_params,
                )
            else:
                # assert cluster chain in clustering is valid
                clustering = ClusterChain(clustering)
                if clustering[-1].shape[0] != prob.nr_labels:
                    raise ValueError("nr_labels mismatch!")
            prelim_hierarchiy = [cc.shape[0] for cc in clustering]
            LOGGER.info("Hierarchical label tree: {}".format(prelim_hierarchiy))

            # get the fine-tuning task numbers
            nr_transformers = sum(i <= train_params.max_match_clusters for i in prelim_hierarchiy)

            LOGGER.info(
                "Fine-tune Transformers with nr_labels={}".format(
                    [cc.shape[0] for cc in clustering[:nr_transformers]]
                )
            )

            steps_scale = kwargs.get("steps_scale", None)
            if steps_scale is None:
                steps_scale = [1.0] * nr_transformers
            if len(steps_scale) != nr_transformers:
                raise ValueError(f"steps-scale length error: {len(steps_scale)}!={nr_transformers}")

            # construct fields with chain now we know the depth
            train_params = HierarchicalMLModel._duplicate_fields_with_name_ending_with_chain(
                train_params, cls.TrainParams, nr_transformers
            )

            LOGGER.debug(
                f"XTransformer train_params: {json.dumps(train_params.to_dict(), indent=True)}"
            )

            pred_params = HierarchicalMLModel._duplicate_fields_with_name_ending_with_chain(
                pred_params, cls.PredParams, nr_transformers
            )
            pred_params = pred_params.override_with_kwargs(kwargs)

            LOGGER.debug(
                f"XTransformer pred_params: {json.dumps(pred_params.to_dict(), indent=True)}"
            )

            def get_negative_samples(mat_true, mat_pred, scheme):
                if scheme == "tfn":
                    result = smat_util.binarized(mat_true)
                elif scheme == "man":
                    result = smat_util.binarized(mat_pred)
                elif "tfn" in scheme and "man" in scheme:
                    result = smat_util.binarized(mat_true) + smat_util.binarized(mat_pred)
                else:
                    raise ValueError("Unrecognized negative sampling method {}".format(scheme))
                LOGGER.debug(
                    f"Construct {scheme} with shape={result.shape} avr_M_nnz={result.nnz/result.shape[0]}"
                )
                return result

            # construct label chain for training and validation set
            # avoid large matmul_threads to prevent overhead in Y.dot(C) and save memory
            matmul_threads = kwargs.get("threads", os.cpu_count())
            matmul_threads = min(32, matmul_threads)
            YC_list = [prob.Y]
            for cur_C in reversed(clustering[1:]):
                Y_t = clib.sparse_matmul(YC_list[-1], cur_C, threads=matmul_threads).tocsr()
                YC_list.append(Y_t)
            YC_list.reverse()

            if val_prob is not None:
                val_YC_list = [val_prob.Y]
                for cur_C in reversed(clustering[1:]):
                    Y_t = clib.sparse_matmul(val_YC_list[-1], cur_C, threads=matmul_threads).tocsr()
                    val_YC_list.append(Y_t)
                val_YC_list.reverse()

            parent_model = None
            M, val_M = None, None
            M_pred, val_M_pred = None, None
            bootstrapping, inst_embeddings = None, None
            for i in range(nr_transformers):
                cur_train_params = train_params.matcher_params_chain[i]
                cur_pred_params = pred_params.matcher_params_chain[i]
                cur_train_params.max_steps = steps_scale[i] * cur_train_params.max_steps
                cur_train_params.num_train_epochs = (
                    steps_scale[i] * cur_train_params.num_train_epochs
                )

                cur_ns = cur_train_params.negative_sampling

                # construct train and val problem for level i
                # note that final layer do not need X_feat
                if i > 0:
                    M = get_negative_samples(YC_list[i - 1], M_pred, cur_ns)

                cur_prob = MLProblemWithText(
                    prob.X_text,
                    YC_list[i],
                    X_feat=None if i == nr_transformers - 1 else prob.X_feat,
                    C=clustering[i],
                    M=M,
                )
                if val_prob is not None:
                    if i > 0:
                        val_M = get_negative_samples(val_YC_list[i - 1], val_M_pred, cur_ns)
                    cur_val_prob = MLProblemWithText(
                        val_prob.X_text,
                        val_YC_list[i],
                        X_feat=None if i == nr_transformers - 1 else val_prob.X_feat,
                        C=clustering[i],
                        M=val_M,
                    )
                else:
                    cur_val_prob = None

                avr_trn_labels = (
                    float(cur_prob.M.nnz) / YC_list[i].shape[0]
                    if cur_prob.M is not None
                    else YC_list[i].shape[1]
                )
                LOGGER.info(
                    "Fine-tuning XR-Transformer with {} at level {}, nr_labels={}, avr_M_nnz={}".format(
                        cur_ns, i, YC_list[i].shape[1], avr_trn_labels
                    )
                )

                # bootstrapping with previous text_encoder and instance embeddings
                if parent_model is not None:
                    init_encoder = deepcopy(parent_model.text_encoder)
                    init_text_model = deepcopy(parent_model.text_model)
                    bootstrapping = (init_encoder, inst_embeddings, init_text_model)

                # determine whether train prediction and instance embeddings are needed
                return_train_pred = (
                    i + 1 < nr_transformers
                ) and "man" in train_params.matcher_params_chain[i + 1].negative_sampling
                return_train_embeddings = (
                    i + 1 == nr_transformers
                ) or "linear" in cur_train_params.bootstrap_method

                res_dict = TransformerMatcher.train(
                    cur_prob,
                    csr_codes=M_pred,
                    val_prob=cur_val_prob,
                    val_csr_codes=val_M_pred,
                    train_params=cur_train_params,
                    pred_params=cur_pred_params,
                    bootstrapping=bootstrapping,
                    return_dict=True,
                    return_train_pred=return_train_pred,
                    return_train_embeddings=return_train_embeddings,
                    saved_trn_pt=saved_trn_pt,
                    saved_val_pt=saved_val_pt,
                )
                parent_model = res_dict["matcher"]
                M_pred = res_dict["trn_pred"]
                val_M_pred = res_dict["val_pred"]
                inst_embeddings = res_dict["trn_embeddings"]
                val_inst_embeddings = res_dict["val_embeddings"]

        if train_params.save_emb_dir:
            os.makedirs(train_params.save_emb_dir, exist_ok=True)
            if inst_embeddings is not None:
                smat_util.save_matrix(
                    os.path.join(train_params.save_emb_dir, "X.trn.npy"),
                    inst_embeddings,
                )
                LOGGER.info(f"Trn embeddings saved to {train_params.save_emb_dir}/X.trn.npy")
            if val_inst_embeddings is not None:
                smat_util.save_matrix(
                    os.path.join(train_params.save_emb_dir, "X.val.npy"),
                    val_inst_embeddings,
                )
                LOGGER.info(f"Val embeddings saved to {train_params.save_emb_dir}/X.val.npy")

        ranker = None
        if not train_params.only_encoder:
            # construct X_concat
            X_concat = TransformerMatcher.concat_features(
                prob.X_feat,
                inst_embeddings,
                normalize_emb=True,
            )
            del inst_embeddings
            LOGGER.info("Constructed instance feature matrix with shape={}".format(X_concat.shape))

            # 3. construct refined HLT
            if train_params.fix_clustering:
                clustering = clustering
            else:
                clustering = Indexer.gen(
                    LabelEmbeddingFactory.pifa(prob.Y, X_concat),
                    train_params=train_params.refined_indexer_params,
                )
            LOGGER.info(
                "Hierarchical label tree for ranker: {}".format([cc.shape[0] for cc in clustering])
            )

            # the HLT could have changed depth
            train_params.ranker_params.hlm_args = (
                HierarchicalMLModel._duplicate_fields_with_name_ending_with_chain(
                    train_params.ranker_params.hlm_args,
                    HierarchicalMLModel.TrainParams,
                    len(clustering),
                )
            )
            pred_params.ranker_params.hlm_args = (
                HierarchicalMLModel._duplicate_fields_with_name_ending_with_chain(
                    pred_params.ranker_params.hlm_args,
                    HierarchicalMLModel.PredParams,
                    len(clustering),
                )
            )
            pred_params.ranker_params.override_with_kwargs(kwargs)

            # train the ranker
            LOGGER.info("Start training ranker...")

            ranker = XLinearModel.train(
                X_concat,
                prob.Y,
                C=clustering,
                train_params=train_params.ranker_params,
                pred_params=pred_params.ranker_params,
            )

        return cls(parent_model, ranker)