def run(self, session, examples, max_step_count=None, hooks=None, hp=None): tensors = self.tensors.Extract( "loss error summaries global_step training_op learning_rate final_state.model" ) state = NS(global_step=tf.train.global_step(session, self.tensors.global_step), model=self.model.initial_state(hp.batch_size)) while True: for batch in util.batches(examples, hp.batch_size): for segment in util.segments(batch, self.segment_length, overlap=LEFTOVER): if max_step_count is not None and state.global_step >= max_step_count: return hooks.Get("step.before", util.noop)(state) x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = tfutil.run(session, tensors, feed_dict=feed_dict) state.model = values.final_state.model state.global_step = values.global_step hooks.Get("step.after", util.noop)(state, values) print("step #%d loss %f error %f learning rate %e" % (values.global_step, values.loss, values.error, values.learning_rate)) if np.isnan(values.loss): raise ValueError("loss has become NaN")
def run(self, session, examples, max_step_count=None, hp=None, aggregates=None): aggregates = NS(aggregates or {}) for key in "loss error".split(): if key not in aggregates: aggregates[key] = util.MeanAggregate() tensors = self.tensors.Extract(*[key for key in aggregates.Keys()]) tensors.Update(self.tensors.Extract("final_state.model")) state = NS(step=0, model=self.model.initial_state(hp.batch_size)) try: for batch in util.batches(examples, hp.batch_size): for segment in util.segments(batch, hp.segment_length, overlap=hp.chunk_size): if max_step_count is not None and state.step >= max_step_count: raise StopIteration() x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), tensors) for key in aggregates.Keys(): aggregates[key].add(values[key]) sys.stderr.write(".") state.model = values.final_state.model state.step += 1 except StopIteration: pass sys.stderr.write("\n") values = NS( (key, aggregate.value) for key, aggregate in aggregates.Items()) values.summaries = [ tf.Summary.Value(tag="%s_valid" % key, simple_value=values[key]) for key in "loss error".split() ] print "### evaluation loss %6.5f error %6.5f" % (values.loss, values.error) if np.isnan(values.loss): raise ValueError("loss has become NaN") return values
def test_segmented_batches(self): length = np.random.randint(2, 100) segment_length = np.random.randint(1, length) example_count = np.random.randint(2, 100) batch_size = np.random.randint(1, example_count) feature_shapes = [ np.random.randint(1, 10, size=np.random.randint(1, 4)) for _ in range(np.random.randint(1, 4)) ] examples = [[ np.random.rand(length, *shape) for shape in feature_shapes ] for _ in range(example_count)] for batch in util.batches(examples, batch_size, augment=False): for segment in util.segments(examples, segment_length): self.assertEqual(batch_size, len(batch)) for features in segment: self.assertEqual(len(feature_shapes), len(features)) for feature, feature_shape in util.equizip( features, feature_shapes): self.assertLessEqual(len(feature), segment_length) self.assertEqual(tuple(feature_shape), feature.shape[1:])
def run(self, session, examples, max_step_count=None, hooks=None, hp=None): state = NS(global_step=tf.train.global_step(session, self.tensors.global_step), model=self.model.initial_state(hp.batch_size)) while True: for batch in util.batches(examples, hp.batch_size): for segment in util.segments( batch, # the last chunk is not processed, so grab # one more to ensure we backpropagate # through at least one full model cycle. # TODO(cotim): rename segment_length to # backprop_length? hp.segment_length + hp.chunk_size, overlap=hp.chunk_size): if max_step_count is not None and state.global_step >= max_step_count: return hooks.Get("step.before", util.noop)(state) x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), self.tensors.Extract( "loss error summaries global_step training_op learning_rate final_state.model" )) state.model = values.final_state.model state.global_step = values.global_step hooks.Get("step.after", util.noop)(state, values) print("step #%d loss %f error %f learning rate %e" % (values.global_step, values.loss, values.error, values.learning_rate)) if np.isnan(values.loss): raise ValueError("loss has become NaN")