from metanet.networks.artificial_networks.feedforward_network import FeedforwardNetwork from metanet.datasets import mnist import argparse if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-t", "--train", default="true") args = parser.parse_args() inp, tgt = mnist.get_mnist(10) net = FeedforwardNetwork(len(inp[0]), [len(inp[0])], len(tgt[0])) if args.train == "true": for i in range(50): print(" [*] Epoch:", i, "error:", net.train(inp, tgt, 1)) else: net.load_net("./data/net.pkl") err = 0 for i in range(len(tgt)): ans = net.test(inp[i]) if tgt[i].index(max(tgt[i])) != ans.index(max(ans)): err += 1 if args.train == "true": net.save_net("./data/net.pkl") print(" [*]", err, "errors")