示例#1
0
    def _LoopEnqueue(self, op, session_override=None):
        """Runs the enqueue op in a loop."""
        p = self.params
        sess = session_override or self._GetSession()

        with tf.container(self._container_id), sess:
            if self._initialize_tables is not None:
                sess.run(self._initialize_tables)
            gsteps = py_utils.GetGlobalStep()
            local_enqueue_steps = 0

            # Global enqueue steps measures how many global steps have data enqueued
            # for already. We use this to terminate; note that the enqueue op may
            # hang in session.run if we do not terminate with this check.
            global_enqueue_steps = None

            tf.logging.info(
                'params.train.max_steps: %d, enqueue_max_steps: %d',
                p.train.max_steps, p.train.enqueue_max_steps)
            while True:
                if self._dequeue_thread_complete:
                    tf.logging.info(
                        'LoopEnqueue done since consuming thread is done.')
                    return

                global_step = sess.run(gsteps)
                if global_enqueue_steps is None:
                    global_enqueue_steps = global_step
                if local_enqueue_steps % 1000 == 0:
                    tf.logging.info(
                        'Current global_enqueue_steps: %d, '
                        'local_enqueue_steps: %d, global_step: %d',
                        global_enqueue_steps, local_enqueue_steps, global_step)

                if py_utils.use_tpu():
                    global_steps_with_available_data = int(
                        global_enqueue_steps // p.train.tpu_steps_per_loop *
                        p.train.tpu_steps_per_loop)
                else:
                    global_steps_with_available_data = global_enqueue_steps

                if (self._ShouldStop(sess, global_steps_with_available_data)
                        or self._ShouldStop(sess, global_step)):
                    tf.logging.info('Done. ShouldStop is True.')
                    tf.logging.info('Enqueue loop sleeping')
                    time.sleep(15)
                    continue
                if (p.train.enqueue_max_steps > 0
                        and local_enqueue_steps >= p.train.enqueue_max_steps):
                    tf.logging.info('Done. train.enqueue_max_steps reached.')
                    tf.logging.info('Enqueue loop sleeping')
                    time.sleep(15)
                    continue
                local_enqueue_steps += 1

                # There are tpu_infeed_parallelism parallel threads enqueuing.
                # We account for all of them when updating global_enqueue_steps.
                global_enqueue_steps += p.input.tpu_infeed_parallelism

                sess.run([op])
示例#2
0
  def _Loop(self):
    with tf.container(self._container_id), self._GetSession(
        cluster_def=self._cluster_def,
        disable_meta_optimizer=FLAGS.disable_meta_optimizer_in_executor
    ) as sess:
      # Initialize the variables first, if needed.
      for program in self._programs:
        program.RestoreIfNeeded(sess)
        program.Compile(sess)
      sess.run(self._initialize_tables)
      sess.run(self._initialize_local_vars)

      while True:
        global_step = sess.run(py_utils.GetGlobalStep())
        if self._ShouldStop(sess, global_step):
          tf.logging.info('Training finished.')
          if not self._ml_perf_log:
            self.save_only_checkpointer.Save(sess, global_step)
          return

        # If a task is explicitly selected, only run the programs associated
        # with that task.
        if self._single_task_mode or self._model_task_name:
          tf.logging.info('Single task mode: %s', self._model_task_name)
          program_schedule = self._program_schedule_dict[self._model_task_name]
        else:
          # Otherwise, sample a task.
          model_task = self.task_scheduler.Sample(global_step)
          tf.logging.info('Sampled %s', model_task)
          program_schedule = self._program_schedule_dict[model_task]

        done = program_schedule.Run(sess)
        if done:
          tf.logging.info('Program schedule told us to stop.')
          return

        # TODO(blee): More complex saving rules. Currently, we assume
        # we save after every task's program schedule execution.
        #
        # global_step local variable above is a result of sess.run, not a
        # tf variable, so when we do save_only_checkpointer.Save(...) here
        # py_utils.GetGlobalStep() is ahead of it by
        #   (train_executions_per_eval * train_steps_per_loop)
        # steps ahead already, due to program_schedule.Run(sess).
        #
        if not self._ml_perf_log:
          self.save_only_checkpointer.Save(sess, py_utils.GetGlobalStep())
示例#3
0
    def __init__(self, params):
        """Layer constructor.

    Sub-classes of BaseLayer should decorator its __init__ with
    @base_layer.initializer

    Args:
      params: A params used to construct this layer.
    """
        assert params.name, ('Layer params for %s must have a "name"' %
                             self.__class__.__name__)

        tf_module_name = params.name
        tf_module_name = re.sub('[^a-zA-Z0-9_]+', '_', tf_module_name)
        tf_module_name = 'bbf_' + self.__class__.__name__ + '_' + tf_module_name
        py_utils.NestedMap.CheckKey(tf_module_name)

        # initialize the base class.
        super(BaseLayer, self).__init__(tf_module_name)

        # Note AutoTracking doesn't work properly due to its inability to walk
        # through py_utils.NestedMap data structures which are used widely
        # throughout the Lingvo codebase. Also there seems to be some performance
        # hit in turning on auto-tracking in constructing graphs. For now, we
        # disable auto-tracking.
        # TODO(lingvo): Re-enable auto-tracking when fuller support is
        # added for key data structures used in Lingvo, and performance issue is
        # debugged more and understood better.
        self._setattr_tracking = False

        self._parent = (_LAYER_STACK.layer_stack[-2]
                        if len(_LAYER_STACK.layer_stack) > 1 else None)
        assert self._parent is not self
        self._params = params.Copy()
        tf.logging.debug('Creating layer %s with params: \n %s \n',
                         self.__class__.__name__, str(params))
        # Vars created by this layer.
        self._private_vars = py_utils.NestedMap()
        # Theta derived from this layer's vars.
        self._private_theta = py_utils.NestedMap()
        # Child layers created by this layer through CreateChild/CreateChildren.
        self._private_children = py_utils.NestedMap()
        # Child layers created by this layer. A well-formed layer should
        # have self._private_children equals to self._children_list. I.e.,
        # all child layers are created using CreateChild/CreateChildren.
        self._children_list = []
        # Extra theta's not directly correpond to any underlying vars. For example,
        # the concatenated sharded variables.
        self._extra_theta = py_utils.NestedMap()
        # All registered accumulators.
        self._private_accumulators = py_utils.NestedMap()
        # Layer-private functions. Add with AddFunction.
        self._private_fns = dict()
        # Mapping from variable names to its symbolic shape.
        # self._var_symbolic_shape_map['var_name'] will be a tuple of integers or
        # symbolic expressions, one for each dimension of the variable.
        self._var_symbolic_shape_map = dict()

        self.AddExtraTheta('global_step', py_utils.GetGlobalStep())
示例#4
0
def GetOverWriteGlobalStep(graph=None):
  graph = graph or tf.get_default_graph()
  mb_tensors = graph.get_collection_ref(_OVERWRITE_GLOBAL_STEP_COLLECTION)
  if len(mb_tensors) == 1:
    mb_tensor = mb_tensors[0]
  else:
    mb_tensor = py_utils.GetGlobalStep()
  return mb_tensor
示例#5
0
 def _WaitUntilInit(self, sess, start_up_delay_steps=None):
     """Wait until the model is ready."""
     try:
         global_step = sess.run(py_utils.GetGlobalStep())
     except tf.errors.FailedPreconditionError as e:
         tf.logging.info(
             '%s: Probably the expected race on global_step: %s',
             self._job_name, e)
         raise
     msg = 'step:%6d' % global_step
     self._SetStatusMessage(msg)
     if start_up_delay_steps:
         if global_step < start_up_delay_steps:
             msg = 'global step (%d) has not reached start up delay steps (%d)' % (
                 global_step, self._start_up_delay_steps)
             tf.logging.info('%s: %s', self._job_name, msg)
             raise tf.errors.FailedPreconditionError(node_def=None,
                                                     op=None,
                                                     message=msg)
     return global_step
示例#6
0
    def RestoreGlobalStepIfNeeded(self, sess):
        """If global step is not initialized, load it from the checkpoint.

    Args:
      sess: tf.Session.
    """
        assert not self._save_only
        uninitialized_vars = self._GetUninitializedVarNames(sess)
        if six.ensure_binary('global_step') not in uninitialized_vars:
            return

        with sess.graph.as_default():
            gstep = py_utils.GetGlobalStep()

        path = tf.train.latest_checkpoint(self._train_dir)
        if path:
            reader = tf.train.NewCheckpointReader(path)
            value = reader.get_tensor('global_step')
            tf.logging.info('Restoring global step: %s', value)
            sess.run(gstep.assign(value))
        else:
            tf.logging.info('Initializing global step')
            sess.run(gstep.initializer)
示例#7
0
  def FProp(self, theta, *args):
    """Run multiple cells in different devices in a pipelining manner.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      *args: Non-keyworded variable length argument list of input tensors.

    Returns:
      A list of output tensors
    """
    # TODO(huangyp): handle optional None inputs.
    p = self.params
    if self.do_eval:
      outputs = copy.copy(args)
      for (name, l) in self._before_layers + self._cells:
        outputs = _ToTuple(outputs)
        outputs = l.FProp(theta[name], *outputs)
      return outputs

    num_cells = len(p.cell_tpl)
    cluster = self.cluster

    # Compute shapes of input and output tensors.
    input_shapes = self._get_input_shapes(*args)
    state_dtype = self._get_state_dtype(*args)
    state_shapes = self._CalculateOutputShapes(input_shapes)
    tf.logging.info('state_shapes={}'.format(state_shapes))

    def GetCellFn(i):
      """Get the ith feature extraction layer."""

      def CellFn(theta, state0, inputs):
        """A cell fn is exectued inside of StackedRecurrent."""
        del state0

        def _FPropInputSetShape(name, t_shape):
          if t_shape is None:
            return None
          inputs[name].set_shape(t_shape.ToTensorShape().as_list())
          return inputs[name]

        if p.nested_map_fprop:
          # pylint: disable=protected-access
          fprop_inputs = state_shapes[i]._RecursiveMap(_FPropInputSetShape)
          # pylint: enable=protected-access
        else:
          fprop_inputs = []
          for input_idx, input_shape in enumerate(state_shapes[i]):
            name = 's{}'.format(input_idx)
            fprop_inputs.append(_FPropInputSetShape(name, input_shape))

        with py_utils.RemoveAssertContext(remove=True):
          with CellFnFPropOpReplacementWrapper():
            tf.logging.info('cell {} input {}'.format(i, fprop_inputs))
            mb_tensor = inputs[_MICRO_BATCH_STATE_NAME]
            SetOverWriteGlobalStep(mb_tensor)
            _, cell = self._cells[i]
            fprop_inputs = _ToTuple(fprop_inputs)
            outputs = cell.FProp(theta, *fprop_inputs)

        if p.nested_map_fprop:
          assert py_utils.IsCompatible(outputs, state_shapes[i + 1])
          state1 = outputs.Filter(lambda x: x is not None)
        else:
          state1 = py_utils.NestedMap()
          outputs = _ToTuple(outputs)
          assert len(outputs) == len(state_shapes[i + 1])
          for output_idx in range(len(outputs)):
            if outputs[output_idx] is not None:
              name = 's{}'.format(output_idx)
              state1[name] = outputs[output_idx]
        state1[_MICRO_BATCH_STATE_NAME] = mb_tensor
        return state1, py_utils.NestedMap()

      return CellFn

    cell_fns = []
    accumulator_layers = []
    thetas = []
    init_states = []
    devices = []
    for cell_idx in range(num_cells):
      cell_name, cell = self._cells[cell_idx]
      accumulator_layers.append(cell)
      cell_fns.append(GetCellFn(cell_idx))
      thetas.append(theta[cell_name])

      def _TfZeros(t_shape):
        if t_shape is None:
          return None
        return tf.zeros(t_shape.ToTensorShape().as_list(), dtype=state_dtype)

      if p.nested_map_fprop:
        init_state = py_utils.Transform(_TfZeros, state_shapes[cell_idx + 1])
        init_state = init_state.Filter(lambda x: x is not None)
      else:
        init_state = py_utils.NestedMap()
        for output_idx, state in enumerate(state_shapes[cell_idx + 1]):
          state = _TfZeros(state)
          if state is not None:
            name = 's{}'.format(output_idx)
            init_state[name] = state
      init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype)
      init_states.append(init_state)

      devices.append(cluster.WorkerDeviceInModelSplit(cell_idx))

    cell_grads = [None] * num_cells
    cell_outs = [lambda x: x] * num_cells
    cell_out_grads = [lambda x: x] * num_cells

    with tf.device(devices[0]):
      previous = _ToTuple(args)
      for (name, l) in self._before_layers:
        previous = l.FProp(theta[name], *previous)
        previous = _ToTuple(previous)

      def _StackAndSplit(x):
        # Split tensors into microbatches.
        if x is None:
          return None
        return tf.stack(tf.split(x, p.num_micro_batches, axis=p.batch_dim))

      if p.nested_map_fprop:
        inputs = py_utils.Transform(_StackAndSplit, previous[0])
        inputs = inputs.Filter(lambda x: x is not None)
      else:
        inputs = py_utils.NestedMap()
        for output_idx, output_tensor in enumerate(previous):
          output_tensor = _StackAndSplit(output_tensor)
          if output_tensor is not None:
            name = 's{}'.format(output_idx)
            inputs[name] = output_tensor
      gs_tensor = py_utils.GetGlobalStep()
      inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([
          tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype)
          for t in range(p.num_micro_batches)
      ])
    tf.logging.info('pipeline input = {}'.format(inputs))
    output_state, _ = recurrent.StackedRecurrent(
        devices=devices,
        cell_fns=cell_fns,
        cell_grads=cell_grads,
        cell_outs=cell_outs,
        cell_out_grads=cell_out_grads,
        thetas=thetas,
        init_states=init_states,
        inputs=inputs,
        accumulator_layers=accumulator_layers,
        unused_acc_state=True)

    with tf.device(devices[-1]):

      def _ReshapeRetVal(name, t_shape):
        """Restore shape for tensors in microbatches."""
        if t_shape is None:
          return None
        output_tensor = output_state[name]
        if p.batch_dim != 0:
          perm = list(range(1, p.batch_dim + 1)) + [0]
          perm += list(range(p.batch_dim + 1, t_shape.rank + 1))
          output_tensor = tf.transpose(output_tensor, perm=perm)
        output_shape = t_shape.ToTensorShape().as_list()
        output_shape[p.batch_dim] *= p.num_micro_batches
        output_tensor = tf.reshape(output_tensor, output_shape)
        return output_tensor

      # Construct the final return values from output_state.
      if p.nested_map_fprop:
        # pylint: disable=protected-access
        output_tensors = state_shapes[-1]._RecursiveMap(_ReshapeRetVal)
        # pylint: enable=protected-access
      else:
        output_tensors = []
        for output_idx, state_shape in enumerate(state_shapes[-1]):
          output_name = 's{}'.format(output_idx)
          output_tensor = _ReshapeRetVal(output_name, state_shape)
          output_tensors.append(output_tensor)
        if len(output_tensors) == 1:
          output_tensors = output_tensors[0]
        else:
          output_tensors = tuple(output_tensors)
      tf.logging.info('pipeline output = {}'.format(output_tensors))
      return output_tensors