Example #1
0
    def learn(self, input_dir):
        """Learn languages features from source files.

        :raise GuesslangError: when the default model is used for learning
        :param str input_dir: source code files directory.
        :return: learning accuracy
        :rtype: float
        """
        if self.is_default:
            LOGGER.error("Cannot learn using default model")
            raise GuesslangError('Cannot learn using default "readonly" model')

        languages = self.languages

        LOGGER.info("Extract training data")
        extensions = [ext for exts in languages.values() for ext in exts]
        files = search_files(input_dir, extensions)
        nb_files = len(files)
        chunk_size = min(int(_CHUNK_PROPORTION * nb_files), _CHUNK_SIZE)

        LOGGER.debug("Evaluation files count: %d", chunk_size)
        LOGGER.debug("Training files count: %d", nb_files - chunk_size)
        batches = _pop_many(files, chunk_size)

        LOGGER.debug("Prepare evaluation data")
        evaluation_data = extract_from_files(next(batches), languages)
        LOGGER.debug("Evaluation data count: %d", len(evaluation_data[0]))

        accuracy = 0
        total = ceil(nb_files / chunk_size) - 1
        LOGGER.info("Start learning")
        for pos, training_files in enumerate(batches, 1):
            LOGGER.info("Step %.2f%%", 100 * pos / total)

            LOGGER.debug("Training data extraction")
            training_data = extract_from_files(training_files, languages)
            LOGGER.debug("Training data count: %d", len(training_data[0]))

            steps = int(_FITTING_FACTOR * len(training_data[0]) / 100)
            LOGGER.debug("Fitting, steps count: %d", steps)
            self._classifier.fit(input_fn=_to_func(training_data), steps=steps)

            LOGGER.debug("Evaluation")
            accuracy = self._classifier.evaluate(
                input_fn=_to_func(evaluation_data), steps=1)['accuracy']
            _comment(accuracy)

        return accuracy
Example #2
0
def test_extract_from_files():
    with tempfile.TemporaryDirectory() as dirname:
        copy_fixtures(dirname, nb_times=10)

        files = utils.search_files(dirname, EXTENSIONS)
        arrays = utils.extract_from_files(files, LANGUAGES)

        assert len(arrays) == 2
        assert all(len(values) == 20 for values in arrays)