def testGetOutputAlternatives(self):
    test_cases = (
        (rnn_common.PredictionType.SINGLE_VALUE,
         constants.ProblemType.CLASSIFICATION,
         {prediction_key.PredictionKey.CLASSES: True,
          prediction_key.PredictionKey.PROBABILITIES: True,
          dynamic_rnn_estimator._get_state_name(0): True},
         {'dynamic_rnn_output':
          (constants.ProblemType.CLASSIFICATION,
           {prediction_key.PredictionKey.CLASSES: True,
            prediction_key.PredictionKey.PROBABILITIES: True})}),

        (rnn_common.PredictionType.SINGLE_VALUE,
         constants.ProblemType.LINEAR_REGRESSION,
         {prediction_key.PredictionKey.SCORES: True,
          dynamic_rnn_estimator._get_state_name(0): True,
          dynamic_rnn_estimator._get_state_name(1): True},
         {'dynamic_rnn_output':
          (constants.ProblemType.LINEAR_REGRESSION,
           {prediction_key.PredictionKey.SCORES: True})}),

        (rnn_common.PredictionType.MULTIPLE_VALUE,
         constants.ProblemType.CLASSIFICATION,
         {prediction_key.PredictionKey.CLASSES: True,
          prediction_key.PredictionKey.PROBABILITIES: True,
          dynamic_rnn_estimator._get_state_name(0): True},
         None))

    for pred_type, prob_type, pred_dict, expected_alternatives in test_cases:
      actual_alternatives = dynamic_rnn_estimator._get_output_alternatives(
          pred_type, prob_type, pred_dict)
      self.assertEqual(expected_alternatives, actual_alternatives)
예제 #2
0
    def testGetOutputAlternatives(self):
        test_cases = ((dynamic_rnn_estimator.PredictionType.SINGLE_VALUE,
                       constants.ProblemType.CLASSIFICATION, {
                           prediction_key.PredictionKey.CLASSES: True,
                           prediction_key.PredictionKey.PROBABILITIES: True,
                           dynamic_rnn_estimator._get_state_name(0): True
                       }, {
                           'dynamic_rnn_output':
                           (constants.ProblemType.CLASSIFICATION, {
                               prediction_key.PredictionKey.CLASSES: True,
                               prediction_key.PredictionKey.PROBABILITIES: True
                           })
                       }), (dynamic_rnn_estimator.PredictionType.SINGLE_VALUE,
                            constants.ProblemType.LINEAR_REGRESSION, {
                                prediction_key.PredictionKey.SCORES: True,
                                dynamic_rnn_estimator._get_state_name(0): True,
                                dynamic_rnn_estimator._get_state_name(1): True
                            }, {
                                'dynamic_rnn_output':
                                (constants.ProblemType.LINEAR_REGRESSION, {
                                    prediction_key.PredictionKey.SCORES: True
                                })
                            }),
                      (dynamic_rnn_estimator.PredictionType.MULTIPLE_VALUE,
                       constants.ProblemType.CLASSIFICATION, {
                           prediction_key.PredictionKey.CLASSES: True,
                           prediction_key.PredictionKey.PROBABILITIES: True,
                           dynamic_rnn_estimator._get_state_name(0): True
                       }, None))

        for pred_type, prob_type, pred_dict, expected_alternatives in test_cases:
            actual_alternatives = dynamic_rnn_estimator._get_output_alternatives(
                pred_type, prob_type, pred_dict)
            self.assertEqual(expected_alternatives, actual_alternatives)