Exemple #1
0
    def createmodel(self, quantize=True):
        """Creates the model and attaches with the dataloader.

        By default it sets up the model for quantization aware training.

        Parameters
        ----------
        quantize : bool, optional
            To quantize or not, by default True
        """
        print("Creating model..")

        vision.learner.create_body = self.create_custom_body

        self.learn = cnn_learner(
            self.data,
            models.mobilenet_v2,
            pretrained=True,
            metrics=[error_rate, FBeta(beta=1), Precision(), Recall(), AUROC()],
            split_on=custom_split,
            model_dir=self.model_dir,
        )

        if quantize:
            self.learn.model[0].qconfig = torch.quantization.default_qat_qconfig
            self.learn.model = torch.quantization.prepare_qat(
                self.learn.model, inplace=True
            )
Exemple #2
0
    def __get_objective_func_name(self):
        from fastai.metrics import root_mean_squared_error, mean_squared_error, mean_absolute_error, accuracy, FBeta, AUROC, Precision, Recall, r2_score

        metrics_map = {
            # Regression
            'root_mean_squared_error': root_mean_squared_error,
            'mean_squared_error': mean_squared_error,
            'mean_absolute_error': mean_absolute_error,
            'r2': r2_score,
            # Not supported: median_absolute_error

            # Classification
            'accuracy': accuracy,

            'f1': FBeta(beta=1),
            'f1_macro': FBeta(beta=1, average='macro'),
            'f1_micro': FBeta(beta=1, average='micro'),
            'f1_weighted': FBeta(beta=1, average='weighted'),  # this one has some issues

            'roc_auc': AUROC(),

            'precision': Precision(),
            'precision_macro': Precision(average='macro'),
            'precision_micro': Precision(average='micro'),
            'precision_weighted': Precision(average='weighted'),

            'recall': Recall(),
            'recall_macro': Recall(average='macro'),
            'recall_micro': Recall(average='micro'),
            'recall_weighted': Recall(average='weighted'),
            'log_loss': None,
            # Not supported: pac_score
        }

        # Unsupported metrics will be replaced by defaults for a given problem type
        objective_func_name = self.stopping_metric.name
        if objective_func_name not in metrics_map.keys():
            if self.problem_type == REGRESSION:
                objective_func_name = 'mean_squared_error'
            else:
                objective_func_name = 'log_loss'
            logger.warning(f'Metric {self.stopping_metric.name} is not supported by this model - using {objective_func_name} instead')

        if objective_func_name in metrics_map.keys():
            nn_metric = metrics_map[objective_func_name]
        else:
            nn_metric = None
        return nn_metric, objective_func_name