def test_convert_axpy(self): def axpy(a, x, y): return a * x + y[:, jnp.newaxis] hlo_proto, hlo_text = jax_to_hlo( axpy, [ ('y', xla_client.Shape('f32[128]')), ('a', xla_client.Shape('f32[]')), ('x', xla_client.Shape('f32[128,2]')), ]) # Check that hlo_text contains a broadcast, add, and multiply. self.assertIn('broadcast', hlo_text) self.assertIn('add', hlo_text) self.assertIn('multiply', hlo_text) # Check that the HLO parameters are in the order we specified in the # jax_to_hlo call. self.assertIn('f32[128]{0} parameter(0)', hlo_text) self.assertIn('f32[] parameter(1)', hlo_text) self.assertIn('f32[128,2]{1,0} parameter(2)', hlo_text) # Check that the parameters are in the expected order. # TODO(jlebar): Ideally we'd check that hlo_proto can be deserialized to a # valid HLO proto, but we don't seem to have access to hlo_pb2 at the # moment, so the best we seem to be able to do is check that it's nonempty. assert hlo_proto
def test_convert_with_constants(self): def fn(a, b, x, y): return a / b * x + y _, hlo_text = jax_to_hlo(fn, input_shapes=[ ('x', xla_client.Shape('f32[128]')), ('y', xla_client.Shape('f32[128]')), ], constants={ 'a': 123456, 'b': 4, }) # Because we passed `a` and `b` as constants, they get constant-folded away # by Python/JAX to a/b = 30864. self.assertIn('constant(30864)', hlo_text) self.assertNotIn('123456', hlo_text)
def test_convert_axpy(self): def axpy(a, x, y): return a * x + y hlo_proto, hlo_text = jax_to_hlo( axpy, { 'a': xla_client.Shape('f32[]'), 'x': xla_client.Shape('f32[128]'), 'y': xla_client.Shape('f32[128]'), }) # Check that hlo_text contains a broadcast, add, and multiply. self.assertIn('broadcast', hlo_text) self.assertIn('add', hlo_text) self.assertIn('multiply', hlo_text) # TODO(jlebar): Ideally we'd check that hlo_proto can be deserialized to a # valid HLO proto, but we don't seem to have access to hlo_pb2 at the # moment, so the best we seem to be able to do is check that it's nonempty. assert hlo_proto