Exemplo n.º 1
0
def select_model(args):
    if args.model == "unet":
        model = unet_model.unet(args)
        if args.weights:
            model.load_weights(args.weights)
        return model
    elif args.model == "resunet":
        model = resunet_model.build_res_unet(args)
        if args.weights:
            model.load_weights(args.weights)
        return model
    elif args.model == "segnet":
        model = segnet_model.create_segnet(args)
        if args.weights:
            model.load_weights(args.weights)
        return model
    elif args.model == "unet_mini":
        model = unet_mini.UNet(args)
        if args.weights:
            model.load_weights(args.weights)
        return model
    else:
        print(args.model + "Model does not exist, select model from"
              " unet, unet_mini, resunet and segnet")
        sys.exit()
Exemplo n.º 2
0
def summary(args):
    if args.model == "unet":
        model = unet_model.UNet(args)
        model.summary()
    elif args.model == "resunet":
        model = resunet_model.build_res_unet(args)
        model.summary()
    elif args.model == "segnet":
        model = segnet_model.create_segnet(args)
        model.summary()
    elif args.model == "linknet":
        # pretrained_encoder = 'True',
        # weights_path = './checkpoints/linknet_encoder_weights.h5'
        model = LinkNet(1, input_shape=(256, 256, 3))
        model = model.get_model()
        model.summary()
    elif args.model == "DLinkNet":
        model = segnet_model.create_segnet(args)
        model.summary()
    else:
        print("The model name should be from the unet, resunet, linknet or segnet")
Exemplo n.º 3
0
def train(args):
    train_csv = args.train_csv
    valid_csv = args.valid_csv
    image_paths = []
    label_paths = []
    valid_image_paths = []
    valid_label_paths = []

    with open(train_csv, 'r', newline='\n') as csvfile:
        plots = csv.reader(csvfile, delimiter=',')
        for row in plots:
            # print(row)
            image_paths.append(row[0])
            label_paths.append(row[1])

    with open(valid_csv, 'r', newline='\n') as csvfile:
        plots = csv.reader(csvfile, delimiter=',')
        for row in plots:
            valid_image_paths.append(row[0])
            valid_label_paths.append(row[1])

    if args.model == "unet":
        model = unet_model.UNet(args)
    elif args.model == "resunet":
        model = resunet_model.build_res_unet(args)
    elif args.model == "segnet":
        model = segnet_model.create_segnet(args)
    else:
        print("The model name should be from the unet, resunet or segnet")

    model.compile(optimizer="adam",
                  loss="binary_crossentropy",
                  metrics=["acc"])
    input_shape = args.input_shape
    train_gen = datagen.DataGenerator(image_paths,
                                      label_paths,
                                      batch_size=args.batch_size,
                                      n_channels=input_shape[2],
                                      patch_size=input_shape[1],
                                      shuffle=True)
    valid_gen = datagen.DataGenerator(valid_image_paths,
                                      valid_label_paths,
                                      batch_size=args.batch_size,
                                      n_channels=input_shape[2],
                                      patch_size=input_shape[1],
                                      shuffle=True)
    train_steps = len(image_paths) // args.batch_size
    valid_steps = len(valid_image_paths) // args.batch_size

    model_name = args.model_name
    model_file = model_name + str(args.epochs) + datetime.datetime.today(
    ).strftime("_%d_%m_%y") + ".hdf5"
    log_file = model_name + str(
        args.epochs) + datetime.datetime.today().strftime("_%d_%m_%y") + ".log"
    # Training the model
    model_checkpoint = ModelCheckpoint(model_file,
                                       monitor='val_loss',
                                       verbose=1,
                                       save_best_only=True)
    csv_logger = CSVLogger(log_file, separator=',', append=False)
    model.fit_generator(train_gen,
                        validation_data=valid_gen,
                        steps_per_epoch=train_steps,
                        validation_steps=valid_steps,
                        epochs=args.epochs,
                        callbacks=[model_checkpoint, csv_logger])

    # Save the model
    print("Model successfully trained")
Exemplo n.º 4
0
def accuracy(args):
    if args.onehot == "yes":
        if args.model == "resunet":
            model = resunet_novel2.build_res_unet(args)
            model.load_weights(args.weights)
    else:
        if args.model == "unet":
            model = unet_model.UNet(args)
            model.load_weights(args.weights)
        elif args.model == "resunet":
            # model = load_model(args.weights)
            model = resunet_model.build_res_unet(args)
            model.load_weights(args.weights)
        elif args.model == "segnet":
            model = segnet_model.create_segnet(args)
            model.load_weights(args.weights)
        else:
            print("The model name should be from the unet, resunet or segnet")
    # print(model)
    paths_file = args.csv_paths
    test_image_paths = []
    test_label_paths = []
    test_pred_paths = []
    with open(paths_file, 'r', newline='\n') as csvfile:
        plots = csv.reader(csvfile, delimiter=',')
        for row in plots:
            test_image_paths.append(row[0])
            test_label_paths.append(row[1])
            if len(row) > 2:
                test_pred_paths.append((row[2]))
            # print(row[0], row[1])
    print(len(test_image_paths), len(test_label_paths), len(test_pred_paths))
    # print(test_image_paths, test_label_paths)

    tn, fp, fn, tp = 0, 0, 0, 0
    rows = []
    for i in range(len(test_image_paths)):
        image = gdal.Open(test_image_paths[i])
        image_array = np.array(image.ReadAsArray()) / 255
        image_array = image_array.transpose(1, 2, 0)
        label = gdal.Open(test_label_paths[i])
        label_array = np.array(label.ReadAsArray()) / 255
        label_array = np.expand_dims(label_array, axis=-1)
        # print(len(test_pred_paths))
        if len(test_pred_paths) > 0:
            pred = gdal.Open(test_pred_paths[i])
            pred_array = np.array(pred.ReadAsArray())
            pred_array = np.expand_dims(pred_array, axis=-1)
            image_array = np.concatenate((image_array, pred_array), axis=2)
        fm = np.expand_dims(image_array, axis=0)
        result_array = model.predict(fm)
        result_array = np.squeeze(result_array)  # .transpose(2, 0, 1)
        # print(result_array.shape)
        # result_array = result_array[1:, :, :]
        # print(result_array.shape)
        A = np.around(label_array.flatten())
        B = np.around(result_array.flatten())
        cm = confusion_matrix(A, B)
        if len(cm) == 1:
            rows.append(
                [test_image_paths[i], test_label_paths[i], cm[0][0], 0, 0, 0])
            tn += cm[0][0]
        else:
            rows.append([
                test_image_paths[i], test_label_paths[i], cm[0][0], cm[0][1],
                cm[1][0], cm[1][1]
            ])
            tn += cm[0][0]
            fp += cm[0][1]
            fn += cm[1][0]
            tp += cm[1][1]
        print("Predicted " + str(i + 1) + " Images")

    iou = tp / (tp + fp + fn)
    f_score = (2 * tp) / (2 * tp + fp + fn)

    print("IOU Score: " + str(iou))
    print("F-Score: " + str(f_score))
Exemplo n.º 5
0
def accuracy(args):
    if args.model == "unet":
        model = unet_model.UNet(args)
        model.load_weights(args.weights)
    elif args.model == "resunet":
        # model = load_model(args.weights)
        model = resunet_model.build_res_unet(args)
        model.load_weights(args.weights)
    elif args.model == "segnet":
        model = segnet_model.create_segnet(args)
        model.load_weights(args.weights)
    else:
        print("The model name should be from the unet, resunet or segnet")
    # print(model)
    paths_file = args.csv_paths
    test_image_paths = []
    test_label_paths = []
    with open(paths_file, 'r', newline='\n') as csvfile:
        plots = csv.reader(csvfile, delimiter=',')
        for row in plots:
            test_image_paths.append(row[0])
            test_label_paths.append(row[1])
    y = []
    y_pred = []
    for i in range(len(test_image_paths)):
        image = gdal.Open(test_image_paths[i])
        image_array = np.array(image.ReadAsArray()) / 255
        image_array = image_array.transpose(1, 2, 0)
        label = gdal.Open(test_label_paths[i])
        label_array = np.array(label.ReadAsArray())
        label_array = np.expand_dims(label_array, axis=-1)
        fm = np.expand_dims(image_array, axis=0)
        result_array = model.predict(fm)
        result_array = np.argmax(result_array[0], axis=2)
        result_array = np.squeeze(result_array)
        y.append(np.around(label_array))
        y_pred.append(result_array)
        print("Predicted " + str(i + 1) + " Images")
    # print(len(np.array(y).flatten()), len(np.array(y_pred).flatten()))
    print("\n")
    cm = confusion_matrix(np.array(y).flatten(), np.array(y_pred).flatten())
    cm_multi = multilabel_confusion_matrix(np.array(y).flatten(), np.array(y_pred).flatten())
    print("Confusion Matrix " + "\n")
    print(cm, "\n")
    accuracy = np.trace(cm/np.sum(cm))
    print("Overal Accuracy: ", round(accuracy, 3), "\n")

    mean_iou = 0
    mean_f1 = 0
    for j in range(len(cm_multi)):
        print("Class: " + str(j))
        iou = cm_multi[j][1][1] / (cm_multi[j][1][1] + cm_multi[j][0][1] + cm_multi[j][1][0])
        f1 = (2 * cm_multi[j][1][1]) / (2 * cm_multi[j][1][1] + cm_multi[j][0][1] + cm_multi[j][1][0])
        precision = cm_multi[j][1][1] / (cm_multi[j][1][1] + cm_multi[j][0][1])
        recall = cm_multi[j][1][1] / (cm_multi[j][1][1] + cm_multi[j][1][0])
        mean_iou  += iou
        mean_f1 += f1
        print("IoU Score: ", round(iou, 3))
        print("F1-Measure: ", round(f1, 3))
        print("Precision: ", round(precision, 3))
        print("Recall: ", round(recall, 3), "\n")
    print("Mean IoU Score: ", round(mean_iou/len(cm_multi), 3))
    print("Mean F1-Measure: ", round(mean_f1/len(cm_multi), 3))