Ejemplo n.º 1
0
    def _check_encoding(self, conf):
        self.encoding_invalid = True
        if not conf.pretrained_model_path and self.dictionary_invalid:
            return False

        # Calculate the MD5 of problem
        problem_path = conf.problem_path if not conf.pretrained_model_path else conf.saved_problem_path
        try:
            conf.problem_md5 = md5([problem_path])
        except Exception as e:
            conf.problem_md5 = None
            logging.info('Can not calculate md5 of problem.pkl from %s' %
                         (problem_path))
            return False

        # check the valid of encoding cache
        ## encoding cache dir
        conf.encoding_cache_dir = os.path.join(
            conf.cache_dir, conf.train_data_md5 + conf.problem_md5)
        logging.debug('[Cache] conf.encoding_cache_dir %s' %
                      (conf.encoding_cache_dir))
        if not os.path.exists(conf.encoding_cache_dir):
            return False

        ## encoding cache index
        conf.encoding_cache_index_file_path = os.path.join(
            conf.encoding_cache_dir, st.cencodig_index_file_name)
        conf.encoding_cache_index_file_md5_path = os.path.join(
            conf.encoding_cache_dir, st.cencoding_index_md5_file_name)
        if not os.path.exists(
                conf.encoding_cache_index_file_path) or not os.path.exists(
                    conf.encoding_cache_index_file_md5_path):
            return False
        if md5([conf.encoding_cache_index_file_path]) != load_from_json(
                conf.encoding_cache_index_file_md5_path):
            return False
        cache_index = load_from_json(conf.encoding_cache_index_file_path)

        ## encoding cache content
        for index in cache_index[st.cencoding_key_index]:
            file_name, file_md5 = index[0], index[1]
            if file_md5 != md5(
                [os.path.join(conf.encoding_cache_dir, file_name)]):
                return False

        if (st.cencoding_key_legal_cnt
                in cache_index) and (st.cencoding_key_illegal_cnt
                                     in cache_index):
            conf.encoding_cache_legal_line_cnt = cache_index[
                st.cencoding_key_legal_cnt]
            conf.encoding_cache_illegal_line_cnt = cache_index[
                st.cencoding_key_illegal_cnt]

        self.encoding_invalid = False
        logging.info('[Cache] encoding found')
        logging.info('%s: %d legal samples, %d illegal samples' %
                     (conf.train_data_path, conf.encoding_cache_legal_line_cnt,
                      conf.encoding_cache_illegal_line_cnt))
        return True
Ejemplo n.º 2
0
    def _prepare_encoding_cache(self, conf, problem, build=False):
        # encoding cache dir
        problem_path = conf.problem_path if not conf.pretrained_model_path else conf.saved_problem_path
        conf.problem_md5 = md5([problem_path])
        conf.encoding_cache_dir = os.path.join(
            conf.cache_dir, conf.train_data_md5 + conf.problem_md5)
        if not os.path.exists(conf.encoding_cache_dir):
            os.makedirs(conf.encoding_cache_dir)

        # encoding cache files
        conf.encoding_cache_index_file_path = os.path.join(
            conf.encoding_cache_dir, st.cencodig_index_file_name)
        conf.encoding_cache_index_file_md5_path = os.path.join(
            conf.encoding_cache_dir, st.cencoding_index_md5_file_name)
        conf.load_encoding_cache_generator = self._load_encoding_cache_generator

        if build:
            prepare_dir(conf.encoding_cache_dir,
                        True,
                        allow_overwrite=True,
                        clear_dir_if_exist=True)
            problem.build_encode_cache(conf)
            self.encoding_invalid = False

        if not self.encoding_invalid:
            cache_index = load_from_json(conf.encoding_cache_index_file_path)
            conf.encoding_file_index = cache_index[st.cencoding_key_index]
Ejemplo n.º 3
0
 def load_from_file(self, conf_path):
     # load file
     self.conf = load_from_json(conf_path, debug=False)
     self = self.Conf.load_data(self, {'Conf' : self.conf}, key_prefix_desc='Conf')
     self.language = self.language.lower()
     self.configurate_outputs()
     self.configurate_inputs()
     self.configurate_training_params()
     self.configurate_architecture()
     self.configurate_loss()
     self.configurate_cache()