Beispiel #1
0
    def Apply(self, lr, var_grad):
        """Applies the gradient to the variable.

    Args:
      lr: A scalar or a callable that returns the learning rate.
      var_grad: A `.NestedMap` of (var, grad) pairs.

    Returns:
      The variable update op.

    Raises:
      RuntimeError: When `lr` is not a callable in Eager mode and user did not
        enable the scalar lr option.
    """

        if py_utils.IsEagerMode() and not callable(lr):
            # Ensure that the learning rate is always updated in Eager mode.
            raise RuntimeError('In Eager mode, `lr` must be a callable.')

        # In Graph mode, always re-create the optimizer to remain consistent with
        # the old logic for the Graph trainer.
        # TODO(jiaweix): Recreating optimizers in Graph mode seems unnecessary.
        if self._optimizer is None or not py_utils.IsEagerMode():
            self._optimizer = self.GetOptimizer(lr)

        def _Apply():
            if not var_grad.Flatten():
                tf.logging.warning(
                    'No gradients are available for optimizer.Apply(). '
                    'Make sure this is expected.')
                return tf.no_op()
            if self.params.use_bf16_gradients_ar:
                return self._optimizer.apply_gradients(
                    [(tf.cast(g, tf.float32), v)
                     for (v, g) in var_grad.Flatten()],
                    name='meta_backprop')
            else:
                return self._optimizer.apply_gradients(
                    [(g, v) for (v, g) in var_grad.Flatten()],
                    name='meta_backprop')

        clear_variable_scope = self.params.clear_variable_scope
        if clear_variable_scope is None:
            clear_variable_scope = not py_utils.IsEagerMode()
        if clear_variable_scope:
            # Many optimizers, e.g., Adam, Adagrad, etc., create
            # variables. We need to ensure name scope and variable scope are
            # cleared. Otherwise, tpu.batch_parallel does not work.
            with tf.name_scope(None):
                with tf.variable_scope(
                        tf.VariableScope(use_resource=True,
                                         reuse=self.VarReuseForSlotVars())):
                    var_update_op = _Apply()
        else:
            var_update_op = _Apply()

        if self.params.add_summary_in_apply:
            lr_value = GetLrValue(lr)
            self.AddSummary(lr_value, self._optimizer, var_grad)
        return var_update_op
Beispiel #2
0
def ReplicatedGenericInput(processor, num_replicas, replica_device_fn,
                           **kwargs):
    """Builds a replicated input pipeline.

  This is similar to GenericInput, except that the input processing can be
  distributed across devices and then concatenated at the current device.

  Args:
    processor: see comments for GenericInput.
    num_replicas: the number of input processing replicas. Usually set to number
      of infeed hosts.
    replica_device_fn: a int -> string function that takes the replica index in
      range [0, num_replicas) and returns a TF device string, e.g.,
      lambda i: '/task:{}/device:CPU:0'.format(i)
    **kwargs: additional keyword args for x_ops.generic_input.

  Returns:
    A tuple of (outputs, bucket_keys):

    - outputs: a NestedMap or a list of tensors, similar to `processor`'s
      return,  except every tensor will have an additional dimension 0 that
      represents the batch dimension. The batch size will be
      (num_replicas * bucket_batch_limit[...]), i.e.,
      kwargs['bucket_batch_limit'] specifies the per-replica batch size.
    - bucket_keys: a tf.int32 vector.

  Raises:
    RuntimeError: If called in pure Eager/tf.function mode without
      `generic_input_v2_key` defined.
  """
    if num_replicas > 1 and 'bucket_batch_limit' in kwargs:
        assert all(b == max(kwargs['bucket_batch_limit'])
                   for b in kwargs['bucket_batch_limit'])
    replica_outputs = []
    if py_utils.IsEagerMode():
        current_key = kwargs.pop('generic_input_v2_key', None)
        if current_key is None:
            raise RuntimeError(_MISSING_KEY_ERR)

    for replica_i in range(num_replicas):
        # Blend `replica_i` into the key for _GENERIC_CACHE_V2 to distinguish
        # different GenericInputV2 ops in the same Datasource object.
        if py_utils.IsEagerMode():
            kwargs['generic_input_v2_key'] = (current_key, replica_i)
        replica_device = replica_device_fn(replica_i)
        with tf.device(replica_device):
            replica_outputs.append(GenericInput(processor, **kwargs))

    output_nmaps, output_bucket_keys = zip(*replica_outputs)
    concat_nmap = tf.nest.map_structure(lambda *t: tf.concat(t, axis=0),
                                        *output_nmaps)
    concat_bucket_keys = tf.concat(output_bucket_keys, axis=0)
    return concat_nmap, concat_bucket_keys
Beispiel #3
0
def scalar(name, value, while_loop_reduce='mean'):
    """Adds summary scalar.

  Outside of tpu_summary.context() does nothing.

  Args:
    name: string name
    value: scalar tensor value
    while_loop_reduce: optional argument, determines what to do when this
      summary appears inside a tf.while_loop. Can be 'mean' or 'sum'.

  Raises:
    RuntimeError: if the function is called in Eager mode.
  """
    if py_utils.IsEagerMode():
        raise RuntimeError(EAGER_MODE_EXCEPTION_STR)

    assert while_loop_reduce in ('mean', 'sum')
    ctx = TpuSummaryContext.current()
    if ctx is None:
        return
    x = TpuSummaryScalar()
    x.name = str(name)
    x.value = tf.convert_to_tensor(value)
    if x.value.shape != ():  # pylint: disable=g-explicit-bool-comparison
        raise ValueError('use tpu_summary.tensor() instead: %r' % value)
    x.name_scope = tf.get_default_graph().get_name_scope()
    x.while_loop_reduce = while_loop_reduce
    ctx.summary_tensors.append(x)
Beispiel #4
0
    def _CreateVariableInternal(self, name: str,
                                meta: CreateVariableMeta) -> None:
        """Immediately creates the variable described by `meta`.

    DO NOT OVERRIDE. For internal use only. Subclasses of BaseLayer should use
    self.CreateVariable() to create variables.

    Args:
      name: The variable name.
      meta: A CreateVariableMeta describing the variable to be created.
    """
        meta.kwargs.setdefault('default_seed', self.params.random_seed)
        var = py_utils.CreateVariable(name, meta.var_params, **meta.kwargs)
        self._private_vars[name] = var

        if py_utils.IsEagerMode():
            # With eager trainer, always use the variable directly.
            value = var
        else:
            if self.cluster.params.worker.gpus_per_replica > 0:
                # On GPU (which always trains a single step per session.run()),
                # reference a tensor in FProp to cache it on device and avoid extraneous
                # sends from reading variables from ps multiple times.
                with tf.device(var.device):
                    value = tf.identity(var, name=name)
            elif self.params.add_name_to_theta:
                value = tf.identity(var, name=name)
            else:
                value = var

        if meta.theta_fn is not None:
            self._private_theta_fn[name] = meta.theta_fn

        self._private_theta[name] = value
Beispiel #5
0
    def _MaybeConstructSharedModel(self, train_cfg):
        """Construct a single shared copy of the model if this is a MultiTaskModel.

    If the share_model_object parameter is set, for MultiTaskModels,
    we create a MultiTaskSubModel for each task, but construct the model only
    once.

    Args:
      train_cfg: The params for a SingleTaskModel or MultiTaskModel.

    Returns:
      A MultiTaskModel, if train_cfg is a MultiTaskModel params object.
    """
        if not issubclass(train_cfg.cls, base_model.MultiTaskModel):
            return None

        if not train_cfg.share_model_object:
            return None

        with self._cluster, tf.container(
                self._container_id), contextlib.ExitStack() as stack:
            if not py_utils.IsEagerMode():
                stack.enter_context(self._graph.as_default())
                stack.enter_context(tf.device(self._cluster.GetPlacer()))
            with py_utils.VariableStore(), py_utils.VariableRenameScope(
                    self._variable_renaming_rules):
                py_utils.GetOrCreateGlobalStepVar()
                shared_model = train_cfg.Instantiate()

        return shared_model
Beispiel #6
0
    def available_devices(self):
        """Returns all compute devices available in a 2D array.

    Returns:
      A 2D array (python list of python lists) of strings. ret[i, j]
      is the j-th visible device on i-th visible replica.
    """
        from lingvo.core import py_utils  # pylint: disable=g-import-not-at-top
        if self.job_spec.tpus_per_replica and not py_utils.IsEagerMode():
            ret = np.empty((1, self.num_devices_per_split), np.object)
            for i in range(self.num_devices_per_split):
                ret[0, i] = tf.tpu.core(i)
            return ret

        if self.job == 'trainer' and self.asynchronous:
            # In async mode, each trainer task can only use its own devices.
            return self.ListDevices(self.job_spec)[self.task:(self.task +
                                                              1), :]

        if self.job == 'trainer_client' and self.synchronous:
            # In sync mode, trainer_client can use every device.
            return self.ListDevices(self.job_spec)

        if self.job == 'executor_tpu' and self.synchronous:
            # executor_tpu can use every device.
            return self.ListDevices(self.job_spec)

        if self.job in ('controller', 'train_summaries', 'evaler', 'decoder'):
            # Our current policy is that each controller/evaler/decoder task
            # only uses 1 replica.
            return self.ListDevices(self.job_spec)[self.task:(self.task +
                                                              1), :]

        assert False, (self.job, self.mode)
Beispiel #7
0
 def _GetSession(self, **kwargs):
     if py_utils.IsEagerMode():
         raise ValueError('_GetSession is not supported in eager mode.')
     graph = kwargs.pop('graph', self._graph)
     return tf.Session(self._tf_master,
                       graph=graph,
                       config=py_utils.SessionConfig(**kwargs))
Beispiel #8
0
 def _CreateCheckpointer(self, train_dir, model, init_op=None):
     """Wrapper method for override purposes."""
     if py_utils.IsEagerMode():
         if FLAGS.write_v2_checkpoints:
             return checkpointer.EagerCheckpointerV2(
                 train_dir, model, init_op)
         return checkpointer.EagerCheckpointerV1(train_dir, model, init_op)
     return checkpointer.Checkpointer(train_dir, model, init_op)
Beispiel #9
0
    def UpdateClusterParamsFromFlags(self, cluster, job_name):
        """Update `cluster` with a training cluster configuration from flags."""
        cluster.mode = FLAGS.mode
        cluster.job = job_name
        cluster.task = FLAGS.task
        cluster.do_eval = job_name in ['evaler', 'decoder']
        cluster.logdir = FLAGS.logdir

        cluster.controller.name = FLAGS.controller_job
        cluster.controller.gpus_per_replica = FLAGS.controller_gpus

        cluster.worker.name = FLAGS.worker_job
        cluster.worker.replicas = FLAGS.worker_replicas
        cluster.worker.gpus_per_replica = FLAGS.worker_gpus
        cluster.worker.tpus_per_replica = FLAGS.worker_tpus
        cluster.worker.num_tpu_hosts = FLAGS.worker_num_tpu_hosts
        cluster.worker.devices_per_split = FLAGS.worker_split_size
        if FLAGS.additional_worker_jobs:
            for additional_job in FLAGS.additional_worker_jobs:
                cluster.worker.additional_worker_names.append(additional_job)

        if FLAGS.tpu:
            job_name = cluster.worker.name.replace('/job:', '', 1)
            worker_hosts = _GetClusterSpecDict()[job_name]
            if FLAGS.additional_worker_jobs:
                for additional_job in cluster.worker.additional_worker_names:
                    additional_job_name = additional_job.replace(
                        '/job:', '', 1)
                    worker_hosts.extend(
                        _GetClusterSpecDict()[additional_job_name])
            cluster.worker.targets = ','.join('grpc://{}'.format(host)
                                              for host in worker_hosts)

        cluster.ps.name = FLAGS.ps_job
        cluster.ps.replicas = FLAGS.ps_replicas
        cluster.ps.gpus_per_replica = FLAGS.ps_gpus

        cluster.input.name = FLAGS.input_job
        cluster.input.replicas = FLAGS.input_replicas
        cluster.input.targets = FLAGS.input_targets

        if py_utils.IsEagerMode():
            cluster.evaler.name = '/job:localhost'
            cluster.decoder.name = '/job:localhost'
        else:
            cluster.evaler.name = FLAGS.evaler_job
            cluster.decoder.name = FLAGS.decoder_job

        cluster.evaler.replicas = FLAGS.evaler_replicas
        cluster.evaler.gpus_per_replica = FLAGS.evaler_gpus
        cluster.decoder.replicas = FLAGS.decoder_replicas
        cluster.decoder.gpus_per_replica = FLAGS.decoder_gpus

        cluster.tf_data_service_address = FLAGS.tf_data_service_address

        cluster.add_summary = FLAGS.add_summary
        cluster.reporting_job = FLAGS.vizier_reporting_job
Beispiel #10
0
    def Apply(self, lr, var_grad):
        """Applies the gradient to the variable.

    Args:
      lr: A scalar or callable that returns the base learning rate.
      var_grad: A `.NestedMap` of (var, grad) pairs.

    Returns:
      The variable update op.

    Raises:
      RuntimeError: When `lr` is not a callable in Eager mode.
    """

        # In Graph mode, always re-create the optimizer to remain consistent with
        # the old logic for the Graph trainer.
        # TODO(jiaweix): Recreating optimizers in Graph mode seems unnecessary.
        if self._optimizer is None or not py_utils.IsEagerMode():
            self._optimizer = self.GetOptimizer(lr)

        def _Apply():
            return self._optimizer.apply_gradients(
                [(g, v) for (v, g) in var_grad.Flatten()],
                name='meta_backprop')

        clear_variable_scope = self.params.clear_variable_scope
        if clear_variable_scope is None:
            clear_variable_scope = not py_utils.IsEagerMode()
        if clear_variable_scope:
            # Many optimizers, e.g., Adam, Adagrad, etc., create
            # variables. We need to ensure name scope and variable scope are
            # cleared. Otherwise, tpu.batch_parallel does not work.
            with tf.name_scope(None):
                with tf.variable_scope(
                        tf.VariableScope(use_resource=True,
                                         reuse=self.VarReuseForSlotVars())):
                    var_update_op = _Apply()
        else:
            var_update_op = _Apply()

        if self.params.add_summary_in_apply:
            lr_value = GetLrValue(lr)
            self.AddSummary(lr_value, self._optimizer, var_grad)
        return var_update_op
Beispiel #11
0
 def Finalize(self):
     """Finishes creation of the overall figure, returning the image summary."""
     rendered = self.FinalizeImage()
     if py_utils.IsEagerMode():
         return tf.compat.v2.summary.image(self._name,
                                           rendered,
                                           max_outputs=self._max_outputs,
                                           step=py_utils.GetGlobalStep())
     else:
         return tf.summary.image(self._name,
                                 rendered,
                                 max_outputs=self._max_outputs)
Beispiel #12
0
def pw_tensor(name, value):
    """Adds summary tensor."""
    if py_utils.IsEagerMode():
        raise RuntimeError(EAGER_MODE_EXCEPTION_STR)
    ctx = TpuSummaryContext.current()
    if ctx is None:
        return
    x = PwTpuSummaryTensor()
    x.name = str(name)
    x.value = tf.convert_to_tensor(value)
    x.name_scope = tf.get_default_graph().get_name_scope()
    ctx.summary_tensors.append(x)
Beispiel #13
0
def tensor(name, value):
    """Adds summary tensor. Similar to scalar() but allows other shapes."""
    if py_utils.IsEagerMode():
        raise RuntimeError(EAGER_MODE_EXCEPTION_STR)
    ctx = TpuSummaryContext.current()
    if ctx is None:
        return
    x = TpuSummaryScalar()
    x.name = str(name)
    x.value = tf.convert_to_tensor(value)
    x.name_scope = tf.get_default_graph().get_name_scope()
    x.while_loop_reduce = 'stack'
    ctx.summary_tensors.append(x)
Beispiel #14
0
    def CreateVariable(self, name: str, var_params: hyperparams.Params,
                       **kwargs) -> None:
        """Create a variable of this layer according to the parameter `var_params`.

    E.g.::

        def __init__(self, ...):    # A layer's constructor
          self.CreateVariable(
              'weight', py_utils.WeightParams(shape=[100, 100]))

    Args:
      name: Variable name which is used as the key into vars/theta.
      var_params: `Params` used to create the variable.
      **kwargs: Keyword args passed to `.py_utils.CreateVariable`.
    """
        kwargs.setdefault('default_seed', self.params.random_seed)
        if self.params.device_mesh is not None:
            if (len([dim for dim in var_params.shape if dim > 1]) > 1
                    and var_params.tensor_split_dims_mapping is None):
                tf.logging.warning(
                    'tensor_split_dims_mapping missing for %s.%s: shape=%s',
                    self.path, name, var_params.shape)
        self._CheckName(name)
        if (self.params.skip_lp_regularization and
                py_utils.SKIP_LP_REGULARIZATION not in var_params.collections):
            var_params = py_utils.WeightParams(
                shape=var_params.shape,
                dtype=var_params.dtype,
                init=var_params.init,
                collections=(var_params.collections +
                             [py_utils.SKIP_LP_REGULARIZATION]))
        self._var_symbolic_shape_map[name] = var_params.shape

        var = py_utils.CreateVariable(name, var_params, **kwargs)
        self._private_vars[name] = var

        if py_utils.IsEagerMode():
            # With eager trainer, always use the variable directly.
            value = var
        else:
            if self.cluster.params.worker.gpus_per_replica > 0:
                # On GPU (which always trains a single step per session.run()),
                # reference a tensor in FProp to cache it on device and avoid extraneous
                # sends from reading variables from ps multiple times.
                with tf.device(var.device):
                    value = tf.identity(var, name=name)
            else:
                value = var

        self._private_theta[name] = value
Beispiel #15
0
                def RunSave(sess, global_step):
                    # Run TPU embedding retrieve ops.
                    # NOTE: this is expensive, so only run it when we're checkpointing.
                    if not py_utils.IsEagerMode():
                        tf.logging.info('Retrieve params.')
                        sess.run(self._retrieve_ops)
                        tf.logging.info('Retrieve params done.')

                    # Save program state first, so it's recoverable after we restore
                    # from checkpoint.
                    for program in self._programs:
                        program.SaveProgramState(sess, global_step)
                    # Save the checkpoints asynchronously.
                    self._checkpointer.Save(sess, global_step, sync=False)
Beispiel #16
0
    def __init__(self,
                 train_dir,
                 models,
                 init_op=None,
                 train_params=None,
                 save_only=False):
        """Initialize Checkpointer.

    Args:
     train_dir: Training directory for saving checkpoints.
     models: One or a list of BaseModel instances. Cannot be empty. If there are
       more than one models and `train_params` is None, the save intervals will
       be only determined by the first model.
     init_op: The initialize variables op. If unset, it will call
       tf.global_variables_initializer().
     train_params: If specified, use these training params instead of those in
       the `model`.
     save_only: This checkpointer is only intended for saving checkpoints.
    """
        self._train_dir = train_dir
        self._save_only = save_only

        if init_op:
            self._init_op = init_op
        else:
            self._init_op = tf.global_variables_initializer()

        self._save_path = os.path.join(self._train_dir, 'ckpt')

        if not isinstance(models, (list, tuple)):
            models = [models]
        self._models = models

        if train_params:
            self._train_params = train_params
        else:
            self._train_params = models[0].params.train

        self._next_checkpoint_seconds = 0
        self._save_interval_seconds = self._train_params.save_interval_seconds
        self._save_interval_steps = self._train_params.save_interval_steps
        self._prev_ckpt_step = None
        self._saver = self._GetSaver()

        if not py_utils.IsEagerMode():
            self._uninitialized_vars = tf.report_uninitialized_variables(
                tf.global_variables())

        self._BuildInitFromCheckpointRules()
Beispiel #17
0
        def _WaitTillInit(job=None):
            """Wait until the model is ready."""
            try:
                if py_utils.IsEagerMode():
                    topology = tf.tpu.experimental.initialize_tpu_system(
                        resolver)
                else:
                    # tpu.initialize_system() is called with None as embedding_config, as
                    # embedding_config is not available yet. Later in _Loop, it is called
                    # with the correct embedding_config. Since it cannot be called twice
                    # in the same graph with different embedding_config, we use a
                    # dummy_graph here.
                    dummy_graph = tf.Graph()
                    with dummy_graph.as_default():
                        tpu_initialize_system_op = tf.tpu.initialize_system(
                            embedding_config=None, job=job)

                    with self._GetSession(graph=dummy_graph) as sess:
                        topology = sess.run(tpu_initialize_system_op)

                if train_cfg.train.tpu_computation_shape is None:
                    computation_shape = py_utils.ComputationShape(
                        num_devices_per_split, topology)
                else:
                    computation_shape = train_cfg.train.tpu_computation_shape
                    assert num_devices_per_split == np.prod(computation_shape)

                if train_cfg.train.tpu_device_order_mode is None:
                    self.device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=computation_shape,
                        num_replicas=data_parallelism)
                else:
                    self.device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=computation_shape,
                        num_replicas=data_parallelism,
                        device_order_mode=train_cfg.train.tpu_device_order_mode
                    )
                py_utils.SetTpuDeviceAssignment(self.device_assignment, job)
                tf.logging.info('device_assignment.core_assignment: %s',
                                str(self.device_assignment.core_assignment))
                tf.logging.info(
                    'device_assignment.topology.device_coordinates: %s',
                    str(self.device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise
Beispiel #18
0
 def Stop(self, session=None):
   """Returns true if stop criterion is met."""
   if self._node is not None:
     if py_utils.IsEagerMode():
       self._best_step, self._last_step = self._node()
     else:
       self._best_step, self._last_step = session.run(self._node())
     s = (
         self._last_step - self._best_step > self.params.window and
         self._last_step >= self.params.min_steps)
     if self.params.verbose:
       tf.logging.info(
           'early stop check: best_step=%d, last_step=%d, stop=%d',
           self._best_step, self._last_step, s)
     return s
   else:
     return False
Beispiel #19
0
def CollectVarHistogram(vs_gs):
    """Adds histogram summaries for variables and gradients."""

    for name, (var, grad) in vs_gs.FlattenItems():
        name = py_utils.SanitizeScopeKey(name)
        with tf.device(var.device), tf.name_scope(name + '/summary'):
            if isinstance(grad, tf.IndexedSlices):
                var = tf.gather(var, grad.indices)
                grad = grad.values
            if var.dtype.is_complex:
                var = tf.abs(var)
                grad = tf.abs(grad)

        if py_utils.IsEagerMode():
            histogram_v2(f'var_hist/{name}', var)
            histogram_v2(f'grad_hist/{name}', grad)
        else:
            histogram(f'var_hist/{name}', var)
            histogram(f'grad_hist/{name}', grad)
Beispiel #20
0
def AddNormSummary(name, vs_gs):
    """"Returns and creates summary for norms of vs and their gradients gs.

  Args:
    name: A name string for summary.
    vs_gs: A `.NestedMap` or a list of `.NestedMap` of (variable, gradient).

  Returns:
    norm of variables, and norm of gradients.
  """
    flatten = py_utils.Flatten(vs_gs)
    v_norm = tf.sqrt(py_utils.SumSquared([v for (v, _) in flatten]))
    g_norm = tf.sqrt(py_utils.SumSquared([g for (_, g) in flatten]))
    if py_utils.IsEagerMode():
        scalar_v2(f'var_norm/{name}', v_norm)
        scalar_v2(f'grad_norm/{name}', g_norm)
    else:
        scalar(f'var_norm/{name}', v_norm)
        scalar(f'grad_norm/{name}', g_norm)
    return v_norm, g_norm
Beispiel #21
0
    def _ShouldStop(self, sess=None, step=None, check_early_stop=True):
        """Check if the runner should stop.

    Args:
      sess: tf.Session.
      step: The current GlobalStep.
      check_early_stop: Whether or not we want to check the EarlyStop condition.
        In TPU-training, we don't want to check this at the every step
        granularity in the enqueue thread, as this may starve the TPU training
        loop which by default operates at the 1000 steps granularity.

    Returns:
      Whether runner should stop.
    """
        if step is None:
            if py_utils.IsEagerMode():
                step = py_utils.GetGlobalStep().numpy()
            else:
                step = sess.run(py_utils.GetGlobalStep())

        if step >= self.params.train.max_steps:
            tf.logging.info('ShouldStop: step:%6d params.train.max_steps:%6d',
                            step, self.params.train.max_steps)
            return True

        if (self._max_steps_for_early_stop
                and step >= self._max_steps_for_early_stop):
            tf.logging.info(
                'ShouldStop: step:%6d _max_steps_for_early_stop:%6d', step,
                self._max_steps_for_early_stop)
            return True

        if check_early_stop and self._ShouldEarlyStop(sess):
            tf.logging.info('ShouldStop: Early stopping.')
            return True

        if self._trial and self._trial.ShouldStop():
            tf.logging.info('ShouldStop: Trial finished.')
            return True

        return False
Beispiel #22
0
 def RetryLoop():
     try:
         if py_utils.IsEagerMode():
             global_step = py_utils.GetGlobalStep().numpy
         else:
             global_step = sess.run(py_utils.GetGlobalStep())
     except tf.errors.FailedPreconditionError as e:
         tf.logging.info(
             '%s: Probably the expected race on global_step: %s',
             self._job_name, e)
         raise
     msg = 'step:%6d' % global_step
     self._SetStatusMessage(msg)
     if start_up_delay_steps:
         if global_step < start_up_delay_steps:
             msg = 'global step (%d) has not reached start up delay steps (%d)' % (
                 global_step, self._start_up_delay_steps)
             tf.logging.info('%s: %s', self._job_name, msg)
             raise tf.errors.FailedPreconditionError(node_def=None,
                                                     op=None,
                                                     message=msg)
     return global_step
Beispiel #23
0
  def GetNext(self):
    """Return input batch from p.file_patterns list weighted by p.weights.

    Examples in the batch will be mixed together from different file_pattern
    source proportionally to the weights.

    Returns:
      An input batch.
    """
    p = self.params
    file_patterns = ','.join(p.file_pattern)
    if p.file_type:
      file_patterns = f'{p.file_type}:{file_patterns}'

    extra_args = dict()
    if p.source_id_offset != 0:
      extra_args['input_source_id_offset'] = p.source_id_offset

    # In TF1 mode, the python method `GetNext` only gets called once for each
    # DataSource object during graph construction.
    # In TF2 mode, however, `GetNext` can be called many times. We must specify
    # keys to uniquely identify its `GenericInputV2` resource. This
    # ensures that the resource is properly reused.
    if py_utils.IsEagerMode():
      # The current DataSource object is used as the key to GenericInputV2 ops.
      extra_args['generic_input_v2_key'] = self

    if p.weights:
      # Within-batch mixing.
      batch = self._input_generator._DataSourceFromFilePattern(  # pylint: disable=protected-access
          file_patterns,
          input_source_weights=p.weights,
          **extra_args)
    else:
      # Default.
      batch = self._input_generator._DataSourceFromFilePattern(  # pylint: disable=protected-access
          file_patterns, **extra_args)

    return batch
Beispiel #24
0
    def Apply(self, metrics, vmap, gradient_mask=None, gradient_adjuster=None):
        """Computes updates on 'vmap' to optimize 'loss'.

    TODO(rpang): explore merging gradient_mask and gradient_adjuster.

    Args:
      metrics: A Dict[str, (value, weight)], from which loss can be extracted
        according to p.loss_name.
      vmap: A `.NestedMap` object containing variables to optimize.
      gradient_mask: if not None, a dict mapping variable names to a 0/1 scalar.
      gradient_adjuster: if not None, a function that mutates a given var_grads.

    Returns:
      (losses, op, eval_metrics), where
        - losses is a list of scalar tensors;
        - op is a tf.Operation to update variables;
        - eval_metrics is a Dict[str, (value, weight)], where each value/weight
          is a scalar tensor.
    """
        # We apply gradients outside the name_scope to maintain backwards
        # compatibility on variables created by self.optimizer.Apply().
        losses, var_grads, eval_metrics = self._ComputeLossesAndGradients(
            metrics, vmap)
        if 'tpu_embedding_var_grads' in var_grads:
            tpu_embedding_var_grads = var_grads.tpu_embedding_var_grads
            del var_grads.tpu_embedding_var_grads

            tpu_embedding_collection = py_utils.GetTpuEmbeddingGraphCollection(
            )[0]
            assert tpu_embedding_collection
            tpu_emb_update_op, stats = tpu_embedding_collection.ApplyGradients(
                py_utils.GetTaskCallScope(),
                tpu_embedding_var_grads.Transform(
                    lambda var_grad: var_grad.grad))
            eval_metrics.update(stats)
        else:
            tpu_emb_update_op = tf.no_op()

        assert py_utils.GetGlobalStep() is not None
        lr = self.LearningRate()

        var_grads, stats = self.AdjustGradients(
            var_grads,
            gradient_mask=gradient_mask,
            gradient_adjuster=gradient_adjuster)
        eval_metrics.update(stats)
        self._var_grads = var_grads

        eval_metrics['learning_rate'] = (tf.convert_to_tensor(lr),
                                         tf.convert_to_tensor(1.))

        if py_utils.IsEagerMode():
            # Use a callback for learning_rate so we can keep just one instance of
            # optimizer object in the whole trainer lifetime in Eager mode.
            lr_or_callable = self.LearningRate
        else:
            lr_or_callable = lr

        with self._SelfVariableScope():
            var_update_op = tf.group([
                tpu_emb_update_op,
                self.optimizer.Apply(lr_or_callable, var_grads)
            ])

        return losses, var_update_op, eval_metrics
Beispiel #25
0
    def Apply(self, lr, var_grad):
        """For each optimizer, apply the gradient to the variable.

    Args:
      lr: A scalar. The base learning rate.
      var_grad: A `.NestedMap` of (var, grad) pairs.

    Returns:
      The variable update op.

    Raises:
      Exception: When the regex overlaps with or does not cover all variables.
    """
        # Override inherited GetOptimizer even though learning rate is unused.
        if not self._tf_optimizer_map or not py_utils.IsEagerMode():
            self._tf_optimizer_map = self.GetOptimizer(0)

        var_grad_map = {regex: [] for regex in self._lingvo_optimizer_map}

        for (v, g) in var_grad.Flatten():
            regex_match = 0
            for regex in self._lingvo_optimizer_map:
                if re.match(regex, v.name):
                    var_grad_map[regex].append((g, v))
                    regex_match += 1
            if regex_match == 0:
                var_grad_map['default_optimizer'].append((g, v))
            if regex_match > 1:
                raise Exception(
                    'Variable {} is matched {} times by regex {}'.format(
                        v.name, regex_match,
                        list(self._lingvo_optimizer_map.keys())))

        def _Apply():
            """Use the matched optimizer to apply the gradients."""
            train_ops = []
            non_default_regex = [
                regex for regex in self._lingvo_optimizer_map
                if regex != 'default_optimizer'
            ]
            for regex in self._lingvo_optimizer_map:
                if var_grad_map[regex]:
                    opt = self._tf_optimizer_map[regex]
                    train_ops.append(opt.apply_gradients(var_grad_map[regex]))
                    # pylint: disable=cell-var-from-loop, g-long-lambda
                    if regex == 'default_optimizer':
                        filtered_var_grad = var_grad.FilterKeyVal(
                            lambda k, v: any([
                                re.match(i, v.var.name)
                                for i in non_default_regex
                            ]))
                    else:
                        filtered_var_grad = var_grad.FilterKeyVal(
                            lambda k, v: (re.match(regex, v.var.name)))
                    # pylint: enable=cell-var-from-loop, g-long-lambda
                    self._lingvo_optimizer_map[regex].AddSummary(
                        self._lr_map[regex], opt, filtered_var_grad)
            return tf.group(*train_ops, name='composite_optimizer_train_op')

        # Many optimizers, e.g., Adam, Adagrad, etc., create
        # variables. We need to ensure name scope and variable scope are
        # cleared. Otherwise, tpu.batch_parallel does not work.
        var_reuse = False
        if py_utils.GetOpportunisticVariableReuse():
            var_reuse = tf.AUTO_REUSE
        with tf.name_scope(None):
            with tf.variable_scope(
                    tf.VariableScope(use_resource=True, reuse=var_reuse)):
                var_update_op = _Apply()
        return var_update_op
Beispiel #26
0
    def _Loop(self):
        with self._cluster, tf.container(
                self._container_id), contextlib.ExitStack() as stack:
            if py_utils.IsEagerMode():
                sess = None
            else:
                sess = self._GetSession(disable_meta_optimizer=FLAGS.
                                        disable_meta_optimizer_in_executor)
                stack.enter_context(sess)
                sess.reset(self._tf_master)
                config_proto = (self._tpu_embedding.config_proto
                                if self._tpu_embedding is not None else None)
                for worker in self._cluster.all_worker_names:
                    sess.run(
                        tf.tpu.initialize_system(embedding_config=config_proto,
                                                 job=worker))

            # Initialize the variables first, if needed.
            # Need to call create global step again because this is run in a thread.
            py_utils.GetOrCreateGlobalStepVar()

            if self._checkpoint_to_load:
                path = self._checkpointer.RestoreFromPath(
                    sess, checkpoint_path=self._checkpoint_to_load)
            else:
                path = self._checkpointer.Restore(sess)

            # Run the compiles in parallel.
            compile_fns = []
            for program in self._programs:
                program.LoadProgramState(path, sess)
                compile_fns += [program.Compile]
            threadpool = multiprocessing.dummy.Pool(len(compile_fns))
            futures = []
            tf.logging.info(
                f'Compiling {len(compile_fns)} programs in parallel.')
            for fn in compile_fns:
                futures += [threadpool.apply_async(fn, args=(sess, ))]
            for future in futures:
                future.get()

            if not py_utils.IsEagerMode():
                sess.run(self._initialize_tables)
                sess.run(self._initialize_local_vars)
                sess.run(self._load_ops)

            program_schedule = None
            # Threadpool to run code in programs async with TF Sessions (on TPUs).
            # This unblocks TPU from waiting for CPU processing on "main" thread, and
            # saves time for time-consuming CPU steps (e.g. PostProcessDecodeOut).
            program_threadpool = multiprocessing.dummy.Pool(1)
            start_time = time.time()
            while True:
                cycle_start_time = time.time()
                if py_utils.IsEagerMode():
                    global_step = py_utils.GetGlobalStep().numpy()
                else:
                    global_step = sess.run(py_utils.GetGlobalStep())

                def RunSave(sess, global_step):
                    # Run TPU embedding retrieve ops.
                    # NOTE: this is expensive, so only run it when we're checkpointing.
                    if not py_utils.IsEagerMode():
                        tf.logging.info('Retrieve params.')
                        sess.run(self._retrieve_ops)
                        tf.logging.info('Retrieve params done.')

                    # Save program state first, so it's recoverable after we restore
                    # from checkpoint.
                    for program in self._programs:
                        program.SaveProgramState(sess, global_step)
                    # Save the checkpoints asynchronously.
                    self._checkpointer.Save(sess, global_step, sync=False)

                checkpoint_write_secs = 0.0
                if not self._ml_perf_log and self._checkpointer.ShouldSave(
                        global_step):
                    checkpoint_write_start = time.time()
                    RunSave(sess, global_step)
                    checkpoint_write_secs = time.time(
                    ) - checkpoint_write_start

                # If a task is explicitly selected, only run the programs associated
                # with that task.
                if self._single_task_mode or self._model_task_name:
                    tf.logging.info('Single task mode: %s',
                                    self._model_task_name)
                    program_schedule = self._program_schedule_dict[
                        self._model_task_name]
                else:
                    # Otherwise, sample a task.
                    model_task = self.task_scheduler.Sample(global_step)
                    tf.logging.info('Sampled %s', model_task)
                    program_schedule = self._program_schedule_dict[model_task]

                done, train_time_in_secs, eval_time_in_secs = program_schedule.Run(
                    sess, program_threadpool)

                executor_cycle_in_secs = time.time() - cycle_start_time
                self._ExportMetrics(
                    executor_cycle_secs=executor_cycle_in_secs,
                    executor_train_time_secs=train_time_in_secs,
                    executor_eval_time_secs=eval_time_in_secs,
                    checkpoint_write_secs=checkpoint_write_secs,
                )

                def _ShutDown():
                    # Wait for the save ops to finish before exit.
                    self._checkpointer.Sync()
                    program_threadpool.close()
                    program_threadpool.join()
                    tf.logging.info(
                        'Program schedule told us to stop.\n'
                        'Shutting down programs after running %f seconds.',
                        time.time() - start_time)
                    program_schedule.Shutdown()

                if done:
                    tf.logging.info(
                        'Program done after %f seconds. Waiting for threads to end.',
                        time.time() - start_time)
                    _ShutDown()
                    return

                if py_utils.IsEagerMode():
                    global_step = py_utils.GetGlobalStep().numpy()
                else:
                    global_step = sess.run(py_utils.GetGlobalStep())
                if self._ShouldStop(sess, global_step):
                    tf.logging.info('Training finished.')
                    if not self._ml_perf_log:
                        RunSave(sess, global_step)
                    tf.logging.info(
                        'Program finished after %f seconds. Waiting for threads to end.',
                        time.time() - start_time)
                    _ShutDown()
                    return
Beispiel #27
0
 def Evaler(self):
     return eager_runners.Evaler if py_utils.IsEagerMode(
     ) else runners.Evaler
Beispiel #28
0
 def _GetSession(self, **kwargs):
     if py_utils.IsEagerMode():
         raise ValueError('Eager mode does not support _GetSession.')
     return super()._GetSession(cluster_def=self._worker_cluster_def,
                                **kwargs)
Beispiel #29
0
 def Decoder(self):
     return eager_runners.Decoder if py_utils.IsEagerMode(
     ) else runners.Decoder
Beispiel #30
0
    def __init__(self, train_cfg, ps_params_dict, *args, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
        if train_cfg is a SingleTaskModelParams, we expect only one entry.
      *args: List args to pass through to BaseRunner.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        if py_utils.IsEagerMode():
            assert tf.executing_eagerly()
            tf.logging.info(f'FLAGS.tf_master: {FLAGS.tf_master}')

            # Connect to the TPU runtime.
            resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
                FLAGS.tf_master, job_name=FLAGS.worker_job[len('/job:'):])
            tf.config.experimental_connect_to_cluster(resolver)

        super().__init__(train_cfg, *args, **kwargs)

        data_parallelism = self._cluster.num_splits_per_client
        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        self.task_scheduler = None
        self._checkpoint_dir = os.path.join(self._logdir, 'train')

        self._variable_renaming_rules = []

        self._ml_perf = None

        # If this is a multi-task model, grab the params for the TaskScheduler.
        if issubclass(train_cfg.cls, base_model.SingleTaskModel):
            tf.logging.info('single_task_model')
            assert len(ps_params_dict) == 1
            self._model_task_name = list(ps_params_dict.keys())[0]
            self._single_task_mode = True
        elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
            tf.logging.info('multi_task_model')

            if issubclass(train_cfg.cls,
                          multitask_model.RegExSharedVariableModel):
                self._variable_renaming_rules = train_cfg.variable_renaming_rules

            if train_cfg.task_schedule is None:
                task_schedule_params = task_scheduler.ConstantScheduler.Params(
                )
                task_schedule_params.task_probs = sorted(
                    list(train_cfg.task_probs.IterParams()))
            else:
                task_schedule_params = train_cfg.task_schedule
            self.task_scheduler = task_schedule_params.Instantiate()
            self._single_task_mode = False
        else:
            tf.logging.fatal(
                'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
                train_cfg.cls)

        tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

        self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir,
                         'trainer_params.txt')
        self._WriteToLog(
            text_format.MessageToString(train_cfg.ToProto(), as_utf8=True),
            self._checkpoint_dir, 'trainer_params.pbtxt')
        if self._ml_perf is not None:
            self._ml_perf_log = True
            mlp_log.mlperf_print(key='benchmark',
                                 value=self._ml_perf.benchmark_name)
        else:
            self._ml_perf_log = False

        train_cfg = self.params

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit(job=None):
            """Wait until the model is ready."""
            try:
                if py_utils.IsEagerMode():
                    topology = tf.tpu.experimental.initialize_tpu_system(
                        resolver)
                else:
                    # tpu.initialize_system() is called with None as embedding_config, as
                    # embedding_config is not available yet. Later in _Loop, it is called
                    # with the correct embedding_config. Since it cannot be called twice
                    # in the same graph with different embedding_config, we use a
                    # dummy_graph here.
                    dummy_graph = tf.Graph()
                    with dummy_graph.as_default():
                        tpu_initialize_system_op = tf.tpu.initialize_system(
                            embedding_config=None, job=job)

                    with self._GetSession(graph=dummy_graph) as sess:
                        topology = sess.run(tpu_initialize_system_op)

                if train_cfg.train.tpu_computation_shape is None:
                    computation_shape = py_utils.ComputationShape(
                        num_devices_per_split, topology)
                else:
                    computation_shape = train_cfg.train.tpu_computation_shape
                    assert num_devices_per_split == np.prod(computation_shape)

                if train_cfg.train.tpu_device_order_mode is None:
                    self.device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=computation_shape,
                        num_replicas=data_parallelism)
                else:
                    self.device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=computation_shape,
                        num_replicas=data_parallelism,
                        device_order_mode=train_cfg.train.tpu_device_order_mode
                    )
                py_utils.SetTpuDeviceAssignment(self.device_assignment, job)
                tf.logging.info('device_assignment.core_assignment: %s',
                                str(self.device_assignment.core_assignment))
                tf.logging.info(
                    'device_assignment.topology.device_coordinates: %s',
                    str(self.device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        if self._ml_perf_log:
            mlp_log.mlperf_print(key='init_start', value=None)
        if len(self._cluster.all_worker_names) > 1:
            for worker in self._cluster.all_worker_names:
                _WaitTillInit(worker)
        else:
            _WaitTillInit(None)

        shared_model = self._MaybeConstructSharedModel(train_cfg)

        self._program_schedule_dict = {}
        self._programs = []
        self._ckpt_programs = []

        self._checkpoint_to_load = None
        with self._cluster:
            # Create the ExponentialMovingAverage singleton shared by all programs, if
            # applicable.
            ema = py_utils.CreateEMAForModel(train_cfg, self._global_step_var)
            for task_string, program_schedule_params in ps_params_dict.items():
                program_schedule_params.logdir = self._logdir
                program_schedule_params.num_splits_per_client = data_parallelism
                program_schedule_params.task_name = task_string
                # If the model was created above, we'll inject it here as a
                # shared_model.
                ps = program_schedule_params.Instantiate(
                    shared_model=shared_model,
                    trial=self._trial,
                    ema=ema,
                    tf_master=self._tf_master)
                self._program_schedule_dict[task_string] = ps
                tf.logging.info('program_schedule_params: %s',
                                program_schedule_params.ToText())
                self._programs += ps.Programs()
                if ps.train_program:
                    self._ckpt_programs.append(ps.train_program)
                else:
                    self._ckpt_programs += ps.Programs()
                if program_schedule_params.ml_perf.benchmark_name is not None:
                    self._ml_perf = program_schedule_params.ml_perf
                if ('checkpoint_to_load' in program_schedule_params
                        and program_schedule_params.checkpoint_to_load):
                    if (self._checkpoint_to_load
                            and (self._checkpoint_to_load !=
                                 program_schedule_params.checkpoint_to_load)):
                        raise ValueError(
                            f'Multiple values found for checkpoint_to_load: '
                            f'{self._checkpoint_to_load}, '
                            f'{program_schedule_params.checkpoint_to_load}.')
                    self._checkpoint_to_load = program_schedule_params.checkpoint_to_load

        tf.logging.info('num_programs: %d', len(self._programs))

        # When running in a vizier trainer, the executor reports infeasiable runs
        # in case of errors. The programs report metrics and normal completions.
        for program in self._programs:
            if program._should_report_metrics:
                self._should_report_metrics = True

        with self._cluster, tf.container(
                self._container_id), contextlib.ExitStack() as stack:
            if not py_utils.IsEagerMode():
                stack.enter_context(self._graph.as_default())

                if FLAGS.use_tpu_mirrored_vars:
                    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
                        FLAGS.tf_master,
                        job_name=FLAGS.worker_job[len('/job:'):])
                    self._tpu_strategy = tf.distribute.experimental.TPUStrategy(
                        resolver, device_assignment=self.device_assignment)
                    stack.enter_context(self._tpu_strategy.scope())
                    stack.enter_context(
                        tpu_strategy._TPUReplicaContext(self._tpu_strategy))
                else:
                    stack.enter_context(tf.device(self._cluster.GetPlacer()))

            if FLAGS.pdb_on_exception:
                stack.enter_context(pdb_wrapper.catch_post_mortem())
            with py_utils.VariableStore(), py_utils.VariableRenameScope(
                    self._variable_renaming_rules):
                # `BuildTpuSubgraph` has to be called before checkpoint restore, so that
                # the optimizer slot variables are guaranteed to be initialized before
                # they get loaded. Otherwise, the optimizers' slot variables will not
                # be properly loaded when V1 checkpoint is used.
                for program in self._programs:
                    program.BuildTpuSubgraph()
                    py_utils.ClearTpuSummaryTensors()

            if not py_utils.IsEagerMode():
                self._initialize_tables = tf.tables_initializer()
                self._initialize_local_vars = tf.local_variables_initializer()
                self._initialize_global_vars = tf.global_variables_initializer(
                )

            checkpointer_models = [
                program.GetModel() for program in self._ckpt_programs
            ]

            if py_utils.IsEagerMode():
                if FLAGS.use_v2_checkpoints_in_eager:
                    self._checkpointer = checkpointer.EagerCheckpointerV2(
                        self._checkpoint_dir,
                        models=checkpointer_models,
                        init_op=None,
                        train_params=train_cfg.train,
                        save_only=False)
                else:
                    self._checkpointer = checkpointer.EagerCheckpointerV1(
                        self._checkpoint_dir,
                        models=checkpointer_models,
                        init_op=None,
                        train_params=train_cfg.train,
                        save_only=False)
            else:
                self._checkpointer = checkpointer.Checkpointer(
                    self._checkpoint_dir,
                    models=checkpointer_models,
                    init_op=self._initialize_global_vars,
                    train_params=train_cfg.train,
                    save_only=False)

            for program in self._programs:
                program.SetStatusMessageFn(self._SetStatusMessage)

            tpu_embedding_collection = (
                tpu_embedding_layers.TpuEmbeddingCollection.Get())
            self._load_ops = tpu_embedding_collection.load_ops
            self._retrieve_ops = tpu_embedding_collection.retrieve_ops
            self._tpu_embedding = tpu_embedding_collection.tpu_embedding