コード例 #1
0
    def test_no_name_for_indexslices(self):
        a = ops_lib.IndexedSlices(constant_op.constant([1.0, 2.0]), [0, 1],
                                  dense_shape=constant_op.constant([2]))
        b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])

        with self.assertRaisesRegexp(ValueError, ''):
            _ = replicate_model_fn._compute_sum_on_device(
                [a, b], device='/device:GPU:0', name='cant_name_indexslices')
コード例 #2
0
    def test_vectors(self):
        with self.test_session() as session:
            total = replicate_model_fn._compute_sum_on_device(
                [1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum')

            self.assertEqual('/device:GPU:0', total.device)
            self.assertEqual('test_sum', total.op.name)
            self.assertEqual(10.0, session.run(total))
コード例 #3
0
  def test_tensors(self):
    with self.test_session() as session:
      total = replicate_model_fn._compute_sum_on_device(
          [[1.0, 2.0], [3.0, 4.0]], device='/device:GPU:0', name='test_sum')

      self.assertEqual('/device:GPU:0', total.device)
      self.assertEqual('test_sum', total.op.name)
      self.assertAllEqual([4.0, 6.0], session.run(total))
コード例 #4
0
  def test_no_name_for_indexslices(self):
    a = ops_lib.IndexedSlices(
        constant_op.constant([1.0, 2.0]), [0, 1],
        dense_shape=constant_op.constant([2]))
    b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])

    with self.assertRaisesRegexp(ValueError, ''):
      _ = replicate_model_fn._compute_sum_on_device(
          [a, b], device='/device:GPU:0', name='cant_name_indexslices')
コード例 #5
0
    def test_indexedslices_some_dont_overlap(self):
        with self.test_session() as session:
            a = ops_lib.IndexedSlices(constant_op.constant([1.0, 2.0]), [0, 3],
                                      dense_shape=constant_op.constant([4]))
            b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])

            total = replicate_model_fn._compute_sum_on_device(
                [a, b], device='/device:GPU:0')

            self.assertEqual('/device:GPU:0', total.device)
            self.assertAllEqual([4.0, 4.0, 0.0, 2.0],
                                session.run(ops_lib.convert_to_tensor(total)))
コード例 #6
0
  def test_indexedslices_some_dont_overlap(self):
    with self.test_session() as session:
      a = ops_lib.IndexedSlices(
          constant_op.constant([1.0, 2.0]), [0, 3],
          dense_shape=constant_op.constant([4]))
      b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])

      total = replicate_model_fn._compute_sum_on_device(
          [a, b], device='/device:GPU:0')

      self.assertEqual('/device:GPU:0', total.device)
      self.assertAllEqual([4.0, 4.0, 0.0, 2.0],
                          session.run(ops_lib.convert_to_tensor(total)))
コード例 #7
0
  def test_indexedslices_higher_dimensions(self):
    with self.test_session() as session:
      a = ops_lib.IndexedSlices(
          constant_op.constant([[1.0, 5.0], [2.0, 6.0]]), [0, 1],
          dense_shape=constant_op.constant([2, 4]))
      b = ops_lib.IndexedSlices(
          constant_op.constant([[3.0, 7.0], [4.0, 8.0]]), [0, 1])

      total = replicate_model_fn._compute_sum_on_device(
          [a, b], device='/device:GPU:0')

      self.assertEqual('/device:GPU:0', total.device)
      self.assertAllEqual([[4.0, 12.0], [6.0, 14.0]],
                          session.run(ops_lib.convert_to_tensor(total)))
コード例 #8
0
  def test_indexedslices_higher_dimensions(self):
    with self.test_session() as session:
      a = ops_lib.IndexedSlices(
          constant_op.constant([[1.0, 5.0], [2.0, 6.0]]), [0, 1],
          dense_shape=constant_op.constant([2, 4]))
      b = ops_lib.IndexedSlices(
          constant_op.constant([[3.0, 7.0], [4.0, 8.0]]), [0, 1])

      total = replicate_model_fn._compute_sum_on_device(
          [a, b], device='/device:GPU:0')

      self.assertEqual('/device:GPU:0', total.device)
      self.assertAllEqual([[4.0, 12.0], [6.0, 14.0]],
                          session.run(ops_lib.convert_to_tensor(total)))