Esempio n. 1
0
def logistic_regression_signature_fn(examples, unused_features, predictions):
  """Creates logistic regression signature from given examples and predictions.

  Args:
    examples: `Tensor`.
    unused_features: `dict` of `Tensor`s.
    predictions: `dict` of `Tensor`s.

  Returns:
    Tuple of default classification signature and named signature.
  """
  # predictions should have shape [batch_size, 2] where first column is P(Y=0|x)
  # while second column is P(Y=1|x). We are only interested in the second
  # column for inference.
  predictions_shape = predictions.get_shape()
  predictions_rank = len(predictions_shape)
  if predictions_rank != 2:
    logging.fatal(
        'Expected predictions to have rank 2, but received predictions with '
        'rank: {} and shape: {}'.format(predictions_rank, predictions_shape))
  if predictions_shape[1] != 2:
    logging.fatal(
        'Expected predictions to have 2nd dimension: 2, but received '
        'predictions with 2nd dimension: {} and shape: {}. Did you mean to use '
        'regression_signature_fn instead?'.format(predictions_shape[1],
                                                  predictions_shape))

  positive_predictions = predictions[:, 1]
  signatures = {}
  signatures['regression'] = exporter.regression_signature(examples,
                                                           positive_predictions)
  return signatures['regression'], signatures
Esempio n. 2
0
def _create_optimizer(hyperparams, learning_rate_var, step_var=None):
  """Creates an optimizer object for a given spec, learning rate and step var.

  Args:
    hyperparams: a GridPoint proto containing optimizer spec, particularly
      learning_method to determine optimizer class to use.
    learning_rate_var: a `tf.Tensor`, the learning rate.
    step_var: a `tf.Variable`, global training step.

  Returns:
    a `tf.train.Optimizer` object that was built.
  """
  if hyperparams.learning_method == 'gradient_descent':
    return tf.train.GradientDescentOptimizer(
        learning_rate_var, use_locking=True)
  elif hyperparams.learning_method == 'adam':
    return tf.train.AdamOptimizer(
        learning_rate_var,
        beta1=hyperparams.adam_beta1,
        beta2=hyperparams.adam_beta2,
        epsilon=hyperparams.adam_eps,
        use_locking=True)
  elif hyperparams.learning_method == 'lazyadam':
    return tf.contrib.opt.LazyAdamOptimizer(
        learning_rate_var,
        beta1=hyperparams.adam_beta1,
        beta2=hyperparams.adam_beta2,
        epsilon=hyperparams.adam_eps,
        use_locking=True)
  elif hyperparams.learning_method == 'momentum':
    return tf.train.MomentumOptimizer(
        learning_rate_var, hyperparams.momentum, use_locking=True)
  else:
    logging.fatal('Unknown learning method: %s', hyperparams.learning_method)
Esempio n. 3
0
def _create_optimizer(hyperparams, learning_rate_var, step_var=None):
    """Creates an optimizer object for a given spec, learning rate and step var.

  Args:
    hyperparams: a GridPoint proto containing optimizer spec, particularly
      learning_method to determine optimizer class to use.
    learning_rate_var: a `tf.Tensor`, the learning rate.
    step_var: a `tf.Variable`, global training step.

  Returns:
    a `tf.train.Optimizer` object that was built.
  """
    if hyperparams.learning_method == 'gradient_descent':
        return tf.train.GradientDescentOptimizer(learning_rate_var,
                                                 use_locking=True)
    elif hyperparams.learning_method == 'adam':
        return tf.train.AdamOptimizer(learning_rate_var,
                                      beta1=hyperparams.adam_beta1,
                                      beta2=hyperparams.adam_beta2,
                                      epsilon=hyperparams.adam_eps,
                                      use_locking=True)
    elif hyperparams.learning_method == 'lazyadam':
        return tf.contrib.opt.LazyAdamOptimizer(learning_rate_var,
                                                beta1=hyperparams.adam_beta1,
                                                beta2=hyperparams.adam_beta2,
                                                epsilon=hyperparams.adam_eps,
                                                use_locking=True)
    elif hyperparams.learning_method == 'momentum':
        return tf.train.MomentumOptimizer(learning_rate_var,
                                          hyperparams.momentum,
                                          use_locking=True)
    else:
        logging.fatal('Unknown learning method: %s',
                      hyperparams.learning_method)
Esempio n. 4
0
def _heartbeat(
    period: int,  # in seconds
    timer: threading.Event,
    token: int,
    num_tasks: int,
    task_id: int,
    device: tf_device.DeviceSpec,
):
    """Periodically sends and receives a heartbeat signal."""
    logging.info('Starting a heartbeat thread')
    global _failure_count
    while True:
        # `timer.wait` blocks until one of two things happens.
        # It returns True if the timer is explicitly set at process exit, and we
        # should gracefully end this heartbeat thread.
        # Otherwise, it returns False when `period` has elapsed, meaning it's time
        # for the next heartbeat exchange.
        # See https://docs.python.org/3/library/threading.html#threading.Event.wait.
        if timer.wait(period):
            logging.info('Exiting the heartbeat thread normally')
            return

        # Every worker fills in one element of the signal with `token`.
        signal = np.zeros([num_tasks], dtype=np.int32)
        signal[task_id] = token

        logging.vlog(2, 'Sending heartbeat signal %s', signal)
        try:
            with ops.device(device):
                # Always use 0 for group and instance keys to reduce unnecessary
                # collective hangs and simplify failure analysis. This also avoid
                # collision with normal collectives.
                signal = all_reduce(constant_op.constant(signal),
                                    group_size=num_tasks,
                                    group_key=0,
                                    instance_key=0,
                                    timeout=max(period - 10, 2)).numpy()
        except Exception as e:  # pylint: disable=broad-except
            _failure_count += 1
            if _failure_count < _CONSECUTIVE_FAILURES_LIMIT:
                logging.warning(
                    'Heartbeat failure %d, %d more until limit: %s',
                    _failure_count,
                    _CONSECUTIVE_FAILURES_LIMIT - _failure_count, e)
            else:
                logging.fatal('Heartbeat failure %d, limit of %d reached: %s',
                              _failure_count, _CONSECUTIVE_FAILURES_LIMIT, e)
        logging.vlog(2, 'Received heartbeat signal %s', signal)

        # Out of sync workers will cause this, crash immediately.
        if not np.all(signal == token):
            logging.fatal('Unexpected heartbeat signal received: %s', signal)

        # Any success resets the failure counter.
        _failure_count = 0
Esempio n. 5
0
def _create_optimizer(hyperparams, learning_rate_var, step_var=None):
    """Creates an optimizer object for a given spec, learning rate and step var.

  Args:
    hyperparams: a GridPoint proto containing optimizer spec, particularly
      learning_method to determine optimizer class to use.
    learning_rate_var: a `tf.Tensor`, the learning rate.
    step_var: a `tf.Variable`, global training step.

  Returns:
    a `tf.train.Optimizer` object that was built.
  """
    if hyperparams.learning_method == 'gradient_descent':
        return tf.train.GradientDescentOptimizer(learning_rate_var,
                                                 use_locking=True)
    elif hyperparams.learning_method == 'adam':
        return tf.train.AdamOptimizer(learning_rate_var,
                                      beta1=hyperparams.adam_beta1,
                                      beta2=hyperparams.adam_beta2,
                                      epsilon=hyperparams.adam_eps,
                                      use_locking=True)
    elif hyperparams.learning_method == 'lazyadam':
        return tf.contrib.opt.LazyAdamOptimizer(learning_rate_var,
                                                beta1=hyperparams.adam_beta1,
                                                beta2=hyperparams.adam_beta2,
                                                epsilon=hyperparams.adam_eps,
                                                use_locking=True)
    elif hyperparams.learning_method == 'momentum':
        return tf.train.MomentumOptimizer(learning_rate_var,
                                          hyperparams.momentum,
                                          use_locking=True)
    elif hyperparams.learning_method == 'composite':
        spec = hyperparams.composite_optimizer_spec
        optimizer1 = _create_optimizer(spec.method1, learning_rate_var,
                                       step_var)
        optimizer2 = _create_optimizer(spec.method2, learning_rate_var,
                                       step_var)
        if step_var is None:
            logging.fatal('step_var is required for CompositeOptimizer')
        switch = tf.less(step_var, spec.switch_after_steps)
        return composite_optimizer.CompositeOptimizer(optimizer1,
                                                      optimizer2,
                                                      switch,
                                                      use_locking=True)
    else:
        logging.fatal('Unknown learning method (optimizer)')
Esempio n. 6
0
    def get_layer_size(self, layer_name):
        if layer_name == 'logits':
            return self._component.num_actions

        if layer_name == 'last_layer':
            return self._hidden_layer_sizes[-1]

        if not layer_name.startswith('layer_'):
            logging.fatal(
                'Invalid layer name: "%s" Can only retrieve from "logits", '
                '"last_layer", and "layer_*".', layer_name)

        # NOTE(danielandor): Since get_layer_size is called before the
        # model has been built, we compute the layer size directly from
        # the hyperparameters rather than from self._layers.
        layer_index = int(layer_name.split('_')[1])
        return self._hidden_layer_sizes[layer_index]
Esempio n. 7
0
def logistic_regression_signature_fn(examples, unused_features, predictions):
    """Creates logistic regression signature from given examples and predictions.

  Args:
    examples: `Tensor`.
    unused_features: `dict` of `Tensor`s.
    predictions: `Tensor` of shape [batch_size, 2] of predicted probabilities or
      dict that contains the probabilities tensor as in
      {'probabilities', `Tensor`}.

  Returns:
    Tuple of default regression signature and named signature.

  Raises:
    ValueError: If examples is `None`.
  """
    if examples is None:
        raise ValueError(
            'examples cannot be None when using this signature fn.')

    if isinstance(predictions, dict):
        predictions_tensor = predictions['probabilities']
    else:
        predictions_tensor = predictions
    # predictions should have shape [batch_size, 2] where first column is P(Y=0|x)
    # while second column is P(Y=1|x). We are only interested in the second
    # column for inference.
    predictions_shape = predictions_tensor.get_shape()
    predictions_rank = len(predictions_shape)
    if predictions_rank != 2:
        logging.fatal(
            'Expected predictions to have rank 2, but received predictions with '
            'rank: {} and shape: {}'.format(predictions_rank,
                                            predictions_shape))
    if predictions_shape[1] != 2:
        logging.fatal(
            'Expected predictions to have 2nd dimension: 2, but received '
            'predictions with 2nd dimension: {} and shape: {}. Did you mean to use '
            'regression_signature_fn or classification_signature_fn_with_prob '
            'instead?'.format(predictions_shape[1], predictions_shape))

    positive_predictions = predictions_tensor[:, 1]
    default_signature = exporter.regression_signature(
        input_tensor=examples, output_tensor=positive_predictions)
    return default_signature, {}
Esempio n. 8
0
def logistic_regression_signature_fn(examples, unused_features, predictions):
    """Creates logistic regression signature from given examples and predictions.

  Args:
    examples: `Tensor`.
    unused_features: `dict` of `Tensor`s.
    predictions: `Tensor` of shape [batch_size, 2] of predicted probabilities or
      dict that contains the probabilities tensor as in
      {'probabilities', `Tensor`}.

  Returns:
    Tuple of default regression signature and named signature.

  Raises:
    ValueError: If examples is `None`.
  """
    if examples is None:
        raise ValueError("examples cannot be None when using this signature fn.")

    if isinstance(predictions, dict):
        predictions_tensor = predictions["probabilities"]
    else:
        predictions_tensor = predictions
    # predictions should have shape [batch_size, 2] where first column is P(Y=0|x)
    # while second column is P(Y=1|x). We are only interested in the second
    # column for inference.
    predictions_shape = predictions_tensor.get_shape()
    predictions_rank = len(predictions_shape)
    if predictions_rank != 2:
        logging.fatal(
            "Expected predictions to have rank 2, but received predictions with "
            "rank: {} and shape: {}".format(predictions_rank, predictions_shape)
        )
    if predictions_shape[1] != 2:
        logging.fatal(
            "Expected predictions to have 2nd dimension: 2, but received "
            "predictions with 2nd dimension: {} and shape: {}. Did you mean to use "
            "regression_signature_fn or classification_signature_fn_with_prob "
            "instead?".format(predictions_shape[1], predictions_shape)
        )

    positive_predictions = predictions_tensor[:, 1]
    default_signature = exporter.regression_signature(input_tensor=examples, output_tensor=positive_predictions)
    return default_signature, {}
Esempio n. 9
0
def _create_optimizer(hyperparams, learning_rate_var, step_var=None):
  """Creates an optimizer object for a given spec, learning rate and step var.

  Args:
    hyperparams: a GridPoint proto containing optimizer spec, particularly
      learning_method to determine optimizer class to use.
    learning_rate_var: a `tf.Tensor`, the learning rate.
    step_var: a `tf.Variable`, global training step.

  Returns:
    a `tf.train.Optimizer` object that was built.
  """
  if hyperparams.learning_method == 'gradient_descent':
    return tf.train.GradientDescentOptimizer(
        learning_rate_var, use_locking=True)
  elif hyperparams.learning_method == 'adam':
    return tf.train.AdamOptimizer(
        learning_rate_var,
        beta1=hyperparams.adam_beta1,
        beta2=hyperparams.adam_beta2,
        epsilon=hyperparams.adam_eps,
        use_locking=True)
  elif hyperparams.learning_method == 'lazyadam':
    return tf.contrib.opt.LazyAdamOptimizer(
        learning_rate_var,
        beta1=hyperparams.adam_beta1,
        beta2=hyperparams.adam_beta2,
        epsilon=hyperparams.adam_eps,
        use_locking=True)
  elif hyperparams.learning_method == 'momentum':
    return tf.train.MomentumOptimizer(
        learning_rate_var, hyperparams.momentum, use_locking=True)
  elif hyperparams.learning_method == 'composite':
    spec = hyperparams.composite_optimizer_spec
    optimizer1 = _create_optimizer(spec.method1, learning_rate_var, step_var)
    optimizer2 = _create_optimizer(spec.method2, learning_rate_var, step_var)
    if step_var is None:
      logging.fatal('step_var is required for CompositeOptimizer')
    switch = tf.less(step_var, spec.switch_after_steps)
    return composite_optimizer.CompositeOptimizer(
        optimizer1, optimizer2, switch, use_locking=True)
  else:
    logging.fatal('Unknown learning method (optimizer)')
Esempio n. 10
0
def restore_model(checkpoint_paths,
                  variables_to_restore,
                  ignore_missing_vars=False,
                  num_streams=1,
                  checkpoint_style=None,
                  special_assign_vars=None):
    all_ops = []
    if len(checkpoint_paths) == 1 and num_streams > 1:
      logging.info('Provided one checkpoint for multi-stream '
                   'network. Will use this as a saved model '
                   'with this exact multi stream network.')
      all_ops.append(slim.assign_from_checkpoint_fn(
        checkpoint_paths[0],
        variables_to_restore,
        ignore_missing_vars=ignore_missing_vars))
    else:
      for sid in range(num_streams):
        this_checkpoint_style = checkpoint_style.split(',')[sid] if \
                                checkpoint_style is not None else None
        checkpoint_path = checkpoint_paths[sid]
        # assert tf.gfile.Exists(checkpoint_path)
        this_stream_name = 'stream%d/' % sid
        this_checkpoint_variables = [var for var in variables_to_restore
                                     if var in
                                     slim.get_model_variables(this_stream_name)]
        if checkpoint_path.endswith('.npy'):
          vars_to_restore_names = [
              el.name for el in this_checkpoint_variables]
          key_name_mapper = var_name_mapper.map()
          init_weights = np.load(checkpoint_path).item()
          init_weights_final = {}
          vars_restored = []
          for key in init_weights.keys():
            for subkey in init_weights[key].keys():
              prefix = this_stream_name
              if this_checkpoint_style == 'v2_withStream':
                prefix = 'stream0/'  # because any model trained with stream
                                     # will have that stream as 0
              final_key_name = prefix + key_name_mapper(
                  key + '/' + subkey)
              if final_key_name not in vars_to_restore_names:
                logging.error('Not using %s from npy' % final_key_name)
                continue
              
              target_shape = slim.get_model_variables(
                final_key_name)[0].get_shape().as_list()
              pretrained_wts = init_weights[key][subkey]
              target_shape_squeezed = np.delete(
                target_shape, np.where(np.array(target_shape) == 1))
              pretrained_shape_squeezed = np.delete(
                pretrained_wts.shape, np.where(np.array(pretrained_wts.shape) == 1))
              if np.all(target_shape_squeezed !=
                        pretrained_shape_squeezed):
                logging.error('Shape mismatch var: %s from npy [%s vs %s]' 
                              % (final_key_name, target_shape,
                                 pretrained_wts.shape))

              init_weights_final[final_key_name] = \
                  pretrained_wts
              vars_restored.append(final_key_name)
          init_weights = init_weights_final
          for v in vars_to_restore_names:
            if v not in vars_restored:
              logging.fatal('No weights found for %s' % v)
          all_ops.append(slim.assign_from_values_fn(
              init_weights))
        else:
          if this_checkpoint_style != 'v2_withStream':
            all_ops.append(slim.assign_from_checkpoint_fn(
                checkpoint_path,
                # stripping the stream name to map variables
                dict(
                  [('/'.join(el.name.split('/')[1:]).split(':')[0], el) for
                      el in this_checkpoint_variables]),
                ignore_missing_vars=ignore_missing_vars))
          else:
            all_ops.append(slim.assign_from_checkpoint_fn(
                checkpoint_path,
                # stripping the stream name to map variables, to stream0,
                # as the model is v2_withStream, hence must be trained with
                # stream0/ prefix
                dict(
                  [('/'.join(['stream0'] + el.name.split('/')[1:]).split(':')[0], el) for
                      el in this_checkpoint_variables]),
                ignore_missing_vars=ignore_missing_vars))
    if special_assign_vars is not None:
      all_ops.append(get_special_assigns(special_assign_vars))
    def combined(sess):
      for op in all_ops:
        op(sess)
    return combined
Esempio n. 11
0
def restore_model(checkpoint_path,
                  variables_to_restore,
                  ignore_missing_vars=False,
                  var_name_mapper_type=None):
  all_ops = []
  checkpoint_variables = variables_to_restore
  if checkpoint_path.endswith('.npy'):
    vars_to_restore_names = [
      el.name for el in checkpoint_variables]
    key_name_mapper = var_name_mapper.map(var_name_mapper_type)
    init_weights = np.load(checkpoint_path).item()
    init_weights_final = {}
    vars_restored = []
    for key in init_weights.keys():
      for subkey in init_weights[key].keys():
        final_key_name = key_name_mapper(
          key + '/' + subkey)
        if final_key_name not in vars_to_restore_names:
          logging.info('Not using %s from npy' % final_key_name)
          continue
        target_shape = slim.get_model_variables(
          final_key_name)[0].get_shape().as_list()
        pretrained_wts = init_weights[key][subkey].copy()
        target_shape_squeezed = np.delete(
          target_shape, np.where(np.array(target_shape) == 1))
        pretrained_shape_squeezed = np.delete(
          pretrained_wts.shape, np.where(np.array(pretrained_wts.shape) == 1))

        go_ahead = False  # whether or not I'll be able to copy these weights
        if np.any(target_shape_squeezed !=
                  pretrained_shape_squeezed):
          logging.info('Shape mismatch var: %s from npy [%s vs %s]. ' % (
                       final_key_name, target_shape,
                       pretrained_wts.shape))
          if pretrained_shape_squeezed[-2] != target_shape_squeezed[-2]:
            logging.info('Trying repeating channels to make it similar.')
            pretrained_wts = np.repeat(
              np.mean(pretrained_wts, axis=-2, keepdims=True),
              repeats=target_shape_squeezed[-2],
              axis=-2)
            if np.all(target_shape_squeezed == pretrained_wts.shape):
              logging.info('Success! Copying the hacked weights.')
              go_ahead = True
            else:
              logging.info('Couldnot match the weights still.')
        else:
          go_ahead = True
        if go_ahead:
          init_weights_final[final_key_name] = \
            pretrained_wts
          vars_restored.append(final_key_name)
    init_weights = init_weights_final
    for v in vars_to_restore_names:
      if v not in vars_restored:
        logging.fatal('No weights found for %s' % v)
        if not ignore_missing_vars:
          raise ValueError()
    all_ops.append(slim.assign_from_values_fn(init_weights))
  else:
    all_ops.append(assign_from_checkpoint_fn(
      checkpoint_path, checkpoint_variables,
      ignore_missing_vars=ignore_missing_vars,
      resize_variables=True))
  def combined(sess):
    for op in all_ops:
      op(sess)
  return combined
Esempio n. 12
0
def restore_model(checkpoint_paths,
                  variables_to_restore,
                  ignore_missing_vars=False,
                  num_streams=1,
                  checkpoint_style=None,
                  special_assign_vars=None):
    all_ops = []
    if len(checkpoint_paths) == 1 and num_streams > 1:
        logging.info('Provided one checkpoint for multi-stream '
                     'network. Will use this as a saved model '
                     'with this exact multi stream network.')
        all_ops.append(
            slim.assign_from_checkpoint_fn(
                checkpoint_paths[0],
                variables_to_restore,
                ignore_missing_vars=ignore_missing_vars))
    else:
        for sid in range(num_streams):
            this_checkpoint_style = checkpoint_style.split(',')[sid] if \
                                    checkpoint_style is not None else None
            checkpoint_path = checkpoint_paths[sid]
            # assert tf.gfile.Exists(checkpoint_path)
            this_stream_name = 'stream%d/' % sid
            this_checkpoint_variables = [
                var for var in variables_to_restore
                if var in slim.get_model_variables(this_stream_name)
            ]
            if checkpoint_path.endswith('.npy'):
                vars_to_restore_names = [
                    el.name for el in this_checkpoint_variables
                ]
                key_name_mapper = var_name_mapper.map()
                init_weights = np.load(checkpoint_path).item()
                init_weights_final = {}
                vars_restored = []
                for key in init_weights.keys():
                    for subkey in init_weights[key].keys():
                        prefix = this_stream_name
                        if this_checkpoint_style == 'v2_withStream':
                            prefix = 'stream0/'  # because any model trained with stream
                            # will have that stream as 0
                        final_key_name = prefix + key_name_mapper(key + '/' +
                                                                  subkey)
                        if final_key_name not in vars_to_restore_names:
                            logging.error('Not using %s from npy' %
                                          final_key_name)
                            continue

                        target_shape = slim.get_model_variables(
                            final_key_name)[0].get_shape().as_list()
                        pretrained_wts = init_weights[key][subkey]
                        target_shape_squeezed = np.delete(
                            target_shape,
                            np.where(np.array(target_shape) == 1))
                        pretrained_shape_squeezed = np.delete(
                            pretrained_wts.shape,
                            np.where(np.array(pretrained_wts.shape) == 1))
                        if np.all(target_shape_squeezed !=
                                  pretrained_shape_squeezed):
                            logging.error(
                                'Shape mismatch var: %s from npy [%s vs %s]' %
                                (final_key_name, target_shape,
                                 pretrained_wts.shape))

                        init_weights_final[final_key_name] = \
                            pretrained_wts
                        vars_restored.append(final_key_name)
                init_weights = init_weights_final
                for v in vars_to_restore_names:
                    if v not in vars_restored:
                        logging.fatal('No weights found for %s' % v)
                all_ops.append(slim.assign_from_values_fn(init_weights))
            else:
                if this_checkpoint_style != 'v2_withStream':
                    all_ops.append(
                        slim.assign_from_checkpoint_fn(
                            checkpoint_path,
                            # stripping the stream name to map variables
                            dict([('/'.join(
                                el.name.split('/')[1:]).split(':')[0], el)
                                  for el in this_checkpoint_variables]),
                            ignore_missing_vars=ignore_missing_vars))
                else:
                    all_ops.append(
                        slim.assign_from_checkpoint_fn(
                            checkpoint_path,
                            # stripping the stream name to map variables, to stream0,
                            # as the model is v2_withStream, hence must be trained with
                            # stream0/ prefix
                            dict([(
                                '/'.join(['stream0'] +
                                         el.name.split('/')[1:]).split(':')[0],
                                el) for el in this_checkpoint_variables]),
                            ignore_missing_vars=ignore_missing_vars))
    if special_assign_vars is not None:
        all_ops.append(get_special_assigns(special_assign_vars))

    def combined(sess):
        for op in all_ops:
            op(sess)

    return combined
Esempio n. 13
0
def start(period: int) -> threading.Event:
    """Starts a persistent thread exchanging heartbeats between workers.

  Args:
    period: Heartbeat interval in seconds. Heartbeat timeout is set to the
      larger of `period` - 10 and 2s.

  Returns:
    A threading.Event object. Users can choose to call its set() method to shut
    down the heartbeat service gracefully. This isn't necessary in most cases,
    because the heartbeat service automatically shuts down at successful program
    exit through atexit handlers. But in situations when atexit handlers are not
    invoked, such as when multiprocessing processes exit in tests, users can
    manually request a shutdown.
  """
    global _heartbeat_timer
    if _heartbeat_timer is not None:
        logging.warning(
            'A heartbeat thread is already running, skipping this one.')
        return _heartbeat_timer

    task_id = api.client_id()
    num_tasks = api.num_clients()

    # Worker 0 generates a random token. All other workers receive that token.
    if task_id == 0:
        token = np.random.randint(0,
                                  pow(2, 16) - 1)  # reserve the other 16 bits
        signal = np.full([num_tasks], token, dtype=np.int32)
    else:
        signal = np.zeros([num_tasks], dtype=np.int32)
    logging.info('Initial heartbeat signal: %s', signal)

    device = tf_device.DeviceSpec(job=api.job_name(),
                                  replica=0,
                                  task=task_id,
                                  device_type='CPU',
                                  device_index=0)
    # Always use 0 for group and instance keys to reduce unnecessary
    # collective hangs and simplify failure analysis. This also avoid
    # collision with normal collectives.
    with ops.device(device):
        signal = all_reduce(constant_op.constant(signal),
                            group_size=num_tasks,
                            group_key=0,
                            instance_key=0,
                            timeout=max(period - 10, 2)).numpy()
    logging.info('Merged heartbeat signal %s', signal)

    # The merged signal should have equal elements. If not, some worker(s) may be
    # out of sync, and we should terminate all workers.
    if task_id == 0:
        if not np.all(signal == token):
            logging.fatal('Merged heartbeat signal has value != %d', token)
    else:
        if len(set(signal)) != 1:
            logging.fatal('Merged heartbeat signal has unequal elements')
        token = signal[0]

    # On normal main process exit, set the timer to stop the heartbeat thread.
    _heartbeat_timer = threading.Event()

    def stop_heartbeat():
        logging.info('Stopping the heartbeat thread')
        _heartbeat_timer.set()
        # Give the threads some time to clean up.
        time.sleep(max(period // 10, 2))

    atexit.register(stop_heartbeat)

    # Start the persistent heartbeat thread.
    thread = threading.Thread(
        target=_heartbeat,
        args=[period, _heartbeat_timer, token, num_tasks, task_id, device],
        daemon=True)
    thread.start()

    return _heartbeat_timer