def test_tf_traceability(self): if os.path.exists(self.tf_dir) and os.path.isdir(self.tf_dir): shutil.rmtree(self.tf_dir) trace = Traceability(save_path=self.tf_dir) est = _build_estimator( fe.build(model_fn=LeNet, optimizer_fn="adam", model_name='tfLeNet'), trace) trace.system = est.system trace.system.epoch_idx = 1 trace.system.summary.name = "TF Test" trace.on_begin(Data()) trace.on_end(Data()) crawler = os.walk(self.tf_dir) root = next(crawler) self.assertIn('resources', root[1], "A resources subdirectory should have been generated") self.assertIn('tf_test.tex', root[2], "The tex file should have been generated") # Might be a pdf and/or a .ds_store file depending on system, but shouldn't be more than that self.assertLessEqual(len(root[2]), 3, "Extra files should not have been generated") figs = next(crawler) self.assertIn('tf_test_tfLeNet.pdf', figs[2], "A figure for the model should have been generated") self.assertIn('tf_test_logs.png', figs[2], "A log image should have been generated") self.assertIn('tf_test.txt', figs[2], "A raw log file should have been generated")
def get_estimator(epochs=2, batch_size=32, save_dir=tempfile.mkdtemp()): # step 1 train_data, eval_data = mnist.load_data() test_data = eval_data.split(0.5) pipeline = fe.Pipeline(train_data=train_data, eval_data=eval_data, test_data=test_data, batch_size=batch_size, ops=[ ExpandDims(inputs="x", outputs="x"), Minmax(inputs="x", outputs="x") ]) # step 2 model = fe.build(model_fn=LeNet, optimizer_fn="adam") network = fe.Network(ops=[ ModelOp(model=model, inputs="x", outputs="y_pred"), CrossEntropy(inputs=("y_pred", "y"), outputs="ce"), UpdateOp(model=model, loss_name="ce") ]) # step 3 traces = [ Accuracy(true_key="y", pred_key="y_pred"), BestModelSaver(model=model, save_dir=save_dir, metric="accuracy", save_best_mode="max"), LRScheduler(model=model, lr_fn=lambda step: cosine_decay( step, cycle_length=3750, init_lr=1e-3)), Traceability(save_path=save_dir) ] estimator = fe.Estimator(pipeline=pipeline, network=network, epochs=epochs, traces=traces) return estimator