示例#1
0
    def run(self):
        """
        Run the benchmark on all the files.
        """
        logger.log_report('starting', self.fixer_repr)
        score_rows = {}
        mean_dict = {}
        files = self.reader.read_test_pairs()
        if self.random_sample_files:
            files = take_first_n(files, 10)

        logger.start()
        num_files = 0
        csv_path = os.path.join(self.get_timestamp_folder_name(),
                                'results.csv')
        open_or_create_write_file(csv_path)
        for chunk_id, files_collection in enumerate(gen_chunker(files, 200)):
            logger.start()
            if NUM_THREADS == 1:
                rows = list(map(self.run_benchmark, files_collection))
            else:
                with multiprocessing.Pool(NUM_THREADS) as pool:
                    rows = list(pool.map(self.run_benchmark, files_collection))
                    pool.close()
                    pool.join()
                    del pool
            num_files += len(rows)
            for row in rows:
                for fil, values in row.items():
                    score_rows[fil] = values
            # self.update_csv(rows, csv_path, files_collection,
            #                 first_row=(chunk_id == 0))
            df = pd.DataFrame.from_dict(score_rows, columns=self.metrics, orient='index')
            df.to_csv(csv_path)

            micro_scores = self.get_scores(df)
            mean_dict[' %d files, macro' % num_files] = np.array(
                list(score_rows.values())).mean(axis=0)
            mean_dict[' %d files, micro' % num_files] = np.array(micro_scores)

            logger.log_report('%d files, seq accuracy %.5f, avg. duration %.2fs' % (
                num_files, df['acc'].mean(), df['duration'].mean()))
            self.summarize(mean_dict)

            logger.log_full_report_into_file(os.path.join(
                self.get_timestamp_folder_name(), 'chunk'), keep_log=True)
            #logger.log_full_report_into_file(os.path.join(
            #    self.get_timestamp_folder_name(),
            #    'chunk%d_' % (chunk_id + 1)), keep_log=True)
            #logger.log_full_report_into_file('%s-chunk%d_' % (
            #    self.dump_dir, chunk_id + 1))

        logger.log_seperator()
        mean_dict[' %d files' % num_files] = np.array(
            list(score_rows.values())).mean(axis=0)
        self.summarize(mean_dict)
        logger.log_full_report_into_file(os.path.join(
            self.get_timestamp_folder_name(),
            'all_'), keep_log=True)
        logger.log_full_report_into_file('%s-all_' % self.dump_dir)
示例#2
0
    def create_network(self, X, Y, combiner_val):
        bias = np.zeros(X.shape[1:])
        weights = np.ones(X.shape[1:])
        if X.shape[0] == 0:
            return weights, bias

        import keras.backend as K
        from keras.layers import Input, Conv1D, Dense, Reshape, Activation, Lambda, Add
        from keras.models import Model
        from keras.optimizers import Adam, SGD, Adadelta
        from .custom_layers import Sparse

        logger.log_debug("\n", X.shape, Y.shape,
                         np.unique(Y, return_counts=True), '\n',
                         X.mean(axis=0), '\n\n', X.std(axis=0))
        inp = Input(X.shape[1:])
        sparse = Sparse(use_kernel=self.use_kernel, use_bias=self.use_bias)
        if X.ndim > 2:
            sum_layer = Lambda(lambda x: K.sum(x, axis=-1))
        else:
            sum_layer = Activation('linear')
        print(combiner_val.shape)
        combiner = Dense(combiner_val.shape[-1],
                         activation='softmax',
                         use_bias=False)
        combiner.trainable = False
        out = combiner(sum_layer(sparse(inp)))
        combiner.set_weights([combiner_val])
        model = Model(inp, out)
        model.summary()

        lr = 0.1
        decay = 5e-2
        model.compile(Adam(lr),
                      loss='sparse_categorical_crossentropy',
                      metrics=['sparse_categorical_accuracy'])
        bar = tqdm(range(10000), ncols=160)
        for epoch in bar:
            K.set_value(model.optimizer.lr, lr / (1 + epoch * decay))
            vals = model.train_on_batch(X, Y)
            # vals = model.fit(X, Y, verbose=0, batch_size=min(X.shape[0], 1 << 17))
            names = [
                name.replace('sparse_categorical_accuracy', 'acc')
                for name in model.metrics_names
            ]
            dicts = dict(zip(names, vals))
            if epoch % 400 == 1:
                for arr in sparse.get_weights():
                    logger.log_debug("\n" + str(np.round(arr, 3)))
            bar.set_postfix(**dicts)
            bar.refresh()

        loss, acc = model.evaluate(X, Y, verbose=0)
        logger.log_report('loss:', loss, 'acc:', acc, highlight=2)

        if self.use_kernel:
            weights = sparse.get_weights()[0]
        if self.use_bias:
            bias = sparse.get_weights()[-1]
        return weights, bias
示例#3
0
    def summarize(self, mean_dict):
        """
        Summarize the evaluated results.

        :param dict mean_dict:
            Evaluation dictionary of the metric evaluations of the fixed files
        """
        viewer = MultiViewer()
        logger.log_report(viewer.metric_comparison(mean_dict, self.metrics))
示例#4
0
 def summarize_html(self, mean_dict, viewer):
     html_comparison = viewer.metric_comparison(mean_dict, self.metrics,
                                                modes=HTML)
     result = [html_comparison]
     path = os.path.join(self.get_timestamp_folder_name(),
                         'comparisons', 'htmls')
     result.extend((open(fil, 'r').read()
                    for fil in get_all_files(path, ext='.html')
                    if fil != 'summary.html'))
     path = os.path.join(path, 'summary.html')
     with open(path, 'w') as summary_file:
         summary_file.write(viewer.merge_wrapped_pages(*result, mode=HTML,
                                                       seperate=True))
         summary_file.close()
         logger.log_report('summarized files in %s' % path)
 def sample_analysis(self):
     for temp in [0.3, 0.5, 0.6, 1.2]:
         text = 'Hello America, this is the new capitalism'
         state = None
         char = self.str_codes(text)
         for _ in range(500):
             preds, state = self.predict(char,
                                         state=state,
                                         return_state=True,
                                         encode=False)
             # preds = self.predict(text)
             char = self.sample_char(preds, temperature=temp)
             if self.direction == BACKWARD:
                 text = char + text
             elif self.direction == FORWARD:
                 text = text + char
             char = self.str_codes(char)
         logger.log_report('Temperature %.2f\n' % temp, text)
         logger.log_seperator()
示例#6
0
 def load_model(self, model_load_path):
     if os.path.isdir(os.path.dirname(model_load_path)):
         path, fil, ext = extract_file_name(model_load_path)
         all_files = [
             os.path.join(path, f) for f in os.listdir(path)
             if f.endswith(fil + '.' + ext)
         ]
         # logger.log_debug(model_load_path, all_files)
         if all_files:
             load_path = max(all_files)
             from keras.models import load_model
             from models.custom_layers import all_custom_layers
             self.model = load_model(load_path,
                                     custom_objects=all_custom_layers)
             logger.log_report(load_path, 'loaded..', highlight=4)
             try:
                 self.initial_epoch = int(
                     extract_file_name(load_path)[1][2:7])
             except ValueError:
                 pass
             return True
     logger.log_error(model_load_path, 'not found..', highlight=6)
     return False
示例#7
0
    def run_benchmark(self, key):
        """
        Run the benchmarks of a given file or directory.

        :param str path: Correct file path
        :rtype: dict
        :returns:
            Evaluation dictionary of the metric evaluations of the fixed file.
        """

        fixer = self.fixer
        if NUM_THREADS > 1:
            fixer = construct_and_load_fixer(self.config)
        file_name, correct_text, corrupt_text = key
        correct_text = re.sub(r' +', ' ', correct_text).strip()
        corrupt_text = re.sub(r' +', ' ', corrupt_text).strip()
        corrupt_path = file_name

        fixed_path = os.path.join(self.get_timestamp_folder_name(),
                                  'fixed', file_name + '_fixed.txt')
        pklfixed_path = os.path.join(self.get_timestamp_folder_name(),
                                     'fixed', file_name + '_fixed.pkl')
        row = {}
        comparisons = []
        # html_comparisons = []
        evaluator = MultiViewer()
        if len(correct_text) >= 12000:
            logger.log_info("%s is too big, won't fix.." % corrupt_path)
            row[file_name] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                              0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            return row

        # Fixing
        if os.path.isfile(pklfixed_path):
            logger.log_info("already done", pklfixed_path)
            with open(pklfixed_path, 'rb') as fl:
                row = pickle.load(fl)
            return row
        else:
            logger.log_info("with %s Fixing.. " % str(fixer) + corrupt_path + '\n')
            t0 = datetime.datetime.now()
            fixed_text = fixer.fix(corrupt_text)
            duration = datetime.datetime.now() - t0
            caption = 'fixing %s with %s' % (file_name, str(fixer))

            # comparing
            # metrics_vals, comparison, html_comparison = evaluator.evaluate(
            #     correct_text, corrupt_text, fixed_text,
            #     modes=[TERMINAL, HTML], caption=caption)
            metrics_vals, comparison = evaluator.evaluate(
                correct_text, corrupt_text, fixed_text,
                modes=TERMINAL, caption=caption)
            row[file_name] = metrics_vals + (duration.total_seconds(),)
            comparisons.append(comparison)
            # html_comparisons.append(html_comparison)

            # Fixed file
            # logger.log_info('with %s Fixed.. ' % str(fixer) + corrupt_path + '\n',
            print('with %s Fixed.. ' % str(fixer) + corrupt_path + '\n',
                  comparison, 'with a duration of',
                  int(round(duration.total_seconds())), 's')
            with open_or_create_write_file(fixed_path) as fixed_file:
                fixed_file.write(fixed_text)
                fixed_file.close()
                logger.log_report('dumped fixed file into:', fixed_path)

            # metric results
            metric_comparison = evaluator.metric_comparison(
                row, self.metrics, modes=TERMINAL)
            # metric_comparison,html_metric_comparison= evaluator.metric_comparison(
            #     row, self.metrics, modes=[TERMINAL, HTML])
            comparisons.append(metric_comparison)
            # html_comparisons.append(html_metric_comparison)
            print(metric_comparison)
            # logger.log_report(metric_comparison)

            # Terminal dumps
            """
            comparisons = evaluator.merge_wrapped_pages(*comparisons,
                                                        mode=TERMINAL)
            comparison_path = os.path.join(self.get_timestamp_folder_name(),
                                           'comparisons', file_name + '.fix')
            with open_or_create_write_file(comparison_path, 'w') as output:
                output.write(cleanstr('Results of fixing: %s' % corrupt_path))
                output.write(cleanstr(comparisons))
                output.close()
                logger.log_report('dumped comparison into:', comparison_path)
            """

            # HTML dumps
            # html_comparisons = evaluator.merge_wrapped_pages(*html_comparisons,
            #                                                  mode=HTML)
            # html_comparison_path = os.path.join(self.get_timestamp_folder_name(),
            #                                     'comparisons', 'htmls',
            #                                     file_name + '.html')
            # with open_or_create_write_file(html_comparison_path, 'w') as output:
            #     output.write(cleanstr(html_comparisons))
            #     output.close()
            #     logger.log_report('dumped html comparison into:',
            #                       html_comparison_path)
            with open(pklfixed_path, 'wb') as fl:
                pickle.dump(row, fl)
            return row
示例#8
0
import os
import multiprocessing

from tqdm import tqdm

from configs import get_dataset_config
from constants import NUM_THREADS
from handlers.reader import Reader
from utils.logger import logger
from utils.utils import get_vocab, take_first_n

if __name__ == '__main__':
    config = get_dataset_config()
    reader = Reader(config)
    total, gen = reader.read_train_lines()

    vocab = {}
    with multiprocessing.Pool(NUM_THREADS) as pool:
        for cnt_dict in tqdm(pool.imap(get_vocab, gen), total=total):
            for word, cnt in cnt_dict.items():
                vocab[word] = vocab.get(word, 0) + cnt
        pool.close()
        pool.join()
    vocab = sorted([(word, cnt) for word, cnt in vocab.items() if cnt > 2])

    logger.log_report("writing into:", config.vocab_path)
    with open(config.vocab_path, 'w') as fl:
        for word, cnt in tqdm(vocab):
            fl.write("%s\t%d\n" % (word, cnt))
    logger.log_report("done.. wrote all into:", config.vocab_path)