예제 #1
0
    def __init__(self,
                 path_list_train,
                 path_list_val,
                 n_classes,
                 threadNum=24,
                 queueCapacity=64,
                 kwargs_dat={}):
        self.n_classes = n_classes
        self.list_train = None
        self.size_train = 0
        self.q_train = None
        self.list_val = None
        self.size_val = 0
        self.q_val = None

        self.mvn = kwargs_dat.get("mvn", False)
        self.batchsize_tr = kwargs_dat.get("batchsize_tr", 1)
        self.batchsize_val = kwargs_dat.get("batchsize_val", 1)
        self.affine_tr = kwargs_dat.get("affine_tr", False)
        self.affine_val = kwargs_dat.get("affine_val", False)
        self.affine_value = kwargs_dat.get("affine_value", 0.025)
        self.elastic_tr = kwargs_dat.get("elastic_tr", False)
        self.elastic_val = kwargs_dat.get("elastic_val", False)
        self.elastic_value_x = kwargs_dat.get("elastic_val_x", 0.0002)
        self.elastic_value_y = kwargs_dat.get("elastic_value_y", 0.0002)
        self.rotate_tr = kwargs_dat.get("rotate_tr", False)
        self.rotate_val = kwargs_dat.get("rotate_val", False)
        self.rotateMod90_tr = kwargs_dat.get("rotateMod90_tr", False)
        self.rotateMod90_val = kwargs_dat.get("rotateMod90_val", False)
        self.skelet = kwargs_dat.get("skelet", True)
        self.dilate_num = kwargs_dat.get("dilate_num", 1)
        self.scale_min = kwargs_dat.get("scale_min", 1.0)
        self.scale_max = kwargs_dat.get("scale_max", 1.0)
        self.scale_val = kwargs_dat.get("scale_val", 1.0)
        self.one_hot_encoding = kwargs_dat.get("one_hot_encoding", True)
        self.dominating_channel = kwargs_dat.get("dominating_channel", 0)
        self.dominating_channel = min(self.dominating_channel, n_classes - 1)
        self.shuffle = kwargs_dat.get("shuffle", True)

        self.threadNum = threadNum
        self.queueCapacity = queueCapacity
        self.stopTrain = threading.Event()
        self.stopVal = threading.Event()
        if path_list_train != None:
            self.list_train = read_image_list(path_list_train)
            self.size_train = len(self.list_train)
            self.q_train, self.threads_tr = self._get_list_queue(
                self.list_train, self.threadNum, self.queueCapacity,
                self.stopTrain, self.batchsize_tr, self.scale_min,
                self.scale_max, self.affine_tr, self.elastic_tr,
                self.rotate_tr, self.rotateMod90_tr)
        if path_list_val != None:
            self.list_val = read_image_list(path_list_val)
            self.size_val = len(self.list_val)
            self.q_val, self.threads_val = self._get_list_queue(
                self.list_val, 1, 100, self.stopVal, self.batchsize_val,
                self.scale_val, self.scale_val, self.affine_val,
                self.elastic_val, self.rotate_val, self.rotateMod90_val)
예제 #2
0
파일: aru_net_test.py 프로젝트: pengge/SWDT
def main():
    if 'posix' == os.name:
        aru_net_dir_path = '/home/sangwook/lib_repo/python/ARU-Net_github'
    else:
        aru_net_dir_path = 'D:/lib_repo/python/rnd/ARU-Net_github'

    path_to_pb = os.path.join(aru_net_dir_path, 'demo_nets/model100_ema.pb'
                              )  # ${ARU-Net_HOME}/demo_nets/model100_ema.pb
    #path_list_imgs = os.path.join(aru_net_dir_path, 'demo_images/imgs.lst')  # ${ARU-Net_HOME}/demo_images/imgs.lst
    path_list_imgs = './epapyrus_images.lst'
    #path_list_imgs = './keit_images.lst'

    img_list = read_image_list(path_list_imgs)
    inference(path_to_pb,
              img_list,
              scale=1.0,
              mode='L',
              print_result=True,
              gpu_device='0')
예제 #3
0
def run(path_list_imgs, path_net_pb):
    list_inf = read_image_list(path_list_imgs)
    inference = Inference_pb(path_net_pb, list_inf, mode='L')
    inference.inference()
def run(path_list_imgs, path_net_pb):
    list_inf = read_image_list(path_list_imgs)
    inference = Predict_pb(path_net_pb)
    outputs = inference.predict(list_inf)
    print("Done")