コード例 #1
0
    def test_handles_single_tower(self):
        with self.test_session() as session:
            tower_losses = map(self.create_constant_loss, [5])
            tower_metrics = map(self.create_eval_metrics, [0.2])
            tower_specs = [
                self.create_estimator_spec(l, m)
                for l, m in zip(tower_losses, tower_metrics)
            ]
            session.run(variables.local_variables_initializer())

            estimator_spec = replicate_model_fn._eval_spec(
                tower_specs, aggregation_device='/device:GPU:0')

            accuracy, a = estimator_spec.eval_metric_ops['accuracy']
            auc, b = estimator_spec.eval_metric_ops['auc']

            self.assertEqual('/device:CPU:0', accuracy.device)
            self.assertEqual('/device:CPU:0', auc.device)

            session.run([a, b])
            accuracy = session.run(accuracy)
            auc = session.run(auc)

            self.assertNear((4 - 1) / 4, accuracy, 0.01)
            self.assertEqual(0, auc)
            self.assertEqual(5, session.run(estimator_spec.loss))
コード例 #2
0
  def test_handles_single_tower(self):
    with self.test_session() as session:
      tower_losses = map(self.create_constant_loss, [5])
      tower_metrics = map(self.create_eval_metrics, [0.2])
      tower_specs = [
          self.create_estimator_spec(l, m)
          for l, m in zip(tower_losses, tower_metrics)
      ]
      session.run(variables.local_variables_initializer())

      estimator_spec = replicate_model_fn._eval_spec(
          tower_specs, aggregation_device='/device:GPU:0')

      accuracy, a = estimator_spec.eval_metric_ops['accuracy']
      auc, b = estimator_spec.eval_metric_ops['auc']

      self.assertEqual('/device:CPU:0', accuracy.device)
      self.assertEqual('/device:CPU:0', auc.device)

      session.run([a, b])
      accuracy = session.run(accuracy)
      auc = session.run(auc)

      self.assertNear((4 - 1) / 4, accuracy, 0.01)
      self.assertEqual(0, auc)
      self.assertEqual(5, session.run(estimator_spec.loss))
コード例 #3
0
    def test_example(self):
        with self.test_session() as session:
            tower_losses = map(self.create_constant_loss, [2, 4, 6])
            tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3])
            tower_specs = [
                self.create_estimator_spec(l, m)
                for l, m in zip(tower_losses, tower_metrics)
            ]
            session.run(variables.local_variables_initializer())

            estimator_spec = replicate_model_fn._eval_spec(
                tower_specs, aggregation_device='/device:GPU:0')

            accuracy, a = estimator_spec.eval_metric_ops['accuracy']
            auc, b = estimator_spec.eval_metric_ops['auc']

            self.assertEqual('/device:CPU:0', accuracy.device)
            self.assertEqual('/device:CPU:0', auc.device)

            session.run([a, b])
            accuracy, auc = session.run([accuracy, auc])

            self.assertNear((12 - 2) / 12, accuracy, 0.01)
            self.assertEqual(0, auc)
            self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
コード例 #4
0
  def test_example(self):
    with self.test_session() as session:
      tower_losses = map(self.create_constant_loss, [2, 4, 6])
      tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3])
      tower_specs = [
          self.create_estimator_spec(l, m)
          for l, m in zip(tower_losses, tower_metrics)
      ]
      session.run(variables.local_variables_initializer())

      estimator_spec = replicate_model_fn._eval_spec(
          tower_specs, aggregation_device='/device:GPU:0')

      accuracy, a = estimator_spec.eval_metric_ops['accuracy']
      auc, b = estimator_spec.eval_metric_ops['auc']

      self.assertEqual('/device:CPU:0', accuracy.device)
      self.assertEqual('/device:CPU:0', auc.device)

      session.run([a, b])
      accuracy, auc = session.run([accuracy, auc])

      self.assertNear((12 - 2) / 12, accuracy, 0.01)
      self.assertEqual(0, auc)
      self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))