def test_wrapper(args: Namespace) -> None: model_name = args.model model = load_model(model_name) if model_type(model) != dataset_type(args.dataset): print("ERROR: This dataset is not compatible with your model") return if dataset_type(args.dataset) == ModelType.MASKED: predictions, test_iter = test_model_masked(model, args.dataset) plot_masked_predictions(predictions, test_iter, args.dataset) else: details, confusion_matrix = test_model_binary(model, args.dataset) model_dir = os.path.dirname(path_from_model_name(model_name)) with open(os.path.join(model_dir, 'results.csv'), 'w') as f: write_dict_to_csv(details, f) plot_confusion_chart(confusion_matrix) plot_predictions(details['Percent'], args.dataset)
def train_wrapper(args: Namespace) -> None: """Function for training a network""" data_type = dataset_type(args.dataset) model_name = args.model if args.cont: model = load_model(model_name) history = model.__asf_model_history else: model_path = path_from_model_name(model_name) if not args.overwrite and os.path.isfile(model_path): print(f"File {model_name} already exists!") return model = create_model(model_name, data_type) history = {"loss": [], "acc": [], "val_loss": [], "val_acc": []} if model_type(model) != data_type: print("ERROR: This dataset is not compatible with your model") return train_model(model, history, args.dataset, args.epochs)
def test_model_type(fake_model: Model, fake_model_masked: Model, fake_model_other: Model): assert model_type(fake_model_masked) == ModelType.MASKED assert model_type(fake_model_other) is None