def test_cfg_to_algorithm_pass_2(write_yaml): """ Test ``yatsm.classifiers.cfg_to_algorithm`` with an empty config """ cfg = { 'algorithm': 'RandomForest', 'RandomForest': {'init': {}, 'fit': {}} } classifiers.cfg_to_algorithm(write_yaml(cfg))
def test_cfg_to_algorithm_fail_2(write_yaml): """ Fail because algorithm parameters don't exist """ cfg = { 'algorithm': 'RandomForest', 'RandomForest': {'init': {'not_a_param': 42}, 'fit': {}} } with pytest.raises(TypeError): classifiers.cfg_to_algorithm(write_yaml(cfg))
def test_cfg_to_algorithm_pass_2(write_yaml): """ Test ``yatsm.classifiers.cfg_to_algorithm`` with an empty config """ cfg = { 'algorithm': 'RandomForest', 'RandomForest': { 'init': {}, 'fit': {} } } classifiers.cfg_to_algorithm(write_yaml(cfg))
def test_cfg_to_algorithm_fail_2(write_yaml): """ Fail because algorithm parameters don't exist """ cfg = { 'algorithm': 'RandomForest', 'RandomForest': { 'init': { 'not_a_param': 42 }, 'fit': {} } } with pytest.raises(TypeError): classifiers.cfg_to_algorithm(write_yaml(cfg))
def train(ctx, config, classifier_config, model, n_fold, seed, plot, diagnostics, overwrite): """ Train a classifier from `scikit-learn` on YATSM output and save result to file <model>. Dataset configuration is specified by <yatsm_config> and classifier and classifier parameters are specified by <classifier_config>. """ # Setup if not model.endswith(".pkl"): model += ".pkl" if os.path.isfile(model) and not overwrite: logger.error("<model> exists and --overwrite was not specified") raise click.Abort() if seed: np.random.seed(seed) # Parse config & algorithm config cfg = parse_config_file(config) algo, algo_cfg = classifiers.cfg_to_algorithm(classifier_config) training_image = cfg["classification"]["training_image"] if not training_image or not os.path.isfile(training_image): logger.error("Training data image %s does not exist" % training_image) raise click.Abort() # Find information from results -- e.g., design info attrs = find_result_attributes(cfg) cfg["YATSM"].update(attrs) # Cache file for training data has_cache = False training_cache = cfg["classification"]["cache_training"] if training_cache: # If doesn't exist, retrieve it if not os.path.isfile(training_cache): logger.info("Could not retrieve cache file for Xy") logger.info(" file: %s" % training_cache) else: logger.info("Restoring X/y from cache file") has_cache = True training_image = cfg["classification"]["training_image"] # Check if we need to regenerate the cache file because training data is # newer than the cache regenerate_cache = is_cache_old(training_cache, training_image) if regenerate_cache: logger.warning("Existing cache file older than training data ROI") logger.warning("Regenerating cache file") if not has_cache or regenerate_cache: logger.debug("Reading in X/y") X, y, row, col, labels = get_training_inputs(cfg) logger.debug("Done reading in X/y") else: logger.debug("Reading in X/y from cache file %s" % training_cache) with np.load(training_cache) as f: X = f["X"] y = f["y"] row = f["row"] col = f["col"] labels = f["labels"] logger.debug("Read in X/y from cache file %s" % training_cache) # If cache didn't exist but is specified, create it for first time if not has_cache and training_cache: logger.info("Saving X/y to cache file %s" % training_cache) try: np.savez(training_cache, X=X, y=y, row=row, col=col, labels=labels) except: logger.error("Could not save X/y to cache file") raise # Do modeling logger.info("Training classifier") algo.fit(X, y, **algo_cfg.get("fit", {})) # Serialize algorithm to file logger.info("Pickling classifier with sklearn.externals.joblib") joblib.dump(algo, model, compress=3) # Diagnostics if diagnostics: algo_diagnostics(cfg, X, y, row, col, algo, n_fold, plot)
def test_cfg_to_algorithm_fail_3(tmpdir): """ Fail because we don't use a YAML file """ f = tmpdir.mkdir('clf').join('test').strpath with pytest.raises(IOError): classifiers.cfg_to_algorithm(f)
def test_cfg_to_algorithm_fail_1(write_yaml): """ Fail because algorithm in config doesn't exist """ cfg = {'algorithm': 'hopefully_not_an_algo'} with pytest.raises(AlgorithmNotFoundException): classifiers.cfg_to_algorithm(write_yaml(cfg))