def get_activations(self): """Get activations for features. This should be called within `computation` that is passed to `tpu.replicate` and friends. Returns: A dictionary mapping from `String` of feature name to `Tensor` of activation. """ recv_activations = tpu_ops.recv_tpu_embedding_activations( num_outputs=len(self._table_to_config_dict), config=self._config_proto.SerializeToString()) activations = collections.OrderedDict() for table_id, table in enumerate(self._table_to_features_dict): features = self._table_to_features_dict[table] for lookup_id, feature in enumerate(features): stride = len(self._table_to_features_dict[table]) activations[feature] = recv_activations[table_id][lookup_id::stride, :] return activations