示例#1
0
文件: xor.py 项目: burgerdev/hostload
def main():

    ext1 = {"class": WeightKeeper}
    ext2 = {"class": ProgressMonitor}
    ext3 = {"class": PlotExtension}
    config["train"]["extensions"] = (ext1, ext2, ext3)
    wd = tempfile.mkdtemp(prefix="xor_")
    w = RegressionWorkflow.build(config, workingdir=wd)
    w.run()

    plt_ext = filter(lambda x: isinstance(x, PlotExtension),
                     w._train.extensions_used)[0]
    ani = plt_ext.get_animation()
    # ani.save('im.mp4', metadata={'artist':'Guido'})

    plt.show()
示例#2
0
 def testMLPReg(self):
     d = tempfile.mkdtemp()
     try:
         c = config.copy()
         c["class"] = RegressionWorkflow
         c["train"] = {"class": OpMLPTrain,
                       "layer_classes": (mlp.Sigmoid,),
                       "layer_sizes": (5,)}
         c["predict"] = {"class": OpMLPPredict}
         c["target"] = {"class": OpRegTarget}
         del c["report"]
         w = RegressionWorkflow.build(c, workingdir=d)
         with warnings.catch_warnings():
             warnings.simplefilter("ignore")
             w.run()
     except:
         raise
     finally:
         shutil.rmtree(d)