def _parameters_sanity_check(self): """ Parameteres sanity check. """ if self.discriminator: assert self.lr_patch_size * self.scale == self.discriminator.patch_size self.adam_optimizer if self.feature_extractor: assert self.lr_patch_size * self.scale == self.feature_extractor.patch_size check_parameter_keys( self.learning_rate, needed_keys=['initial_value'], optional_keys=['decay_factor', 'decay_frequency'], default_value=None, ) check_parameter_keys( self.flatness, needed_keys=[], optional_keys=['min', 'increase_frequency', 'increase', 'max'], default_value=0.0, ) check_parameter_keys( self.adam_optimizer, needed_keys=['beta1', 'beta2'], optional_keys=['epsilon'], default_value=None, ) check_parameter_keys(self.log_dirs, needed_keys=['logs', 'weights'])
def test_check_parameter_keys(self): par = {'a': 0} utils.check_parameter_keys(parameter=par, needed_keys=['a']) utils.check_parameter_keys(parameter=par, needed_keys=None, optional_keys=['b'], default_value=-1) self.assertTrue(par['b'] == -1) try: utils.check_parameter_keys(parameter=par, needed_keys=['c']) except: self.assertTrue(True) else: self.assertTrue(False) def check_parameter_keys(parameter, needed_keys, optional_keys=None, default_value=None): if needed_keys: for key in needed_keys: if key not in parameter: logger.error('{p} is missing key {k}'.format( p=parameter, k=key)) raise if optional_keys: for key in optional_keys: if key not in parameter: logger.info('Setting {k} in {p} to {d}'.format( k=key, p=parameter, d=default_value)) parameter[key] = default_value