def _LoopEnqueue(self, op, session_override=None): """Runs the enqueue op in a loop.""" p = self.params sess = session_override or self._GetSession() with tf.container(self._container_id), sess: if self._initialize_tables is not None: sess.run(self._initialize_tables) gsteps = py_utils.GetGlobalStep() local_enqueue_steps = 0 # Global enqueue steps measures how many global steps have data enqueued # for already. We use this to terminate; note that the enqueue op may # hang in session.run if we do not terminate with this check. global_enqueue_steps = None tf.logging.info( 'params.train.max_steps: %d, enqueue_max_steps: %d', p.train.max_steps, p.train.enqueue_max_steps) while True: if self._dequeue_thread_complete: tf.logging.info( 'LoopEnqueue done since consuming thread is done.') return global_step = sess.run(gsteps) if global_enqueue_steps is None: global_enqueue_steps = global_step if local_enqueue_steps % 1000 == 0: tf.logging.info( 'Current global_enqueue_steps: %d, ' 'local_enqueue_steps: %d, global_step: %d', global_enqueue_steps, local_enqueue_steps, global_step) if py_utils.use_tpu(): global_steps_with_available_data = int( global_enqueue_steps // p.train.tpu_steps_per_loop * p.train.tpu_steps_per_loop) else: global_steps_with_available_data = global_enqueue_steps if (self._ShouldStop(sess, global_steps_with_available_data) or self._ShouldStop(sess, global_step)): tf.logging.info('Done. ShouldStop is True.') tf.logging.info('Enqueue loop sleeping') time.sleep(15) continue if (p.train.enqueue_max_steps > 0 and local_enqueue_steps >= p.train.enqueue_max_steps): tf.logging.info('Done. train.enqueue_max_steps reached.') tf.logging.info('Enqueue loop sleeping') time.sleep(15) continue local_enqueue_steps += 1 # There are tpu_infeed_parallelism parallel threads enqueuing. # We account for all of them when updating global_enqueue_steps. global_enqueue_steps += p.input.tpu_infeed_parallelism sess.run([op])
def _Loop(self): with tf.container(self._container_id), self._GetSession( cluster_def=self._cluster_def, disable_meta_optimizer=FLAGS.disable_meta_optimizer_in_executor ) as sess: # Initialize the variables first, if needed. for program in self._programs: program.RestoreIfNeeded(sess) program.Compile(sess) sess.run(self._initialize_tables) sess.run(self._initialize_local_vars) while True: global_step = sess.run(py_utils.GetGlobalStep()) if self._ShouldStop(sess, global_step): tf.logging.info('Training finished.') if not self._ml_perf_log: self.save_only_checkpointer.Save(sess, global_step) return # If a task is explicitly selected, only run the programs associated # with that task. if self._single_task_mode or self._model_task_name: tf.logging.info('Single task mode: %s', self._model_task_name) program_schedule = self._program_schedule_dict[self._model_task_name] else: # Otherwise, sample a task. model_task = self.task_scheduler.Sample(global_step) tf.logging.info('Sampled %s', model_task) program_schedule = self._program_schedule_dict[model_task] done = program_schedule.Run(sess) if done: tf.logging.info('Program schedule told us to stop.') return # TODO(blee): More complex saving rules. Currently, we assume # we save after every task's program schedule execution. # # global_step local variable above is a result of sess.run, not a # tf variable, so when we do save_only_checkpointer.Save(...) here # py_utils.GetGlobalStep() is ahead of it by # (train_executions_per_eval * train_steps_per_loop) # steps ahead already, due to program_schedule.Run(sess). # if not self._ml_perf_log: self.save_only_checkpointer.Save(sess, py_utils.GetGlobalStep())
def __init__(self, params): """Layer constructor. Sub-classes of BaseLayer should decorator its __init__ with @base_layer.initializer Args: params: A params used to construct this layer. """ assert params.name, ('Layer params for %s must have a "name"' % self.__class__.__name__) tf_module_name = params.name tf_module_name = re.sub('[^a-zA-Z0-9_]+', '_', tf_module_name) tf_module_name = 'bbf_' + self.__class__.__name__ + '_' + tf_module_name py_utils.NestedMap.CheckKey(tf_module_name) # initialize the base class. super(BaseLayer, self).__init__(tf_module_name) # Note AutoTracking doesn't work properly due to its inability to walk # through py_utils.NestedMap data structures which are used widely # throughout the Lingvo codebase. Also there seems to be some performance # hit in turning on auto-tracking in constructing graphs. For now, we # disable auto-tracking. # TODO(lingvo): Re-enable auto-tracking when fuller support is # added for key data structures used in Lingvo, and performance issue is # debugged more and understood better. self._setattr_tracking = False self._parent = (_LAYER_STACK.layer_stack[-2] if len(_LAYER_STACK.layer_stack) > 1 else None) assert self._parent is not self self._params = params.Copy() tf.logging.debug('Creating layer %s with params: \n %s \n', self.__class__.__name__, str(params)) # Vars created by this layer. self._private_vars = py_utils.NestedMap() # Theta derived from this layer's vars. self._private_theta = py_utils.NestedMap() # Child layers created by this layer through CreateChild/CreateChildren. self._private_children = py_utils.NestedMap() # Child layers created by this layer. A well-formed layer should # have self._private_children equals to self._children_list. I.e., # all child layers are created using CreateChild/CreateChildren. self._children_list = [] # Extra theta's not directly correpond to any underlying vars. For example, # the concatenated sharded variables. self._extra_theta = py_utils.NestedMap() # All registered accumulators. self._private_accumulators = py_utils.NestedMap() # Layer-private functions. Add with AddFunction. self._private_fns = dict() # Mapping from variable names to its symbolic shape. # self._var_symbolic_shape_map['var_name'] will be a tuple of integers or # symbolic expressions, one for each dimension of the variable. self._var_symbolic_shape_map = dict() self.AddExtraTheta('global_step', py_utils.GetGlobalStep())
def GetOverWriteGlobalStep(graph=None): graph = graph or tf.get_default_graph() mb_tensors = graph.get_collection_ref(_OVERWRITE_GLOBAL_STEP_COLLECTION) if len(mb_tensors) == 1: mb_tensor = mb_tensors[0] else: mb_tensor = py_utils.GetGlobalStep() return mb_tensor
def _WaitUntilInit(self, sess, start_up_delay_steps=None): """Wait until the model is ready.""" try: global_step = sess.run(py_utils.GetGlobalStep()) except tf.errors.FailedPreconditionError as e: tf.logging.info( '%s: Probably the expected race on global_step: %s', self._job_name, e) raise msg = 'step:%6d' % global_step self._SetStatusMessage(msg) if start_up_delay_steps: if global_step < start_up_delay_steps: msg = 'global step (%d) has not reached start up delay steps (%d)' % ( global_step, self._start_up_delay_steps) tf.logging.info('%s: %s', self._job_name, msg) raise tf.errors.FailedPreconditionError(node_def=None, op=None, message=msg) return global_step
def RestoreGlobalStepIfNeeded(self, sess): """If global step is not initialized, load it from the checkpoint. Args: sess: tf.Session. """ assert not self._save_only uninitialized_vars = self._GetUninitializedVarNames(sess) if six.ensure_binary('global_step') not in uninitialized_vars: return with sess.graph.as_default(): gstep = py_utils.GetGlobalStep() path = tf.train.latest_checkpoint(self._train_dir) if path: reader = tf.train.NewCheckpointReader(path) value = reader.get_tensor('global_step') tf.logging.info('Restoring global step: %s', value) sess.run(gstep.assign(value)) else: tf.logging.info('Initializing global step') sess.run(gstep.initializer)
def FProp(self, theta, *args): """Run multiple cells in different devices in a pipelining manner. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. *args: Non-keyworded variable length argument list of input tensors. Returns: A list of output tensors """ # TODO(huangyp): handle optional None inputs. p = self.params if self.do_eval: outputs = copy.copy(args) for (name, l) in self._before_layers + self._cells: outputs = _ToTuple(outputs) outputs = l.FProp(theta[name], *outputs) return outputs num_cells = len(p.cell_tpl) cluster = self.cluster # Compute shapes of input and output tensors. input_shapes = self._get_input_shapes(*args) state_dtype = self._get_state_dtype(*args) state_shapes = self._CalculateOutputShapes(input_shapes) tf.logging.info('state_shapes={}'.format(state_shapes)) def GetCellFn(i): """Get the ith feature extraction layer.""" def CellFn(theta, state0, inputs): """A cell fn is exectued inside of StackedRecurrent.""" del state0 def _FPropInputSetShape(name, t_shape): if t_shape is None: return None inputs[name].set_shape(t_shape.ToTensorShape().as_list()) return inputs[name] if p.nested_map_fprop: # pylint: disable=protected-access fprop_inputs = state_shapes[i]._RecursiveMap(_FPropInputSetShape) # pylint: enable=protected-access else: fprop_inputs = [] for input_idx, input_shape in enumerate(state_shapes[i]): name = 's{}'.format(input_idx) fprop_inputs.append(_FPropInputSetShape(name, input_shape)) with py_utils.RemoveAssertContext(remove=True): with CellFnFPropOpReplacementWrapper(): tf.logging.info('cell {} input {}'.format(i, fprop_inputs)) mb_tensor = inputs[_MICRO_BATCH_STATE_NAME] SetOverWriteGlobalStep(mb_tensor) _, cell = self._cells[i] fprop_inputs = _ToTuple(fprop_inputs) outputs = cell.FProp(theta, *fprop_inputs) if p.nested_map_fprop: assert py_utils.IsCompatible(outputs, state_shapes[i + 1]) state1 = outputs.Filter(lambda x: x is not None) else: state1 = py_utils.NestedMap() outputs = _ToTuple(outputs) assert len(outputs) == len(state_shapes[i + 1]) for output_idx in range(len(outputs)): if outputs[output_idx] is not None: name = 's{}'.format(output_idx) state1[name] = outputs[output_idx] state1[_MICRO_BATCH_STATE_NAME] = mb_tensor return state1, py_utils.NestedMap() return CellFn cell_fns = [] accumulator_layers = [] thetas = [] init_states = [] devices = [] for cell_idx in range(num_cells): cell_name, cell = self._cells[cell_idx] accumulator_layers.append(cell) cell_fns.append(GetCellFn(cell_idx)) thetas.append(theta[cell_name]) def _TfZeros(t_shape): if t_shape is None: return None return tf.zeros(t_shape.ToTensorShape().as_list(), dtype=state_dtype) if p.nested_map_fprop: init_state = py_utils.Transform(_TfZeros, state_shapes[cell_idx + 1]) init_state = init_state.Filter(lambda x: x is not None) else: init_state = py_utils.NestedMap() for output_idx, state in enumerate(state_shapes[cell_idx + 1]): state = _TfZeros(state) if state is not None: name = 's{}'.format(output_idx) init_state[name] = state init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype) init_states.append(init_state) devices.append(cluster.WorkerDeviceInModelSplit(cell_idx)) cell_grads = [None] * num_cells cell_outs = [lambda x: x] * num_cells cell_out_grads = [lambda x: x] * num_cells with tf.device(devices[0]): previous = _ToTuple(args) for (name, l) in self._before_layers: previous = l.FProp(theta[name], *previous) previous = _ToTuple(previous) def _StackAndSplit(x): # Split tensors into microbatches. if x is None: return None return tf.stack(tf.split(x, p.num_micro_batches, axis=p.batch_dim)) if p.nested_map_fprop: inputs = py_utils.Transform(_StackAndSplit, previous[0]) inputs = inputs.Filter(lambda x: x is not None) else: inputs = py_utils.NestedMap() for output_idx, output_tensor in enumerate(previous): output_tensor = _StackAndSplit(output_tensor) if output_tensor is not None: name = 's{}'.format(output_idx) inputs[name] = output_tensor gs_tensor = py_utils.GetGlobalStep() inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([ tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype) for t in range(p.num_micro_batches) ]) tf.logging.info('pipeline input = {}'.format(inputs)) output_state, _ = recurrent.StackedRecurrent( devices=devices, cell_fns=cell_fns, cell_grads=cell_grads, cell_outs=cell_outs, cell_out_grads=cell_out_grads, thetas=thetas, init_states=init_states, inputs=inputs, accumulator_layers=accumulator_layers, unused_acc_state=True) with tf.device(devices[-1]): def _ReshapeRetVal(name, t_shape): """Restore shape for tensors in microbatches.""" if t_shape is None: return None output_tensor = output_state[name] if p.batch_dim != 0: perm = list(range(1, p.batch_dim + 1)) + [0] perm += list(range(p.batch_dim + 1, t_shape.rank + 1)) output_tensor = tf.transpose(output_tensor, perm=perm) output_shape = t_shape.ToTensorShape().as_list() output_shape[p.batch_dim] *= p.num_micro_batches output_tensor = tf.reshape(output_tensor, output_shape) return output_tensor # Construct the final return values from output_state. if p.nested_map_fprop: # pylint: disable=protected-access output_tensors = state_shapes[-1]._RecursiveMap(_ReshapeRetVal) # pylint: enable=protected-access else: output_tensors = [] for output_idx, state_shape in enumerate(state_shapes[-1]): output_name = 's{}'.format(output_idx) output_tensor = _ReshapeRetVal(output_name, state_shape) output_tensors.append(output_tensor) if len(output_tensors) == 1: output_tensors = output_tensors[0] else: output_tensors = tuple(output_tensors) tf.logging.info('pipeline output = {}'.format(output_tensors)) return output_tensors