def _Apply(): """Use the matched optimizer to apply the gradients.""" train_ops = [] non_default_regex = [ regex for regex in self._optimizer_map if regex != 'default_optimizer' ] for regex in self._optimizer_map: if var_grad_map[regex]: opt = tf_optimizer_map[regex] train_ops.append(opt.apply_gradients(var_grad_map[regex])) # pylint: disable=cell-var-from-loop, g-long-lambda if regex == 'default_optimizer': filtered_var_grad = var_grad.FilterKeyVal( lambda k, v: any([ re.match(i, v.var.name) for i in non_default_regex ])) else: filtered_var_grad = var_grad.FilterKeyVal( lambda k, v: (re.match(regex, v.var.name))) # pylint: enable=cell-var-from-loop, g-long-lambda self._optimizer_map[regex].AddSummary( self._lr_map[regex], opt, filtered_var_grad) return tf.group(*train_ops, name='composite_optimizer_train_op')
def PostTrainingStepUpdate(self, global_step): ops = [ super(PassiveAsymQDomain, self).PostTrainingStepUpdate(global_step) ] for t_name in self._t_names: ops.extend(self._RecordTensor(t_name)) self._SummarizeTensor(t_name) return tf.group(ops)
def CreateTpuEmbeddingEnqueueOps(self): """Creates the TpuEmbedding enqueue ops on the host. Note that this must be called after the instantiation of the monolithic TPUEmbeddingLayer. """ p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) enqueue_ops = [] if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) if not tpu_embedding: return for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): if isinstance(self._batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. self._batch = self._batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, self._batch) enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = self._batch[key] tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host) for core, split in enumerate(tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where(tf.not_equal(split, -1)) embedding_indices = tf.gather_nd(split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data enqueue_ops += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) self._tpu_infeed_op.append(tf.group(*enqueue_ops))
def _ApplyAndReset(): with tf.control_dependencies([ self._opt.Apply( lr, py_utils.ApplyGradMultiplier(var_grad, 1. / p.accum_steps)) ]): return tf.group(*[ tf.assign(a, tf.zeros_like(a)) for _, a in var_grad.Flatten() ])
def ApplyPostTrainingLoop(self, global_step): """Apply any computation to run after each tpu training loop for each optimizer. Args: global_step: Global step variable. Returns: Ops to run after training loop ends. """ post_training_ops = [ opt.ApplyPostTrainingLoop(global_step) for _, opt in self._optimizer_map.items() ] return tf.group(*post_training_ops)
def PostTrainingStepUpdate(self, global_step): """Updates moving_mean, moving_variance after each training step.""" p = self.params # Get sufficient stats that accumulates over microbatches. counts = self.accumulators.counts.GetValue() mean_ss = self.accumulators.mean_ss.GetValue() variance_ss = self.accumulators.variance_ss.GetValue() # Compute batch mean and batch variance from sufficient stats mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss, None) decay = tf.convert_to_tensor(1.0 - p.decay, p.dtype) # Update moving_mean, moving_variance from batch mean and batch variance. with tf.name_scope(p.name) as scope: with tf.ops.colocate_with(self.vars.moving_mean): mean_update = tf.assign_sub( self.vars.moving_mean, tf.where(tf.greater(counts, 0.5), (self.vars.moving_mean - tf.cast(mean, p.dtype)) * decay, tf.zeros_like(self.vars.moving_mean)), name='moving_mean_update') with tf.ops.colocate_with(self.vars.moving_variance): var_update = tf.assign_sub( self.vars.moving_variance, tf.where(tf.greater(counts, 0.5), (self.vars.moving_variance - tf.cast(variance, p.dtype)) * decay, tf.zeros_like(self.vars.moving_variance)), name='moving_variance_update') py_utils.CheckNumerics( self.vars.moving_mean, 'moving mean of {} failed numeric check'.format(scope)) py_utils.CheckNumerics( self.vars.moving_variance, 'moving variance of {} failed numeric check'.format(scope)) self.accumulators.counts.Reset() self.accumulators.mean_ss.Reset() self.accumulators.variance_ss.Reset() return tf.group(mean_update, var_update)
def Apply(self, lr, var_grad): p = self.params def _Acc(vg): """Updating accumulators.""" v, g = vg with tf.variable_scope(v.op.name): _, a = py_utils.CreateVariable( 'grad_accumulator', py_utils.WeightParams(v.get_shape(), py_utils.WeightInit.Constant(0.0), self.params.dtype), trainable=False) a = tf.assign_add(a, g) return py_utils.VarGrad(v, a) var_grad = var_grad.Transform(_Acc) def _ApplyAndReset(): with tf.control_dependencies([ self._opt.Apply( lr, py_utils.ApplyGradMultiplier(var_grad, 1. / p.accum_steps)) ]): return tf.group(*[ tf.assign(a, tf.zeros_like(a)) for _, a in var_grad.Flatten() ]) return tf.cond( tf.equal(tf.math.floormod(self.theta.global_step, p.accum_steps), p.accum_steps - 1), _ApplyAndReset, lambda: tf.group(tf.no_op()))
def Export(cls, model_cfg, model_task_name=None, device_options=InferenceDeviceOptions( device='', retain_device_placement=False, var_options=None, gen_init_op=True, dtype_override=None), freeze_checkpoint=None, freeze_defaults=False, export_path=None, subgraph_filter=None, random_seed=None, disable_packed_input=True): """Exports a InferenceGraph proto with piecewise subgraphs. Sets FLAGS.enable_asserts to False unless user explicitly sets it to True. Args: model_cfg: a Params instance as returned by model_registry.GetParams(modelname, 'Test') or model_params.Model(). model_task_name: The task to generate an inference graph for. Should be None for single-task models. device_options: Device options for the accelerator used for serving. freeze_checkpoint: The checkpoint to load. Loads and freezes the model if given. freeze_defaults: Default initializes the graph and freeze. Useful for early testing of downstream tools without having a checkpoint. export_path: If not None, write the inference graph in ASCII to this path. subgraph_filter: A list of subgraph names. If not None or empty, export only this list of inference subgraphs. random_seed: Fixes the random seed in the exported inference graph. disable_packed_input: Disable packed input for inference writing purposes. Returns: InferenceGraph proto. Raises: ValueError: if the model does not support the listed subgraphs. """ assert issubclass(model_cfg.cls, base_model.BaseModel) # Disable assertions unless user explicitly enables it. if FLAGS['enable_asserts'].using_default_value: FLAGS.enable_asserts = False # TODO(laurenzo): Work out how much we need to specify here in terms of # cluster configuration. cls._SetClusterParams(model_cfg.cluster, device_options) # Configure the model. model_cfg.random_seed = random_seed model_cfg.is_inference = True if disable_packed_input: def _DisablePackedInput(task): if (_ParamExists(task, 'encoder') and _ParamExists(task.encoder, 'packed_input')): task.encoder.packed_input = False if (_ParamExists(task, 'decoder') and _ParamExists(task.decoder, 'packed_input')): task.decoder.packed_input = False if issubclass(model_cfg.cls, base_model.MultiTaskModel): for _, task_param in model_cfg.task_params.IterParams(): _DisablePackedInput(task_param) else: _DisablePackedInput(model_cfg.task) tf.logging.info('Model %s params:', model_cfg.name) for line in model_cfg.ToText().split('\n'): tf.logging.info('%s', line) # Instantiate the graph. graph = tf.Graph() with graph.as_default(): tf.random.set_seed(random_seed) cluster = model_cfg.cluster.Instantiate() device = cluster.GetPlacer() tpu_const_scope = _DummyScope() if (IsTpu(device_options) and device_options.var_options == 'AS_CONSTANTS'): # Do not specify devices for variables if we are marking them as # constants. device = '' tpu_const_scope = ConstGuaranteeScope() with cluster, tf.device(device), tpu_const_scope: bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations( device_options) if bfloat16_override: py_utils.UpdateDtype(model_cfg, tf.bfloat16) py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16) # Hard-code TPU-related flags prior to instantiating model. old_enable_asserts = FLAGS.enable_asserts old_xla_device = FLAGS.xla_device if IsTpu(device_options): FLAGS.enable_asserts = False FLAGS.xla_device = 'tpu' # Ensure the global_step variable is created. global_step_var = py_utils.GetOrCreateGlobalStepVar() global_step = tf.identity(global_step_var, name='global_step_tensor') with py_utils.GlobalStepContext(global_step): try: mdl = model_cfg.Instantiate() variables_to_restore = (_MakeVariableDictionary( tf.global_variables()) if not mdl.ema else mdl.ema.variables_to_restore( mdl.variables_for_ema)) if bfloat16_override: saver_var_spec = ( bfloat16_variables. get_saver_spec_for_variables_with_bf16_overrides( variables_to_restore)) else: saver_var_spec = variables_to_restore saver = tf.train.Saver(saver_var_spec) tf.variables_initializer(tf.global_variables(), name='init_all_variables') if IsTpu( device_options) and device_options.gen_init_op: tf.group(tf.tpu.initialize_system(), name='tpu_init_op') model_task = mdl.GetTask(model_task_name) inference_graph_proto = inference_graph_pb2.InferenceGraph( ) subgraphs_proto = model_task.Inference() if isinstance(subgraphs_proto, dict): subgraphs_proto = ConvertSubgraphDictToProto( subgraphs_proto) for name, subgraph in subgraphs_proto.subgraphs.items( ): if not subgraph_filter or name in subgraph_filter: inference_graph_proto.subgraphs[name].CopyFrom( subgraph) # Add a table init op and global variable init op to the graph. # Tables can be declared anywhere in the graph, so this op has to be # added last. tf.tables_initializer(name='init_all_tables') finally: # Reset TPU-related flags after model instantiation. FLAGS.enable_asserts = old_enable_asserts FLAGS.xla_device = old_xla_device tf.logging.info('Graph contains ops: %r', [op.name for op in graph.get_operations()]) inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def()) # Freezing. if freeze_defaults or freeze_checkpoint: output_op_names = GetOutputOpNames(graph, inference_graph_proto, preserve_colocation_nodes=False) if cls._DeviceSupportsFreezing(device_options): raise ValueError( 'freeze_checkpoint cannot be used with device ' + device_options.device) if freeze_checkpoint: tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint) graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint, output_op_names) elif freeze_defaults: tf.logging.info('Default initializing graph and freezing.') graph_def = _FreezeDefaults(graph, output_op_names) else: output_op_names = GetOutputOpNames(graph, inference_graph_proto) # Prune the graph to just the parts we need. # To support restoring, we have to not prune out the restore node. output_op_names.append('init_all_tables') output_op_names.append('init_all_variables') output_op_names.append('save/control_dependency') output_op_names.append('save/restore_all') if IsTpu(device_options) and device_options.gen_init_op: output_op_names.append('tpu_init_op') graph_def = graph.as_graph_def() tf.logging.info('Pruning graph to output ops: %r', output_op_names) graph_def = tf.graph_util.extract_sub_graph( graph_def, output_op_names) if not device_options.retain_device_placement: # Clear the device so that the runtime can choose. tf.logging.info('Clearing device placement for: %s', device_options.device) for node in graph_def.node: node.ClearField('device') for function in graph_def.library.function: for node_def in function.node_def: node_def.ClearField('device') inference_graph_proto.graph_def.CopyFrom(graph_def) if export_path: with tf.io.gfile.GFile(export_path, 'w') as f: f.write(text_format.MessageToString(inference_graph_proto)) return inference_graph_proto