def test_path_from_model_name(): assert path_from_model_name("example_net") == os.path.join( PROJECT_DIR, "models", "example_net", "latest.h5") assert path_from_model_name("example_net:latest") == os.path.join( PROJECT_DIR, "models", "example_net", "latest.h5") assert path_from_model_name("example_net:epoch1") == os.path.join( PROJECT_DIR, "models", "example_net", "epoch1.h5") with pytest.raises(ValueError): path_from_model_name("A\nB")
def train_wrapper(args: Namespace) -> None: """ Function for training a network. """ model_name = args.model if args.cont: model = load_model(model_name) history = model.__asf_model_history weights = model.get_weights() lr_schedule = ExponentialDecay(9.2e-4, decay_steps=2000, decay_rate=0.96, staircase=True) # optimizer = model.optimizer model.compile(loss=jaccard_distance_loss, optimizer=Adam(learning_rate=lr_schedule), metrics=['accuracy', MeanIoU(num_classes=2)]) model.set_weights(weights) # model.compile( # loss='binary_crossentropy', optimizer='adam', metrics=["accuracy"] # ) else: model_path = path_from_model_name(model_name) if not args.overwrite and os.path.isfile(model_path): print(f"File {model_name} already exists!") return # model = create_model_masked(model_name) model = create_cdl_model_masked(model_name) history = {'loss': [], 'accuracy': [], "mean_io_u": []} train_model(model, history, args.dataset, args.epochs)
def train_wrapper(args: Namespace) -> None: """ Function for training a network. """ model_name = args.model if args.cont: model = load_model(model_name) history = model.__asf_model_history else: model_path = path_from_model_name(model_name) if not args.overwrite and os.path.isfile(model_path): print(f"File {model_name} already exists!") return model = create_model_masked(model_name) history = {"loss": [], "acc": [], "val_loss": [], "val_acc": []} train_model(model, history, args.dataset, args.epochs)
def test_wrapper(args: Namespace) -> None: model_name = args.model model = load_model(model_name) if model_type(model) != dataset_type(args.dataset): print("ERROR: This dataset is not compatible with your model") return if dataset_type(args.dataset) == ModelType.MASKED: predictions, test_iter = test_model_masked(model, args.dataset) plot_masked_predictions(predictions, test_iter, args.dataset) else: details, confusion_matrix = test_model_binary(model, args.dataset) model_dir = os.path.dirname(path_from_model_name(model_name)) with open(os.path.join(model_dir, 'results.csv'), 'w') as f: write_dict_to_csv(details, f) plot_confusion_chart(confusion_matrix) plot_predictions(details['Percent'], args.dataset)
def train_wrapper(args: Namespace) -> None: """Function for training a network""" data_type = dataset_type(args.dataset) model_name = args.model if args.cont: model = load_model(model_name) history = model.__asf_model_history else: model_path = path_from_model_name(model_name) if not args.overwrite and os.path.isfile(model_path): print(f"File {model_name} already exists!") return model = create_model(model_name, data_type) history = {"loss": [], "acc": [], "val_loss": [], "val_acc": []} if model_type(model) != data_type: print("ERROR: This dataset is not compatible with your model") return train_model(model, history, args.dataset, args.epochs)
def test_fuzz_path_from_model_name(network_name, tag): assert path_from_model_name(f"{network_name}:{tag}") == os.path.join( PROJECT_DIR, "models", network_name, f"{tag}.h5")
def test_fuzz_path_from_model_name(network_name, tag): assert path_from_model_name(f"{network_name}:{tag}") == os.path.join( MODEL_DIR, network_name, f"{tag}.h5")