예제 #1
0
    def __init__(self, prm):
        self.log = logger.get_logger('factories')
        self.prm = prm

        self.dataset_name = self.prm.dataset.DATASET_NAME
        self.trainer = self.prm.train.train_control.TRAINER
        self.architecture = self.prm.network.ARCHITECTURE
        self.learning_rate_setter = self.prm.train.train_control.learning_rate_setter.LEARNING_RATE_SETTER
        self.tester = self.prm.test.test_control.TESTER
        self.active_selection_criterion = self.prm.train.train_control.ACTIVE_SELECTION_CRITERION
    def __init__(self, name, prm, fixed_centers, *args, **kwargs):
        super(AllCentersKMeans, self).__init__(*args, **kwargs)
        self.name = name
        self.prm = prm
        self.log = logger.get_logger(name)

        self.fixed_centers = fixed_centers
        self.n_fixed = fixed_centers.shape[0]
        self.init = 'random'
        self.n_init = 10
        self.verbose = False
        self.assert_config()
    def __init__(self, name, prm, model, retention):
        self.name = name
        self.prm = prm
        self.model = model
        self.retention = retention
        self.log = logger.get_logger(name)

        self._init_mm = self.prm.network.optimization.DML_MARGIN_MULTIPLIER
        self.decay_refractory_steps = self.prm.train.train_control.margin_multiplier_setter.MM_DECAY_REFRACTORY_STEPS
        self.global_step_of_last_decay = 0

        self._mm = self._init_mm
    def __init__(self, name, prm, model, steps_to_save, checkpoint_dir, saver, checkpoint_basename='model_schedule.ckpt'):
        self.name = name
        self.prm = prm
        self.log = logger.get_logger(name)
        self.model = model  # model might change between runs, cannot use global train step. Must use model step.
        self._saver = saver
        self._checkpoint_dir = checkpoint_dir
        self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)

        if steps_to_save is None:
            steps_to_save = []
        self._steps_to_save = steps_to_save
예제 #5
0
    def __init__(self, name, prm, fixed_centers, random_state, *args,
                 **kwargs):
        super(KMeansWrapper, self).__init__(*args, **kwargs)
        self.name = name
        self.prm = prm
        self.log = logger.get_logger(name)

        self.fixed_centers = fixed_centers
        self.n_fixed = fixed_centers.shape[0]
        self.random_state = random_state
        self.init = 'random'
        self.n_init = 1
        self.verbose = True
        self.assert_config()
    def __init__(self, name, prm, model, retention):
        self.name = name
        self.prm = prm
        self.model = model
        self.retention = retention  # used in children
        self.log = logger.get_logger(name)

        self.learning_rate_setter = self.prm.train.train_control.learning_rate_setter.LEARNING_RATE_SETTER
        self._init_lrn_rate = self.prm.network.optimization.LEARNING_RATE
        self._reset_lrn_rate = self.prm.train.train_control.learning_rate_setter.LEARNING_RATE_RESET
        if self._reset_lrn_rate is None:
            self.log.warning(
                'LEARNING_RATE_RESET is None. Setting LEARNING_RATE_RESET=LEARNING_RATE'
            )
            self._reset_lrn_rate = self.prm.network.optimization.LEARNING_RATE
        self._lrn_rate = self._init_lrn_rate
예제 #7
0
    def __init__(self, name, prm):
        super(DatasetWrapper, self).__init__(name)
        self.prm = prm
        self.log = logger.get_logger(name)
        self.dataset_name             = self.prm.dataset.DATASET_NAME
        self.train_set_size           = self.prm.dataset.TRAIN_SET_SIZE
        self.validation_set_size      = self.prm.dataset.VALIDATION_SET_SIZE
        self.test_set_size            = self.prm.dataset.TEST_SET_SIZE
        self.train_validation_map_ref = self.prm.dataset.TRAIN_VALIDATION_MAP_REF
        self.H                        = self.prm.network.IMAGE_HEIGHT
        self.W                        = self.prm.network.IMAGE_WIDTH
        self.train_batch_size         = self.prm.train.train_control.TRAIN_BATCH_SIZE
        self.eval_batch_size          = self.prm.train.train_control.EVAL_BATCH_SIZE
        self.rand_gen                 = np.random.RandomState(prm.SUPERSEED)

        self.train_validation_info    = []

        self.train_dataset            = None
        self.train_eval_dataset       = None
        self.validation_dataset       = None
        self.test_dataset             = None

        self.iterator                 = None
        self.train_iterator           = None  # static iterator for train only
        self.train_eval_iterator      = None  # dynamic iterator for train evaluation. need to reinitialize
        self.validation_iterator      = None  # dynamic iterator for validation. need to reinitialize
        self.test_iterator            = None  # dynamic iterator for test. need to reinitialize

        self.handle                   = None
        self.train_handle             = None
        self.train_eval_handle        = None
        self.validation_handle        = None
        self.test_handle              = None

        self.next_minibatch           = None  # this is the output of iterator.get_next()

        if self.validation_set_size is None:
            self.log.warning('Validation set size is None. Setting its size to 0')
            self.validation_set_size = 0
        self.train_validation_size  = self.train_set_size + self.validation_set_size
예제 #8
0
import lib.logger.logger as logger
from lib.logger.logging_config import logging_config
from utils.parameters import Parameters
from utils.factories import Factories
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from lib.active_kmean import KMeansWrapper
from sklearn.datasets import make_blobs


logging = logging_config()

logging.disable(logging.DEBUG)
log = logger.get_logger('main')

prm_file = '/data/gilad/logs/log_2210_220817_wrn-fc2_kmeans_SGD_init_200_clusters_4_cap_204/parameters.ini'

prm = Parameters()
prm.override(prm_file)

dev = prm.network.DEVICE

factories = Factories(prm)

model = factories.get_model()
model.print_stats()  # debug

preprocessor = factories.get_preprocessor()
preprocessor.print_stats()  # debug
예제 #9
0
 def __init__(self, name):
     self.name = name
     self.log = logger.get_logger(name)
예제 #10
0
 def __init__(self):
     self.log = logger.get_logger('FrozenIniParser')
     super(FrozenClass, self).__init__()
예제 #11
0
 def __init__(self):
     self.log = logger.get_logger('IniParser')
 def __init__(self, name, prm, model, *args, **kwargs):
     super(TrainSummarySaverHook, self).__init__(*args, **kwargs)
     self.name = name
     self.prm = prm
     self.log = logger.get_logger(name)
     self.model = model