def Inference(self): with tf.name_scope('inference'): feed1 = tf.placeholder(name='feed1_node', dtype=tf.float32, shape=[1]) fetch1 = tf.identity(feed1, name='fetch1_node') inference_graph = inference_graph_pb2.InferenceGraph() subgraph = inference_graph.subgraphs['default'] subgraph.feeds['feed1'] = feed1.name subgraph.fetches['fetch1'] = fetch1.name return inference_graph
def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs, unused_step_ids, states, unused_num_hyps_per_beam): atten_probs = tf.identity(states.atten_probs) logits = tf.random.normal([tgt_batch_size, vocab_size], seed=8273747) return (py_utils.NestedMap({ 'atten_probs': atten_probs, 'log_probs': logits }), states)
def __AddVariable(self, var): name = self.get_name(var) if name in self._activated_var_names: tf.logging.info( "Warning, already activated variable with name {}".format(name)) return tf.logging.info("Adding variable {}".format(name)) self._private_vars[name] = var self._private_theta[name] = tf.identity(var) self._activated_var_names.add(name)
def _ConstructPostTrainingLoop(train_loop_op, outfeed_dequeue_op): """Returns the op for tpu training with tail cpu computation.""" # Adds a tail computation that is run after the tpu_training loop # step finishes. This allows us to run certain computation that # acts on the variable between tpu_train_loop iterations and # amortizing the cost of the operations. Alternative of running # tpu.outside_compilation & using tf.cond is expenseive. with tf.control_dependencies(train_loop_op): self._model.ConstructPostTrainingLoop() with tf.control_dependencies([self._task.post_training_loop_op]): return ([[tf.identity(o) for o in train_loop_op], outfeed_dequeue_op])
def _CreateLayerVariables(self): p = self.params # Reuse the singleton table variables if they were created before. all_table_vars = self._tpu_embedding_collection.table_variables if self.table_name in all_table_vars: embedding_table_vars = all_table_vars[self.table_name] else: w_pc = py_utils.WeightParams( shape=[self._ids_per_shard, p.embedding_dim], init=p.params_init, dtype=p.dtype, collections=[self.__class__.__name__ + '_vars']) embedding_table_vars = [] for i in range(p.num_tpu_hosts): device_name = self.GetDeviceName(i) with tf.device(device_name), py_utils.outside_all_rewrites(): var_name = self.GetVariableName(i) self.CreateVariable(var_name, w_pc) embedding_var = self.vars[var_name] embedding_table_vars.append(embedding_var) # Remove from _private_vars / _private_thetas to be added later as wm. _RemovePrivateVar(self, var_name) self._tpu_embedding_collection.AddTableVariables(self.table_name, embedding_table_vars) if not _ShouldUseTpu(p): # We don't want to add this for TrainerTpu, otherwise the identity # reference leads to copying the embedding to the TPU for no reason. # However, this is needed for CPU (eval/decode/controller). self._private_vars['wm'] = embedding_table_vars self._private_theta['wm'] = [tf.identity(v) for v in embedding_table_vars] # If slot variables and load/retrieve ops were created before, maybe by a # different program or task, don't create it again. # Note that there should be only one copy of slot variables and # load/retrieve ops in the graph and they're shared by different # tasks/programs. all_load_ops = self._tpu_embedding_collection.load_ops if self.table_name not in all_load_ops: assert self.table_name not in self._tpu_embedding_collection.retrieve_ops # Only trainer and controller (for checkpointing) need slot variables. # Only trainer needs load/retrieve ops. if not self.do_eval and not p.is_inference: load_ops, retrieve_ops = self.optimizer.CreateSlotVariablesAndOps( embedding_table_vars, self) self._tpu_embedding_collection.AddLoadRetrieveOps( self.table_name, load_ops, retrieve_ops)
def CreateVariable(self, name: str, var_params: hyperparams.Params, **kwargs) -> None: """Create a variable of this layer according to the parameter `var_params`. E.g.:: def __init__(self, ...): # A layer's constructor self.CreateVariable( 'weight', py_utils.WeightParams(shape=[100, 100])) Args: name: Variable name which is used as the key into vars/theta. var_params: `Params` used to create the variable. **kwargs: Keyword args passed to `.py_utils.CreateVariable`. """ kwargs.setdefault('default_seed', self.params.random_seed) if self.params.device_mesh is not None: if (len([dim for dim in var_params.shape if dim > 1]) > 1 and var_params.tensor_split_dims_mapping is None): tf.logging.warning( 'tensor_split_dims_mapping missing for %s.%s: shape=%s', self.path, name, var_params.shape) self._CheckName(name) if (self.params.skip_lp_regularization and py_utils.SKIP_LP_REGULARIZATION not in var_params.collections): var_params = py_utils.WeightParams( shape=var_params.shape, dtype=var_params.dtype, init=var_params.init, collections=(var_params.collections + [py_utils.SKIP_LP_REGULARIZATION])) self._var_symbolic_shape_map[name] = var_params.shape var = py_utils.CreateVariable(name, var_params, **kwargs) self._private_vars[name] = var if py_utils.IsEagerMode(): # With eager trainer, always use the variable directly. value = var else: if self.cluster.params.worker.gpus_per_replica > 0: # On GPU (which always trains a single step per session.run()), # reference a tensor in FProp to cache it on device and avoid extraneous # sends from reading variables from ps multiple times. with tf.device(var.device): value = tf.identity(var, name=name) else: value = var self._private_theta[name] = value
def __init__(self, params): """Initializes this Model.""" assert issubclass(params.cls, BaseModel) self._global_step_var = py_utils.GetOrCreateGlobalStepVar() self._global_step = tf.identity( self._global_step_var, name='global_step_tensor') super(BaseModel, self).__init__(params) self._ema = None tp = self.params.train tf.logging.info('Training parameters for %s: %s', params.cls, tp) if tp.ema_decay > 0: assert tp.ema_decay < 1.0 self._ema = tf.train.ExponentialMovingAverage( decay=tp.ema_decay, num_updates=self.global_step)
def BeamSearchDecodeOutputToDecoderTopK(decoder_outs, *, ids_to_strings_fn, tag=''): """Converts BeamSearchDecodeOutput to DecoderTopK. As a side-effect, also creates TF nodes used by eval pipelines ("top_k_decoded" and "top_k_scores"). Args: decoder_outs: a beam_search_helper.BeamSearchDecodeOutput instance. ids_to_strings_fn: a function of (ids, lens) -> strings, where ids has shape [batch, length], lens has shape [batch], and strings has shape [batch]. tag: optional tag for tf.identity() names. Returns: A DecoderTopK instance. """ hyps = decoder_outs.topk_hyps ids = decoder_outs.topk_ids lens = tf.identity(decoder_outs.topk_lens, name='TopKLabelLengths' + tag) scores = decoder_outs.topk_scores decoded = decoder_outs.topk_decoded if decoder_outs.topk_ids is not None: ids = tf.identity(ids, name='TopKLabelIds' + tag) # With the assumption that ids[-1] is always EOS token. # TODO(b/195027707): remove EOS token in better way. decoded = ids_to_strings_fn(ids, lens - 1) decoded = tf.identity(decoded, name='top_k_decoded%s' % tag) decoded = tf.reshape(decoded, tf.shape(scores)) if scores is not None and hyps is not None: scores = tf.identity(tf.reshape(scores, tf.shape(lens)), name='top_k_scores%s' % tag) scores = tf.reshape(scores, tf.shape(hyps)) return DecoderTopK(hyps, ids, lens, scores, decoded)
def Inference(self): if py_utils.use_tpu(): raise NotImplementedError('TPU is not supported.') with tf.name_scope('inference'): feed1 = tf.placeholder(name='feed1_node', dtype=tf.float32, shape=[1]) fetch1 = tf.identity(feed1, name='fetch1_node') return { 'default': ( py_utils.NestedMap({ 'fetch1': fetch1, 'fetch_op': fetch1.op, # Tests that ops are supported. }), py_utils.NestedMap({ 'feed1': feed1, })), 'unused': (py_utils.NestedMap({}), py_utils.NestedMap({})), }
def _CreateVariableInternal(self, name, meta): """Immediately creates the variable described by `meta`. DO NOT OVERRIDE. For internal use only. Subclasses of BaseLayer should use self.CreateVariable() to create variables. Args: name: The variable name. meta: A CreateVariableMeta describing the variable to be created. """ meta.kwargs.setdefault('default_seed', self.params.random_seed) var = py_utils.CreateVariable(name, meta.var_params, **meta.kwargs) self._private_vars[name] = var if resource_variable_ops.is_resource_variable(var): value = var else: with tf.device(var.device): value = tf.identity(var) if meta.theta_fn is not None: value = meta.theta_fn(value) self._private_theta[name] = value
def _CreateVariableInternal(self, name: str, meta: CreateVariableMeta) -> None: """Immediately creates the variable described by `meta`. DO NOT OVERRIDE. For internal use only. Subclasses of BaseLayer should use self.CreateVariable() to create variables. Args: name: The variable name. meta: A CreateVariableMeta describing the variable to be created. """ meta.kwargs.setdefault('default_seed', self.params.random_seed) var = py_utils.CreateVariable(name, meta.var_params, **meta.kwargs) self._private_vars[name] = var if self.cluster.params.worker.gpus_per_replica > 0: # On GPU (which always trains a single step per session.run()), reference # a tensor in FProp to cache it on device and avoid extraneous sends from # reading variables from ps multiple times. with tf.device(var.device): value = tf.identity(var) else: # Pass the resource variable directly into the training loop. value = var # Due to b/174956514, we have to annotate the use of the variable once, # otherwise, the sharding annotation on the var will be ignored. # TODO(yonghui): Get rid of this once b/174956514 is fixed. if (meta.var_params.device_mesh is not None and var.shape.rank == len(meta.var_params.tensor_split_dims_mapping)): value = gshard_utils.MeshSplit( value, meta.var_params.device_mesh, meta.var_params.tensor_split_dims_mapping, use_sharding_op=True) if meta.theta_fn is not None: self._private_theta_fn[name] = meta.theta_fn self._private_theta[name] = value
def _WrapNonLingvoVars(dest_layer: base_layer.BaseLayer, variables: Collection[tf.Variable], trainable_variables: Collection[tf.Variable] = ()): """Adds variables to the given lingvo layer and appropriate graph collections. This function helps wrap variables created outside of lingvo so they are correctly handled by lingvo's trainer and checkpointer. It does the following: - makes all `variables` trackable through `dest_layer.vars`; - ensures `variables` are in the `tf.global_variables()` graph collection so the trainer can initialize them; - adds the `trainable_variables` subset to the `tf.trainable_variables()` graph collection, so they are visible to the learner (i.e. can be trained). Args: dest_layer: Lingvo layer to add the `variables` to. variables: The non-lingvo variables to wrap. trainable_variables: The subset of `variables` to ensure are trainable. """ global_collection = set(tf.global_variables()) for v in variables: assert v in global_collection name = v.name.split(':')[0] # pylint: disable=protected-access dest_layer._private_vars[name] = v with tf.device(v.device): dest_layer._private_theta[name] = tf.identity(v) # pylint: enable=protected-access trainable_collection = set(tf.trainable_variables()) for v in trainable_variables: if v not in trainable_collection: tf.logging.warning( 'Wrapped var %s not in trainable collection; adding it.', v.name) tf.compat.v1.add_to_collection( tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, v)
def _CreateLayerVariables(self): p = self.params w_pc = py_utils.WeightParams( shape=[self._ids_per_shard, p.embedding_dim], init=p.params_init, dtype=p.dtype, collections=[self.__class__.__name__ + '_vars']) embedding_table_vars = [] for i in range(p.num_tpu_hosts): device_name = self.GetDeviceName(i) with tf.device(device_name), py_utils.outside_all_rewrites(): var_name = self.GetVariableName(i) self.CreateVariable(var_name, w_pc) embedding_var = self.vars[var_name] embedding_table_vars.append(embedding_var) # Remove from _private_vars / _private_thetas to be added later as wm. del self._private_vars[var_name] del self._private_theta[var_name] self._tpu_embedding_collection.AddTableVariables( self.table_name, embedding_table_vars) if not py_utils.use_tpu(): # We don't want to add this for TrainerTpu, otherwise the identity # reference leads to copying the embedding to the TPU for no reason. # However, this is needed for CPU (eval/decode/controller). self._private_vars['wm'] = embedding_table_vars self._private_theta['wm'] = [ tf.identity(v) for v in embedding_table_vars ] # Only trainer and controller need slot variables and load/retrieve ops. if not self.do_eval: self._load_op_list, self._retrieve_op_list = ( self.optimizer.CreateSlotVariablesAndOps( embedding_table_vars, self))
def Export(cls, model_cfg, model_task_name=None, device_options=InferenceDeviceOptions( device='', retain_device_placement=False, var_options=None, gen_init_op=True, dtype_override=None), freeze_checkpoint=None, freeze_defaults=False, export_path=None, subgraph_filter=None, random_seed=None, disable_packed_input=True): """Exports a InferenceGraph proto with piecewise subgraphs. Sets FLAGS.enable_asserts to False unless user explicitly sets it to True. Args: model_cfg: a Params instance as returned by model_registry.GetParams(modelname, 'Test') or model_params.Model(). model_task_name: The task to generate an inference graph for. Should be None for single-task models. device_options: Device options for the accelerator used for serving. freeze_checkpoint: The checkpoint to load. Loads and freezes the model if given. freeze_defaults: Default initializes the graph and freeze. Useful for early testing of downstream tools without having a checkpoint. export_path: If not None, write the inference graph in ASCII to this path. subgraph_filter: A list of subgraph names. If not None or empty, export only this list of inference subgraphs. random_seed: Fixes the random seed in the exported inference graph. disable_packed_input: Disable packed input for inference writing purposes. Returns: InferenceGraph proto. Raises: ValueError: if the model does not support the listed subgraphs. """ assert issubclass(model_cfg.cls, base_model.BaseModel) # Disable assertions unless user explicitly enables it. if FLAGS['enable_asserts'].using_default_value: FLAGS.enable_asserts = False # TODO(laurenzo): Work out how much we need to specify here in terms of # cluster configuration. cls._SetClusterParams(model_cfg.cluster, device_options) # Configure the model. model_cfg.random_seed = random_seed model_cfg.is_inference = True if disable_packed_input: def _DisablePackedInput(task): if (_ParamExists(task, 'encoder') and _ParamExists(task.encoder, 'packed_input')): task.encoder.packed_input = False if (_ParamExists(task, 'decoder') and _ParamExists(task.decoder, 'packed_input')): task.decoder.packed_input = False if issubclass(model_cfg.cls, base_model.MultiTaskModel): for _, task_param in model_cfg.task_params.IterParams(): _DisablePackedInput(task_param) else: _DisablePackedInput(model_cfg.task) tf.logging.info('Model %s. Params: %s', model_cfg.name, model_cfg.ToText()) # Instantiate the graph. graph = tf.Graph() with graph.as_default(): tf.random.set_seed(random_seed) cluster = model_cfg.cluster.Instantiate() device = cluster.GetPlacer() tpu_const_scope = _DummyScope() if (IsTpu(device_options) and device_options.var_options == 'AS_CONSTANTS'): # Do not specify devices for variables if we are marking them as # constants. device = '' tpu_const_scope = ConstGuaranteeScope() with cluster, tf.device(device), tpu_const_scope: bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations( device_options) if bfloat16_override: py_utils.UpdateDtype(model_cfg, tf.bfloat16) py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16) # Hard-code TPU-related flags prior to instantiating model. old_enable_asserts = FLAGS.enable_asserts old_xla_device = FLAGS.xla_device if IsTpu(device_options): FLAGS.enable_asserts = False FLAGS.xla_device = 'tpu' # Ensure the global_step variable is created. global_step_var = py_utils.GetOrCreateGlobalStepVar() global_step = tf.identity(global_step_var, name='global_step_tensor') with py_utils.GlobalStepContext(global_step): try: mdl = model_cfg.Instantiate() variables_to_restore = ( _MakeVariableDictionary(tf.global_variables()) if not mdl.ema else mdl.ema.variables_to_restore(mdl.variables_for_ema)) if bfloat16_override: saver_var_spec = ( bfloat16_variables .get_saver_spec_for_variables_with_bf16_overrides( variables_to_restore)) else: saver_var_spec = variables_to_restore saver = tf.train.Saver(saver_var_spec) tf.variables_initializer( tf.global_variables(), name='init_all_variables') if IsTpu(device_options) and device_options.gen_init_op: tf.group(tf.tpu.initialize_system(), name='tpu_init_op') model_task = mdl.GetTask(model_task_name) inference_graph_proto = inference_graph_pb2.InferenceGraph() subgraphs_proto = model_task.Inference() if isinstance(subgraphs_proto, dict): subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto) for name, subgraph in subgraphs_proto.subgraphs.items(): if not subgraph_filter or name in subgraph_filter: inference_graph_proto.subgraphs[name].CopyFrom(subgraph) # Add a table init op and global variable init op to the graph. # Tables can be declared anywhere in the graph, so this op has to be # added last. tf.tables_initializer(name='init_all_tables') finally: # Reset TPU-related flags after model instantiation. FLAGS.enable_asserts = old_enable_asserts FLAGS.xla_device = old_xla_device tf.logging.info('Graph contains ops: %r', [op.name for op in graph.get_operations()]) inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def()) # Freezing. if freeze_defaults or freeze_checkpoint: output_op_names = GetOutputOpNames( graph, inference_graph_proto, preserve_colocation_nodes=False) if cls._DeviceSupportsFreezing(device_options): raise ValueError('freeze_checkpoint cannot be used with device ' + device_options.device) if freeze_checkpoint: tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint) graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint, output_op_names) elif freeze_defaults: tf.logging.info('Default initializing graph and freezing.') graph_def = _FreezeDefaults(graph, output_op_names) else: output_op_names = GetOutputOpNames(graph, inference_graph_proto) # Prune the graph to just the parts we need. # To support restoring, we have to not prune out the restore node. output_op_names.append('init_all_tables') output_op_names.append('init_all_variables') output_op_names.append('save/control_dependency') output_op_names.append('save/restore_all') if IsTpu(device_options) and device_options.gen_init_op: output_op_names.append('tpu_init_op') graph_def = graph.as_graph_def() tf.logging.info('Pruning graph to output ops: %r', output_op_names) graph_def = tf.graph_util.extract_sub_graph(graph_def, output_op_names) if not device_options.retain_device_placement: # Clear the device so that the runtime can choose. tf.logging.info('Clearing device placement for: %s', device_options.device) for node in graph_def.node: node.ClearField('device') for function in graph_def.library.function: for node_def in function.node_def: node_def.ClearField('device') inference_graph_proto.graph_def.CopyFrom(graph_def) if export_path: with tf.io.gfile.GFile(export_path, 'w') as f: f.write(text_format.MessageToString(inference_graph_proto)) return inference_graph_proto
def FProp(self, theta, input_tensor): p = self.params if self._output_tensor is not None: raise ValueError('FProp was already called.') def _Gradient(inputs, _, original_grad): # Compute the gradients for each loss w.r.t. the inputs. # TODO(jngiam): Look into whether TF dedups this computation. per_loss_grads = [] for loss, _ in self._losses: per_loss_grad = tf.gradients(loss, self._output_tensor)[0] if per_loss_grad is None: tf.logging.warning( 'Loss %s did not result in a gradient during ' 'GradDrop computation.', loss) else: per_loss_grads.append(per_loss_grad) if not per_loss_grads: raise ValueError('No valid gradients for GradDrop.') # Multiply the gradients with the inputs. grads = per_loss_grads if p.use_input_sign_only: input_abs = tf.abs( tf.cast(tf.abs(inputs) <= p.epsilon, tf.float32) + inputs) grads = [grad * ((inputs) / (input_abs)) for grad in grads] else: grads = [grad * inputs for grad in grads] # Sum gradient over batch, assuming that batch is always on dim 0. if p.marginalize_batch_dim: grads = [ tf.reduce_sum(grad, axis=0, keepdims=True) for grad in grads ] # First discretize all gradients into their sign values. grad_sign_positive = [ tf.cast(grad > 0.0, tf.float32) for grad in grads ] grad_sign_negative = [ tf.cast(grad < 0.0, tf.float32) for grad in grads ] # Calculate the probability of positive gradients based on equation (1) # in the GradDrop paper. grad_abs_sum = tf.add_n([tf.abs(grad) for grad in grads]) prob_pos = (tf.add_n(grads) / (2. * grad_abs_sum + p.epsilon)) # Implementation of different scales for the keep function. Larger # scales result in steeper keep functions. prob_pos *= p.keep_prob_function_scale if p.keep_prob_function == 'sigmoid': # Standard sigmoid has derivative of 0.25 at 0 so the factor of 4.0 # allows the function scale in sigmoid to be compatible with the # function scale in the linear case. prob_pos = tf.sigmoid(4.0 * prob_pos) elif p.keep_prob_function == 'linear': prob_pos += 0.5 # The main, default mode of GradDrop. Only gradients of one sign are kept, # and which sign is calculated via equation (1) of the main paper. prob_pos = tf.cast(prob_pos >= tf.random.uniform(prob_pos.shape), tf.float32) - 0.5 grad_masks = [ (gsp - gsn) * prob_pos >= 0 for (gsn, gsp) in zip(grad_sign_negative, grad_sign_positive) ] # This diag value gives us the percentage of grads which are kept. gradmask_diag = [tf.cast(gm, tf.float32) for gm in grad_masks] diag = tf.reduce_mean(tf.add_n(gradmask_diag) / len(grad_masks)) summary_utils.scalar('average_grad_mask', diag) leak_ratios = [leak_ratio for _, leak_ratio in self._losses] transformed_per_loss_grads = [ grad * (leak + (1.0 - leak) * tf.cast(grad_mask, tf.float32)) for (leak, grad, grad_mask) in zip(leak_ratios, per_loss_grads, grad_masks) ] transformed_grad = tf.cast(tf.add_n(transformed_per_loss_grads), original_grad.dtype) if not p.keep_gradnorm_constant: return transformed_grad transformed_grad_norm = tf.sqrt(tf.reduce_sum(transformed_grad**2)) original_grad_norm = tf.sqrt(tf.reduce_sum(original_grad**2)) return transformed_grad * original_grad_norm / ( transformed_grad_norm + p.epsilon) output_tensor = py_utils.CallDefun(tf.identity, input_tensor, _Gradient) self._output_tensor = tf.identity(output_tensor) return self._output_tensor
def try_apply_dense(self, grad, var): assert grad is not None cond = tf.constant(True) is_finite_checks = [] stats = {} grad_dtype = var.dtype # TODO(lepikhin): add to params grad = tf.cast(grad, grad_dtype) factored_dims = self._factored_dims(var.shape.as_list()) if factored_dims: vr = self.get_slot(var, 'vr') vc = self.get_slot(var, 'vc') else: v = self.get_slot(var, 'v') if self._beta1: m = self.get_slot(var, 'm') def _Upd(c, k, x): stats[k] = x is_finite_checks.append(tf.reduce_all(tf.math.is_finite(x))) return c with tf.variable_scope(var.name[:-2] + '/Adafactor'): grad_squared = tf.math.square(grad) + tf.cast( self._epsilon1, grad_dtype) cond = _Upd(cond, 'grad_squared', grad_squared) # 0 (factored) decay_rate = tf.cast(self._decay_rate, var.dtype) old_val = tf.identity( var) # TODO(lepikhin): introduce gradient dtype assert self._multiply_by_parameter_scale lr = GetLrValue(self._learning_rate) if self._multiply_by_parameter_scale: parameter_scale = self._parameter_scale(old_val) cond = _Upd(cond, 'parameter_scale', parameter_scale) # 1 (factored) update_scale = self._parameter_scale(old_val) * tf.cast( lr, grad_dtype) else: update_scale = lr mixing_rate = tf.cast(1.0 - decay_rate, grad_dtype) update_scale = tf.cast(update_scale, grad_dtype) if factored_dims: d0, d1 = factored_dims vr_axis, vc_axis = d0, d1 grad_squared_row_mean = tf.reduce_mean(grad_squared, axis=vr_axis) grad_squared_col_mean = tf.reduce_mean(grad_squared, axis=vc_axis) # new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean) new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate # new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean) new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate cond = _Upd(cond, 'new_vr', new_vr) # 2 (factored) cond = _Upd(cond, 'new_vc', new_vc) # 3 (factored) # vr_update = _Wrap(tf.assign, vr, new_vr) # vc_update = _Wrap(tf.assign, vc, new_vc) # updates.extend([vr_update, vc_update]) long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True) r_factor = tf.math.rsqrt(new_vr / long_term_mean) c_factor = tf.math.rsqrt(new_vc) mult = tf.expand_dims(r_factor, vr_axis) * tf.expand_dims( c_factor, vc_axis) cond = _Upd(cond, 'mult', mult) # 4 (factored) x = grad * mult else: new_v = v * decay_rate + grad_squared * mixing_rate cond = _Upd(cond, 'new_v', new_v) # v_update = _Wrap(tf.assign, v, new_v) # updates.append(v_update) x = grad * tf.math.rsqrt(new_v) assert self._clipping_threshold is not None if self._clipping_threshold is not None: clipping_denom = tf.maximum( tf.constant(1.0, grad_dtype), py_utils.ReduceRms(x) / tf.constant(self._clipping_threshold, grad_dtype)) x /= clipping_denom cond = _Upd(cond, 'x', x) subtrahend = x * update_scale if self._beta1: new_m = (m * tf.constant(self._beta1, dtype=grad_dtype) + subtrahend * tf.constant(1.0 - self._beta1, dtype=grad_dtype)) subtrahend = new_m cond = _Upd(cond, 'new_m', new_m) # updates.append(_Wrap(tf.assign, m, new_m)) # It is critical to use assign_sub instead of tf.assign(var - subtrahend) # for the case of bfloat16 activations, so as to avoid repeatedly # rounding the slice value, which results in poor quality. cond = _Upd(cond, 'subtrahend', subtrahend) # 5 (factored) # var_update = _Wrap(tf.assign_sub, var, subtrahend) # updates.append(var_update) return is_finite_checks, stats
def _resource_apply_dense(self, grad, var): if grad is None: tf.logging.warning('Gradient is None for variable %s' % var.name) return [] grad_dtype = var.dtype # TODO(lepikhin): add to params grad = tf.cast(grad, grad_dtype) factored_dims = self._factored_dims(var.shape.as_list()) if factored_dims: vr = self.get_slot(var, 'vr') vc = self.get_slot(var, 'vc') else: v = self.get_slot(var, 'v') if self._beta1: m = self.get_slot(var, 'm') cond = tf.constant(True) def _Upd(c, x): if not self._cond_is_finite: return c c = tf.math.logical_and(c, tf.reduce_all(tf.math.is_finite(x))) c = tf.math.logical_and( c, tf.reduce_all(tf.math.logical_not(tf.math.is_inf(x)))) return c def _Wrap(fn, x, y): if not self._cond_is_finite: return fn(x, y) return tf.cond(cond, lambda: fn(x, y), lambda: x) with tf.variable_scope(var.name[:-2] + '/Adafactor'): grad_squared = tf.math.square(grad) + tf.cast( self._epsilon1, grad_dtype) cond = _Upd(cond, grad_squared) decay_rate = tf.cast(self._decay_rate, var.dtype) old_val = tf.identity( var) # TODO(lepikhin): introduce gradient dtype lr = GetLrValue(self._learning_rate) if self._multiply_by_parameter_scale: update_scale = self._parameter_scale(old_val) * tf.cast( lr, grad_dtype) else: update_scale = lr mixing_rate = tf.cast(1.0 - decay_rate, grad_dtype) update_scale = tf.cast(update_scale, grad_dtype) updates = [] if factored_dims: d0, d1 = factored_dims vr_axis, vc_axis = d0, d1 grad_squared_row_mean = tf.reduce_mean(grad_squared, axis=vr_axis) grad_squared_col_mean = tf.reduce_mean(grad_squared, axis=vc_axis) # new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean) new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate # new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean) new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate cond = _Upd(cond, new_vr) cond = _Upd(cond, new_vc) vr_update = _Wrap(tf.assign, vr, new_vr) vc_update = _Wrap(tf.assign, vc, new_vc) updates.extend([vr_update, vc_update]) long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True) r_factor = tf.math.rsqrt(new_vr / long_term_mean) c_factor = tf.math.rsqrt(new_vc) x = grad * tf.expand_dims(r_factor, vr_axis) * tf.expand_dims( c_factor, vc_axis) else: new_v = v * decay_rate + grad_squared * mixing_rate cond = _Upd(cond, new_v) v_update = _Wrap(tf.assign, v, new_v) updates.append(v_update) x = grad * tf.math.rsqrt(new_v) if self._clipping_threshold is not None: clipping_denom = tf.maximum( tf.constant(1.0, grad_dtype), py_utils.ReduceRms(x) / tf.constant(self._clipping_threshold, grad_dtype)) x /= clipping_denom subtrahend = x * update_scale if self._beta1: new_m = (m * tf.constant(self._beta1, dtype=grad_dtype) + subtrahend * tf.constant(1.0 - self._beta1, dtype=grad_dtype)) subtrahend = new_m cond = _Upd(cond, new_m) updates.append(_Wrap(tf.assign, m, new_m)) # It is critical to use assign_sub instead of tf.assign(var - subtrahend) # for the case of bfloat16 activations, so as to avoid repeatedly # rounding the slice value, which results in poor quality. cond = _Upd(cond, subtrahend) var_update = _Wrap(tf.assign_sub, var, subtrahend) updates.append(var_update) return tf.group(*updates)
def InstantiateVariables(self): with py_utils.GlobalStepContext( tf.identity(self._global_step_var, name='global_step_tensor')): super().InstantiateVariables()
def __init__(self, learning_rate, momentum=0.0, initial_accumulator_value=0.0, start_preconditioning_steps=1000, statistics_computation_frequency=1, matrix_epsilon=1e-6, synchronous_preconditioning=False, second_moment_averaging=1.0, fallback_to_diagonal_dim=4096, max_any_dim=6656, block_size=4096, block_partition_threshold_size=1000000, global_step=None, exponent_multiplier=1.0, name="DistributedShampoo"): """Construct a DistributedShampoo optimizer. Args: learning_rate: A `Tensor` or a floating point value. The learning rate. momentum: A `Tensor` or a floating point value. Momentum is not applied to sparse updates. initial_accumulator_value: A floating point value. start_preconditioning_steps: A int32 value which indicates when to start preconditioning. statistics_computation_frequency: A int32 step value which indicates how often to compute statistics for preconditioning. matrix_epsilon: An epsilon regularizer to make the matrices positive definite. synchronous_preconditioning: Whether to run preconditioning synchronously. second_moment_averaging: 1.0 means sum of gradients squares, while less than 1.0 switches to RMSProp style exponential moving averages of the second moments. fallback_to_diagonal_dim: Fallback to diagonal version of AFMA if the any of the dimension is larger than fallback_to_diagonal_dim. max_any_dim: If maximum value for any dimension is greater than this value we skip preconditioning and fall back to the diagonal. block_size: Dimension of the partitioned tensors. block_partition_threshold_size: Partitions diemnsions beyond this size. global_step: Global step for training. exponent_multiplier: A multiplier 'e` for the exponent for the inverse calculation. e * -1/(2*rank). Only applies when calculating inverses through svd. name: Optional name prefix for the operations created when applying gradients. """ super().__init__(False, name) self._learning_rate = learning_rate self._momentum = momentum self._initial_accumulator_value = initial_accumulator_value self._start_preconditioning_steps = start_preconditioning_steps self._matrix_epsilon = matrix_epsilon self._synchronous_preconditioning = synchronous_preconditioning self._second_moment_averaging = second_moment_averaging self._fallback_to_diagonal_dim = fallback_to_diagonal_dim self._max_any_dim = max_any_dim self._block_size = block_size # NOTE: On XLA - int64 is not handled properly. if global_step is not None: self._global_step = tf.cast(tf.identity(global_step), tf.int32) else: self._global_step = tf.cast( tf.identity(tf.train.get_or_create_global_step()), tf.int32) self._run_nondiagonal_update = tf.greater_equal( self._global_step, self._start_preconditioning_steps) start_steps_f = tf.cast(self._start_preconditioning_steps, tf.float32) global_step_f = tf.cast(self._global_step, tf.float32) self._run_nondiagonal_update_warmup = tf.minimum( 1.0, tf.maximum((global_step_f - start_steps_f) / start_steps_f, 0.0)) # Computes statistics every K steps. self._statistics_computation_frequency = statistics_computation_frequency self._run_statistics_computation = tf.equal( tf.math.floormod(self._global_step, self._statistics_computation_frequency), 0) # All vars that are preconditioned. self._all_vars_for_preconditioning = [] self._exponent_multiplier = exponent_multiplier self._partition_info = PartitionConfig(block_partition_threshold_size, block_size) self._partitioner_metadata = {}
def __init__(self, params): assert issubclass(params.cls, BaseTask) # Ensure global_step exists before calling super. py_utils.GetOrCreateGlobalStepVar() super(BaseTask, self).__init__(params) p = self.params if p.input: # TODO(zhifengc): Consider a simpler way to ensure the input # generator stops after one epoch. if p.is_eval and p.eval: seq_inp = issubclass(p.input.cls, base_input_generator.BaseInputGeneratorFromFiles) if p.input.num_samples == 0: # Dataset size is unknown. Computes eval summary based on num_samples. assert p.eval.samples_per_summary > 0 elif (p.eval.samples_per_summary == 0) or (p.input.num_samples < p.eval.samples_per_summary): # If we know the dataset size and we want to evaluate the full # set, we need to coordinate the input generator to flush out # all samples so the evaler and decoder compute metrics on the # whole set for each summary step. if seq_inp: p.input.flush_every_n = p.input.num_samples p.eval.samples_per_summary = p.input.num_samples if seq_inp and p.input.num_batcher_threads > 1: tf.logging.warning('input.num_batcher_threads > 1 inside eval mode. ' 'The input generator may not iterate over exactly ' 'one epoch per run') tf.logging.info('input_params: %s', p.input) input_params = self.cluster.PlaceInput(p.input) with py_utils.outside_all_rewrites(): self.CreateChild('input', input_params) self._encoder = None self._online_encoder = None self._decoder = None self._loss = None self._num_predictions = None self._train_op = None self._eval_metrics = {} self._per_example = {} self._trainer_verbose_tensors = {} # Create the gradient mask, self._per_input_gradient_mask = None task_global_step_list = tf.get_collection('TASK_GLOBAL_STEP', '^%s_global_step' % p.name) if len(task_global_step_list) > 1: raise ValueError('Found multiple task_global_step for task %s' % p.name) self._global_step_var = ( task_global_step_list[0] if len(task_global_step_list) == 1 else py_utils.GetOrCreateGlobalStepVar()) self._global_step = tf.identity( self._global_step_var, name='global_step_tensor') tp = p.train # p.train can be None if this task is the teacher/student task in a # DistillationTask. if tp and self.cluster.job in ('worker', 'trainer', 'trainer_client', 'controller', 'executor_tpu'): self._SetLearnerFromLegacyParams(tp) if tp.learner is not None: if isinstance(tp.learner, (list, tuple)): self.CreateChildren('learners', tp.learner) else: self.CreateChildren('learners', [tp.learner]) self._UpdateVnConfig()
def AddAttentionSummaryBatchMajor(name, attention_tensors, src_paddings, tgt_paddings, transcripts=None, max_outputs=3): """Adds an image summary showing the attention probability matrix and state. As opposed to AddAttentionSummary() takes all tensors with batch dimension in axis 0. Args: name: Summary name. attention_tensors: A list of 3D tensors shaped [batch_size, target_len, source_len] where attention[b, i, j] is the probability for the i-th output attending to the j-th input for element b in the batch. src_paddings: A tensor of binary paddings shaped [batch, source_len] for the source sequence. Or a list of tensors of the same length as attention_tensors with a separate paddings for each entry in attention_tensors. tgt_paddings: A tensor of binary paddings shaped [batch, target_len] for the target sequence. Or a list of tensors of the same length as attention_tensors with a separate paddings for each entry in attention_tensors. transcripts: Optional, transcripts shaped [batch, source_len] for the source sequence. max_outputs: Integer maximum number of elements of the batch to plot. """ def VerifyLen(paddings): length = len(paddings) if isinstance(paddings, list) else 1 if length != 1 and length != len(attention_tensors): raise ValueError('Bad length of paddings list {}'.format(length)) VerifyLen(src_paddings) VerifyLen(tgt_paddings) # Verify shapes. for i, attention_tensor in enumerate(attention_tensors): src, tgt = src_paddings, tgt_paddings src = src[0 if len(src) == 1 else i] if isinstance(src, list) else src tgt = tgt[0 if len(tgt) == 1 else i] if isinstance(tgt, list) else tgt tgt_shape = py_utils.GetShape(tgt) attention_tensors[i] = tf.identity( py_utils.with_dependencies([ py_utils.assert_equal( py_utils.GetShape(attention_tensor), tgt_shape[:2] + [py_utils.GetShape(src)[1]] + tgt_shape[2:]) ], attention_tensor), re.sub(':.*$', '', GetTensorName(attention_tensor, name, i))) if not _ShouldAddSummary(): return def ToLengths(paddings): paddings = paddings if isinstance(paddings, list) else [paddings] return [SequenceLength(p) for p in paddings] def Get(lengths, i): return lengths[0 if len(lengths) == 1 else i] src_lens = ToLengths(src_paddings) tgt_lens = ToLengths(tgt_paddings) with plot.MatplotlibFigureSummary(name + '/Attention', max_outputs=max_outputs, gridspec_kwargs={'hspace': 0.3}) as fig: for n, atten in enumerate(attention_tensors): # Diagnostic metric that decreases as attention picks up. max_entropy = tf.math.log(tf.cast(Get(src_lens, n), tf.float32)) max_entropy = tf.expand_dims(tf.expand_dims(max_entropy, -1), -1) atten_normalized_entropy = -atten * tf.math.log( atten + 1e-10) / max_entropy scalar(name + '/Attention/average_normalized_entropy/%d' % n, tf.reduce_mean(atten_normalized_entropy)) args = [atten, Get(src_lens, n), Get(tgt_lens, n)] if transcripts is not None and n == 0: args.append(transcripts) fig.AddSubplot(args, TrimPaddingAndPlotAttention, title=GetTensorName(atten, name, n), xlabel='Input', ylabel='Output')