def generate_send_gradients_op(self, feature_to_gradient_dict):
        """Send gradient to TPU embedding.

    Args:
      feature_to_gradient_dict: dict mapping feature names to gradient wrt
        activations.

    Returns:
      SendTPUEmbeddingGradients Op.

    Raises:
      RuntimeError: If `mode` is not `TRAINING`.
    """
        if self._mode != TRAINING:
            raise RuntimeError('Only in training mode gradients need to '
                               'be sent to TPU embedding; got mode {}.'.format(
                                   self._mode))
        gradients = []
        for table in self._table_to_features_dict:
            features = self._table_to_features_dict[table]
            table_gradients = [
                feature_to_gradient_dict[feature] for feature in features
            ]
            concat_table_grads = array_ops.concat(table_gradients, axis=0)
            gradients.append(concat_table_grads)
        return tpu_ops.send_tpu_embedding_gradients(
            inputs=gradients, config=self.config_proto.SerializeToString())
    def generate_send_gradients_op(self, gradient_multipliers=None):
        """Retrieve gradients from collections and send them to TPU embedding.

    Args:
      gradient_multipliers: None, or dict mapping table names to gradient
        multiplier Tensors.

    Returns:
      SendTPUEmbeddingGradients Op.

    Raises:
      ValueError: If required gradients have not been defined.
      RuntimeError: If `mode` is not `TRAINING`.
    """
        if self._mode != TRAINING:
            raise RuntimeError('Only in training mode gradients need to '
                               'be sent to TPU embedding; got mode {}.'.format(
                                   self._mode))

        g = ops.get_default_graph()
        gradients = list()
        for table_id, table in enumerate(self._table_to_config_dict):
            table_gradients = g.get_collection(
                'tpu_embedding_gradients_table_%d' % table_id)
            if any(gradient is None for gradient in table_gradients):
                raise ValueError(
                    'Table {}/{} has undefined gradients: this is probably because the '
                    'model asked TPUEmbedding to compute activations that were not '
                    'used.'.format(table_id, table))
            concat_table_grads = array_ops.concat(table_gradients, axis=0)
            if gradient_multipliers is not None:
                concat_table_grads *= gradient_multipliers[table.name]
            gradients.append(concat_table_grads)

        return tpu_ops.send_tpu_embedding_gradients(
            inputs=gradients, config=self.config_proto.SerializeToString())
示例#3
0
  def generate_send_gradients_op(self, gradient_multipliers=None):
    """Retrieve gradients from collections and send them to TPU embedding.

    Args:
      gradient_multipliers: None, or dict mapping table names to gradient
        multiplier Tensors.

    Returns:
      SendTPUEmbeddingGradients Op.

    Raises:
      ValueError: If required gradients have not been defined.
      RuntimeError: If `mode` is not `TRAINING`.
    """
    if self._mode != TRAINING:
      raise RuntimeError('Only in training mode gradients need to '
                         'be sent to TPU embedding; got mode {}.'
                         .format(self._mode))

    g = ops.get_default_graph()
    gradients = list()
    for table_id, table in enumerate(self._table_to_config_dict):
      table_gradients = g.get_collection(
          'tpu_embedding_gradients_table_%d' % table_id)
      if any(gradient is None for gradient in table_gradients):
        raise ValueError(
            'Table {}/{} has undefined gradients: this is probably because the '
            'model asked TPUEmbedding to compute activations that were not '
            'used.'.format(table_id, table))
      concat_table_grads = array_ops.concat(table_gradients, axis=0)
      if gradient_multipliers is not None:
        concat_table_grads *= gradient_multipliers[table.name]
      gradients.append(concat_table_grads)

    return tpu_ops.send_tpu_embedding_gradients(
        inputs=gradients, config=self.config_proto.SerializeToString())