コード例 #1
0
def hook_dummy_table_variables_to_activations(tpu_embedding, activations,
                                              dummy_table_variables):
    """Have activations depend on dummy table variables for gradient intercept.

  Args:
    tpu_embedding: TPUEmbedding, activations and dummy_table_variables are from
      tpu_embedding.
    activations: An OrderedDict of feature name String to activation tensors.
    dummy_table_variables: An OrderedDict of table name String to dummy table
      variables.

  Returns:
    An OrderedDict of feature name String to activation tensors, which can be
      used just as the activations input.
  """
    new_activations = collections.OrderedDict()
    for feature in activations:
        table = tpu_embedding.feature_to_config_dict[feature].table_id
        new_activations[feature] = tpu_ops.tpu_embedding_activations(
            dummy_table_variables[table],
            activations[feature],
            table_id=list(tpu_embedding.table_to_config_dict).index(table),
            lookup_id=tpu_embedding.table_to_features_dict[table].index(
                feature))
    return new_activations
コード例 #2
0
def hook_dummy_table_variables_to_activations(tpu_embedding, activations,
                                              dummy_table_variables):
  """Have activations depend on dummy table variables for gradient intercept.

  Args:
    tpu_embedding: TPUEmbedding, activations and dummy_table_variables are from
      tpu_embedding.
    activations: An OrderedDict of feature name String to activation tensors.
    dummy_table_variables: An OrderedDict of table name String to dummy table
      variables.

  Returns:
    An OrderedDict of feature name String to activation tensors, which can be
      used just as the activations input.
  """
  new_activations = collections.OrderedDict()
  for feature in activations:
    table = tpu_embedding.feature_to_config_dict[feature].table_id
    new_activations[feature] = tpu_ops.tpu_embedding_activations(
        dummy_table_variables[table],
        activations[feature],
        table_id=list(tpu_embedding.table_to_config_dict).index(table),
        lookup_id=tpu_embedding.table_to_features_dict[table].index(feature))
  return new_activations