Ejemplo n.º 1
0
    def test_jax_to_tf_axpy(self):
        tf_proto, tf_text = jax_to_ir.jax_to_tf(axpy, [
            ('y', jax_to_ir.parse_shape_str('f32[128]')),
            ('a', jax_to_ir.parse_shape_str('f32[]')),
            ('x', jax_to_ir.parse_shape_str('f32[128,2]')),
        ])

        # Check that tf debug txt contains a broadcast, add, and multiply.
        self.assertIn('name: "BroadcastTo"', tf_text)
        self.assertIn('name: "AddV2"', tf_text)
        self.assertIn('name: "Mul"', tf_text)

        # Check that we can re-import our graphdef.
        gdef = tf.compat.v1.GraphDef()
        gdef.ParseFromString(tf_proto)
        g = tf.Graph()
        with g.as_default():
            tf.import_graph_def(gdef, name='')

        # Check that the HLO parameters are named as we specified.
        ops = {
            o.name: o
            for o in g.get_operations()
            if o.name in ('y', 'a', 'x', 'jax2tf_out')
        }
        self.assertLen(ops, 4)
        self.assertIdentityOp(ops['y'], [128], jnp.float32)
        self.assertIdentityOp(ops['a'], [], jnp.float32)
        self.assertIdentityOp(ops['x'], [128, 2], jnp.float32)
        self.assertIdentityOp(ops['jax2tf_out'], [128, 2], jnp.float32)
Ejemplo n.º 2
0
    def test_jax_to_hlo_with_constants(self):
        def fn(a, b, x, y):
            return a / b * x + y

        _, hlo_text = jax_to_ir.jax_to_hlo(
            fn,
            input_shapes=[
                ('x', jax_to_ir.parse_shape_str('f32[128]')),
                ('y', jax_to_ir.parse_shape_str('f32[128]')),
            ],
            constants={
                'a': 123456,
                'b': 4,
            })
        # Because we passed `a` and `b` as constants, they get constant-folded away
        # by Python/JAX to a/b = 30864.
        self.assertIn('constant(30864)', hlo_text)
        self.assertNotIn('123456', hlo_text)
Ejemplo n.º 3
0
    def test_jax_to_hlo_axpy(self):
        hlo_proto, hlo_text = jax_to_ir.jax_to_hlo(axpy, [
            ('y', jax_to_ir.parse_shape_str('f32[128]')),
            ('a', jax_to_ir.parse_shape_str('f32[]')),
            ('x', jax_to_ir.parse_shape_str('f32[128,2]')),
        ])

        # Check that hlo_text contains a broadcast, add, and multiply.
        self.assertIn('broadcast', hlo_text)
        self.assertIn('add', hlo_text)
        self.assertIn('multiply', hlo_text)

        # Check that the HLO parameters are in the order we specified in the
        # jax_to_hlo call.
        self.assertIn('f32[128]{0} parameter(0)', hlo_text)
        self.assertIn('f32[] parameter(1)', hlo_text)
        self.assertIn('f32[128,2]{1,0} parameter(2)', hlo_text)

        # Check that the parameters are in the expected order.

        # TODO(jlebar): Ideally we'd check that hlo_proto can be deserialized to a
        # valid HLO proto, but we don't seem to have access to hlo_pb2 at the
        # moment, so the best we seem to be able to do is check that it's nonempty.
        assert hlo_proto
Ejemplo n.º 4
0
 def test_parse_shape_str_invalid(self):
     with self.assertRaisesRegex(ValueError, 'Invalid shape.*foo'):
         jax_to_ir.parse_shape_str('foo[]')
Ejemplo n.º 5
0
 def assertParsedShape(self, s: str, expected_shape, expected_dtype):
     p = jax_to_ir.parse_shape_str(s)
     self.assertEqual(p.shape, tuple(expected_shape))
     self.assertEqual(p.dtype, expected_dtype)
Ejemplo n.º 6
0
def encoder_from_file(config,
                      batch_size=8,
                      encode_length=16,
                      use_bfloat16=True,
                      use_xla_optimizations=True):
    """Generates HLO for just the encoder of the WMT model.

  Args:
    config: A ConfigDict instance.
    batch_size: Batch size.
    encode_length: Max length of an input sentence.
    use_bfloat16: Use bfloat16 mixed precision training instead of float32.
    use_xla_optimizations: Whether to use xla optimizations.
  """
    if FLAGS.checkpoint:
        raise app.UsageError('Checkpoints not yet supported for WMT encoder.')

    input_shape = (batch_size, encode_length)
    rng = jax.random.PRNGKey(0)
    hparams = hparams_utils.load_dataclass_from_config_dict(
        training_hparams.TrainingHParams, config)
    model_hparams = hparams.model_hparams
    model = models.Encoder(vocab_size=32711,
                           hparams=model_hparams.encoder,
                           shared_embedding=None,
                           use_bfloat16=use_bfloat16,
                           emb_dim=model_hparams.emb_dim,
                           num_heads=model_hparams.num_heads,
                           qkv_dim=model_hparams.qkv_dim,
                           mlp_dim=model_hparams.mlp_dim,
                           max_len=encode_length,
                           train=False,
                           dropout_rate=0.1,
                           attention_dropout_rate=0.1,
                           quant_context=quant_config.QuantContext(
                               update_bounds=False,
                               collect_acts_stats=False,
                               quantize_acts=True))
    init_state = model.init(rng, jnp.ones(input_shape, jnp.float32))

    def _fn(state, inputs):
        return model.apply(state, inputs, mutable=False)

    if not use_xla_optimizations:
        computation = jax.xla_computation(_fn)(init_state,
                                               jnp.ones(
                                                   input_shape, jnp.float32))
        hlo_utils.output_hlo(computation, FLAGS.hlo_output)

    else:

        def _wrapped_fn(inputs):
            return _fn(init_state, inputs)

        def to_shape_str(shape_tuple):
            return 'f32[%s]' % ','.join(map(str, shape_tuple))

        hlo_module_proto_str, hlo_txt = jax_to_ir.jax_to_hlo(
            _wrapped_fn,
            [('inputs', jax_to_ir.parse_shape_str(to_shape_str(input_shape)))])
        hlo_utils.output_hlo_to_file(hlo_module_proto_str, hlo_txt,
                                     FLAGS.hlo_output)