Example #1
0
    def run(self, session, primers, length, temperature, hp=None):
        batch_size = len(primers)
        # process in segments to avoid tensorflow eating all the memory
        max_segment_length = min(10000, hp.segment_length)

        print "conditioning..."
        segment_length = min(max_segment_length,
                             max(len(primer[0]) for primer in primers))
        # ensure segment_length is a multiple of chunk_size
        segment_length -= segment_length % hp.chunk_size

        state = NS(model=self.model.initial_state(batch_size))
        for segment in util.segments(primers,
                                     segment_length,
                                     overlap=hp.chunk_size):
            x, = util.examples_as_arrays(segment)
            feed_dict = {self.tensors.x: x.T}
            feed_dict.update(self.model.feed_dict(state.model))
            values = NS.FlatCall(
                ft.partial(session.run, feed_dict=feed_dict),
                self.tensors.cond.Extract("final_state.model final_xchunk"))
            state.model = values.final_state.model
            sys.stderr.write(".")
        sys.stderr.write("\n")

        cond_values = values

        # make sure length is a multiple of chunk_size
        chunky_length = length + hp.chunk_size - length % hp.chunk_size

        print "sampling..."
        length_left = chunky_length
        xhats = []
        state = NS(model=cond_values.final_state.model,
                   initial_xchunk=cond_values.final_xchunk)
        while length_left > 0:
            segment_length = min(max_segment_length, length_left)
            length_left -= segment_length

            feed_dict = {
                self.tensors.initial_xchunk: state.initial_xchunk,
                self.tensors.length: segment_length,
                self.tensors.temperature: temperature
            }
            feed_dict.update(self.model.feed_dict(state.model))
            sample_values = NS.FlatCall(
                ft.partial(session.run, feed_dict=feed_dict),
                self.tensors.sample.Extract(
                    "final_state.model xhat final_xhatchunk"))
            state.model = sample_values.final_state.model
            state.initial_xchunk = sample_values.final_xhatchunk

            xhats.append(sample_values.xhat)
            sys.stderr.write(".")
        sys.stderr.write("\n")

        xhat = np.concatenate(xhats, axis=0)
        # truncate from chunky_length to the desired sample length
        xhat = xhat[:length]
        return xhat.T
Example #2
0
 def testGet(self):
     ns = NS(foo=NS(bar="baz"))
     self.assertRaises(KeyError, lambda: ns["foo"]["baz"])
     self.assertIsNone(ns.Get("foo.baz"))
     x = object()
     self.assertEqual(x, ns.Get("foo.baz", x))
     self.assertEqual("baz", ns.Get("foo.bar"))
     self.assertEqual(NS(bar="baz"), ns.Get("foo"))
Example #3
0
    def run(self,
            session,
            examples,
            max_step_count=None,
            hp=None,
            aggregates=None):
        aggregates = NS(aggregates or {})
        for key in "loss error".split():
            if key not in aggregates:
                aggregates[key] = util.MeanAggregate()

        tensors = self.tensors.Extract(*[key for key in aggregates.Keys()])
        tensors.Update(self.tensors.Extract("final_state.model"))

        state = NS(step=0, model=self.model.initial_state(hp.batch_size))

        try:
            for batch in util.batches(examples, hp.batch_size):
                for segment in util.segments(batch,
                                             hp.segment_length,
                                             overlap=hp.chunk_size):
                    if max_step_count is not None and state.step >= max_step_count:
                        raise StopIteration()

                    x, = util.examples_as_arrays(segment)
                    feed_dict = {self.tensors.x: x.T}
                    feed_dict.update(self.model.feed_dict(state.model))
                    values = NS.FlatCall(
                        ft.partial(session.run, feed_dict=feed_dict), tensors)

                    for key in aggregates.Keys():
                        aggregates[key].add(values[key])

                    sys.stderr.write(".")
                    state.model = values.final_state.model
                    state.step += 1
        except StopIteration:
            pass

        sys.stderr.write("\n")

        values = NS(
            (key, aggregate.value) for key, aggregate in aggregates.Items())

        values.summaries = [
            tf.Summary.Value(tag="%s_valid" % key, simple_value=values[key])
            for key in "loss error".split()
        ]
        print "### evaluation loss %6.5f error %6.5f" % (values.loss,
                                                         values.error)

        if np.isnan(values.loss):
            raise ValueError("loss has become NaN")

        return values
Example #4
0
    def run(self, session, primers, length, temperature, hp=None):
        batch_size = len(primers)
        # process in segments to avoid tensorflow eating all the memory
        max_segment_length = min(10000, hp.segment_length)

        print "conditioning..."
        segment_length = min(max_segment_length,
                             max(len(primer[0]) for primer in primers))

        state = NS(model=self.model.initial_state(batch_size))
        for segment in util.segments(primers, segment_length,
                                     overlap=LEFTOVER):
            x, = util.examples_as_arrays(segment)
            feed_dict = {self.tensors.x: x.T}
            feed_dict.update(self.model.feed_dict(state.model))
            values = tfutil.run(session,
                                tensors=self.tensors.cond.Extract(
                                    "final_state.model final_xelt"),
                                feed_dict=feed_dict)
            state.model = values.final_state.model
            sys.stderr.write(".")
        sys.stderr.write("\n")

        cond_values = values

        print "sampling..."
        length_left = length + LEFTOVER
        xhats = []
        state = NS(model=cond_values.final_state.model,
                   initial_xelt=cond_values.final_xelt)
        while length_left > 0:
            segment_length = min(max_segment_length, length_left)
            length_left -= segment_length

            feed_dict = {
                self.tensors.initial_xelt: state.initial_xelt,
                self.tensors.length: segment_length,
                self.tensors.temperature: temperature
            }
            feed_dict.update(self.model.feed_dict(state.model))
            sample_values = tfutil.run(
                session,
                tensors=self.tensors.sample.Extract(
                    "final_state.model xhat final_xhatelt"),
                feed_dict=feed_dict),
            state.model = sample_values.final_state.model
            state.initial_xelt = sample_values.final_xhatelt

            xhats.append(sample_values.xhat)
            sys.stderr.write(".")
        sys.stderr.write("\n")

        xhat = np.concatenate(xhats, axis=0)
        return xhat.T
Example #5
0
    def test_parse(self):
        self.assertEqual(
            hyperparameters.parse_value(
                """{a: 1, b: [2e0, 3., four], c: {d: "five", "e": False}}"""),
            NS(a=1, b=[2., 3., "four"], c=NS(d="five", e=False)))

        self.assertRaises(
            hyperparameters.ParseError,
            ft.partial(hyperparameters.parse_value, """{a:1, b: [fn()]}"""))
        self.assertRaises(
            hyperparameters.ParseError,
            ft.partial(hyperparameters.parse_value, """{a:1, b: dict(c=2)}"""))
Example #6
0
def get_defaults(**overrides):
    """Get default hyperparameters.

  Args:
    **overrides: overrides for a subset of hyperparameters.

  Raises:
    ValueError: If an override refers to a nonexistent hyperparameter or the
                specified value is of a different type than the default value.

  Returns:
    A Namespace with (possibly overridden) defaults.
  """
    hp = NS((name, hyperparameter.default)
            for name, hyperparameter in schema.AsDict().items())
    for name, value in overrides.items():
        if name not in hp:
            raise ValueError(
                "value provided for nonexistent hyperparameter %s" % name)
        # TODO(cotim): deep typecheck
        if type(value) is not type(hp[name]):
            raise ValueError(
                "value %s (%s) provided for hyperparameter %s does not"
                " match type of default %s (%s)" %
                (value, type(value), name, hp[name], type(hp[name])))
        hp[name] = value
    return hp
Example #7
0
    def _make(self, hp, global_step=None):
        ts = NS()
        ts.global_step = global_step
        ts.x = tf.placeholder(dtype=tf.int32, name="x")
        ts.seq = self.model.make_training_graph(x=ts.x,
                                                length=self.segment_length)
        ts.final_state = ts.seq.final_state
        ts.loss = ts.seq.loss
        ts.error = ts.seq.error

        ts.learning_rate = tf.Variable(hp.initial_learning_rate,
                                       dtype=tf.float32,
                                       trainable=False,
                                       name="learning_rate")
        ts.decay_op = tf.assign(ts.learning_rate,
                                ts.learning_rate * hp.decay_rate)
        ts.optimizer = tf.train.AdamOptimizer(ts.learning_rate)
        ts.params = tf.trainable_variables()
        print[param.name for param in ts.params]

        ts.gradients = tf.gradients(
            ts.loss,
            ts.params,
            # secret memory-conserving sauce
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE)

        loose_params = [
            param for param, gradient in util.equizip(ts.params, ts.gradients)
            if gradient is None
        ]
        if loose_params:
            raise ValueError("loose parameters: %s" %
                             " ".join(param.name for param in loose_params))

        # tensorflow fails miserably to compute gradient for these
        for reg_var in tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES):
            ts.gradients[ts.params.index(reg_var)] += (
                hp.weight_decay *
                tf.gradients(tf.sqrt(tf.reduce_sum(reg_var**2)), [reg_var])[0])

        ts.clipped_gradients, _ = tf.clip_by_global_norm(
            ts.gradients, hp.clip_norm)
        ts.training_op = ts.optimizer.apply_gradients(
            util.equizip(ts.clipped_gradients, ts.params),
            global_step=ts.global_step)

        ts.summaries = [
            tf.summary.scalar("loss_train", ts.loss),
            tf.summary.scalar("error_train", ts.error),
            tf.summary.scalar("learning_rate", ts.learning_rate)
        ]
        for parameter, gradient in util.equizip(ts.params, ts.gradients):
            ts.summaries.append(
                tf.summary.scalar("meanlogabs_%s" % parameter.name,
                                  tfutil.meanlogabs(parameter)))
            ts.summaries.append(
                tf.summary.scalar("meanlogabsgrad_%s" % parameter.name,
                                  tfutil.meanlogabs(gradient)))

        return ts
Example #8
0
    def run(self, session, examples, max_step_count=None, hooks=None, hp=None):
        tensors = self.tensors.Extract(
            "loss error summaries global_step training_op learning_rate final_state.model"
        )
        state = NS(global_step=tf.train.global_step(session,
                                                    self.tensors.global_step),
                   model=self.model.initial_state(hp.batch_size))
        while True:
            for batch in util.batches(examples, hp.batch_size):
                for segment in util.segments(batch,
                                             self.segment_length,
                                             overlap=LEFTOVER):
                    if max_step_count is not None and state.global_step >= max_step_count:
                        return

                    hooks.Get("step.before", util.noop)(state)
                    x, = util.examples_as_arrays(segment)
                    feed_dict = {self.tensors.x: x.T}
                    feed_dict.update(self.model.feed_dict(state.model))
                    values = tfutil.run(session, tensors, feed_dict=feed_dict)
                    state.model = values.final_state.model
                    state.global_step = values.global_step
                    hooks.Get("step.after", util.noop)(state, values)

                    print("step #%d loss %f error %f learning rate %e" %
                          (values.global_step, values.loss, values.error,
                           values.learning_rate))

                    if np.isnan(values.loss):
                        raise ValueError("loss has become NaN")
Example #9
0
 def _make(self, unused_hp):
     ts = NS()
     ts.x = tf.placeholder(dtype=tf.int32, name="x")
     ts.seq = self.model.make_evaluation_graph(x=ts.x)
     ts.final_state = ts.seq.final_state
     ts.loss = ts.seq.loss
     ts.error = ts.seq.error
     return ts
Example #10
0
    def testMisc(self):
        ns = NS()
        ns.w = 0
        ns["x"] = 3
        ns.x = 1
        ns.y = NS(z=2)
        self.assertEqual(list(ns.Keys()), [("w", ), ("x", ), ("y", "z")])
        self.assertEqual(list(ns.Values()), [0, 1, 2])
        self.assertEqual(list(ns.Items()), [(("w", ), 0), (("x", ), 1),
                                            (("y", "z"), 2)])
        self.assertEqual(ns.AsDict(),
                         OrderedDict([("w", 0), ("x", 1), ("y", NS(z=2))]))
        ns.Update(ns.y)
        self.assertEqual(list(ns), [("w", ), ("x", ), ("y", "z"), ("z", )])
        self.assertEqual(list(ns.Keys()), [("w", ), ("x", ), ("y", "z"),
                                           ("z", )])
        self.assertEqual(list(ns.Values()), [0, 1, 2, 2])
        self.assertEqual(list(ns.Items()), [(("w", ), 0), (("x", ), 1),
                                            (("y", "z"), 2), (("z", ), 2)])
        self.assertEqual(
            ns.AsDict(),
            OrderedDict([("w", 0), ("x", 1), ("y", NS(z=2)), ("z", 2)]))

        ns = NS(v=2, w=NS(x=1, y=[3, NS(z=0)]))
        self.assertItemsEqual([("v", ), ("w", "x"), ("w", "y", 0),
                               ("w", "y", 1, "z")], list(ns.Keys()))
Example #11
0
    def __init__(self, cells_, hp):
        """Initialize a `Stack` instance.

    Args:
      cells_: recurrent transition cells, from bottom to top.
      hp: model hyperparameters.
    """
        super(Stack, self).__init__(hp)
        self.cells = list(cells_)
        self._state_placeholders = NS(
            cells=[cell.state_placeholders for cell in self.cells])
Example #12
0
 def testFlatCallFlatZip(self):
     before = NS(v=2, w=NS(x=1, y=NS(z=0)))
     after = NS.FlatCall(lambda xs: [2 * x for x in xs], before)
     self.assertEqual(NS(v=4, w=NS(x=2, y=NS(z=0))), after)
     self.assertItemsEqual([(2, 4), (1, 2), (0, 0)],
                           list(NS.FlatZip([before, after])))
     after.w.y.a = 6
     self.assertRaises(ValueError, lambda: NS.FlatZip([before, after]))
     self.assertItemsEqual([(2, 4), (0, 0)],
                           list(NS.FlatZip([before, after], "v w.y.z")))
Example #13
0
    def __init__(self, cells_, hp):
        """Initialize a `Wayback` instance.

    The following hyperparameters are specific to this model:
      periods: update interval of each layer, from top to bottom. As layer 0
          always runs at every step, periods[0] gives the number of steps
          of layer 0 before layer 1 is updated. periods[-1] gives the
          number of steps to run at the highest layer before the model
          should be considered to have completed a cycle.
      unroll_layer_count: number of upper layers to unroll. Unrolling allows
          for gradient truncation on the levels below.
      carry: whether to carry over each cell's state from one cycle to the next
          or break the chain and compute new initial states based on the state
          of the cell above.

    Args:
      cells_: recurrent transition cells, from top to bottom.
      hp: model hyperparameters.

    Raises:
      ValueError: If the number of cells and the number of periods differ.
    """
        super(Wayback, self).__init__(hp)

        if len(self.hp.periods) != len(cells_):
            raise ValueError("must specify one period for each cell")
        if len(self.hp.boundaries) != len(cells_):
            raise ValueError("must specify one boundary for each cell")
        self.cells = list(cells_)

        cutoff = len(cells_) - self.hp.unroll_layer_count
        self.inner_indices = list(range(cutoff))
        self.outer_indices = list(range(cutoff, len(cells_)))
        self.inner_slice = slice(cutoff)
        self.outer_slice = slice(cutoff, len(cells_))

        self._state_placeholders = NS(
            time=tf.placeholder(dtype=tf.int32, name="time"),
            cells=[cell.state_placeholders for cell in self.cells])
Example #14
0
    def _make(self, hp):
        ts = NS()
        ts.x = tf.placeholder(dtype=tf.int32, name="x")

        # conditioning graph
        ts.cond = self.model.make_evaluation_graph(x=ts.x)

        # generation graph
        tf.get_variable_scope().reuse_variables()
        ts.initial_xelt = tf.placeholder(dtype=tf.int32,
                                         name="initial_xelt",
                                         shape=[None])
        ts.length = tf.placeholder(dtype=tf.int32, name="length", shape=[])
        ts.temperature = tf.placeholder(dtype=tf.float32,
                                        name="temperature",
                                        shape=[])
        ts.sample = self.model.make_sampling_graph(
            initial_xelt=ts.initial_xelt,
            length=ts.length,
            temperature=ts.temperature)

        return ts
Example #15
0
def parse_value(expr):
    """Parse a hyperparameter value.

  A value can be any Python literal. Barewords are converted to strings.
  Dictionaries are converted to Namespaces.

  Args:
    expr: value expression as a string or `ast.expr`

  Raises:
    ParseError: if `expr` is not a literal expression.

  Returns:
    The value represented by `expr`.
  """
    if isinstance(expr, basestring):
        expr = ast.parse(expr).body[0].value

    if isinstance(expr, ast.Num):
        return expr.n
    elif isinstance(expr, ast.Str):
        return expr.s
    elif isinstance(expr, ast.Name):
        try:
            # True/False are represented as Names -_-
            return ast.literal_eval(expr.id)
        except ValueError:
            # interpret as string
            return expr.id
    elif isinstance(expr, ast.List):
        return list(map(parse_value, expr.elts))
    elif isinstance(expr, ast.Tuple):
        return tuple(map(parse_value, expr.elts))
    elif isinstance(expr, ast.Dict):
        return NS((parse_key(key), parse_value(value))
                  for key, value in zip(expr.keys, expr.values))
    else:
        raise ParseError("invalid value", expr)
Example #16
0
    def run(self, session, examples, max_step_count=None, hooks=None, hp=None):
        state = NS(global_step=tf.train.global_step(session,
                                                    self.tensors.global_step),
                   model=self.model.initial_state(hp.batch_size))
        while True:
            for batch in util.batches(examples, hp.batch_size):
                for segment in util.segments(
                        batch,
                        # the last chunk is not processed, so grab
                        # one more to ensure we backpropagate
                        # through at least one full model cycle.
                        # TODO(cotim): rename segment_length to
                        # backprop_length?
                        hp.segment_length + hp.chunk_size,
                        overlap=hp.chunk_size):
                    if max_step_count is not None and state.global_step >= max_step_count:
                        return

                    hooks.Get("step.before", util.noop)(state)
                    x, = util.examples_as_arrays(segment)
                    feed_dict = {self.tensors.x: x.T}
                    feed_dict.update(self.model.feed_dict(state.model))
                    values = NS.FlatCall(
                        ft.partial(session.run, feed_dict=feed_dict),
                        self.tensors.Extract(
                            "loss error summaries global_step training_op learning_rate final_state.model"
                        ))
                    state.model = values.final_state.model
                    state.global_step = values.global_step
                    hooks.Get("step.after", util.noop)(state, values)

                    print("step #%d loss %f error %f learning rate %e" %
                          (values.global_step, values.loss, values.error,
                           values.learning_rate))

                    if np.isnan(values.loss):
                        raise ValueError("loss has become NaN")
Example #17
0
def main(argv):
    assert not argv[1:]

    hp = hyperparameters.parse(FLAGS.hyperparameters)

    print "loading data from %s" % FLAGS.data_dir
    dataset = datasets.construct(FLAGS.data_type,
                                 directory=FLAGS.data_dir,
                                 frequency=hp.sampling_frequency,
                                 bit_depth=hp.bit_depth)
    print "done"
    hp.data_dim = dataset.data_dim

    model_name = get_model_name(hp)
    print model_name
    output_dir = os.path.join(FLAGS.base_output_dir, model_name)

    if not FLAGS.resume:
        if tf.gfile.Exists(output_dir):
            tf.gfile.DeleteRecursively(output_dir)
    if not tf.gfile.Exists(output_dir):
        tf.gfile.MakeDirs(output_dir)

    hyperparameters.dump(os.path.join(output_dir, "hyperparameters.yaml"), hp)

    model = models.construct(hp)

    print "constructing graph..."
    global_step = tf.Variable(0, trainable=False, name="global_step")
    trainer = training.Trainer(model, hp=hp, global_step=global_step)
    tf.get_variable_scope().reuse_variables()
    evaluator = evaluation.Evaluator(model, hp=hp)
    print "done"

    best_saver = tf.train.Saver()
    supervisor = tf.train.Supervisor(logdir=output_dir, summary_op=None)
    session = supervisor.PrepareSession()

    tracking = NS(best_loss=None, reset_time=0)

    def track(loss, step):
        if step % FLAGS.tracking_interval == 0:
            if tracking.best_loss is None or loss < tracking.best_loss:
                tracking.best_loss = loss
                tracking.reset_time = step
                best_saver.save(session,
                                os.path.join(
                                    os.path.dirname(supervisor.save_path),
                                    "best_%i_%s.ckpt" % (step, loss)),
                                global_step=supervisor.global_step)
            elif step - tracking.reset_time > hp.decay_patience:
                session.run(trainer.tensors.decay_op)
                tracking.reset_time = step

    def maybe_validate(state):
        if state.global_step % FLAGS.validation_interval == 0:
            aggregates = {}
            if FLAGS.dump_predictions:
                # extract final exhats and losses for debugging
                aggregates.update(
                    (key, util.LastAggregate()) for key in
                    "seq.final_x final_state.exhats final_state.losses".split(
                    ))
            values = evaluator.run(examples=dataset.examples.valid,
                                   session=session,
                                   hp=hp,
                                   aggregates=aggregates,
                                   max_step_count=FLAGS.max_validation_steps)
            supervisor.summary_computed(session,
                                        tf.Summary(value=values.summaries))
            if FLAGS.dump_predictions:
                np.savez_compressed(
                    os.path.join(os.path.dirname(supervisor.save_path),
                                 "xhats_%i.npz" % state.global_step),
                    # i'm sure we'll get the idea from 100 steps of 10 examples
                    xs=values.seq.final_x[:100, :10],
                    exhats=values.final_state.exhats[:100, :10],
                    losses=values.final_state.losses[:100, :10])
            # track validation loss
            track(values.loss, state.global_step)

    def maybe_stop(_):
        if supervisor.ShouldStop():
            raise StopTraining()

    def before_step_hook(state):
        maybe_validate(state)
        maybe_stop(state)

    def after_step_hook(state, values):
        for summary in values.summaries:
            supervisor.summary_computed(session, summary)
        # track training loss
        #track(values.loss, state.global_step)

    print "training."
    try:
        trainer.run(
            examples=dataset.examples.train[:FLAGS.max_examples],
            session=session,
            hp=hp,
            max_step_count=FLAGS.max_step_count,
            hooks=NS(step=NS(before=before_step_hook, after=after_step_hook)))
    except StopTraining:
        pass
Example #18
0
def _make_sequence_graph(transition=None,
                         model_state=None,
                         x=None,
                         initial_xelt=None,
                         context=None,
                         length=None,
                         temperature=1.0,
                         hp=None,
                         back_prop=False):
    """Construct the graph to process a sequence of categorical integers.

  If `x` is given, the graph processes the sequence `x` one element at a time.  At step `i`, the
  model receives the `i`th element as input, and its output is used to predict the `i + 1`th
  element.

  The last element is not processed, as there would be no further element available to compare
  against and compute loss. To ensure all data is processed during TBPTT, segments `x` fed into
  successive computations of the graph should overlap by 1.

  If `x` is not given, `initial_xelt` must be given as the first input to the model.  Further
  elements are constructed from the model's predictions.

  Args:
    transition: model transition function mapping (xelt, model_state,
        context) to (output, new_model_state).
    model_state: initial state of the model.
    x: Sequence of integer (categorical) inputs. Not needed if sampling.
        Axes [time, batch].
    initial_xelt: When sampling, x is not given; initial_xelt specifies
        the input x[0] to the first timestep.
    context: a `Tensor` denoting context, e.g. for conditioning.
    length: Optional length of sequence. Inferred from `x` if possible.
    temperature: Softmax temperature to use for sampling.
    hp: Model hyperparameters.
    back_prop: Whether the graph will be backpropagated through.

  Returns:
    Namespace containing relevant symbolic variables.
  """
    with tf.variable_scope("seq") as scope:
        # if the caching device is not set explicitly, set it such that the
        # variables for the RNN are all cached locally.
        if scope.caching_device is None:
            scope.set_caching_device(lambda op: op.device)

        if length is None:
            length = tf.shape(x)[0]

        def _make_ta(name, **kwargs):
            # infer_shape=False because it is too strict; it considers unknown
            # dimensions to be incompatible with anything else. Effectively that
            # requires all shapes to be fully defined at graph construction time.
            return tf.TensorArray(tensor_array_name=name,
                                  infer_shape=False,
                                  **kwargs)

        state = NS(i=tf.constant(0), model=model_state)

        state.xhats = _make_ta("xhats",
                               dtype=tf.int32,
                               size=length,
                               clear_after_read=False)
        state.xhats = state.xhats.write(0,
                                        initial_xelt if x is None else x[0, :])

        state.exhats = _make_ta("exhats",
                                dtype=tf.float32,
                                size=length - LEFTOVER)

        if x is not None:
            state.losses = _make_ta("losses",
                                    dtype=tf.float32,
                                    size=length - LEFTOVER)
            state.errors = _make_ta("errors",
                                    dtype=tf.bool,
                                    size=length - LEFTOVER)

        state = tfutil.while_loop(
            cond=lambda state: state.i < length - LEFTOVER,
            body=ft.partial(make_transition_graph,
                            transition=transition,
                            x=x,
                            context=context,
                            temperature=temperature,
                            hp=hp),
            loop_vars=state,
            back_prop=back_prop)

        # pack TensorArrays
        for key in "exhats xhats losses errors".split():
            if key in state:
                state[key] = state[key].pack()

        ts = NS()
        ts.final_state = state
        ts.xhat = state.xhats[1:, :]
        ts.final_xhatelt = state.xhats[length - 1, :]
        if x is not None:
            ts.loss = tf.reduce_mean(state.losses)
            ts.error = tf.reduce_mean(tf.to_float(state.errors))
            ts.final_x = x
            # expose the final, unprocessed element of x for convenience
            ts.final_xelt = x[length - 1, :]
        return ts
Example #19
0
    def _make_sequence_graph_with_unroll(self,
                                         model_state=None,
                                         x=None,
                                         initial_xelt=None,
                                         context=None,
                                         length=None,
                                         temperature=1.0,
                                         hp=None,
                                         back_prop=False):
        """Create a sequence graph by unrolling upper layers.

    This method is similar to `_make_sequence_graph`, except that `length` must be provided. The
    resulting graph behaves in the same way as that constructed by `_make_sequence_graph`, except
    that the upper layers are outside of the while loop and so the gradient can actually be
    truncated between runs of lower layers.

    If `x` is given, the graph processes the sequence `x` one element at a time.  At step `i`, the
    model receives the `i`th element as input, and its output is used to predict the `i + 1`th
    element.

    The last element is not processed, as there would be no further element available to compare
    against and compute loss. To ensure all data is processed during TBPTT, segments `x` fed into
    successive computations of the graph should overlap by 1.

    If `x` is not given, `initial_xelt` must be given as the first input to the model.  Further
    elements are constructed from the model's predictions.

    Args:
      model_state: initial state of the model.
      x: Sequence of integer (categorical) inputs. Not needed if sampling.
          Axes [time, batch].
      initial_xelt: When sampling, x is not given; initial_xelt specifies
          the input x[0] to the first timestep.
      context: a `Tensor` denoting context, e.g. for conditioning.
          Axes [batch, features].
      length: Optional length of sequence. Inferred from `x` if possible.
      temperature: Softmax temperature to use for sampling.
      hp: Model hyperparameters.
      back_prop: Whether the graph will be backpropagated through.

    Raises:
      ValueError: if `length` is not an int.

    Returns:
      Namespace containing relevant symbolic variables.
    """
        if length is None or not isinstance(length, int):
            raise ValueError(
                "For partial unrolling, length must be known at graph construction time."
            )

        if model_state is None:
            model_state = self.state_placeholders()

        state = NS(model=model_state,
                   inner_initial_xelt=initial_xelt,
                   xhats=[],
                   losses=[],
                   errors=[])

        # i suspect ugly gradient biases may occur if gradients are truncated
        # somewhere halfway through the cycle. ensure we start at a cycle boundary.
        state.model.time = tfutil.assertion(state.model.time,
                                            tf.equal(state.model.time, 0),
                                            [state.model.time],
                                            name="outer_alignment_assertion")
        # ensure we end at a cycle boundary too.
        assert (length - LEFTOVER) % self.period == 0

        inner_period = int(np.prod(hp.periods[:self.outer_indices[0] + 1]))

        # hp.boundaries specifies truncation boundaries relative to the end of the sequence and in terms
        # of each layer's own steps; translate this to be relative to the beginning of the sequence and
        # in terms of sequence elements. note that due to the dynamic unrolling of the inner graph, the
        # inner layers necessarily get truncated at the topmost inner layer's boundary.
        boundaries = [
            length - 1 - hp.boundaries[i] * int(np.prod(hp.periods[:i + 1]))
            for i in range(len(hp.periods))
        ]
        assert all(0 <= boundary and boundary < length - LEFTOVER
                   for boundary in boundaries)
        assert boundaries == list(reversed(sorted(boundaries)))

        print "length %s periods %s boundaries %s %s inner period %s" % (
            length, hp.periods, hp.boundaries, boundaries, inner_period)

        outer_step_count = length // inner_period
        for outer_time in range(outer_step_count):
            if outer_time > 0:
                tf.get_variable_scope().reuse_variables()

            # update outer layers (wrap in seq scope to be consistent with the fully
            # symbolic version of this graph)
            with tf.variable_scope("seq"):
                # truncate gradient (only effective on outer layers)
                for i in range(len(self.cells)):
                    if outer_time * inner_period <= boundaries[i]:
                        state.model.cells[i] = list(
                            map(tf.stop_gradient, state.model.cells[i]))

                state.model.cells = Wayback.transition(
                    outer_time * inner_period,
                    state.model.cells,
                    self.cells,
                    below=None,
                    above=context,
                    subset=self.outer_indices,
                    hp=hp,
                    symbolic=False)

            # run inner layers on subsequence
            if x is None:
                inner_x = None
            else:
                start = inner_period * outer_time
                stop = inner_period * (outer_time + 1) + LEFTOVER
                inner_x = x[start:stop, :]

            # grab a copy of the outer states. they will not be updated in the inner
            # loop, so we can put back the copy after the inner loop completes.
            # this avoids the gradient truncation due to calling `while_loop` with
            # `back_prop=False`.
            outer_cell_states = NS.Copy(state.model.cells[self.outer_slice])

            def _inner_transition(input_, state, context=None):
                assert not context
                state.cells = Wayback.transition(state.time,
                                                 state.cells,
                                                 self.cells,
                                                 below=input_,
                                                 above=None,
                                                 subset=self.inner_indices,
                                                 hp=hp,
                                                 symbolic=True)
                state.time += 1
                state.time %= self.period
                h = self.get_output(state)
                return h, state

            inner_back_prop = back_prop and outer_time * inner_period >= boundaries[
                self.inner_indices[-1]]
            inner_ts = _make_sequence_graph(
                transition=_inner_transition,
                model_state=state.model,
                x=inner_x,
                initial_xelt=state.inner_initial_xelt,
                temperature=temperature,
                hp=hp,
                back_prop=inner_back_prop)

            state.model = inner_ts.final_state.model
            state.inner_initial_xelt = inner_ts.final_xelt if x is not None else inner_ts.final_xhatelt
            state.final_xhatelt = inner_ts.final_xhatelt
            if x is not None:
                state.final_x = inner_x
                state.final_xelt = inner_ts.final_xelt
                # track only losses and errors after the boundary to avoid bypassing the truncation boundary.
                if inner_back_prop:
                    state.losses.append(inner_ts.loss)
                    state.errors.append(inner_ts.error)
            state.xhats.append(inner_ts.xhat)

            # restore static outer states
            state.model.cells[self.outer_slice] = outer_cell_states

            # double check alignment to be safe
            state.model.time = tfutil.assertion(
                state.model.time,
                tf.equal(state.model.time % inner_period, 0),
                [state.model.time, tf.shape(inner_x)],
                name="inner_alignment_assertion")

        ts = NS()
        ts.xhat = tf.concat(0, state.xhats)
        ts.final_xhatelt = state.final_xhatelt
        ts.final_state = state
        if x is not None:
            ts.final_x = state.final_x
            ts.final_xelt = state.final_xelt
            # inner means are all on the same sample size, so taking their mean is valid
            ts.loss = tf.reduce_mean(state.losses)
            ts.error = tf.reduce_mean(state.errors)
        return ts
Example #20
0
 def testFlattenUnflatten(self):
     before = NS(v=2, w=NS(x=1, y=NS(z=0)))
     flat = NS.Flatten(before)
     after = NS.UnflattenLike(before, flat)
     self.assertEqual(before, after)
Example #21
0
schema = NS((
    name, NS(name=name, description=description, default=default)
) for name, (description, default) in dict(
    sampling_frequency=("desired waveform time resolution in Hz", 44100),
    bit_depth=("desired waveform amplitude resolution in bits", 8),
    data_dim=("data dimensionality (usually inferred)", 256),
    initial_learning_rate=("initial learning rate", 0.002),
    decay_patience=(
        "how long to wait for improvement before decaying the learning rate",
        100),
    decay_rate=("rate of decay of learning rate", 0.1),
    clip_norm=("ratio for gradient clipping_by_norm", 1),
    batch_size=("number of examples in minibatch", 100),
    use_bn=("whether to use batch normalizatin", False),
    activation=("recurrent activation function to use (tanh/elu/identity)",
                "tanh"),
    io_sizes=("layer sizes for input and output MLPs", [512]),
    weight_decay=("L2 weight decay coefficient", 1e-7),
    segment_length=("length of truncated backpropagation", 1000),
    chunk_size=("number of samples per model step", 1),
    layout=("recurrent connection pattern (stack/wayback)", "stack"),
    cell=("recurrent cell (rnn/lstm/gru)", "lstm"),
    layer_sizes=("number of hidden units in each layer, from bottom to top.",
                 [1000]),
    vskip=("vertical skip connections between all layers", False),
    periods=
    ("update interval for each layer, from bottom to top. only used for the wayback layout",
     [1000]),
    boundaries=
    ("number of periods to backprop through for each layer, from bottom to top. only used for the wayback layout",
     [1]),
    unroll_layer_count=
    ("number of upper layers to move outside the while loop. only used for the wayback layout",
     0),
    carry=
    ("whether to carry state between cycles or restart based on context. only used for the wayback layout",
     True)).items())
Example #22
0
 def testExtract(self):
     ns = NS(v=2, w=NS(x=1, y=NS(z=0))).Extract("w.y v")
     self.assertEqual(ns.v, 2)
     self.assertEqual(ns.w, NS(y=NS(z=0)))
     self.assertEqual(ns.w.y, NS(z=0))
Example #23
0
 def testCopy(self):
     before = NS(v=2, w=NS(x=1, y=NS(z=0)))
     after = NS.Copy(before)
     self.assertEqual(before, after)
     self.assertTrue(
         all(a is b for a, b in zip(NS.Flatten(after), NS.Flatten(before))))
Example #24
0
 def initial_state(self, batch_size):
     return NS(
         time=0,
         cells=[cell.initial_state(batch_size) for cell in self.cells])