def fetch_differentiable_fixed_embeddings(comp, state, stride, during_training): """Looks up fixed features with separate, differentiable, embedding lookup. Args: comp: Component whose fixed features we wish to look up. state: live MasterState object for the component. stride: Tensor containing current batch * beam size. during_training: True if this is being called from a training code path. This controls, e.g., the use of feature ID dropout. Returns: state handle: updated state handle to be used after this call fixed_embeddings: list of NamedTensor objects """ _validate_embedded_fixed_features(comp) num_channels = len(comp.spec.fixed_feature) if not num_channels: return state.handle, [] state.handle, indices, ids, weights, num_steps = ( dragnn_ops.bulk_fixed_features(state.handle, component=comp.name, num_channels=num_channels)) fixed_embeddings = [] for channel, feature_spec in enumerate(comp.spec.fixed_feature): differentiable_or_constant = ('constant' if feature_spec.is_constant else 'differentiable') tf.logging.info('[%s] Adding %s fixed feature "%s"', comp.name, differentiable_or_constant, feature_spec.name) if during_training and feature_spec.dropout_id >= 0: ids[channel], weights[ channel] = network_units.apply_feature_id_dropout( ids[channel], weights[channel], feature_spec) size = stride * num_steps * feature_spec.size fixed_embedding = network_units.embedding_lookup( comp.get_variable(network_units.fixed_embeddings_name(channel)), indices[channel], ids[channel], weights[channel], size) if feature_spec.is_constant: fixed_embedding = tf.stop_gradient(fixed_embedding) fixed_embeddings.append( network_units.NamedTensor(fixed_embedding, feature_spec.name)) return state.handle, fixed_embeddings
def fetch_differentiable_fixed_embeddings(comp, state, stride, during_training): """Looks up fixed features with separate, differentiable, embedding lookup. Args: comp: Component whose fixed features we wish to look up. state: live MasterState object for the component. stride: Tensor containing current batch * beam size. during_training: True if this is being called from a training code path. This controls, e.g., the use of feature ID dropout. Returns: state handle: updated state handle to be used after this call fixed_embeddings: list of NamedTensor objects """ _validate_embedded_fixed_features(comp) num_channels = len(comp.spec.fixed_feature) if not num_channels: return state.handle, [] state.handle, indices, ids, weights, num_steps = ( dragnn_ops.bulk_fixed_features( state.handle, component=comp.name, num_channels=num_channels)) fixed_embeddings = [] for channel, feature_spec in enumerate(comp.spec.fixed_feature): differentiable_or_constant = ('constant' if feature_spec.is_constant else 'differentiable') tf.logging.info('[%s] Adding %s fixed feature "%s"', comp.name, differentiable_or_constant, feature_spec.name) if during_training and feature_spec.dropout_id >= 0: ids[channel], weights[channel] = network_units.apply_feature_id_dropout( ids[channel], weights[channel], feature_spec) size = stride * num_steps * feature_spec.size fixed_embedding = network_units.embedding_lookup( comp.get_variable(network_units.fixed_embeddings_name(channel)), indices[channel], ids[channel], weights[channel], size) if feature_spec.is_constant: fixed_embedding = tf.stop_gradient(fixed_embedding) fixed_embeddings.append( network_units.NamedTensor(fixed_embedding, feature_spec.name)) return state.handle, fixed_embeddings
def extract_fixed_feature_ids(comp, state, stride): """Extracts fixed feature IDs. Args: comp: Component whose fixed feature IDs we wish to extract. state: Live MasterState object for the component. stride: Tensor containing current batch * beam size. Returns: state handle: Updated state handle to be used after this call. ids: List of [stride * num_steps, 1] feature IDs per channel. Missing IDs (e.g., due to batch padding) are set to -1. """ num_channels = len(comp.spec.fixed_feature) if not num_channels: return state.handle, [] for feature_spec in comp.spec.fixed_feature: check.Eq(feature_spec.size, 1, 'All features must have size=1') check.Lt(feature_spec.embedding_dim, 0, 'All features must be non-embedded') state.handle, indices, ids, _, num_steps = dragnn_ops.bulk_fixed_features( state.handle, component=comp.name, num_channels=num_channels) size = stride * num_steps fixed_ids = [] for channel, feature_spec in enumerate(comp.spec.fixed_feature): tf.logging.info('[%s] Adding fixed feature IDs "%s"', comp.name, feature_spec.name) # The +1 and -1 increments ensure that missing IDs default to -1. # # TODO(googleuser): This formula breaks if multiple IDs are extracted at some # step. Try using tf.unique() to enforce the unique-IDS precondition. sums = tf.unsorted_segment_sum(ids[channel] + 1, indices[channel], size) - 1 sums = tf.expand_dims(sums, axis=1) fixed_ids.append( network_units.NamedTensor(sums, feature_spec.name, dim=1)) return state.handle, fixed_ids
def extract_fixed_feature_ids(comp, state, stride): """Extracts fixed feature IDs. Args: comp: Component whose fixed feature IDs we wish to extract. state: Live MasterState object for the component. stride: Tensor containing current batch * beam size. Returns: state handle: Updated state handle to be used after this call. ids: List of [stride * num_steps, 1] feature IDs per channel. Missing IDs (e.g., due to batch padding) are set to -1. """ num_channels = len(comp.spec.fixed_feature) if not num_channels: return state.handle, [] for feature_spec in comp.spec.fixed_feature: check.Eq(feature_spec.size, 1, 'All features must have size=1') check.Lt(feature_spec.embedding_dim, 0, 'All features must be non-embedded') state.handle, indices, ids, _, num_steps = dragnn_ops.bulk_fixed_features( state.handle, component=comp.name, num_channels=num_channels) size = stride * num_steps fixed_ids = [] for channel, feature_spec in enumerate(comp.spec.fixed_feature): tf.logging.info('[%s] Adding fixed feature IDs "%s"', comp.name, feature_spec.name) # The +1 and -1 increments ensure that missing IDs default to -1. # # TODO(googleuser): This formula breaks if multiple IDs are extracted at some # step. Try using tf.unique() to enforce the unique-IDS precondition. sums = tf.unsorted_segment_sum(ids[channel] + 1, indices[channel], size) - 1 sums = tf.expand_dims(sums, axis=1) fixed_ids.append(network_units.NamedTensor(sums, feature_spec.name, dim=1)) return state.handle, fixed_ids