예제 #1
0
class TaskCompactReports(object):

    description = 'Compacting reports ...'

    def setup(self, options):
        self.__printer = Printer(options)
        # sanitizer.category.count
        self.__data = {}

    def process(self, report):
        sanitizer = report.sanitizer.name_short
        category = report.category_name
        number_orig = report.number
        if not sanitizer in self.__data:
            self.__data[sanitizer] = {}
        if not category in self.__data[sanitizer]:
            self.__data[sanitizer][category] = 0
        self.__data[sanitizer][category] += 1
        number_new = self.__data[sanitizer][category]
        if number_new != number_orig:
            orig_report_str = str(report)
            new_file_path = utils.files.report_file_path(os.path.dirname(report.file_path), number_new)
            os.rename(report.file_path, new_file_path)
            report.number = number_new
            report.file_path = new_file_path
            self.__printer.task_info('renaming: ' + orig_report_str + ' -> ' + str(report))
예제 #2
0
def compute_top_n_tokens_for_each_doc(top_n, first_id, last_id):

    models.connect_to_db(conf.DATABASE_FILENAME)
    cleaner = Cleaner()
    top_n_tokens_per_paper = {}

    for i in range(first_id, last_id + 1, increments):
        papers_to_process = ids_to_query(i, increments, last_id)
        for paper_id in papers_to_process:
            paper_query = models.Papers_NR.select().where(
                models.Papers.id == paper_id)
            if DEBUG:
                print(paper_query)
                print(len(paper_query))

            if len(paper_query) > 0:
                paper_content = paper_query[0].paper_text
                pdf_name = paper_query[0].paper_name
                tokens = cleaner.tokenize(paper_content)
                token_frequencies = {}
                for token in tokens:
                    if token not in token_frequencies:
                        token_frequencies[token] = 1
                    else:
                        token_frequencies[token] = token_frequencies[token] + 1

                sorted_tokens = [(k, token_frequencies[k]) for k in sorted(
                    token_frequencies, key=token_frequencies.get, reverse=True)
                                 ]
                top_n_tokens_per_paper[pdf_name] = sorted_tokens[:top_n]

    models.close_connection()
    printer = Printer()
    printer.print_dict(top_n_tokens_per_paper)
예제 #3
0
class TSanReportExtractor(ReportExtractor):
    """Extract ThreadSanitizer reports"""

    __category_names = ['data race', 'thread leak']

    __start_line_pattern = re.compile(
        '^warning: threadsanitizer: (?P<category>[a-z ]+) \(', re.IGNORECASE)
    __start_last_line_pattern = re.compile('^={18}$', re.MULTILINE)
    __end_line_pattern = __start_last_line_pattern

    def __init__(self, options):
        super(TSanReportExtractor,
              self).__init__(options, Sanitizer('ThreadSanitizer', 'tsan'))
        self.__printer = Printer(options)

    def collect(self):
        for category in self.__category_names:
            self._collect_reports(category)

    def extract(self, last_line, line):
        if self._extract_continue(line):
            if self.__end_line_pattern.search(line):
                self._extract_end()
        else:
            search = self.__start_line_pattern.search(line)
            if search and self.__start_last_line_pattern.search(last_line):
                category = search.group('category').lower()
                if not category in self.__category_names:
                    self.__printer.bailout('unkown category ' + repr(category))
                else:
                    self._extract_start(
                        self._make_and_add_report(True, category))
                    self._extract_continue(last_line + line)
예제 #4
0
 def _setup(self, options, sanitizer_name_short):
     self.__blacklist_file_path = os.path.join(
         options.output_root_path, self.__blacklists_dir_name,
         'clang-' + sanitizer_name_short + self.__blacklist_file_ending)
     # dir_name.file_name.func_name
     self.__printer = Printer(options)
     self.__data = OrderedDict()
예제 #5
0
class TaskAnalyseReports(object):
    """Analyse reports and add (special) data"""

    description = 'Analysing reports ...'

    __tsan_data_race_global_location_pattern = re.compile(
        '^  location is global \'(?P<global_location>.+)\' of size \d',
        re.IGNORECASE)

    def setup(self, options):
        self.__printer = Printer(options)
        self.__analysing_funcs = {
            'tsan': {
                'data race': self.__tsan_analyse_data_race
            }
        }

    def process(self, report):
        if self.__analysing_funcs.get(report.sanitizer.name_short,
                                      {}).get(report.category_name):
            self.__analysing_funcs[report.sanitizer.name_short][
                report.category_name](report)

    def __tsan_analyse_data_race(self, report):
        with open(report.file_path, 'r') as report_file:
            for line in report_file:
                search = self.__tsan_data_race_global_location_pattern.search(
                    line)
                if search:
                    self.__printer.task_info('found global location of ' +
                                             str(report))
                    report.special[
                        'tsan_data_race_global_location'] = search.group(
                            'global_location')
            report_file.close()
예제 #6
0
class TaskSummary(object):
    """Print various stats about the reports"""
    def setup(self, options):
        self.__printer = Printer(options)
        # sanitizer_name.category_name.(new|old)
        self.__data = OrderedDict()

    def process(self, report):
        sanitizer_name = report.sanitizer.name
        category_name = report.category_name
        if not sanitizer_name in self.__data:
            self.__data[sanitizer_name] = OrderedDict()
        if not category_name in self.__data[sanitizer_name]:
            self.__data[sanitizer_name][category_name] = {'new': 0, 'old': 0}
        self.__data[sanitizer_name][category_name]['new' if report.
                                                   is_new else 'old'] += 1

    def teardown(self):
        self.__printer.task_description('Summary:')
        if len(self.__data) < 1:
            self.__printer.task_info('nothing found')
        else:
            for sanitizer_name, categories in sorted(self.__data.items(),
                                                     key=lambda s: s[0]):
                self.__printer.just_print('  ' + sanitizer_name + ':')
                for category_name, count in sorted(categories.items(),
                                                   key=lambda c: c[0]):
                    new = count['new']
                    self.__printer.just_print('    ' + category_name + ': ' +
                                              str(count['old'] + new) + ' (' +
                                              str(new) + ' new)')
        self.__printer.nl()
예제 #7
0
 def setup(self, options):
     self.__printer = Printer(options)
     self.__analysing_funcs = {
         'tsan': {
             'data race': self.__tsan_analyse_data_race
         }
     }
예제 #8
0
 def __init__(self):
     # Init related objects
     self.app = App(self)
     self.installer = Installer(self)
     self.local_op = LocalOperations()
     self.remote_op = RemoteOperations(self)
     self.printer = Printer()
     self.connect()
     self.setup()
예제 #9
0
 def setup(self, options):
     self.__root_dir_path = os.path.join(options.output_root_path, self.__root_dir_name)
     utils.files.makedirs(self.__root_dir_path, True)
     self.__printer = Printer(options)
     self.__skeletons = {}
     self.__add_funcs = {
         'tsan': {
             'data race': self.__add_tsan_data_race
         }
     }
예제 #10
0
    def fit(self, train_x, train_y, test_x, test_y):
        if self.model is None:
            Printer.warning("Model was automatically built when fitting.")
            self.build()

        history = self.model.fit(x=train_x,
                                 y=train_y,
                                 batch_size=self.batch_size,
                                 epochs=self.n_epochs,
                                 validation_data=(test_x, test_y))
        return history
예제 #11
0
 def setup(self, options):
     self.__printer = Printer(options)
     self.__duplicate_reports = []
     self.__identifiers_funcs = {
         'tsan': {
             'data race': self.__tsan_data_race_identifiers,
             'thread leak': self.__tsan_thread_leak_identifiers
         }
     }
     # TODO: split into separate lists for sanitizers and categories for better performance
     self.__known_identifiers = []
예제 #12
0
 def __init__(self, options, sanitizer):
     self.__options = options
     self.__sanitizer = sanitizer
     self.__printer = Printer(options)
     # category.number
     self.__counters = {}
     self.__reports_dir_base_path = os.path.join(options.output_root_path,
                                                 self.__reports_dir_name,
                                                 sanitizer.name_short)
     self.__report_file = None
     self.__reports = []
예제 #13
0
 def setup(self, options):
     self.__printer = Printer(options)
     self.__csv_base_dir_path = os.path.join(options.output_root_path,
                                             self.__csv_base_dir_name)
     utils.files.makedirs(self.__csv_base_dir_path)
     self.__controls = {
         'tsan': {
             'data race': {
                 'header_func': self.__header_tsan_data_race,
                 'process_func': self.__process_tsan_data_race
             }
         }
     }
예제 #14
0
class Enhancitizer(object):
    def __init__(self, options):
        self.__options = options
        self.__printer = Printer(options)
        self.__tasks = []

    def add_tasks(self, task):
        """Add additional tasks that will be executed after the main tasks"""
        self.__tasks.append(task)

    def run(self):
        """Run the enhancitizer"""
        bank = ReportsBank(self.__options)
        self.__printer.welcome() \
                      .settings() \
                      .task_description('Collecting existing reports ...')
        watch = StopWatch().start()
        bank.collect_reports()
        self.__printer.task_info_debug('execution time: ' + str(watch)) \
                      .nl() \
                      .task_description('Extracting new reports ...')
        watch.start()
        for path in self.__options.logfiles_paths:
            bank.extract_reports(path)
        self.__printer.task_info_debug('execution time: ' + str(watch)).nl()
        tasks = [
            TaskEliminateDuplicateReports(bank),  # should be the first thing
            TaskCompactReports(),  # after the elimination, before "real" tasks
            TaskAnalyseReports(),  # after the elimination, before "real" tasks
            TaskCreateTSanBlacklist(),
            TaskBuildSkeleton(),
            TaskCreateCsvSummaries(),
            TaskAddTSanContext(
            ),  # should run late to speed up stack parsing of the previous tasks
            TaskSummary()  # should be the last thing
        ]
        tasks.extend(self.__tasks)
        for task in tasks:
            watch.start()
            if hasattr(task, 'description'):
                self.__printer.task_description(task.description)
            if hasattr(task, 'setup'):
                task.setup(self.__options)
            if hasattr(task, 'process'):
                for report in bank:
                    task.process(report)
            if hasattr(task, 'teardown'):
                task.teardown()
            if hasattr(task, 'description'):
                self.__printer.task_info_debug('execution time: ' +
                                               str(watch)).nl()
예제 #15
0
    def __init__(self, options: ModelOptions, n_classes: int):
        self.name = 'simple_nn'
        self.width = options.image_width
        self.height = options.image_height
        self.learning_rate = options.learning_rate
        self.n_epochs = options.n_epochs
        self.batch_size = options.batch_size
        self.n_classes = n_classes
        self.model = None

        if not options.flatten:
            Printer.error(
                "Images has to be flattened before sending to Small VGGNet")
            exit()
예제 #16
0
    def build(self):
        """https://arxiv.org/abs/1602.07261"""

        if self.height != 299 or self.width != 299:
            Printer.error("ImageNet requires image sizes to be 299x299")
            exit()

        input_shape = (self.height, self.width, self.depth)
        model = InceptionV4(num_classes=self.n_classes, dropout_prob=0.2)

        optimizer = SGD(lr=self.learning_rate,
                        decay=self.learning_rate / self.n_epochs)
        model.compile(loss='categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])
        self.model = model
예제 #17
0
    def evaluation_report(self):
        """Generate report on test data"""

        predictions = self.neural_net.predict(self.sample_x)

        y_true = self.sample_y.argmax(axis=1)
        y_pred = predictions.argmax(axis=1)

        report = classification_report(y_true, y_pred, target_names=self.classes)

        report_output_path = os.path.join(self.output_dir,
                                          f'[{self.neural_net.name}]evaluation.txt')

        with open(report_output_path, 'w') as fw:
            fw.write(str(report))

        Printer.default(report)
예제 #18
0
class TaskBuildSkeleton(object):

    description = 'Building the skeleton ...'

    __root_dir_name = 'skeleton'
    __tsan_data_race_max_stack_depth = 3

    def setup(self, options):
        self.__root_dir_path = os.path.join(options.output_root_path, self.__root_dir_name)
        utils.files.makedirs(self.__root_dir_path, True)
        self.__printer = Printer(options)
        self.__skeletons = {}
        self.__add_funcs = {
            'tsan': {
                'data race': self.__add_tsan_data_race
            }
        }

    def process(self, report):
        if report.category_name in self.__add_funcs.get(report.sanitizer.name_short, {}):
            self.__add_funcs[report.sanitizer.name_short][report.category_name](report)

    def __add(self, report, stack_frame_id, stack_frame):
        if stack_frame.complete:
            file_rel_path = stack_frame.src_file_rel_path
            line_num = stack_frame.line_num
            if not file_rel_path in self.__skeletons:
                self.__skeletons[file_rel_path] = Skeleton(stack_frame.src_file_path)
            self.__skeletons[file_rel_path].mark(
                stack_frame.line_num, stack_frame.char_pos, str(report) + ' - frame #' + str(stack_frame_id))

    def __add_tsan_data_race(self, report):
        for stack in report.call_stacks:
            if 'tsan_data_race_type' in stack.special:
                for i in range(min(len(stack.frames), self.__tsan_data_race_max_stack_depth)):
                    self.__add(report, i, stack.frames[i])

    def teardown(self):
        for src_file_path, skeleton in self.__skeletons.items():
            skeleton_file_path = os.path.join(self.__root_dir_path, src_file_path + '.skeleton')
            self.__printer.task_info(
                'creating ' + skeleton_file_path + ' (' + str(skeleton.marked_lines_count) + ' lines)')
            utils.files.makedirs(os.path.dirname(skeleton_file_path))
            with open(skeleton_file_path, 'w') as skeleton_file:
                skeleton.write(skeleton_file)
                skeleton_file.close()
예제 #19
0
    def main_menu(self):
        print("Main menu:")
        while True:
            print('Please select one of the following options:\n'
                  '(R -> register, E -> exit, P -> print, F -> find)')
            user_input = input()

            if user_input == 'r' or user_input == 'R':
                std = self.register_student()
                self._students.append(std)
                print("Register a new student...")
                time.sleep(1)
                FileManager.write_file(r'files\students.txt', std)
                print("Done.")
            elif user_input == 'f' or user_input == 'F':
                self.find_student()
            elif user_input == 'p' or user_input == 'P':
                printer = Printer(self._students)
                # self.print_all_students()
                printer.show_printer_menu()
                printer.print_sorted_list(printer.get_user_input())
                self.main_menu()
            else:
                print("Exiting program...")
                time.sleep(1)
                exit()
예제 #20
0
class TaskAddTSanContext(object):

    description = 'Adding TSan context ...'

    __supported_sanitizer_name_short = 'tsan'
    __supported_category_names = ['data race']
    __max_stack_frames = 3 # max amount of lookups from the top of the call stack
    
    def setup(self, options):
        self.__printer = Printer(options)

    def process(self, report):
        if report.is_new and \
           report.sanitizer.name_short == self.__supported_sanitizer_name_short and \
           report.category_name in self.__supported_category_names:
            report_file_path = report.file_path
            buffer_file_path = report_file_path + '.buffer'
            self.__printer.task_info('adding context to ' + str(report))
            with open(buffer_file_path, 'w') as buffer_file:
                buffer_file.write('\n')
                for stack in report.call_stacks:
                    if 'tsan_data_race_type' in stack.special:
                        buffer_file.write(stack.title + '\n\n')
                        for i in range(min(len(stack.frames), self.__max_stack_frames)):
                            if stack.frames[i].complete:
                                # TODO: find full function signature
                                func_signature = stack.frames[i].func_name + '(...)'
                                line = SourceCodeLine(stack.frames[i].src_file_path, stack.frames[i].line_num)
                                if line.line:
                                    buffer_file.write(
                                        func_signature + ' {\n' +
                                        '  // ...\n' +
                                        '! ' + line.line + '\n' +
                                        (' ' * (stack.frames[i].char_pos - line.indent + 1)) + '^\n' +
                                        '  // ...\n' +
                                        '}\n\n')
                with open(report_file_path, 'r') as report_file:
                    for line in report_file:
                        buffer_file.write(line)
                    report_file.close()
                buffer_file.close()
            os.remove(report_file_path)
            os.rename(buffer_file_path, report_file_path)
예제 #21
0
class TaskCreateBlacklist(object):
    """Base class for creating blacklists; can be inherited to make more sense"""

    __blacklists_dir_name = 'blacklists'
    __blacklist_file_ending = '.blacklist'

    def _setup(self, options, sanitizer_name_short):
        self.__blacklist_file_path = os.path.join(
            options.output_root_path, self.__blacklists_dir_name,
            'clang-' + sanitizer_name_short + self.__blacklist_file_ending)
        # dir_name.file_name.func_name
        self.__printer = Printer(options)
        self.__data = OrderedDict()

    def _add_stack_frame(self, frame):
        if frame.complete:
            dir_name = frame.src_file_dir_rel_path
            file_name = frame.src_file_name
            func_name = frame.func_name
            if not dir_name in self.__data:
                self.__data[dir_name] = OrderedDict()
            if not file_name in self.__data[dir_name]:
                self.__data[dir_name][file_name] = []
            if not func_name in self.__data[dir_name][file_name]:
                self.__data[dir_name][file_name].append(func_name)
            self.__printer.task_info('adding ' + func_name + ' (' + frame.src_file_rel_path + ')')

    def teardown(self):
        self.__printer.task_info('creating ' + self.__blacklist_file_path)
        utils.files.makedirs(os.path.dirname(self.__blacklist_file_path))
        with open(self.__blacklist_file_path, 'w') as blacklist_file:
            for dir_name, files in sorted(self.__data.items(), key=lambda d: d[0]):
                blacklist_file.write(
                    '# --------------------------------------------------------------------------- #\n' +
                    '#   ' + dir_name + (' ' * (74 - len(dir_name)))+ '#\n' +
                    '# --------------------------------------------------------------------------- #\n\n')
                for file_name, func_names in sorted(files.items(), key=lambda f: f[0]):
                    blacklist_file.write('# ' + file_name + ' #\n\n')
                    for func_name in sorted(func_names):
                        blacklist_file.write('fun:' + func_name + '\n')
                    blacklist_file.write('\n')
            blacklist_file.close()
예제 #22
0
    def fit(self, train_x, train_y, test_x, test_y):
        if self.model is None:
            Printer.warning("Model was automatically built when fitting.")
            self.build()

        data_augmenter = ImageDataGenerator(rotation_range=30,
                                            width_shift_range=0.1,
                                            height_shift_range=0.1,
                                            shear_range=0.2,
                                            zoom_range=0.2,
                                            horizontal_flip=True,
                                            fill_mode='nearest')

        steps_per_epoch = len(train_x) // self.batch_size
        history = self.model.fit_generator(data_augmenter.flow(
            train_x, train_y, batch_size=self.batch_size),
                                           validation_data=(test_x, test_y),
                                           steps_per_epoch=steps_per_epoch,
                                           epochs=self.n_epochs)
        return history
예제 #23
0
    def from_json(json):
        options = ModelOptions()

        try:
            options.model_name = json['model_name']
            options.random_seed = int(json['random_seed'])
            options.dataset_directory = json['dataset_directory']
            options.image_extensions = set(json['image_extensions'].split())
            options.image_width = int(json['image_width'])
            options.image_height = int(json['image_height']) if json['image_height'] \
                else options.image_width
            options.flatten = (str(json['flatten']) == 'True')
            options.test_percentage = float(json['test_percentage'])
            options.learning_rate = float(json['learning_rate'])
            options.n_epochs = int(json['n_epochs'])
            options.batch_size = int(json['batch_size'])
            options.output_dir = json['output_dir']
        except ValueError:
            Printer.error("Input data parsing failed.")

        return options
예제 #24
0
    def fit(self):
        printer = Printer()

        while self.epoch < self.epochs:
            self.new_epoch()  # increments self.epoch

            self.logger.set_mode("train")
            stats = self.train_epoch(self.epoch)
            self.logger.log(stats, self.epoch)
            printer.print(stats, self.epoch, prefix="\ntrain: ")

            if self.epoch % self.test_every_n_epochs == 0 or self.phase1_will_end(
            ) or self.phase2_will_end():
                self.logger.set_mode("test")
                stats = self.test_epoch(self.validdataloader)
                self.logger.log(stats, self.epoch)
                printer.print(stats, self.epoch, prefix="\nvalid: ")
                if hasattr(self, "visdom"): self.visdom_log_test_run(stats)

            if hasattr(self, "visdom"):
                self.visdom.plot_epochs(self.logger.get_data())

        self.check_events()

        # stores all stored values in the rootpath of the logger
        #if self.save_logger:
        self.logger.save()

        return self.logger.data
예제 #25
0
def compute_document_frequencies():

    models.connect_to_db(conf.DATABASE_FILENAME)
    first_id = 1
    last_id_query = papers.select().order_by(papers.id.desc()).limit(1)
    last_id = last_id_query[0].id
    increments = 10

    token_frequencies = {}

    for i in range(first_id, last_id + 1, increments):
        papers_to_process = ids_to_query(i, increments, last_id)
        for paper_id in papers_to_process:
            paper_query = papers.select().where(papers.id == paper_id)

            unique_tokens = set()
            
            if DEBUG:
                print(paper_query)
                print(len(paper_query))

            if len(paper_query) > 0:
                paper_content = paper_query[0].paper_text
                paper_pdf_name = paper_query[0].pdf_name
                tokens = paper_content.strip().split()
                for token in tokens:
                    #print(token)
                    unique_tokens.add(token.lower())

                for i, token in enumerate(unique_tokens):
                    #print(token)
                    if token not in token_frequencies:
                        token_frequencies[token] = 1
                    else:
                        token_frequencies[token] = token_frequencies[token] + 1
                
    models.close_connection()
    sorted_tokens = [(k, token_frequencies[k]) for k in sorted(token_frequencies, key=token_frequencies.get)]
    printer = Printer()
    printer.print_token_frequency(sorted_tokens)
예제 #26
0
def compute_top_n_tokens_for_collection(top_n):

    models.connect_to_db(conf.DATABASE_FILENAME)
    first_id = 1
    last_id_query = models.Papers_NR.select().order_by(
        models.Papers_NR.id.desc()).limit(1)
    last_id = last_id_query[0].id
    increments = 10

    cleaner = Cleaner()
    token_frequencies = {}

    for i in range(first_id, last_id + 1, increments):
        papers_to_process = ids_to_query(i, increments, last_id)
        for paper_id in papers_to_process:
            paper_query = models.Papers.select().where(
                models.Papers.id == paper_id)

            if DEBUG:
                print(paper_query)
                print(len(paper_query))

            if len(paper_query) > 0:
                paper_content = paper_query[0].paper_text
                paper_pdf_name = paper_query[0].pdf_name
                tokens = cleaner.tokenize(paper_content)
                for token in tokens:
                    if token not in token_frequencies:
                        token_frequencies[token] = 1
                    else:
                        token_frequencies[token] = token_frequencies[token] + 1

    models.close_connection()
    sorted_tokens = [(k, token_frequencies[k]) for k in sorted(
        token_frequencies, key=token_frequencies.get, reverse=True)]
    top_n_tokens = sorted_tokens[:top_n]
    printer = Printer()
    printer.print_token_frequency(top_n_tokens)
예제 #27
0
    def fit(self):
        printer = Printer()

        while self.epoch < self.epochs:
            self.new_epoch()  # increments self.epoch

            self.logger.set_mode("train")
            stats = self.train_epoch(self.epoch)
            self.logger.log(stats, self.epoch)
            printer.print(stats,
                          self.epoch,
                          prefix="\n" +
                          self.traindataloader.dataset.partition + ": ")

            if self.epoch % self.test_every_n_epochs == 0 or self.epoch == 1:
                self.logger.set_mode("test")
                stats = self.test_epoch(self.validdataloader)
                self.logger.log(stats, self.epoch)
                printer.print(stats,
                              self.epoch,
                              prefix="\n" +
                              self.validdataloader.dataset.partition + ": ")
                if self.visdom is not None:
                    self.visdom_log_test_run(stats)

            if self.visdom is not None:
                self.visdom.plot_epochs(self.logger.get_data())

            if self.checkpoint_every_n_epochs % self.epoch == 0:
                print("Saving model to {}".format(self.get_model_name()))
                self.snapshot(self.get_model_name())
                print("Saving log to {}".format(self.get_log_name()))
                self.logger.get_data().to_csv(self.get_log_name())

            if self.epoch > self.early_stopping_smooth_period and self.check_for_early_stopping(
                    smooth_period=self.early_stopping_smooth_period):
                print()
                print(
                    f"Model did not improve in the last {self.early_stopping_smooth_period} epochs. stopping training..."
                )
                print("Saving model to {}".format(self.get_model_name()))
                self.snapshot(self.get_model_name())
                print("Saving log to {}".format(self.get_log_name()))
                self.logger.get_data().to_csv(self.get_log_name())
                return self.logger

        return self.logger
예제 #28
0
    def _get_image_paths(self):
        """Return a list of image files inside the `base_path` """

        image_paths = []
        for (dir_path, _, file_names) in os.walk(self.base_path):
            for file_name in file_names:
                if os.extsep not in file_name:
                    Printer.warning(
                        "Files without extension found: {}".format(file_name))
                    continue
                extension = file_name.split(os.extsep)[-1]
                if extension not in self.image_extensions:
                    Printer.warning(
                        "Non-image files found: {}".format(file_name))
                    continue
                if '.ipynb' in dir_path:
                    Printer.warning("IPyNb caches found: {}".format(file_name))
                    continue
                image_paths.append(os.path.join(dir_path, file_name))

        random.shuffle(image_paths)

        Printer.information(f"Found {len(image_paths)} images")
        return image_paths
예제 #29
0
    def fit(self):
        printer = Printer()

        while self.epoch < self.epochs:
            self.new_epoch()  # increments self.epoch

            self.logger.set_mode("train")
            stats = self.train_epoch(self.epoch)
            self.logger.log(stats, self.epoch)
            printer.print(stats, self.epoch, prefix="\ntrain: ")

            if self.epoch % self.test_every_n_epochs == 0 or self.phase1_will_end(
            ) or self.phase2_will_end():
                self.logger.set_mode("test")
                stats = self.test_epoch(self.epoch)
                self.logger.log(stats, self.epoch)
                printer.print(stats, self.epoch, prefix="\nvalid: ")
                self.visdom_log_test_run(stats)

            self.visdom.plot_epochs(self.logger.get_data())

        self.check_events()
        return self.logger.data
예제 #30
0
    def fit(self, epochs):
        printer = Printer()

        while self.epoch < epochs:
            self.new_epoch()  # increments self.epoch

            self.logger.set_mode("train")
            stats = self.train_epoch(self.epoch)
            self.logger.log(stats, self.epoch)
            printer.print(stats, self.epoch, prefix="\ntrain: ")

            if self.epoch % self.test_every_n_epochs == 0:
                self.logger.set_mode("test")
                stats = self.test_epoch(self.testdataloader)
                self.logger.log(stats, self.epoch)
                printer.print(stats, self.epoch, prefix="\ntest: ")
                if hasattr(self, "visdom"): self.visdom_log_test_run(stats)

            if hasattr(self, "visdom"):
                self.visdom.plot_epochs(self.logger.get_data())

        self.logger.save()

        return self.logger.data