コード例 #1
0
    def test_reduce_is_idempotent(self):
        with self.test_session() as session:
            for tower_id in range(3):
                self.create_tower_metrics(tower_id)

            session.run(
                variables.variables_initializer(
                    ops_lib.get_collection(
                        ops_lib.GraphKeys.METRIC_VARIABLES)))

            for _ in range(20):
                session.run(
                    replicate_model_fn._reduce_metric_variables(
                        number_of_towers=3))

            local_metrics = session.run(
                ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))

            self.assertNear(7.8, local_metrics[0], 0.01)
            self.assertNear(13.8, local_metrics[1], 0.01)
            self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
            self.assertNear(0.0, local_metrics[3], 0.01)
            self.assertNear(0.0, local_metrics[4], 0.01)
            self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
            self.assertNear(0.0, local_metrics[6], 0.01)
            self.assertNear(0.0, local_metrics[7], 0.01)
            self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
コード例 #2
0
  def test_reduce_is_idempotent(self):
    with self.test_session() as session:
      for tower_id in range(3):
        self.create_tower_metrics(tower_id)

      session.run(
          variables.variables_initializer(
              ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))

      for _ in range(20):
        session.run(
            replicate_model_fn._reduce_metric_variables(number_of_towers=3))

      local_metrics = session.run(
          ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))

      self.assertNear(7.8, local_metrics[0], 0.01)
      self.assertNear(13.8, local_metrics[1], 0.01)
      self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
      self.assertNear(0.0, local_metrics[3], 0.01)
      self.assertNear(0.0, local_metrics[4], 0.01)
      self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
      self.assertNear(0.0, local_metrics[6], 0.01)
      self.assertNear(0.0, local_metrics[7], 0.01)
      self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
コード例 #3
0
    def test_example(self):
        with self.test_session() as session:
            for tower_id in range(3):
                self.create_tower_metrics(tower_id)

            session.run(
                variables.variables_initializer(
                    ops_lib.get_collection(
                        ops_lib.GraphKeys.METRIC_VARIABLES)))

            session.run(
                replicate_model_fn._reduce_metric_variables(
                    number_of_towers=3))

            # 1st tower = 1.3, 2.3,  [3.3, 3.5, 3.7]
            # 2nd tower = 2.6, 4.6,  [6.6, 7.0, 7.4]
            # 3rd tower = 3.9, 6.9,  [9.9, 10.5, 11.1]
            # Reduced =   7.8, 13.8, [19.8, 21.0, 22.2]
            # Towers are accumulated in the first tower.
            local_metrics = session.run(
                ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))

            self.assertNear(7.8, local_metrics[0], 0.01)
            self.assertNear(13.8, local_metrics[1], 0.01)
            self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
            self.assertNear(0.0, local_metrics[3], 0.01)
            self.assertNear(0.0, local_metrics[4], 0.01)
            self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
            self.assertNear(0.0, local_metrics[6], 0.01)
            self.assertNear(0.0, local_metrics[7], 0.01)
            self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
コード例 #4
0
  def test_example(self):
    with self.test_session() as session:
      for tower_id in range(3):
        self.create_tower_metrics(tower_id)

      session.run(
          variables.variables_initializer(
              ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))

      session.run(
          replicate_model_fn._reduce_metric_variables(number_of_towers=3))

      # 1st tower = 1.3, 2.3,  [3.3, 3.5, 3.7]
      # 2nd tower = 2.6, 4.6,  [6.6, 7.0, 7.4]
      # 3rd tower = 3.9, 6.9,  [9.9, 10.5, 11.1]
      # Reduced =   7.8, 13.8, [19.8, 21.0, 22.2]
      # Towers are accumulated in the first tower.
      local_metrics = session.run(
          ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))

      self.assertNear(7.8, local_metrics[0], 0.01)
      self.assertNear(13.8, local_metrics[1], 0.01)
      self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
      self.assertNear(0.0, local_metrics[3], 0.01)
      self.assertNear(0.0, local_metrics[4], 0.01)
      self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
      self.assertNear(0.0, local_metrics[6], 0.01)
      self.assertNear(0.0, local_metrics[7], 0.01)
      self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
コード例 #5
0
  def test_doesnt_accept_uneven_number_of_variables(self):
    with self.test_session() as session:
      for tower_id in range(3):
        self.create_tower_metrics(tower_id)
      self.create_metric_variable(-1.0, 'oddball')

      session.run(
          variables.variables_initializer(
              ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))

      with self.assertRaisesRegexp(ValueError, ''):
        session.run(
            replicate_model_fn._reduce_metric_variables(number_of_towers=3))
コード例 #6
0
  def test_doesnt_accept_uneven_number_of_variables(self):
    with self.test_session() as session:
      for tower_id in range(3):
        self.create_tower_metrics(tower_id)
      self.create_metric_variable(-1.0, 'oddball')

      session.run(
          variables.variables_initializer(
              ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))

      with self.assertRaisesRegexp(ValueError, ''):
        session.run(
            replicate_model_fn._reduce_metric_variables(number_of_towers=3))
コード例 #7
0
  def test_handles_single_tower(self):
    with self.test_session() as session:
      self.create_tower_metrics(0)
      session.run(
          variables.variables_initializer(
              ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))

      session.run(
          replicate_model_fn._reduce_metric_variables(number_of_towers=1))

      local_metrics = session.run(
          ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))

      self.assertNear(1.3, local_metrics[0], 0.01)
      self.assertNear(2.3, local_metrics[1], 0.01)
      self.assertAllClose([3.3, 3.5, 3.7], local_metrics[2], 0.01)
コード例 #8
0
  def test_handles_single_tower(self):
    with self.test_session() as session:
      self.create_tower_metrics(0)
      session.run(
          variables.variables_initializer(
              ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))

      session.run(
          replicate_model_fn._reduce_metric_variables(number_of_towers=1))

      local_metrics = session.run(
          ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))

      self.assertNear(1.3, local_metrics[0], 0.01)
      self.assertNear(2.3, local_metrics[1], 0.01)
      self.assertAllClose([3.3, 3.5, 3.7], local_metrics[2], 0.01)