Пример #1
0
 def test_process_2D(self):
     model = nn.Sequential(
         nn.Conv2d(3, 512, 3), nn.Conv2d(512, 512, 3), nn.Conv2d(512, 512, 3), nn.Conv2d(512, 3, 3)
     )
     handler = ModelHandler(
         model=model, channels=3, device_names="cuda:0", dynamic_shape_code="(32 * (nH + 1), 32 * (nW + 1))"
     )
     shape = handler.binary_dry_run([3096, 2048])
     out = handler.forward(torch.zeros(*([1, 3] + shape), dtype=torch.float32))
Пример #2
0
 def test_process_3D(self):
     model = nn.Sequential(nn.Conv3d(1, 12, 3), nn.Conv3d(12, 12, 3), nn.Conv3d(12, 12, 3), nn.Conv3d(12, 1, 3))
     handler = ModelHandler(
         model=model,
         channels=1,
         device_names="cuda:0",
         dynamic_shape_code="(32 * (nD + 1), 32 * (nH + 1), 32 * (nW + 1))",
     )
     shape = handler.binary_dry_run([128, 512, 512])
     out = handler.forward(torch.zeros(*([1, 1] + shape), dtype=torch.float32))
Пример #3
0
class TestUNet(unittest.TestCase):
    def setUp(self):
        path = "/export/home/jhugger/sfb1129/pretrained_net_constantin/ISBI2012_UNet_pretrained/"
        model_file_name = path + "model.py"  #'/export/home/jhugger/sfb1129/ISBI2012_UNet_pretrained/model.py'
        module_spec = imputils.spec_from_file_location("model",
                                                       model_file_name)
        module = imputils.module_from_spec(module_spec)
        module_spec.loader.exec_module(module)
        model: torch.nn.Module = getattr(module,
                                         "UNet2dGN")(in_channels=1,
                                                     initial_features=64,
                                                     out_channels=1)
        state_path = path + "state.nn"  #'/export/home/jhugger/sfb1129/ISBI2012_UNet_pretrained/state.nn'

        try:
            state_dict = torch.load(state_path,
                                    map_location=lambda storage, loc: storage)
            model.load_state_dict(state_dict)
        except:
            raise FileNotFoundError(
                f"Model weights could not be found at location '{state_path}'!"
            )

        self.handler = ModelHandler(
            model=model,
            channels=1,
            device_names="cuda:0",
            dynamic_shape_code="(32 * (nH + 1), 32 * (nW + 1))")

    def test_model(self):
        self.setUp()
        # shape = self.handler.binary_dry_run([2000, 2000])
        transform = Compose(Normalize(), Cast("float32"))

        # with h5py.File('/export/home/jhugger/sfb1129/sample_C_20160501.hdf') as f:
        with h5py.File(
                "/export/home/jhugger/sfb1129/sample_C_20160501.hdf") as f:
            cremi_raw = f["volumes"]["raw"][0:1, 0:1248, 0:1248]

        input_tensor = torch.from_numpy(transform(cremi_raw[0:1]))
        input_tensor = torch.rand(1, 572, 572)
        print(torch.unsqueeze(input_tensor, 0).shape)
        out = self.handler.forward(torch.unsqueeze(input_tensor, 0))
        import scipy

        scipy.misc.imsave("/export/home/jhugger/sfb1129/tiktorch/out.jpg",
                          out[0, 0].data.cpu().numpy())
        scipy.misc.imsave("/home/jo/server/tiktorch/out.jpg",
                          out[0, 0].data.cpu().numpy())
Пример #4
0
class TestDenseUNet(unittest.TestCase):
    def setUp(self):
        model_file_name = "/export/home/jhugger/sfb1129/CREMI_DUNet_pretrained/model.py"
        module_spec = imputils.spec_from_file_location("model",
                                                       model_file_name)
        module = imputils.module_from_spec(module_spec)
        module_spec.loader.exec_module(module)
        model: torch.nn.Module = getattr(module, "DUNet2D")(in_channels=1,
                                                            out_channels=1)
        state_path = "/export/home/jhugger/sfb1129/CREMI_DUNet_pretrained/state.nn"

        try:
            state_dict = torch.load(state_path,
                                    map_location=lambda storage, loc: storage)
            model.load_state_dict(state_dict)
        except:
            raise FileNotFoundError(
                f"Model weights could not be found at location '{state_path}'!"
            )

        self.handler = ModelHandler(
            model=model,
            channels=1,
            device_names="cuda:0",
            dynamic_shape_code="(32 * (nH + 1), 32 * (nW + 1))")

    def test_model(self):
        self.setUp()
        shape = self.handler.binary_dry_run([1250, 1250])
        transform = Compose(Normalize(), Cast("float32"))

        with h5py.File(
                "/export/home/jhugger/sfb1129/sample_C_20160501.hdf") as f:
            # with h5py.File('/home/jo/sfb1129/sample_C_20160501.hdf') as f:
            cremi_raw = f["volumes"]["raw"][0:1, 0:shape[0], 0:shape[1]]

        input_tensor = torch.from_numpy(transform(cremi_raw[0:1]))
        out = self.handler.forward(torch.unsqueeze(input_tensor, 0))
        import scipy

        scipy.misc.imsave("/export/home/jhugger/sfb1129/tiktorch/out.jpg",
                          out[0, 0].data.cpu().numpy())