def get_activations(self): """Get activations for features. This should be called within `computation` that is passed to `tpu.replicate` and friends. Returns: A dictionary mapping from `String` of feature name to `Tensor` of activation. """ recv_activations = tpu_ops.recv_tpu_embedding_activations( num_outputs=len(self._table_to_config_dict), config=self._config_proto.SerializeToString()) activations = collections.OrderedDict() for table_id, table in enumerate(self._table_to_features_dict): features = self._table_to_features_dict[table] for lookup_id, feature in enumerate(features): start_row = lookup_id * self._batch_size_per_core end_row = start_row + self._batch_size_per_core activations[feature] = gen_tpu_ops.tpu_embedding_activations( self._dummy_table_variables[table_id], recv_activations[table_id][start_row:end_row, :], table_id=table_id, lookup_id=lookup_id) return activations