Beispiel #1
0
  def testAdd(self):
    for dtype in self.numeric_types:
      self._assertOpOutputMatchesExpected(
          xla.add,
          args=(np.array([1, 2, 3], dtype=dtype),
                np.array([4, 5, 6], dtype=dtype)),
          expected=np.array([5, 7, 9], dtype=dtype))

      self._assertOpOutputMatchesExpected(
          lambda x, y: xla.add(x, y, broadcast_dims=(0,)),
          args=(np.array([[1, 2], [3, 4]], dtype=dtype),
                np.array([7, 11], dtype=dtype)),
          expected=np.array([[8, 9], [14, 15]], dtype=dtype))

      self._assertOpOutputMatchesExpected(
          lambda x, y: xla.add(x, y, broadcast_dims=(1,)),
          args=(np.array([[1, 2], [3, 4]], dtype=dtype),
                np.array([7, 11], dtype=dtype)),
          expected=np.array([[8, 13], [10, 15]], dtype=dtype))
Beispiel #2
0
  def testAdd(self):
    for dtype in self.numeric_types:
      self._assertOpOutputMatchesExpected(
          xla.add,
          args=(np.array([1, 2, 3], dtype=dtype),
                np.array([4, 5, 6], dtype=dtype)),
          expected=np.array([5, 7, 9], dtype=dtype))

      self._assertOpOutputMatchesExpected(
          lambda x, y: xla.add(x, y, broadcast_dims=(0,)),
          args=(np.array([[1, 2], [3, 4]], dtype=dtype),
                np.array([7, 11], dtype=dtype)),
          expected=np.array([[8, 9], [14, 15]], dtype=dtype))

      self._assertOpOutputMatchesExpected(
          lambda x, y: xla.add(x, y, broadcast_dims=(1,)),
          args=(np.array([[1, 2], [3, 4]], dtype=dtype),
                np.array([7, 11], dtype=dtype)),
          expected=np.array([[8, 13], [10, 15]], dtype=dtype))
    def testAdd(self):
        if xla_test.test.is_built_with_rocm():
            self.skipTest('Broken with rocm')
        for dtype in self.numeric_types:
            self._assertOpOutputMatchesExpected(xla.add,
                                                args=(np.array([1, 2, 3],
                                                               dtype=dtype),
                                                      np.array([4, 5, 6],
                                                               dtype=dtype)),
                                                expected=np.array([5, 7, 9],
                                                                  dtype=dtype))

            self._assertOpOutputMatchesExpected(
                lambda x, y: xla.add(x, y, broadcast_dims=(0, )),
                args=(np.array([[1, 2], [3, 4]],
                               dtype=dtype), np.array([7, 11], dtype=dtype)),
                expected=np.array([[8, 9], [14, 15]], dtype=dtype))

            self._assertOpOutputMatchesExpected(
                lambda x, y: xla.add(x, y, broadcast_dims=(1, )),
                args=(np.array([[1, 2], [3, 4]],
                               dtype=dtype), np.array([7, 11], dtype=dtype)),
                expected=np.array([[8, 13], [10, 15]], dtype=dtype))