Esempio n. 1
0
    def generate_send_gradients_op(self, feature_to_gradient_dict):
        """Send gradient to TPU embedding.

    Args:
      feature_to_gradient_dict: dict mapping feature names to gradient wrt
        activations.

    Returns:
      SendTPUEmbeddingGradients Op.

    Raises:
      RuntimeError: If `mode` is not `TRAINING`.
    """
        if self._mode != TRAINING:
            raise RuntimeError('Only in training mode gradients need to '
                               'be sent to TPU embedding; got mode {}.'.format(
                                   self._mode))
        gradients = []
        for table in self._table_to_features_dict:
            features = self._table_to_features_dict[table]
            table_gradients = [
                feature_to_gradient_dict[feature] for feature in features
            ]
            concat_table_grads = array_ops.concat(table_gradients, axis=0)
            gradients.append(concat_table_grads)
        return tpu_ops.send_tpu_embedding_gradients(
            inputs=gradients, config=self.config_proto.SerializeToString())
Esempio n. 2
0
  def generate_send_gradients_op(self, feature_to_gradient_dict):
    """Send gradient to TPU embedding.

    Args:
      feature_to_gradient_dict: dict mapping feature names to gradient wrt
        activations.

    Returns:
      SendTPUEmbeddingGradients Op.

    Raises:
      RuntimeError: If `mode` is not `TRAINING`.
    """
    if self._mode != TRAINING:
      raise RuntimeError('Only in training mode gradients need to '
                         'be sent to TPU embedding; got mode {}.'
                         .format(self._mode))
    gradients = []
    for table in self._table_to_features_dict:
      features = self._table_to_features_dict[table]
      table_gradients = [
          feature_to_gradient_dict[feature] for feature in features
      ]
      interleaved_table_grads = array_ops.reshape(
          array_ops.stack(table_gradients, axis=1),
          [-1, table_gradients[0].shape[1]])
      gradients.append(interleaved_table_grads)
    return tpu_ops.send_tpu_embedding_gradients(
        inputs=gradients, config=self.config_proto.SerializeToString())