def testMisc(self): ns = NS() ns.w = 0 ns["x"] = 3 ns.x = 1 ns.y = NS(z=2) self.assertEqual(list(ns.Keys()), [("w", ), ("x", ), ("y", "z")]) self.assertEqual(list(ns.Values()), [0, 1, 2]) self.assertEqual(list(ns.Items()), [(("w", ), 0), (("x", ), 1), (("y", "z"), 2)]) self.assertEqual(ns.AsDict(), OrderedDict([("w", 0), ("x", 1), ("y", NS(z=2))])) ns.Update(ns.y) self.assertEqual(list(ns), [("w", ), ("x", ), ("y", "z"), ("z", )]) self.assertEqual(list(ns.Keys()), [("w", ), ("x", ), ("y", "z"), ("z", )]) self.assertEqual(list(ns.Values()), [0, 1, 2, 2]) self.assertEqual(list(ns.Items()), [(("w", ), 0), (("x", ), 1), (("y", "z"), 2), (("z", ), 2)]) self.assertEqual( ns.AsDict(), OrderedDict([("w", 0), ("x", 1), ("y", NS(z=2)), ("z", 2)])) ns = NS(v=2, w=NS(x=1, y=[3, NS(z=0)])) self.assertItemsEqual([("v", ), ("w", "x"), ("w", "y", 0), ("w", "y", 1, "z")], list(ns.Keys()))
def run(self, session, examples, max_step_count=None, hp=None, aggregates=None): aggregates = NS(aggregates or {}) for key in "loss error".split(): if key not in aggregates: aggregates[key] = util.MeanAggregate() tensors = self.tensors.Extract(*[key for key in aggregates.Keys()]) tensors.Update(self.tensors.Extract("final_state.model")) state = NS(step=0, model=self.model.initial_state(hp.batch_size)) try: for batch in util.batches(examples, hp.batch_size): for segment in util.segments(batch, hp.segment_length, overlap=hp.chunk_size): if max_step_count is not None and state.step >= max_step_count: raise StopIteration() x, = util.examples_as_arrays(segment) feed_dict = {self.tensors.x: x.T} feed_dict.update(self.model.feed_dict(state.model)) values = NS.FlatCall( ft.partial(session.run, feed_dict=feed_dict), tensors) for key in aggregates.Keys(): aggregates[key].add(values[key]) sys.stderr.write(".") state.model = values.final_state.model state.step += 1 except StopIteration: pass sys.stderr.write("\n") values = NS( (key, aggregate.value) for key, aggregate in aggregates.Items()) values.summaries = [ tf.Summary.Value(tag="%s_valid" % key, simple_value=values[key]) for key in "loss error".split() ] print "### evaluation loss %6.5f error %6.5f" % (values.loss, values.error) if np.isnan(values.loss): raise ValueError("loss has become NaN") return values