示例#1
0
 def get_default_specific_model_parameters(network_class, network_type):
     model_params = HParams(model_class=None)
     if network_type is not None:
         model_module_name = 'model_%s' % network_type
         model_class_name = underline_to_camel(model_module_name)
         try:
             src_module = __import__('src.models.%s.%s' %
                                     (network_class, model_module_name))
             model_class = eval(
                 'src_module.models.%s.%s.%s' %
                 (network_class, model_module_name, model_class_name))
             model_params = model_class.get_default_model_parameters()
             model_params.add_hparam('model_class',
                                     model_class)  # add model class
         except ImportError:
             print('Fatal Error: no model module: \"src.models.%s.%s\"' %
                   (network_class, model_module_name))
         except AttributeError:
             print(
                 'Fatal Error: probably (1) no model class named as %s.%s, '
                 'or (2) the class no \"get_default_model_parameters()\"' %
                 (network_class, model_module_name))
     return model_params
示例#2
0
class ParamsCenter(object):
    def __init__(self):
        self.hparam_name_list = []
        # ------parsing input arguments"--------
        parser = argparse.ArgumentParser()
        parser.register('type', 'bool', (lambda x: x.lower() in
                                         ("yes", "true", "t", "1")))
        parser.add_argument('--mode',
                            type=str,
                            default='train',
                            help='train_tasks')
        parser.add_argument('--dataset',
                            type=str,
                            default='snli',
                            help='[snli|multinli_m|multinli_mm]')
        parser.add_argument('--network_class',
                            type=str,
                            default='transformer',
                            help='None')
        parser.add_argument('--network_type',
                            type=str,
                            default=None,
                            help='None')
        parser.add_argument('--gpu',
                            type=str,
                            default='3',
                            help='selected gpu index')
        parser.add_argument('--gpu_mem',
                            type=float,
                            default=None,
                            help='selected gpu index')
        parser.add_argument('--model_dir_prefix',
                            type=str,
                            default='prefix',
                            help='model dir name prefix')
        parser.add_argument('--aws',
                            type='bool',
                            default=False,
                            help='using aws')

        # parsing parameters group
        parser.add_argument('--preprocessing_params',
                            type=str,
                            default='',
                            help='')
        parser.add_argument('--model_params', type=str, default='', help='')
        parser.add_argument('--training_params', type=str, default='', help='')

        parser.set_defaults(shuffle=True)
        args = parser.parse_args()
        self.parsed_params = HParams()
        for key, val in args.__dict__.items():
            self.parsed_params.add_hparam(key, val)
        self.register_hparams(self.parsed_params, 'parsed_params')

        # pre-processed
        self.preprocessed_params = self.get_default_preprocessing_params()
        self.preprocessed_params.parse(self.parsed_params.preprocessing_params)
        self.register_hparams(self.preprocessed_params, 'preprocessed_params')

        # model
        self.model_params = merge_params(
            self.get_default_model_parameters(),
            self.get_default_specific_model_parameters(
                self.parsed_params.network_class,
                self.parsed_params.network_type))
        self.model_params.parse(self.parsed_params.model_params)
        self.register_hparams(self.model_params, 'model_params')

        # traning
        self.training_params = self.get_default_training_params()
        self.training_params.parse(self.parsed_params.training_params)
        self.register_hparams(self.training_params, 'training_params')

    @staticmethod
    def get_default_preprocessing_params():
        params = HParams(
            max_sent_len=50,
            load_preproc=True,
        )
        return params

    @staticmethod
    def get_default_model_parameters():
        return HParams()

    @staticmethod
    def get_default_training_params():
        hparams = HParams(
            optimizer='openai_adam',
            grad_norm=1.,
            n_steps=90000,
            lr=6.25e-5,

            # control
            save_model=False,
            save_num=3,
            load_model=False,
            load_path='',
            summary_period=1000,
            eval_period=500,
            train_batch_size=20,
            test_batch_size=24,
        )
        return hparams

    @staticmethod
    def get_default_specific_model_parameters(network_class, network_type):
        model_params = HParams(model_class=None)
        if network_type is not None:
            model_module_name = 'model_%s' % network_type
            model_class_name = underline_to_camel(model_module_name)
            try:
                src_module = __import__('src.models.%s.%s' %
                                        (network_class, model_module_name))
                model_class = eval(
                    'src_module.models.%s.%s.%s' %
                    (network_class, model_module_name, model_class_name))
                model_params = model_class.get_default_model_parameters()
                model_params.add_hparam('model_class',
                                        model_class)  # add model class
            except ImportError:
                print('Fatal Error: no model module: \"src.models.%s.%s\"' %
                      (network_class, model_module_name))
            except AttributeError:
                print(
                    'Fatal Error: probably (1) no model class named as %s.%s, '
                    'or (2) the class no \"get_default_model_parameters()\"' %
                    (network_class, model_module_name))
        return model_params

    # ============== Utils =============
    def register_hparams(self, hparams, name):
        assert isinstance(hparams, HParams)
        assert isinstance(name, str)
        assert name not in self.hparam_name_list

        self.hparam_name_list.append(name)
        setattr(self, name, hparams)

    @property
    def all_params(self):
        all_params = HParams()
        for hparam_name in reversed(self.hparam_name_list):
            cur_params = getattr(self, hparam_name)
            all_params = merge_params(all_params, cur_params)
        return all_params

    def __getitem__(self, item):
        assert isinstance(item, str)

        for hparam_name in reversed(self.hparam_name_list):
            try:
                return getattr(getattr(self, hparam_name), item)
            except AttributeError:
                pass
        raise AttributeError('no item named as \'%s\'' % item)