예제 #1
0
파일: pipeline.py 프로젝트: fagan2888/ml4ir
    def get_relevance_model(self,
                            feature_layer_keys_to_fns={}) -> RelevanceModel:
        """
        Creates RelevanceModel suited for classification use-case.

        NOTE: Override this method to create custom loss, scorer, model objects.
        """

        # Define interaction model
        interaction_model: InteractionModel = UnivariateInteractionModel(
            feature_config=self.feature_config,
            feature_layer_keys_to_fns=feature_layer_keys_to_fns,
            tfrecord_type=self.tfrecord_type,
            file_io=self.file_io)

        # Define loss object from loss key
        loss: RelevanceLossBase = categorical_cross_entropy.get_loss(
            loss_key=self.loss_key)

        # Define scorer
        scorer: ScorerBase = RelevanceScorer.from_model_config_file(
            model_config_file=self.model_config_file,
            interaction_model=interaction_model,
            loss=loss,
            output_name=self.args.output_name,
            logger=self.logger,
            file_io=self.file_io,
        )

        # Define metrics objects from metrics keys
        metrics: List[Union[Type[Metric], str]] = [
            metric_factory.get_metric(metric_key=metric_key)
            for metric_key in self.metrics_keys
        ]

        # Define optimizer
        optimizer: Optimizer = get_optimizer(
            optimizer_key=self.optimizer_key,
            learning_rate=self.args.learning_rate,
            learning_rate_decay=self.args.learning_rate_decay,
            learning_rate_decay_steps=self.args.learning_rate_decay_steps,
            gradient_clip_value=self.args.gradient_clip_value,
        )

        # Combine the above to define a RelevanceModel
        relevance_model: RelevanceModel = RelevanceModel(
            feature_config=self.feature_config,
            scorer=scorer,
            metrics=metrics,
            optimizer=optimizer,
            tfrecord_type=self.tfrecord_type,
            model_file=self.args.model_file,
            compile_keras_model=self.args.compile_keras_model,
            output_name=self.args.output_name,
            file_io=self.local_io,
            logger=self.logger,
        )
        return relevance_model
예제 #2
0
    def get_relevance_model(self, feature_layer_keys_to_fns={}) -> RelevanceModel:
        """
        Creates a RelevanceModel that can be used for training and evaluating

        Parameters
        ----------
        feature_layer_keys_to_fns : dict of (str, function)
            dictionary of function names mapped to tensorflow compatible
            function definitions that can now be used in the InteractionModel
            as a feature function to transform input features

        Returns
        -------
        `RelevanceModel`
            RelevanceModel that can be used for training and evaluating
            a classification model

        Notes
        -----
        Override this method to create custom loss, scorer, model objects
        """

        # Define interaction model
        interaction_model: InteractionModel = UnivariateInteractionModel(
            feature_config=self.feature_config,
            feature_layer_keys_to_fns=feature_layer_keys_to_fns,
            tfrecord_type=self.tfrecord_type,
            file_io=self.file_io,
        )

        # Define loss object from loss key
        loss: RelevanceLossBase = categorical_cross_entropy.get_loss(loss_key=self.loss_key)

        # Define scorer
        scorer: ScorerBase = RelevanceScorer(
            feature_config=self.feature_config,
            model_config=self.model_config,
            interaction_model=interaction_model,
            loss=loss,
            output_name=self.args.output_name,
            logger=self.logger,
            file_io=self.file_io,
        )

        # Define metrics objects from metrics keys
        metrics: List[Union[Type[Metric], str]] = [
            metrics_factory.get_metric(metric_key=metric_key) for metric_key in self.metrics_keys
        ]

        # Define optimizer
        optimizer: Optimizer = get_optimizer(model_config=self.model_config)

        # Combine the above to define a RelevanceModel
        relevance_model: RelevanceModel = ClassificationModel(
            feature_config=self.feature_config,
            scorer=scorer,
            metrics=metrics,
            optimizer=optimizer,
            tfrecord_type=self.tfrecord_type,
            model_file=self.args.model_file,
            initialize_layers_dict=ast.literal_eval(self.args.initialize_layers_dict),
            freeze_layers_list=ast.literal_eval(self.args.freeze_layers_list),
            compile_keras_model=self.args.compile_keras_model,
            output_name=self.args.output_name,
            file_io=self.local_io,
            logger=self.logger,
        )
        return relevance_model