コード例 #1
0
ファイル: jax_to_hlo.py プロジェクト: jheek/jax
def main(argv):
  if len(argv) != 1:
    raise app.UsageError('No positional arguments are accepted.')

  if not FLAGS.hlo_proto_dest and not FLAGS.hlo_text_dest:
    raise app.Error('At least one of --hlo_proto_dest and '
                    '--hlo_text_dest is required.')

  module_name, fn_name = FLAGS.fn.rsplit('.', 1)
  module = importlib.import_module(module_name)
  fn = getattr(module, fn_name)

  input_shapes = {
    name: xla_client.Shape(shape_str)
    for name, shape_str in literal_eval(FLAGS.input_shapes).items()
  }
  hlo_proto, hlo_text = jax_to_hlo(fn, input_shapes)

  if FLAGS.hlo_proto_dest:
    with open(FLAGS.hlo_proto_dest, 'wb') as f:
      f.write(hlo_proto)

  if FLAGS.hlo_text_dest:
    with open(FLAGS.hlo_text_dest, 'w') as f:
      f.write(hlo_text)
コード例 #2
0
ファイル: jax_to_hlo.py プロジェクト: xwinxu/jax
def main(argv):
    if len(argv) != 1:
        raise app.UsageError('No positional arguments are accepted.')

    if not FLAGS.hlo_proto_dest and not FLAGS.hlo_text_dest:
        raise app.Error('At least one of --hlo_proto_dest and '
                        '--hlo_text_dest is required.')

    module_name, fn_name = FLAGS.fn.rsplit('.', 1)
    module = importlib.import_module(module_name)
    fn = getattr(module, fn_name)

    input_shapes = [(name, xla_client.Shape(shape_str))
                    for name, shape_str in literal_eval(FLAGS.input_shapes)]

    # Parse --constants and --evaled_constants.
    constants = {}
    for k, v in literal_eval(FLAGS.constants).items():
        if isinstance(v, list):
            v = np.asarray(v)
        constants[k] = v

    for k, v in literal_eval(FLAGS.evaled_constants).items():
        if isinstance(v, str):
            v = literal_eval(v)
        if isinstance(v, list):
            v = np.asarray(v)
        if k in constants:
            raise ValueError(
                'Argument appears in both --constants and --evaled_constants: %s'
                % k)
        constants[k] = v

    hlo_proto, hlo_text = jax_to_hlo(fn, input_shapes, constants)

    if FLAGS.hlo_proto_dest:
        with open(FLAGS.hlo_proto_dest, 'wb') as f:
            f.write(hlo_proto)

    if FLAGS.hlo_text_dest:
        with open(FLAGS.hlo_text_dest, 'w') as f:
            f.write(hlo_text)