Ejemplo n.º 1
0
    def call(self, trial_split: TrialSplit, epoch_number: int, total_epochs: int, input_train: DataContainer,
             pred_train: DataContainer, input_val: DataContainer, pred_val: DataContainer, context: ExecutionContext,
             is_finished_and_fitted: bool):

        if self.pass_context_to_metric_function:
            train_score = self.metric_function(pred_train.expected_outputs, pred_train.data_inputs, context=context)
            validation_score = self.metric_function(pred_val.expected_outputs, pred_val.data_inputs, context=context)
        else:
            train_score = self.metric_function(pred_train.expected_outputs, pred_train.data_inputs)
            validation_score = self.metric_function(pred_val.expected_outputs, pred_val.data_inputs)

        trial_split.add_metric_results_train(
            name=self.name,
            score=train_score,
            higher_score_is_better=self.higher_score_is_better,
            log_metric=self.log_metrics
        )

        trial_split.add_metric_results_validation(
            name=self.name,
            score=validation_score,
            higher_score_is_better=self.higher_score_is_better,
            log_metric=self.log_metrics
        )

        return False
Ejemplo n.º 2
0
    def call(self, trial: TrialSplit, epoch_number: int, total_epochs: int,
             input_train: DataContainer, pred_train: DataContainer,
             input_val: DataContainer, pred_val: DataContainer,
             is_finished_and_fitted: bool):
        train_score = self.metric_function(pred_train.expected_outputs,
                                           pred_train.data_inputs)
        validation_score = self.metric_function(pred_val.expected_outputs,
                                                pred_val.data_inputs)

        trial.add_metric_results_train(
            name=self.name,
            score=train_score,
            higher_score_is_better=self.higher_score_is_better)

        trial.add_metric_results_validation(
            name=self.name,
            score=validation_score,
            higher_score_is_better=self.higher_score_is_better)

        if self.print_metrics:
            self.print_function('{} train: {}'.format(self.name, train_score))
            self.print_function('{} validation: {}'.format(
                self.name, validation_score))

        return False