def _build_select_slate_op(self): p_no_click = self._prob_no_click_ph p = self._doc_affinity_scores_ph q = self._net_outputs.q_values[0] with tf.name_scope('select_slate'): self._output_slate = self._select_slate_fn(self._slate_size, p_no_click, p, q) self._output_slate = tf.Print( self._output_slate, [tf.constant('cp 1'), self._output_slate, p, q], summarize=10000) self._output_slate = tf.reshape(self._output_slate, (self._slate_size, )) self._action_counts = tf.get_variable( 'action_counts', shape=[self._num_candidates], initializer=tf.zeros_initializer()) output_slate = tf.reshape(self._output_slate, [-1]) output_one_hot = tf.one_hot(output_slate, self._num_candidates) update_ops = [] for i in range(self._slate_size): update_ops.append( tf.assign_add(self._action_counts, output_one_hot[i])) self._select_action_update_op = tf.group(*update_ops)
def _import_feature(key, allow_missing=False): """Import a feature from the features dictionary into a mtf.Tensor. Args: key: a string allow_missing: a boolean Returns: a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim] """ outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) length_dim = mtf.Dimension("length", sequence_length) mtf_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim]) if key not in features: if allow_missing: return None else: raise ValueError("feature not found %s - features %s = " % (key, features)) tf.logging.info("Import feature %s: %s" % (key, features[key])) x = tf.to_int32(features[key]) x = tf.reshape( x, [outer_batch_size, batch_size // outer_batch_size, -1]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key)
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config global_step = tf.train.get_global_step() if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = _logical_to_physical(physical_shape, mesh_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) def _import_feature(key, allow_missing=False): """Import a feature from the features dictionary into a mtf.Tensor. Args: key: a string allow_missing: a boolean Returns: a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim] """ outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) length_dim = mtf.Dimension("length", sequence_length) mtf_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim]) if key not in features: if allow_missing: return None else: raise ValueError("feature not found %s - features %s = " % (key, features)) tf.logging.info("Import feature %s: %s" % (key, features[key])) x = tf.to_int32(features[key]) x = tf.reshape( x, [outer_batch_size, batch_size // outer_batch_size, -1]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) if mode == tf.estimator.ModeKeys.PREDICT: inputs = _import_feature("inputs") inputs = mtf.reshape( inputs, mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", sequence_length) ])) if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Bitransformer): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"outputs": outputs} return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) targets = _import_feature("targets") anon_targets = mtf.anonymize(targets) if model_type == "lm": _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = _import_feature("inputs") if mode == tf.estimator.ModeKeys.EVAL: if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Bitransformer): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) labels = lowering.export_to_tf_tensor(anon_targets) restore_hook = mtf.MtfRestoreHook(lowering) # metric_names becomes locally scoped if we simply assign # ["padded_neg_log_perplexity"] to it conditioned on if it's None. local_metric_names = metric_names or ["token_accuracy"] def metric_fn(labels, outputs): return get_metric_fns(local_metric_names, labels, outputs) eval_metrics = (metric_fn, [labels, outputs]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, # Unfortunately TPUEstimatorSpec requires us to provide a value for # loss when in EVAL mode. Since we are sampling or decoding from the # model, we don't have a loss to report. loss=tf.constant(0.), evaluation_hooks=[restore_hook], eval_metrics=eval_metrics) if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=_import_feature("targets_segmentation", True), position=_import_feature("targets_position", True), ) elif isinstance(transformer_model, transformer.Bitransformer): position_kwargs = dict( encoder_sequence_id=_import_feature("inputs_segmentation", True), decoder_sequence_id=_import_feature("targets_segmentation", True), encoder_position=_import_feature("inputs_position", True), decoder_position=_import_feature("targets_position", True), ) else: raise ValueError("unrecognized class") logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer( learning_rate=learning_rate) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print(tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=checkpoints_to_keep, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ])
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config global_step = tf.train.get_global_step() if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = _logical_to_physical(physical_shape, mesh_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) length_dim = mtf.Dimension("length", sequence_length) feature_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim]) mtf_features = {} for key, x in features.items(): x = tf.to_int32(features[key]) x = tf.reshape(x, [ outer_batch_size, batch_size // outer_batch_size, sequence_length ]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) if mode == tf.estimator.ModeKeys.PREDICT: inputs = mtf_features["inputs"] inputs = mtf.reshape( inputs, mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", sequence_length) ])) if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance( transformer_model, (transformer.Bitransformer, transformer.StudentTeacher)): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"outputs": outputs} return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) elif mode == tf.estimator.ModeKeys.EVAL: raise NotImplementedError("We don't expect to use mode == eval.") else: assert mode == tf.estimator.ModeKeys.TRAIN num_microbatches = serialize_num_microbatches( batch_dim, length_dim, mesh_shape, layout_rules) def model_fn(mtf_features): """The kind of function we need for mtf.serialize_training_step. Args: mtf_features: a dictionary Returns: a dictionary """ targets = mtf_features["targets"] if model_type == "lm": _, _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = mtf_features["inputs"] if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=mtf_features.get("targets_segmentation", None), position=mtf_features.get("targets_position", None), ) elif isinstance(transformer_model, transformer.Bitransformer ) or model_type == "bi_student_teacher": position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "targets_segmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "targets_position", None), ) else: raise ValueError("unrecognized class") logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if num_microbatches > 1: loss /= float(num_microbatches) del logits return {"loss": loss} if num_microbatches > 1: var_grads, loss_dict = mtf.serialize_training_step( mtf_features, model_fn, batch_dim, num_microbatches) else: loss_dict = model_fn(mtf_features) var_grads = mtf.gradients( [loss_dict["loss"]], [v.outputs[0] for v in graph.trainable_variables]) loss = loss_dict["loss"] if callable(learning_rate_schedule): # the following happens on CPU since TPU can't handle summaries. with mtf.utils.outside_all_rewrites(): learning_rate = learning_rate_schedule( step=tf.train.get_global_step()) tf.summary.scalar("learning_rate", learning_rate) else: learning_rate = learning_rate_schedule update_ops = optimizer(learning_rate=learning_rate).apply_grads( var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print( tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) if hasattr(transformer_model, "initialize"): with mtf.utils.outside_all_rewrites(): transformer_model.initialize() with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=keep_checkpoint_max, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_checkpoints_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True) if use_tpu: if tpu_summaries: tf.summary.scalar("loss", tf_loss) host_call = mtf.utils.create_host_call(model_dir) mtf.utils.remove_summaries() else: host_call = None return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ])