def get_compiled_multi_io_model_temporal(sample_weight_mode):
    model = get_multi_io_temporal_model()
    model.compile(optimizer=optimizer_v2.gradient_descent.SGD(0.1),
                  loss='mae',
                  metrics=[metrics.MeanAbsoluteError(name='mae')],
                  weighted_metrics=[metrics.MeanAbsoluteError(name='mae_2')],
                  sample_weight_mode=sample_weight_mode,
                  run_eagerly=testing_utils.should_run_eagerly(),
                  run_distributed=testing_utils.should_run_distributed())
    return model
  def metrics(self, regularization_losses=None):
    """Creates metrics. See `base_head.Head` for details."""
    keys = metric_keys.MetricKeys
    with ops.name_scope('metrics', values=(regularization_losses,)):
      # Mean metric.
      eval_metrics = {}
      eval_metrics[self._loss_mean_key] = metrics.Mean(name=keys.LOSS_MEAN)
      eval_metrics[self._prediction_mean_key] = metrics.Mean(
          name=keys.PREDICTION_MEAN)
      eval_metrics[self._label_mean_key] = metrics.Mean(name=keys.LABEL_MEAN)
      if regularization_losses is not None:
        eval_metrics[self._loss_regularization_key] = metrics.Mean(
            name=keys.LOSS_REGULARIZATION)
      if self._model_hparams.da_tlen > 0 and not self._model_hparams.event_relation:
        eval_metrics[self._auc_roc_24] = metrics.AUC(name=AUC_ROC % '24')
        eval_metrics[self._auc_roc_48] = metrics.AUC(name=AUC_ROC % '48')
        eval_metrics[self._auc_pr] = metrics.AUC(
            curve='PR', name=AUC_PR % 'avg')
        eval_metrics[self._auc_roc] = metrics.AUC(name=AUC_ROC % 'avg')
        eval_metrics[self._mean_abs_error] = metrics.MeanAbsoluteError(
            name=MEAN_ABS_ERROR % 'avg')

        for i in range(int(self._model_hparams.da_tlen / SLOT_TO_WINDOW) + 1):
          eval_metrics[self._probablity_within_window_list[i]] = metrics.Mean(
              name=PROBABILITY_AT_WINDOW % i)

    return eval_metrics
Пример #3
0
    def test_config(self):
        mae_obj = metrics.MeanAbsoluteError(name='my_mae', dtype=dtypes.int32)
        self.assertEqual(mae_obj.name, 'my_mae')
        self.assertEqual(mae_obj._dtype, dtypes.int32)

        # Check save and restore config
        mae_obj2 = metrics.MeanAbsoluteError.from_config(mae_obj.get_config())
        self.assertEqual(mae_obj2.name, 'my_mae')
        self.assertEqual(mae_obj2._dtype, dtypes.int32)
Пример #4
0
 def test_weighted(self):
     mae_obj = metrics.MeanAbsoluteError()
     self.evaluate(variables.variables_initializer(mae_obj.variables))
     y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1),
                                    (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)))
     y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1),
                                    (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))
     sample_weight = constant_op.constant((1., 1.5, 2., 2.5))
     result = mae_obj(y_true, y_pred, sample_weight=sample_weight)
     self.assertAllClose(0.54285, self.evaluate(result), atol=1e-5)
Пример #5
0
    def test_unweighted(self):
        mae_obj = metrics.MeanAbsoluteError()
        self.evaluate(variables.variables_initializer(mae_obj.variables))
        y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1),
                                       (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)))
        y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1),
                                       (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))

        update_op = mae_obj.update_state(y_true, y_pred)
        self.evaluate(update_op)
        result = mae_obj.result()
        self.assertAllClose(0.5, result, atol=1e-5)
    def metrics(self, regularization_losses=None):
        """Creates metrics. See `base_head.Head` for details."""
        keys = metric_keys.MetricKeys
        with ops.name_scope('metrics', values=(regularization_losses, )):
            # Mean metric.
            eval_metrics = {}
            eval_metrics[self._loss_mean_key] = metrics.Mean(
                name=keys.LOSS_MEAN)
            eval_metrics[self._prediction_mean_key] = metrics.Mean(
                name=keys.PREDICTION_MEAN)
            eval_metrics[self._label_mean_key] = metrics.Mean(
                name=keys.LABEL_MEAN)
            if regularization_losses is not None:
                eval_metrics[self._loss_regularization_key] = metrics.Mean(
                    name=keys.LOSS_REGULARIZATION)
            eval_metrics[self._mean_abs_error] = metrics.MeanAbsoluteError(
                name=MEAN_ABS_ERROR, dtype=tf.float32)

        return eval_metrics