Exemplo n.º 1
0
class EncryptedOKTest(models.Test):
    name = core.String()
    points = core.Float()
    partner = core.String(optional=True)

    def warn(self, method):
        print_error("Cannot {} {}: test is encrypted".format(
            method, self.name))
        keys_string = input("Please paste the key to decrypt this test: ")
        keys = keys_string.strip().split()
        if keys:
            raise ex.ForceDecryptionException(keys)

    def run(self, env):
        self.warn('run')
        return {'failed': 1, 'locked': 0, 'passed': 0}

    def score(self):
        self.warn('score')
        return 0

    def unlock(self, interact):
        self.warn('unlock')

    def lock(self, hash_fn):
        self.warn('lock')

    def dump(self):
        self.warn('save the test')
Exemplo n.º 2
0
class Test(core.Serializable):
    name = core.String()
    points = core.Float()
    partner = core.String(optional=True)

    def run(self, env):
        """Subclasses should override this method to run tests.

        NOTE: env is intended only for use with the programmatic API for
        Python OK tests.
        """
        raise NotImplementedError

    def score(self):
        """Subclasses should override this method to score the test."""
        raise NotImplementedError

    def unlock(self, interact):
        """Subclasses should override this method to lock the test."""
        raise NotImplementedError

    def lock(self, hash_fn):
        """Subclasses should override this method to lock the test."""
        raise NotImplementedError

    def dump(self):
        """Subclasses should override this method for serialization."""
        raise NotImplementedError
Exemplo n.º 3
0
class DoctestSuite(models.Suite):
    setup = core.String(default='')
    teardown = core.String(default='')

    console_type = pyconsole.PythonConsole

    def __init__(self, verbose, interactive, timeout=None, **fields):
        super().__init__(verbose, interactive, timeout, **fields)
        self.console = self.console_type(verbose, interactive, timeout)

    def post_instantiation(self):
        for i, case in enumerate(self.cases):
            if not isinstance(case, dict):
                raise ex.SerializeException('Test cases must be dictionaries')
            self.cases[i] = interpreter.CodeCase(self.console, self.setup,
                                                 self.teardown, **case)
Exemplo n.º 4
0
class MockSerializable(core.Serializable):
    TEST_INT = 2

    var1 = core.Boolean()
    var2 = core.Int(default=TEST_INT)
    var3 = core.String(optional=True)
    var4 = core.List(optional=True)
Exemplo n.º 5
0
class DoctestSuite(models.Suite):
    setup = core.String(default='')
    teardown = core.String(default='')

    console_type = pyconsole.PythonConsole

    # A hack that allows OkTest to identify DoctestSuites without circular
    # imports.
    doctest_suite_flag = True

    def __init__(self, test, verbose, interactive, timeout=None, **fields):
        super().__init__(test, verbose, interactive, timeout, **fields)
        self.skip_locked_cases = True
        self.console = self.console_type(verbose, interactive, timeout)

    def post_instantiation(self):
        for i, case in enumerate(self.cases):
            if not isinstance(case, dict):
                raise ex.SerializeException('Test cases must be dictionaries')
            self.cases[i] = interpreter.CodeCase(self.console, self.setup,
                                                 self.teardown, **case)
Exemplo n.º 6
0
class Suite(core.Serializable):
    type = core.String()
    scored = core.Boolean(default=True)
    cases = core.List()

    def __init__(self, verbose, interactive, timeout=None, **fields):
        super().__init__(**fields)
        self.verbose = verbose
        self.interactive = interactive
        self.timeout = timeout

    def run(self, test_name, suite_number):
        """Subclasses should override this method to run tests.

        PARAMETERS:
        test_name    -- str; name of the parent test.
        suite_number -- int; suite number, assumed to be 1-indexed.

        RETURNS:
        dict; results of the following form:
        {
            'passed': int,
            'failed': int,
            'locked': int,
        }
        """
        raise NotImplementedError

    def _run_case(self, test_name, suite_number, case, case_number):
        """A wrapper for case.run().

        Prints informative output and also captures output of the test case
        and returns it as a log. The output is suppressed -- it is up to the
        calling function to decide whether or not to print the log.
        """
        output.off()    # Delay printing until case status is determined.
        log_id = output.new_log()

        format.print_line('-')
        print('{} > Suite {} > Case {}'.format(test_name, suite_number,
                                               case_number))
        print()

        success = case.run()
        if success:
            print('-- OK! --')

        output.on()
        output_log = output.get_log(log_id)
        output.remove_log(log_id)

        return success, output_log
Exemplo n.º 7
0
class ConceptCase(common_models.Case):
    question = core.String()
    answer = core.String()
    choices = core.List(type=str, optional=True)

    def post_instantiation(self):
        self.question = textwrap.dedent(self.question).strip()
        self.answer = textwrap.dedent(self.answer).strip()

        if self.choices != core.NoValue:
            for i, choice in enumerate(self.choices):
                self.choices[i] = textwrap.dedent(choice).strip()

    def run(self):
        """Runs the conceptual test case.

        RETURNS:
        bool; True if the test case passes, False otherwise.
        """
        print('Q: ' + self.question)
        print('A: ' + self.answer)
        return True

    def lock(self, hash_fn):
        self.answer = hash_fn(self.answer)
        self.locked = True

    def unlock(self, unique_id_prefix, case_id, interact):
        """Unlocks the conceptual test case."""
        print('Q: ' + self.question)
        answer = interact(unique_id_prefix + '\n' + self.question, case_id,
                          self.question, [self.answer], self.choices)
        assert len(answer) == 1
        answer = answer[0]
        if answer != self.answer:
            # Answer was presumably unlocked
            self.locked = False
            self.answer = answer
Exemplo n.º 8
0
class Test(core.Serializable):
    name = core.String()
    points = core.Float()
    partner = core.String(optional=True)

    def run(self):
        """Subclasses should override this method to run tests."""
        raise NotImplementedError

    def score(self):
        """Subclasses should override this method to score the test."""
        raise NotImplementedError

    def unlock(self, interact):
        """Subclasses should override this method to lock the test."""
        raise NotImplementedError

    def lock(self, hash_fn):
        """Subclasses should override this method to lock the test."""
        raise NotImplementedError

    def dump(self):
        """Subclasses should override this method for serialization."""
        raise NotImplementedError
Exemplo n.º 9
0
class Assignment(core.Serializable):
    name = core.String()
    endpoint = core.String()
    src = core.List(type=str, optional=True)
    tests = core.Dict(keys=str, values=str, ordered=True)
    default_tests = core.List(type=str, optional=True)
    protocols = core.List(type=str)

    ####################
    # Programmatic API #
    ####################

    def grade(self, question, env=None, skip_locked_cases=False):
        """Runs tests for a particular question. The setup and teardown will
        always be executed.

        question -- str; a question name (as would be entered at the command
                    line
        env      -- dict; an environment in which to execute the tests. If
                    None, uses the environment of __main__. The original
                    dictionary is never modified; each test is given a
                    duplicate of env.
        skip_locked_cases -- bool; if False, locked cases will be tested

        Returns: dict; maps question names (str) -> results (dict). The
        results dictionary contains the following fields:
        - "passed": int (number of test cases passed)
        - "failed": int (number of test cases failed)
        - "locked": int (number of test cases locked)
        """
        if env is None:
            import __main__
            env = __main__.__dict__
        messages = {}
        tests = self._resolve_specified_tests([question], all_tests=False)
        for test in tests:
            try:
                for suite in test.suites:
                    suite.skip_locked_cases = skip_locked_cases
                    suite.console.skip_locked_cases = skip_locked_cases
                    suite.console.hash_key = self.name
            except AttributeError:
                pass
        test_name = tests[0].name
        grade(tests, messages, env)
        return messages['grading'][test_name]

    ############
    # Internal #
    ############

    _TESTS_PACKAGE = 'client.sources'
    _PROTOCOL_PACKAGE = 'client.protocols'

    def __init__(self, args, **fields):
        self.cmd_args = args
        self.test_map = collections.OrderedDict()
        self.protocol_map = collections.OrderedDict()

    def post_instantiation(self):
        self._print_header()
        self._load_tests()
        self._load_protocols()
        self.specified_tests = self._resolve_specified_tests(
            self.cmd_args.question, self.cmd_args.all)

    def _load_tests(self):
        """Loads all tests specified by test_map."""
        log.info('Loading tests')
        for file_pattern, source in self.tests.items():
            # Separate filepath and parameter
            if ':' in file_pattern:
                file_pattern, parameter = file_pattern.split(':', 1)
            else:
                parameter = ''

            for file in sorted(glob.glob(file_pattern)):
                try:
                    module = importlib.import_module(self._TESTS_PACKAGE +
                                                     '.' + source)
                except ImportError:
                    raise ex.LoadingException(
                        'Invalid test source: {}'.format(source))

                test_name = file
                if parameter:
                    test_name += ':' + parameter

                self.test_map.update(module.load(file, parameter, self))

                log.info('Loaded {}'.format(test_name))

        if not self.test_map:
            raise ex.LoadingException('No tests loaded')

    def dump_tests(self):
        """Dumps all tests, as determined by their .dump() method.

        PARAMETERS:
        tests -- dict; file -> Test. Each Test object has a .dump method
                 that takes a filename and serializes the test object.
        """
        log.info('Dumping tests')
        for test in self.test_map.values():
            try:
                test.dump()
            except ex.SerializeException as e:
                log.warning('Unable to dump {}: {}'.format(test.name, str(e)))
            else:
                log.info('Dumped {}'.format(test.name))

    def _resolve_specified_tests(self, questions, all_tests=False):
        """For each of the questions specified on the command line,
        find the test corresponding that question.

        Questions are preserved in the order that they are specified on the
        command line. If no questions are specified, use the entire set of
        tests.
        """
        if not questions and not all_tests \
                and self.default_tests != core.NoValue \
                and len(self.default_tests) > 0:
            log.info('Using default tests (no questions specified): '
                     '{}'.format(self.default_tests))
            return [self.test_map[test] for test in self.default_tests]
        elif not questions:
            log.info(
                'Using all tests (no questions specified and no default tests)'
            )
            return list(self.test_map.values())
        elif not self.test_map:
            log.info('No tests loaded')
            return []

        specified_tests = []
        for question in questions:
            if question not in self.test_map:
                print('Test "{}" not found.'.format(question))
                print('Did you mean one of the following? '
                      '(Names are case sensitive)')
                for test in self.test_map:
                    print('    {}'.format(test))
                raise ex.LoadingException(
                    'Invalid test specified: {}'.format(question))

            log.info('Adding {} to specified tests'.format(question))
            if question not in specified_tests:
                specified_tests.append(self.test_map[question])
        return specified_tests

    def _load_protocols(self):
        log.info('Loading protocols')
        for proto in self.protocols:
            try:
                module = importlib.import_module(self._PROTOCOL_PACKAGE + '.' +
                                                 proto)
                self.protocol_map[proto] = module.protocol(self.cmd_args, self)
                log.info('Loaded protocol "{}"'.format(proto))
            except ImportError:
                log.debug('Skipping unknown protocol "{}"'.format(proto))

    def _print_header(self):
        format.print_line('=')
        print('Assignment: {}'.format(self.name))
        print('OK, version {}'.format(client.__version__))
        format.print_line('=')
        print()
Exemplo n.º 10
0
class Doctest(models.Test):
    docstring = core.String()

    PS1 = '>>> '
    PS2 = '... '

    IMPORT_STRING = 'from {} import *'
    SETUP = PS1 + IMPORT_STRING
    prompt_re = re.compile(r'(\s*)({}|{})'.format(PS1, '\.\.\. '))

    def __init__(self, file, verbose, interactive, timeout=None, **fields):
        super().__init__(**fields)
        self.file = file
        self.verbose = verbose
        self.interactive = interactive
        self.timeout = timeout

        self.console = pyconsole.PythonConsole(self.verbose, self.interactive,
                                                  self.timeout)

    def post_instantiation(self):
        # TODO(albert): rewrite test validation. Inconsistent leading space is
        # currently not validated correctly (see tests).
        self.docstring = textwrap.dedent(self.docstring)
        code = []
        prompt_on = False
        leading_space = ''
        for line in self.docstring.split('\n'):
            prompt_match = self.prompt_re.match(line)
            if prompt_match:
                if prompt_on and not line.startswith(leading_space):
                    raise ex.SerializeException('Inconsistent tabs for doctest')
                elif not prompt_on:
                    prompt_on = True
                    leading_space = prompt_match.group(1)
                code.append(line.lstrip())
            elif line.endswith('...'):
                # A line consisting only of ... is treated as a noop. See
                # issue #46
                continue
            elif not line.strip():
                prompt_on = False
                leading_space = ''
            elif prompt_on:
                if not line.startswith(leading_space):
                    raise ex.SerializeException('Inconsistent tabs for doctest')
                code.append(line[len(leading_space):])
        module = self.SETUP.format(importing.path_to_module_string(self.file))
        self.case = interpreter.CodeCase(self.console, module,
                                         code='\n'.join(code))

    def run(self, env):
        """Runs the suites associated with this doctest.

        NOTE: env is intended only for use with the programmatic API to support
        Python OK tests. It is not used here.

        RETURNS:
        bool; True if the doctest completely passes, False otherwise.
        """
        output.off()
        log_id = output.new_log()

        format.print_line('-')
        print('Doctests for {}'.format(self.name))
        print()

        if not self.docstring:
            print('-- No doctests found for {} --'.format(self.name))
            success = False
        else:
            success = self.case.run()
            if success:
                print('-- OK! --')

        output.on()
        output_log = output.get_log(log_id)
        output.remove_log(log_id)

        if not success or self.verbose:
            print(''.join(output_log))

        if not success and self.interactive:
            self.console.interact()

        if success:
            return {'passed': 1, 'failed': 0, 'locked': 0}
        else:
            return {'passed': 0, 'failed': 1, 'locked': 0}

    def score(self):
        format.print_line('-')
        print('Doctests for {}'.format(self.name))
        print()
        success = self.case.run()
        score = 1.0 if success else 0.0

        print('Score: {}/1'.format(score))
        print()
        return score

    def unlock(self, interact):
        """Doctests cannot be unlocked."""

    def lock(self, hash_fn):
        """Doctests cannot be locked."""

    def dump(self):
        """Doctests do not need to be dumped, since no state changes."""

    def get_code(self):
        """Render code for tracing."""
        setup = self.IMPORT_STRING.format(importing.path_to_module_string(self.file))
        data = {
            self.name: {
            'setup': setup + '\n',
            'code': self.case.formatted_code(),
            'teardown': '',
            }
        }
        return data
Exemplo n.º 11
0
class Doctest(models.Test):
    docstring = core.String()

    PS1 = '>>> '
    PS2 = '... '

    SETUP = PS1 + 'from {} import *'
    prompt_re = re.compile(r'(\s*)({}|{})'.format(PS1, '\.\.\. '))

    def __init__(self, file, verbose, interactive, timeout=None, **fields):
        super().__init__(**fields)
        self.file = file
        self.verbose = verbose
        self.interactive = interactive
        self.timeout = timeout

        self.console = doctest_case.PythonConsole(self.verbose,
                                                  self.interactive,
                                                  self.timeout)

    def post_instantiation(self):
        # TODO(albert): rewrite test validation. Inconsistent leading space is
        # currently not validated correctly (see tests).
        self.docstring = textwrap.dedent(self.docstring)
        code = []
        prompt_on = False
        leading_space = ''
        for line in self.docstring.split('\n'):
            prompt_match = self.prompt_re.match(line)
            if prompt_match:
                if prompt_on and not line.startswith(leading_space):
                    raise ex.SerializeException(
                        'Inconsistent tabs for doctest')
                elif not prompt_on:
                    prompt_on = True
                    leading_space = prompt_match.group(1)
                code.append(line.lstrip())
            elif not line.strip():
                prompt_on = False
                leading_space = ''
            elif prompt_on:
                if not line.startswith(leading_space):
                    raise ex.SerializeException(
                        'Inconsistent tabs for doctest')
                code.append(line.lstrip())
        module = self.SETUP.format(importing.path_to_module_string(self.file))
        self.case = doctest_case.DoctestCase(self.console,
                                             module,
                                             code='\n'.join(code))

    def run(self):
        """Runs the suites associated with this doctest.

        RETURNS:
        bool; True if the doctest completely passes, False otherwise.
        """
        output.off()
        log_id = output.new_log()

        format.print_line('-')
        print('Doctests for {}'.format(self.name))
        print()

        success = self.case.run()
        if success:
            print('-- OK! --')

        output.on()
        output_log = output.get_log(log_id)
        output.remove_log(log_id)

        if not success or self.verbose:
            print(''.join(output_log))

        if not success and self.interactive:
            self.console.interact()

        if success:
            return {'passed': 1, 'failed': 0, 'locked': 0}
        else:
            return {'passed': 0, 'failed': 1, 'locked': 0}

    def score(self):
        format.print_line('-')
        print('Doctests for {}'.format(self.name))
        print()
        success = self.case.run()
        score = 1.0 if success else 0.0

        print('Score: {}/1'.format(score))
        print()
        return score

    def unlock(self, interact):
        """Doctests cannot be unlocked."""

    def lock(self, hash_fn):
        """Doctests cannot be locked."""

    def dump(self, file):
        """Doctests do not need to be dumped, since no state changes."""
Exemplo n.º 12
0
class Suite(core.Serializable):
    type = core.String()
    scored = core.Boolean(default=True)
    cases = core.List()

    def __init__(self, test, verbose, interactive, timeout=None, **fields):
        super().__init__(**fields)
        self.test = test
        self.verbose = verbose
        self.interactive = interactive
        self.timeout = timeout
        self.run_only = []

    def run(self, test_name, suite_number, env=None):
        """Subclasses should override this method to run tests.

        PARAMETERS:
        test_name    -- str; name of the parent test.
        suite_number -- int; suite number, assumed to be 1-indexed.
        env          -- dict; used by programmatic API to provide a
                        custom environment to run tests with.

        RETURNS:
        dict; results of the following form:
        {
            'passed': int,
            'failed': int,
            'locked': int,
        }
        """
        raise NotImplementedError

    def enumerate_cases(self):
        enumerated = enumerate(self.cases)
        if self.run_only:
            return [x for x in enumerated if x[0] + 1 in self.run_only]
        return enumerated

    def _run_case(self, test_name, suite_number, case, case_number):
        """A wrapper for case.run().

        Prints informative output and also captures output of the test case
        and returns it as a log. The output is printed only if the case fails,
        or if self.verbose is True.
        """
        output.off()  # Delay printing until case status is determined.
        log_id = output.new_log()
        format.print_line('-')
        print('{} > Suite {} > Case {}'.format(test_name, suite_number,
                                               case_number))
        print()

        success = case.run()
        if success:
            print('-- OK! --')

        output.on()
        output_log = output.get_log(log_id)
        output.remove_log(log_id)

        if not success or self.verbose:
            print(''.join(output_log))
        if not success:
            short_name = self.test.get_short_name()
            # TODO: Change when in notebook mode
            print('Run only this test case with '
                  '"python3 ok -q {} --suite {} --case {}"'.format(
                      short_name, suite_number, case_number))
        return success
Exemplo n.º 13
0
class OkTest(models.Test):
    suites = core.List()
    description = core.String(optional=True)

    def __init__(self, file, suite_map, assign_name, verbose, interactive,
                 timeout=None, **fields):
        super().__init__(**fields)
        self.file = file
        self.suite_map = suite_map
        self.verbose = verbose
        self.interactive = interactive
        self.timeout = timeout
        self.assignment_name = assign_name

    def post_instantiation(self):
        for i, suite in enumerate(self.suites):
            if not isinstance(suite, dict):
                raise ex.SerializeException('Test cases must be dictionaries')
            elif 'type' not in suite:
                raise ex.SerializeException('Suites must have field "type"')
            elif suite['type'] not in self.suite_map:
                raise ex.SerializeException('Invalid suite type: '
                                            '{}'.format(suite['type']))
            self.suites[i] = self.suite_map[suite['type']](
                    self.verbose, self.interactive, self.timeout, **suite)

    def run(self):
        """Runs the suites associated with this OK test.

        RETURNS:
        dict; the results for this test, in the form
        {
            'passed': int,
            'failed': int,
            'locked': int,
        }
        """
        passed, failed, locked = 0, 0, 0
        for i, suite in enumerate(self.suites):
            results = suite.run(self.name, i + 1)

            passed += results['passed']
            failed += results['failed']
            locked += results['locked']

            if not self.verbose and (failed > 0 or locked > 0):
                # Stop at the first failed test
                break

        if locked > 0:
            print()
            print('There are still locked tests! '
                  'Use the -u option to unlock them')

        if type(self.description) == str and self.description:
            print()
            print(self.description)
            print()
        return {
            'passed': passed,
            'failed': failed,
            'locked': locked,
        }

    def score(self):
        """Runs test cases and computes the score for this particular test.

        Scores are determined by aggregating results from suite.run() for each
        suite. A suite is considered passed only if it results in no locked
        nor failed results.

        The points available for this test are distributed evenly across
        scoreable (i.e. unlocked and 'scored' = True) suites.
        """
        passed, total = 0, 0
        for i, suite in enumerate(self.suites):
            if not suite.scored:
                continue

            total += 1
            results = suite.run(self.name, i + 1)

            if results['locked'] == 0 and results['failed'] == 0:
                passed += 1
        if total > 0:
            score = passed * self.points / total
        else:
            score = 0.0

        format.print_progress_bar(self.name, passed, total - passed, 0)
        print()
        return score

    def unlock(self, interact):
        total_cases = len([case for suite in self.suites
                                for case in suite.cases])
        for suite_num, suite in enumerate(self.suites):
            for case_num, case in enumerate(suite.cases):
                case_id = '{} > Suite {} > Case {}'.format(
                            self.name, suite_num + 1, case_num + 1)

                format.print_line('-')
                print(case_id)
                print('(cases remaining: {})'.format(total_cases))
                print()
                total_cases -= 1

                if case.locked != True:
                    print('-- Already unlocked --')
                    print()
                    continue

                case.unlock(self.unique_id_prefix, case_id, interact)

        assert total_cases == 0, 'Number of cases is incorrect'
        format.print_line('-')
        print('OK! All cases for {} unlocked.'.format(self.name))
        print()
Exemplo n.º 14
0
class Assignment(core.Serializable):
    name = core.String()
    endpoint = core.String()
    src = core.List(type=str, optional=True)
    tests = core.Dict(keys=str, values=str, ordered=True)
    default_tests = core.List(type=str, optional=True)
    protocols = core.List(type=str)

    _TESTS_PACKAGE = 'client.sources'
    _PROTOCOL_PACKAGE = 'client.protocols'

    def __init__(self, cmd_args, **fields):
        self.cmd_args = cmd_args
        self.test_map = collections.OrderedDict()
        self.protocol_map = collections.OrderedDict()
        self.specified_tests = []

    def post_instantiation(self):
        self._print_header()

    def load(self):
        """Load tests and protocols."""
        self._load_tests()
        self._load_protocols()
        self._resolve_specified_tests()

    def _load_tests(self):
        """Loads all tests specified by test_map."""
        log.info('Loading tests')
        for file_pattern, source in self.tests.items():
            # Separate filepath and parameter
            if ':' in file_pattern:
                file_pattern, parameter = file_pattern.split(':', 1)
            else:
                parameter = ''

            for file in sorted(glob.glob(file_pattern)):
                try:
                    module = importlib.import_module(self._TESTS_PACKAGE + '.' + source)
                except ImportError:
                    raise ex.LoadingException('Invalid test source: {}'.format(source))

                test_name = file
                if parameter:
                    test_name += ':' + parameter
                self.test_map.update(module.load(file, parameter, self))
                log.info('Loaded {}'.format(test_name))

        if not self.test_map:
            raise ex.LoadingException('No tests loaded')

    def dump_tests(self):
        """Dumps all tests, as determined by their .dump() method.

        PARAMETERS:
        tests -- dict; file -> Test. Each Test object has a .dump method
                 that takes a filename and serializes the test object.
        """
        log.info('Dumping tests')
        for test in self.test_map.values():
            try:
                test.dump()
            except ex.SerializeException as e:
                log.warning('Unable to dump {}: {}'.format(test.name, str(e)))
            else:
                log.info('Dumped {}'.format(test.name))

    def _resolve_specified_tests(self):
        """For each of the questions specified on the command line,
        find the test corresponding that question.

        Questions are preserved in the order that they are specified on the
        command line. If no questions are specified, use the entire set of
        tests.
        """
        if not self.cmd_args.question and not self.cmd_args.all \
                and self.default_tests != core.NoValue \
                and len(self.default_tests) > 0:
            log.info('Using default tests (no questions specified): '
                     '{}'.format(self.default_tests))
            self.specified_tests = [self.test_map[test]
                                    for test in self.default_tests]
            return
        elif not self.cmd_args.question:
            log.info('Using all tests (no questions specified and no default tests)')
            self.specified_tests = list(self.test_map.values())
            return
        elif not self.test_map:
            log.info('No tests loaded')
            return
        for question in self.cmd_args.question:
            if question not in self.test_map:
                print('Test "{}" not found.'.format(question))
                print('Did you mean one of the following? '
                      '(Names are case sensitive)')
                for test in self.test_map:
                    print('    {}'.format(test))
                raise ex.LoadingException('Invalid test specified: {}'.format(question))

            log.info('Adding {} to specified tests'.format(question))
            if question not in self.specified_tests:
                self.specified_tests.append(self.test_map[question])


    def _load_protocols(self):
        log.info('Loading protocols')
        for proto in self.protocols:
            try:
                module = importlib.import_module(self._PROTOCOL_PACKAGE + '.' + proto)
            except ImportError:
                raise ex.LoadingException('Invalid protocol: {}'.format(proto))

            self.protocol_map[proto] = module.protocol(self.cmd_args, self)
            log.info('Loaded protocol "{}"'.format(proto))

    def _print_header(self):
        format.print_line('=')
        print('Assignment: {}'.format(self.name))
        print('OK, version {}'.format(client.__version__))
        format.print_line('=')
        print()
Exemplo n.º 15
0
class Assignment(core.Serializable):
    name = core.String()
    endpoint = core.String(optional=True, default='')
    decryption_keypage = core.String(optional=True, default='')
    src = core.List(type=str, optional=True)
    tests = core.Dict(keys=str, values=str, ordered=True)
    default_tests = core.List(type=str, optional=True)
    # ignored, for backwards-compatibility only
    protocols = core.List(type=str, optional=True)

    ####################
    # Programmatic API #
    ####################

    def grade(self, question, env=None, skip_locked_cases=False):
        """Runs tests for a particular question. The setup and teardown will
        always be executed.

        question -- str; a question name (as would be entered at the command
                    line
        env      -- dict; an environment in which to execute the tests. If
                    None, uses the environment of __main__. The original
                    dictionary is never modified; each test is given a
                    duplicate of env.
        skip_locked_cases -- bool; if False, locked cases will be tested

        Returns: dict; maps question names (str) -> results (dict). The
        results dictionary contains the following fields:
        - "passed": int (number of test cases passed)
        - "failed": int (number of test cases failed)
        - "locked": int (number of test cases locked)
        """
        if env is None:
            import __main__
            env = __main__.__dict__
        messages = {}
        tests = self._resolve_specified_tests([question], all_tests=False)
        for test in tests:
            try:
                for suite in test.suites:
                    suite.skip_locked_cases = skip_locked_cases
                    suite.console.skip_locked_cases = skip_locked_cases
                    suite.console.hash_key = self.name
            except AttributeError:
                pass
        test_name = tests[0].name
        grade(tests, messages, env)
        return messages['grading'][test_name]

    ##############
    # Encryption #
    ##############

    def generate_encryption_key(self, keys_file):
        data = [(filename, encryption.generate_key())
                for filename in self._get_files()]
        with open(keys_file, "w") as f:
            json.dump(data, f)

    def encrypt(self, keys_file, padding):
        """
        Encrypt each question and test, with the given keys file, which contains (file, key) pairs
        """
        with open(keys_file) as f:
            keys = dict(json.load(f))
        for file in self._get_files():
            if file in keys:
                self._encrypt_file(file, keys[file], padding)

    def decrypt(self, keys):
        decrypted_files, undecrypted_files = self.attempt_decryption(keys)
        if not undecrypted_files + decrypted_files:
            print_warning("All files are already decrypted")
        elif undecrypted_files:
            if keys:
                print_error("Unable to decrypt some files with the keys",
                            ", ".join(keys))
            else:
                print_error("No keys found, could not decrypt any files")
            print_error("    Non-decrypted files:", *undecrypted_files)

    def attempt_decryption(self, keys):
        if self.decryption_keypage:
            try:
                response = requests.get(self.decryption_keypage)
                response.raise_for_status()
                keys_data = response.content.decode('utf-8')
                keys = keys + encryption.get_keys(keys_data)
            except Exception as e:
                print_error("Could not load decryption page {}: {}.".format(
                    self.decryption_keypage, e))
                print_error(
                    "You can pass in a key directly by running python3 ok --decrypt [KEY]"
                )

        decrypted_files = []
        undecrypted_files = []
        for file in self._get_files():
            with open(file) as f:
                if not encryption.is_encrypted(f.read()):
                    continue
            for key in keys:
                success = self._decrypt_file(file, key)
                if success:
                    decrypted_files.append(file)
                    break
            else:
                undecrypted_files.append(file)
        return decrypted_files, undecrypted_files

    def _decrypt_file(self, path, key):
        """
        Decrypt the given file in place with the given key.
        If the key does not match, do not change the file contents
        """
        success = False

        def decrypt(ciphertext):
            if not encryption.is_encrypted(ciphertext):
                return ciphertext
            try:
                plaintext = encryption.decrypt(ciphertext, key)
                nonlocal success
                success = True
                print_success("decrypted", path, "with", key)
                return plaintext
            except encryption.InvalidKeyException:
                return ciphertext

        self._in_place_edit(path, decrypt)
        return success

    def _encrypt_file(self, path, key, padding):
        """
        Encrypt the given file in place with the given key.
        This is idempotent but if you try to encrypt the same file with multiple keys it errors.
        """
        def encrypt(data):
            if encryption.is_encrypted(data):
                try:
                    data = encryption.decrypt(data, key)
                except encryption.InvalidKeyException:
                    raise ValueError(
                        "Attempt to re-encrypt file with an invalid key")
            return encryption.encrypt(data, key, padding)

        self._in_place_edit(path, encrypt)

    @staticmethod
    def _in_place_edit(path, func):
        """
        Edit the given file in place, atomically. `func` is a function that modifies the data in the file.
        """
        with open(path) as f:
            data = f.read()
        ciphertext = func(data)
        temporary_file = "." + uuid.uuid4().hex
        with open(temporary_file, "w") as f:
            f.write(ciphertext)
        # atomic rename
        os.replace(temporary_file, path)

    def _get_files(self):
        """
        Get all the test and submission source files associated with this assignment, deduplicated
        """
        tests = [
            file for k, v in self.tests.items() for file in glob.glob(k)
            if v == 'ok_test' or v == 'scheme_test'
        ]
        src = list(self.src)
        return sorted(set(tests + src))

    @property
    def server_url(self):
        scheme = 'http' if self.cmd_args.insecure else 'https'
        return '{}://{}'.format(scheme, self.cmd_args.server)

    ############
    # Internal #
    ############

    _TESTS_PACKAGE = 'client.sources'
    _PROTOCOL_PACKAGE = 'client.protocols'

    # A list of all protocols that should be loaded. Order is important.
    # Dependencies:
    # analytics     -> grading
    # autostyle     -> analytics, grading
    # backup        -> all other protocols
    # collaborate   -> file_contents, analytics
    # file_contents -> none
    # grading       -> rate_limit
    # hinting       -> file_contents, analytics
    # lock          -> none
    # rate_limit    -> none
    # scoring       -> none
    # trace         -> file_contents
    # unlock        -> none
    # testing       -> none
    _PROTOCOLS = [
        "testing",
        # "rate_limit", uncomment to turn rate limiting back on!
        "file_contents",
        "grading",
        "analytics",
        "autostyle",
        "collaborate",
        "hinting",
        "lock",
        "scoring",
        "unlock",
        "trace",
        "backup",
    ]

    def __init__(self, args, **fields):
        self.cmd_args = args
        self.test_map = collections.OrderedDict()
        self.protocol_map = collections.OrderedDict()

    def post_instantiation(self):
        self._print_header()
        self._load_tests()
        self._load_protocols()
        self.specified_tests = self._resolve_specified_tests(
            self.cmd_args.question, self.cmd_args.all)

    def set_args(self, **kwargs):
        """Set command-line arguments programmatically. For example:

            assignment.set_args(
                server='http://localhost:5000',
                no_browser=True,
                backup=True,
                timeout=60,
            )
        """
        self.cmd_args.update(**kwargs)

    def authenticate(self, force=False, inline=False):
        if not inline:
            return auth.authenticate(self.cmd_args,
                                     endpoint=self.endpoint,
                                     force=force)
        else:
            return auth.notebook_authenticate(self.cmd_args, force=force)

    def get_student_email(self):
        return auth.get_student_email(self.cmd_args, endpoint=self.endpoint)

    def get_identifier(self):
        return auth.get_identifier(self.cmd_args, endpoint=self.endpoint)

    def is_empty_init(self, path):
        if os.path.basename(path) != '__init__.py':
            return False

        with open(path) as f:
            contents = f.read()

        return contents.strip() == ""

    def _load_tests(self):
        """Loads all tests specified by test_map."""
        log.info('Loading tests')
        for file_pattern, sources in self.tests.items():
            for source in sources.split(","):
                # Separate filepath and parameter
                if ':' in file_pattern:
                    file_pattern, parameter = file_pattern.split(':', 1)
                else:
                    parameter = ''

                for file in sorted(glob.glob(file_pattern)):
                    if self.is_empty_init(file):
                        continue
                    try:
                        module = importlib.import_module(self._TESTS_PACKAGE +
                                                         '.' + source)
                    except ImportError:
                        raise ex.LoadingException(
                            'Invalid test source: {}'.format(source))

                    test_name = file
                    if parameter:
                        test_name += ':' + parameter

                    self.test_map.update(module.load(file, parameter, self))

                    log.info('Loaded {}'.format(test_name))

    def dump_tests(self):
        """Dumps all tests, as determined by their .dump() method.

        PARAMETERS:
        tests -- dict; file -> Test. Each Test object has a .dump method
                 that takes a filename and serializes the test object.
        """
        log.info('Dumping tests')
        for test in self.test_map.values():
            try:
                test.dump()
            except ex.SerializeException as e:
                log.warning('Unable to dump {}: {}'.format(test.name, str(e)))
            else:
                log.info('Dumped {}'.format(test.name))

    def _resolve_specified_tests(self, questions, all_tests=False):
        """For each of the questions specified on the command line,
        find the test corresponding that question.

        Questions are preserved in the order that they are specified on the
        command line. If no questions are specified, use the entire set of
        tests.
        """
        if not questions and not all_tests \
                and self.default_tests != core.NoValue \
                and len(self.default_tests) > 0:
            log.info('Using default tests (no questions specified): '
                     '{}'.format(self.default_tests))
            bad_tests = sorted(test for test in self.default_tests
                               if test not in self.test_map)
            if bad_tests:
                error_message = (
                    "Required question(s) missing: {}. "
                    "This often is the result of accidentally deleting the question's doctests or the entire function."
                )
                raise ex.LoadingException(
                    error_message.format(", ".join(bad_tests)))
            return [self.test_map[test] for test in self.default_tests]
        elif not questions:
            log.info(
                'Using all tests (no questions specified and no default tests)'
            )
            return list(self.test_map.values())
        elif not self.test_map:
            log.info('No tests loaded')
            return []

        specified_tests = []
        for question in questions:
            if question not in self.test_map:
                raise ex.InvalidTestInQuestionListException(
                    list(self.test_map), question)

            log.info('Adding {} to specified tests'.format(question))
            if question not in specified_tests:
                specified_tests.append(self.test_map[question])
        return specified_tests

    def _load_protocols(self):
        log.info('Loading protocols')
        for proto in self._PROTOCOLS:
            module = importlib.import_module(self._PROTOCOL_PACKAGE + '.' +
                                             proto)
            self.protocol_map[proto] = module.protocol(self.cmd_args, self)
            log.info('Loaded protocol "{}"'.format(proto))

    def _print_header(self):
        format.print_line('=')
        print('Assignment: {}'.format(self.name))
        print('OK, version {}'.format(client.__version__))
        format.print_line('=')
        print()
Exemplo n.º 16
0
class CodeCase(models.Case):
    """TestCase for doctest-style Python tests."""

    code = core.String()

    def __init__(self, console, setup='', teardown='', **fields):
        """Constructor.

        PARAMETERS:
        input_str -- str; the input string, which will be dedented and
                     split along newlines.
        outputs   -- list of TestCaseAnswers
        test      -- Test or None; the test to which this test case
                     belongs.
        frame     -- dict; the environment in which the test case will
                     be executed.
        teardown  -- str; the teardown code. This code will be executed
                     regardless of errors.
        status    -- keyword arguments; statuses for the test case.
        """
        super().__init__(**fields)
        self.console = console
        self.setup = setup
        self.teardown = teardown

        # must reload for parsons problems
        if self.setup and self.console.parsons:
            assignment_name = self.setup.split()[2]
            self.setup = textwrap.dedent(self.setup)
            self.setup += f"\n>>> import {assignment_name}"
            self.setup += "\n>>> from importlib import reload"
            self.setup += f"\n>>> {assignment_name} = reload({assignment_name})"
            self.setup += f"\n>>> from {assignment_name} import *"

    def post_instantiation(self):
        self.code = textwrap.dedent(self.code)
        self.setup = textwrap.dedent(self.setup)
        self.teardown = textwrap.dedent(self.teardown)

        self.lines = self.split_code(self.code, self.console.PS1,
                                     self.console.PS2)

    def run(self):
        """Implements the GradedTestCase interface."""
        self.console.load(self.lines, setup=self.setup, teardown=self.teardown)
        return self.console.interpret()

    def lock(self, hash_fn):
        assert self.locked != False, 'called lock when self.lock = False'
        for line in self.lines:
            if isinstance(line, CodeAnswer) and not line.locked:
                line.output = [hash_fn(output) for output in line.output]
                line.locked = True
        self.locked = True
        self._sync_code()

    def unlock(self, unique_id_prefix, case_id, interact):
        """Unlocks the CodeCase.

        PARAMETERS:
        unique_id_prefix -- string; a prefix of a unique identifier for this
                            Case, for purposes of analytics.
        case_id          -- string; an identifier for this Case, for purposes of
                            analytics.
        interact         -- function; handles user interaction during the unlocking
                            phase.
        """
        print(self.setup.strip())
        prompt_num = 0
        current_prompt = []
        try:
            for line in self.lines:
                if isinstance(line, str) and line:
                    print(line)
                    current_prompt.append(line)
                elif isinstance(line, CodeAnswer):
                    prompt_num += 1
                    if not line.locked:
                        print('\n'.join(line.output))
                        continue

                    unique_id = self._construct_unique_id(
                        unique_id_prefix, self.lines)
                    line.output = interact(unique_id,
                                           case_id +
                                           ' >  Prompt {}'.format(prompt_num),
                                           '\n'.join(current_prompt),
                                           line.output,
                                           normalizer=self.console.normalize,
                                           choices=line.choices,
                                           multiline=self.multiline)
                    line.locked = False
                    current_prompt = []
            self.locked = False
        finally:
            self._sync_code()

    @classmethod
    def split_code(cls, code, PS1, PS2):
        """Splits the given string of code based on the provided PS1 and PS2
        symbols.

        PARAMETERS:
        code -- str; lines of interpretable code, using PS1 and PS2 prompts
        PS1  -- str; first-level prompt symbol
        PS2  -- str; second-level prompt symbol

        RETURN:
        list; a processed sequence of lines corresponding to the input code.
        """
        processed_lines = []
        for line in textwrap.dedent(code).splitlines():
            if not line or line.startswith(PS1) or line.startswith(PS2):
                processed_lines.append(line)
                continue

            assert len(processed_lines
                       ) > 0, 'code improperly formatted: {}'.format(code)
            if not isinstance(processed_lines[-1], CodeAnswer):
                processed_lines.append(CodeAnswer())
            processed_lines[-1].update(line)
        return processed_lines

    def _sync_code(self):
        """Syncs the current state of self.lines with self.code, the
        serializable string representing the set of code.
        """
        new_code = []
        for line in self.lines:
            if isinstance(line, CodeAnswer):
                new_code.append(line.dump())
            else:
                new_code.append(line)
        self.code = '\n'.join(new_code)

    def _format_code_line(self, line):
        """Remove PS1/PS2 from code lines in tests.
        """
        if line.startswith(self.console.PS1):
            line = line.replace(self.console.PS1, '')
        elif line.startswith(self.console.PS2):
            line = line.replace(self.console.PS2, '')
        return line

    def formatted_code(self):
        """Provides a interpretable version of the code in the case,
        with formatting for external users (Tracing or Exporting).
        """
        code_lines = []
        for line in self.lines:
            text = line
            if isinstance(line, CodeAnswer):
                if line.locked:
                    text = '# Expected: ? (test case is locked)'
                else:
                    split_lines = line.dump().splitlines()
                    # Handle case when we expect multiline outputs
                    text = '# Expected: ' + '\n# '.join(split_lines)
            else:
                text = self._format_code_line(line)
            code_lines.append(text)
        return code_lines

    def formatted_setup(self):
        return '\n'.join(
            [self._format_code_line(l) for l in self.setup.splitlines() if l])

    def formatted_teardown(self):
        return '\n'.join([
            self._format_code_line(l) for l in self.teardown.splitlines() if l
        ])

    def _construct_unique_id(self, id_prefix, lines):
        """Constructs a unique ID for a particular prompt in this case,
        based on the id_prefix and the lines in the prompt.
        """
        text = []
        for line in lines:
            if isinstance(line, str):
                text.append(line)
            elif isinstance(line, CodeAnswer):
                text.append(line.dump())
        return id_prefix + '\n' + '\n'.join(text)
Exemplo n.º 17
0
class DoctestCase(interpreter.InterpreterCase):
    """TestCase for doctest-style Python tests."""

    code = core.String()

    PS1 = '>>> '
    PS2 = '... '

    def __init__(self, console, setup='', teardown='', **fields):
        """Constructor.

        PARAMETERS:
        input_str -- str; the input string, which will be dedented and
                     split along newlines.
        outputs   -- list of TestCaseAnswers
        test      -- Test or None; the test to which this test case
                     belongs.
        frame     -- dict; the environment in which the test case will
                     be executed.
        teardown  -- str; the teardown code. This code will be executed
                     regardless of errors.
        status    -- keyword arguments; statuses for the test case.
        """
        assert isinstance(console, PythonConsole), 'Improper console: {}'.format(console)
        super().__init__(console, **fields)
        self.setup = setup
        self.teardown = teardown

    def post_instantiation(self):
        self.code = textwrap.dedent(self.code)
        self.setup = textwrap.dedent(self.setup)
        self.teardown = textwrap.dedent(self.teardown)

        self.lines = _split_code(self.code, self.PS1, self.PS2)

    def preprocess(self):
        self.console.load(self.code, setup=self.setup, teardown=self.teardown)

    def lock(self, hash_fn):
        assert self.locked != False, 'called lock when self.lock = False'
        for line in self.lines:
            if isinstance(line, _Answer) and not line.locked:
                line.output = [hash_fn(output) for output in line.output]
                line.locked = True
        self.locked = True
        self._sync_code()

    def unlock(self, interact):
        """Unlocks the DoctestCase.

        PARAMETERS:
        interact -- function; handles user interaction during the unlocking
                    phase.
        """
        try:
            for line in self.lines:
                if isinstance(line, str) and line:
                    print(line)
                elif isinstance(line, _Answer):
                    if not line.locked:
                        print('\n'.join(line.output))
                        continue
                    line.output = interact(line.output, line.choices)
                    line.locked = False
            self.locked = False
        finally:
            self._sync_code()

    def _sync_code(self):
        new_code = []
        for line in self.lines:
            if isinstance(line, _Answer):
                new_code.append(line.dump())
            else:
                new_code.append(line)
        self.code = '\n'.join(new_code)
Exemplo n.º 18
0
class MockSerializable2(MockSerializable):
    TEST_INT = 1

    var2 = core.Int(default=TEST_INT)
    var5 = core.String(optional=True)
Exemplo n.º 19
0
class Assignment(core.Serializable):
    name = core.String()
    endpoint = core.String(optional=True, default='')
    src = core.List(type=str, optional=True)
    tests = core.Dict(keys=str, values=str, ordered=True)
    default_tests = core.List(type=str, optional=True)
    # ignored, for backwards-compatibility only
    protocols = core.List(type=str, optional=True)

    ####################
    # Programmatic API #
    ####################

    def grade(self, question, env=None, skip_locked_cases=False):
        """Runs tests for a particular question. The setup and teardown will
        always be executed.

        question -- str; a question name (as would be entered at the command
                    line
        env      -- dict; an environment in which to execute the tests. If
                    None, uses the environment of __main__. The original
                    dictionary is never modified; each test is given a
                    duplicate of env.
        skip_locked_cases -- bool; if False, locked cases will be tested

        Returns: dict; maps question names (str) -> results (dict). The
        results dictionary contains the following fields:
        - "passed": int (number of test cases passed)
        - "failed": int (number of test cases failed)
        - "locked": int (number of test cases locked)
        """
        if env is None:
            import __main__
            env = __main__.__dict__
        messages = {}
        tests = self._resolve_specified_tests([question], all_tests=False)
        for test in tests:
            try:
                for suite in test.suites:
                    suite.skip_locked_cases = skip_locked_cases
                    suite.console.skip_locked_cases = skip_locked_cases
                    suite.console.hash_key = self.name
            except AttributeError:
                pass
        test_name = tests[0].name
        grade(tests, messages, env)
        return messages['grading'][test_name]

    @property
    def server_url(self):
        scheme = 'http' if self.cmd_args.insecure else 'https'
        return '{}://{}'.format(scheme, self.cmd_args.server)

    ############
    # Internal #
    ############

    _TESTS_PACKAGE = 'client.sources'
    _PROTOCOL_PACKAGE = 'client.protocols'

    # A list of all protocols that should be loaded. Order is important.
    # Dependencies:
    # analytics     -> grading
    # autostyle     -> analytics, grading
    # backup        -> all other protocols
    # collaborate   -> file_contents, analytics
    # file_contents -> none
    # grading       -> rate_limit
    # hinting       -> file_contents, analytics
    # lock          -> none
    # rate_limit    -> none
    # scoring       -> none
    # trace         -> file_contents
    # unlock        -> none
    # testing       -> none
    _PROTOCOLS = [
        "testing",
        # "rate_limit", uncomment to turn rate limiting back on!
        "file_contents",
        "grading",
        "analytics",
        "autostyle",
        "collaborate",
        "hinting",
        "lock",
        "scoring",
        "unlock",
        "trace",
        "backup",
    ]

    def __init__(self, args, **fields):
        self.cmd_args = args
        self.test_map = collections.OrderedDict()
        self.protocol_map = collections.OrderedDict()

    def post_instantiation(self):
        self._print_header()
        self._load_tests()
        self._load_protocols()
        self.specified_tests = self._resolve_specified_tests(
            self.cmd_args.question, self.cmd_args.all)

    def set_args(self, **kwargs):
        """Set command-line arguments programmatically. For example:

            assignment.set_args(
                server='http://localhost:5000',
                no_browser=True,
                backup=True,
                timeout=60,
            )
        """
        self.cmd_args.update(**kwargs)

    def authenticate(self, force=False, inline=False):
        if not inline:
            return auth.authenticate(self.cmd_args,
                                     endpoint=self.endpoint,
                                     force=force)
        else:
            return auth.notebook_authenticate(self.cmd_args, force=force)

    def get_student_email(self):
        return auth.get_student_email(self.cmd_args, endpoint=self.endpoint)

    def get_identifier(self):
        return auth.get_identifier(self.cmd_args, endpoint=self.endpoint)

    def _load_tests(self):
        """Loads all tests specified by test_map."""
        log.info('Loading tests')
        for file_pattern, sources in self.tests.items():
            for source in sources.split(","):
                # Separate filepath and parameter
                if ':' in file_pattern:
                    file_pattern, parameter = file_pattern.split(':', 1)
                else:
                    parameter = ''

                for file in sorted(glob.glob(file_pattern)):
                    try:
                        module = importlib.import_module(self._TESTS_PACKAGE +
                                                         '.' + source)
                    except ImportError:
                        raise ex.LoadingException(
                            'Invalid test source: {}'.format(source))

                    test_name = file
                    if parameter:
                        test_name += ':' + parameter

                    self.test_map.update(module.load(file, parameter, self))

                    log.info('Loaded {}'.format(test_name))

    def dump_tests(self):
        """Dumps all tests, as determined by their .dump() method.

        PARAMETERS:
        tests -- dict; file -> Test. Each Test object has a .dump method
                 that takes a filename and serializes the test object.
        """
        log.info('Dumping tests')
        for test in self.test_map.values():
            try:
                test.dump()
            except ex.SerializeException as e:
                log.warning('Unable to dump {}: {}'.format(test.name, str(e)))
            else:
                log.info('Dumped {}'.format(test.name))

    def _resolve_specified_tests(self, questions, all_tests=False):
        """For each of the questions specified on the command line,
        find the test corresponding that question.

        Questions are preserved in the order that they are specified on the
        command line. If no questions are specified, use the entire set of
        tests.
        """
        if not questions and not all_tests \
                and self.default_tests != core.NoValue \
                and len(self.default_tests) > 0:
            log.info('Using default tests (no questions specified): '
                     '{}'.format(self.default_tests))
            bad_tests = sorted(test for test in self.default_tests
                               if test not in self.test_map)
            if bad_tests:
                error_message = (
                    "Required question(s) missing: {}. "
                    "This often is the result of accidentally deleting the question's doctests or the entire function."
                )
                raise ex.LoadingException(
                    error_message.format(", ".join(bad_tests)))
            return [self.test_map[test] for test in self.default_tests]
        elif not questions:
            log.info(
                'Using all tests (no questions specified and no default tests)'
            )
            return list(self.test_map.values())
        elif not self.test_map:
            log.info('No tests loaded')
            return []

        specified_tests = []
        for question in questions:
            if question not in self.test_map:
                print('Test "{}" not found.'.format(question))
                print('Did you mean one of the following? '
                      '(Names are case sensitive)')
                for test in self.test_map:
                    print('    {}'.format(test))
                raise ex.LoadingException(
                    'Invalid test specified: {}'.format(question))

            log.info('Adding {} to specified tests'.format(question))
            if question not in specified_tests:
                specified_tests.append(self.test_map[question])
        return specified_tests

    def _load_protocols(self):
        log.info('Loading protocols')
        for proto in self._PROTOCOLS:
            module = importlib.import_module(self._PROTOCOL_PACKAGE + '.' +
                                             proto)
            self.protocol_map[proto] = module.protocol(self.cmd_args, self)
            log.info('Loaded protocol "{}"'.format(proto))

    def _print_header(self):
        format.print_line('=')
        print('Assignment: {}'.format(self.name))
        print('OK, version {}'.format(client.__version__))
        format.print_line('=')
        print()
Exemplo n.º 20
0
class Assignment(core.Serializable):
    name = core.String()
    endpoint = core.String()
    src = core.List(type=str, optional=True)
    tests = core.Dict(keys=str, values=str, ordered=True)
    protocols = core.List(type=str)

    _TESTS_PACKAGE = 'client.sources'
    _PROTOCOL_PACKAGE = 'client.protocols'

    def __init__(self, cmd_args, **fields):
        self.cmd_args = cmd_args
        self.test_map = collections.OrderedDict()
        self.protocol_map = collections.OrderedDict()
        self.specified_tests = []

    def post_instantiation(self):
        self._print_header()
        self._load_tests()
        self._load_protocols()
        self._resolve_specified_tests()

    def _load_tests(self):
        """Loads all tests specified by test_map.

        PARAMETERS:
        test_map -- dict; file pattern -> serialize module. Every file that
                    that matches the UNIX style file pattern will be loaded
                    by the module.load method.
        """
        log.info('Loading tests')
        for file_pattern, source in self.tests.items():
            # Separate filepath and parameter
            if ':' in file_pattern:
                file_pattern, parameter = file_pattern.split(':', maxsplit=1)
            else:
                parameter = ''

            files = glob.glob(file_pattern)
            if not files:
                error_msg = 'No tests found for pattern: {}'.format(
                    file_pattern)
                print(error_msg)
                raise ex.LoadingException(error_msg)

            for file in files:
                try:
                    module = importlib.import_module(self._TESTS_PACKAGE +
                                                     '.' + source)
                except ImportError:
                    raise ex.LoadingException(
                        'Invalid test source: {}'.format(source))

                test_name = file
                if parameter:
                    test_name += ':' + parameter
                self.test_map[test_name] = module.load(file, parameter,
                                                       self.cmd_args)
                log.info('Loaded {}'.format(test_name))

    def dump_tests(self):
        """Dumps all tests, as determined by their .dump() method.

        PARAMETERS:
        tests -- dict; file -> Test. Each Test object has a .dump method
                 that takes a filename and serializes the test object.
        """
        log.info('Dumping tests')
        for file, test in self.test_map.items():
            try:
                test.dump(file)
            except ex.SerializeException as e:
                log.info('Unable to dump {} to {}: {}'.format(
                    test.name, file, str(e)))
            else:
                log.info('Dumped {} to {}'.format(test.name, file))

    def _resolve_specified_tests(self):
        """For each of the questions specified on the command line,
        find the best test corresponding that question.

        The best match is found by finding the test filepath that has the
        smallest edit distance with the specified question.

        Questions are preserved in the order that they are specified on the
        command line. If no questions are specified, use the entire set of
        tests.
        """
        if not self.cmd_args.question:
            log.info('Using all tests (no questions specified)')
            self.specified_tests = list(self.test_map.values())
            return
        elif not self.test_map:
            log.info('No tests loaded')
            return
        for question in self.cmd_args.question:
            matches = []
            for test in self.test_map:
                if _has_subsequence(test.lower(), question.lower()):
                    matches.append(test)

            if len(matches) > 1:
                print('Did you mean one of the following?')
                for test in matches:
                    print('    {}'.format(test))
                raise ex.LoadingException(
                    'Ambiguous test specified: {}'.format(question))

            elif not matches:
                print('Did you mean one of the following?')
                for test in self.test_map:
                    print('    {}'.format(test))
                raise ex.LoadingException(
                    'Invalid test specified: {}'.format(question))

            match = matches[0]
            log.info('Matched {} to {}'.format(question, match))
            if match not in self.specified_tests:
                self.specified_tests.append(self.test_map[match])

    def _load_protocols(self):
        log.info('Loading protocols')
        for proto in self.protocols:
            try:
                module = importlib.import_module(self._PROTOCOL_PACKAGE + '.' +
                                                 proto)
            except ImportError:
                raise ex.LoadingException('Invalid protocol: {}'.format(proto))

            self.protocol_map[proto] = module.protocol(self.cmd_args, self)
            log.info('Loaded protocol "{}"'.format(proto))

    def _print_header(self):
        format.print_line('=')
        print('Assignment: {}'.format(self.name))
        print('OK, version {}'.format(client.__version__))
        format.print_line('=')
        print()