예제 #1
0
  def __init__(self, master, component_spec):
    """Initializes the feature ID extractor component.

    Args:
      master: dragnn.MasterBuilder object.
      component_spec: dragnn.ComponentSpec proto to be built.
    """
    super(BulkFeatureIdExtractorComponentBuilder, self).__init__(
        master, component_spec)
    check.Eq(len(self.spec.linked_feature), 0, 'Linked features are forbidden')
    for feature_spec in self.spec.fixed_feature:
      check.Lt(feature_spec.embedding_dim, 0,
               'Features must be non-embedded: %s' % feature_spec)
예제 #2
0
def extract_fixed_feature_ids(comp, state, stride):
    """Extracts fixed feature IDs.

  Args:
    comp: Component whose fixed feature IDs we wish to extract.
    state: Live MasterState object for the component.
    stride: Tensor containing current batch * beam size.

  Returns:
    state handle: Updated state handle to be used after this call.
    ids: List of [stride * num_steps, 1] feature IDs per channel.  Missing IDs
         (e.g., due to batch padding) are set to -1.
  """
    num_channels = len(comp.spec.fixed_feature)
    if not num_channels:
        return state.handle, []

    for feature_spec in comp.spec.fixed_feature:
        check.Eq(feature_spec.size, 1, 'All features must have size=1')
        check.Lt(feature_spec.embedding_dim, 0,
                 'All features must be non-embedded')

    state.handle, indices, ids, _, num_steps = dragnn_ops.bulk_fixed_features(
        state.handle, component=comp.name, num_channels=num_channels)
    size = stride * num_steps

    fixed_ids = []
    for channel, feature_spec in enumerate(comp.spec.fixed_feature):
        tf.logging.info('[%s] Adding fixed feature IDs "%s"', comp.name,
                        feature_spec.name)

        # The +1 and -1 increments ensure that missing IDs default to -1.
        #
        # TODO(googleuser): This formula breaks if multiple IDs are extracted at some
        # step.  Try using tf.unique() to enforce the unique-IDS precondition.
        sums = tf.unsorted_segment_sum(ids[channel] + 1, indices[channel],
                                       size) - 1
        sums = tf.expand_dims(sums, axis=1)
        fixed_ids.append(
            network_units.NamedTensor(sums, feature_spec.name, dim=1))
    return state.handle, fixed_ids
예제 #3
0
 def testCheckLt(self):
     check.Lt(1, 2, 'foo')
     with self.assertRaisesRegexp(ValueError, 'bar'):
         check.Lt(1, 1, 'bar')
     with self.assertRaisesRegexp(RuntimeError, 'baz'):
         check.Lt(1, -1, 'baz', RuntimeError)