Ejemplo n.º 1
0
def test_path_from_model_name():
    assert path_from_model_name("example_net") == os.path.join(
        PROJECT_DIR, "models", "example_net", "latest.h5")
    assert path_from_model_name("example_net:latest") == os.path.join(
        PROJECT_DIR, "models", "example_net", "latest.h5")
    assert path_from_model_name("example_net:epoch1") == os.path.join(
        PROJECT_DIR, "models", "example_net", "epoch1.h5")
    with pytest.raises(ValueError):
        path_from_model_name("A\nB")
Ejemplo n.º 2
0
def train_wrapper(args: Namespace) -> None:
    """ Function for training a network. """
    model_name = args.model
    if args.cont:
        model = load_model(model_name)
        history = model.__asf_model_history
        weights = model.get_weights()
        lr_schedule = ExponentialDecay(9.2e-4,
                                       decay_steps=2000,
                                       decay_rate=0.96,
                                       staircase=True)
        # optimizer = model.optimizer
        model.compile(loss=jaccard_distance_loss,
                      optimizer=Adam(learning_rate=lr_schedule),
                      metrics=['accuracy', MeanIoU(num_classes=2)])
        model.set_weights(weights)
    #     model.compile(
    #         loss='binary_crossentropy', optimizer='adam', metrics=["accuracy"]
    # )
    else:
        model_path = path_from_model_name(model_name)
        if not args.overwrite and os.path.isfile(model_path):
            print(f"File {model_name} already exists!")
            return

        # model = create_model_masked(model_name)
        model = create_cdl_model_masked(model_name)
        history = {'loss': [], 'accuracy': [], "mean_io_u": []}

    train_model(model, history, args.dataset, args.epochs)
Ejemplo n.º 3
0
def train_wrapper(args: Namespace) -> None:
    """ Function for training a network. """
    model_name = args.model
    if args.cont:
        model = load_model(model_name)
        history = model.__asf_model_history
    else:
        model_path = path_from_model_name(model_name)
        if not args.overwrite and os.path.isfile(model_path):
            print(f"File {model_name} already exists!")
            return

        model = create_model_masked(model_name)
        history = {"loss": [], "acc": [], "val_loss": [], "val_acc": []}

    train_model(model, history, args.dataset, args.epochs)
Ejemplo n.º 4
0
def test_wrapper(args: Namespace) -> None:
    model_name = args.model
    model = load_model(model_name)

    if model_type(model) != dataset_type(args.dataset):
        print("ERROR: This dataset is not compatible with your model")
        return
    if dataset_type(args.dataset) == ModelType.MASKED:
        predictions, test_iter = test_model_masked(model, args.dataset)
        plot_masked_predictions(predictions, test_iter, args.dataset)
    else:

        details, confusion_matrix = test_model_binary(model, args.dataset)

        model_dir = os.path.dirname(path_from_model_name(model_name))
        with open(os.path.join(model_dir, 'results.csv'), 'w') as f:
            write_dict_to_csv(details, f)

        plot_confusion_chart(confusion_matrix)
        plot_predictions(details['Percent'], args.dataset)
Ejemplo n.º 5
0
def train_wrapper(args: Namespace) -> None:
    """Function for training a network"""
    data_type = dataset_type(args.dataset)
    model_name = args.model
    if args.cont:
        model = load_model(model_name)
        history = model.__asf_model_history
    else:
        model_path = path_from_model_name(model_name)
        if not args.overwrite and os.path.isfile(model_path):
            print(f"File {model_name} already exists!")
            return

        model = create_model(model_name, data_type)
        history = {"loss": [], "acc": [], "val_loss": [], "val_acc": []}

    if model_type(model) != data_type:
        print("ERROR: This dataset is not compatible with your model")
        return

    train_model(model, history, args.dataset, args.epochs)
Ejemplo n.º 6
0
def test_fuzz_path_from_model_name(network_name, tag):
    assert path_from_model_name(f"{network_name}:{tag}") == os.path.join(
        PROJECT_DIR, "models", network_name, f"{tag}.h5")
Ejemplo n.º 7
0
def test_fuzz_path_from_model_name(network_name, tag):
    assert path_from_model_name(f"{network_name}:{tag}") == os.path.join(
        MODEL_DIR, network_name, f"{tag}.h5")