def testInfeedThenOutfeedInALoop(self): hcb.stop_outfeed_receiver() def doubler(_, token): y, token = lax.infeed( token, shape=jax.ShapedArray((3, 4), jnp.float32)) return lax.outfeed(token, y * np.float32(2)) @jax.jit def f(n): token = lax.create_token(n) token = lax.fori_loop(0, n, doubler, token) return n device = jax.local_devices()[0] n = 10 execution = threading.Thread(target=lambda: f(n)) execution.start() for _ in range(n): x = np.random.randn(3, 4).astype(np.float32) device.transfer_to_infeed((x,)) y, = device.transfer_from_outfeed(xla_client.shape_from_pyval((x,)) .with_major_to_minor_layout_if_absent()) self.assertAllClose(y, x * np.float32(2)) execution.join()
def test_add_numbers(self): builder = xla_client.XlaBuilder('comp') param = xla_client.ops.Parameter( builder, 0, xla_client.shape_from_pyval( tuple([np.array(0, dtype=np.int32)] * 2))) xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0), xla_client.ops.GetTupleElement(param, 1)) xla_comp = builder.build() comp_type = computation_types.FunctionType((np.int32, np.int32), np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], comp_type) ex = executor.XlaExecutor() async def _compute_fn(): comp_val = await ex.create_value(comp_pb, comp_type) x_val = await ex.create_value(20, np.int32) y_val = await ex.create_value(30, np.int32) arg_val = await ex.create_struct([x_val, y_val]) call_val = await ex.create_call(comp_val, arg_val) return await call_val.compute() result = asyncio.run(_compute_fn()) self.assertEqual(result, 50)
def test_to_representation_for_type_with_noarg_to_int32_comp(self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Constant(builder, np.int32(10)) xla_comp = builder.build() comp_type = computation_types.FunctionType(None, np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [], comp_type) rep = executor.to_representation_for_type(comp_pb, comp_type, self._backend) self.assertTrue(callable(rep)) result = rep() self.assertEqual(result, 10)
def test_set_local_python_execution_context_and_run_simple_xla_computation( self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Constant(builder, np.int32(10)) xla_comp = builder.build() comp_type = computation_types.FunctionType(None, np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [], comp_type) ctx_stack = context_stack_impl.context_stack comp = computation_impl.ConcreteComputation(comp_pb, ctx_stack) execution_contexts.set_local_python_execution_context() self.assertEqual(comp(), 10)
def test_computation_callable_return_one_number(self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Constant(builder, np.int32(10)) xla_comp = builder.build() comp_type = computation_types.FunctionType(None, np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [], comp_type) backend = jax.lib.xla_bridge.get_backend() comp_callable = runtime.ComputationCallable(comp_pb, comp_type, backend) self.assertIsInstance(comp_callable, runtime.ComputationCallable) self.assertEqual(str(comp_callable.type_signature), '( -> int32)') result = comp_callable() self.assertEqual(result, 10)
def testInfeedThenOutfeed(self): @jax.jit def f(x): token = lax.create_token(x) y, token = lax.infeed( token, shape=jax.ShapedArray((3, 4), np.float32)) token = lax.outfeed(token, y + onp.float32(1)) return lax.tie_in(token, x - 1) x = onp.float32(7.5) y = onp.random.randn(3, 4).astype(onp.float32) execution = threading.Thread(target=lambda: f(x)) execution.start() xla_client.transfer_to_infeed((y,)) out, = xla_client.transfer_from_outfeed(xla_client.shape_from_pyval((y,))) execution.join() self.assertAllClose(out, y + onp.float32(1), check_dtypes=True)
def test_to_representation_for_type_with_2xint32_to_int32_comp(self): builder = xla_client.XlaBuilder('comp') param = xla_client.ops.Parameter( builder, 0, xla_client.shape_from_pyval( tuple([np.array(0, dtype=np.int32)] * 2))) xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0), xla_client.ops.GetTupleElement(param, 1)) xla_comp = builder.build() comp_type = computation_types.FunctionType((np.int32, np.int32), np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], comp_type) rep = executor.to_representation_for_type(comp_pb, comp_type, self._backend) self.assertTrue(callable(rep)) result = rep( structure.Struct([(None, np.int32(20)), (None, np.int32(30))])) self.assertEqual(result, 50)
def test_to_representation_for_type_with_noarg_to_2xint32_comp(self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Tuple(builder, [ xla_client.ops.Constant(builder, np.int32(10)), xla_client.ops.Constant(builder, np.int32(20)) ]) xla_comp = builder.build() comp_type = computation_types.FunctionType( None, computation_types.StructType([('a', np.int32), ('b', np.int32)])) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], comp_type) rep = executor.to_representation_for_type(comp_pb, comp_type, self._backend) self.assertTrue(callable(rep)) result = rep() self.assertEqual(str(result), '<a=10,b=20>')
def testInfeedThenOutfeed(self): @jax.jit def f(x): token = lax.create_token(x) y, token = lax.infeed( token, shape=jax.ShapedArray((3, 4), np.float32)) token = lax.outfeed(token, y + onp.float32(1)) return lax.tie_in(token, x - 1) x = onp.float32(7.5) y = onp.random.randn(3, 4).astype(onp.float32) execution = threading.Thread(target=lambda: f(x)) execution.start() device = jax.local_devices()[0] device.transfer_to_infeed((y,)) out, = device.transfer_from_outfeed( xla_client.shape_from_pyval((y,)).with_major_to_minor_layout_if_absent()) execution.join() self.assertAllClose(out, y + onp.float32(1))
def test_computation_callable_add_two_numbers(self): builder = xla_client.XlaBuilder('comp') param = xla_client.ops.Parameter( builder, 0, xla_client.shape_from_pyval(tuple([np.array(0, dtype=np.int32)] * 2))) xla_client.ops.Add( xla_client.ops.GetTupleElement(param, 0), xla_client.ops.GetTupleElement(param, 1)) xla_comp = builder.build() comp_type = computation_types.FunctionType((np.int32, np.int32), np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], comp_type) backend = jax.lib.xla_bridge.get_backend() comp_callable = runtime.ComputationCallable(comp_pb, comp_type, backend) self.assertIsInstance(comp_callable, runtime.ComputationCallable) self.assertEqual( str(comp_callable.type_signature), '(<int32,int32> -> int32)') result = comp_callable( structure.Struct([(None, np.int32(2)), (None, np.int32(3))])) self.assertEqual(result, 5)
def create_constant_from_scalar( self, value, type_spec: computation_types.Type ) -> local_computation_factory_base.ComputationProtoAndType: py_typecheck.check_type(type_spec, computation_types.Type) if not type_analysis.is_structure_of_tensors(type_spec): raise ValueError( 'Not a tensor or a structure of tensors: {}'.format( str(type_spec))) builder = xla_client.XlaBuilder('comp') # We maintain the convention that arguments are supplied as a tuple for the # sake of consistency and uniformity (see comments in `computation.proto`). # Since there are no arguments here, we create an empty tuple. xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) def _constant_from_tensor(tensor_type): py_typecheck.check_type(tensor_type, computation_types.TensorType) numpy_value = np.full(shape=tensor_type.shape.dims, fill_value=value, dtype=tensor_type.dtype.as_numpy_dtype) return xla_client.ops.Constant(builder, numpy_value) if isinstance(type_spec, computation_types.TensorType): tensors = [_constant_from_tensor(type_spec)] else: tensors = [ _constant_from_tensor(x) for x in structure.flatten(type_spec) ] # Likewise, results are always returned as a single tuple with results. # This is always a flat tuple; the nested TFF structure is defined by the # binding. xla_client.ops.Tuple(builder, tensors) xla_computation = builder.build() comp_type = computation_types.FunctionType(None, type_spec) comp_pb = xla_serialization.create_xla_tff_computation( xla_computation, [], comp_type) return (comp_pb, comp_type)
def testInfeedThenOutfeedInALoop(self): def doubler(_, token): y, token = lax.infeed( token, shape=jax.ShapedArray((3, 4), np.float32)) return lax.outfeed(token, y * onp.float32(2)) @jax.jit def f(n): token = lax.create_token(n) token = lax.fori_loop(0, n, doubler, token) return lax.tie_in(token, n) n = 10 execution = threading.Thread(target=lambda: f(n)) execution.start() for _ in range(n): x = onp.random.randn(3, 4).astype(onp.float32) xla_client.transfer_to_infeed((x,)) y, = xla_client.transfer_from_outfeed(xla_client.shape_from_pyval((x,))) self.assertAllClose(y, x * onp.float32(2), check_dtypes=True) execution.join()
def test_create_and_invoke_noarg_comp_returning_int32(self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Constant(builder, np.int32(10)) xla_comp = builder.build() comp_type = computation_types.FunctionType(None, np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [], comp_type) ex = executor.XlaExecutor() comp_val = asyncio.run(ex.create_value(comp_pb, comp_type)) self.assertIsInstance(comp_val, executor.XlaValue) self.assertEqual(str(comp_val.type_signature), str(comp_type)) self.assertTrue(callable(comp_val.internal_representation)) result = comp_val.internal_representation() self.assertEqual(result, 10) call_val = asyncio.run(ex.create_call(comp_val)) self.assertIsInstance(call_val, executor.XlaValue) self.assertEqual(str(call_val.type_signature), 'int32') result = asyncio.run(call_val.compute()) self.assertEqual(result, 10)
def check_outfeed(d, x): y, = d.transfer_from_outfeed( xla_client.shape_from_pyval( (x, )).with_major_to_minor_layout_if_absent()) self.assertAllClose(x, y, check_dtypes=True)