import tensorflow as tf
import cv2
from models.unet import Unet
from data_augmentation.data_augmentation import DataAugmentation
import numpy as np

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

# Initialize
IMAGE_PATH = "dataset/Original/Testing/"
MASK_PATH = "dataset/MASKS/Testing/"
IMAGE_FILE = "Frame00314-org"

model = Unet(input_shape=(224, 224, 1)).build()
model.load_weights("models/model_weight.h5")
model.summary()
print("yeah")


def convert_to_tensor(numpy_image):
    numpy_image = np.expand_dims(numpy_image, axis=2)
    numpy_image = np.expand_dims(numpy_image, axis=0)
    tensor_image = tf.convert_to_tensor(numpy_image)
    return tensor_image


def predict(image):
    process_obj = DataAugmentation(input_size=224, output_size=224)
    image_processed = process_obj.data_process_test(image)
    tensor_image = convert_to_tensor(image_processed)
Esempio n. 2
0
    #     EarlyStopping(monitor=f'val_{monitors}', patience =10, verbose =1 , mode ='min'),
    ModelCheckpoint(os.path.join(SEGMENT_RESULT_PATH,
                                 "checkpoint-{epoch:03d}.h5"),
                    monitor=f'val_{monitors}',
                    save_best_only=True,
                    mode='min'),
    LearningRateScheduler(lr_scheduler, verbose=1),
    HistoryCheckpoint(os.path.join(SEGMENT_RESULT_PATH, "checkpoint_hist.csv"),
                      monitors),
    #     SlackMessage(MY_SLACK_TOKEN,monitors)
]

try:
    weight = last_cheackpoint(SEGMENT_RESULT_PATH)
    init_epoch = int(os.path.basename(weight.split("-")[-1].split(".")[0]))
    unet.load_weights(weight)
    print(
        f"*******************\ncheckpoint restored {weight}\n*******************"
    )
except:
    init_epoch = 0
    print(
        "*******************\nfailed to load checkpoint\n*******************")

train_options = {
    "optimizer": get_config(optim),
    "batchsize": BATCH_SIZE,
    "loss_function": loss_func,
    "input_shape": IMAGE_SHAPE,
    "augmemtation": augm
}