Example #1
0
    def create_model(self, shared_rep):
        # shared_rep: do query and response share representation layer?
        metadata = self.data_handler.metadata
        model_config = QueryDocumentPairwiseRankingModel_Deprecated.Config()

        model_config.representation = QueryDocumentPairwiseRankingRep.Config()
        model_config.representation.shared_representations = shared_rep

        model_config.decoder = MLPDecoderQueryResponse.Config()
        model_config.decoder.hidden_dims = [64]
        model_config.output_layer = PairwiseRankingOutputLayer.Config()

        feat_config = ModelInputConfig()
        feat_config.pos_response = WordFeatConfig()
        feat_config.pos_response.embed_dim = 64
        feat_config.neg_response = WordFeatConfig()
        feat_config.query = WordFeatConfig()
        return QueryDocumentPairwiseRankingModel_Deprecated.from_config(
            model_config, feat_config, metadata)
    def setUp(self):
        simple_featurizer_config = SimpleFeaturizer.Config()
        simple_featurizer_config.split_regex = r""
        simple_featurizer_config.convert_to_bytes = True

        self.data_handler = QueryDocumentPairwiseRankingDataHandler.from_config(
            QueryDocumentPairwiseRankingDataHandler.Config(),
            ModelInputConfig(),
            [],
            featurizer=SimpleFeaturizer.from_config(simple_featurizer_config,
                                                    FeatureConfig()),
        )
Example #3
0
    def setup_data(self):
        simple_featurizer_config = SimpleFeaturizer.Config()
        simple_featurizer_config.split_regex = r""
        simple_featurizer_config.convert_to_bytes = True

        self.data_handler = QueryDocumentPairwiseRankingDataHandler.from_config(
            QueryDocumentPairwiseRankingDataHandler.Config(),
            ModelInputConfig(),
            [],
            featurizer=SimpleFeaturizer.from_config(simple_featurizer_config,
                                                    FeatureConfig()),
        )
        self.file_name = tests_module.test_file(
            "query_document_pairwise_ranking_tiny.tsv")
        self.data_handler.shuffle = False
        self.data_handler.init_metadata_from_path(self.file_name,
                                                  self.file_name,
                                                  self.file_name)