コード例 #1
0
ファイル: test_classifiers.py プロジェクト: ceholden/yatsm
def test_cfg_to_algorithm_pass_2(write_yaml):
    """ Test ``yatsm.classifiers.cfg_to_algorithm`` with an empty config
    """
    cfg = {
        'algorithm': 'RandomForest',
        'RandomForest': {'init': {}, 'fit': {}}
    }
    classifiers.cfg_to_algorithm(write_yaml(cfg))
コード例 #2
0
ファイル: test_classifiers.py プロジェクト: ceholden/yatsm
def test_cfg_to_algorithm_fail_2(write_yaml):
    """ Fail because algorithm parameters don't exist
    """
    cfg = {
        'algorithm': 'RandomForest',
        'RandomForest': {'init': {'not_a_param': 42}, 'fit': {}}
    }
    with pytest.raises(TypeError):
        classifiers.cfg_to_algorithm(write_yaml(cfg))
コード例 #3
0
ファイル: test_classifiers.py プロジェクト: sumesh1/yatsm
def test_cfg_to_algorithm_pass_2(write_yaml):
    """ Test ``yatsm.classifiers.cfg_to_algorithm`` with an empty config
    """
    cfg = {
        'algorithm': 'RandomForest',
        'RandomForest': {
            'init': {},
            'fit': {}
        }
    }
    classifiers.cfg_to_algorithm(write_yaml(cfg))
コード例 #4
0
ファイル: test_classifiers.py プロジェクト: sumesh1/yatsm
def test_cfg_to_algorithm_fail_2(write_yaml):
    """ Fail because algorithm parameters don't exist
    """
    cfg = {
        'algorithm': 'RandomForest',
        'RandomForest': {
            'init': {
                'not_a_param': 42
            },
            'fit': {}
        }
    }
    with pytest.raises(TypeError):
        classifiers.cfg_to_algorithm(write_yaml(cfg))
コード例 #5
0
ファイル: train.py プロジェクト: johanez/yatsm
def train(ctx, config, classifier_config, model, n_fold, seed, plot, diagnostics, overwrite):
    """
    Train a classifier from `scikit-learn` on YATSM output and save result to
    file <model>. Dataset configuration is specified by <yatsm_config> and
    classifier and classifier parameters are specified by <classifier_config>.
    """
    # Setup
    if not model.endswith(".pkl"):
        model += ".pkl"
    if os.path.isfile(model) and not overwrite:
        logger.error("<model> exists and --overwrite was not specified")
        raise click.Abort()

    if seed:
        np.random.seed(seed)

    # Parse config & algorithm config
    cfg = parse_config_file(config)
    algo, algo_cfg = classifiers.cfg_to_algorithm(classifier_config)

    training_image = cfg["classification"]["training_image"]
    if not training_image or not os.path.isfile(training_image):
        logger.error("Training data image %s does not exist" % training_image)
        raise click.Abort()

    # Find information from results -- e.g., design info
    attrs = find_result_attributes(cfg)
    cfg["YATSM"].update(attrs)

    # Cache file for training data
    has_cache = False
    training_cache = cfg["classification"]["cache_training"]
    if training_cache:
        # If doesn't exist, retrieve it
        if not os.path.isfile(training_cache):
            logger.info("Could not retrieve cache file for Xy")
            logger.info("    file: %s" % training_cache)
        else:
            logger.info("Restoring X/y from cache file")
            has_cache = True

    training_image = cfg["classification"]["training_image"]
    # Check if we need to regenerate the cache file because training data is
    #   newer than the cache
    regenerate_cache = is_cache_old(training_cache, training_image)
    if regenerate_cache:
        logger.warning("Existing cache file older than training data ROI")
        logger.warning("Regenerating cache file")

    if not has_cache or regenerate_cache:
        logger.debug("Reading in X/y")
        X, y, row, col, labels = get_training_inputs(cfg)
        logger.debug("Done reading in X/y")
    else:
        logger.debug("Reading in X/y from cache file %s" % training_cache)
        with np.load(training_cache) as f:
            X = f["X"]
            y = f["y"]
            row = f["row"]
            col = f["col"]
            labels = f["labels"]
        logger.debug("Read in X/y from cache file %s" % training_cache)

    # If cache didn't exist but is specified, create it for first time
    if not has_cache and training_cache:
        logger.info("Saving X/y to cache file %s" % training_cache)
        try:
            np.savez(training_cache, X=X, y=y, row=row, col=col, labels=labels)
        except:
            logger.error("Could not save X/y to cache file")
            raise

    # Do modeling
    logger.info("Training classifier")
    algo.fit(X, y, **algo_cfg.get("fit", {}))

    # Serialize algorithm to file
    logger.info("Pickling classifier with sklearn.externals.joblib")
    joblib.dump(algo, model, compress=3)

    # Diagnostics
    if diagnostics:
        algo_diagnostics(cfg, X, y, row, col, algo, n_fold, plot)
コード例 #6
0
ファイル: test_classifiers.py プロジェクト: sumesh1/yatsm
def test_cfg_to_algorithm_fail_3(tmpdir):
    """ Fail because we don't use a YAML file
    """
    f = tmpdir.mkdir('clf').join('test').strpath
    with pytest.raises(IOError):
        classifiers.cfg_to_algorithm(f)
コード例 #7
0
ファイル: test_classifiers.py プロジェクト: sumesh1/yatsm
def test_cfg_to_algorithm_fail_1(write_yaml):
    """ Fail because algorithm in config doesn't exist
    """
    cfg = {'algorithm': 'hopefully_not_an_algo'}
    with pytest.raises(AlgorithmNotFoundException):
        classifiers.cfg_to_algorithm(write_yaml(cfg))
コード例 #8
0
ファイル: test_classifiers.py プロジェクト: ceholden/yatsm
def test_cfg_to_algorithm_fail_3(tmpdir):
    """ Fail because we don't use a YAML file
    """
    f = tmpdir.mkdir('clf').join('test').strpath
    with pytest.raises(IOError):
        classifiers.cfg_to_algorithm(f)
コード例 #9
0
ファイル: test_classifiers.py プロジェクト: ceholden/yatsm
def test_cfg_to_algorithm_fail_1(write_yaml):
    """ Fail because algorithm in config doesn't exist
    """
    cfg = {'algorithm': 'hopefully_not_an_algo'}
    with pytest.raises(AlgorithmNotFoundException):
        classifiers.cfg_to_algorithm(write_yaml(cfg))