Ejemplo n.º 1
0
def train_model(X_train, X_test, y_train, y_test, save_dir):
    """
    Run train model process.

    Args:
        X_train (list): List of strings X_train files patches.
        X_test (list): List of strings X_test files patches.
        y_train (list): List of strings y_train files patches.
        y_test (list): List of strings y_test files patches.
        save_dir (pathlib.PosixPath): Path for save results.

    """
    len_train, len_test = len(X_train), len(X_test)

    # Create datasets from data.
    train_dataset = tf.data.Dataset.from_tensor_slices(
        (X_train, y_train)).map(load_data, num_parallel_calls=AUTOTUNE)

    test_dataset = tf.data.Dataset.from_tensor_slices(
        (X_test, y_test)).map(load_data, num_parallel_calls=AUTOTUNE)

    # Prepare data for training.
    train_dataset = train_dataset.map(normalize, num_parallel_calls=AUTOTUNE)
    train_dataset = train_dataset.shuffle(len_train * 2)
    train_dataset = train_dataset.repeat()
    train_dataset = train_dataset.batch(BATCH_SIZE)
    train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)

    test_dataset = test_dataset.map(normalize, num_parallel_calls=AUTOTUNE)
    test_dataset = test_dataset.batch(BATCH_SIZE)

    # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    # TRAIN MODEL.

    # Init paths.
    save_dir.mkdir(parents=True, exist_ok=True)

    path_weights = (save_dir / "weights.hdf5").as_posix()
    path_history = (save_dir / "training.csv").as_posix()
    path_model = (save_dir / "model").as_posix()

    path_predict_fulls = (save_dir / "predict_fulls.csv").as_posix()
    path_predict_means = (save_dir / "predict_means.csv").as_posix()

    # Init callbacks.
    callback_save_best_weights = tf.keras.callbacks.ModelCheckpoint(
        filepath=path_weights,
        monitor="val_loss",
        mode="min",
        save_best_only=True,
        verbose=0,
    )

    callback_csv_logger = tf.keras.callbacks.CSVLogger(
        filename=path_history,
        separator=",",
        append=False,
    )

    # Init optimizer.
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

    # Init model.
    model = UNet(
        input_shape=(IMAGE_NET_W, IMAGE_NET_H, CLASSES_COUNT),
        out_channels=CLASSES_COUNT,
        filters=[16, 32, 64, 128, 256],
    )

    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.MSE,
        metrics=[tf.keras.metrics.Accuracy()],
    )

    # Fake predict for build model and summary.
    model.predict(next(iter(test_dataset))[0])
    model.summary()

    # Load weights.
    if MODEL_LOAD_WEIGHTS:
        model.load_weights(path_weights)

    # Train model.
    if MODEL_RUN_TRAIN:
        model.fit(
            train_dataset,
            validation_data=test_dataset,
            epochs=EPOCHS,
            steps_per_epoch=len_train // BATCH_SIZE,
            validation_steps=len_test // BATCH_SIZE,
            callbacks=[callback_save_best_weights, callback_csv_logger],
        )

    # Evaluate model.
    evaluate = model.evaluate(test_dataset)
    print(tabulate(zip(model.metrics_names, evaluate), tablefmt="fancy_grid"))

    # Make prediction.
    predict_dataset = model.predict(test_dataset)

    # Save model.
    model.save(path_model)

    # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    # COMPARE WITH REAL COORDINATES.

    X_test_names = [pathlib.Path(x).stem for x in X_test]
    result_array = np.zeros((len_test, CLASSES_COUNT))

    for k, predict_array in enumerate(predict_dataset):
        for channel_type, channel_number in zip(CLASSES, range(CLASSES_COUNT)):
            # Get channel.
            predict_channel = predict_array[:, :, channel_number]

            # Get predicted coordinates with max value.
            maxes = np.argwhere(predict_channel.max() == predict_channel)
            x_p, y_p = np.mean(maxes, axis=0).astype(np.int)

            # Rescale predicted coordinates.
            x_p, y_p = rescale_coords(
                (x_p, y_p),
                src_size=(IMAGE_NET_W, IMAGE_NET_H),
                out_size=(IMAGE_SRC_W, IMAGE_SRC_H),
            )

            # Load text coordinates.
            filename = PWD_COORDS / channel_type / f"{X_test_names[k]}.txt"
            x_t, y_t = np.fromfile(filename, sep=",")

            # Compute distance.
            # ! NOTE: in this dataset, 1 cm == 11 pix.
            distance = np.sqrt((x_t - x_p)**2 + (y_t - y_p)**2) / 11.0

            # Save result.
            result_array[k][channel_number] = distance

    # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    # DUMP RESULTS.

    # Dump full results to CSV.
    results = pd.DataFrame(columns=["filename", *CLASSES])

    for n, filename in enumerate(X_test_names):

        # Write rows of matrix to CSV.
        named_results = dict(zip(CLASSES, list(result_array[n])))
        named_results["filename"] = filename

        results = results.append(named_results, ignore_index=True)

    results.to_csv(path_predict_fulls, index=False, sep=";")

    # Dump mean results to CSV.
    results = pd.DataFrame(columns=["class", "mean", "max"])

    means = np.mean(result_array, axis=0)
    maxes = np.max(result_array, axis=0)

    for class_label, val_mean, val_max in zip(CLASSES, means, maxes):
        results = results.append(
            {
                "class": class_label,
                "mean": val_mean,
                "max": val_max
            },
            ignore_index=True,
        )

    results.to_csv(path_predict_means, index=False, sep=";")
    print(results)
Ejemplo n.º 2
0
	os.makedirs(all_configs['checkpoints'])
all_configs['ckpt_dir'] = ckpt_dir

processWandb(all_configs)

def bakeGenerator(annot):
	lines = open(annot, 'r').read().strip().split('\n')
	x_set = [line.split()[0] for line in lines]
	y_set = [line.split()[1] for line in lines]
	# x_set = glob.glob(os.path.join(root, 'raw', '*.jpg'))
	# y_set = glob.glob(os.path.join(root, 'mask', '*.jpg'))
	gen = DataGenerator(x_set, y_set, batch_size=batch_size, shuffle=True)
	return gen

train_gen = bakeGenerator(train_annot)
val_gen = bakeGenerator(val_annot)

model = UNet()

hist = model.fit(
	train_gen,
	validation_data=val_gen,
	epochs=N_EPOCHS,
	batch_size=batch_size,
	verbose=1,
	use_multiprocessing=True,
	workers=8,
	callbacks=[CustomCallback(ckpt_dir),],
)