def __init__(self, path_list_train, path_list_val, n_classes, threadNum=24, queueCapacity=64, kwargs_dat={}): self.n_classes = n_classes self.list_train = None self.size_train = 0 self.q_train = None self.list_val = None self.size_val = 0 self.q_val = None self.mvn = kwargs_dat.get("mvn", False) self.batchsize_tr = kwargs_dat.get("batchsize_tr", 1) self.batchsize_val = kwargs_dat.get("batchsize_val", 1) self.affine_tr = kwargs_dat.get("affine_tr", False) self.affine_val = kwargs_dat.get("affine_val", False) self.affine_value = kwargs_dat.get("affine_value", 0.025) self.elastic_tr = kwargs_dat.get("elastic_tr", False) self.elastic_val = kwargs_dat.get("elastic_val", False) self.elastic_value_x = kwargs_dat.get("elastic_val_x", 0.0002) self.elastic_value_y = kwargs_dat.get("elastic_value_y", 0.0002) self.rotate_tr = kwargs_dat.get("rotate_tr", False) self.rotate_val = kwargs_dat.get("rotate_val", False) self.rotateMod90_tr = kwargs_dat.get("rotateMod90_tr", False) self.rotateMod90_val = kwargs_dat.get("rotateMod90_val", False) self.skelet = kwargs_dat.get("skelet", True) self.dilate_num = kwargs_dat.get("dilate_num", 1) self.scale_min = kwargs_dat.get("scale_min", 1.0) self.scale_max = kwargs_dat.get("scale_max", 1.0) self.scale_val = kwargs_dat.get("scale_val", 1.0) self.one_hot_encoding = kwargs_dat.get("one_hot_encoding", True) self.dominating_channel = kwargs_dat.get("dominating_channel", 0) self.dominating_channel = min(self.dominating_channel, n_classes - 1) self.shuffle = kwargs_dat.get("shuffle", True) self.threadNum = threadNum self.queueCapacity = queueCapacity self.stopTrain = threading.Event() self.stopVal = threading.Event() if path_list_train != None: self.list_train = read_image_list(path_list_train) self.size_train = len(self.list_train) self.q_train, self.threads_tr = self._get_list_queue( self.list_train, self.threadNum, self.queueCapacity, self.stopTrain, self.batchsize_tr, self.scale_min, self.scale_max, self.affine_tr, self.elastic_tr, self.rotate_tr, self.rotateMod90_tr) if path_list_val != None: self.list_val = read_image_list(path_list_val) self.size_val = len(self.list_val) self.q_val, self.threads_val = self._get_list_queue( self.list_val, 1, 100, self.stopVal, self.batchsize_val, self.scale_val, self.scale_val, self.affine_val, self.elastic_val, self.rotate_val, self.rotateMod90_val)
def main(): if 'posix' == os.name: aru_net_dir_path = '/home/sangwook/lib_repo/python/ARU-Net_github' else: aru_net_dir_path = 'D:/lib_repo/python/rnd/ARU-Net_github' path_to_pb = os.path.join(aru_net_dir_path, 'demo_nets/model100_ema.pb' ) # ${ARU-Net_HOME}/demo_nets/model100_ema.pb #path_list_imgs = os.path.join(aru_net_dir_path, 'demo_images/imgs.lst') # ${ARU-Net_HOME}/demo_images/imgs.lst path_list_imgs = './epapyrus_images.lst' #path_list_imgs = './keit_images.lst' img_list = read_image_list(path_list_imgs) inference(path_to_pb, img_list, scale=1.0, mode='L', print_result=True, gpu_device='0')
def run(path_list_imgs, path_net_pb): list_inf = read_image_list(path_list_imgs) inference = Inference_pb(path_net_pb, list_inf, mode='L') inference.inference()
def run(path_list_imgs, path_net_pb): list_inf = read_image_list(path_list_imgs) inference = Predict_pb(path_net_pb) outputs = inference.predict(list_inf) print("Done")