def _OutfeedEnqueue(self, per_example_tensors): if not per_example_tensors: return tf.no_op() per_example_tensors = py_utils.NestedMap(per_example_tensors) device = tpu.core(0) if self.spmd else '' with tf.device(device): return tpu_ops.outfeed_enqueue_tuple(per_example_tensors.Flatten())
def _DecodeStep(): """Decode call to be compiled for TPU.""" input_batch = self._task.input.TpuDequeueBatch() metrics_dict = self._task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) device = tpu.core(0) if self.spmd else '' with tf.device(device): outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple( self.metrics_nm.Flatten()) return [outfeed_enqueue]
def decode_fn(*infeed_batch): # pylint: disable=missing-docstring # Length 6 is passed when there is no tgt_mask (e.g. decoding) and # length 7 is passed when there is a tgt_mask (e.g. fprop). self.outfeed = self._config_outfeed(xformer, infeed_batch) with tf.device(tf.tpu.core(0)): outfeed_op = tpu_ops.outfeed_enqueue_tuple( tf.nest.flatten(self.outfeed)) return [outfeed_op]
def eval_loop(self): per_replica_eval_batch_size = self.eval_batch_size // self.num_replicas tf.get_variable_scope().reuse_variables() predictions = tf.zeros( [self.eval_steps, per_replica_eval_batch_size, 2]) _, predictions = training_loop.repeat(int(self.eval_steps), self.eval_step, [tf.constant(0), predictions]) with tf.control_dependencies( [tpu_ops.outfeed_enqueue_tuple([predictions])]): return tf.no_op()
def _DecodeStep(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() input_batch = self._task.input.TpuDequeueBatch() metrics_dict = self._task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) device = tpu.core(0) if self.spmd else '' with tf.device(device): outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple( self.metrics_nm.Flatten()) return [outfeed_enqueue]
def tpu_eval_step(): """Generate the TPU graph.""" values = self.eval_infeed_queue[0].generate_dequeue_op( tpu_device=0) unflattened_inputs = data_nest.pack_sequence_as( self.eval_feature_structure, values) features = unflattened_inputs["features"] estimator_spec = model_fn(features, None, tf.estimator.ModeKeys.PREDICT, params) for k, v in six.iteritems(estimator_spec.predictions): self.outfeed_names.append(k) self.outfeed_tensors.append(v) with tf.device(utils.device_for_tpu_core(self._get_host(0))): outfeed_enqueue_ops = tpu_ops.outfeed_enqueue_tuple( self.outfeed_tensors) with tf.control_dependencies([outfeed_enqueue_ops]): return tf.no_op()
def eval_step(self): """One evaluation step.""" inp = self.infeed_op[False].generate_dequeue_op() flatten_structure = tf.nest.flatten(self.feature_structure[False]) inp = [ tf.slice(i, [0] * i.shape.ndims, j.shape) for i, j in zip(inp, flatten_structure) ] if self.eval_has_labels: features, labels = tf.nest.pack_sequence_as( self.feature_structure[False], inp) else: features = tf.nest.pack_sequence_as(self.feature_structure[False], inp) labels = None self.maybe_add_embedding_features(features, False) _, self.predict_output = self.model_fn(features, labels, False) for _ in self.predict_output: self.dequeue_ops.append([]) with tf.device(device_for_tpu_core(self.get_host(0))): return [ tpu_ops.outfeed_enqueue_tuple( tf.nest.flatten(self.predict_output)) ]
def _OutfeedEnqueue(self, per_example_tensors): if not per_example_tensors: return tf.no_op() per_example_tensors = py_utils.NestedMap(per_example_tensors) return tpu_ops.outfeed_enqueue_tuple(per_example_tensors.Flatten())
def _model_fn(input_fea, input_lab): """Creates a model, add summary, modes (train or eval), and hooks.""" # input_fea and input_lab should be a list (laid_out_tensors). if not isinstance(input_fea, list): input_fea = [input_fea] if not isinstance(input_lab, list): input_lab = [input_lab] def _add_summary(lowering, train_or_eval, tf_loss, scalars, global_step): """Add all summaries.""" for k in scalars.keys(): if not isinstance(scalars[k], tf.Tensor): scalars[k] = tf.cast( lowering.export_to_tf_tensor(scalars[k]), tf.float32) def _host_loss_summary(global_step, tf_loss, **scalars): """Add summary.scalar in host side.""" gs = tf.cast(global_step, tf.int64) sum_loss = contrib_summary.scalar( '{}_loss'.format(train_or_eval), tf_loss, step=gs) sum_ops = [sum_loss.op] for description, tf_metric in scalars.iteritems(): sum_metric = contrib_summary.scalar( '{}_{}'.format(train_or_eval, description), tf_metric, step=gs) sum_ops.append(sum_metric) with tf.control_dependencies(sum_ops): return tf.identity(tf_loss) if FLAGS.use_tpu: # Cast the global step to tf.int32, since # outside_compilation does not support tf.int64. tf_loss = tpu.outside_compilation( _host_loss_summary, tf.cast(global_step, tf.int32), tf_loss, **scalars) else: tf_loss = _host_loss_summary( tf.cast(global_step, tf.int32), tf_loss, **scalars) return tf_loss global_step = tf.train.get_or_create_global_step() graph, mesh, mesh_impl = mesh_context.create_graph_mesh_and_mesh_impl() with mtf.utils.outside_all_rewrites(): # Do not tpu_rewrite this part. Inside this unet, If you use Tensorflow, # instead of Mesh-Tensorflor, it will cause host to tpu send/rec. preds, loss, scalars, bn_update_ops = ( unet.unet_with_spatial_partition( mesh, mesh_impl, train_or_eval, input_fea, input_lab)) if train_or_eval == 'train': var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = FLAGS.lr * tf.pow( FLAGS.lr_drop_rate, tf.floor(tf.cast(global_step, tf.float32) / FLAGS.lr_drop_steps)) scalars['learning_rate'] = lr optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) # This is where the actual tf graph got built. lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf_update_ops.extend( [lowering.lowered_operation(op) for op in bn_update_ops]) else: # train_or_eval == 'eval': preds = [mtf.anonymize(pred) for pred in preds] # This is where the actual tf graph got built. lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_preds = [tf.cast( lowering.export_to_tf_tensor(pred), tf.float32) for pred in preds] tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32) if FLAGS.write_summary: tf_loss = _add_summary( lowering, train_or_eval, tf_loss, scalars, global_step) master_to_slice_hook = mtf.MtfRestoreHook(lowering) if train_or_eval == 'train': with mtf.utils.outside_all_rewrites(): saver = tf.train.Saver(tf.global_variables(), save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) slice_to_master_hook = tf.train.CheckpointSaverHook( FLAGS.checkpoint_dir, save_steps=FLAGS.save_checkpoints_steps, saver=saver, listeners=[saver_listener]) captured_hooks.capture([master_to_slice_hook, slice_to_master_hook]) return tf.group([tf_loss] + tf_update_ops) else: # train_or_eval == 'eval': if FLAGS.use_tpu: tf_preds.extend([tf_loss, global_step]) tf_preds_dtypes = [tf_pred.dtype for tf_pred in tf_preds] tf_preds_shapes = [tf_pred.shape for tf_pred in tf_preds] captured_hooks.capture([master_to_slice_hook, None]) captured_output_dtypes_shapes.capture( [tf_preds_dtypes, tf_preds_shapes]) return tpu_ops.outfeed_enqueue_tuple(tf_preds) else: tf_preds.extend([tf_loss, global_step]) captured_hooks.capture([master_to_slice_hook, None]) return tf_preds
def _model_fn(input_fea, input_lab): """Creates a model, add summary, modes (train or eval), and hooks.""" def _add_summary(lowering, train_or_eval, tf_loss, scalars, global_step): """Add all summaries.""" for k in scalars.keys(): if not isinstance(scalars[k], tf.Tensor): scalars[k] = tf.cast( lowering.export_to_tf_tensor(scalars[k]), tf.float32) def _host_loss_summary(global_step, tf_loss, **scalars): """Add summary.scalar in host side.""" gs = tf.cast(global_step, tf.int64) sum_loss = tf.contrib.summary.scalar( '{}_loss'.format(train_or_eval), tf_loss, step=gs) sum_ops = [sum_loss.op] for description, tf_metric in scalars.iteritems(): sum_metric = tf.contrib.summary.scalar('{}_{}'.format( train_or_eval, description), tf_metric, step=gs) sum_ops.append(sum_metric) with tf.control_dependencies(sum_ops): return tf.identity(tf_loss) # Cast the global step to tf.int32, since # outside_compilation does not support tf.int64. tf_loss = tpu.outside_compilation(_host_loss_summary, tf.cast(global_step, tf.int32), tf_loss, **scalars) return tf_loss global_step = tf.train.get_or_create_global_step() graph = mtf.Graph() # Worker 0 caches all the TPU binaries. replica_cache_size = 300 * 1024 * 1024 # 300M per replica. worker0_mem = replica_cache_size * 8 * num_hosts devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1) tf.logging.info('cpu_devices: {}, devices_mem: {}'.format( cpu_devices, devices_memory_usage)) var_placer = mtf.utils.BalancedVariablePlacer(cpu_devices, devices_memory_usage) mesh = mtf.Mesh(graph, 'my_mesh', var_placer) mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = unet.get_layout() mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, None, d_assignment) with mtf.utils.outside_all_rewrites(): # Do not tpu_rewrite this part. preds, loss, scalars, bn_update_ops = ( unet.unet_with_spatial_partition(mesh, train_or_eval, input_fea, input_lab)) if train_or_eval == 'train': var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = FLAGS.lr * tf.pow( FLAGS.lr_drop_rate, tf.floor( tf.cast(global_step, tf.float32) / FLAGS.lr_drop_steps)) scalars['learning_rate'] = lr optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) # This is where the actual tf graph got built. lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) tf_update_ops.extend( [lowering.lowered_operation(op) for op in bn_update_ops]) tf_update_ops_group = tf.group(tf_update_ops) else: # train_or_eval == 'eval': preds = [mtf.anonymize(pred) for pred in preds] # This is where the actual tf graph got built. lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_preds = [ tf.cast(lowering.export_to_tf_tensor(pred), tf.float32) for pred in preds ] tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32) if FLAGS.write_summary: tf_loss = _add_summary(lowering, train_or_eval, tf_loss, scalars, global_step) master_to_slice_hook = mtf.MtfRestoreHook(lowering) if train_or_eval == 'train': with mtf.utils.outside_all_rewrites(): saver = tf.train.Saver(tf.global_variables(), save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) slice_to_master_hook = tf.train.CheckpointSaverHook( FLAGS.checkpoint_dir, save_steps=FLAGS.save_checkpoints_steps, saver=saver, listeners=[saver_listener]) captured_hooks.capture( [master_to_slice_hook, slice_to_master_hook]) return tf_update_ops_group else: # train_or_eval == 'eval': tf_preds.extend([tf_loss, global_step]) tf_preds_dtypes = [tf_pred.dtype for tf_pred in tf_preds] tf_preds_shapes = [tf_pred.shape for tf_pred in tf_preds] captured_hooks.capture([master_to_slice_hook, None]) captured_output_dtypes_shapes.capture( [tf_preds_dtypes, tf_preds_shapes]) return tpu_ops.outfeed_enqueue_tuple(tf_preds)