예제 #1
0
  def test_convert_axpy(self):

    def axpy(a, x, y):
      return a * x + y[:, jnp.newaxis]

    hlo_proto, hlo_text = jax_to_hlo(
        axpy, [
            ('y', xla_client.Shape('f32[128]')),
            ('a', xla_client.Shape('f32[]')),
            ('x', xla_client.Shape('f32[128,2]')),
        ])

    # Check that hlo_text contains a broadcast, add, and multiply.
    self.assertIn('broadcast', hlo_text)
    self.assertIn('add', hlo_text)
    self.assertIn('multiply', hlo_text)

    # Check that the HLO parameters are in the order we specified in the
    # jax_to_hlo call.
    self.assertIn('f32[128]{0} parameter(0)', hlo_text)
    self.assertIn('f32[] parameter(1)', hlo_text)
    self.assertIn('f32[128,2]{1,0} parameter(2)', hlo_text)

    # Check that the parameters are in the expected order.

    # TODO(jlebar): Ideally we'd check that hlo_proto can be deserialized to a
    # valid HLO proto, but we don't seem to have access to hlo_pb2 at the
    # moment, so the best we seem to be able to do is check that it's nonempty.
    assert hlo_proto
예제 #2
0
    def test_convert_with_constants(self):
        def fn(a, b, x, y):
            return a / b * x + y

        _, hlo_text = jax_to_hlo(fn,
                                 input_shapes=[
                                     ('x', xla_client.Shape('f32[128]')),
                                     ('y', xla_client.Shape('f32[128]')),
                                 ],
                                 constants={
                                     'a': 123456,
                                     'b': 4,
                                 })
        # Because we passed `a` and `b` as constants, they get constant-folded away
        # by Python/JAX to a/b = 30864.
        self.assertIn('constant(30864)', hlo_text)
        self.assertNotIn('123456', hlo_text)
예제 #3
0
    def test_convert_axpy(self):
        def axpy(a, x, y):
            return a * x + y

        hlo_proto, hlo_text = jax_to_hlo(
            axpy, {
                'a': xla_client.Shape('f32[]'),
                'x': xla_client.Shape('f32[128]'),
                'y': xla_client.Shape('f32[128]'),
            })

        # Check that hlo_text contains a broadcast, add, and multiply.
        self.assertIn('broadcast', hlo_text)
        self.assertIn('add', hlo_text)
        self.assertIn('multiply', hlo_text)

        # TODO(jlebar): Ideally we'd check that hlo_proto can be deserialized to a
        # valid HLO proto, but we don't seem to have access to hlo_pb2 at the
        # moment, so the best we seem to be able to do is check that it's nonempty.
        assert hlo_proto