def main(argv): if len(argv) != 1: raise app.UsageError('No positional arguments are accepted.') if not FLAGS.hlo_proto_dest and not FLAGS.hlo_text_dest: raise app.Error('At least one of --hlo_proto_dest and ' '--hlo_text_dest is required.') module_name, fn_name = FLAGS.fn.rsplit('.', 1) module = importlib.import_module(module_name) fn = getattr(module, fn_name) input_shapes = { name: xla_client.Shape(shape_str) for name, shape_str in literal_eval(FLAGS.input_shapes).items() } hlo_proto, hlo_text = jax_to_hlo(fn, input_shapes) if FLAGS.hlo_proto_dest: with open(FLAGS.hlo_proto_dest, 'wb') as f: f.write(hlo_proto) if FLAGS.hlo_text_dest: with open(FLAGS.hlo_text_dest, 'w') as f: f.write(hlo_text)
def main(argv): if len(argv) != 1: raise app.UsageError('No positional arguments are accepted.') if not FLAGS.hlo_proto_dest and not FLAGS.hlo_text_dest: raise app.Error('At least one of --hlo_proto_dest and ' '--hlo_text_dest is required.') module_name, fn_name = FLAGS.fn.rsplit('.', 1) module = importlib.import_module(module_name) fn = getattr(module, fn_name) input_shapes = [(name, xla_client.Shape(shape_str)) for name, shape_str in literal_eval(FLAGS.input_shapes)] # Parse --constants and --evaled_constants. constants = {} for k, v in literal_eval(FLAGS.constants).items(): if isinstance(v, list): v = np.asarray(v) constants[k] = v for k, v in literal_eval(FLAGS.evaled_constants).items(): if isinstance(v, str): v = literal_eval(v) if isinstance(v, list): v = np.asarray(v) if k in constants: raise ValueError( 'Argument appears in both --constants and --evaled_constants: %s' % k) constants[k] = v hlo_proto, hlo_text = jax_to_hlo(fn, input_shapes, constants) if FLAGS.hlo_proto_dest: with open(FLAGS.hlo_proto_dest, 'wb') as f: f.write(hlo_proto) if FLAGS.hlo_text_dest: with open(FLAGS.hlo_text_dest, 'w') as f: f.write(hlo_text)