def unroll_state_saver(input_layer, name, state_shapes, template, lengths=None): """Unrolls the given function with state taken from the state saver. Args: input_layer: The input sequence. name: The name of this layer. state_shapes: A list of shapes, one for each state variable. template: A template with unbound variables for input and states that returns a RecurrentResult. lengths: The length of each item in the batch. If provided, use this to truncate computation. Returns: A sequence from applying the given template to each item in the input sequence. """ state_saver = input_layer.bookkeeper.recurrent_state state_names = [STATE_NAME % name + "_%d" % i for i in xrange(len(state_shapes))] if isinstance(state_saver, bookkeeper.SimpleStateSaver): for state_name, state_shape in zip(state_names, state_shapes): state_saver.AddState(state_name, input_layer.dtype, state_shape) if lengths: max_length = tf.reduce_max(lengths) else: max_length = None results = [] prev_states = [] for state_name, state_shape in zip(state_names, state_shapes): my_shape = list(state_shape) my_shape[0] = -1 prev_states.append(tf.reshape(state_saver.state(state_name), my_shape)) parameters = None for i, layer in enumerate(input_layer.sequence): with input_layer.g.name_scope("unroll_%00d" % i): if i > 0 and max_length: # TODO(eiderman): Right now the everything after length is undefined. # If we can efficiently propagate the last result to the end, then # models with only a final output would require a single softmax # computation. # pylint: disable=cell-var-from-loop result = control_flow_ops.cond( i < max_length, lambda: unwrap_all(*template(layer, *prev_states).flatten()), lambda: unwrap_all(out, *prev_states), ) out = result[0] prev_states = result[1:] else: out, prev_states = template(layer, *prev_states) if parameters is None: parameters = out.layer_parameters results.append(prettytensor.unwrap(out)) updates = [ state_saver.save_state(state_name, prettytensor.unwrap(prev_state)) for state_name, prev_state in zip(state_names, prev_states) ] # Set it up so that update is evaluated when the result of this method is # evaluated by injecting a dependency on an arbitrary result. with tf.control_dependencies(updates): results[0] = tf.identity(results[0]) return input_layer.with_sequence(results, parameters=parameters)
def unroll_state_saver(input_layer, name, state_shapes, template, lengths=None): """Unrolls the given function with state taken from the state saver. Args: input_layer: The input sequence. name: The name of this layer. state_shapes: A list of shapes, one for each state variable. template: A template with unbound variables for input and states that returns a RecurrentResult. lengths: The length of each item in the batch. If provided, use this to truncate computation. Returns: A sequence from applying the given template to each item in the input sequence. """ state_saver = input_layer.bookkeeper.recurrent_state state_names = [ STATE_NAME % name + '_%d' % i for i in xrange(len(state_shapes)) ] if hasattr(state_saver, 'add_state'): for state_name, state_shape in zip(state_names, state_shapes): initial_state = tf.zeros(state_shape[1:], dtype=input_layer.dtype) state_saver.add_state(state_name, initial_state=initial_state, batch_size=state_shape[0]) if lengths is not None: max_length = tf.reduce_max(lengths) else: max_length = None results = [] prev_states = [] for state_name, state_shape in zip(state_names, state_shapes): my_shape = list(state_shape) my_shape[0] = -1 prev_states.append(tf.reshape(state_saver.state(state_name), my_shape)) my_parameters = None for i, layer in enumerate(input_layer.sequence): with input_layer.g.name_scope('unroll_%00d' % i): if i > 0 and max_length is not None: # TODO(eiderman): Right now the everything after length is undefined. # If we can efficiently propagate the last result to the end, then # models with only a final output would require a single softmax # computation. # pylint: disable=cell-var-from-loop result = control_flow_ops.cond( i < max_length, lambda: unwrap_all(*template( layer, *prev_states).flatten()), lambda: unwrap_all(out, *prev_states)) out = result[0] prev_states = result[1:] else: out, prev_states = template(layer, *prev_states) if my_parameters is None: my_parameters = out.layer_parameters results.append(prettytensor.unwrap(out)) updates = [ state_saver.save_state(state_name, prettytensor.unwrap(prev_state)) for state_name, prev_state in zip(state_names, prev_states) ] # Set it up so that update is evaluated when the result of this method is # evaluated by injecting a dependency on an arbitrary result. with tf.control_dependencies(updates): results[0] = tf.identity(results[0]) return input_layer.with_sequence(results, parameters=my_parameters)
def unwrap_all(*args): """Unwraps all of the tensors and returns a list.""" result = [prettytensor.unwrap(x) for x in args] return result