def setUpClass(cls): export_config = get_config(action='test', optimised=True) cls.config = export_config if not os.path.exists(cls.config['checkpoint']): download_checkpoint() if os.path.exists(export_config["default_image_path"]): cls.image_path = export_config["default_image_path"] else: if not os.path.exists(export_config["image_path"]): download_data() cls.image_path = export_config["image_path"] dataset_test = RSNADataSet(cls.config['dummy_valid_list'], cls.config['dummy_labels'], cls.image_path, transform=True) cls.data_loader_test = DataLoader(dataset=dataset_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=False) alpha = cls.config['alpha']**cls.config['phi'] beta = cls.config['beta']**cls.config['phi'] cls.model = DenseNet121Eff(alpha, beta, cls.config['class_count']) cls.class_names = cls.config['class_names'] cls.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") cls.inference = RSNAInference(cls.model, cls.data_loader_test, cls.config['class_count'], cls.config['checkpoint'], cls.config['class_names'], cls.device) cls.exporter = Exporter(cls.config, optimised=True)
def test_export_onnx(self): self.config = get_config(action='export', optimised=True) if not os.path.exists(self.config['checkpoint']): download_checkpoint() self.exporter = Exporter(self.config, optimised=True) self.exporter.export_model_onnx() checkpoint = os.path.split(self.config['checkpoint'])[0] self.assertTrue( os.path.join(checkpoint, self.config.get('model_name_onnx')))
def test_trainer(self): self.model = DenseNet121(self.config["clscount"]) if not os.path.exists(self.config["checkpoint"]): download_checkpoint() self.device = self.config["device"] self.trainer = RSNATrainer( self.model, self.data_loader_train, self.data_loader_valid, self.data_loader_test, self.config["clscount"], self.config["checkpoint"], self.device, self.config["class_names"], self.config["lr"]) self.trainer.train(self.config["max_epoch"], self.config["savepath"]) cur_train_loss = self.trainer.current_train_loss self.trainer.train(self.config["max_epoch"], self.config["savepath"]) self.assertLessEqual(self.trainer.current_train_loss, cur_train_loss)
def test_export_ir(self): self.config = get_config(action='export', optimised=True) if not os.path.exists(self.config['checkpoint']): download_checkpoint() self.exporter = Exporter(self.config, optimised=True) self.model_path = os.path.split(self.config['checkpoint'])[0] if not os.path.exists( os.path.join(self.model_path, self.config.get('model_name_onnx'))): self.exporter.export_model_onnx() self.exporter.export_model_ir() name_xml = self.config.get('model_name') + '.xml' name_bin = self.config.get('model_name') + '.bin' xml_status = os.path.exists(os.path.join(self.model_path, name_xml)) bin_status = os.path.exists(os.path.join(self.model_path, name_bin)) self.assertTrue(xml_status) self.assertTrue(bin_status)
def setUpClass(cls): cls.config = get_config(action='export') if not os.path.exists(cls.config['checkpoint']): download_checkpoint() cls.model_path = cls.config['checkpoint']