def run(self, session, primers, length, temperature, hp=None): batch_size = len(primers) # process in segments to avoid tensorflow eating all the memory max_segment_length = min(10000, hp.segment_length) print "conditioning..." segment_length = min(max_segment_length, max(len(primer[0]) for primer in primers)) # ensure segment_length is a multiple of chunk_size segment_length -= segment_length % hp.chunk_size state = NS(model=self.model.initial_state(batch_size)) for segment in util.segments(primers, segment_length, overlap=hp.chunk_size): x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), self.tensors.cond.Extract("final_state.model final_xchunk")) state.model = values.final_state.model sys.stderr.write(".") sys.stderr.write("\n") cond_values = values # make sure length is a multiple of chunk_size chunky_length = length + hp.chunk_size - length % hp.chunk_size print "sampling..." length_left = chunky_length xhats = [] state = NS(model=cond_values.final_state.model, initial_xchunk=cond_values.final_xchunk) while length_left > 0: segment_length = min(max_segment_length, length_left) length_left -= segment_length feed_dict = { self.tensors.initial_xchunk: state.initial_xchunk, self.tensors.length: segment_length, self.tensors.temperature: temperature } feed_dict.update(self.model.feed_dict(state.model)) sample_values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), self.tensors.sample.Extract( "final_state.model xhat final_xhatchunk")) state.model = sample_values.final_state.model state.initial_xchunk = sample_values.final_xhatchunk xhats.append(sample_values.xhat) sys.stderr.write(".") sys.stderr.write("\n") xhat = np.concatenate(xhats, axis=0) # truncate from chunky_length to the desired sample length xhat = xhat[:length] return xhat.T
def test_segments_regression1(self): examples = [[np.arange(9)]] segment_length = 4 overlap = 2 for segment in util.segments(examples, segment_length, overlap=overlap): self.assertEqual(4, len(segment[0][0]))
def run(self, session, examples, max_step_count=None, hooks=None, hp=None): tensors = self.tensors.Extract( "loss error summaries global_step training_op learning_rate final_state.model" ) state = NS(global_step=tf.train.global_step(session, self.tensors.global_step), model=self.model.initial_state(hp.batch_size)) while True: for batch in util.batches(examples, hp.batch_size): for segment in util.segments(batch, self.segment_length, overlap=LEFTOVER): if max_step_count is not None and state.global_step >= max_step_count: return hooks.Get("step.before", util.noop)(state) x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = tfutil.run(session, tensors, feed_dict=feed_dict) state.model = values.final_state.model state.global_step = values.global_step hooks.Get("step.after", util.noop)(state, values) print("step #%d loss %f error %f learning rate %e" % (values.global_step, values.loss, values.error, values.learning_rate)) if np.isnan(values.loss): raise ValueError("loss has become NaN")
def preprocess_primers(examples, hp): # maybe augment number of examples to ensure batch norm will work min_batch_size = 16 if len(examples) < min_batch_size: k = min_batch_size // len(examples) examples.extend( example.with_offset(offset) for example in examples for offset in util.random_choice(len(example), size=[k], replace=False)) # maybe augment number of time steps to ensure util.segments doesn't discard # anything at the ends of the examples. this is done by left-padding the # shorter examples with repetitions. max_len = max(map(len, examples)) examples = [ example.map(lambda feature: np.pad( feature, [(max_len - len(feature), 0)], mode="wrap")) for wav, in examples ] # time is tight; condition on 3 seconds of the wav files only examples_segments = list(util.segments(examples, 3 * hp.sampling_frequency)) if len(examples_segments) > 2: # don't use the first and last segments to avoid silence examples_segments = examples_segments[1:-1] examples = examples_segments[util.random_choice(len(examples_segments))] return examples
def test_segments1(self): length = 100 xs = np.random.randint(2, 100, size=[length]) examples = [[xs]] segment_length = 10 ys = [] for segment in util.segments(examples, segment_length): ys.extend(segment[0][0]) self.assertEqual(list(xs), list(ys)) k = np.random.randint(1, segment_length - 1) for i, segment in enumerate( util.segments(examples, segment_length, overlap=k)): if i != 0: self.assertTrue(np.array_equal(overlap, segment[0][0][:k])) overlap = segment[0][0][-k:]
def run(self, session, examples, max_step_count=None, hp=None, aggregates=None): aggregates = NS(aggregates or {}) for key in "loss error".split(): if key not in aggregates: aggregates[key] = util.MeanAggregate() tensors = self.tensors.Extract(*[key for key in aggregates.Keys()]) tensors.Update(self.tensors.Extract("final_state.model")) state = NS(step=0, model=self.model.initial_state(hp.batch_size)) try: for batch in util.batches(examples, hp.batch_size): for segment in util.segments(batch, hp.segment_length, overlap=hp.chunk_size): if max_step_count is not None and state.step >= max_step_count: raise StopIteration() x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), tensors) for key in aggregates.Keys(): aggregates[key].add(values[key]) sys.stderr.write(".") state.model = values.final_state.model state.step += 1 except StopIteration: pass sys.stderr.write("\n") values = NS( (key, aggregate.value) for key, aggregate in aggregates.Items()) values.summaries = [ tf.Summary.Value(tag="%s_valid" % key, simple_value=values[key]) for key in "loss error".split() ] print "### evaluation loss %6.5f error %6.5f" % (values.loss, values.error) if np.isnan(values.loss): raise ValueError("loss has become NaN") return values
def run(self, session, primers, length, temperature, hp=None): batch_size = len(primers) # process in segments to avoid tensorflow eating all the memory max_segment_length = min(10000, hp.segment_length) print "conditioning..." segment_length = min(max_segment_length, max(len(primer[0]) for primer in primers)) state = NS(model=self.model.initial_state(batch_size)) for segment in util.segments(primers, segment_length, overlap=LEFTOVER): x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = tfutil.run(session, tensors=self.tensors.cond.Extract( "final_state.model final_xelt"), feed_dict=feed_dict) state.model = values.final_state.model sys.stderr.write(".") sys.stderr.write("\n") cond_values = values print "sampling..." length_left = length + LEFTOVER xhats = [] state = NS(model=cond_values.final_state.model, initial_xelt=cond_values.final_xelt) while length_left > 0: segment_length = min(max_segment_length, length_left) length_left -= segment_length feed_dict = { self.tensors.initial_xelt: state.initial_xelt, self.tensors.length: segment_length, self.tensors.temperature: temperature } feed_dict.update(self.model.feed_dict(state.model)) sample_values = tfutil.run( session, tensors=self.tensors.sample.Extract( "final_state.model xhat final_xhatelt"), feed_dict=feed_dict), state.model = sample_values.final_state.model state.initial_xelt = sample_values.final_xhatelt xhats.append(sample_values.xhat) sys.stderr.write(".") sys.stderr.write("\n") xhat = np.concatenate(xhats, axis=0) return xhat.T
def test_segmented_batches(self): length = np.random.randint(2, 100) segment_length = np.random.randint(1, length) example_count = np.random.randint(2, 100) batch_size = np.random.randint(1, example_count) feature_shapes = [ np.random.randint(1, 10, size=np.random.randint(1, 4)) for _ in range(np.random.randint(1, 4)) ] examples = [[ np.random.rand(length, *shape) for shape in feature_shapes ] for _ in range(example_count)] for batch in util.batches(examples, batch_size, augment=False): for segment in util.segments(examples, segment_length): self.assertEqual(batch_size, len(batch)) for features in segment: self.assertEqual(len(feature_shapes), len(features)) for feature, feature_shape in util.equizip( features, feature_shapes): self.assertLessEqual(len(feature), segment_length) self.assertEqual(tuple(feature_shape), feature.shape[1:])
def run(self, session, examples, max_step_count=None, hooks=None, hp=None): state = NS(global_step=tf.train.global_step(session, self.tensors.global_step), model=self.model.initial_state(hp.batch_size)) while True: for batch in util.batches(examples, hp.batch_size): for segment in util.segments( batch, # the last chunk is not processed, so grab # one more to ensure we backpropagate # through at least one full model cycle. # TODO(cotim): rename segment_length to # backprop_length? hp.segment_length + hp.chunk_size, overlap=hp.chunk_size): if max_step_count is not None and state.global_step >= max_step_count: return hooks.Get("step.before", util.noop)(state) x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), self.tensors.Extract( "loss error summaries global_step training_op learning_rate final_state.model" )) state.model = values.final_state.model state.global_step = values.global_step hooks.Get("step.after", util.noop)(state, values) print("step #%d loss %f error %f learning rate %e" % (values.global_step, values.loss, values.error, values.learning_rate)) if np.isnan(values.loss): raise ValueError("loss has become NaN")
def test_segments_truncate(self): class ComparableNdarray(np.ndarray): """A Numpy ndarray that doesn't break equality. Numpy ndarray violates the __eq__ contract, which breaks deep comparisons. Work around it by wrapping the arrays. """ def __eq__(self, other): return np.array_equal(self, other) def comparablearray(*args, **kwargs): array = np.array(*args, **kwargs) return ComparableNdarray(array.shape, buffer=array, dtype=array.dtype) def to_examples(segmented_examples): segments_by_example = [[[ comparablearray(list(s.strip()), dtype="|S1") ] for s in e.split("|")] for e in segmented_examples] examples_by_segment = list( map(list, util.equizip(*segments_by_example))) return examples_by_segment examples = to_examples(["abcdefg", "mno", "vwxyz"])[0] self.assertEqual( to_examples(["ab|cd|ef|g ", "mn|o | | ", "vw|xy|z | "]), list(util.segments(examples, 2, truncate=False))) self.assertEqual(to_examples(["ab", "mn", "vw"]), list(util.segments(examples, 2, truncate=True))) self.assertEqual( to_examples(["abc|def|g ", "mno| | ", "vwx|yz | "]), list(util.segments(examples, 3, truncate=False))) self.assertEqual(to_examples(["abc", "mno", "vwx"]), list(util.segments(examples, 3, truncate=True))) self.assertEqual(to_examples(["abcd|efg ", "mno | ", "vwxy|z "]), list(util.segments(examples, 4, truncate=False))) self.assertEqual([], list(util.segments(examples, 4, truncate=True))) overlap = 1 self.assertEqual( to_examples([ "ab|bc|cd|de|ef|fg|g ", "mn|no|o | | | | ", "vw|wx|xy|yz|z | | " ]), list(util.segments(examples, 2, overlap, truncate=False))) self.assertEqual( to_examples(["ab|bc", "mn|no", "vw|wx"]), list(util.segments(examples, 2, overlap, truncate=True))) self.assertEqual( to_examples( ["abc|cde|efg|g ", "mno|o | | ", "vwx|xyz|z | "]), list(util.segments(examples, 3, overlap, truncate=False))) self.assertEqual( to_examples(["abc", "mno", "vwx"]), list(util.segments(examples, 3, overlap, truncate=True))) self.assertEqual( to_examples(["abcd|defg|g ", "mno | | ", "vwxy|yz | "]), list(util.segments(examples, 4, overlap, truncate=False))) self.assertEqual([], list( util.segments(examples, 4, overlap, truncate=True)))