Ejemplo n.º 1
0
  def get_activations(self):
    """Get activations for features.

    This should be called within `computation` that is passed to
      `tpu.replicate` and friends.

    Returns:
      A dictionary mapping from `String` of feature name to `Tensor`
        of activation.
    """
    recv_activations = tpu_ops.recv_tpu_embedding_activations(
        num_outputs=len(self._table_to_config_dict),
        config=self._config_proto.SerializeToString())

    activations = collections.OrderedDict()
    for table_id, table in enumerate(self._table_to_features_dict):
      features = self._table_to_features_dict[table]
      for lookup_id, feature in enumerate(features):
        stride = len(self._table_to_features_dict[table])
        activations[feature] = recv_activations[table_id][lookup_id::stride, :]
    return activations
Ejemplo n.º 2
0
  def get_activations(self):
    """Get activations for features.

    This should be called within `computation` that is passed to
      `tpu.replicate` and friends.

    Returns:
      A dictionary mapping from `String` of feature name to `Tensor`
        of activation.
    """
    recv_activations = tpu_ops.recv_tpu_embedding_activations(
        num_outputs=len(self._table_to_config_dict),
        config=self._config_proto.SerializeToString())

    activations = collections.OrderedDict()
    for table_id, table in enumerate(self._table_to_features_dict):
      features = self._table_to_features_dict[table]
      for lookup_id, feature in enumerate(features):
        stride = len(self._table_to_features_dict[table])
        activations[feature] = recv_activations[table_id][lookup_id::stride, :]
    return activations