Exemplo n.º 1
0
  def test_convert_axpy(self):

    def axpy(a, x, y):
      return a * x + y[:, jnp.newaxis]

    hlo_proto, hlo_text = jax_to_hlo(
        axpy, [
            ('y', xla_client.Shape('f32[128]')),
            ('a', xla_client.Shape('f32[]')),
            ('x', xla_client.Shape('f32[128,2]')),
        ])

    # Check that hlo_text contains a broadcast, add, and multiply.
    self.assertIn('broadcast', hlo_text)
    self.assertIn('add', hlo_text)
    self.assertIn('multiply', hlo_text)

    # Check that the HLO parameters are in the order we specified in the
    # jax_to_hlo call.
    self.assertIn('f32[128]{0} parameter(0)', hlo_text)
    self.assertIn('f32[] parameter(1)', hlo_text)
    self.assertIn('f32[128,2]{1,0} parameter(2)', hlo_text)

    # Check that the parameters are in the expected order.

    # TODO(jlebar): Ideally we'd check that hlo_proto can be deserialized to a
    # valid HLO proto, but we don't seem to have access to hlo_pb2 at the
    # moment, so the best we seem to be able to do is check that it's nonempty.
    assert hlo_proto
Exemplo n.º 2
0
    def test_convert_with_constants(self):
        def fn(a, b, x, y):
            return a / b * x + y

        _, hlo_text = jax_to_hlo(fn,
                                 input_shapes=[
                                     ('x', xla_client.Shape('f32[128]')),
                                     ('y', xla_client.Shape('f32[128]')),
                                 ],
                                 constants={
                                     'a': 123456,
                                     'b': 4,
                                 })
        # Because we passed `a` and `b` as constants, they get constant-folded away
        # by Python/JAX to a/b = 30864.
        self.assertIn('constant(30864)', hlo_text)
        self.assertNotIn('123456', hlo_text)
Exemplo n.º 3
0
def main(argv):
    if len(argv) != 1:
        raise app.UsageError('No positional arguments are accepted.')

    if not FLAGS.hlo_proto_dest and not FLAGS.hlo_text_dest:
        raise app.Error('At least one of --hlo_proto_dest and '
                        '--hlo_text_dest is required.')

    module_name, fn_name = FLAGS.fn.rsplit('.', 1)
    module = importlib.import_module(module_name)
    fn = getattr(module, fn_name)

    input_shapes = [(name, xla_client.Shape(shape_str))
                    for name, shape_str in literal_eval(FLAGS.input_shapes)]

    # Parse --constants and --evaled_constants.
    constants = {}
    for k, v in literal_eval(FLAGS.constants).items():
        if isinstance(v, list):
            v = np.asarray(v)
        constants[k] = v

    for k, v in literal_eval(FLAGS.evaled_constants).items():
        if isinstance(v, str):
            v = literal_eval(v)
        if isinstance(v, list):
            v = np.asarray(v)
        if k in constants:
            raise ValueError(
                'Argument appears in both --constants and --evaled_constants: %s'
                % k)
        constants[k] = v

    hlo_proto, hlo_text = jax_to_hlo(fn, input_shapes, constants)

    if FLAGS.hlo_proto_dest:
        with open(FLAGS.hlo_proto_dest, 'wb') as f:
            f.write(hlo_proto)

    if FLAGS.hlo_text_dest:
        with open(FLAGS.hlo_text_dest, 'w') as f:
            f.write(hlo_text)