Exemplo n.º 1
0
def get_params(test_config):
    """get params and save them to root dir"""
    prm = Parameters()

    # get giles paths
    prm.override(test_config)
    test_parameter_file = os.path.join(prm.train.train_control.ROOT_DIR,
                                       'test_parameters.ini')
    log_file = os.path.join(prm.train.train_control.ROOT_DIR, 'test.log')

    ret = True
    if os.path.isfile(test_parameter_file):
        warnings.warn('Test parameter file {} already exists'.format(
            test_parameter_file))
        ret = query_yes_no('Overwrite parameter file?')

    if ret:
        dir = os.path.dirname(test_parameter_file)
        if not os.path.exists(dir):
            os.makedirs(dir)
        prm.save(test_parameter_file)

    logging = logging_config(log_file)
    logging.disable(logging.DEBUG)

    return prm
Exemplo n.º 2
0
def get_params(test_config):
    """get params and save them to root dir"""
    prm = Parameters()

    # get giles paths
    prm.override(test_config)  # just to get the LOG_DIR_LIST[0]
    train_log_dir = prm.test.ensemble.LOG_DIR_LIST[0]

    parameter_file = os.path.join(train_log_dir, 'parameters.ini')
    test_parameter_file = os.path.join(prm.train.train_control.ROOT_DIR,
                                       'test_parameters.ini')
    all_parameter_file = os.path.join(prm.train.train_control.ROOT_DIR,
                                      'all_parameters.ini')
    log_file = os.path.join(prm.train.train_control.ROOT_DIR, 'test.log')

    if not os.path.isfile(parameter_file):
        raise AssertionError('Can not find file: {}'.format(parameter_file))

    ret = True
    if os.path.isfile(test_parameter_file):
        warnings.warn('Test parameter file {} already exists'.format(
            test_parameter_file))
        ret = query_yes_no('Overwrite parameter file?')

    if ret:
        dir = os.path.dirname(test_parameter_file)
        if not os.path.exists(dir):
            os.makedirs(dir)
        prm.save(test_parameter_file)

    logging = logging_config(log_file)
    logging.disable(logging.DEBUG)

    # Done saving test parameters. Now doing the integration:
    prm = Parameters()
    prm.override(parameter_file)
    prm.override(test_parameter_file)

    ret = True
    if os.path.isfile(all_parameter_file):
        warnings.warn(
            'All parameter file {} already exists'.format(all_parameter_file))
        ret = query_yes_no('Overwrite parameter file?')

    if ret:
        dir = os.path.dirname(all_parameter_file)
        if not os.path.exists(dir):
            os.makedirs(dir)
        prm.save(all_parameter_file)

    return prm
Exemplo n.º 3
0
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from lib.active_kmean import KMeansWrapper
from sklearn.datasets import make_blobs


logging = logging_config()

logging.disable(logging.DEBUG)
log = logger.get_logger('main')

prm_file = '/data/gilad/logs/log_2210_220817_wrn-fc2_kmeans_SGD_init_200_clusters_4_cap_204/parameters.ini'

prm = Parameters()
prm.override(prm_file)

dev = prm.network.DEVICE

factories = Factories(prm)

model = factories.get_model()
model.print_stats()  # debug

preprocessor = factories.get_preprocessor()
preprocessor.print_stats()  # debug

train_dataset = factories.get_train_dataset(preprocessor)
validation_dataset = factories.get_validation_dataset(preprocessor)

dataset_wrapper = DatasetWrapper(prm.dataset.DATASET_NAME + '_wrapper', prm, train_dataset, validation_dataset)
Exemplo n.º 4
0
def get_params(test_config, parser_args=None):
    """get params and save them to root dir"""

    # Just to get the ROOT_DIR and save prm test_config
    prm = Parameters()
    prm.override(test_config)

    # get manual test parameters from config:
    if parser_args is not None:
        # overriding some parameters manually from parser:
        prm.train.train_control.ROOT_DIR = parser_args.ROOT_DIR
        prm.train.train_control.TEST_DIR = parser_args.ROOT_DIR + '/test'
        prm.train.train_control.PREDICTION_DIR = parser_args.ROOT_DIR + '/prediction'
        prm.train.train_control.CHECKPOINT_DIR = parser_args.ROOT_DIR + '/checkpoint'
        prm.test.test_control.KNN_WEIGHTS = parser_args.KNN_WEIGHTS
        prm.test.test_control.KNN_NORM = parser_args.KNN_NORM
        prm.train.train_control.PCA_REDUCTION = (
            parser_args.PCA_REDUCTION == 'True')
        prm.train.train_control.PCA_EMBEDDING_DIMS = int(
            parser_args.PCA_EMBEDDING_DIMS)
        prm.test.test_control.KNN_NEIGHBORS = int(parser_args.KNN_NEIGHBORS)
        prm.test.test_control.DUMP_NET = (parser_args.DUMP_NET == 'True')
        prm.test.test_control.LOAD_FROM_DISK = (
            parser_args.LOAD_FROM_DISK == 'True')

    ROOT_DIR = prm.train.train_control.ROOT_DIR

    # get time stamp
    ts = get_timestamp()

    # get files paths
    parameter_file = os.path.join(ROOT_DIR, 'parameters.ini')
    test_parameter_file = os.path.join(ROOT_DIR,
                                       'test_parameters_' + ts + '.ini')
    all_parameter_file = os.path.join(ROOT_DIR,
                                      'all_parameters_' + ts + '.ini')
    log_file = os.path.join(ROOT_DIR, 'test_' + ts + '.log')
    logging = logging_config(log_file)
    logging.disable(logging.DEBUG)

    if not os.path.isfile(parameter_file):
        raise AssertionError('Can not find file: {}'.format(parameter_file))

    dir = os.path.dirname(test_parameter_file)
    if not os.path.exists(dir):
        os.makedirs(dir)
    prm.save(test_parameter_file)

    # Done saving test parameters. Now doing the integration:
    prm = Parameters()
    prm.override(parameter_file)
    prm.override(test_parameter_file)
    if parser_args is not None:
        # overriding some parameters manually from parser:
        prm.train.train_control.ROOT_DIR = parser_args.ROOT_DIR
        prm.train.train_control.TEST_DIR = parser_args.ROOT_DIR + '/test'
        prm.train.train_control.PREDICTION_DIR = parser_args.ROOT_DIR + '/prediction'
        prm.train.train_control.CHECKPOINT_DIR = parser_args.ROOT_DIR + '/checkpoint'
        prm.test.test_control.KNN_WEIGHTS = parser_args.KNN_WEIGHTS
        prm.test.test_control.KNN_NORM = parser_args.KNN_NORM
        prm.train.train_control.PCA_REDUCTION = (
            parser_args.PCA_REDUCTION == 'True')
        prm.train.train_control.PCA_EMBEDDING_DIMS = int(
            parser_args.PCA_EMBEDDING_DIMS)
        prm.test.test_control.KNN_NEIGHBORS = int(parser_args.KNN_NEIGHBORS)
        prm.test.test_control.DUMP_NET = (parser_args.DUMP_NET == 'True')
        prm.test.test_control.LOAD_FROM_DISK = (
            parser_args.LOAD_FROM_DISK == 'True')

    dir = os.path.dirname(all_parameter_file)
    if not os.path.exists(dir):
        os.makedirs(dir)
    prm.save(all_parameter_file)

    return prm