def while_loop(cond, body, loop_vars, **kwargs): """Like `tf.while_loop` but with structured `loop_vars`. Args: cond: as in `tf.while_loop`, but takes a single `loop_vars` argument. body: as in `tf.while_loop`, but takes and returns a single `loop_vars` tree which it is allowed to modify. loop_vars: as in `tf.while_loop`, but consists of a Namespace tree. **kwargs: passed onto `tf.while_loop`. Returns: A Namespace tree structure containing the final values of the loop variables. """ def _cond(*flat_vars): return cond(NS.UnflattenLike(loop_vars, flat_vars)) def _body(*flat_vars): return NS.Flatten(body(NS.UnflattenLike(loop_vars, flat_vars))) return NS.UnflattenLike( loop_vars, tf.while_loop(cond=_cond, body=_body, loop_vars=NS.Flatten(loop_vars), **kwargs))
def run(self, session, examples, max_step_count=None, hooks=None, hp=None): tensors = self.tensors.Extract( "loss error summaries global_step training_op learning_rate final_state.model" ) state = NS(global_step=tf.train.global_step(session, self.tensors.global_step), model=self.model.initial_state(hp.batch_size)) while True: for batch in util.batches(examples, hp.batch_size): for segment in util.segments(batch, self.segment_length, overlap=LEFTOVER): if max_step_count is not None and state.global_step >= max_step_count: return hooks.Get("step.before", util.noop)(state) x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = tfutil.run(session, tensors, feed_dict=feed_dict) state.model = values.final_state.model state.global_step = values.global_step hooks.Get("step.after", util.noop)(state, values) print("step #%d loss %f error %f learning rate %e" % (values.global_step, values.loss, values.error, values.learning_rate)) if np.isnan(values.loss): raise ValueError("loss has become NaN")
def testGet(self): ns = NS(foo=NS(bar="baz")) self.assertRaises(KeyError, lambda: ns["foo"]["baz"]) self.assertIsNone(ns.Get("foo.baz")) x = object() self.assertEqual(x, ns.Get("foo.baz", x)) self.assertEqual("baz", ns.Get("foo.bar")) self.assertEqual(NS(bar="baz"), ns.Get("foo"))
def testFlatCallFlatZip(self): before = NS(v=2, w=NS(x=1, y=NS(z=0))) after = NS.FlatCall(lambda xs: [2 * x for x in xs], before) self.assertEqual(NS(v=4, w=NS(x=2, y=NS(z=0))), after) self.assertItemsEqual([(2, 4), (1, 2), (0, 0)], list(NS.FlatZip([before, after]))) after.w.y.a = 6 self.assertRaises(ValueError, lambda: NS.FlatZip([before, after])) self.assertItemsEqual([(2, 4), (0, 0)], list(NS.FlatZip([before, after], "v w.y.z")))
def run(self, session, primers, length, temperature, hp=None): batch_size = len(primers) # process in segments to avoid tensorflow eating all the memory max_segment_length = min(10000, hp.segment_length) print "conditioning..." segment_length = min(max_segment_length, max(len(primer[0]) for primer in primers)) state = NS(model=self.model.initial_state(batch_size)) for segment in util.segments(primers, segment_length, overlap=LEFTOVER): x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = tfutil.run(session, tensors=self.tensors.cond.Extract( "final_state.model final_xelt"), feed_dict=feed_dict) state.model = values.final_state.model sys.stderr.write(".") sys.stderr.write("\n") cond_values = values print "sampling..." length_left = length + LEFTOVER xhats = [] state = NS(model=cond_values.final_state.model, initial_xelt=cond_values.final_xelt) while length_left > 0: segment_length = min(max_segment_length, length_left) length_left -= segment_length feed_dict = { self.tensors.initial_xelt: state.initial_xelt, self.tensors.length: segment_length, self.tensors.temperature: temperature } feed_dict.update(self.model.feed_dict(state.model)) sample_values = tfutil.run( session, tensors=self.tensors.sample.Extract( "final_state.model xhat final_xhatelt"), feed_dict=feed_dict), state.model = sample_values.final_state.model state.initial_xelt = sample_values.final_xhatelt xhats.append(sample_values.xhat) sys.stderr.write(".") sys.stderr.write("\n") xhat = np.concatenate(xhats, axis=0) return xhat.T
def run(self, session, examples, max_step_count=None, hp=None, aggregates=None): aggregates = NS(aggregates or {}) for key in "loss error".split(): if key not in aggregates: aggregates[key] = util.MeanAggregate() tensors = self.tensors.Extract(*[key for key in aggregates.Keys()]) tensors.Update(self.tensors.Extract("final_state.model")) state = NS(step=0, model=self.model.initial_state(hp.batch_size)) try: for batch in util.batches(examples, hp.batch_size): for segment in util.segments(batch, hp.segment_length, overlap=hp.chunk_size): if max_step_count is not None and state.step >= max_step_count: raise StopIteration() x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), tensors) for key in aggregates.Keys(): aggregates[key].add(values[key]) sys.stderr.write(".") state.model = values.final_state.model state.step += 1 except StopIteration: pass sys.stderr.write("\n") values = NS( (key, aggregate.value) for key, aggregate in aggregates.Items()) values.summaries = [ tf.Summary.Value(tag="%s_valid" % key, simple_value=values[key]) for key in "loss error".split() ] print "### evaluation loss %6.5f error %6.5f" % (values.loss, values.error) if np.isnan(values.loss): raise ValueError("loss has become NaN") return values
def state_placeholders(self): """Get the Tensorflow placeholders for the model's states. Returns: A Namespace tree containing the placeholders. """ return NS.Copy(self._state_placeholders)
def cond(pred, fn1, fn2, prototype, **kwargs): """Like `tf.cond` but with structured collections of variables. Args: pred: boolean Tensor, as in `tf.cond`. fn1: a callable representing the `then` branch as in `tf.cond`, but may return an arbitrary Namespace tree. fn2: a callable representing the `else` branch as in `tf.cond`, but may return an arbitrary Namespace tree. prototype: an example Namespace tree to indicate the structure of the values returned from `fn1` and `fn2`. **kwargs: passed onto `tf.cond`. Returns: Like `tf.cond`, except structured like `prototype`. """ def wrap_branch(fn): def wrapped_branch(): tree = fn() liszt = NS.Flatten(tree) return liszt return wrapped_branch results = tf.cond(pred, wrap_branch(fn1), wrap_branch(fn2), **kwargs) # tf.cond unpacks singleton lists returned from fn1, fn2 -_- if not isinstance(results, (tuple, list)): results = [results] # need a prototype to unflatten because at this point neither fn1 nor fn2 # have been called tree3 = NS.UnflattenLike(prototype, results) return tree3
def run(self, session, primers, length, temperature, hp=None): batch_size = len(primers) # process in segments to avoid tensorflow eating all the memory max_segment_length = min(10000, hp.segment_length) print "conditioning..." segment_length = min(max_segment_length, max(len(primer[0]) for primer in primers)) # ensure segment_length is a multiple of chunk_size segment_length -= segment_length % hp.chunk_size state = NS(model=self.model.initial_state(batch_size)) for segment in util.segments(primers, segment_length, overlap=hp.chunk_size): x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), self.tensors.cond.Extract("final_state.model final_xchunk")) state.model = values.final_state.model sys.stderr.write(".") sys.stderr.write("\n") cond_values = values # make sure length is a multiple of chunk_size chunky_length = length + hp.chunk_size - length % hp.chunk_size print "sampling..." length_left = chunky_length xhats = [] state = NS(model=cond_values.final_state.model, initial_xchunk=cond_values.final_xchunk) while length_left > 0: segment_length = min(max_segment_length, length_left) length_left -= segment_length feed_dict = { self.tensors.initial_xchunk: state.initial_xchunk, self.tensors.length: segment_length, self.tensors.temperature: temperature } feed_dict.update(self.model.feed_dict(state.model)) sample_values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), self.tensors.sample.Extract( "final_state.model xhat final_xhatchunk")) state.model = sample_values.final_state.model state.initial_xchunk = sample_values.final_xhatchunk xhats.append(sample_values.xhat) sys.stderr.write(".") sys.stderr.write("\n") xhat = np.concatenate(xhats, axis=0) # truncate from chunky_length to the desired sample length xhat = xhat[:length] return xhat.T
def feed_dict(self, state): """Construct a feed dict for the model's states. Args: state: the model state. Returns: A feed dict mapping each of the model's placeholders to the corresponding numerical value in `state`. """ return util.odict(NS.FlatZip([self.state_placeholders(), state]))
def __call__(self, x, state, context=None): # construct the usual graph without unrolling state = NS.Copy(state) state.cells = Wayback.transition(state.time, state.cells, self.cells, below=x, above=context, hp=self.hp, symbolic=True) state.time += 1 state.time %= self.period return state
def _make(self, unused_hp): ts = NS() ts.x = tf.placeholder(dtype=tf.int32, name="x") ts.seq = self.model.make_evaluation_graph(x=ts.x) ts.final_state = ts.seq.final_state ts.loss = ts.seq.loss ts.error = ts.seq.error return ts
def run(self, session, examples, max_step_count=None, hooks=None, hp=None): state = NS(global_step=tf.train.global_step(session, self.tensors.global_step), model=self.model.initial_state(hp.batch_size)) while True: for batch in util.batches(examples, hp.batch_size): for segment in util.segments( batch, # the last chunk is not processed, so grab # one more to ensure we backpropagate # through at least one full model cycle. # TODO(cotim): rename segment_length to # backprop_length? hp.segment_length + hp.chunk_size, overlap=hp.chunk_size): if max_step_count is not None and state.global_step >= max_step_count: return hooks.Get("step.before", util.noop)(state) x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), self.tensors.Extract( "loss error summaries global_step training_op learning_rate final_state.model" )) state.model = values.final_state.model state.global_step = values.global_step hooks.Get("step.after", util.noop)(state, values) print("step #%d loss %f error %f learning rate %e" % (values.global_step, values.loss, values.error, values.learning_rate)) if np.isnan(values.loss): raise ValueError("loss has become NaN")
def make_transition_graph(state, transition, x=None, context=None, temperature=1.0, hp=None): """Make the graph that processes a single sequence element. Args: state: `_make_sequence_graph` loop state. transition: Model transition function mapping (xelt, model_state, context) to (output, new_model_state). x: Sequence of integer (categorical) inputs. Axes [time, batch]. context: Optional Tensor denoting context, shaped [batch, ?]. temperature: Softmax temperature to use for sampling. hp: Model hyperparameters. Returns: Updated loop state. """ state = NS.Copy(state) xelt = tfutil.shaped_one_hot( state.xhats.read(state.i) if x is None else x[state.i, :], [None, hp.data_dim]) embedding = tfutil.layers([xelt], sizes=hp.io_sizes, use_bn=hp.use_bn) h, state.model = transition(embedding, state.model, context=context) # predict the next elt with tf.variable_scope("xhat") as scope: embedding = tfutil.layers([h], sizes=hp.io_sizes, use_bn=hp.use_bn) exhat = tfutil.project(embedding, output_dim=hp.data_dim) xhat = tfutil.sample(exhat, temperature) state.xhats = state.xhats.write(state.i + LEFTOVER, xhat) if x is not None: target = tfutil.shaped_one_hot(x[state.i + 1], [None, hp.data_dim]) state.losses = state.losses.write( state.i, tf.nn.softmax_cross_entropy_with_logits(exhat, target)) state.errors = state.errors.write( state.i, tf.not_equal(tf.nn.top_k(exhat)[1], tf.nn.top_k(target)[1])) state.exhats = state.exhats.write(state.i, exhat) state.i += 1 return state
def run(session, tensors, **run_kwargs): # too damn big trace = False if trace: run_metadata = tf.RunMetadata() run_kwargs["options"] = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_kwargs["run_metadata"] = run_metadata values = NS.FlatCall(ft.partial(session.run, **run_kwargs), tensors) if trace: from tensorflow.python.client import timeline trace = timeline.Timeline(step_stats=run_metadata.step_stats) with open("timeline_%s.json" % "_".join(".".join(map(str, key)) for key in tensors.Keys()), "w") as trace_file: trace_file.write(trace.generate_chrome_trace_format()) return values
def __call__(self, x, state, context=None): state = NS.Copy(state) for i, _ in enumerate(self.cells): cell_inputs = [] if i == 0: cell_inputs.append(x) if context is not None and i == len(self.cells) - 1: cell_inputs.append(context) if self.hp.vskip: # feed in state of all other layers cell_inputs.extend(self.cells[j].get_output(state.cells[j]) for j in range(len(self.cells)) if j != i) else: # feed in state of layer below if i > 0: cell_inputs.append(self.cells[i - 1].get_output( state.cells[i - 1])) state.cells[i] = self.cells[i].transition(cell_inputs, state.cells[i], scope="cell%i" % i) return state
def _make(self, hp): ts = NS() ts.x = tf.placeholder(dtype=tf.int32, name="x") # conditioning graph ts.cond = self.model.make_evaluation_graph(x=ts.x) # generation graph tf.get_variable_scope().reuse_variables() ts.initial_xelt = tf.placeholder(dtype=tf.int32, name="initial_xelt", shape=[None]) ts.length = tf.placeholder(dtype=tf.int32, name="length", shape=[]) ts.temperature = tf.placeholder(dtype=tf.float32, name="temperature", shape=[]) ts.sample = self.model.make_sampling_graph( initial_xelt=ts.initial_xelt, length=ts.length, temperature=ts.temperature) return ts
def _make(self, hp, global_step=None): ts = NS() ts.global_step = global_step ts.x = tf.placeholder(dtype=tf.int32, name="x") length = hp.segment_length + hp.chunk_size ts.seq = self.model.make_training_graph(x=ts.x, length=length) ts.final_state = ts.seq.final_state ts.loss = ts.seq.loss ts.error = ts.seq.error ts.learning_rate = tf.Variable(hp.initial_learning_rate, dtype=tf.float32, trainable=False, name="learning_rate") ts.decay_op = tf.assign(ts.learning_rate, ts.learning_rate * hp.decay_rate) ts.optimizer = tf.train.AdamOptimizer(ts.learning_rate) ts.params = tf.trainable_variables() print[param.name for param in ts.params] ts.gradients = tf.gradients(ts.loss, ts.params) loose_params = [ param for param, gradient in util.equizip(ts.params, ts.gradients) if gradient is None ] if loose_params: raise ValueError("loose parameters: %s" % " ".join(param.name for param in loose_params)) # tensorflow fails miserably to compute gradient for these for reg_var in tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES): ts.gradients[ts.params.index(reg_var)] += ( hp.weight_decay * tf.gradients(tf.sqrt(tf.reduce_sum(reg_var**2)), [reg_var])[0]) ts.clipped_gradients, _ = tf.clip_by_global_norm( ts.gradients, hp.clip_norm) ts.training_op = ts.optimizer.apply_gradients( util.equizip(ts.clipped_gradients, ts.params), global_step=ts.global_step) ts.summaries = [ tf.scalar_summary("loss_train", ts.loss), tf.scalar_summary("error_train", ts.error), tf.scalar_summary("learning_rate", ts.learning_rate) ] for parameter, gradient in util.equizip(ts.params, ts.gradients): ts.summaries.append( tf.scalar_summary("meanlogabs_%s" % parameter.name, tfutil.meanlogabs(parameter))) ts.summaries.append( tf.scalar_summary("meanlogabsgrad_%s" % parameter.name, tfutil.meanlogabs(gradient))) return ts
def make_transition_graph(state, transition, x=None, context=None, temperature=1.0, hp=None): """Make the graph that processes a single sequence element. Args: state: `_make_sequence_graph` loop state. transition: Model transition function mapping (xchunk, model_state, context) to (output, new_model_state). x: Sequence of integer (categorical) inputs. Axes [time, batch]. context: Optional Tensor denoting context, shaped [batch, ?]. temperature: Softmax temperature to use for sampling. hp: Model hyperparameters. Returns: Updated loop state. """ state = NS.Copy(state) xchunk = _get_flat_chunk(state.xhats if x is None else x, state.i * hp.chunk_size, hp.chunk_size, depth=hp.data_dim) embedding = tfutil.layers([xchunk], sizes=hp.io_sizes, use_bn=hp.use_bn) h, state.model = transition(embedding, state.model, context=context) # predict the next chunk exhats = [] with tf.variable_scope("xhat") as scope: for j in range(hp.chunk_size): if j > 0: scope.reuse_variables() xchunk = _get_flat_chunk(state.xhats if x is None else x, state.i * hp.chunk_size + j, hp.chunk_size, depth=hp.data_dim) embedding = tfutil.layers([h, xchunk], sizes=hp.io_sizes, use_bn=hp.use_bn) exhat = tfutil.project(embedding, output_dim=hp.data_dim) exhats.append(exhat) state.xhats = state.xhats.write((state.i + 1) * hp.chunk_size + j, tfutil.sample(exhat, temperature)) if x is not None: targets = tf.unpack(_get_1hot_chunk(x, (state.i + 1) * hp.chunk_size, hp.chunk_size, depth=hp.data_dim), num=hp.chunk_size, axis=1) state.losses = _put_chunk(state.losses, state.i * hp.chunk_size, [ tf.nn.softmax_cross_entropy_with_logits(exhat, target) for exhat, target in util.equizip(exhats, targets) ]) state.errors = _put_chunk(state.errors, state.i * hp.chunk_size, [ tf.not_equal(tf.nn.top_k(exhat)[1], tf.nn.top_k(target)[1]) for exhat, target in util.equizip(exhats, targets) ]) state.exhats = _put_chunk(state.exhats, state.i * hp.chunk_size, exhats) state.i += 1 return state
def wrapped_branch(): tree = fn() liszt = NS.Flatten(tree) return liszt
def _make_sequence_graph(transition=None, model_state=None, x=None, initial_xelt=None, context=None, length=None, temperature=1.0, hp=None, back_prop=False): """Construct the graph to process a sequence of categorical integers. If `x` is given, the graph processes the sequence `x` one element at a time. At step `i`, the model receives the `i`th element as input, and its output is used to predict the `i + 1`th element. The last element is not processed, as there would be no further element available to compare against and compute loss. To ensure all data is processed during TBPTT, segments `x` fed into successive computations of the graph should overlap by 1. If `x` is not given, `initial_xelt` must be given as the first input to the model. Further elements are constructed from the model's predictions. Args: transition: model transition function mapping (xelt, model_state, context) to (output, new_model_state). model_state: initial state of the model. x: Sequence of integer (categorical) inputs. Not needed if sampling. Axes [time, batch]. initial_xelt: When sampling, x is not given; initial_xelt specifies the input x[0] to the first timestep. context: a `Tensor` denoting context, e.g. for conditioning. length: Optional length of sequence. Inferred from `x` if possible. temperature: Softmax temperature to use for sampling. hp: Model hyperparameters. back_prop: Whether the graph will be backpropagated through. Returns: Namespace containing relevant symbolic variables. """ with tf.variable_scope("seq") as scope: # if the caching device is not set explicitly, set it such that the # variables for the RNN are all cached locally. if scope.caching_device is None: scope.set_caching_device(lambda op: op.device) if length is None: length = tf.shape(x)[0] def _make_ta(name, **kwargs): # infer_shape=False because it is too strict; it considers unknown # dimensions to be incompatible with anything else. Effectively that # requires all shapes to be fully defined at graph construction time. return tf.TensorArray(tensor_array_name=name, infer_shape=False, **kwargs) state = NS(i=tf.constant(0), model=model_state) state.xhats = _make_ta("xhats", dtype=tf.int32, size=length, clear_after_read=False) state.xhats = state.xhats.write(0, initial_xelt if x is None else x[0, :]) state.exhats = _make_ta("exhats", dtype=tf.float32, size=length - LEFTOVER) if x is not None: state.losses = _make_ta("losses", dtype=tf.float32, size=length - LEFTOVER) state.errors = _make_ta("errors", dtype=tf.bool, size=length - LEFTOVER) state = tfutil.while_loop( cond=lambda state: state.i < length - LEFTOVER, body=ft.partial(make_transition_graph, transition=transition, x=x, context=context, temperature=temperature, hp=hp), loop_vars=state, back_prop=back_prop) # pack TensorArrays for key in "exhats xhats losses errors".split(): if key in state: state[key] = state[key].pack() ts = NS() ts.final_state = state ts.xhat = state.xhats[1:, :] ts.final_xhatelt = state.xhats[length - 1, :] if x is not None: ts.loss = tf.reduce_mean(state.losses) ts.error = tf.reduce_mean(tf.to_float(state.errors)) ts.final_x = x # expose the final, unprocessed element of x for convenience ts.final_xelt = x[length - 1, :] return ts
def testMisc(self): ns = NS() ns.w = 0 ns["x"] = 3 ns.x = 1 ns.y = NS(z=2) self.assertEqual(list(ns.Keys()), [("w", ), ("x", ), ("y", "z")]) self.assertEqual(list(ns.Values()), [0, 1, 2]) self.assertEqual(list(ns.Items()), [(("w", ), 0), (("x", ), 1), (("y", "z"), 2)]) self.assertEqual(ns.AsDict(), OrderedDict([("w", 0), ("x", 1), ("y", NS(z=2))])) ns.Update(ns.y) self.assertEqual(list(ns), [("w", ), ("x", ), ("y", "z"), ("z", )]) self.assertEqual(list(ns.Keys()), [("w", ), ("x", ), ("y", "z"), ("z", )]) self.assertEqual(list(ns.Values()), [0, 1, 2, 2]) self.assertEqual(list(ns.Items()), [(("w", ), 0), (("x", ), 1), (("y", "z"), 2), (("z", ), 2)]) self.assertEqual( ns.AsDict(), OrderedDict([("w", 0), ("x", 1), ("y", NS(z=2)), ("z", 2)])) ns = NS(v=2, w=NS(x=1, y=[3, NS(z=0)])) self.assertItemsEqual([("v", ), ("w", "x"), ("w", "y", 0), ("w", "y", 1, "z")], list(ns.Keys()))
def _make_sequence_graph_with_unroll(self, model_state=None, x=None, initial_xelt=None, context=None, length=None, temperature=1.0, hp=None, back_prop=False): """Create a sequence graph by unrolling upper layers. This method is similar to `_make_sequence_graph`, except that `length` must be provided. The resulting graph behaves in the same way as that constructed by `_make_sequence_graph`, except that the upper layers are outside of the while loop and so the gradient can actually be truncated between runs of lower layers. If `x` is given, the graph processes the sequence `x` one element at a time. At step `i`, the model receives the `i`th element as input, and its output is used to predict the `i + 1`th element. The last element is not processed, as there would be no further element available to compare against and compute loss. To ensure all data is processed during TBPTT, segments `x` fed into successive computations of the graph should overlap by 1. If `x` is not given, `initial_xelt` must be given as the first input to the model. Further elements are constructed from the model's predictions. Args: model_state: initial state of the model. x: Sequence of integer (categorical) inputs. Not needed if sampling. Axes [time, batch]. initial_xelt: When sampling, x is not given; initial_xelt specifies the input x[0] to the first timestep. context: a `Tensor` denoting context, e.g. for conditioning. Axes [batch, features]. length: Optional length of sequence. Inferred from `x` if possible. temperature: Softmax temperature to use for sampling. hp: Model hyperparameters. back_prop: Whether the graph will be backpropagated through. Raises: ValueError: if `length` is not an int. Returns: Namespace containing relevant symbolic variables. """ if length is None or not isinstance(length, int): raise ValueError( "For partial unrolling, length must be known at graph construction time." ) if model_state is None: model_state = self.state_placeholders() state = NS(model=model_state, inner_initial_xelt=initial_xelt, xhats=[], losses=[], errors=[]) # i suspect ugly gradient biases may occur if gradients are truncated # somewhere halfway through the cycle. ensure we start at a cycle boundary. state.model.time = tfutil.assertion(state.model.time, tf.equal(state.model.time, 0), [state.model.time], name="outer_alignment_assertion") # ensure we end at a cycle boundary too. assert (length - LEFTOVER) % self.period == 0 inner_period = int(np.prod(hp.periods[:self.outer_indices[0] + 1])) # hp.boundaries specifies truncation boundaries relative to the end of the sequence and in terms # of each layer's own steps; translate this to be relative to the beginning of the sequence and # in terms of sequence elements. note that due to the dynamic unrolling of the inner graph, the # inner layers necessarily get truncated at the topmost inner layer's boundary. boundaries = [ length - 1 - hp.boundaries[i] * int(np.prod(hp.periods[:i + 1])) for i in range(len(hp.periods)) ] assert all(0 <= boundary and boundary < length - LEFTOVER for boundary in boundaries) assert boundaries == list(reversed(sorted(boundaries))) print "length %s periods %s boundaries %s %s inner period %s" % ( length, hp.periods, hp.boundaries, boundaries, inner_period) outer_step_count = length // inner_period for outer_time in range(outer_step_count): if outer_time > 0: tf.get_variable_scope().reuse_variables() # update outer layers (wrap in seq scope to be consistent with the fully # symbolic version of this graph) with tf.variable_scope("seq"): # truncate gradient (only effective on outer layers) for i in range(len(self.cells)): if outer_time * inner_period <= boundaries[i]: state.model.cells[i] = list( map(tf.stop_gradient, state.model.cells[i])) state.model.cells = Wayback.transition( outer_time * inner_period, state.model.cells, self.cells, below=None, above=context, subset=self.outer_indices, hp=hp, symbolic=False) # run inner layers on subsequence if x is None: inner_x = None else: start = inner_period * outer_time stop = inner_period * (outer_time + 1) + LEFTOVER inner_x = x[start:stop, :] # grab a copy of the outer states. they will not be updated in the inner # loop, so we can put back the copy after the inner loop completes. # this avoids the gradient truncation due to calling `while_loop` with # `back_prop=False`. outer_cell_states = NS.Copy(state.model.cells[self.outer_slice]) def _inner_transition(input_, state, context=None): assert not context state.cells = Wayback.transition(state.time, state.cells, self.cells, below=input_, above=None, subset=self.inner_indices, hp=hp, symbolic=True) state.time += 1 state.time %= self.period h = self.get_output(state) return h, state inner_back_prop = back_prop and outer_time * inner_period >= boundaries[ self.inner_indices[-1]] inner_ts = _make_sequence_graph( transition=_inner_transition, model_state=state.model, x=inner_x, initial_xelt=state.inner_initial_xelt, temperature=temperature, hp=hp, back_prop=inner_back_prop) state.model = inner_ts.final_state.model state.inner_initial_xelt = inner_ts.final_xelt if x is not None else inner_ts.final_xhatelt state.final_xhatelt = inner_ts.final_xhatelt if x is not None: state.final_x = inner_x state.final_xelt = inner_ts.final_xelt # track only losses and errors after the boundary to avoid bypassing the truncation boundary. if inner_back_prop: state.losses.append(inner_ts.loss) state.errors.append(inner_ts.error) state.xhats.append(inner_ts.xhat) # restore static outer states state.model.cells[self.outer_slice] = outer_cell_states # double check alignment to be safe state.model.time = tfutil.assertion( state.model.time, tf.equal(state.model.time % inner_period, 0), [state.model.time, tf.shape(inner_x)], name="inner_alignment_assertion") ts = NS() ts.xhat = tf.concat(0, state.xhats) ts.final_xhatelt = state.final_xhatelt ts.final_state = state if x is not None: ts.final_x = state.final_x ts.final_xelt = state.final_xelt # inner means are all on the same sample size, so taking their mean is valid ts.loss = tf.reduce_mean(state.losses) ts.error = tf.reduce_mean(state.errors) return ts
def testCopy(self): before = NS(v=2, w=NS(x=1, y=NS(z=0))) after = NS.Copy(before) self.assertEqual(before, after) self.assertTrue( all(a is b for a, b in zip(NS.Flatten(after), NS.Flatten(before))))
def _body(*flat_vars): return NS.Flatten(body(NS.UnflattenLike(loop_vars, flat_vars)))
def testFlattenUnflatten(self): before = NS(v=2, w=NS(x=1, y=NS(z=0))) flat = NS.Flatten(before) after = NS.UnflattenLike(before, flat) self.assertEqual(before, after)
def _cond(*flat_vars): return cond(NS.UnflattenLike(loop_vars, flat_vars))