def main(pb_file, img_file): """ Predict and visualize by TensorFlow. :param pb_file: :param img_file: :return: """ with tf.gfile.GFile(pb_file, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def, name=prefix) for op in graph.get_operations(): print(op.name) y = graph.get_tensor_by_name('proba/Sigmoid:0') x = graph.get_tensor_by_name('input_1:0') # with tf.gfile.FastGFile("seg.pb", mode='wb') as f: # f.write(graph.SerializeToString()) img_root = '/media/hszc/data1/seg_data/diy_seg' _, val_pd = get_train_val(img_root, test_size=1.0, random_state=42) img_paths = val_pd['image_paths'].tolist() mask_paths = val_pd['mask_paths'].tolist() transform = valAug(size=(256, 256)) deNormalizer = deNormalize(mean=None, std=None) import time with tf.Session(graph=graph) as sess: for img_path, mask_path in zip(img_paths, mask_paths): # img = img[:, :, 0:3] img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) img_h = img.shape[0] img_w = img.shape[1] mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) if mask is None: mask = np.zeros((img_h, img_w)) # batch and val aug print '==' * 20 print img_path print img.shape print mask.shape img, _ = transform(img, mask) img_batched = img[np.newaxis, :, :, :] start_time = time.time() pred = sess.run(y, feed_dict={x: img_batched}) end_time = time.time() print('time:', end_time - start_time ) # jian's : 224: 0.018 448:0.05 img = cv2.resize(deNormalizer(img), dsize=(img_w, img_h)) pred = cv2.resize(pred[0, :, :, 0], dsize=(img_w, img_h)) print img.shape, pred.shape if True: vis_segmentation(img, mask, pred)
train_root = '/media/hszc/data1/seg_data' val_root = '/media/hszc/data1/seg_data/diy_seg' save_dir = '/media/hszc/data1/seg_data/diy_seg/pred_mask(T20)' img_shape = (256, 256) bs = 8 do_para = False resume = './saved_models/MUs2(4-2 0p5)_256/bestmodel-[0.8534].h5' T = 20 if not os.path.exists(save_dir): os.mkdir(save_dir) # prepare data train_pd, _ = get_train_val(train_root, test_size=0.0) _, val_pd = get_train_val(val_root, test_size=1.0) print train_pd.info() print val_pd.info() data_set, data_loader = gen_dataloader(train_pd, val_pd, valAug(), valAug(), train_bs=bs, val_bs=2, train_shuffle=False, val_shuffle=False) trained_model = MobileUNet_s2(input_shape=(256, 256, 3),
img_shape = (256, 256) save_dir = './saved_models/MUs2(4-2 0p5)_256-dis/' alpha = 0.95 T = 20 bs = 8 do_para = False resume = './saved_models/MUs2(4-2 0p5)_256-dis/weights-[0-0]-[0.4009].h5' if not os.path.exists(save_dir): os.makedirs(save_dir) logfile = '%s/trainlog.log' % save_dir trainlog(logfile) # prepare data train_pd, _ = get_train_val(train_root, test_size=0.0, dis='pred_mask(T20)') _, val_pd = get_train_val(val_root, test_size=1.0) print train_pd.info() print val_pd.info() data_set, data_loader = gen_dataloader(train_pd, val_pd, trainAug(), valAug(), train_bs=bs, val_bs=2, dis=True) # logging info of dataset logging.info(train_pd.shape)