Exemplo n.º 1
0
 def y_pred_multilabel(self, threshold: float = 0.5) -> pd.DataFrame:
     """
     Returns:
       Indicator matrix representing the predicted labels for each observation
       using the given (optional) threshold.
     """
     return pred_prob_to_pred_multilabel(self.y_pred_proba, threshold)
Exemplo n.º 2
0
 def y_pred_multilabel(self) -> pd.DataFrame:
     """
     Returns:
       Indicator dataframe containing a 0 if each label wasn't predicted and 1 if
       it was for each observation.
     """
     return pred_prob_to_pred_multilabel(self.y_pred_proba).astype("int")
Exemplo n.º 3
0
    def _train(self, train_input: gobbli.io.TrainInput,
               context: ContainerTaskContext) -> gobbli.io.TrainOutput:
        self._write_input(
            train_input.X_train,
            train_input.y_train_multilabel,
            context.host_input_dir / FastText._TRAIN_INPUT_FILE,
        )
        self._write_input(
            train_input.X_valid,
            train_input.y_valid_multilabel,
            context.host_input_dir / FastText._VALID_INPUT_FILE,
        )

        container_validation_input_path = (context.container_input_dir /
                                           FastText._VALID_INPUT_FILE)
        train_logs, train_loss = self._run_supervised(
            train_input.checkpoint,
            context.container_input_dir / FastText._TRAIN_INPUT_FILE,
            context.container_output_dir / FastText._CHECKPOINT_BASE,
            context,
            train_input.num_train_epochs,
            autotune_validation_file_path=container_validation_input_path,
        )

        host_checkpoint_path = context.host_output_dir / f"{FastText._CHECKPOINT_BASE}"

        labels = train_input.labels()

        # Calculate validation accuracy on our own, since the CLI only provides
        # precision/recall
        predict_logs, pred_prob_df = self._run_predict_prob(
            host_checkpoint_path, labels, container_validation_input_path,
            context)

        if train_input.multilabel:
            pred_labels = pred_prob_to_pred_multilabel(pred_prob_df)
            gold_labels = multilabel_to_indicator_df(
                train_input.y_valid_multilabel, labels)
        else:
            pred_labels = pred_prob_to_pred_label(pred_prob_df)
            gold_labels = train_input.y_valid_multiclass

        valid_accuracy = accuracy_score(gold_labels, pred_labels)

        # Not ideal, but fastText doesn't provide a way to get validation loss;
        # Negate the validation accuracy instead
        valid_loss = -valid_accuracy

        return gobbli.io.TrainOutput(
            train_loss=train_loss,
            valid_loss=valid_loss,
            valid_accuracy=valid_accuracy,
            labels=labels,
            multilabel=train_input.multilabel,
            checkpoint=host_checkpoint_path,
            _console_output="\n".join((train_logs, predict_logs)),
        )
Exemplo n.º 4
0
    def y_pred_multilabel(self, threshold: float = 0.5) -> List[str]:
        """
        Args:
          threshold: The predicted probability threshold for predictions

        Returns:
          The predicted labels for this observation (predicted probability greater than
          the given threshold)
        """
        return pred_prob_to_pred_multilabel(self.y_pred_proba, threshold)