def test_convert_axpy(self): def axpy(a, x, y): return a * x + y[:, jnp.newaxis] hlo_proto, hlo_text = jax_to_hlo( axpy, [ ('y', xla_client.Shape('f32[128]')), ('a', xla_client.Shape('f32[]')), ('x', xla_client.Shape('f32[128,2]')), ]) # Check that hlo_text contains a broadcast, add, and multiply. self.assertIn('broadcast', hlo_text) self.assertIn('add', hlo_text) self.assertIn('multiply', hlo_text) # Check that the HLO parameters are in the order we specified in the # jax_to_hlo call. self.assertIn('f32[128]{0} parameter(0)', hlo_text) self.assertIn('f32[] parameter(1)', hlo_text) self.assertIn('f32[128,2]{1,0} parameter(2)', hlo_text) # Check that the parameters are in the expected order. # TODO(jlebar): Ideally we'd check that hlo_proto can be deserialized to a # valid HLO proto, but we don't seem to have access to hlo_pb2 at the # moment, so the best we seem to be able to do is check that it's nonempty. assert hlo_proto
def test_convert_with_constants(self): def fn(a, b, x, y): return a / b * x + y _, hlo_text = jax_to_hlo(fn, input_shapes=[ ('x', xla_client.Shape('f32[128]')), ('y', xla_client.Shape('f32[128]')), ], constants={ 'a': 123456, 'b': 4, }) # Because we passed `a` and `b` as constants, they get constant-folded away # by Python/JAX to a/b = 30864. self.assertIn('constant(30864)', hlo_text) self.assertNotIn('123456', hlo_text)
def main(argv): if len(argv) != 1: raise app.UsageError('No positional arguments are accepted.') if not FLAGS.hlo_proto_dest and not FLAGS.hlo_text_dest: raise app.Error('At least one of --hlo_proto_dest and ' '--hlo_text_dest is required.') module_name, fn_name = FLAGS.fn.rsplit('.', 1) module = importlib.import_module(module_name) fn = getattr(module, fn_name) input_shapes = [(name, xla_client.Shape(shape_str)) for name, shape_str in literal_eval(FLAGS.input_shapes)] # Parse --constants and --evaled_constants. constants = {} for k, v in literal_eval(FLAGS.constants).items(): if isinstance(v, list): v = np.asarray(v) constants[k] = v for k, v in literal_eval(FLAGS.evaled_constants).items(): if isinstance(v, str): v = literal_eval(v) if isinstance(v, list): v = np.asarray(v) if k in constants: raise ValueError( 'Argument appears in both --constants and --evaled_constants: %s' % k) constants[k] = v hlo_proto, hlo_text = jax_to_hlo(fn, input_shapes, constants) if FLAGS.hlo_proto_dest: with open(FLAGS.hlo_proto_dest, 'wb') as f: f.write(hlo_proto) if FLAGS.hlo_text_dest: with open(FLAGS.hlo_text_dest, 'w') as f: f.write(hlo_text)