Exemplo n.º 1
0
def while_loop(cond, body, loop_vars, **kwargs):
    """Like `tf.while_loop` but with structured `loop_vars`.

  Args:
    cond: as in `tf.while_loop`, but takes a single `loop_vars` argument.
    body: as in `tf.while_loop`, but takes and returns a single `loop_vars`
          tree which it is allowed to modify.
    loop_vars: as in `tf.while_loop`, but consists of a Namespace tree.
    **kwargs: passed onto `tf.while_loop`.

  Returns:
    A Namespace tree structure containing the final values of the loop
    variables.
  """
    def _cond(*flat_vars):
        return cond(NS.UnflattenLike(loop_vars, flat_vars))

    def _body(*flat_vars):
        return NS.Flatten(body(NS.UnflattenLike(loop_vars, flat_vars)))

    return NS.UnflattenLike(
        loop_vars,
        tf.while_loop(cond=_cond,
                      body=_body,
                      loop_vars=NS.Flatten(loop_vars),
                      **kwargs))
Exemplo n.º 2
0
    def run(self, session, examples, max_step_count=None, hooks=None, hp=None):
        tensors = self.tensors.Extract(
            "loss error summaries global_step training_op learning_rate final_state.model"
        )
        state = NS(global_step=tf.train.global_step(session,
                                                    self.tensors.global_step),
                   model=self.model.initial_state(hp.batch_size))
        while True:
            for batch in util.batches(examples, hp.batch_size):
                for segment in util.segments(batch,
                                             self.segment_length,
                                             overlap=LEFTOVER):
                    if max_step_count is not None and state.global_step >= max_step_count:
                        return

                    hooks.Get("step.before", util.noop)(state)
                    x, = util.examples_as_arrays(segment)
                    feed_dict = {self.tensors.x: x.T}
                    feed_dict.update(self.model.feed_dict(state.model))
                    values = tfutil.run(session, tensors, feed_dict=feed_dict)
                    state.model = values.final_state.model
                    state.global_step = values.global_step
                    hooks.Get("step.after", util.noop)(state, values)

                    print("step #%d loss %f error %f learning rate %e" %
                          (values.global_step, values.loss, values.error,
                           values.learning_rate))

                    if np.isnan(values.loss):
                        raise ValueError("loss has become NaN")
Exemplo n.º 3
0
 def testGet(self):
     ns = NS(foo=NS(bar="baz"))
     self.assertRaises(KeyError, lambda: ns["foo"]["baz"])
     self.assertIsNone(ns.Get("foo.baz"))
     x = object()
     self.assertEqual(x, ns.Get("foo.baz", x))
     self.assertEqual("baz", ns.Get("foo.bar"))
     self.assertEqual(NS(bar="baz"), ns.Get("foo"))
Exemplo n.º 4
0
 def testFlatCallFlatZip(self):
     before = NS(v=2, w=NS(x=1, y=NS(z=0)))
     after = NS.FlatCall(lambda xs: [2 * x for x in xs], before)
     self.assertEqual(NS(v=4, w=NS(x=2, y=NS(z=0))), after)
     self.assertItemsEqual([(2, 4), (1, 2), (0, 0)],
                           list(NS.FlatZip([before, after])))
     after.w.y.a = 6
     self.assertRaises(ValueError, lambda: NS.FlatZip([before, after]))
     self.assertItemsEqual([(2, 4), (0, 0)],
                           list(NS.FlatZip([before, after], "v w.y.z")))
Exemplo n.º 5
0
    def run(self, session, primers, length, temperature, hp=None):
        batch_size = len(primers)
        # process in segments to avoid tensorflow eating all the memory
        max_segment_length = min(10000, hp.segment_length)

        print "conditioning..."
        segment_length = min(max_segment_length,
                             max(len(primer[0]) for primer in primers))

        state = NS(model=self.model.initial_state(batch_size))
        for segment in util.segments(primers, segment_length,
                                     overlap=LEFTOVER):
            x, = util.examples_as_arrays(segment)
            feed_dict = {self.tensors.x: x.T}
            feed_dict.update(self.model.feed_dict(state.model))
            values = tfutil.run(session,
                                tensors=self.tensors.cond.Extract(
                                    "final_state.model final_xelt"),
                                feed_dict=feed_dict)
            state.model = values.final_state.model
            sys.stderr.write(".")
        sys.stderr.write("\n")

        cond_values = values

        print "sampling..."
        length_left = length + LEFTOVER
        xhats = []
        state = NS(model=cond_values.final_state.model,
                   initial_xelt=cond_values.final_xelt)
        while length_left > 0:
            segment_length = min(max_segment_length, length_left)
            length_left -= segment_length

            feed_dict = {
                self.tensors.initial_xelt: state.initial_xelt,
                self.tensors.length: segment_length,
                self.tensors.temperature: temperature
            }
            feed_dict.update(self.model.feed_dict(state.model))
            sample_values = tfutil.run(
                session,
                tensors=self.tensors.sample.Extract(
                    "final_state.model xhat final_xhatelt"),
                feed_dict=feed_dict),
            state.model = sample_values.final_state.model
            state.initial_xelt = sample_values.final_xhatelt

            xhats.append(sample_values.xhat)
            sys.stderr.write(".")
        sys.stderr.write("\n")

        xhat = np.concatenate(xhats, axis=0)
        return xhat.T
Exemplo n.º 6
0
    def run(self,
            session,
            examples,
            max_step_count=None,
            hp=None,
            aggregates=None):
        aggregates = NS(aggregates or {})
        for key in "loss error".split():
            if key not in aggregates:
                aggregates[key] = util.MeanAggregate()

        tensors = self.tensors.Extract(*[key for key in aggregates.Keys()])
        tensors.Update(self.tensors.Extract("final_state.model"))

        state = NS(step=0, model=self.model.initial_state(hp.batch_size))

        try:
            for batch in util.batches(examples, hp.batch_size):
                for segment in util.segments(batch,
                                             hp.segment_length,
                                             overlap=hp.chunk_size):
                    if max_step_count is not None and state.step >= max_step_count:
                        raise StopIteration()

                    x, = util.examples_as_arrays(segment)
                    feed_dict = {self.tensors.x: x.T}
                    feed_dict.update(self.model.feed_dict(state.model))
                    values = NS.FlatCall(
                        ft.partial(session.run, feed_dict=feed_dict), tensors)

                    for key in aggregates.Keys():
                        aggregates[key].add(values[key])

                    sys.stderr.write(".")
                    state.model = values.final_state.model
                    state.step += 1
        except StopIteration:
            pass

        sys.stderr.write("\n")

        values = NS(
            (key, aggregate.value) for key, aggregate in aggregates.Items())

        values.summaries = [
            tf.Summary.Value(tag="%s_valid" % key, simple_value=values[key])
            for key in "loss error".split()
        ]
        print "### evaluation loss %6.5f error %6.5f" % (values.loss,
                                                         values.error)

        if np.isnan(values.loss):
            raise ValueError("loss has become NaN")

        return values
Exemplo n.º 7
0
    def state_placeholders(self):
        """Get the Tensorflow placeholders for the model's states.

    Returns:
      A Namespace tree containing the placeholders.
    """
        return NS.Copy(self._state_placeholders)
Exemplo n.º 8
0
def cond(pred, fn1, fn2, prototype, **kwargs):
    """Like `tf.cond` but with structured collections of variables.

  Args:
    pred: boolean Tensor, as in `tf.cond`.
    fn1: a callable representing the `then` branch as in `tf.cond`, but
         may return an arbitrary Namespace tree.
    fn2: a callable representing the `else` branch as in `tf.cond`, but
         may return an arbitrary Namespace tree.
    prototype: an example Namespace tree to indicate the structure of the
               values returned from `fn1` and `fn2`.
    **kwargs: passed onto `tf.cond`.

  Returns:
    Like `tf.cond`, except structured like `prototype`.
  """
    def wrap_branch(fn):
        def wrapped_branch():
            tree = fn()
            liszt = NS.Flatten(tree)
            return liszt

        return wrapped_branch

    results = tf.cond(pred, wrap_branch(fn1), wrap_branch(fn2), **kwargs)
    # tf.cond unpacks singleton lists returned from fn1, fn2 -_-
    if not isinstance(results, (tuple, list)):
        results = [results]
    # need a prototype to unflatten because at this point neither fn1 nor fn2
    # have been called
    tree3 = NS.UnflattenLike(prototype, results)
    return tree3
Exemplo n.º 9
0
    def run(self, session, primers, length, temperature, hp=None):
        batch_size = len(primers)
        # process in segments to avoid tensorflow eating all the memory
        max_segment_length = min(10000, hp.segment_length)

        print "conditioning..."
        segment_length = min(max_segment_length,
                             max(len(primer[0]) for primer in primers))
        # ensure segment_length is a multiple of chunk_size
        segment_length -= segment_length % hp.chunk_size

        state = NS(model=self.model.initial_state(batch_size))
        for segment in util.segments(primers,
                                     segment_length,
                                     overlap=hp.chunk_size):
            x, = util.examples_as_arrays(segment)
            feed_dict = {self.tensors.x: x.T}
            feed_dict.update(self.model.feed_dict(state.model))
            values = NS.FlatCall(
                ft.partial(session.run, feed_dict=feed_dict),
                self.tensors.cond.Extract("final_state.model final_xchunk"))
            state.model = values.final_state.model
            sys.stderr.write(".")
        sys.stderr.write("\n")

        cond_values = values

        # make sure length is a multiple of chunk_size
        chunky_length = length + hp.chunk_size - length % hp.chunk_size

        print "sampling..."
        length_left = chunky_length
        xhats = []
        state = NS(model=cond_values.final_state.model,
                   initial_xchunk=cond_values.final_xchunk)
        while length_left > 0:
            segment_length = min(max_segment_length, length_left)
            length_left -= segment_length

            feed_dict = {
                self.tensors.initial_xchunk: state.initial_xchunk,
                self.tensors.length: segment_length,
                self.tensors.temperature: temperature
            }
            feed_dict.update(self.model.feed_dict(state.model))
            sample_values = NS.FlatCall(
                ft.partial(session.run, feed_dict=feed_dict),
                self.tensors.sample.Extract(
                    "final_state.model xhat final_xhatchunk"))
            state.model = sample_values.final_state.model
            state.initial_xchunk = sample_values.final_xhatchunk

            xhats.append(sample_values.xhat)
            sys.stderr.write(".")
        sys.stderr.write("\n")

        xhat = np.concatenate(xhats, axis=0)
        # truncate from chunky_length to the desired sample length
        xhat = xhat[:length]
        return xhat.T
Exemplo n.º 10
0
    def feed_dict(self, state):
        """Construct a feed dict for the model's states.

    Args:
      state: the model state.

    Returns:
      A feed dict mapping each of the model's placeholders to the corresponding
      numerical value in `state`.
    """
        return util.odict(NS.FlatZip([self.state_placeholders(), state]))
Exemplo n.º 11
0
 def __call__(self, x, state, context=None):
     # construct the usual graph without unrolling
     state = NS.Copy(state)
     state.cells = Wayback.transition(state.time,
                                      state.cells,
                                      self.cells,
                                      below=x,
                                      above=context,
                                      hp=self.hp,
                                      symbolic=True)
     state.time += 1
     state.time %= self.period
     return state
Exemplo n.º 12
0
 def _make(self, unused_hp):
     ts = NS()
     ts.x = tf.placeholder(dtype=tf.int32, name="x")
     ts.seq = self.model.make_evaluation_graph(x=ts.x)
     ts.final_state = ts.seq.final_state
     ts.loss = ts.seq.loss
     ts.error = ts.seq.error
     return ts
Exemplo n.º 13
0
    def run(self, session, examples, max_step_count=None, hooks=None, hp=None):
        state = NS(global_step=tf.train.global_step(session,
                                                    self.tensors.global_step),
                   model=self.model.initial_state(hp.batch_size))
        while True:
            for batch in util.batches(examples, hp.batch_size):
                for segment in util.segments(
                        batch,
                        # the last chunk is not processed, so grab
                        # one more to ensure we backpropagate
                        # through at least one full model cycle.
                        # TODO(cotim): rename segment_length to
                        # backprop_length?
                        hp.segment_length + hp.chunk_size,
                        overlap=hp.chunk_size):
                    if max_step_count is not None and state.global_step >= max_step_count:
                        return

                    hooks.Get("step.before", util.noop)(state)
                    x, = util.examples_as_arrays(segment)
                    feed_dict = {self.tensors.x: x.T}
                    feed_dict.update(self.model.feed_dict(state.model))
                    values = NS.FlatCall(
                        ft.partial(session.run, feed_dict=feed_dict),
                        self.tensors.Extract(
                            "loss error summaries global_step training_op learning_rate final_state.model"
                        ))
                    state.model = values.final_state.model
                    state.global_step = values.global_step
                    hooks.Get("step.after", util.noop)(state, values)

                    print("step #%d loss %f error %f learning rate %e" %
                          (values.global_step, values.loss, values.error,
                           values.learning_rate))

                    if np.isnan(values.loss):
                        raise ValueError("loss has become NaN")
Exemplo n.º 14
0
def make_transition_graph(state,
                          transition,
                          x=None,
                          context=None,
                          temperature=1.0,
                          hp=None):
    """Make the graph that processes a single sequence element.

  Args:
    state: `_make_sequence_graph` loop state.
    transition: Model transition function mapping (xelt, model_state,
        context) to (output, new_model_state).
    x: Sequence of integer (categorical) inputs. Axes [time, batch].
    context: Optional Tensor denoting context, shaped [batch, ?].
    temperature: Softmax temperature to use for sampling.
    hp: Model hyperparameters.

  Returns:
    Updated loop state.
  """
    state = NS.Copy(state)

    xelt = tfutil.shaped_one_hot(
        state.xhats.read(state.i) if x is None else x[state.i, :],
        [None, hp.data_dim])
    embedding = tfutil.layers([xelt], sizes=hp.io_sizes, use_bn=hp.use_bn)
    h, state.model = transition(embedding, state.model, context=context)

    # predict the next elt
    with tf.variable_scope("xhat") as scope:
        embedding = tfutil.layers([h], sizes=hp.io_sizes, use_bn=hp.use_bn)
        exhat = tfutil.project(embedding, output_dim=hp.data_dim)
        xhat = tfutil.sample(exhat, temperature)
        state.xhats = state.xhats.write(state.i + LEFTOVER, xhat)

    if x is not None:
        target = tfutil.shaped_one_hot(x[state.i + 1], [None, hp.data_dim])
        state.losses = state.losses.write(
            state.i, tf.nn.softmax_cross_entropy_with_logits(exhat, target))
        state.errors = state.errors.write(
            state.i,
            tf.not_equal(tf.nn.top_k(exhat)[1],
                         tf.nn.top_k(target)[1]))
        state.exhats = state.exhats.write(state.i, exhat)

    state.i += 1
    return state
Exemplo n.º 15
0
def run(session, tensors, **run_kwargs):
  # too damn big
  trace = False

  if trace:
    run_metadata = tf.RunMetadata()
    run_kwargs["options"] = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_kwargs["run_metadata"] = run_metadata

  values = NS.FlatCall(ft.partial(session.run, **run_kwargs), tensors)
  
  if trace:
    from tensorflow.python.client import timeline
    trace = timeline.Timeline(step_stats=run_metadata.step_stats)
    with open("timeline_%s.json" % "_".join(".".join(map(str, key)) for key in tensors.Keys()), "w") as trace_file:
      trace_file.write(trace.generate_chrome_trace_format())

  return values
Exemplo n.º 16
0
 def __call__(self, x, state, context=None):
     state = NS.Copy(state)
     for i, _ in enumerate(self.cells):
         cell_inputs = []
         if i == 0:
             cell_inputs.append(x)
         if context is not None and i == len(self.cells) - 1:
             cell_inputs.append(context)
         if self.hp.vskip:
             # feed in state of all other layers
             cell_inputs.extend(self.cells[j].get_output(state.cells[j])
                                for j in range(len(self.cells)) if j != i)
         else:
             # feed in state of layer below
             if i > 0:
                 cell_inputs.append(self.cells[i - 1].get_output(
                     state.cells[i - 1]))
         state.cells[i] = self.cells[i].transition(cell_inputs,
                                                   state.cells[i],
                                                   scope="cell%i" % i)
     return state
Exemplo n.º 17
0
    def _make(self, hp):
        ts = NS()
        ts.x = tf.placeholder(dtype=tf.int32, name="x")

        # conditioning graph
        ts.cond = self.model.make_evaluation_graph(x=ts.x)

        # generation graph
        tf.get_variable_scope().reuse_variables()
        ts.initial_xelt = tf.placeholder(dtype=tf.int32,
                                         name="initial_xelt",
                                         shape=[None])
        ts.length = tf.placeholder(dtype=tf.int32, name="length", shape=[])
        ts.temperature = tf.placeholder(dtype=tf.float32,
                                        name="temperature",
                                        shape=[])
        ts.sample = self.model.make_sampling_graph(
            initial_xelt=ts.initial_xelt,
            length=ts.length,
            temperature=ts.temperature)

        return ts
Exemplo n.º 18
0
    def _make(self, hp, global_step=None):
        ts = NS()
        ts.global_step = global_step
        ts.x = tf.placeholder(dtype=tf.int32, name="x")
        length = hp.segment_length + hp.chunk_size
        ts.seq = self.model.make_training_graph(x=ts.x, length=length)
        ts.final_state = ts.seq.final_state
        ts.loss = ts.seq.loss
        ts.error = ts.seq.error

        ts.learning_rate = tf.Variable(hp.initial_learning_rate,
                                       dtype=tf.float32,
                                       trainable=False,
                                       name="learning_rate")
        ts.decay_op = tf.assign(ts.learning_rate,
                                ts.learning_rate * hp.decay_rate)
        ts.optimizer = tf.train.AdamOptimizer(ts.learning_rate)
        ts.params = tf.trainable_variables()
        print[param.name for param in ts.params]

        ts.gradients = tf.gradients(ts.loss, ts.params)

        loose_params = [
            param for param, gradient in util.equizip(ts.params, ts.gradients)
            if gradient is None
        ]
        if loose_params:
            raise ValueError("loose parameters: %s" %
                             " ".join(param.name for param in loose_params))

        # tensorflow fails miserably to compute gradient for these
        for reg_var in tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES):
            ts.gradients[ts.params.index(reg_var)] += (
                hp.weight_decay *
                tf.gradients(tf.sqrt(tf.reduce_sum(reg_var**2)), [reg_var])[0])

        ts.clipped_gradients, _ = tf.clip_by_global_norm(
            ts.gradients, hp.clip_norm)
        ts.training_op = ts.optimizer.apply_gradients(
            util.equizip(ts.clipped_gradients, ts.params),
            global_step=ts.global_step)

        ts.summaries = [
            tf.scalar_summary("loss_train", ts.loss),
            tf.scalar_summary("error_train", ts.error),
            tf.scalar_summary("learning_rate", ts.learning_rate)
        ]
        for parameter, gradient in util.equizip(ts.params, ts.gradients):
            ts.summaries.append(
                tf.scalar_summary("meanlogabs_%s" % parameter.name,
                                  tfutil.meanlogabs(parameter)))
            ts.summaries.append(
                tf.scalar_summary("meanlogabsgrad_%s" % parameter.name,
                                  tfutil.meanlogabs(gradient)))

        return ts
Exemplo n.º 19
0
def make_transition_graph(state,
                          transition,
                          x=None,
                          context=None,
                          temperature=1.0,
                          hp=None):
    """Make the graph that processes a single sequence element.

  Args:
    state: `_make_sequence_graph` loop state.
    transition: Model transition function mapping (xchunk, model_state,
        context) to (output, new_model_state).
    x: Sequence of integer (categorical) inputs. Axes [time, batch].
    context: Optional Tensor denoting context, shaped [batch, ?].
    temperature: Softmax temperature to use for sampling.
    hp: Model hyperparameters.

  Returns:
    Updated loop state.
  """
    state = NS.Copy(state)

    xchunk = _get_flat_chunk(state.xhats if x is None else x,
                             state.i * hp.chunk_size,
                             hp.chunk_size,
                             depth=hp.data_dim)
    embedding = tfutil.layers([xchunk], sizes=hp.io_sizes, use_bn=hp.use_bn)
    h, state.model = transition(embedding, state.model, context=context)

    # predict the next chunk
    exhats = []
    with tf.variable_scope("xhat") as scope:
        for j in range(hp.chunk_size):
            if j > 0:
                scope.reuse_variables()

            xchunk = _get_flat_chunk(state.xhats if x is None else x,
                                     state.i * hp.chunk_size + j,
                                     hp.chunk_size,
                                     depth=hp.data_dim)
            embedding = tfutil.layers([h, xchunk],
                                      sizes=hp.io_sizes,
                                      use_bn=hp.use_bn)
            exhat = tfutil.project(embedding, output_dim=hp.data_dim)
            exhats.append(exhat)

            state.xhats = state.xhats.write((state.i + 1) * hp.chunk_size + j,
                                            tfutil.sample(exhat, temperature))

    if x is not None:
        targets = tf.unpack(_get_1hot_chunk(x, (state.i + 1) * hp.chunk_size,
                                            hp.chunk_size,
                                            depth=hp.data_dim),
                            num=hp.chunk_size,
                            axis=1)
        state.losses = _put_chunk(state.losses, state.i * hp.chunk_size, [
            tf.nn.softmax_cross_entropy_with_logits(exhat, target)
            for exhat, target in util.equizip(exhats, targets)
        ])
        state.errors = _put_chunk(state.errors, state.i * hp.chunk_size, [
            tf.not_equal(tf.nn.top_k(exhat)[1],
                         tf.nn.top_k(target)[1])
            for exhat, target in util.equizip(exhats, targets)
        ])
        state.exhats = _put_chunk(state.exhats, state.i * hp.chunk_size,
                                  exhats)

    state.i += 1
    return state
Exemplo n.º 20
0
 def wrapped_branch():
     tree = fn()
     liszt = NS.Flatten(tree)
     return liszt
Exemplo n.º 21
0
def _make_sequence_graph(transition=None,
                         model_state=None,
                         x=None,
                         initial_xelt=None,
                         context=None,
                         length=None,
                         temperature=1.0,
                         hp=None,
                         back_prop=False):
    """Construct the graph to process a sequence of categorical integers.

  If `x` is given, the graph processes the sequence `x` one element at a time.  At step `i`, the
  model receives the `i`th element as input, and its output is used to predict the `i + 1`th
  element.

  The last element is not processed, as there would be no further element available to compare
  against and compute loss. To ensure all data is processed during TBPTT, segments `x` fed into
  successive computations of the graph should overlap by 1.

  If `x` is not given, `initial_xelt` must be given as the first input to the model.  Further
  elements are constructed from the model's predictions.

  Args:
    transition: model transition function mapping (xelt, model_state,
        context) to (output, new_model_state).
    model_state: initial state of the model.
    x: Sequence of integer (categorical) inputs. Not needed if sampling.
        Axes [time, batch].
    initial_xelt: When sampling, x is not given; initial_xelt specifies
        the input x[0] to the first timestep.
    context: a `Tensor` denoting context, e.g. for conditioning.
    length: Optional length of sequence. Inferred from `x` if possible.
    temperature: Softmax temperature to use for sampling.
    hp: Model hyperparameters.
    back_prop: Whether the graph will be backpropagated through.

  Returns:
    Namespace containing relevant symbolic variables.
  """
    with tf.variable_scope("seq") as scope:
        # if the caching device is not set explicitly, set it such that the
        # variables for the RNN are all cached locally.
        if scope.caching_device is None:
            scope.set_caching_device(lambda op: op.device)

        if length is None:
            length = tf.shape(x)[0]

        def _make_ta(name, **kwargs):
            # infer_shape=False because it is too strict; it considers unknown
            # dimensions to be incompatible with anything else. Effectively that
            # requires all shapes to be fully defined at graph construction time.
            return tf.TensorArray(tensor_array_name=name,
                                  infer_shape=False,
                                  **kwargs)

        state = NS(i=tf.constant(0), model=model_state)

        state.xhats = _make_ta("xhats",
                               dtype=tf.int32,
                               size=length,
                               clear_after_read=False)
        state.xhats = state.xhats.write(0,
                                        initial_xelt if x is None else x[0, :])

        state.exhats = _make_ta("exhats",
                                dtype=tf.float32,
                                size=length - LEFTOVER)

        if x is not None:
            state.losses = _make_ta("losses",
                                    dtype=tf.float32,
                                    size=length - LEFTOVER)
            state.errors = _make_ta("errors",
                                    dtype=tf.bool,
                                    size=length - LEFTOVER)

        state = tfutil.while_loop(
            cond=lambda state: state.i < length - LEFTOVER,
            body=ft.partial(make_transition_graph,
                            transition=transition,
                            x=x,
                            context=context,
                            temperature=temperature,
                            hp=hp),
            loop_vars=state,
            back_prop=back_prop)

        # pack TensorArrays
        for key in "exhats xhats losses errors".split():
            if key in state:
                state[key] = state[key].pack()

        ts = NS()
        ts.final_state = state
        ts.xhat = state.xhats[1:, :]
        ts.final_xhatelt = state.xhats[length - 1, :]
        if x is not None:
            ts.loss = tf.reduce_mean(state.losses)
            ts.error = tf.reduce_mean(tf.to_float(state.errors))
            ts.final_x = x
            # expose the final, unprocessed element of x for convenience
            ts.final_xelt = x[length - 1, :]
        return ts
Exemplo n.º 22
0
    def testMisc(self):
        ns = NS()
        ns.w = 0
        ns["x"] = 3
        ns.x = 1
        ns.y = NS(z=2)
        self.assertEqual(list(ns.Keys()), [("w", ), ("x", ), ("y", "z")])
        self.assertEqual(list(ns.Values()), [0, 1, 2])
        self.assertEqual(list(ns.Items()), [(("w", ), 0), (("x", ), 1),
                                            (("y", "z"), 2)])
        self.assertEqual(ns.AsDict(),
                         OrderedDict([("w", 0), ("x", 1), ("y", NS(z=2))]))
        ns.Update(ns.y)
        self.assertEqual(list(ns), [("w", ), ("x", ), ("y", "z"), ("z", )])
        self.assertEqual(list(ns.Keys()), [("w", ), ("x", ), ("y", "z"),
                                           ("z", )])
        self.assertEqual(list(ns.Values()), [0, 1, 2, 2])
        self.assertEqual(list(ns.Items()), [(("w", ), 0), (("x", ), 1),
                                            (("y", "z"), 2), (("z", ), 2)])
        self.assertEqual(
            ns.AsDict(),
            OrderedDict([("w", 0), ("x", 1), ("y", NS(z=2)), ("z", 2)]))

        ns = NS(v=2, w=NS(x=1, y=[3, NS(z=0)]))
        self.assertItemsEqual([("v", ), ("w", "x"), ("w", "y", 0),
                               ("w", "y", 1, "z")], list(ns.Keys()))
Exemplo n.º 23
0
    def _make_sequence_graph_with_unroll(self,
                                         model_state=None,
                                         x=None,
                                         initial_xelt=None,
                                         context=None,
                                         length=None,
                                         temperature=1.0,
                                         hp=None,
                                         back_prop=False):
        """Create a sequence graph by unrolling upper layers.

    This method is similar to `_make_sequence_graph`, except that `length` must be provided. The
    resulting graph behaves in the same way as that constructed by `_make_sequence_graph`, except
    that the upper layers are outside of the while loop and so the gradient can actually be
    truncated between runs of lower layers.

    If `x` is given, the graph processes the sequence `x` one element at a time.  At step `i`, the
    model receives the `i`th element as input, and its output is used to predict the `i + 1`th
    element.

    The last element is not processed, as there would be no further element available to compare
    against and compute loss. To ensure all data is processed during TBPTT, segments `x` fed into
    successive computations of the graph should overlap by 1.

    If `x` is not given, `initial_xelt` must be given as the first input to the model.  Further
    elements are constructed from the model's predictions.

    Args:
      model_state: initial state of the model.
      x: Sequence of integer (categorical) inputs. Not needed if sampling.
          Axes [time, batch].
      initial_xelt: When sampling, x is not given; initial_xelt specifies
          the input x[0] to the first timestep.
      context: a `Tensor` denoting context, e.g. for conditioning.
          Axes [batch, features].
      length: Optional length of sequence. Inferred from `x` if possible.
      temperature: Softmax temperature to use for sampling.
      hp: Model hyperparameters.
      back_prop: Whether the graph will be backpropagated through.

    Raises:
      ValueError: if `length` is not an int.

    Returns:
      Namespace containing relevant symbolic variables.
    """
        if length is None or not isinstance(length, int):
            raise ValueError(
                "For partial unrolling, length must be known at graph construction time."
            )

        if model_state is None:
            model_state = self.state_placeholders()

        state = NS(model=model_state,
                   inner_initial_xelt=initial_xelt,
                   xhats=[],
                   losses=[],
                   errors=[])

        # i suspect ugly gradient biases may occur if gradients are truncated
        # somewhere halfway through the cycle. ensure we start at a cycle boundary.
        state.model.time = tfutil.assertion(state.model.time,
                                            tf.equal(state.model.time, 0),
                                            [state.model.time],
                                            name="outer_alignment_assertion")
        # ensure we end at a cycle boundary too.
        assert (length - LEFTOVER) % self.period == 0

        inner_period = int(np.prod(hp.periods[:self.outer_indices[0] + 1]))

        # hp.boundaries specifies truncation boundaries relative to the end of the sequence and in terms
        # of each layer's own steps; translate this to be relative to the beginning of the sequence and
        # in terms of sequence elements. note that due to the dynamic unrolling of the inner graph, the
        # inner layers necessarily get truncated at the topmost inner layer's boundary.
        boundaries = [
            length - 1 - hp.boundaries[i] * int(np.prod(hp.periods[:i + 1]))
            for i in range(len(hp.periods))
        ]
        assert all(0 <= boundary and boundary < length - LEFTOVER
                   for boundary in boundaries)
        assert boundaries == list(reversed(sorted(boundaries)))

        print "length %s periods %s boundaries %s %s inner period %s" % (
            length, hp.periods, hp.boundaries, boundaries, inner_period)

        outer_step_count = length // inner_period
        for outer_time in range(outer_step_count):
            if outer_time > 0:
                tf.get_variable_scope().reuse_variables()

            # update outer layers (wrap in seq scope to be consistent with the fully
            # symbolic version of this graph)
            with tf.variable_scope("seq"):
                # truncate gradient (only effective on outer layers)
                for i in range(len(self.cells)):
                    if outer_time * inner_period <= boundaries[i]:
                        state.model.cells[i] = list(
                            map(tf.stop_gradient, state.model.cells[i]))

                state.model.cells = Wayback.transition(
                    outer_time * inner_period,
                    state.model.cells,
                    self.cells,
                    below=None,
                    above=context,
                    subset=self.outer_indices,
                    hp=hp,
                    symbolic=False)

            # run inner layers on subsequence
            if x is None:
                inner_x = None
            else:
                start = inner_period * outer_time
                stop = inner_period * (outer_time + 1) + LEFTOVER
                inner_x = x[start:stop, :]

            # grab a copy of the outer states. they will not be updated in the inner
            # loop, so we can put back the copy after the inner loop completes.
            # this avoids the gradient truncation due to calling `while_loop` with
            # `back_prop=False`.
            outer_cell_states = NS.Copy(state.model.cells[self.outer_slice])

            def _inner_transition(input_, state, context=None):
                assert not context
                state.cells = Wayback.transition(state.time,
                                                 state.cells,
                                                 self.cells,
                                                 below=input_,
                                                 above=None,
                                                 subset=self.inner_indices,
                                                 hp=hp,
                                                 symbolic=True)
                state.time += 1
                state.time %= self.period
                h = self.get_output(state)
                return h, state

            inner_back_prop = back_prop and outer_time * inner_period >= boundaries[
                self.inner_indices[-1]]
            inner_ts = _make_sequence_graph(
                transition=_inner_transition,
                model_state=state.model,
                x=inner_x,
                initial_xelt=state.inner_initial_xelt,
                temperature=temperature,
                hp=hp,
                back_prop=inner_back_prop)

            state.model = inner_ts.final_state.model
            state.inner_initial_xelt = inner_ts.final_xelt if x is not None else inner_ts.final_xhatelt
            state.final_xhatelt = inner_ts.final_xhatelt
            if x is not None:
                state.final_x = inner_x
                state.final_xelt = inner_ts.final_xelt
                # track only losses and errors after the boundary to avoid bypassing the truncation boundary.
                if inner_back_prop:
                    state.losses.append(inner_ts.loss)
                    state.errors.append(inner_ts.error)
            state.xhats.append(inner_ts.xhat)

            # restore static outer states
            state.model.cells[self.outer_slice] = outer_cell_states

            # double check alignment to be safe
            state.model.time = tfutil.assertion(
                state.model.time,
                tf.equal(state.model.time % inner_period, 0),
                [state.model.time, tf.shape(inner_x)],
                name="inner_alignment_assertion")

        ts = NS()
        ts.xhat = tf.concat(0, state.xhats)
        ts.final_xhatelt = state.final_xhatelt
        ts.final_state = state
        if x is not None:
            ts.final_x = state.final_x
            ts.final_xelt = state.final_xelt
            # inner means are all on the same sample size, so taking their mean is valid
            ts.loss = tf.reduce_mean(state.losses)
            ts.error = tf.reduce_mean(state.errors)
        return ts
Exemplo n.º 24
0
 def testCopy(self):
     before = NS(v=2, w=NS(x=1, y=NS(z=0)))
     after = NS.Copy(before)
     self.assertEqual(before, after)
     self.assertTrue(
         all(a is b for a, b in zip(NS.Flatten(after), NS.Flatten(before))))
Exemplo n.º 25
0
 def _body(*flat_vars):
     return NS.Flatten(body(NS.UnflattenLike(loop_vars, flat_vars)))
Exemplo n.º 26
0
 def testFlattenUnflatten(self):
     before = NS(v=2, w=NS(x=1, y=NS(z=0)))
     flat = NS.Flatten(before)
     after = NS.UnflattenLike(before, flat)
     self.assertEqual(before, after)
Exemplo n.º 27
0
 def _cond(*flat_vars):
     return cond(NS.UnflattenLike(loop_vars, flat_vars))