Exemplo n.º 1
0
def unroll_state_saver(input_layer, name, state_shapes, template, lengths=None):
    """Unrolls the given function with state taken from the state saver.

  Args:
    input_layer: The input sequence.
    name: The name of this layer.
    state_shapes: A list of shapes, one for each state variable.
    template: A template with unbound variables for input and states that
      returns a RecurrentResult.
    lengths: The length of each item in the batch.  If provided, use this to
      truncate computation.
  Returns:
    A sequence from applying the given template to each item in the input
    sequence.
  """
    state_saver = input_layer.bookkeeper.recurrent_state
    state_names = [STATE_NAME % name + "_%d" % i for i in xrange(len(state_shapes))]
    if isinstance(state_saver, bookkeeper.SimpleStateSaver):
        for state_name, state_shape in zip(state_names, state_shapes):
            state_saver.AddState(state_name, input_layer.dtype, state_shape)
    if lengths:
        max_length = tf.reduce_max(lengths)
    else:
        max_length = None

    results = []
    prev_states = []
    for state_name, state_shape in zip(state_names, state_shapes):
        my_shape = list(state_shape)
        my_shape[0] = -1
        prev_states.append(tf.reshape(state_saver.state(state_name), my_shape))

    parameters = None
    for i, layer in enumerate(input_layer.sequence):
        with input_layer.g.name_scope("unroll_%00d" % i):
            if i > 0 and max_length:
                # TODO(eiderman): Right now the everything after length is undefined.
                # If we can efficiently propagate the last result to the end, then
                # models with only a final output would require a single softmax
                # computation.
                # pylint: disable=cell-var-from-loop
                result = control_flow_ops.cond(
                    i < max_length,
                    lambda: unwrap_all(*template(layer, *prev_states).flatten()),
                    lambda: unwrap_all(out, *prev_states),
                )
                out = result[0]
                prev_states = result[1:]
            else:
                out, prev_states = template(layer, *prev_states)
        if parameters is None:
            parameters = out.layer_parameters
        results.append(prettytensor.unwrap(out))

    updates = [
        state_saver.save_state(state_name, prettytensor.unwrap(prev_state))
        for state_name, prev_state in zip(state_names, prev_states)
    ]

    # Set it up so that update is evaluated when the result of this method is
    # evaluated by injecting a dependency on an arbitrary result.
    with tf.control_dependencies(updates):
        results[0] = tf.identity(results[0])
    return input_layer.with_sequence(results, parameters=parameters)
Exemplo n.º 2
0
def unroll_state_saver(input_layer,
                       name,
                       state_shapes,
                       template,
                       lengths=None):
    """Unrolls the given function with state taken from the state saver.

  Args:
    input_layer: The input sequence.
    name: The name of this layer.
    state_shapes: A list of shapes, one for each state variable.
    template: A template with unbound variables for input and states that
      returns a RecurrentResult.
    lengths: The length of each item in the batch.  If provided, use this to
      truncate computation.
  Returns:
    A sequence from applying the given template to each item in the input
    sequence.
  """
    state_saver = input_layer.bookkeeper.recurrent_state
    state_names = [
        STATE_NAME % name + '_%d' % i for i in xrange(len(state_shapes))
    ]
    if hasattr(state_saver, 'add_state'):
        for state_name, state_shape in zip(state_names, state_shapes):
            initial_state = tf.zeros(state_shape[1:], dtype=input_layer.dtype)
            state_saver.add_state(state_name,
                                  initial_state=initial_state,
                                  batch_size=state_shape[0])
    if lengths is not None:
        max_length = tf.reduce_max(lengths)
    else:
        max_length = None

    results = []
    prev_states = []
    for state_name, state_shape in zip(state_names, state_shapes):
        my_shape = list(state_shape)
        my_shape[0] = -1
        prev_states.append(tf.reshape(state_saver.state(state_name), my_shape))

    my_parameters = None
    for i, layer in enumerate(input_layer.sequence):
        with input_layer.g.name_scope('unroll_%00d' % i):
            if i > 0 and max_length is not None:
                # TODO(eiderman): Right now the everything after length is undefined.
                # If we can efficiently propagate the last result to the end, then
                # models with only a final output would require a single softmax
                # computation.
                # pylint: disable=cell-var-from-loop
                result = control_flow_ops.cond(
                    i < max_length, lambda: unwrap_all(*template(
                        layer, *prev_states).flatten()),
                    lambda: unwrap_all(out, *prev_states))
                out = result[0]
                prev_states = result[1:]
            else:
                out, prev_states = template(layer, *prev_states)
        if my_parameters is None:
            my_parameters = out.layer_parameters
        results.append(prettytensor.unwrap(out))

    updates = [
        state_saver.save_state(state_name, prettytensor.unwrap(prev_state))
        for state_name, prev_state in zip(state_names, prev_states)
    ]

    # Set it up so that update is evaluated when the result of this method is
    # evaluated by injecting a dependency on an arbitrary result.
    with tf.control_dependencies(updates):
        results[0] = tf.identity(results[0])
    return input_layer.with_sequence(results, parameters=my_parameters)
Exemplo n.º 3
0
def unwrap_all(*args):
    """Unwraps all of the tensors and returns a list."""
    result = [prettytensor.unwrap(x) for x in args]
    return result
Exemplo n.º 4
0
def unwrap_all(*args):
    """Unwraps all of the tensors and returns a list."""
    result = [prettytensor.unwrap(x) for x in args]
    return result