def _make_sequence_graph(transition=None, model_state=None, x=None, initial_xelt=None, context=None, length=None, temperature=1.0, hp=None, back_prop=False): """Construct the graph to process a sequence of categorical integers. If `x` is given, the graph processes the sequence `x` one element at a time. At step `i`, the model receives the `i`th element as input, and its output is used to predict the `i + 1`th element. The last element is not processed, as there would be no further element available to compare against and compute loss. To ensure all data is processed during TBPTT, segments `x` fed into successive computations of the graph should overlap by 1. If `x` is not given, `initial_xelt` must be given as the first input to the model. Further elements are constructed from the model's predictions. Args: transition: model transition function mapping (xelt, model_state, context) to (output, new_model_state). model_state: initial state of the model. x: Sequence of integer (categorical) inputs. Not needed if sampling. Axes [time, batch]. initial_xelt: When sampling, x is not given; initial_xelt specifies the input x[0] to the first timestep. context: a `Tensor` denoting context, e.g. for conditioning. length: Optional length of sequence. Inferred from `x` if possible. temperature: Softmax temperature to use for sampling. hp: Model hyperparameters. back_prop: Whether the graph will be backpropagated through. Returns: Namespace containing relevant symbolic variables. """ with tf.variable_scope("seq") as scope: # if the caching device is not set explicitly, set it such that the # variables for the RNN are all cached locally. if scope.caching_device is None: scope.set_caching_device(lambda op: op.device) if length is None: length = tf.shape(x)[0] def _make_ta(name, **kwargs): # infer_shape=False because it is too strict; it considers unknown # dimensions to be incompatible with anything else. Effectively that # requires all shapes to be fully defined at graph construction time. return tf.TensorArray(tensor_array_name=name, infer_shape=False, **kwargs) state = NS(i=tf.constant(0), model=model_state) state.xhats = _make_ta("xhats", dtype=tf.int32, size=length, clear_after_read=False) state.xhats = state.xhats.write(0, initial_xelt if x is None else x[0, :]) state.exhats = _make_ta("exhats", dtype=tf.float32, size=length - LEFTOVER) if x is not None: state.losses = _make_ta("losses", dtype=tf.float32, size=length - LEFTOVER) state.errors = _make_ta("errors", dtype=tf.bool, size=length - LEFTOVER) state = tfutil.while_loop( cond=lambda state: state.i < length - LEFTOVER, body=ft.partial(make_transition_graph, transition=transition, x=x, context=context, temperature=temperature, hp=hp), loop_vars=state, back_prop=back_prop) # pack TensorArrays for key in "exhats xhats losses errors".split(): if key in state: state[key] = state[key].pack() ts = NS() ts.final_state = state ts.xhat = state.xhats[1:, :] ts.final_xhatelt = state.xhats[length - 1, :] if x is not None: ts.loss = tf.reduce_mean(state.losses) ts.error = tf.reduce_mean(tf.to_float(state.errors)) ts.final_x = x # expose the final, unprocessed element of x for convenience ts.final_xelt = x[length - 1, :] return ts
def _make_sequence_graph_with_unroll(self, model_state=None, x=None, initial_xelt=None, context=None, length=None, temperature=1.0, hp=None, back_prop=False): """Create a sequence graph by unrolling upper layers. This method is similar to `_make_sequence_graph`, except that `length` must be provided. The resulting graph behaves in the same way as that constructed by `_make_sequence_graph`, except that the upper layers are outside of the while loop and so the gradient can actually be truncated between runs of lower layers. If `x` is given, the graph processes the sequence `x` one element at a time. At step `i`, the model receives the `i`th element as input, and its output is used to predict the `i + 1`th element. The last element is not processed, as there would be no further element available to compare against and compute loss. To ensure all data is processed during TBPTT, segments `x` fed into successive computations of the graph should overlap by 1. If `x` is not given, `initial_xelt` must be given as the first input to the model. Further elements are constructed from the model's predictions. Args: model_state: initial state of the model. x: Sequence of integer (categorical) inputs. Not needed if sampling. Axes [time, batch]. initial_xelt: When sampling, x is not given; initial_xelt specifies the input x[0] to the first timestep. context: a `Tensor` denoting context, e.g. for conditioning. Axes [batch, features]. length: Optional length of sequence. Inferred from `x` if possible. temperature: Softmax temperature to use for sampling. hp: Model hyperparameters. back_prop: Whether the graph will be backpropagated through. Raises: ValueError: if `length` is not an int. Returns: Namespace containing relevant symbolic variables. """ if length is None or not isinstance(length, int): raise ValueError( "For partial unrolling, length must be known at graph construction time." ) if model_state is None: model_state = self.state_placeholders() state = NS(model=model_state, inner_initial_xelt=initial_xelt, xhats=[], losses=[], errors=[]) # i suspect ugly gradient biases may occur if gradients are truncated # somewhere halfway through the cycle. ensure we start at a cycle boundary. state.model.time = tfutil.assertion(state.model.time, tf.equal(state.model.time, 0), [state.model.time], name="outer_alignment_assertion") # ensure we end at a cycle boundary too. assert (length - LEFTOVER) % self.period == 0 inner_period = int(np.prod(hp.periods[:self.outer_indices[0] + 1])) # hp.boundaries specifies truncation boundaries relative to the end of the sequence and in terms # of each layer's own steps; translate this to be relative to the beginning of the sequence and # in terms of sequence elements. note that due to the dynamic unrolling of the inner graph, the # inner layers necessarily get truncated at the topmost inner layer's boundary. boundaries = [ length - 1 - hp.boundaries[i] * int(np.prod(hp.periods[:i + 1])) for i in range(len(hp.periods)) ] assert all(0 <= boundary and boundary < length - LEFTOVER for boundary in boundaries) assert boundaries == list(reversed(sorted(boundaries))) print "length %s periods %s boundaries %s %s inner period %s" % ( length, hp.periods, hp.boundaries, boundaries, inner_period) outer_step_count = length // inner_period for outer_time in range(outer_step_count): if outer_time > 0: tf.get_variable_scope().reuse_variables() # update outer layers (wrap in seq scope to be consistent with the fully # symbolic version of this graph) with tf.variable_scope("seq"): # truncate gradient (only effective on outer layers) for i in range(len(self.cells)): if outer_time * inner_period <= boundaries[i]: state.model.cells[i] = list( map(tf.stop_gradient, state.model.cells[i])) state.model.cells = Wayback.transition( outer_time * inner_period, state.model.cells, self.cells, below=None, above=context, subset=self.outer_indices, hp=hp, symbolic=False) # run inner layers on subsequence if x is None: inner_x = None else: start = inner_period * outer_time stop = inner_period * (outer_time + 1) + LEFTOVER inner_x = x[start:stop, :] # grab a copy of the outer states. they will not be updated in the inner # loop, so we can put back the copy after the inner loop completes. # this avoids the gradient truncation due to calling `while_loop` with # `back_prop=False`. outer_cell_states = NS.Copy(state.model.cells[self.outer_slice]) def _inner_transition(input_, state, context=None): assert not context state.cells = Wayback.transition(state.time, state.cells, self.cells, below=input_, above=None, subset=self.inner_indices, hp=hp, symbolic=True) state.time += 1 state.time %= self.period h = self.get_output(state) return h, state inner_back_prop = back_prop and outer_time * inner_period >= boundaries[ self.inner_indices[-1]] inner_ts = _make_sequence_graph( transition=_inner_transition, model_state=state.model, x=inner_x, initial_xelt=state.inner_initial_xelt, temperature=temperature, hp=hp, back_prop=inner_back_prop) state.model = inner_ts.final_state.model state.inner_initial_xelt = inner_ts.final_xelt if x is not None else inner_ts.final_xhatelt state.final_xhatelt = inner_ts.final_xhatelt if x is not None: state.final_x = inner_x state.final_xelt = inner_ts.final_xelt # track only losses and errors after the boundary to avoid bypassing the truncation boundary. if inner_back_prop: state.losses.append(inner_ts.loss) state.errors.append(inner_ts.error) state.xhats.append(inner_ts.xhat) # restore static outer states state.model.cells[self.outer_slice] = outer_cell_states # double check alignment to be safe state.model.time = tfutil.assertion( state.model.time, tf.equal(state.model.time % inner_period, 0), [state.model.time, tf.shape(inner_x)], name="inner_alignment_assertion") ts = NS() ts.xhat = tf.concat(0, state.xhats) ts.final_xhatelt = state.final_xhatelt ts.final_state = state if x is not None: ts.final_x = state.final_x ts.final_xelt = state.final_xelt # inner means are all on the same sample size, so taking their mean is valid ts.loss = tf.reduce_mean(state.losses) ts.error = tf.reduce_mean(state.errors) return ts