def test_export(self, topology, wbits, abits, QONNX_export): if wbits > abits: pytest.skip("No wbits > abits end2end network configs for now") if topology == "lfc" and not (wbits == 1 and abits == 1): pytest.skip("Skipping certain lfc configs") (model, ishape) = get_trained_network_and_ishape(topology, wbits, abits) chkpt_name = get_checkpoint_name(topology, wbits, abits, QONNX_export, "export") if QONNX_export: BrevitasONNXManager.export(model, ishape, chkpt_name) qonnx_cleanup(chkpt_name, out_file=chkpt_name) model = ModelWrapper(chkpt_name) model = model.transform(ConvertQONNXtoFINN()) model.save(chkpt_name) else: bo.export_finn_onnx(model, ishape, chkpt_name) nname = "%s_w%da%d" % (topology, wbits, abits) update_dashboard_data(topology, wbits, abits, "network", nname) dtstr = datetime.now().strftime("%Y-%m-%d %H:%M:%S") update_dashboard_data(topology, wbits, abits, "datetime", dtstr) finn_commit = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd="/workspace/finn") finn_commit = finn_commit.decode("utf-8").strip() update_dashboard_data(topology, wbits, abits, "finn-commit", finn_commit) assert os.path.isfile(chkpt_name)
def get_golden_io_pair(topology, wbits, abits, preproc=ToTensor(), return_topk=None): (model, ishape) = get_trained_network_and_ishape(topology, wbits, abits) input_tensor_npy = get_example_input(topology) input_tensor_torch = torch.from_numpy(input_tensor_npy).float() if preproc is not None: input_tensor_torch = preproc.forward(input_tensor_torch).detach() output_tensor_npy = model.forward(input_tensor_torch).detach().numpy() if return_topk is not None: output_tensor_npy = get_topk(output_tensor_npy, k=return_topk) return (input_tensor_npy, output_tensor_npy)
def test_export(self, topology, wbits, abits): if wbits > abits: pytest.skip("No wbits > abits end2end network configs for now") (model, ishape) = get_trained_network_and_ishape(topology, wbits, abits) chkpt_name = get_checkpoint_name(topology, wbits, abits, "export") bo.export_finn_onnx(model, ishape, chkpt_name) nname = "%s_w%da%d" % (topology, wbits, abits) update_dashboard_data(topology, wbits, abits, "network", nname) dtstr = datetime.now().strftime("%Y-%m-%d %H:%M:%S") update_dashboard_data(topology, wbits, abits, "datetime", dtstr) finn_commit = subprocess.check_output( ["git", "rev-parse", "HEAD"], cwd="/workspace/finn" ) finn_commit = finn_commit.decode("utf-8").strip() update_dashboard_data(topology, wbits, abits, "finn-commit", finn_commit) assert os.path.isfile(chkpt_name)