Esempio n. 1
0
        def _WaitTillInit(job=None):
            """Wait until the model is ready."""
            try:
                if py_utils.IsEagerMode():
                    topology = tf.tpu.experimental.initialize_tpu_system(
                        resolver)
                else:
                    # tpu.initialize_system() is called with None as embedding_config, as
                    # embedding_config is not available yet. Later in _Loop, it is called
                    # with the correct embedding_config. Since it cannot be called twice
                    # in the same graph with different embedding_config, we use a
                    # dummy_graph here.
                    dummy_graph = tf.Graph()
                    with dummy_graph.as_default():
                        tpu_initialize_system_op = tf.tpu.initialize_system(
                            embedding_config=None, job=job)

                    with self._GetSession(graph=dummy_graph) as sess:
                        topology = sess.run(tpu_initialize_system_op)

                if train_cfg.train.tpu_computation_shape is None:
                    computation_shape = py_utils.ComputationShape(
                        num_devices_per_split, topology)
                else:
                    computation_shape = train_cfg.train.tpu_computation_shape
                    assert num_devices_per_split == np.prod(computation_shape)

                if train_cfg.train.tpu_device_order_mode is None:
                    self.device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=computation_shape,
                        num_replicas=data_parallelism)
                else:
                    self.device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=computation_shape,
                        num_replicas=data_parallelism,
                        device_order_mode=train_cfg.train.tpu_device_order_mode
                    )
                py_utils.SetTpuDeviceAssignment(self.device_assignment, job)
                tf.logging.info('device_assignment.core_assignment: %s',
                                str(self.device_assignment.core_assignment))
                tf.logging.info(
                    'device_assignment.topology.device_coordinates: %s',
                    str(self.device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise
Esempio n. 2
0
    def _get_device_assignment(self):
        """Gets the (maybe cached) TPU device assignment."""
        master = self._get_master_address()
        device_assignment = self._lazy_device_assignment_dict.get(master)
        if device_assignment is not None:
            return device_assignment

        tpu_system_metadata = self._get_tpu_system_metadata()

        device_assignment = tpu_device_assignment.device_assignment(
            tpu_system_metadata.topology,
            computation_shape=self._computation_shape,
            num_replicas=self.num_replicas)

        tf.compat.v1.logging.info(
            'num_cores_per_replica: %s',
            str(self._config.tpu_config.num_cores_per_replica))
        tf.compat.v1.logging.info('computation_shape: %s',
                                  str(self._computation_shape))
        tf.compat.v1.logging.info('num_replicas: %d', self.num_replicas)
        tf.compat.v1.logging.info(
            'device_assignment.topology.device_coordinates: %s',
            str(device_assignment.topology.device_coordinates))
        tf.compat.v1.logging.info('device_assignment.core_assignment: %s',
                                  str(device_assignment.core_assignment))

        self._lazy_device_assignment_dict[master] = device_assignment
        return device_assignment
Esempio n. 3
0
  def _get_device_assignment(self):
    """Gets the (maybe cached) TPU device assignment."""
    master = self._get_master_address()
    device_assignment = self._lazy_device_assignment_dict.get(master)
    if device_assignment is not None:
      return device_assignment

    tpu_system_metadata = self._get_tpu_system_metadata()

    device_assignment = tpu_device_assignment.device_assignment(
        tpu_system_metadata.topology,
        computation_shape=self._computation_shape,
        num_replicas=self.num_replicas)

    logging.info('num_cores_per_replica: %s',
                 str(self._config.tpu_config.num_cores_per_replica))
    logging.info('computation_shape: %s', str(self._computation_shape))
    logging.info('num_replicas: %d', self.num_replicas)
    logging.info('device_assignment.topology.device_coordinates: %s',
                 str(device_assignment.topology.device_coordinates))
    logging.info('device_assignment.core_assignment: %s',
                 str(device_assignment.core_assignment))

    self._lazy_device_assignment_dict[master] = device_assignment
    return device_assignment
Esempio n. 4
0
    def test_adam(self):
        self.lowering = None

        def create_computation_fn(device_assignment):
            def computation_fn():
                graph = mtf.Graph()
                mesh = mtf.Mesh(graph, 'my_mesh')
                mesh_shape = mtf.convert_to_shape('all:2')
                layout = 'none:all'
                mesh_devices = [''] * mesh_shape.size
                mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                    mesh_shape, mtf.convert_to_layout_rules(layout),
                    mesh_devices, device_assignment)
                hidden_dim = mtf.Dimension('hidden', 3)
                w = mtf.get_variable(mesh,
                                     'w',
                                     shape=[hidden_dim],
                                     initializer=tf.constant_initializer(
                                         [0.1, -0.2, -0.1]))
                x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim],
                                 dtype=tf.float32)
                loss = mtf.reduce_mean(mtf.square(x - w))
                var_grads = mtf.gradients(
                    [loss], [v.outputs[0] for v in graph.trainable_variables])
                optimizer = mtf_optimize.AdamWeightDecayOptimizer(
                    learning_rate=0.2)
                update_ops = optimizer.apply_grads(var_grads,
                                                   graph.trainable_variables)
                self.lowering = mtf.Lowering(graph, {mesh: mesh_impl})
                tf_update_ops = [
                    self.lowering.lowered_operation(op) for op in update_ops
                ]
                return tf.group(tf_update_ops)

            return computation_fn

        with self.test_session() as sess:
            topology = sess.run(tf.tpu.initialize_system())
            device_assignment = tpu_device_assignment.device_assignment(
                topology, computation_shape=[1, 1, 1], num_replicas=2)
            tpu_computation_fn = tf.tpu.batch_parallel(
                create_computation_fn(device_assignment),
                inputs=None,
                num_shards=2,
                infeed_queue=None,
                device_assignment=device_assignment)

            sess.run(tf.global_variables_initializer())
            sess.run(self.lowering.copy_masters_to_slices())

            for _ in range(100):
                sess.run(tpu_computation_fn)

            sess.run(self.lowering.copy_slices_to_masters())
            w_np = sess.run(tf.global_variables()[0])
            self.assertAllClose([0.4, 0.2, -0.5],
                                w_np.flat,
                                rtol=1e-2,
                                atol=1e-2)
            sess.run(tf.tpu.shutdown_system())
Esempio n. 5
0
    def test_optimizer(self):
        self.lowering = None

        def create_computation_fn(device_assignment):
            def computation_fn():
                graph = mtf.Graph()
                mesh = mtf.Mesh(graph, 'my_mesh')
                mesh_shape = mtf.convert_to_shape('all:2')
                layout = 'none:all'
                mesh_devices = [''] * mesh_shape.size
                mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                    mesh_shape, mtf.convert_to_layout_rules(layout),
                    mesh_devices, device_assignment)
                hidden_dim = mtf.Dimension('hidden', 3)
                w = mtf.get_variable(mesh,
                                     'w',
                                     shape=[hidden_dim],
                                     initializer=tf.constant_initializer(
                                         [0.1, -0.2, -0.1]))
                x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim],
                                 dtype=tf.float32)
                loss = mtf.reduce_mean(mtf.square(x - w))

                lr, update_ops = optimization_lib.create_optimizer(
                    loss, 0.2, 100, 10)
                self.lowering = mtf.Lowering(graph, {mesh: mesh_impl})

                tf_update_ops = [
                    self.lowering.lowered_operation(op) for op in update_ops
                ]
                tf_update_ops.append(
                    tf.assign_add(tf.train.get_or_create_global_step(), 1))
                train_op = tf.group(tf_update_ops)

                return lr, train_op

            return computation_fn

        with self.test_session() as sess:
            topology = sess.run(tf.tpu.initialize_system())
            device_assignment = tpu_device_assignment.device_assignment(
                topology, computation_shape=[1, 1, 1], num_replicas=2)
            tpu_computation_fn = tf.tpu.batch_parallel(
                create_computation_fn(device_assignment),
                inputs=None,
                num_shards=2,
                infeed_queue=None,
                device_assignment=device_assignment)

            sess.run(tf.global_variables_initializer())
            sess.run(self.lowering.copy_masters_to_slices())
            lrs = []
            for _ in range(100):
                lr = sess.run(tpu_computation_fn)
                lrs.append(lr[0][0])
            self.assertAllClose(0.02, lrs[0])
            self.assertAllClose(0.18, lrs[8])
            self.assertAllClose(0., lrs[99])
            sess.run(tf.tpu.shutdown_system())
Esempio n. 6
0
 def test_infeed_uneven_partition(self):
   """Tests uneven infeed tensors partition."""
   ds = device_assignment(
       self._topology_2x2x2, num_replicas=1, computation_shape=[2, 2, 1, 2])
   input_partition_dims = [[4, 2]]
   # pylint: disable=protected-access
   partitioned_infeed = tpu_feed._PartitionedInfeedQueue(
       number_of_tuple_elements=1,
       host_id=0,
       input_partition_dims=input_partition_dims,
       device_assignment=ds)
   x = array_ops.zeros((14, 5))
   tensors = partitioned_infeed._check_dims_and_partition_or_replicate_on_host(
       x, dims=input_partition_dims[0])
   self.assertEqual(8, len(tensors))
   self.assertEqual((2, 2), tensors[-1].shape)
Esempio n. 7
0
 def test_infeed_tailing_zero_partition(self):
   """Tests infeed tensors partition which causes zero-size tensors."""
   ds = device_assignment(
       self._topology_2x2x2, num_replicas=1, computation_shape=[1, 2, 1, 2])
   input_partition_dims = [[4, 1]]
   # pylint: disable=protected-access
   partitioned_infeed = tpu_feed._PartitionedInfeedQueue(
       number_of_tuple_elements=1,
       host_id=0,
       input_partition_dims=input_partition_dims,
       device_assignment=ds)
   x = array_ops.zeros((5, 5))
   tensors = partitioned_infeed._check_dims_and_partition_or_replicate_on_host(
       x, dims=input_partition_dims[0])
   self.assertEqual(4, len(tensors))
   self.assertEqual((1, 5), tensors[2].shape)
   self.assertEqual((0, 5), tensors[3].shape)
Esempio n. 8
0
 def _WaitTillInit():
   """Wait until the model is ready."""
   try:
     with self._graph.as_default(), self._GetSession(
         cluster_def=self._cluster_def) as sess:
       topology = sess.run(
           tf.tpu.initialize_system(embedding_config=None, job=None))
       device_assignment = device_assignment_lib.device_assignment(
           topology,
           computation_shape=ComputationShape(num_devices_per_split),
           num_replicas=data_parallelism)
       py_utils.SetTpuDeviceAssignment(device_assignment)
       tf.logging.info('device_assignment.core_assignment: %s',
                       str(device_assignment.core_assignment))
       tf.logging.info('device_assignment.topology.device_coordinates: %s',
                       str(device_assignment.topology.device_coordinates))
   except py_utils.transient_tf_errors as e:
     tf.logging.info('TPU initialization failed: %s', e)
     raise
Esempio n. 9
0
 def _DeviceAssignment(self):
     """A context for tpu device assignment of a JF 8x8 slice."""
     mesh_shape = [8, 8, 1, 2]
     device_coordinates = np.zeros([16, 8, 4], dtype=np.int32)
     for i in range(np.prod(mesh_shape)):
         x = i // 16
         y = i % 16 // 2
         core = i % 2
         task = x // 2 * 4 + y // 2
         device = x % 2 * 4 + y % 2 * 2 + core
         device_coordinates[task, device] = [x, y, 0, core]
     topology = tf.tpu.experimental.Topology(
         mesh_shape=mesh_shape, device_coordinates=device_coordinates)
     assignment = device_assignment.device_assignment(
         topology, computation_shape=[1, 1, 1, 1], num_replicas=128)
     py_utils.SetTpuDeviceAssignment(assignment)
     try:
         yield
     finally:
         py_utils.SetTpuDeviceAssignment(None)
Esempio n. 10
0
    def _init_tpu(self, num_partitions, device_order_mode):
        """Initialize tpu device assignment."""
        tf.logging.info('Initializing TPU to get device assignment: start')

        graph = tf.Graph()
        with graph.as_default():
            init_tpu_op = tf.tpu.initialize_system()
        try:
            sess = tf.Session(target=self._tpu,
                              graph=graph,
                              config=self._no_opt_sess_cfg())
            topology = sess.run(init_tpu_op)
        except Exception as e:
            tf.logging.fatal('TPU initialization failed: %s', e)
            raise

        topology_proto = topology_pb2.TopologyProto()
        topology_proto.ParseFromString(topology)
        tf.logging.info('topology.num_tasks: %r', topology_proto.num_tasks)
        tf.logging.info('topology.num_tpu_devices_per_task: %r',
                        topology_proto.num_tpu_devices_per_task)
        tf.logging.info('topology.mesh_shape: %r', topology_proto.mesh_shape)
        self.cluster_params = self._configure_cluster_params(
            tpu_cores=(topology_proto.num_tpu_devices_per_task *
                       topology_proto.num_tasks),
            cpu_hosts=topology_proto.num_tasks)

        # We assume the topology and device assignment does not change
        # for a single address space.
        device_assignment = tpu_device_assignment.device_assignment(
            topology,
            computation_shape=py_utils.ComputationShape(
                num_partitions, topology),
            num_replicas=1,
            device_order_mode=device_order_mode)
        py_utils.SetTpuDeviceAssignment(device_assignment)

        tf.logging.info('Initializing TPU to get device assignment: done')
Esempio n. 11
0
    def test_get_laidout_tensors(self, is_eval_mode):
        mesh_shape = "mesh_x:2, mesh_y:1"
        layout = "batch:mesh_x, io:mesh_y"
        batch_io_dim = 4

        with tf.Session() as sess:
            topology, num_cores = self.initialize_system(sess)

            # Get a device_assignment object for mtf.
            d_assignment = device_assignment.device_assignment(
                topology,
                computation_shape=[
                    1,
                ] * mtf.utils.topology_rank(topology),
                num_replicas=num_cores)

            # Hacked dataset creator: creates different datasets for the first and
            # second call, in order to test SimdMeshImplInputReader.
            self.sub_batch_created_times = 0

            def stateful_ds_creator():
                whole_batch = tf.eye(batch_io_dim, dtype=tf.float32)
                sub_batch = tf.slice(whole_batch,
                                     [self.sub_batch_created_times * 2, 0],
                                     [2, 4])
                self.sub_batch_created_times += 1
                return tf.data.Dataset.from_tensors(
                    sub_batch).repeat().unbatch()

            batch_dim = mtf.Dimension("batch", batch_io_dim)
            io_dim = mtf.Dimension("io", batch_io_dim)
            mtf_input_shapes = [mtf.Shape([batch_dim, io_dim])]

            # Get mesh_impl.
            mesh_shape = mtf.convert_to_shape(mesh_shape)
            layout_rules = mtf.convert_to_layout_rules(layout)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, None, d_assignment)

            simd_input_reader = input_reader.SimdMeshImplInputReader(
                mesh_impl,
                stateful_ds_creator,
                mtf_input_shapes,
                external_worker=False,
                is_eval_mode=is_eval_mode)

            def model_fn(features):
                return features

            replicated_computation = tpu.replicate(
                computation=model_fn,
                inputs=[[]] * num_cores,
                infeed_queue=simd_input_reader.infeed_queue,
                device_assignment=d_assignment)

            simd_input_reader.start_infeed_thread(sess, 1)
            results = sess.run(replicated_computation)
            print("results: {}".format(results))

            core_0_data = results[0][0]
            core_1_data = results[1][0]
            print("core_0_data: {}".format(core_0_data))
            print("core_1_data: {}".format(core_1_data))

            if is_eval_mode:
                # If there is only one dataset object, then the stateful_ds_creator()
                # should be called only once.
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_0_data)
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_1_data)
            else:
                # If there are two dataset objects, then the stateful_ds_creator()
                # should be called twice.
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_0_data)
                self.assertAllClose(
                    np.array([[0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32),
                    core_1_data)

            sess.run(tf.tpu.shutdown_system())
Esempio n. 12
0
    def initialize(self,
                   train_input_fn,
                   eval_input_fn,
                   model_fn,
                   train_batch_size,
                   eval_batch_size,
                   input_partition_dims=None,
                   init_fn=None,
                   train_has_labels=True,
                   eval_has_labels=True,
                   params=None,
                   num_partitions=None):
        """Build graphs for the TPU device and the input pipelines."""
        num_cores_per_replica = 1
        num_cores_per_replica = functools.reduce(
            operator.mul, input_partition_dims
        ) if input_partition_dims else num_partitions if num_partitions else 1

        self.device_assignment = device_assignment.device_assignment(
            topology=self.device_topology,
            computation_shape=_NUM_CORES_TO_COMPUTATION_SHAPE[
                num_cores_per_replica],
            num_replicas=self.num_replicas)
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.eval_has_labels = eval_has_labels
        self.model_fn = model_fn

        if params is None:
            params = {}
        params[
            "dataset_num_shards"] = self.num_replicas // FLAGS.replicas_per_host
        per_replica_train_batch_size = train_batch_size // self.num_replicas
        per_replica_eval_batch_size = eval_batch_size // self.num_replicas
        for i in range(self.num_replicas // FLAGS.replicas_per_host):
            params["dataset_index"] = i
            params["batch_size"] = per_replica_train_batch_size
            self.build_enqueue_ops(train_input_fn, True, input_partition_dims,
                                   params)
            if self.eval_steps > 0:
                params["batch_size"] = per_replica_eval_batch_size
                self.build_enqueue_ops(eval_input_fn, False,
                                       input_partition_dims, params)

        def train_step(_):
            """One train step."""
            inp = self.infeed_op[True].generate_dequeue_op()
            flatten_structure = tf.nest.flatten(self.feature_structure[True])
            inp = [
                tf.slice(i, [0] * i.shape.ndims, j.shape)
                for i, j in zip(inp, flatten_structure)
            ]
            if train_has_labels:
                features, labels = tf.nest.pack_sequence_as(
                    self.feature_structure[True], inp)
            else:
                features = tf.nest.pack_sequence_as(
                    self.feature_structure[True], inp)
                labels = None
            self.maybe_add_embedding_features(features, True)
            train_op, _ = model_fn(features, labels, True)
            embedding_train_op = self.maybe_get_embedding_train_op()
            with tf.device(device_for_tpu_core(self.get_host(0))):
                with tf.control_dependencies([train_op, embedding_train_op]):
                    return tf.constant(0)

        @tpu_function.on_device_training_loop
        def train_loop():
            return training_loop.repeat(self.iterations_per_loop, train_step,
                                        tf.constant(0))

        def train_eval_step():
            with tf.control_dependencies(train_loop()):
                if self.eval_steps > 0:
                    return self.eval_loop()
                else:
                    return tf.no_op()

        @on_device_train_and_eval_loops
        def train_eval_loop():
            return training_loop.repeat(self.max_train_iterations,
                                        train_eval_step)

        with self.graph.as_default():
            (self.train_eval_op, ) = tpu.shard(
                train_eval_loop,
                inputs=[],
                num_shards=self.num_replicas,
                outputs_from_all_shards=False,
                device_assignment=self.device_assignment)
            if FLAGS.model_dir:
                tf.io.write_graph(self.graph, FLAGS.model_dir, "graph.pbtxt")

        output_graph = tf.Graph()
        if self.eval_steps > 0:
            with output_graph.as_default():
                flatten_output = tf.nest.flatten(self.predict_output)
                self.dequeue_ops = [[] for _ in flatten_output]
                tensor_dtypes = [v.dtype for v in flatten_output]
                tensor_shapes = [v.shape for v in flatten_output]
                is_padded_index = flatten_output.index(
                    self.predict_output[_IS_PADDED]
                ) if _IS_PADDED in self.predict_output else -1
                for i in range(self.num_replicas // FLAGS.replicas_per_host):
                    with tf.device(device_for_host(self.get_host(i))):
                        host_dequeue_ops = [[] for _ in flatten_output]
                        for j in range(FLAGS.replicas_per_host):
                            replica_id = self.device_assignment.lookup_replicas(
                                i, 0)[j]
                            ordinal = self.device_assignment.tpu_ordinal(
                                replica=replica_id, logical_core=0)
                            dequeue_ops = tpu_ops.outfeed_dequeue_tuple(
                                dtypes=tensor_dtypes,
                                shapes=tensor_shapes,
                                device_ordinal=ordinal)
                            if is_padded_index >= 0:
                                num_non_pad = tf.shape(
                                    dequeue_ops[is_padded_index]
                                )[0] - tf.reduce_sum(
                                    tf.cast(dequeue_ops[is_padded_index],
                                            tf.int32))
                                dequeue_ops = [
                                    tf.slice(k, [0] * k.shape.ndims,
                                             [num_non_pad] + [-1] *
                                             (k.shape.ndims - 1))
                                    for k in dequeue_ops
                                ]
                            for k, item in enumerate(dequeue_ops):
                                host_dequeue_ops[k].append(item)
                        for k in range(len(self.predict_output)):
                            self.dequeue_ops[k].append(
                                tf.concat(host_dequeue_ops[k], axis=0))

        self.sess = tf.Session(self.master,
                               graph=self.graph,
                               config=self.config)
        for is_training in [True, False]:
            if is_training or self.eval_steps > 0:
                for i in range(self.num_input_graphs):
                    with self.input_graph[is_training][i].as_default():
                        self.input_sess[is_training][i] = tf.Session(
                            self.master,
                            graph=self.input_graph[is_training][i],
                            config=self.config)
                        self.input_sess[is_training][i].run(
                            self.dataset_initializer[is_training][i])
        self.output_sess = tf.Session(self.master,
                                      graph=output_graph,
                                      config=self.config)

        with self.graph.as_default():
            _ = tf.train.get_or_create_global_step()
            if init_fn:
                init_fn()
            checkpoint_path = tf.train.latest_checkpoint(
                FLAGS.model_dir) if FLAGS.model_dir else None
            if FLAGS.restore_checkpoint and checkpoint_path:
                tf.train.Saver().restore(self.sess, checkpoint_path)
            else:
                self.sess.run(tf.global_variables_initializer())
                self.sess.run(tf.local_variables_initializer())
            self.maybe_load_embedding_vars()
            self.global_step = self.sess.run(
                tf.train.get_global_step(self.graph))

        def train_eval_thread_fn(sess, train_eval_op):
            sess.run([train_eval_op])

        # Start the just in time compilation of the model function
        self.train_eval_thread = threading.Thread(target=train_eval_thread_fn,
                                                  args=(self.sess,
                                                        self.train_eval_op))
        self.train_eval_thread.start()

        # Sleep for JTC to finish
        time.sleep(FLAGS.sleep_after_init)
Esempio n. 13
0
    def test_bert_forward(self):
        def create_computation_fn(device_assignment):
            d_model = 128
            num_blocks = 2
            seq_length = 128
            batch_size = 2
            vocab_size = 30522

            bert_config = bert_lib.BertConfig(
                vocab_size=vocab_size,
                d_model=int(d_model),
                num_blocks=int(num_blocks),
                attention_num_heads=int(d_model / 64),
                feedforward_intermediate_size=int(d_model * 4),
                feedforward_intermediate_act='relu',
                feedforward_intermediate_dropout_prob=0.1,
                attention_probs_dropout_prob=0.1,
                max_position_embeddings=seq_length,
                type_vocab_size=2,
                initializer_range=0.02)

            def computation_fn():
                graph = mtf.Graph()
                mesh = mtf.Mesh(graph, 'my_mesh')
                mesh_shape = mtf.convert_to_shape('all:2')
                layout = 'num_heads:all'
                mesh_devices = [''] * mesh_shape.size
                mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                    mesh_shape, mtf.convert_to_layout_rules(layout),
                    mesh_devices, device_assignment)
                batch_dim = mtf.Dimension('batch', batch_size)
                seq_dim = mtf.Dimension('seq', seq_length)

                input_ids = tf.random.uniform((batch_size, seq_length),
                                              minval=0,
                                              maxval=vocab_size,
                                              dtype=tf.int32)
                mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids,
                                                     [batch_dim, seq_dim])

                model = bert_lib.BertModel(config=bert_config,
                                           is_training=True,
                                           input_ids=mtf_input_ids,
                                           input_mask=None,
                                           token_type_ids=None)
                pooled = model.get_pooled_output()
                lowering = mtf.Lowering(graph, {mesh: mesh_impl})
                return lowering.export_to_tf_tensor(pooled)

            return computation_fn

        with self.test_session() as sess:
            topology = sess.run(tf.tpu.initialize_system())
            device_assignment = tpu_device_assignment.device_assignment(
                topology, computation_shape=[1, 1, 1], num_replicas=2)
            tpu_computation_fn = tf.tpu.batch_parallel(
                create_computation_fn(device_assignment),
                inputs=None,
                num_shards=2,
                infeed_queue=None,
                device_assignment=device_assignment)

            sess.run(tf.global_variables_initializer())
            sess.run(tf.variables_initializer(tf.get_collection('TPU_VAR')))

            print('TPU', sess.run(tpu_computation_fn))
            sess.run(tf.tpu.shutdown_system())