def setUpClass(cls): config = get_config(action='train', optimised=True) cls.config = config if os.path.exists(config["default_image_path"]): image_path = config["default_image_path"] else: if not os.path.exists(config["image_path"]): download_data() image_path = config["image_path"] dataset_train = RSNADataSet(cls.config['dummy_train_list'], cls.config['dummy_labels'], image_path, transform=True) dataset_valid = RSNADataSet(cls.config['dummy_valid_list'], cls.config['dummy_labels'], image_path, transform=True) dataset_test = RSNADataSet(cls.config['dummy_test_list'], cls.config['dummy_labels'], image_path, transform=True) cls.data_loader_train = DataLoader(dataset=dataset_train, shuffle=True, pin_memory=False) cls.data_loader_valid = DataLoader(dataset=dataset_valid, shuffle=False, pin_memory=False) cls.data_loader_test = DataLoader(dataset=dataset_test, shuffle=False, pin_memory=False)
def setUpClass(cls): export_config = get_config(action='test', optimised=True) cls.config = export_config if not os.path.exists(cls.config['checkpoint']): download_checkpoint() if os.path.exists(export_config["default_image_path"]): cls.image_path = export_config["default_image_path"] else: if not os.path.exists(export_config["image_path"]): download_data() cls.image_path = export_config["image_path"] dataset_test = RSNADataSet(cls.config['dummy_valid_list'], cls.config['dummy_labels'], cls.image_path, transform=True) cls.data_loader_test = DataLoader(dataset=dataset_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=False) alpha = cls.config['alpha']**cls.config['phi'] beta = cls.config['beta']**cls.config['phi'] cls.model = DenseNet121Eff(alpha, beta, cls.config['class_count']) cls.class_names = cls.config['class_names'] cls.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") cls.inference = RSNAInference(cls.model, cls.data_loader_test, cls.config['class_count'], cls.config['checkpoint'], cls.config['class_names'], cls.device) cls.exporter = Exporter(cls.config, optimised=True)
def export(args): export_config = get_config(action='export', optimised=args.optimised) exporter = Exporter(export_config, args.optimised) if args.onnx: exporter.export_model_onnx() if args.ir: exporter.export_model_ir()
def test_config(self): self.config = get_config(action='export', optimised=True) self.model_path = self.config['checkpoint'] self.input_shape = self.config['input_shape'] self.output_dir = os.path.split(self.model_path)[0] self.assertTrue(self.output_dir) self.assertTrue(self.model_path) self.assertListEqual(self.input_shape, [1, 3, 1024, 1024])
def test_export_onnx(self): self.config = get_config(action='export', optimised=True) if not os.path.exists(self.config['checkpoint']): download_checkpoint() self.exporter = Exporter(self.config, optimised=True) self.exporter.export_model_onnx() checkpoint = os.path.split(self.config['checkpoint'])[0] self.assertTrue( os.path.join(checkpoint, self.config.get('model_name_onnx')))
def test_config(self): self.config = get_config(action='train', optimised=True) self.learn_rate = self.config["lr"] self.class_count = self.config["class_count"] self.assertGreaterEqual(self.learn_rate, 1e-8) self.assertEqual(self.class_count, 3) self.assertGreaterEqual(self.config['alpha'], 0) self.assertGreaterEqual(self.config['phi'], -1) self.assertLessEqual(self.config['alpha'], 2) self.assertLessEqual(self.config['phi'], 1)
def test_export_ir(self): self.config = get_config(action='export', optimised=True) if not os.path.exists(self.config['checkpoint']): download_checkpoint() self.exporter = Exporter(self.config, optimised=True) self.model_path = os.path.split(self.config['checkpoint'])[0] if not os.path.exists( os.path.join(self.model_path, self.config.get('model_name_onnx'))): self.exporter.export_model_onnx() self.exporter.export_model_ir() name_xml = self.config.get('model_name') + '.xml' name_bin = self.config.get('model_name') + '.bin' xml_status = os.path.exists(os.path.join(self.model_path, name_xml)) bin_status = os.path.exists(os.path.join(self.model_path, name_bin)) self.assertTrue(xml_status) self.assertTrue(bin_status)
def test_config(self): self.config = get_config(action='test') self.assertEqual(self.config['clscount'], 3)
def setUpClass(cls): cls.config = get_config(action='export') if not os.path.exists(cls.config['checkpoint']): download_checkpoint() cls.model_path = cls.config['checkpoint']