def __init__(self, tiktorch_net, filename=None, HALO_SIZE=32, BATCH_SIZE=3): """ Args: tiktorch_net (tiktorch): tiktorch object to be loaded into this classifier object filename (None, optional): Save file name for future reference """ self._filename = filename if self._filename is None: self._filename = "" self.HALO_SIZE = HALO_SIZE self.BATCH_SIZE = BATCH_SIZE if tiktorch_net is None: print(self._filename) tiktorch_net = TikTorch.unserialize(self._filename) # print (self._filename) # assert tiktorch_net.return_hypercolumns == False # print('blah') self._tiktorch_net = tiktorch_net self._opReorderAxes = OpReorderAxes(graph=Graph()) self._opReorderAxes.AxisOrder.setValue("zcyx")
def _test_real_serialization(self): wannabe = TinyConvNet3D() tik_torch = TikTorch(model=wannabe) with tempdir() as d: the_path = '{}/testfile.nn'.format(d) tik_torch.serialize(to_path=the_path) new_torch = TikTorch.unserialize(from_path=the_path) self.assertIsInstance(new_torch, TikTorch)
def _test_gpu_serialization(self): wannabe = TinyConvNet3D() tik_torch = TikTorch(model=wannabe) tik_torch.cuda() with tempdir() as d: the_path = "{}/testfile.nn".format(d) tik_torch.serialize(to_path=the_path) new_torch = TikTorch.unserialize(from_path=the_path) self.assertIsInstance(new_torch, TikTorch) self.assertTrue(new_torch.is_cuda)
def create_and_train_pixelwise(self, feature_images, label_images, axistags=None, feature_names=None): self._filename = PYTORCH_MODEL_FILE_PATH logger.debug("Loading pytorch network from {}".format(self._filename)) # Save for future reference # known_labels = numpy.sort(vigra.analysis.unique(y)) # TODO: check whether loaded network has the same number of classes as specified in ilastik! self._loaded_pytorch_net = TikTorch.unserialize(self._filename) logger.info(self.description) # logger.info("OOB during training: {}".format( oob )) return TikTorchLazyflowClassifier(self._loaded_pytorch_net, self._filename)
def deserialize_hdf5(cls, h5py_group): # TODO: load from HDF5 instead of hard coded path! logger.debug("Deserializing") # HACK: # filename = PYTORCH_MODEL_FILE_PATH filename = h5py_group[cls.HDF5_GROUP_FILENAME] logger.debug("Deserializing from {}".format(filename)) with tempfile.TemporaryFile() as f: f.write(h5py_group["classifier"].value) f.seek(0) loaded_pytorch_net = TikTorch.unserialize(f) return TikTorchLazyflowClassifier(loaded_pytorch_net, filename)