def run(self, session, primers, length, temperature, hp=None): batch_size = len(primers) # process in segments to avoid tensorflow eating all the memory max_segment_length = min(10000, hp.segment_length) print "conditioning..." segment_length = min(max_segment_length, max(len(primer[0]) for primer in primers)) # ensure segment_length is a multiple of chunk_size segment_length -= segment_length % hp.chunk_size state = NS(model=self.model.initial_state(batch_size)) for segment in util.segments(primers, segment_length, overlap=hp.chunk_size): x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), self.tensors.cond.Extract("final_state.model final_xchunk")) state.model = values.final_state.model sys.stderr.write(".") sys.stderr.write("\n") cond_values = values # make sure length is a multiple of chunk_size chunky_length = length + hp.chunk_size - length % hp.chunk_size print "sampling..." length_left = chunky_length xhats = [] state = NS(model=cond_values.final_state.model, initial_xchunk=cond_values.final_xchunk) while length_left > 0: segment_length = min(max_segment_length, length_left) length_left -= segment_length feed_dict = { self.tensors.initial_xchunk: state.initial_xchunk, self.tensors.length: segment_length, self.tensors.temperature: temperature } feed_dict.update(self.model.feed_dict(state.model)) sample_values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), self.tensors.sample.Extract( "final_state.model xhat final_xhatchunk")) state.model = sample_values.final_state.model state.initial_xchunk = sample_values.final_xhatchunk xhats.append(sample_values.xhat) sys.stderr.write(".") sys.stderr.write("\n") xhat = np.concatenate(xhats, axis=0) # truncate from chunky_length to the desired sample length xhat = xhat[:length] return xhat.T
def testFlatCallFlatZip(self): before = NS(v=2, w=NS(x=1, y=NS(z=0))) after = NS.FlatCall(lambda xs: [2 * x for x in xs], before) self.assertEqual(NS(v=4, w=NS(x=2, y=NS(z=0))), after) self.assertItemsEqual([(2, 4), (1, 2), (0, 0)], list(NS.FlatZip([before, after]))) after.w.y.a = 6 self.assertRaises(ValueError, lambda: NS.FlatZip([before, after])) self.assertItemsEqual([(2, 4), (0, 0)], list(NS.FlatZip([before, after], "v w.y.z")))
def run(self, session, examples, max_step_count=None, hp=None, aggregates=None): aggregates = NS(aggregates or {}) for key in "loss error".split(): if key not in aggregates: aggregates[key] = util.MeanAggregate() tensors = self.tensors.Extract(*[key for key in aggregates.Keys()]) tensors.Update(self.tensors.Extract("final_state.model")) state = NS(step=0, model=self.model.initial_state(hp.batch_size)) try: for batch in util.batches(examples, hp.batch_size): for segment in util.segments(batch, hp.segment_length, overlap=hp.chunk_size): if max_step_count is not None and state.step >= max_step_count: raise StopIteration() x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), tensors) for key in aggregates.Keys(): aggregates[key].add(values[key]) sys.stderr.write(".") state.model = values.final_state.model state.step += 1 except StopIteration: pass sys.stderr.write("\n") values = NS( (key, aggregate.value) for key, aggregate in aggregates.Items()) values.summaries = [ tf.Summary.Value(tag="%s_valid" % key, simple_value=values[key]) for key in "loss error".split() ] print "### evaluation loss %6.5f error %6.5f" % (values.loss, values.error) if np.isnan(values.loss): raise ValueError("loss has become NaN") return values
def run(session, tensors, **run_kwargs): # too damn big trace = False if trace: run_metadata = tf.RunMetadata() run_kwargs["options"] = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_kwargs["run_metadata"] = run_metadata values = NS.FlatCall(ft.partial(session.run, **run_kwargs), tensors) if trace: from tensorflow.python.client import timeline trace = timeline.Timeline(step_stats=run_metadata.step_stats) with open("timeline_%s.json" % "_".join(".".join(map(str, key)) for key in tensors.Keys()), "w") as trace_file: trace_file.write(trace.generate_chrome_trace_format()) return values
def run(self, session, examples, max_step_count=None, hooks=None, hp=None): state = NS(global_step=tf.train.global_step(session, self.tensors.global_step), model=self.model.initial_state(hp.batch_size)) while True: for batch in util.batches(examples, hp.batch_size): for segment in util.segments( batch, # the last chunk is not processed, so grab # one more to ensure we backpropagate # through at least one full model cycle. # TODO(cotim): rename segment_length to # backprop_length? hp.segment_length + hp.chunk_size, overlap=hp.chunk_size): if max_step_count is not None and state.global_step >= max_step_count: return hooks.Get("step.before", util.noop)(state) x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), self.tensors.Extract( "loss error summaries global_step training_op learning_rate final_state.model" )) state.model = values.final_state.model state.global_step = values.global_step hooks.Get("step.after", util.noop)(state, values) print("step #%d loss %f error %f learning rate %e" % (values.global_step, values.loss, values.error, values.learning_rate)) if np.isnan(values.loss): raise ValueError("loss has become NaN")