コード例 #1
0
  def testShapeInference(self):
    dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
                  dtypes.uint8, dtypes.uint16]

    with self.session(use_gpu=True) as sess:
      for dtype in dtype_list:
        lhs = constant_op.constant([[0], [3], [5]], dtype=dtype)
        rhs = constant_op.constant([[1, 2, 4]], dtype=dtype)

        and_tensor = bitwise_ops.bitwise_and(lhs, rhs)
        or_tensor = bitwise_ops.bitwise_or(lhs, rhs)
        xor_tensor = bitwise_ops.bitwise_xor(lhs, rhs)
        ls_tensor = bitwise_ops.left_shift(lhs, rhs)
        rs_tensor = bitwise_ops.right_shift(lhs, rhs)

        and_result, or_result, xor_result, ls_result, rs_result = sess.run(
            [and_tensor, or_tensor, xor_tensor, ls_tensor, rs_tensor])

        # Compare shape inference with result
        self.assertAllEqual(and_tensor.get_shape().as_list(), and_result.shape)
        self.assertAllEqual(and_tensor.get_shape().as_list(), [3, 3])
        self.assertAllEqual(or_tensor.get_shape().as_list(), or_result.shape)
        self.assertAllEqual(or_tensor.get_shape().as_list(), [3, 3])
        self.assertAllEqual(xor_tensor.get_shape().as_list(), xor_result.shape)
        self.assertAllEqual(xor_tensor.get_shape().as_list(), [3, 3])
        self.assertAllEqual(ls_tensor.get_shape().as_list(), ls_result.shape)
        self.assertAllEqual(ls_tensor.get_shape().as_list(), [3, 3])
        self.assertAllEqual(rs_tensor.get_shape().as_list(), rs_result.shape)
        self.assertAllEqual(rs_tensor.get_shape().as_list(), [3, 3])
コード例 #2
0
  def testShapeInference(self):
    dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
                  dtypes.uint8, dtypes.uint16]

    with self.test_session(use_gpu=True) as sess:
      for dtype in dtype_list:
        lhs = constant_op.constant([[0], [3], [5]], dtype=dtype)
        rhs = constant_op.constant([[1, 2, 4]], dtype=dtype)

        and_tensor = bitwise_ops.bitwise_and(lhs, rhs)
        or_tensor = bitwise_ops.bitwise_or(lhs, rhs)
        xor_tensor = bitwise_ops.bitwise_xor(lhs, rhs)
        ls_tensor = bitwise_ops.left_shift(lhs, rhs)
        rs_tensor = bitwise_ops.right_shift(lhs, rhs)

        and_result, or_result, xor_result, ls_result, rs_result = sess.run(
            [and_tensor, or_tensor, xor_tensor, ls_tensor, rs_tensor])

        # Compare shape inference with result
        self.assertAllEqual(and_tensor.get_shape().as_list(), and_result.shape)
        self.assertAllEqual(and_tensor.get_shape().as_list(), [3, 3])
        self.assertAllEqual(or_tensor.get_shape().as_list(), or_result.shape)
        self.assertAllEqual(or_tensor.get_shape().as_list(), [3, 3])
        self.assertAllEqual(xor_tensor.get_shape().as_list(), xor_result.shape)
        self.assertAllEqual(xor_tensor.get_shape().as_list(), [3, 3])
        self.assertAllEqual(ls_tensor.get_shape().as_list(), ls_result.shape)
        self.assertAllEqual(ls_tensor.get_shape().as_list(), [3, 3])
        self.assertAllEqual(rs_tensor.get_shape().as_list(), rs_result.shape)
        self.assertAllEqual(rs_tensor.get_shape().as_list(), [3, 3])
コード例 #3
0
  def testShiftsWithNegativeLHS(self):
    dtype_list = [np.int8, np.int16, np.int32, np.int64]

    with self.test_session(use_gpu=True) as sess:
      for dtype in dtype_list:
        lhs = np.array([-1, -5, -3, -14], dtype=dtype)
        rhs = np.array([5, 0, 7, 11], dtype=dtype)
        left_shift_result, right_shift_result = sess.run(
            [bitwise_ops.left_shift(lhs, rhs),
             bitwise_ops.right_shift(lhs, rhs)])
        self.assertAllEqual(left_shift_result, np.left_shift(lhs, rhs))
        self.assertAllEqual(right_shift_result, np.right_shift(lhs, rhs))
コード例 #4
0
  def testShiftsWithNegativeLHS(self):
    dtype_list = [np.int8, np.int16, np.int32, np.int64]

    with self.session(use_gpu=True) as sess:
      for dtype in dtype_list:
        lhs = np.array([-1, -5, -3, -14], dtype=dtype)
        rhs = np.array([5, 0, 7, 11], dtype=dtype)
        left_shift_result, right_shift_result = sess.run(
            [bitwise_ops.left_shift(lhs, rhs),
             bitwise_ops.right_shift(lhs, rhs)])
        self.assertAllEqual(left_shift_result, np.left_shift(lhs, rhs))
        self.assertAllEqual(right_shift_result, np.right_shift(lhs, rhs))
コード例 #5
0
ファイル: xla.py プロジェクト: AnishShah/tensorflow
def _shift_right_arithmetic_helper(x, y, name=None):
  """Performs an integer right arithmetic shift irrespective of input type."""
  assert y.dtype == x.dtype
  dtype = x.dtype
  unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE
  if unsigned:
    signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype]
    x = math_ops.cast(x, signed_dtype)
    y = math_ops.cast(y, signed_dtype)
  output = bitwise_ops.right_shift(x, y, name=name)
  if unsigned:
    output = math_ops.cast(output, dtype)
  return output
コード例 #6
0
ファイル: xla.py プロジェクト: AnishShah/tensorflow
def _shift_right_logical_helper(x, y, name=None):
  """Performs an integer right logical shift irrespective of input type."""
  assert y.dtype == x.dtype
  dtype = x.dtype
  signed = dtype in _SIGNED_TO_UNSIGNED_TABLE
  if signed:
    unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype]
    x = math_ops.cast(x, unsigned_dtype)
    y = math_ops.cast(y, unsigned_dtype)
  output = bitwise_ops.right_shift(x, y, name=name)
  if signed:
    output = math_ops.cast(output, dtype)
  return output
コード例 #7
0
def _shift_right_arithmetic_helper(x, y, name=None):
  """Performs an integer right arithmetic shift irrespective of input type."""
  assert y.dtype == x.dtype
  dtype = x.dtype
  unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE
  if unsigned:
    signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype]
    x = math_ops.cast(x, signed_dtype)
    y = math_ops.cast(y, signed_dtype)
  output = bitwise_ops.right_shift(x, y, name=name)
  if unsigned:
    output = math_ops.cast(output, dtype)
  return output
コード例 #8
0
def _shift_right_logical_helper(x, y, name=None):
  """Performs an integer right logical shift irrespective of input type."""
  assert y.dtype == x.dtype
  dtype = x.dtype
  signed = dtype in _SIGNED_TO_UNSIGNED_TABLE
  if signed:
    unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype]
    x = math_ops.cast(x, unsigned_dtype)
    y = math_ops.cast(y, unsigned_dtype)
  output = bitwise_ops.right_shift(x, y, name=name)
  if signed:
    output = math_ops.cast(output, dtype)
  return output
コード例 #9
0
  def testImplementationDefinedShiftsDoNotCrash(self):
    dtype_list = [np.int8, np.int16, np.int32, np.int64]

    with self.test_session(use_gpu=True) as sess:
      for dtype in dtype_list:
        lhs = np.array([-1, -5, -3, -14], dtype=dtype)
        rhs = np.array([-2, 64, 101, 32], dtype=dtype)
        # We intentionally do not test for specific values here since the exact
        # outputs are implementation-defined. However, we should not crash or
        # trigger an undefined-behavior error from tools such as
        # AddressSanitizer.
        sess.run([bitwise_ops.left_shift(lhs, rhs),
                  bitwise_ops.right_shift(lhs, rhs)])
コード例 #10
0
  def testImplementationDefinedShiftsDoNotCrash(self):
    dtype_list = [np.int8, np.int16, np.int32, np.int64]

    with self.session(use_gpu=True) as sess:
      for dtype in dtype_list:
        lhs = np.array([-1, -5, -3, -14], dtype=dtype)
        rhs = np.array([-2, 64, 101, 32], dtype=dtype)
        # We intentionally do not test for specific values here since the exact
        # outputs are implementation-defined. However, we should not crash or
        # trigger an undefined-behavior error from tools such as
        # AddressSanitizer.
        sess.run([bitwise_ops.left_shift(lhs, rhs),
                  bitwise_ops.right_shift(lhs, rhs)])
コード例 #11
0
  def testShiftsWithPositiveLHS(self):
    dtype_list = [np.int8, np.int16, np.int32, np.int64,
                  np.uint8, np.uint16, np.uint32, np.uint64]

    with self.session() as sess:
      for dtype in dtype_list:
        lhs = np.array([0, 5, 3, 14], dtype=dtype)
        rhs = np.array([5, 0, 7, 3], dtype=dtype)
        left_shift_result, right_shift_result = sess.run(
            [bitwise_ops.left_shift(lhs, rhs),
             bitwise_ops.right_shift(lhs, rhs)])
        self.assertAllEqual(left_shift_result, np.left_shift(lhs, rhs))
        self.assertAllEqual(right_shift_result, np.right_shift(lhs, rhs))