def _WaitTillInit(job=None): """Wait until the model is ready.""" try: if py_utils.IsEagerMode(): topology = tf.tpu.experimental.initialize_tpu_system( resolver) else: # tpu.initialize_system() is called with None as embedding_config, as # embedding_config is not available yet. Later in _Loop, it is called # with the correct embedding_config. Since it cannot be called twice # in the same graph with different embedding_config, we use a # dummy_graph here. dummy_graph = tf.Graph() with dummy_graph.as_default(): tpu_initialize_system_op = tf.tpu.initialize_system( embedding_config=None, job=job) with self._GetSession(graph=dummy_graph) as sess: topology = sess.run(tpu_initialize_system_op) if train_cfg.train.tpu_computation_shape is None: computation_shape = py_utils.ComputationShape( num_devices_per_split, topology) else: computation_shape = train_cfg.train.tpu_computation_shape assert num_devices_per_split == np.prod(computation_shape) if train_cfg.train.tpu_device_order_mode is None: self.device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=computation_shape, num_replicas=data_parallelism) else: self.device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=computation_shape, num_replicas=data_parallelism, device_order_mode=train_cfg.train.tpu_device_order_mode ) py_utils.SetTpuDeviceAssignment(self.device_assignment, job) tf.logging.info('device_assignment.core_assignment: %s', str(self.device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(self.device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise
def _get_device_assignment(self): """Gets the (maybe cached) TPU device assignment.""" master = self._get_master_address() device_assignment = self._lazy_device_assignment_dict.get(master) if device_assignment is not None: return device_assignment tpu_system_metadata = self._get_tpu_system_metadata() device_assignment = tpu_device_assignment.device_assignment( tpu_system_metadata.topology, computation_shape=self._computation_shape, num_replicas=self.num_replicas) tf.compat.v1.logging.info( 'num_cores_per_replica: %s', str(self._config.tpu_config.num_cores_per_replica)) tf.compat.v1.logging.info('computation_shape: %s', str(self._computation_shape)) tf.compat.v1.logging.info('num_replicas: %d', self.num_replicas) tf.compat.v1.logging.info( 'device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) tf.compat.v1.logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) self._lazy_device_assignment_dict[master] = device_assignment return device_assignment
def _get_device_assignment(self): """Gets the (maybe cached) TPU device assignment.""" master = self._get_master_address() device_assignment = self._lazy_device_assignment_dict.get(master) if device_assignment is not None: return device_assignment tpu_system_metadata = self._get_tpu_system_metadata() device_assignment = tpu_device_assignment.device_assignment( tpu_system_metadata.topology, computation_shape=self._computation_shape, num_replicas=self.num_replicas) logging.info('num_cores_per_replica: %s', str(self._config.tpu_config.num_cores_per_replica)) logging.info('computation_shape: %s', str(self._computation_shape)) logging.info('num_replicas: %d', self.num_replicas) logging.info('device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) self._lazy_device_assignment_dict[master] = device_assignment return device_assignment
def test_adam(self): self.lowering = None def create_computation_fn(device_assignment): def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'none:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) hidden_dim = mtf.Dimension('hidden', 3) w = mtf.get_variable(mesh, 'w', shape=[hidden_dim], initializer=tf.constant_initializer( [0.1, -0.2, -0.1])) x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim], dtype=tf.float32) loss = mtf.reduce_mean(mtf.square(x - w)) var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf_optimize.AdamWeightDecayOptimizer( learning_rate=0.2) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) self.lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ self.lowering.lowered_operation(op) for op in update_ops ] return tf.group(tf_update_ops) return computation_fn with self.test_session() as sess: topology = sess.run(tf.tpu.initialize_system()) device_assignment = tpu_device_assignment.device_assignment( topology, computation_shape=[1, 1, 1], num_replicas=2) tpu_computation_fn = tf.tpu.batch_parallel( create_computation_fn(device_assignment), inputs=None, num_shards=2, infeed_queue=None, device_assignment=device_assignment) sess.run(tf.global_variables_initializer()) sess.run(self.lowering.copy_masters_to_slices()) for _ in range(100): sess.run(tpu_computation_fn) sess.run(self.lowering.copy_slices_to_masters()) w_np = sess.run(tf.global_variables()[0]) self.assertAllClose([0.4, 0.2, -0.5], w_np.flat, rtol=1e-2, atol=1e-2) sess.run(tf.tpu.shutdown_system())
def test_optimizer(self): self.lowering = None def create_computation_fn(device_assignment): def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'none:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) hidden_dim = mtf.Dimension('hidden', 3) w = mtf.get_variable(mesh, 'w', shape=[hidden_dim], initializer=tf.constant_initializer( [0.1, -0.2, -0.1])) x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim], dtype=tf.float32) loss = mtf.reduce_mean(mtf.square(x - w)) lr, update_ops = optimization_lib.create_optimizer( loss, 0.2, 100, 10) self.lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ self.lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append( tf.assign_add(tf.train.get_or_create_global_step(), 1)) train_op = tf.group(tf_update_ops) return lr, train_op return computation_fn with self.test_session() as sess: topology = sess.run(tf.tpu.initialize_system()) device_assignment = tpu_device_assignment.device_assignment( topology, computation_shape=[1, 1, 1], num_replicas=2) tpu_computation_fn = tf.tpu.batch_parallel( create_computation_fn(device_assignment), inputs=None, num_shards=2, infeed_queue=None, device_assignment=device_assignment) sess.run(tf.global_variables_initializer()) sess.run(self.lowering.copy_masters_to_slices()) lrs = [] for _ in range(100): lr = sess.run(tpu_computation_fn) lrs.append(lr[0][0]) self.assertAllClose(0.02, lrs[0]) self.assertAllClose(0.18, lrs[8]) self.assertAllClose(0., lrs[99]) sess.run(tf.tpu.shutdown_system())
def test_infeed_uneven_partition(self): """Tests uneven infeed tensors partition.""" ds = device_assignment( self._topology_2x2x2, num_replicas=1, computation_shape=[2, 2, 1, 2]) input_partition_dims = [[4, 2]] # pylint: disable=protected-access partitioned_infeed = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=1, host_id=0, input_partition_dims=input_partition_dims, device_assignment=ds) x = array_ops.zeros((14, 5)) tensors = partitioned_infeed._check_dims_and_partition_or_replicate_on_host( x, dims=input_partition_dims[0]) self.assertEqual(8, len(tensors)) self.assertEqual((2, 2), tensors[-1].shape)
def test_infeed_tailing_zero_partition(self): """Tests infeed tensors partition which causes zero-size tensors.""" ds = device_assignment( self._topology_2x2x2, num_replicas=1, computation_shape=[1, 2, 1, 2]) input_partition_dims = [[4, 1]] # pylint: disable=protected-access partitioned_infeed = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=1, host_id=0, input_partition_dims=input_partition_dims, device_assignment=ds) x = array_ops.zeros((5, 5)) tensors = partitioned_infeed._check_dims_and_partition_or_replicate_on_host( x, dims=input_partition_dims[0]) self.assertEqual(4, len(tensors)) self.assertEqual((1, 5), tensors[2].shape) self.assertEqual((0, 5), tensors[3].shape)
def _WaitTillInit(): """Wait until the model is ready.""" try: with self._graph.as_default(), self._GetSession( cluster_def=self._cluster_def) as sess: topology = sess.run( tf.tpu.initialize_system(embedding_config=None, job=None)) device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=ComputationShape(num_devices_per_split), num_replicas=data_parallelism) py_utils.SetTpuDeviceAssignment(device_assignment) tf.logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) tf.logging.info('device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise
def _DeviceAssignment(self): """A context for tpu device assignment of a JF 8x8 slice.""" mesh_shape = [8, 8, 1, 2] device_coordinates = np.zeros([16, 8, 4], dtype=np.int32) for i in range(np.prod(mesh_shape)): x = i // 16 y = i % 16 // 2 core = i % 2 task = x // 2 * 4 + y // 2 device = x % 2 * 4 + y % 2 * 2 + core device_coordinates[task, device] = [x, y, 0, core] topology = tf.tpu.experimental.Topology( mesh_shape=mesh_shape, device_coordinates=device_coordinates) assignment = device_assignment.device_assignment( topology, computation_shape=[1, 1, 1, 1], num_replicas=128) py_utils.SetTpuDeviceAssignment(assignment) try: yield finally: py_utils.SetTpuDeviceAssignment(None)
def _init_tpu(self, num_partitions, device_order_mode): """Initialize tpu device assignment.""" tf.logging.info('Initializing TPU to get device assignment: start') graph = tf.Graph() with graph.as_default(): init_tpu_op = tf.tpu.initialize_system() try: sess = tf.Session(target=self._tpu, graph=graph, config=self._no_opt_sess_cfg()) topology = sess.run(init_tpu_op) except Exception as e: tf.logging.fatal('TPU initialization failed: %s', e) raise topology_proto = topology_pb2.TopologyProto() topology_proto.ParseFromString(topology) tf.logging.info('topology.num_tasks: %r', topology_proto.num_tasks) tf.logging.info('topology.num_tpu_devices_per_task: %r', topology_proto.num_tpu_devices_per_task) tf.logging.info('topology.mesh_shape: %r', topology_proto.mesh_shape) self.cluster_params = self._configure_cluster_params( tpu_cores=(topology_proto.num_tpu_devices_per_task * topology_proto.num_tasks), cpu_hosts=topology_proto.num_tasks) # We assume the topology and device assignment does not change # for a single address space. device_assignment = tpu_device_assignment.device_assignment( topology, computation_shape=py_utils.ComputationShape( num_partitions, topology), num_replicas=1, device_order_mode=device_order_mode) py_utils.SetTpuDeviceAssignment(device_assignment) tf.logging.info('Initializing TPU to get device assignment: done')
def test_get_laidout_tensors(self, is_eval_mode): mesh_shape = "mesh_x:2, mesh_y:1" layout = "batch:mesh_x, io:mesh_y" batch_io_dim = 4 with tf.Session() as sess: topology, num_cores = self.initialize_system(sess) # Get a device_assignment object for mtf. d_assignment = device_assignment.device_assignment( topology, computation_shape=[ 1, ] * mtf.utils.topology_rank(topology), num_replicas=num_cores) # Hacked dataset creator: creates different datasets for the first and # second call, in order to test SimdMeshImplInputReader. self.sub_batch_created_times = 0 def stateful_ds_creator(): whole_batch = tf.eye(batch_io_dim, dtype=tf.float32) sub_batch = tf.slice(whole_batch, [self.sub_batch_created_times * 2, 0], [2, 4]) self.sub_batch_created_times += 1 return tf.data.Dataset.from_tensors( sub_batch).repeat().unbatch() batch_dim = mtf.Dimension("batch", batch_io_dim) io_dim = mtf.Dimension("io", batch_io_dim) mtf_input_shapes = [mtf.Shape([batch_dim, io_dim])] # Get mesh_impl. mesh_shape = mtf.convert_to_shape(mesh_shape) layout_rules = mtf.convert_to_layout_rules(layout) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, None, d_assignment) simd_input_reader = input_reader.SimdMeshImplInputReader( mesh_impl, stateful_ds_creator, mtf_input_shapes, external_worker=False, is_eval_mode=is_eval_mode) def model_fn(features): return features replicated_computation = tpu.replicate( computation=model_fn, inputs=[[]] * num_cores, infeed_queue=simd_input_reader.infeed_queue, device_assignment=d_assignment) simd_input_reader.start_infeed_thread(sess, 1) results = sess.run(replicated_computation) print("results: {}".format(results)) core_0_data = results[0][0] core_1_data = results[1][0] print("core_0_data: {}".format(core_0_data)) print("core_1_data: {}".format(core_1_data)) if is_eval_mode: # If there is only one dataset object, then the stateful_ds_creator() # should be called only once. self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_0_data) self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_1_data) else: # If there are two dataset objects, then the stateful_ds_creator() # should be called twice. self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_0_data) self.assertAllClose( np.array([[0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32), core_1_data) sess.run(tf.tpu.shutdown_system())
def initialize(self, train_input_fn, eval_input_fn, model_fn, train_batch_size, eval_batch_size, input_partition_dims=None, init_fn=None, train_has_labels=True, eval_has_labels=True, params=None, num_partitions=None): """Build graphs for the TPU device and the input pipelines.""" num_cores_per_replica = 1 num_cores_per_replica = functools.reduce( operator.mul, input_partition_dims ) if input_partition_dims else num_partitions if num_partitions else 1 self.device_assignment = device_assignment.device_assignment( topology=self.device_topology, computation_shape=_NUM_CORES_TO_COMPUTATION_SHAPE[ num_cores_per_replica], num_replicas=self.num_replicas) self.train_batch_size = train_batch_size self.eval_batch_size = eval_batch_size self.eval_has_labels = eval_has_labels self.model_fn = model_fn if params is None: params = {} params[ "dataset_num_shards"] = self.num_replicas // FLAGS.replicas_per_host per_replica_train_batch_size = train_batch_size // self.num_replicas per_replica_eval_batch_size = eval_batch_size // self.num_replicas for i in range(self.num_replicas // FLAGS.replicas_per_host): params["dataset_index"] = i params["batch_size"] = per_replica_train_batch_size self.build_enqueue_ops(train_input_fn, True, input_partition_dims, params) if self.eval_steps > 0: params["batch_size"] = per_replica_eval_batch_size self.build_enqueue_ops(eval_input_fn, False, input_partition_dims, params) def train_step(_): """One train step.""" inp = self.infeed_op[True].generate_dequeue_op() flatten_structure = tf.nest.flatten(self.feature_structure[True]) inp = [ tf.slice(i, [0] * i.shape.ndims, j.shape) for i, j in zip(inp, flatten_structure) ] if train_has_labels: features, labels = tf.nest.pack_sequence_as( self.feature_structure[True], inp) else: features = tf.nest.pack_sequence_as( self.feature_structure[True], inp) labels = None self.maybe_add_embedding_features(features, True) train_op, _ = model_fn(features, labels, True) embedding_train_op = self.maybe_get_embedding_train_op() with tf.device(device_for_tpu_core(self.get_host(0))): with tf.control_dependencies([train_op, embedding_train_op]): return tf.constant(0) @tpu_function.on_device_training_loop def train_loop(): return training_loop.repeat(self.iterations_per_loop, train_step, tf.constant(0)) def train_eval_step(): with tf.control_dependencies(train_loop()): if self.eval_steps > 0: return self.eval_loop() else: return tf.no_op() @on_device_train_and_eval_loops def train_eval_loop(): return training_loop.repeat(self.max_train_iterations, train_eval_step) with self.graph.as_default(): (self.train_eval_op, ) = tpu.shard( train_eval_loop, inputs=[], num_shards=self.num_replicas, outputs_from_all_shards=False, device_assignment=self.device_assignment) if FLAGS.model_dir: tf.io.write_graph(self.graph, FLAGS.model_dir, "graph.pbtxt") output_graph = tf.Graph() if self.eval_steps > 0: with output_graph.as_default(): flatten_output = tf.nest.flatten(self.predict_output) self.dequeue_ops = [[] for _ in flatten_output] tensor_dtypes = [v.dtype for v in flatten_output] tensor_shapes = [v.shape for v in flatten_output] is_padded_index = flatten_output.index( self.predict_output[_IS_PADDED] ) if _IS_PADDED in self.predict_output else -1 for i in range(self.num_replicas // FLAGS.replicas_per_host): with tf.device(device_for_host(self.get_host(i))): host_dequeue_ops = [[] for _ in flatten_output] for j in range(FLAGS.replicas_per_host): replica_id = self.device_assignment.lookup_replicas( i, 0)[j] ordinal = self.device_assignment.tpu_ordinal( replica=replica_id, logical_core=0) dequeue_ops = tpu_ops.outfeed_dequeue_tuple( dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=ordinal) if is_padded_index >= 0: num_non_pad = tf.shape( dequeue_ops[is_padded_index] )[0] - tf.reduce_sum( tf.cast(dequeue_ops[is_padded_index], tf.int32)) dequeue_ops = [ tf.slice(k, [0] * k.shape.ndims, [num_non_pad] + [-1] * (k.shape.ndims - 1)) for k in dequeue_ops ] for k, item in enumerate(dequeue_ops): host_dequeue_ops[k].append(item) for k in range(len(self.predict_output)): self.dequeue_ops[k].append( tf.concat(host_dequeue_ops[k], axis=0)) self.sess = tf.Session(self.master, graph=self.graph, config=self.config) for is_training in [True, False]: if is_training or self.eval_steps > 0: for i in range(self.num_input_graphs): with self.input_graph[is_training][i].as_default(): self.input_sess[is_training][i] = tf.Session( self.master, graph=self.input_graph[is_training][i], config=self.config) self.input_sess[is_training][i].run( self.dataset_initializer[is_training][i]) self.output_sess = tf.Session(self.master, graph=output_graph, config=self.config) with self.graph.as_default(): _ = tf.train.get_or_create_global_step() if init_fn: init_fn() checkpoint_path = tf.train.latest_checkpoint( FLAGS.model_dir) if FLAGS.model_dir else None if FLAGS.restore_checkpoint and checkpoint_path: tf.train.Saver().restore(self.sess, checkpoint_path) else: self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.local_variables_initializer()) self.maybe_load_embedding_vars() self.global_step = self.sess.run( tf.train.get_global_step(self.graph)) def train_eval_thread_fn(sess, train_eval_op): sess.run([train_eval_op]) # Start the just in time compilation of the model function self.train_eval_thread = threading.Thread(target=train_eval_thread_fn, args=(self.sess, self.train_eval_op)) self.train_eval_thread.start() # Sleep for JTC to finish time.sleep(FLAGS.sleep_after_init)
def test_bert_forward(self): def create_computation_fn(device_assignment): d_model = 128 num_blocks = 2 seq_length = 128 batch_size = 2 vocab_size = 30522 bert_config = bert_lib.BertConfig( vocab_size=vocab_size, d_model=int(d_model), num_blocks=int(num_blocks), attention_num_heads=int(d_model / 64), feedforward_intermediate_size=int(d_model * 4), feedforward_intermediate_act='relu', feedforward_intermediate_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=seq_length, type_vocab_size=2, initializer_range=0.02) def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'num_heads:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) batch_dim = mtf.Dimension('batch', batch_size) seq_dim = mtf.Dimension('seq', seq_length) input_ids = tf.random.uniform((batch_size, seq_length), minval=0, maxval=vocab_size, dtype=tf.int32) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) model = bert_lib.BertModel(config=bert_config, is_training=True, input_ids=mtf_input_ids, input_mask=None, token_type_ids=None) pooled = model.get_pooled_output() lowering = mtf.Lowering(graph, {mesh: mesh_impl}) return lowering.export_to_tf_tensor(pooled) return computation_fn with self.test_session() as sess: topology = sess.run(tf.tpu.initialize_system()) device_assignment = tpu_device_assignment.device_assignment( topology, computation_shape=[1, 1, 1], num_replicas=2) tpu_computation_fn = tf.tpu.batch_parallel( create_computation_fn(device_assignment), inputs=None, num_shards=2, infeed_queue=None, device_assignment=device_assignment) sess.run(tf.global_variables_initializer()) sess.run(tf.variables_initializer(tf.get_collection('TPU_VAR'))) print('TPU', sess.run(tpu_computation_fn)) sess.run(tf.tpu.shutdown_system())