Пример #1
0
        def setUpClass(cls):
            config = get_config(action='train', optimised=True)
            cls.config = config
            if os.path.exists(config["default_image_path"]):
                image_path = config["default_image_path"]
            else:
                if not os.path.exists(config["image_path"]):
                    download_data()
                image_path = config["image_path"]

            dataset_train = RSNADataSet(cls.config['dummy_train_list'],
                                        cls.config['dummy_labels'],
                                        image_path,
                                        transform=True)
            dataset_valid = RSNADataSet(cls.config['dummy_valid_list'],
                                        cls.config['dummy_labels'],
                                        image_path,
                                        transform=True)
            dataset_test = RSNADataSet(cls.config['dummy_test_list'],
                                       cls.config['dummy_labels'],
                                       image_path,
                                       transform=True)
            cls.data_loader_train = DataLoader(dataset=dataset_train,
                                               shuffle=True,
                                               pin_memory=False)
            cls.data_loader_valid = DataLoader(dataset=dataset_valid,
                                               shuffle=False,
                                               pin_memory=False)
            cls.data_loader_test = DataLoader(dataset=dataset_test,
                                              shuffle=False,
                                              pin_memory=False)
        def setUpClass(cls):
            export_config = get_config(action='test', optimised=True)
            cls.config = export_config
            if not os.path.exists(cls.config['checkpoint']):
                download_checkpoint()
            if os.path.exists(export_config["default_image_path"]):
                cls.image_path = export_config["default_image_path"]
            else:
                if not os.path.exists(export_config["image_path"]):
                    download_data()
                cls.image_path = export_config["image_path"]
            dataset_test = RSNADataSet(cls.config['dummy_valid_list'],
                                       cls.config['dummy_labels'],
                                       cls.image_path,
                                       transform=True)
            cls.data_loader_test = DataLoader(dataset=dataset_test,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=False)

            alpha = cls.config['alpha']**cls.config['phi']
            beta = cls.config['beta']**cls.config['phi']
            cls.model = DenseNet121Eff(alpha, beta, cls.config['class_count'])
            cls.class_names = cls.config['class_names']
            cls.device = torch.device(
                "cuda" if torch.cuda.is_available() else "cpu")
            cls.inference = RSNAInference(cls.model, cls.data_loader_test,
                                          cls.config['class_count'],
                                          cls.config['checkpoint'],
                                          cls.config['class_names'],
                                          cls.device)
            cls.exporter = Exporter(cls.config, optimised=True)
Пример #3
0
def export(args):
    export_config = get_config(action='export', optimised=args.optimised)
    exporter = Exporter(export_config, args.optimised)

    if args.onnx:
        exporter.export_model_onnx()
    if args.ir:
        exporter.export_model_ir()
Пример #4
0
 def test_config(self):
     self.config = get_config(action='export', optimised=True)
     self.model_path = self.config['checkpoint']
     self.input_shape = self.config['input_shape']
     self.output_dir = os.path.split(self.model_path)[0]
     self.assertTrue(self.output_dir)
     self.assertTrue(self.model_path)
     self.assertListEqual(self.input_shape, [1, 3, 1024, 1024])
Пример #5
0
 def test_export_onnx(self):
     self.config = get_config(action='export', optimised=True)
     if not os.path.exists(self.config['checkpoint']):
         download_checkpoint()
     self.exporter = Exporter(self.config, optimised=True)
     self.exporter.export_model_onnx()
     checkpoint = os.path.split(self.config['checkpoint'])[0]
     self.assertTrue(
         os.path.join(checkpoint, self.config.get('model_name_onnx')))
Пример #6
0
 def test_config(self):
     self.config = get_config(action='train', optimised=True)
     self.learn_rate = self.config["lr"]
     self.class_count = self.config["class_count"]
     self.assertGreaterEqual(self.learn_rate, 1e-8)
     self.assertEqual(self.class_count, 3)
     self.assertGreaterEqual(self.config['alpha'], 0)
     self.assertGreaterEqual(self.config['phi'], -1)
     self.assertLessEqual(self.config['alpha'], 2)
     self.assertLessEqual(self.config['phi'], 1)
Пример #7
0
 def test_export_ir(self):
     self.config = get_config(action='export', optimised=True)
     if not os.path.exists(self.config['checkpoint']):
         download_checkpoint()
     self.exporter = Exporter(self.config, optimised=True)
     self.model_path = os.path.split(self.config['checkpoint'])[0]
     if not os.path.exists(
             os.path.join(self.model_path,
                          self.config.get('model_name_onnx'))):
         self.exporter.export_model_onnx()
     self.exporter.export_model_ir()
     name_xml = self.config.get('model_name') + '.xml'
     name_bin = self.config.get('model_name') + '.bin'
     xml_status = os.path.exists(os.path.join(self.model_path,
                                              name_xml))
     bin_status = os.path.exists(os.path.join(self.model_path,
                                              name_bin))
     self.assertTrue(xml_status)
     self.assertTrue(bin_status)
 def test_config(self):
     self.config = get_config(action='test')
     self.assertEqual(self.config['clscount'], 3)
Пример #9
0
 def setUpClass(cls):
     cls.config = get_config(action='export')
     if not os.path.exists(cls.config['checkpoint']):
         download_checkpoint()
     cls.model_path = cls.config['checkpoint']