def get_default_specific_model_parameters(network_class, network_type): model_params = HParams(model_class=None) if network_type is not None: model_module_name = 'model_%s' % network_type model_class_name = underline_to_camel(model_module_name) try: src_module = __import__('src.models.%s.%s' % (network_class, model_module_name)) model_class = eval( 'src_module.models.%s.%s.%s' % (network_class, model_module_name, model_class_name)) model_params = model_class.get_default_model_parameters() model_params.add_hparam('model_class', model_class) # add model class except ImportError: print('Fatal Error: no model module: \"src.models.%s.%s\"' % (network_class, model_module_name)) except AttributeError: print( 'Fatal Error: probably (1) no model class named as %s.%s, ' 'or (2) the class no \"get_default_model_parameters()\"' % (network_class, model_module_name)) return model_params
class ParamsCenter(object): def __init__(self): self.hparam_name_list = [] # ------parsing input arguments"-------- parser = argparse.ArgumentParser() parser.register('type', 'bool', (lambda x: x.lower() in ("yes", "true", "t", "1"))) parser.add_argument('--mode', type=str, default='train', help='train_tasks') parser.add_argument('--dataset', type=str, default='snli', help='[snli|multinli_m|multinli_mm]') parser.add_argument('--network_class', type=str, default='transformer', help='None') parser.add_argument('--network_type', type=str, default=None, help='None') parser.add_argument('--gpu', type=str, default='3', help='selected gpu index') parser.add_argument('--gpu_mem', type=float, default=None, help='selected gpu index') parser.add_argument('--model_dir_prefix', type=str, default='prefix', help='model dir name prefix') parser.add_argument('--aws', type='bool', default=False, help='using aws') # parsing parameters group parser.add_argument('--preprocessing_params', type=str, default='', help='') parser.add_argument('--model_params', type=str, default='', help='') parser.add_argument('--training_params', type=str, default='', help='') parser.set_defaults(shuffle=True) args = parser.parse_args() self.parsed_params = HParams() for key, val in args.__dict__.items(): self.parsed_params.add_hparam(key, val) self.register_hparams(self.parsed_params, 'parsed_params') # pre-processed self.preprocessed_params = self.get_default_preprocessing_params() self.preprocessed_params.parse(self.parsed_params.preprocessing_params) self.register_hparams(self.preprocessed_params, 'preprocessed_params') # model self.model_params = merge_params( self.get_default_model_parameters(), self.get_default_specific_model_parameters( self.parsed_params.network_class, self.parsed_params.network_type)) self.model_params.parse(self.parsed_params.model_params) self.register_hparams(self.model_params, 'model_params') # traning self.training_params = self.get_default_training_params() self.training_params.parse(self.parsed_params.training_params) self.register_hparams(self.training_params, 'training_params') @staticmethod def get_default_preprocessing_params(): params = HParams( max_sent_len=50, load_preproc=True, ) return params @staticmethod def get_default_model_parameters(): return HParams() @staticmethod def get_default_training_params(): hparams = HParams( optimizer='openai_adam', grad_norm=1., n_steps=90000, lr=6.25e-5, # control save_model=False, save_num=3, load_model=False, load_path='', summary_period=1000, eval_period=500, train_batch_size=20, test_batch_size=24, ) return hparams @staticmethod def get_default_specific_model_parameters(network_class, network_type): model_params = HParams(model_class=None) if network_type is not None: model_module_name = 'model_%s' % network_type model_class_name = underline_to_camel(model_module_name) try: src_module = __import__('src.models.%s.%s' % (network_class, model_module_name)) model_class = eval( 'src_module.models.%s.%s.%s' % (network_class, model_module_name, model_class_name)) model_params = model_class.get_default_model_parameters() model_params.add_hparam('model_class', model_class) # add model class except ImportError: print('Fatal Error: no model module: \"src.models.%s.%s\"' % (network_class, model_module_name)) except AttributeError: print( 'Fatal Error: probably (1) no model class named as %s.%s, ' 'or (2) the class no \"get_default_model_parameters()\"' % (network_class, model_module_name)) return model_params # ============== Utils ============= def register_hparams(self, hparams, name): assert isinstance(hparams, HParams) assert isinstance(name, str) assert name not in self.hparam_name_list self.hparam_name_list.append(name) setattr(self, name, hparams) @property def all_params(self): all_params = HParams() for hparam_name in reversed(self.hparam_name_list): cur_params = getattr(self, hparam_name) all_params = merge_params(all_params, cur_params) return all_params def __getitem__(self, item): assert isinstance(item, str) for hparam_name in reversed(self.hparam_name_list): try: return getattr(getattr(self, hparam_name), item) except AttributeError: pass raise AttributeError('no item named as \'%s\'' % item)