Exemplo n.º 1
0
 def test_fixed_effect_lr_lbfgs_model_creation(self):
     fe_model = ModelFactory.get_model(
         base_training_params=setup_fake_base_training_params(training_stage=constants.FIXED_EFFECT,
                                                              model_type=constants.LOGISTIC_REGRESSION),
         raw_model_params=self.model_params)
     # Assert the type of model
     self.assertIsInstance(fe_model, FixedEffectLRModelLBFGS)
Exemplo n.º 2
0
 def test_random_effect_custom_logistic_regression_model_creation(self):
     re_model = ModelFactory.get_model(
         base_training_params=setup_fake_base_training_params(
             training_stage=constants.RANDOM_EFFECT,
             model_type=constants.LOGISTIC_REGRESSION),
         raw_model_params=self.model_params)
     self.assertIsInstance(re_model, RandomEffectLRLBFGSModel)
Exemplo n.º 3
0
    def get_driver(base_training_params, raw_model_params):
        """
        Create driver and associated dependencies, based on type. Only linear and DeText models are supported
        for now
        :param base_training_params:      Parsed base training parameters common to all models. This could including
        path to training data, validation data, metadata file path, learning rate etc.
        :param raw_model_params:          Raw model parameters, representing model-specific requirements. For example, a
        CNN might expose filter_size as a parameter, a text-based model might expose the size it's word embedding matrix
        as a parameter
        :return:            Fixed or Random effect driver
        """

        driver = DriverFactory.drivers[base_training_params.stage]
        model = ModelFactory.get_model(base_training_params, raw_model_params)
        logger.info(f"Instantiating model {model} and driver {driver}")
        return driver(base_training_params=base_training_params, model=model)
Exemplo n.º 4
0
    def get_driver(base_training_params, raw_model_params):
        """
        Create driver and associated dependencies, based on type. Only linear, estimator-based models supported
        for now
        :param base_training_params:      Parsed base training parameters common to all models. This could including
        path to training data, validation data, metadata file path, learning rate etc.
        :param raw_model_params:          Raw model parameters, representing model-specific requirements. For example, a
        CNN might expose filter_size as a parameter, a text-based model might expose the size it's word embedding matrix
        as a parameter
        :return:            Fixed or Random effect driver
        """

        driver_type = base_training_params[constants.STAGE]
        model = ModelFactory.get_model(base_training_params, raw_model_params)
        if driver_type == constants.FIXED_EFFECT:
            logger.info("Instantiating fixed effect model and driver")
            driver = FixedEffectDriver(base_training_params=base_training_params, model=model)
        elif driver_type == constants.RANDOM_EFFECT:
            logger.info("Instantiating random effect model and driver")
            driver = RandomEffectDriver(base_training_params=base_training_params, model=model)
        else:
            raise Exception("Unknown training stage")
        return driver