def run(self, session, primers, length, temperature, hp=None): batch_size = len(primers) # process in segments to avoid tensorflow eating all the memory max_segment_length = min(10000, hp.segment_length) print "conditioning..." segment_length = min(max_segment_length, max(len(primer[0]) for primer in primers)) # ensure segment_length is a multiple of chunk_size segment_length -= segment_length % hp.chunk_size state = NS(model=self.model.initial_state(batch_size)) for segment in util.segments(primers, segment_length, overlap=hp.chunk_size): x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), self.tensors.cond.Extract("final_state.model final_xchunk")) state.model = values.final_state.model sys.stderr.write(".") sys.stderr.write("\n") cond_values = values # make sure length is a multiple of chunk_size chunky_length = length + hp.chunk_size - length % hp.chunk_size print "sampling..." length_left = chunky_length xhats = [] state = NS(model=cond_values.final_state.model, initial_xchunk=cond_values.final_xchunk) while length_left > 0: segment_length = min(max_segment_length, length_left) length_left -= segment_length feed_dict = { self.tensors.initial_xchunk: state.initial_xchunk, self.tensors.length: segment_length, self.tensors.temperature: temperature } feed_dict.update(self.model.feed_dict(state.model)) sample_values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), self.tensors.sample.Extract( "final_state.model xhat final_xhatchunk")) state.model = sample_values.final_state.model state.initial_xchunk = sample_values.final_xhatchunk xhats.append(sample_values.xhat) sys.stderr.write(".") sys.stderr.write("\n") xhat = np.concatenate(xhats, axis=0) # truncate from chunky_length to the desired sample length xhat = xhat[:length] return xhat.T
def testGet(self): ns = NS(foo=NS(bar="baz")) self.assertRaises(KeyError, lambda: ns["foo"]["baz"]) self.assertIsNone(ns.Get("foo.baz")) x = object() self.assertEqual(x, ns.Get("foo.baz", x)) self.assertEqual("baz", ns.Get("foo.bar")) self.assertEqual(NS(bar="baz"), ns.Get("foo"))
def run(self, session, examples, max_step_count=None, hp=None, aggregates=None): aggregates = NS(aggregates or {}) for key in "loss error".split(): if key not in aggregates: aggregates[key] = util.MeanAggregate() tensors = self.tensors.Extract(*[key for key in aggregates.Keys()]) tensors.Update(self.tensors.Extract("final_state.model")) state = NS(step=0, model=self.model.initial_state(hp.batch_size)) try: for batch in util.batches(examples, hp.batch_size): for segment in util.segments(batch, hp.segment_length, overlap=hp.chunk_size): if max_step_count is not None and state.step >= max_step_count: raise StopIteration() x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), tensors) for key in aggregates.Keys(): aggregates[key].add(values[key]) sys.stderr.write(".") state.model = values.final_state.model state.step += 1 except StopIteration: pass sys.stderr.write("\n") values = NS( (key, aggregate.value) for key, aggregate in aggregates.Items()) values.summaries = [ tf.Summary.Value(tag="%s_valid" % key, simple_value=values[key]) for key in "loss error".split() ] print "### evaluation loss %6.5f error %6.5f" % (values.loss, values.error) if np.isnan(values.loss): raise ValueError("loss has become NaN") return values
def run(self, session, primers, length, temperature, hp=None): batch_size = len(primers) # process in segments to avoid tensorflow eating all the memory max_segment_length = min(10000, hp.segment_length) print "conditioning..." segment_length = min(max_segment_length, max(len(primer[0]) for primer in primers)) state = NS(model=self.model.initial_state(batch_size)) for segment in util.segments(primers, segment_length, overlap=LEFTOVER): x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = tfutil.run(session, tensors=self.tensors.cond.Extract( "final_state.model final_xelt"), feed_dict=feed_dict) state.model = values.final_state.model sys.stderr.write(".") sys.stderr.write("\n") cond_values = values print "sampling..." length_left = length + LEFTOVER xhats = [] state = NS(model=cond_values.final_state.model, initial_xelt=cond_values.final_xelt) while length_left > 0: segment_length = min(max_segment_length, length_left) length_left -= segment_length feed_dict = { self.tensors.initial_xelt: state.initial_xelt, self.tensors.length: segment_length, self.tensors.temperature: temperature } feed_dict.update(self.model.feed_dict(state.model)) sample_values = tfutil.run( session, tensors=self.tensors.sample.Extract( "final_state.model xhat final_xhatelt"), feed_dict=feed_dict), state.model = sample_values.final_state.model state.initial_xelt = sample_values.final_xhatelt xhats.append(sample_values.xhat) sys.stderr.write(".") sys.stderr.write("\n") xhat = np.concatenate(xhats, axis=0) return xhat.T
def test_parse(self): self.assertEqual( hyperparameters.parse_value( """{a: 1, b: [2e0, 3., four], c: {d: "five", "e": False}}"""), NS(a=1, b=[2., 3., "four"], c=NS(d="five", e=False))) self.assertRaises( hyperparameters.ParseError, ft.partial(hyperparameters.parse_value, """{a:1, b: [fn()]}""")) self.assertRaises( hyperparameters.ParseError, ft.partial(hyperparameters.parse_value, """{a:1, b: dict(c=2)}"""))
def get_defaults(**overrides): """Get default hyperparameters. Args: **overrides: overrides for a subset of hyperparameters. Raises: ValueError: If an override refers to a nonexistent hyperparameter or the specified value is of a different type than the default value. Returns: A Namespace with (possibly overridden) defaults. """ hp = NS((name, hyperparameter.default) for name, hyperparameter in schema.AsDict().items()) for name, value in overrides.items(): if name not in hp: raise ValueError( "value provided for nonexistent hyperparameter %s" % name) # TODO(cotim): deep typecheck if type(value) is not type(hp[name]): raise ValueError( "value %s (%s) provided for hyperparameter %s does not" " match type of default %s (%s)" % (value, type(value), name, hp[name], type(hp[name]))) hp[name] = value return hp
def _make(self, hp, global_step=None): ts = NS() ts.global_step = global_step ts.x = tf.placeholder(dtype=tf.int32, name="x") ts.seq = self.model.make_training_graph(x=ts.x, length=self.segment_length) ts.final_state = ts.seq.final_state ts.loss = ts.seq.loss ts.error = ts.seq.error ts.learning_rate = tf.Variable(hp.initial_learning_rate, dtype=tf.float32, trainable=False, name="learning_rate") ts.decay_op = tf.assign(ts.learning_rate, ts.learning_rate * hp.decay_rate) ts.optimizer = tf.train.AdamOptimizer(ts.learning_rate) ts.params = tf.trainable_variables() print[param.name for param in ts.params] ts.gradients = tf.gradients( ts.loss, ts.params, # secret memory-conserving sauce aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE) loose_params = [ param for param, gradient in util.equizip(ts.params, ts.gradients) if gradient is None ] if loose_params: raise ValueError("loose parameters: %s" % " ".join(param.name for param in loose_params)) # tensorflow fails miserably to compute gradient for these for reg_var in tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES): ts.gradients[ts.params.index(reg_var)] += ( hp.weight_decay * tf.gradients(tf.sqrt(tf.reduce_sum(reg_var**2)), [reg_var])[0]) ts.clipped_gradients, _ = tf.clip_by_global_norm( ts.gradients, hp.clip_norm) ts.training_op = ts.optimizer.apply_gradients( util.equizip(ts.clipped_gradients, ts.params), global_step=ts.global_step) ts.summaries = [ tf.summary.scalar("loss_train", ts.loss), tf.summary.scalar("error_train", ts.error), tf.summary.scalar("learning_rate", ts.learning_rate) ] for parameter, gradient in util.equizip(ts.params, ts.gradients): ts.summaries.append( tf.summary.scalar("meanlogabs_%s" % parameter.name, tfutil.meanlogabs(parameter))) ts.summaries.append( tf.summary.scalar("meanlogabsgrad_%s" % parameter.name, tfutil.meanlogabs(gradient))) return ts
def run(self, session, examples, max_step_count=None, hooks=None, hp=None): tensors = self.tensors.Extract( "loss error summaries global_step training_op learning_rate final_state.model" ) state = NS(global_step=tf.train.global_step(session, self.tensors.global_step), model=self.model.initial_state(hp.batch_size)) while True: for batch in util.batches(examples, hp.batch_size): for segment in util.segments(batch, self.segment_length, overlap=LEFTOVER): if max_step_count is not None and state.global_step >= max_step_count: return hooks.Get("step.before", util.noop)(state) x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = tfutil.run(session, tensors, feed_dict=feed_dict) state.model = values.final_state.model state.global_step = values.global_step hooks.Get("step.after", util.noop)(state, values) print("step #%d loss %f error %f learning rate %e" % (values.global_step, values.loss, values.error, values.learning_rate)) if np.isnan(values.loss): raise ValueError("loss has become NaN")
def _make(self, unused_hp): ts = NS() ts.x = tf.placeholder(dtype=tf.int32, name="x") ts.seq = self.model.make_evaluation_graph(x=ts.x) ts.final_state = ts.seq.final_state ts.loss = ts.seq.loss ts.error = ts.seq.error return ts
def testMisc(self): ns = NS() ns.w = 0 ns["x"] = 3 ns.x = 1 ns.y = NS(z=2) self.assertEqual(list(ns.Keys()), [("w", ), ("x", ), ("y", "z")]) self.assertEqual(list(ns.Values()), [0, 1, 2]) self.assertEqual(list(ns.Items()), [(("w", ), 0), (("x", ), 1), (("y", "z"), 2)]) self.assertEqual(ns.AsDict(), OrderedDict([("w", 0), ("x", 1), ("y", NS(z=2))])) ns.Update(ns.y) self.assertEqual(list(ns), [("w", ), ("x", ), ("y", "z"), ("z", )]) self.assertEqual(list(ns.Keys()), [("w", ), ("x", ), ("y", "z"), ("z", )]) self.assertEqual(list(ns.Values()), [0, 1, 2, 2]) self.assertEqual(list(ns.Items()), [(("w", ), 0), (("x", ), 1), (("y", "z"), 2), (("z", ), 2)]) self.assertEqual( ns.AsDict(), OrderedDict([("w", 0), ("x", 1), ("y", NS(z=2)), ("z", 2)])) ns = NS(v=2, w=NS(x=1, y=[3, NS(z=0)])) self.assertItemsEqual([("v", ), ("w", "x"), ("w", "y", 0), ("w", "y", 1, "z")], list(ns.Keys()))
def __init__(self, cells_, hp): """Initialize a `Stack` instance. Args: cells_: recurrent transition cells, from bottom to top. hp: model hyperparameters. """ super(Stack, self).__init__(hp) self.cells = list(cells_) self._state_placeholders = NS( cells=[cell.state_placeholders for cell in self.cells])
def testFlatCallFlatZip(self): before = NS(v=2, w=NS(x=1, y=NS(z=0))) after = NS.FlatCall(lambda xs: [2 * x for x in xs], before) self.assertEqual(NS(v=4, w=NS(x=2, y=NS(z=0))), after) self.assertItemsEqual([(2, 4), (1, 2), (0, 0)], list(NS.FlatZip([before, after]))) after.w.y.a = 6 self.assertRaises(ValueError, lambda: NS.FlatZip([before, after])) self.assertItemsEqual([(2, 4), (0, 0)], list(NS.FlatZip([before, after], "v w.y.z")))
def __init__(self, cells_, hp): """Initialize a `Wayback` instance. The following hyperparameters are specific to this model: periods: update interval of each layer, from top to bottom. As layer 0 always runs at every step, periods[0] gives the number of steps of layer 0 before layer 1 is updated. periods[-1] gives the number of steps to run at the highest layer before the model should be considered to have completed a cycle. unroll_layer_count: number of upper layers to unroll. Unrolling allows for gradient truncation on the levels below. carry: whether to carry over each cell's state from one cycle to the next or break the chain and compute new initial states based on the state of the cell above. Args: cells_: recurrent transition cells, from top to bottom. hp: model hyperparameters. Raises: ValueError: If the number of cells and the number of periods differ. """ super(Wayback, self).__init__(hp) if len(self.hp.periods) != len(cells_): raise ValueError("must specify one period for each cell") if len(self.hp.boundaries) != len(cells_): raise ValueError("must specify one boundary for each cell") self.cells = list(cells_) cutoff = len(cells_) - self.hp.unroll_layer_count self.inner_indices = list(range(cutoff)) self.outer_indices = list(range(cutoff, len(cells_))) self.inner_slice = slice(cutoff) self.outer_slice = slice(cutoff, len(cells_)) self._state_placeholders = NS( time=tf.placeholder(dtype=tf.int32, name="time"), cells=[cell.state_placeholders for cell in self.cells])
def _make(self, hp): ts = NS() ts.x = tf.placeholder(dtype=tf.int32, name="x") # conditioning graph ts.cond = self.model.make_evaluation_graph(x=ts.x) # generation graph tf.get_variable_scope().reuse_variables() ts.initial_xelt = tf.placeholder(dtype=tf.int32, name="initial_xelt", shape=[None]) ts.length = tf.placeholder(dtype=tf.int32, name="length", shape=[]) ts.temperature = tf.placeholder(dtype=tf.float32, name="temperature", shape=[]) ts.sample = self.model.make_sampling_graph( initial_xelt=ts.initial_xelt, length=ts.length, temperature=ts.temperature) return ts
def parse_value(expr): """Parse a hyperparameter value. A value can be any Python literal. Barewords are converted to strings. Dictionaries are converted to Namespaces. Args: expr: value expression as a string or `ast.expr` Raises: ParseError: if `expr` is not a literal expression. Returns: The value represented by `expr`. """ if isinstance(expr, basestring): expr = ast.parse(expr).body[0].value if isinstance(expr, ast.Num): return expr.n elif isinstance(expr, ast.Str): return expr.s elif isinstance(expr, ast.Name): try: # True/False are represented as Names -_- return ast.literal_eval(expr.id) except ValueError: # interpret as string return expr.id elif isinstance(expr, ast.List): return list(map(parse_value, expr.elts)) elif isinstance(expr, ast.Tuple): return tuple(map(parse_value, expr.elts)) elif isinstance(expr, ast.Dict): return NS((parse_key(key), parse_value(value)) for key, value in zip(expr.keys, expr.values)) else: raise ParseError("invalid value", expr)
def run(self, session, examples, max_step_count=None, hooks=None, hp=None): state = NS(global_step=tf.train.global_step(session, self.tensors.global_step), model=self.model.initial_state(hp.batch_size)) while True: for batch in util.batches(examples, hp.batch_size): for segment in util.segments( batch, # the last chunk is not processed, so grab # one more to ensure we backpropagate # through at least one full model cycle. # TODO(cotim): rename segment_length to # backprop_length? hp.segment_length + hp.chunk_size, overlap=hp.chunk_size): if max_step_count is not None and state.global_step >= max_step_count: return hooks.Get("step.before", util.noop)(state) x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), self.tensors.Extract( "loss error summaries global_step training_op learning_rate final_state.model" )) state.model = values.final_state.model state.global_step = values.global_step hooks.Get("step.after", util.noop)(state, values) print("step #%d loss %f error %f learning rate %e" % (values.global_step, values.loss, values.error, values.learning_rate)) if np.isnan(values.loss): raise ValueError("loss has become NaN")
def main(argv): assert not argv[1:] hp = hyperparameters.parse(FLAGS.hyperparameters) print "loading data from %s" % FLAGS.data_dir dataset = datasets.construct(FLAGS.data_type, directory=FLAGS.data_dir, frequency=hp.sampling_frequency, bit_depth=hp.bit_depth) print "done" hp.data_dim = dataset.data_dim model_name = get_model_name(hp) print model_name output_dir = os.path.join(FLAGS.base_output_dir, model_name) if not FLAGS.resume: if tf.gfile.Exists(output_dir): tf.gfile.DeleteRecursively(output_dir) if not tf.gfile.Exists(output_dir): tf.gfile.MakeDirs(output_dir) hyperparameters.dump(os.path.join(output_dir, "hyperparameters.yaml"), hp) model = models.construct(hp) print "constructing graph..." global_step = tf.Variable(0, trainable=False, name="global_step") trainer = training.Trainer(model, hp=hp, global_step=global_step) tf.get_variable_scope().reuse_variables() evaluator = evaluation.Evaluator(model, hp=hp) print "done" best_saver = tf.train.Saver() supervisor = tf.train.Supervisor(logdir=output_dir, summary_op=None) session = supervisor.PrepareSession() tracking = NS(best_loss=None, reset_time=0) def track(loss, step): if step % FLAGS.tracking_interval == 0: if tracking.best_loss is None or loss < tracking.best_loss: tracking.best_loss = loss tracking.reset_time = step best_saver.save(session, os.path.join( os.path.dirname(supervisor.save_path), "best_%i_%s.ckpt" % (step, loss)), global_step=supervisor.global_step) elif step - tracking.reset_time > hp.decay_patience: session.run(trainer.tensors.decay_op) tracking.reset_time = step def maybe_validate(state): if state.global_step % FLAGS.validation_interval == 0: aggregates = {} if FLAGS.dump_predictions: # extract final exhats and losses for debugging aggregates.update( (key, util.LastAggregate()) for key in "seq.final_x final_state.exhats final_state.losses".split( )) values = evaluator.run(examples=dataset.examples.valid, session=session, hp=hp, aggregates=aggregates, max_step_count=FLAGS.max_validation_steps) supervisor.summary_computed(session, tf.Summary(value=values.summaries)) if FLAGS.dump_predictions: np.savez_compressed( os.path.join(os.path.dirname(supervisor.save_path), "xhats_%i.npz" % state.global_step), # i'm sure we'll get the idea from 100 steps of 10 examples xs=values.seq.final_x[:100, :10], exhats=values.final_state.exhats[:100, :10], losses=values.final_state.losses[:100, :10]) # track validation loss track(values.loss, state.global_step) def maybe_stop(_): if supervisor.ShouldStop(): raise StopTraining() def before_step_hook(state): maybe_validate(state) maybe_stop(state) def after_step_hook(state, values): for summary in values.summaries: supervisor.summary_computed(session, summary) # track training loss #track(values.loss, state.global_step) print "training." try: trainer.run( examples=dataset.examples.train[:FLAGS.max_examples], session=session, hp=hp, max_step_count=FLAGS.max_step_count, hooks=NS(step=NS(before=before_step_hook, after=after_step_hook))) except StopTraining: pass
def _make_sequence_graph(transition=None, model_state=None, x=None, initial_xelt=None, context=None, length=None, temperature=1.0, hp=None, back_prop=False): """Construct the graph to process a sequence of categorical integers. If `x` is given, the graph processes the sequence `x` one element at a time. At step `i`, the model receives the `i`th element as input, and its output is used to predict the `i + 1`th element. The last element is not processed, as there would be no further element available to compare against and compute loss. To ensure all data is processed during TBPTT, segments `x` fed into successive computations of the graph should overlap by 1. If `x` is not given, `initial_xelt` must be given as the first input to the model. Further elements are constructed from the model's predictions. Args: transition: model transition function mapping (xelt, model_state, context) to (output, new_model_state). model_state: initial state of the model. x: Sequence of integer (categorical) inputs. Not needed if sampling. Axes [time, batch]. initial_xelt: When sampling, x is not given; initial_xelt specifies the input x[0] to the first timestep. context: a `Tensor` denoting context, e.g. for conditioning. length: Optional length of sequence. Inferred from `x` if possible. temperature: Softmax temperature to use for sampling. hp: Model hyperparameters. back_prop: Whether the graph will be backpropagated through. Returns: Namespace containing relevant symbolic variables. """ with tf.variable_scope("seq") as scope: # if the caching device is not set explicitly, set it such that the # variables for the RNN are all cached locally. if scope.caching_device is None: scope.set_caching_device(lambda op: op.device) if length is None: length = tf.shape(x)[0] def _make_ta(name, **kwargs): # infer_shape=False because it is too strict; it considers unknown # dimensions to be incompatible with anything else. Effectively that # requires all shapes to be fully defined at graph construction time. return tf.TensorArray(tensor_array_name=name, infer_shape=False, **kwargs) state = NS(i=tf.constant(0), model=model_state) state.xhats = _make_ta("xhats", dtype=tf.int32, size=length, clear_after_read=False) state.xhats = state.xhats.write(0, initial_xelt if x is None else x[0, :]) state.exhats = _make_ta("exhats", dtype=tf.float32, size=length - LEFTOVER) if x is not None: state.losses = _make_ta("losses", dtype=tf.float32, size=length - LEFTOVER) state.errors = _make_ta("errors", dtype=tf.bool, size=length - LEFTOVER) state = tfutil.while_loop( cond=lambda state: state.i < length - LEFTOVER, body=ft.partial(make_transition_graph, transition=transition, x=x, context=context, temperature=temperature, hp=hp), loop_vars=state, back_prop=back_prop) # pack TensorArrays for key in "exhats xhats losses errors".split(): if key in state: state[key] = state[key].pack() ts = NS() ts.final_state = state ts.xhat = state.xhats[1:, :] ts.final_xhatelt = state.xhats[length - 1, :] if x is not None: ts.loss = tf.reduce_mean(state.losses) ts.error = tf.reduce_mean(tf.to_float(state.errors)) ts.final_x = x # expose the final, unprocessed element of x for convenience ts.final_xelt = x[length - 1, :] return ts
def _make_sequence_graph_with_unroll(self, model_state=None, x=None, initial_xelt=None, context=None, length=None, temperature=1.0, hp=None, back_prop=False): """Create a sequence graph by unrolling upper layers. This method is similar to `_make_sequence_graph`, except that `length` must be provided. The resulting graph behaves in the same way as that constructed by `_make_sequence_graph`, except that the upper layers are outside of the while loop and so the gradient can actually be truncated between runs of lower layers. If `x` is given, the graph processes the sequence `x` one element at a time. At step `i`, the model receives the `i`th element as input, and its output is used to predict the `i + 1`th element. The last element is not processed, as there would be no further element available to compare against and compute loss. To ensure all data is processed during TBPTT, segments `x` fed into successive computations of the graph should overlap by 1. If `x` is not given, `initial_xelt` must be given as the first input to the model. Further elements are constructed from the model's predictions. Args: model_state: initial state of the model. x: Sequence of integer (categorical) inputs. Not needed if sampling. Axes [time, batch]. initial_xelt: When sampling, x is not given; initial_xelt specifies the input x[0] to the first timestep. context: a `Tensor` denoting context, e.g. for conditioning. Axes [batch, features]. length: Optional length of sequence. Inferred from `x` if possible. temperature: Softmax temperature to use for sampling. hp: Model hyperparameters. back_prop: Whether the graph will be backpropagated through. Raises: ValueError: if `length` is not an int. Returns: Namespace containing relevant symbolic variables. """ if length is None or not isinstance(length, int): raise ValueError( "For partial unrolling, length must be known at graph construction time." ) if model_state is None: model_state = self.state_placeholders() state = NS(model=model_state, inner_initial_xelt=initial_xelt, xhats=[], losses=[], errors=[]) # i suspect ugly gradient biases may occur if gradients are truncated # somewhere halfway through the cycle. ensure we start at a cycle boundary. state.model.time = tfutil.assertion(state.model.time, tf.equal(state.model.time, 0), [state.model.time], name="outer_alignment_assertion") # ensure we end at a cycle boundary too. assert (length - LEFTOVER) % self.period == 0 inner_period = int(np.prod(hp.periods[:self.outer_indices[0] + 1])) # hp.boundaries specifies truncation boundaries relative to the end of the sequence and in terms # of each layer's own steps; translate this to be relative to the beginning of the sequence and # in terms of sequence elements. note that due to the dynamic unrolling of the inner graph, the # inner layers necessarily get truncated at the topmost inner layer's boundary. boundaries = [ length - 1 - hp.boundaries[i] * int(np.prod(hp.periods[:i + 1])) for i in range(len(hp.periods)) ] assert all(0 <= boundary and boundary < length - LEFTOVER for boundary in boundaries) assert boundaries == list(reversed(sorted(boundaries))) print "length %s periods %s boundaries %s %s inner period %s" % ( length, hp.periods, hp.boundaries, boundaries, inner_period) outer_step_count = length // inner_period for outer_time in range(outer_step_count): if outer_time > 0: tf.get_variable_scope().reuse_variables() # update outer layers (wrap in seq scope to be consistent with the fully # symbolic version of this graph) with tf.variable_scope("seq"): # truncate gradient (only effective on outer layers) for i in range(len(self.cells)): if outer_time * inner_period <= boundaries[i]: state.model.cells[i] = list( map(tf.stop_gradient, state.model.cells[i])) state.model.cells = Wayback.transition( outer_time * inner_period, state.model.cells, self.cells, below=None, above=context, subset=self.outer_indices, hp=hp, symbolic=False) # run inner layers on subsequence if x is None: inner_x = None else: start = inner_period * outer_time stop = inner_period * (outer_time + 1) + LEFTOVER inner_x = x[start:stop, :] # grab a copy of the outer states. they will not be updated in the inner # loop, so we can put back the copy after the inner loop completes. # this avoids the gradient truncation due to calling `while_loop` with # `back_prop=False`. outer_cell_states = NS.Copy(state.model.cells[self.outer_slice]) def _inner_transition(input_, state, context=None): assert not context state.cells = Wayback.transition(state.time, state.cells, self.cells, below=input_, above=None, subset=self.inner_indices, hp=hp, symbolic=True) state.time += 1 state.time %= self.period h = self.get_output(state) return h, state inner_back_prop = back_prop and outer_time * inner_period >= boundaries[ self.inner_indices[-1]] inner_ts = _make_sequence_graph( transition=_inner_transition, model_state=state.model, x=inner_x, initial_xelt=state.inner_initial_xelt, temperature=temperature, hp=hp, back_prop=inner_back_prop) state.model = inner_ts.final_state.model state.inner_initial_xelt = inner_ts.final_xelt if x is not None else inner_ts.final_xhatelt state.final_xhatelt = inner_ts.final_xhatelt if x is not None: state.final_x = inner_x state.final_xelt = inner_ts.final_xelt # track only losses and errors after the boundary to avoid bypassing the truncation boundary. if inner_back_prop: state.losses.append(inner_ts.loss) state.errors.append(inner_ts.error) state.xhats.append(inner_ts.xhat) # restore static outer states state.model.cells[self.outer_slice] = outer_cell_states # double check alignment to be safe state.model.time = tfutil.assertion( state.model.time, tf.equal(state.model.time % inner_period, 0), [state.model.time, tf.shape(inner_x)], name="inner_alignment_assertion") ts = NS() ts.xhat = tf.concat(0, state.xhats) ts.final_xhatelt = state.final_xhatelt ts.final_state = state if x is not None: ts.final_x = state.final_x ts.final_xelt = state.final_xelt # inner means are all on the same sample size, so taking their mean is valid ts.loss = tf.reduce_mean(state.losses) ts.error = tf.reduce_mean(state.errors) return ts
def testFlattenUnflatten(self): before = NS(v=2, w=NS(x=1, y=NS(z=0))) flat = NS.Flatten(before) after = NS.UnflattenLike(before, flat) self.assertEqual(before, after)
schema = NS(( name, NS(name=name, description=description, default=default) ) for name, (description, default) in dict( sampling_frequency=("desired waveform time resolution in Hz", 44100), bit_depth=("desired waveform amplitude resolution in bits", 8), data_dim=("data dimensionality (usually inferred)", 256), initial_learning_rate=("initial learning rate", 0.002), decay_patience=( "how long to wait for improvement before decaying the learning rate", 100), decay_rate=("rate of decay of learning rate", 0.1), clip_norm=("ratio for gradient clipping_by_norm", 1), batch_size=("number of examples in minibatch", 100), use_bn=("whether to use batch normalizatin", False), activation=("recurrent activation function to use (tanh/elu/identity)", "tanh"), io_sizes=("layer sizes for input and output MLPs", [512]), weight_decay=("L2 weight decay coefficient", 1e-7), segment_length=("length of truncated backpropagation", 1000), chunk_size=("number of samples per model step", 1), layout=("recurrent connection pattern (stack/wayback)", "stack"), cell=("recurrent cell (rnn/lstm/gru)", "lstm"), layer_sizes=("number of hidden units in each layer, from bottom to top.", [1000]), vskip=("vertical skip connections between all layers", False), periods= ("update interval for each layer, from bottom to top. only used for the wayback layout", [1000]), boundaries= ("number of periods to backprop through for each layer, from bottom to top. only used for the wayback layout", [1]), unroll_layer_count= ("number of upper layers to move outside the while loop. only used for the wayback layout", 0), carry= ("whether to carry state between cycles or restart based on context. only used for the wayback layout", True)).items())
def testExtract(self): ns = NS(v=2, w=NS(x=1, y=NS(z=0))).Extract("w.y v") self.assertEqual(ns.v, 2) self.assertEqual(ns.w, NS(y=NS(z=0))) self.assertEqual(ns.w.y, NS(z=0))
def testCopy(self): before = NS(v=2, w=NS(x=1, y=NS(z=0))) after = NS.Copy(before) self.assertEqual(before, after) self.assertTrue( all(a is b for a, b in zip(NS.Flatten(after), NS.Flatten(before))))
def initial_state(self, batch_size): return NS( time=0, cells=[cell.initial_state(batch_size) for cell in self.cells])