Ejemplo n.º 1
0
def _make_sequence_graph(transition=None,
                         model_state=None,
                         x=None,
                         initial_xelt=None,
                         context=None,
                         length=None,
                         temperature=1.0,
                         hp=None,
                         back_prop=False):
    """Construct the graph to process a sequence of categorical integers.

  If `x` is given, the graph processes the sequence `x` one element at a time.  At step `i`, the
  model receives the `i`th element as input, and its output is used to predict the `i + 1`th
  element.

  The last element is not processed, as there would be no further element available to compare
  against and compute loss. To ensure all data is processed during TBPTT, segments `x` fed into
  successive computations of the graph should overlap by 1.

  If `x` is not given, `initial_xelt` must be given as the first input to the model.  Further
  elements are constructed from the model's predictions.

  Args:
    transition: model transition function mapping (xelt, model_state,
        context) to (output, new_model_state).
    model_state: initial state of the model.
    x: Sequence of integer (categorical) inputs. Not needed if sampling.
        Axes [time, batch].
    initial_xelt: When sampling, x is not given; initial_xelt specifies
        the input x[0] to the first timestep.
    context: a `Tensor` denoting context, e.g. for conditioning.
    length: Optional length of sequence. Inferred from `x` if possible.
    temperature: Softmax temperature to use for sampling.
    hp: Model hyperparameters.
    back_prop: Whether the graph will be backpropagated through.

  Returns:
    Namespace containing relevant symbolic variables.
  """
    with tf.variable_scope("seq") as scope:
        # if the caching device is not set explicitly, set it such that the
        # variables for the RNN are all cached locally.
        if scope.caching_device is None:
            scope.set_caching_device(lambda op: op.device)

        if length is None:
            length = tf.shape(x)[0]

        def _make_ta(name, **kwargs):
            # infer_shape=False because it is too strict; it considers unknown
            # dimensions to be incompatible with anything else. Effectively that
            # requires all shapes to be fully defined at graph construction time.
            return tf.TensorArray(tensor_array_name=name,
                                  infer_shape=False,
                                  **kwargs)

        state = NS(i=tf.constant(0), model=model_state)

        state.xhats = _make_ta("xhats",
                               dtype=tf.int32,
                               size=length,
                               clear_after_read=False)
        state.xhats = state.xhats.write(0,
                                        initial_xelt if x is None else x[0, :])

        state.exhats = _make_ta("exhats",
                                dtype=tf.float32,
                                size=length - LEFTOVER)

        if x is not None:
            state.losses = _make_ta("losses",
                                    dtype=tf.float32,
                                    size=length - LEFTOVER)
            state.errors = _make_ta("errors",
                                    dtype=tf.bool,
                                    size=length - LEFTOVER)

        state = tfutil.while_loop(
            cond=lambda state: state.i < length - LEFTOVER,
            body=ft.partial(make_transition_graph,
                            transition=transition,
                            x=x,
                            context=context,
                            temperature=temperature,
                            hp=hp),
            loop_vars=state,
            back_prop=back_prop)

        # pack TensorArrays
        for key in "exhats xhats losses errors".split():
            if key in state:
                state[key] = state[key].pack()

        ts = NS()
        ts.final_state = state
        ts.xhat = state.xhats[1:, :]
        ts.final_xhatelt = state.xhats[length - 1, :]
        if x is not None:
            ts.loss = tf.reduce_mean(state.losses)
            ts.error = tf.reduce_mean(tf.to_float(state.errors))
            ts.final_x = x
            # expose the final, unprocessed element of x for convenience
            ts.final_xelt = x[length - 1, :]
        return ts
Ejemplo n.º 2
0
    def _make_sequence_graph_with_unroll(self,
                                         model_state=None,
                                         x=None,
                                         initial_xelt=None,
                                         context=None,
                                         length=None,
                                         temperature=1.0,
                                         hp=None,
                                         back_prop=False):
        """Create a sequence graph by unrolling upper layers.

    This method is similar to `_make_sequence_graph`, except that `length` must be provided. The
    resulting graph behaves in the same way as that constructed by `_make_sequence_graph`, except
    that the upper layers are outside of the while loop and so the gradient can actually be
    truncated between runs of lower layers.

    If `x` is given, the graph processes the sequence `x` one element at a time.  At step `i`, the
    model receives the `i`th element as input, and its output is used to predict the `i + 1`th
    element.

    The last element is not processed, as there would be no further element available to compare
    against and compute loss. To ensure all data is processed during TBPTT, segments `x` fed into
    successive computations of the graph should overlap by 1.

    If `x` is not given, `initial_xelt` must be given as the first input to the model.  Further
    elements are constructed from the model's predictions.

    Args:
      model_state: initial state of the model.
      x: Sequence of integer (categorical) inputs. Not needed if sampling.
          Axes [time, batch].
      initial_xelt: When sampling, x is not given; initial_xelt specifies
          the input x[0] to the first timestep.
      context: a `Tensor` denoting context, e.g. for conditioning.
          Axes [batch, features].
      length: Optional length of sequence. Inferred from `x` if possible.
      temperature: Softmax temperature to use for sampling.
      hp: Model hyperparameters.
      back_prop: Whether the graph will be backpropagated through.

    Raises:
      ValueError: if `length` is not an int.

    Returns:
      Namespace containing relevant symbolic variables.
    """
        if length is None or not isinstance(length, int):
            raise ValueError(
                "For partial unrolling, length must be known at graph construction time."
            )

        if model_state is None:
            model_state = self.state_placeholders()

        state = NS(model=model_state,
                   inner_initial_xelt=initial_xelt,
                   xhats=[],
                   losses=[],
                   errors=[])

        # i suspect ugly gradient biases may occur if gradients are truncated
        # somewhere halfway through the cycle. ensure we start at a cycle boundary.
        state.model.time = tfutil.assertion(state.model.time,
                                            tf.equal(state.model.time, 0),
                                            [state.model.time],
                                            name="outer_alignment_assertion")
        # ensure we end at a cycle boundary too.
        assert (length - LEFTOVER) % self.period == 0

        inner_period = int(np.prod(hp.periods[:self.outer_indices[0] + 1]))

        # hp.boundaries specifies truncation boundaries relative to the end of the sequence and in terms
        # of each layer's own steps; translate this to be relative to the beginning of the sequence and
        # in terms of sequence elements. note that due to the dynamic unrolling of the inner graph, the
        # inner layers necessarily get truncated at the topmost inner layer's boundary.
        boundaries = [
            length - 1 - hp.boundaries[i] * int(np.prod(hp.periods[:i + 1]))
            for i in range(len(hp.periods))
        ]
        assert all(0 <= boundary and boundary < length - LEFTOVER
                   for boundary in boundaries)
        assert boundaries == list(reversed(sorted(boundaries)))

        print "length %s periods %s boundaries %s %s inner period %s" % (
            length, hp.periods, hp.boundaries, boundaries, inner_period)

        outer_step_count = length // inner_period
        for outer_time in range(outer_step_count):
            if outer_time > 0:
                tf.get_variable_scope().reuse_variables()

            # update outer layers (wrap in seq scope to be consistent with the fully
            # symbolic version of this graph)
            with tf.variable_scope("seq"):
                # truncate gradient (only effective on outer layers)
                for i in range(len(self.cells)):
                    if outer_time * inner_period <= boundaries[i]:
                        state.model.cells[i] = list(
                            map(tf.stop_gradient, state.model.cells[i]))

                state.model.cells = Wayback.transition(
                    outer_time * inner_period,
                    state.model.cells,
                    self.cells,
                    below=None,
                    above=context,
                    subset=self.outer_indices,
                    hp=hp,
                    symbolic=False)

            # run inner layers on subsequence
            if x is None:
                inner_x = None
            else:
                start = inner_period * outer_time
                stop = inner_period * (outer_time + 1) + LEFTOVER
                inner_x = x[start:stop, :]

            # grab a copy of the outer states. they will not be updated in the inner
            # loop, so we can put back the copy after the inner loop completes.
            # this avoids the gradient truncation due to calling `while_loop` with
            # `back_prop=False`.
            outer_cell_states = NS.Copy(state.model.cells[self.outer_slice])

            def _inner_transition(input_, state, context=None):
                assert not context
                state.cells = Wayback.transition(state.time,
                                                 state.cells,
                                                 self.cells,
                                                 below=input_,
                                                 above=None,
                                                 subset=self.inner_indices,
                                                 hp=hp,
                                                 symbolic=True)
                state.time += 1
                state.time %= self.period
                h = self.get_output(state)
                return h, state

            inner_back_prop = back_prop and outer_time * inner_period >= boundaries[
                self.inner_indices[-1]]
            inner_ts = _make_sequence_graph(
                transition=_inner_transition,
                model_state=state.model,
                x=inner_x,
                initial_xelt=state.inner_initial_xelt,
                temperature=temperature,
                hp=hp,
                back_prop=inner_back_prop)

            state.model = inner_ts.final_state.model
            state.inner_initial_xelt = inner_ts.final_xelt if x is not None else inner_ts.final_xhatelt
            state.final_xhatelt = inner_ts.final_xhatelt
            if x is not None:
                state.final_x = inner_x
                state.final_xelt = inner_ts.final_xelt
                # track only losses and errors after the boundary to avoid bypassing the truncation boundary.
                if inner_back_prop:
                    state.losses.append(inner_ts.loss)
                    state.errors.append(inner_ts.error)
            state.xhats.append(inner_ts.xhat)

            # restore static outer states
            state.model.cells[self.outer_slice] = outer_cell_states

            # double check alignment to be safe
            state.model.time = tfutil.assertion(
                state.model.time,
                tf.equal(state.model.time % inner_period, 0),
                [state.model.time, tf.shape(inner_x)],
                name="inner_alignment_assertion")

        ts = NS()
        ts.xhat = tf.concat(0, state.xhats)
        ts.final_xhatelt = state.final_xhatelt
        ts.final_state = state
        if x is not None:
            ts.final_x = state.final_x
            ts.final_xelt = state.final_xelt
            # inner means are all on the same sample size, so taking their mean is valid
            ts.loss = tf.reduce_mean(state.losses)
            ts.error = tf.reduce_mean(state.errors)
        return ts