コード例 #1
0
ファイル: random_test.py プロジェクト: zizai/jax
 def testNoOpByOpUnderHash(self):
   def fail(*args, **kwargs): assert False
   apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
   try:
     _ = random.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32))
   finally:
     xla.apply_primitive = apply_primitive
コード例 #2
0
ファイル: random_test.py プロジェクト: zizai/jax
 def testThreefry2x32Empty(self):
   # Regression test for an op-by-op crash for empty arrays in CUDA mode.
   with api.disable_jit():
     result = random.threefry_2x32(
       (np.uint32(0x13198a2e), np.uint32(0x03707344)),
       jnp.ones((10, 0,), jnp.uint32))
   np.testing.assert_equal(result, np.zeros((10, 0,), dtype=np.uint32))
コード例 #3
0
ファイル: random_test.py プロジェクト: tpanthera/jax
    def testNoOpByOpUnderHash(self):
        def fail():
            assert False

        apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
        out = random.threefry_2x32(onp.zeros(2, onp.uint32),
                                   onp.arange(10, dtype=onp.uint32))
        xla.apply_primitive = apply_primitive
コード例 #4
0
ファイル: random_test.py プロジェクト: uafpdivad/jax
 def testNoOpByOpUnderHash(self):
   if not config.omnistaging_enabled:
     raise SkipTest("test requires omnistaging")
   def fail(*args, **kwargs): assert False
   apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
   try:
     _ = random.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32))
   finally:
     xla.apply_primitive = apply_primitive
コード例 #5
0
ファイル: random_test.py プロジェクト: zizai/jax
 def testThreefry2x32Large(self):
   n = 10000000
   result = random.threefry_2x32(
     (np.uint32(0x13198a2e), np.uint32(0x03707344)),
     jnp.concatenate([
       jnp.full((n,), 0x243f6a88, jnp.uint32),
       jnp.full((n,), 0x85a308d3, jnp.uint32)
     ]))
   np.testing.assert_equal(result[:n], np.full((n,), 0xc4923a9c, dtype=np.uint32))
   np.testing.assert_equal(result[n:], np.full((n,), 0x483df7a0, dtype=np.uint32))
コード例 #6
0
ファイル: random_test.py プロジェクト: samuela/jax
    def testThreefry2x32(self):
        # We test the hash by comparing to known values provided in the test code of
        # the original reference implementation of Threefry. For the values, see
        # https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32
        def result_to_hex(result):
            return tuple([hex(x.copy()).rstrip("L") for x in result])

        expected = ("0x6b200159", "0x99ba4efe")
        result = random.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0]))

        self.assertEqual(expected, result_to_hex(result))

        expected = ("0x1cb996fc", "0xbb002be7")
        result = random.threefry_2x32(np.uint32([-1, -1]), np.uint32([-1, -1]))
        self.assertEqual(expected, result_to_hex(result))

        expected = ("0xc4923a9c", "0x483df7a0")
        result = random.threefry_2x32(np.uint32([0x13198a2e, 0x03707344]),
                                      np.uint32([0x243f6a88, 0x85a308d3]))
        self.assertEqual(expected, result_to_hex(result))
コード例 #7
0
ファイル: jax_call.py プロジェクト: nicolasvasilache/dex-lang
def eval_for(op):
  if op.op_name in ("IAdd", "IMul", "FAdd", "FMul", "FDiv"):
    x, y = op.args
    x_bc = broadcast_dims(op.all_idxs, x.idxs, x.atom.val)
    y_bc = broadcast_dims(op.all_idxs, y.idxs, y.atom.val)
    if op.op_name in ("IAdd", "FAdd"):
      return jnp.add(x_bc, y_bc)
    elif op.op_name in ("IMul", "FMul"):
      return jnp.multiply(x_bc, y_bc)
    if op.op_name in ("FDiv",):
      return jnp.divide(x_bc, y_bc)
    else:
      raise Exception("Not implemented: " + str(op.op_name))
  elif op.op_name == "Iota":
    n, = op.size_args
    val = jnp.arange(n)
    val_bc = broadcast_dims(op.all_idxs, [], val)
    return val_bc
  elif op.op_name == "Id":
    x, = op.args
    x_bc = broadcast_dims(op.all_idxs, x.idxs, x.atom.val)
    return x_bc
  elif op.op_name == "Get":
    x, idx = op.args
    out_shape = [i.size for i in op.all_idxs]
    x_idxs_used = get_stack_idxs_used(op.all_idxs, x.idxs)
    leading_idx_arrays = []
    for i, idx_used in enumerate(x_idxs_used):
      if idx_used:
        leading_idx_arrays.append(nth_iota(out_shape, i))
      else:
        pass
    payload_idx_array = broadcast_dims(op.all_idxs, idx.idxs, idx.atom.val)
    out = x.atom.val[tuple(leading_idx_arrays) + (payload_idx_array,)]
    return out
  elif op.op_name == "IntToReal":
    x, = op.args
    real_val = jnp.array(x.atom.val, dtype="float32")
    x_bc = broadcast_dims(op.all_idxs, x.idxs, real_val)
    return x_bc
  elif op.op_name in ("FNeg", "INeg"):
    x, = op.args
    x_bc = broadcast_dims(op.all_idxs, x.idxs, jnp.negative(x.atom.val))
    return x_bc
  elif op.op_name == "ThreeFry2x32":
    convert_64_to_32s = lambda x: np.array([x]).view(np.uint32)
    convert_32s_to_64 = lambda x: np.int64(np.array(x).view(np.int64).item())
    x, y = op.args
    key, count = convert_64_to_32s(x.atom.val), convert_64_to_32s(y.atom.val)
    result = convert_32s_to_64(random.threefry_2x32(key, count))
    x_bc = broadcast_dims(op.all_idxs, x.idxs, result)
    return x_bc
  else:
    raise Exception("Unrecognized op: {}".format(op.op_name))