def _CreateQStateVar(self, t_name, suffix, params): name = t_name + '_' + suffix assert name not in self._qvars, 'QState var already exists: %s' % name var_name = self._qvars_scope.name + '/' + name with tf.variable_scope(py_utils.GetGlobalVariableScope()): v = py_utils.CreateVariable(var_name, params, trainable=False) self._qvars[name] = v return v
def CreateTaskGlobalStep(task_name): """Create if needed and return the global_step.""" with tf.name_scope(None), tf.variable_scope( py_utils.GetGlobalVariableScope()): graph_collections = [tf.GraphKeys.GLOBAL_VARIABLES, 'TASK_GLOBAL_STEP'] _, v = py_utils.CreateVariable( name=task_name + '_global_step', params=py_utils.WeightParams([], py_utils.WeightInit.Constant(0), tf.int64), trainable=False, collections=graph_collections) summary_utils.scalar(v.name, v) return v
def __init__(self, params): assert issubclass(params.cls, BaseTask) # Ensure global_step exists before calling super. py_utils.GetOrCreateGlobalStepVar() super().__init__(params) p = self.params self._encoder = None self._online_encoder = None self._decoder = None self._loss = None self._num_predictions = None self._train_op = None self._post_train_ops = [] self._eval_metrics = {} self._per_example = {} # Create the gradient mask, self._per_input_gradient_mask = None if p.task_global_step: with tf.name_scope(None), tf.variable_scope( py_utils.GetGlobalVariableScope()): var_name = p.name + '_global_step' # Create the variable immediately. self._CreateVariableInternal( var_name, base_layer.CreateVariableMeta( var_params=py_utils.WeightParams( [], py_utils.WeightInit.Constant(0), tf.int64), theta_fn=None, kwargs=dict( trainable=False, collections=[tf.GraphKeys.GLOBAL_VARIABLES]))) summary_utils.scalar(var_name, self._private_vars[var_name]) self._global_step_var = self._private_vars[var_name] else: self._global_step_var = py_utils.GetOrCreateGlobalStepVar() if p.input: # TODO(zhifengc): Consider a simpler way to ensure the input # generator stops after one epoch. if self.do_eval and p.eval: seq_inp = issubclass(p.input.cls, base_input_generator.BaseInputGeneratorFromFiles) if p.input.num_samples == 0: # Dataset size is unknown. Computes eval summary based on num_samples. assert p.eval.samples_per_summary > 0 elif (p.eval.samples_per_summary == 0) or (p.input.num_samples < p.eval.samples_per_summary): # If we know the dataset size and we want to evaluate the full # set, we need to coordinate the input generator to flush out # all samples so the evaler and decoder compute metrics on the # whole set for each summary step. if seq_inp: p.input.flush_every_n = p.input.num_samples p.eval.samples_per_summary = p.input.num_samples if seq_inp and p.input.num_batcher_threads > 1: tf.logging.warning( 'input.num_batcher_threads > 1 inside eval mode. ' 'The input generator may not iterate over exactly ' 'one epoch per run') tf.logging.info('input_params: %s', p.input) input_params = self.cluster.PlaceInput(p.input) # For TPU training, we create the input generator in a # different scope and AddChild it in later. if 'skip_create_child' not in p.input: self.CreateChild('input', input_params) tp = p.train # p.train can be None if this task is the teacher/student task in a # DistillationTask. if tp: self._SetLearnerFromLegacyParams(tp) if tp.learner is not None: if isinstance(tp.learner, (list, tuple)): self.CreateChildren('learners', tp.learner) else: self.CreateChildren('learners', [tp.learner]) self._UpdateVnConfig()