Пример #1
0
 def __init__(self, cfg: Namespace):
     """
     Args:
         cfg:  config
     """
     self.cfg = cfg
     setattr(cfg, 'model_id', self.model_id(cfg))
     setattr(cfg, 'out_dir', '{}/{}'.format(cfg.logdir, cfg.model_id))
     setattr(cfg, 'context_len', 2 * cfg.window + 1)
     self.rsc = Resource(cfg)
     self.model = CnnModel(cfg, self.rsc)
     self.optimizer = torch.optim.Adam(self.model.parameters(),
                                       cfg.learning_rate)
     self.criterion = nn.CrossEntropyLoss()
     self.evaler = Evaluator()
     self._load_dataset()
     if 'epoch' not in cfg.__dict__:
         setattr(cfg, 'epoch', 0)
         setattr(cfg, 'best_epoch', 0)
     self.log_file = None  # tab separated log file
     self.sum_wrt = None  # tensorboard summary writer
     self.loss_trains = []
     self.loss_devs = []
     self.acc_chars = []
     self.acc_words = []
     self.f_scores = []
     self.learning_rates = []
Пример #2
0
 def __init__(self, model_dir: str):
     """
     Args:
         model_dir:  model dir
     """
     cfg_dict = json.load(
         open('{}/config.json'.format(model_dir), 'r', encoding='UTF-8'))
     self.cfg = Namespace()
     for key, val in cfg_dict.items():
         setattr(self.cfg, key, val)
     self.rsc = Resource(self.cfg)
     self.model = CnnModel(self.cfg, self.rsc)
     self.model.load('{}/model.state'.format(model_dir))
     self.model.eval()
Пример #3
0
 def load_train(cls, rsc_src: str) -> List['Sentence']:
     """
     load from khaiii training set
     Returns:
         list of sentences
     """
     restore_dic = Resource.load_restore_dic(f'{rsc_src}/restore.dic')
     sentences = []
     for sent in PosDataset(None, restore_dic, sys.stdin):
         sentence = Sentence()
         for word in sent.pos_tagged_words:
             sentence.words.append(word.raw)
             sentence.morphs.append(' + '.join(
                 [str(m) for m in word.pos_tagged_morphs]))
         sentences.append(sentence)
     return sentences
Пример #4
0
 def __init__(self, model_dir: str, gpu_num: int = -1):
     """
     Args:
         model_dir:  model dir
         gpu_num:  GPU number to override
     """
     cfg_dict = json.load(
         open('{}/config.json'.format(model_dir), 'r', encoding='UTF-8'))
     self.cfg = Namespace()
     for key, val in cfg_dict.items():
         setattr(self.cfg, key, val)
     setattr(self.cfg, 'gpu_num', gpu_num)
     self.rsc = Resource(self.cfg)
     self.model = Model(self.cfg, self.rsc)
     self.model.load('{}/model.state'.format(model_dir))
     self.model.eval()
Пример #5
0
def _load_cfg_rsc(rsc_src: str, model_size: str) -> Tuple[Namespace, Resource]:
    """
    load config and resource from source directory
    Args:
        rsc_src:  source directory
        model_size:  model size (base|large)
    Returns:
        config
        resource
    """
    file_path = '{}/{}.config.json'.format(rsc_src, model_size)
    cfg_dic = json.load(open(file_path, 'r', encoding='UTF-8'))
    logging.info('config: %s', json.dumps(cfg_dic, indent=4, sort_keys=True))
    cfg = Namespace()
    for key, val in cfg_dic.items():
        setattr(cfg, key, val)
    setattr(cfg, 'rsc_src', rsc_src)
    rsc = Resource(cfg)
    return cfg, rsc
Пример #6
0
def run(args: Namespace):
    """
    run function which is the start point of program
    Args:
        args:  program arguments
    """
    cfg = _load_config('{}/config.json'.format(args.in_dir))
    setattr(cfg, 'rsc_src', args.rsc_src)
    rsc = Resource(cfg)
    state_dict = torch.load('{}/model.state'.format(args.in_dir),
                            map_location=lambda storage, loc: storage)
    _validate_state_dict(cfg, rsc, state_dict)
    data = _get_data(rsc, state_dict)

    config_path = '{}/{}.config.json'.format(args.rsc_src, args.model_size)
    with open(config_path, 'w', encoding='UTF-8') as fout:
        json.dump(vars(cfg), fout, indent=4, sort_keys=True)

    pickle_path = '{}/{}.model.pickle'.format(args.rsc_src, args.model_size)
    with open(pickle_path, 'wb') as fout:
        pickle.dump(data, fout)