예제 #1
0
    def testConstructor_homogeneousSubclass(self):
        class IntSubclass(int):
            def __init__(self):
                pass

        field = core.List(type=int)
        self.assertTrue(field.is_valid([1, IntSubclass()]))
예제 #2
0
class MockSerializable(core.Serializable):
    TEST_INT = 2

    var1 = core.Boolean()
    var2 = core.Int(default=TEST_INT)
    var3 = core.String(optional=True)
    var4 = core.List(optional=True)
예제 #3
0
    def testToJson_recursive(self):
        field = core.List()

        class Recursive(object):
            def to_json(self):
                return ListFieldTest.TEST_INT

        expect = [1, self.TEST_INT, True]
        arg = [1, Recursive(), True]

        self.assertEqual(expect, field.to_json(arg))
예제 #4
0
class Suite(core.Serializable):
    type = core.String()
    scored = core.Boolean(default=True)
    cases = core.List()

    def __init__(self, verbose, interactive, timeout=None, **fields):
        super().__init__(**fields)
        self.verbose = verbose
        self.interactive = interactive
        self.timeout = timeout

    def run(self, test_name, suite_number):
        """Subclasses should override this method to run tests.

        PARAMETERS:
        test_name    -- str; name of the parent test.
        suite_number -- int; suite number, assumed to be 1-indexed.

        RETURNS:
        dict; results of the following form:
        {
            'passed': int,
            'failed': int,
            'locked': int,
        }
        """
        raise NotImplementedError

    def _run_case(self, test_name, suite_number, case, case_number):
        """A wrapper for case.run().

        Prints informative output and also captures output of the test case
        and returns it as a log. The output is suppressed -- it is up to the
        calling function to decide whether or not to print the log.
        """
        output.off()    # Delay printing until case status is determined.
        log_id = output.new_log()

        format.print_line('-')
        print('{} > Suite {} > Case {}'.format(test_name, suite_number,
                                               case_number))
        print()

        success = case.run()
        if success:
            print('-- OK! --')

        output.on()
        output_log = output.get_log(log_id)
        output.remove_log(log_id)

        return success, output_log
예제 #5
0
class ConceptCase(common_models.Case):
    question = core.String()
    answer = core.String()
    choices = core.List(type=str, optional=True)

    def post_instantiation(self):
        self.question = textwrap.dedent(self.question).strip()
        self.answer = textwrap.dedent(self.answer).strip()

        if self.choices != core.NoValue:
            for i, choice in enumerate(self.choices):
                self.choices[i] = textwrap.dedent(choice).strip()

    def run(self):
        """Runs the conceptual test case.

        RETURNS:
        bool; True if the test case passes, False otherwise.
        """
        print('Q: ' + self.question)
        print('A: ' + self.answer)
        return True

    def lock(self, hash_fn):
        self.answer = hash_fn(self.answer)
        self.locked = True

    def unlock(self, unique_id_prefix, case_id, interact):
        """Unlocks the conceptual test case."""
        print('Q: ' + self.question)
        answer = interact(unique_id_prefix + '\n' + self.question, case_id,
                          self.question, [self.answer], self.choices)
        assert len(answer) == 1
        answer = answer[0]
        if answer != self.answer:
            # Answer was presumably unlocked
            self.locked = False
            self.answer = answer
예제 #6
0
class Assignment(core.Serializable):
    name = core.String()
    endpoint = core.String(optional=True, default='')
    decryption_keypage = core.String(optional=True, default='')
    src = core.List(type=str, optional=True)
    tests = core.Dict(keys=str, values=str, ordered=True)
    default_tests = core.List(type=str, optional=True)
    # ignored, for backwards-compatibility only
    protocols = core.List(type=str, optional=True)

    ####################
    # Programmatic API #
    ####################

    def grade(self, question, env=None, skip_locked_cases=False):
        """Runs tests for a particular question. The setup and teardown will
        always be executed.

        question -- str; a question name (as would be entered at the command
                    line
        env      -- dict; an environment in which to execute the tests. If
                    None, uses the environment of __main__. The original
                    dictionary is never modified; each test is given a
                    duplicate of env.
        skip_locked_cases -- bool; if False, locked cases will be tested

        Returns: dict; maps question names (str) -> results (dict). The
        results dictionary contains the following fields:
        - "passed": int (number of test cases passed)
        - "failed": int (number of test cases failed)
        - "locked": int (number of test cases locked)
        """
        if env is None:
            import __main__
            env = __main__.__dict__
        messages = {}
        tests = self._resolve_specified_tests([question], all_tests=False)
        for test in tests:
            try:
                for suite in test.suites:
                    suite.skip_locked_cases = skip_locked_cases
                    suite.console.skip_locked_cases = skip_locked_cases
                    suite.console.hash_key = self.name
            except AttributeError:
                pass
        test_name = tests[0].name
        grade(tests, messages, env)
        return messages['grading'][test_name]

    ##############
    # Encryption #
    ##############

    def generate_encryption_key(self, keys_file):
        data = [(filename, encryption.generate_key())
                for filename in self._get_files()]
        with open(keys_file, "w") as f:
            json.dump(data, f)

    def encrypt(self, keys_file, padding):
        """
        Encrypt each question and test, with the given keys file, which contains (file, key) pairs
        """
        with open(keys_file) as f:
            keys = dict(json.load(f))
        for file in self._get_files():
            if file in keys:
                self._encrypt_file(file, keys[file], padding)

    def decrypt(self, keys):
        decrypted_files, undecrypted_files = self.attempt_decryption(keys)
        if not undecrypted_files + decrypted_files:
            print_warning("All files are already decrypted")
        elif undecrypted_files:
            if keys:
                print_error("Unable to decrypt some files with the keys",
                            ", ".join(keys))
            else:
                print_error("No keys found, could not decrypt any files")
            print_error("    Non-decrypted files:", *undecrypted_files)

    def attempt_decryption(self, keys):
        if self.decryption_keypage:
            try:
                response = requests.get(self.decryption_keypage)
                response.raise_for_status()
                keys_data = response.content.decode('utf-8')
                keys = keys + encryption.get_keys(keys_data)
            except Exception as e:
                print_error("Could not load decryption page {}: {}.".format(
                    self.decryption_keypage, e))
                print_error(
                    "You can pass in a key directly by running python3 ok --decrypt [KEY]"
                )

        decrypted_files = []
        undecrypted_files = []
        for file in self._get_files():
            with open(file) as f:
                if not encryption.is_encrypted(f.read()):
                    continue
            for key in keys:
                success = self._decrypt_file(file, key)
                if success:
                    decrypted_files.append(file)
                    break
            else:
                undecrypted_files.append(file)
        return decrypted_files, undecrypted_files

    def _decrypt_file(self, path, key):
        """
        Decrypt the given file in place with the given key.
        If the key does not match, do not change the file contents
        """
        success = False

        def decrypt(ciphertext):
            if not encryption.is_encrypted(ciphertext):
                return ciphertext
            try:
                plaintext = encryption.decrypt(ciphertext, key)
                nonlocal success
                success = True
                print_success("decrypted", path, "with", key)
                return plaintext
            except encryption.InvalidKeyException:
                return ciphertext

        self._in_place_edit(path, decrypt)
        return success

    def _encrypt_file(self, path, key, padding):
        """
        Encrypt the given file in place with the given key.
        This is idempotent but if you try to encrypt the same file with multiple keys it errors.
        """
        def encrypt(data):
            if encryption.is_encrypted(data):
                try:
                    data = encryption.decrypt(data, key)
                except encryption.InvalidKeyException:
                    raise ValueError(
                        "Attempt to re-encrypt file with an invalid key")
            return encryption.encrypt(data, key, padding)

        self._in_place_edit(path, encrypt)

    @staticmethod
    def _in_place_edit(path, func):
        """
        Edit the given file in place, atomically. `func` is a function that modifies the data in the file.
        """
        with open(path) as f:
            data = f.read()
        ciphertext = func(data)
        temporary_file = "." + uuid.uuid4().hex
        with open(temporary_file, "w") as f:
            f.write(ciphertext)
        # atomic rename
        os.replace(temporary_file, path)

    def _get_files(self):
        """
        Get all the test and submission source files associated with this assignment, deduplicated
        """
        tests = [
            file for k, v in self.tests.items() for file in glob.glob(k)
            if v == 'ok_test' or v == 'scheme_test'
        ]
        src = list(self.src)
        return sorted(set(tests + src))

    @property
    def server_url(self):
        scheme = 'http' if self.cmd_args.insecure else 'https'
        return '{}://{}'.format(scheme, self.cmd_args.server)

    ############
    # Internal #
    ############

    _TESTS_PACKAGE = 'client.sources'
    _PROTOCOL_PACKAGE = 'client.protocols'

    # A list of all protocols that should be loaded. Order is important.
    # Dependencies:
    # analytics     -> grading
    # autostyle     -> analytics, grading
    # backup        -> all other protocols
    # collaborate   -> file_contents, analytics
    # file_contents -> none
    # grading       -> rate_limit
    # hinting       -> file_contents, analytics
    # lock          -> none
    # rate_limit    -> none
    # scoring       -> none
    # trace         -> file_contents
    # unlock        -> none
    # testing       -> none
    _PROTOCOLS = [
        "testing",
        # "rate_limit", uncomment to turn rate limiting back on!
        "file_contents",
        "grading",
        "analytics",
        "autostyle",
        "collaborate",
        "hinting",
        "lock",
        "scoring",
        "unlock",
        "trace",
        "backup",
    ]

    def __init__(self, args, **fields):
        self.cmd_args = args
        self.test_map = collections.OrderedDict()
        self.protocol_map = collections.OrderedDict()

    def post_instantiation(self):
        self._print_header()
        self._load_tests()
        self._load_protocols()
        self.specified_tests = self._resolve_specified_tests(
            self.cmd_args.question, self.cmd_args.all)

    def set_args(self, **kwargs):
        """Set command-line arguments programmatically. For example:

            assignment.set_args(
                server='http://localhost:5000',
                no_browser=True,
                backup=True,
                timeout=60,
            )
        """
        self.cmd_args.update(**kwargs)

    def authenticate(self, force=False, inline=False):
        if not inline:
            return auth.authenticate(self.cmd_args,
                                     endpoint=self.endpoint,
                                     force=force)
        else:
            return auth.notebook_authenticate(self.cmd_args, force=force)

    def get_student_email(self):
        return auth.get_student_email(self.cmd_args, endpoint=self.endpoint)

    def get_identifier(self):
        return auth.get_identifier(self.cmd_args, endpoint=self.endpoint)

    def is_empty_init(self, path):
        if os.path.basename(path) != '__init__.py':
            return False

        with open(path) as f:
            contents = f.read()

        return contents.strip() == ""

    def _load_tests(self):
        """Loads all tests specified by test_map."""
        log.info('Loading tests')
        for file_pattern, sources in self.tests.items():
            for source in sources.split(","):
                # Separate filepath and parameter
                if ':' in file_pattern:
                    file_pattern, parameter = file_pattern.split(':', 1)
                else:
                    parameter = ''

                for file in sorted(glob.glob(file_pattern)):
                    if self.is_empty_init(file):
                        continue
                    try:
                        module = importlib.import_module(self._TESTS_PACKAGE +
                                                         '.' + source)
                    except ImportError:
                        raise ex.LoadingException(
                            'Invalid test source: {}'.format(source))

                    test_name = file
                    if parameter:
                        test_name += ':' + parameter

                    self.test_map.update(module.load(file, parameter, self))

                    log.info('Loaded {}'.format(test_name))

    def dump_tests(self):
        """Dumps all tests, as determined by their .dump() method.

        PARAMETERS:
        tests -- dict; file -> Test. Each Test object has a .dump method
                 that takes a filename and serializes the test object.
        """
        log.info('Dumping tests')
        for test in self.test_map.values():
            try:
                test.dump()
            except ex.SerializeException as e:
                log.warning('Unable to dump {}: {}'.format(test.name, str(e)))
            else:
                log.info('Dumped {}'.format(test.name))

    def _resolve_specified_tests(self, questions, all_tests=False):
        """For each of the questions specified on the command line,
        find the test corresponding that question.

        Questions are preserved in the order that they are specified on the
        command line. If no questions are specified, use the entire set of
        tests.
        """
        if not questions and not all_tests \
                and self.default_tests != core.NoValue \
                and len(self.default_tests) > 0:
            log.info('Using default tests (no questions specified): '
                     '{}'.format(self.default_tests))
            bad_tests = sorted(test for test in self.default_tests
                               if test not in self.test_map)
            if bad_tests:
                error_message = (
                    "Required question(s) missing: {}. "
                    "This often is the result of accidentally deleting the question's doctests or the entire function."
                )
                raise ex.LoadingException(
                    error_message.format(", ".join(bad_tests)))
            return [self.test_map[test] for test in self.default_tests]
        elif not questions:
            log.info(
                'Using all tests (no questions specified and no default tests)'
            )
            return list(self.test_map.values())
        elif not self.test_map:
            log.info('No tests loaded')
            return []

        specified_tests = []
        for question in questions:
            if question not in self.test_map:
                raise ex.InvalidTestInQuestionListException(
                    list(self.test_map), question)

            log.info('Adding {} to specified tests'.format(question))
            if question not in specified_tests:
                specified_tests.append(self.test_map[question])
        return specified_tests

    def _load_protocols(self):
        log.info('Loading protocols')
        for proto in self._PROTOCOLS:
            module = importlib.import_module(self._PROTOCOL_PACKAGE + '.' +
                                             proto)
            self.protocol_map[proto] = module.protocol(self.cmd_args, self)
            log.info('Loaded protocol "{}"'.format(proto))

    def _print_header(self):
        format.print_line('=')
        print('Assignment: {}'.format(self.name))
        print('OK, version {}'.format(client.__version__))
        format.print_line('=')
        print()
예제 #7
0
class Assignment(core.Serializable):
    name = core.String()
    endpoint = core.String(optional=True, default='')
    src = core.List(type=str, optional=True)
    tests = core.Dict(keys=str, values=str, ordered=True)
    default_tests = core.List(type=str, optional=True)
    # ignored, for backwards-compatibility only
    protocols = core.List(type=str, optional=True)

    ####################
    # Programmatic API #
    ####################

    def grade(self, question, env=None, skip_locked_cases=False):
        """Runs tests for a particular question. The setup and teardown will
        always be executed.

        question -- str; a question name (as would be entered at the command
                    line
        env      -- dict; an environment in which to execute the tests. If
                    None, uses the environment of __main__. The original
                    dictionary is never modified; each test is given a
                    duplicate of env.
        skip_locked_cases -- bool; if False, locked cases will be tested

        Returns: dict; maps question names (str) -> results (dict). The
        results dictionary contains the following fields:
        - "passed": int (number of test cases passed)
        - "failed": int (number of test cases failed)
        - "locked": int (number of test cases locked)
        """
        if env is None:
            import __main__
            env = __main__.__dict__
        messages = {}
        tests = self._resolve_specified_tests([question], all_tests=False)
        for test in tests:
            try:
                for suite in test.suites:
                    suite.skip_locked_cases = skip_locked_cases
                    suite.console.skip_locked_cases = skip_locked_cases
                    suite.console.hash_key = self.name
            except AttributeError:
                pass
        test_name = tests[0].name
        grade(tests, messages, env)
        return messages['grading'][test_name]

    @property
    def server_url(self):
        scheme = 'http' if self.cmd_args.insecure else 'https'
        return '{}://{}'.format(scheme, self.cmd_args.server)

    ############
    # Internal #
    ############

    _TESTS_PACKAGE = 'client.sources'
    _PROTOCOL_PACKAGE = 'client.protocols'

    # A list of all protocols that should be loaded. Order is important.
    # Dependencies:
    # analytics     -> grading
    # autostyle     -> analytics, grading
    # backup        -> all other protocols
    # collaborate   -> file_contents, analytics
    # file_contents -> none
    # grading       -> rate_limit
    # hinting       -> file_contents, analytics
    # lock          -> none
    # rate_limit    -> none
    # scoring       -> none
    # trace         -> file_contents
    # unlock        -> none
    # testing       -> none
    _PROTOCOLS = [
        "testing",
        # "rate_limit", uncomment to turn rate limiting back on!
        "file_contents",
        "grading",
        "analytics",
        "autostyle",
        "collaborate",
        "hinting",
        "lock",
        "scoring",
        "unlock",
        "trace",
        "backup",
    ]

    def __init__(self, args, **fields):
        self.cmd_args = args
        self.test_map = collections.OrderedDict()
        self.protocol_map = collections.OrderedDict()

    def post_instantiation(self):
        self._print_header()
        self._load_tests()
        self._load_protocols()
        self.specified_tests = self._resolve_specified_tests(
            self.cmd_args.question, self.cmd_args.all)

    def set_args(self, **kwargs):
        """Set command-line arguments programmatically. For example:

            assignment.set_args(
                server='http://localhost:5000',
                no_browser=True,
                backup=True,
                timeout=60,
            )
        """
        self.cmd_args.update(**kwargs)

    def authenticate(self, force=False, inline=False):
        if not inline:
            return auth.authenticate(self.cmd_args,
                                     endpoint=self.endpoint,
                                     force=force)
        else:
            return auth.notebook_authenticate(self.cmd_args, force=force)

    def get_student_email(self):
        return auth.get_student_email(self.cmd_args, endpoint=self.endpoint)

    def get_identifier(self):
        return auth.get_identifier(self.cmd_args, endpoint=self.endpoint)

    def _load_tests(self):
        """Loads all tests specified by test_map."""
        log.info('Loading tests')
        for file_pattern, sources in self.tests.items():
            for source in sources.split(","):
                # Separate filepath and parameter
                if ':' in file_pattern:
                    file_pattern, parameter = file_pattern.split(':', 1)
                else:
                    parameter = ''

                for file in sorted(glob.glob(file_pattern)):
                    try:
                        module = importlib.import_module(self._TESTS_PACKAGE +
                                                         '.' + source)
                    except ImportError:
                        raise ex.LoadingException(
                            'Invalid test source: {}'.format(source))

                    test_name = file
                    if parameter:
                        test_name += ':' + parameter

                    self.test_map.update(module.load(file, parameter, self))

                    log.info('Loaded {}'.format(test_name))

    def dump_tests(self):
        """Dumps all tests, as determined by their .dump() method.

        PARAMETERS:
        tests -- dict; file -> Test. Each Test object has a .dump method
                 that takes a filename and serializes the test object.
        """
        log.info('Dumping tests')
        for test in self.test_map.values():
            try:
                test.dump()
            except ex.SerializeException as e:
                log.warning('Unable to dump {}: {}'.format(test.name, str(e)))
            else:
                log.info('Dumped {}'.format(test.name))

    def _resolve_specified_tests(self, questions, all_tests=False):
        """For each of the questions specified on the command line,
        find the test corresponding that question.

        Questions are preserved in the order that they are specified on the
        command line. If no questions are specified, use the entire set of
        tests.
        """
        if not questions and not all_tests \
                and self.default_tests != core.NoValue \
                and len(self.default_tests) > 0:
            log.info('Using default tests (no questions specified): '
                     '{}'.format(self.default_tests))
            bad_tests = sorted(test for test in self.default_tests
                               if test not in self.test_map)
            if bad_tests:
                error_message = (
                    "Required question(s) missing: {}. "
                    "This often is the result of accidentally deleting the question's doctests or the entire function."
                )
                raise ex.LoadingException(
                    error_message.format(", ".join(bad_tests)))
            return [self.test_map[test] for test in self.default_tests]
        elif not questions:
            log.info(
                'Using all tests (no questions specified and no default tests)'
            )
            return list(self.test_map.values())
        elif not self.test_map:
            log.info('No tests loaded')
            return []

        specified_tests = []
        for question in questions:
            if question not in self.test_map:
                print('Test "{}" not found.'.format(question))
                print('Did you mean one of the following? '
                      '(Names are case sensitive)')
                for test in self.test_map:
                    print('    {}'.format(test))
                raise ex.LoadingException(
                    'Invalid test specified: {}'.format(question))

            log.info('Adding {} to specified tests'.format(question))
            if question not in specified_tests:
                specified_tests.append(self.test_map[question])
        return specified_tests

    def _load_protocols(self):
        log.info('Loading protocols')
        for proto in self._PROTOCOLS:
            module = importlib.import_module(self._PROTOCOL_PACKAGE + '.' +
                                             proto)
            self.protocol_map[proto] = module.protocol(self.cmd_args, self)
            log.info('Loaded protocol "{}"'.format(proto))

    def _print_header(self):
        format.print_line('=')
        print('Assignment: {}'.format(self.name))
        print('OK, version {}'.format(client.__version__))
        format.print_line('=')
        print()
예제 #8
0
class Assignment(core.Serializable):
    name = core.String()
    endpoint = core.String()
    src = core.List(type=str, optional=True)
    tests = core.Dict(keys=str, values=str, ordered=True)
    protocols = core.List(type=str)

    _TESTS_PACKAGE = 'client.sources'
    _PROTOCOL_PACKAGE = 'client.protocols'

    def __init__(self, cmd_args, **fields):
        self.cmd_args = cmd_args
        self.test_map = collections.OrderedDict()
        self.protocol_map = collections.OrderedDict()
        self.specified_tests = []

    def post_instantiation(self):
        self._print_header()
        self._load_tests()
        self._load_protocols()
        self._resolve_specified_tests()

    def _load_tests(self):
        """Loads all tests specified by test_map.

        PARAMETERS:
        test_map -- dict; file pattern -> serialize module. Every file that
                    that matches the UNIX style file pattern will be loaded
                    by the module.load method.
        """
        log.info('Loading tests')
        for file_pattern, source in self.tests.items():
            # Separate filepath and parameter
            if ':' in file_pattern:
                file_pattern, parameter = file_pattern.split(':', maxsplit=1)
            else:
                parameter = ''

            files = glob.glob(file_pattern)
            if not files:
                error_msg = 'No tests found for pattern: {}'.format(
                    file_pattern)
                print(error_msg)
                raise ex.LoadingException(error_msg)

            for file in files:
                try:
                    module = importlib.import_module(self._TESTS_PACKAGE +
                                                     '.' + source)
                except ImportError:
                    raise ex.LoadingException(
                        'Invalid test source: {}'.format(source))

                test_name = file
                if parameter:
                    test_name += ':' + parameter
                self.test_map[test_name] = module.load(file, parameter,
                                                       self.cmd_args)
                log.info('Loaded {}'.format(test_name))

    def dump_tests(self):
        """Dumps all tests, as determined by their .dump() method.

        PARAMETERS:
        tests -- dict; file -> Test. Each Test object has a .dump method
                 that takes a filename and serializes the test object.
        """
        log.info('Dumping tests')
        for file, test in self.test_map.items():
            try:
                test.dump(file)
            except ex.SerializeException as e:
                log.info('Unable to dump {} to {}: {}'.format(
                    test.name, file, str(e)))
            else:
                log.info('Dumped {} to {}'.format(test.name, file))

    def _resolve_specified_tests(self):
        """For each of the questions specified on the command line,
        find the best test corresponding that question.

        The best match is found by finding the test filepath that has the
        smallest edit distance with the specified question.

        Questions are preserved in the order that they are specified on the
        command line. If no questions are specified, use the entire set of
        tests.
        """
        if not self.cmd_args.question:
            log.info('Using all tests (no questions specified)')
            self.specified_tests = list(self.test_map.values())
            return
        elif not self.test_map:
            log.info('No tests loaded')
            return
        for question in self.cmd_args.question:
            matches = []
            for test in self.test_map:
                if _has_subsequence(test.lower(), question.lower()):
                    matches.append(test)

            if len(matches) > 1:
                print('Did you mean one of the following?')
                for test in matches:
                    print('    {}'.format(test))
                raise ex.LoadingException(
                    'Ambiguous test specified: {}'.format(question))

            elif not matches:
                print('Did you mean one of the following?')
                for test in self.test_map:
                    print('    {}'.format(test))
                raise ex.LoadingException(
                    'Invalid test specified: {}'.format(question))

            match = matches[0]
            log.info('Matched {} to {}'.format(question, match))
            if match not in self.specified_tests:
                self.specified_tests.append(self.test_map[match])

    def _load_protocols(self):
        log.info('Loading protocols')
        for proto in self.protocols:
            try:
                module = importlib.import_module(self._PROTOCOL_PACKAGE + '.' +
                                                 proto)
            except ImportError:
                raise ex.LoadingException('Invalid protocol: {}'.format(proto))

            self.protocol_map[proto] = module.protocol(self.cmd_args, self)
            log.info('Loaded protocol "{}"'.format(proto))

    def _print_header(self):
        format.print_line('=')
        print('Assignment: {}'.format(self.name))
        print('OK, version {}'.format(client.__version__))
        format.print_line('=')
        print()
예제 #9
0
class Assignment(core.Serializable):
    name = core.String()
    endpoint = core.String()
    src = core.List(type=str, optional=True)
    tests = core.Dict(keys=str, values=str, ordered=True)
    default_tests = core.List(type=str, optional=True)
    protocols = core.List(type=str)

    _TESTS_PACKAGE = 'client.sources'
    _PROTOCOL_PACKAGE = 'client.protocols'

    def __init__(self, cmd_args, **fields):
        self.cmd_args = cmd_args
        self.test_map = collections.OrderedDict()
        self.protocol_map = collections.OrderedDict()
        self.specified_tests = []

    def post_instantiation(self):
        self._print_header()

    def load(self):
        """Load tests and protocols."""
        self._load_tests()
        self._load_protocols()
        self._resolve_specified_tests()

    def _load_tests(self):
        """Loads all tests specified by test_map."""
        log.info('Loading tests')
        for file_pattern, source in self.tests.items():
            # Separate filepath and parameter
            if ':' in file_pattern:
                file_pattern, parameter = file_pattern.split(':', 1)
            else:
                parameter = ''

            for file in sorted(glob.glob(file_pattern)):
                try:
                    module = importlib.import_module(self._TESTS_PACKAGE + '.' + source)
                except ImportError:
                    raise ex.LoadingException('Invalid test source: {}'.format(source))

                test_name = file
                if parameter:
                    test_name += ':' + parameter
                self.test_map.update(module.load(file, parameter, self))
                log.info('Loaded {}'.format(test_name))

        if not self.test_map:
            raise ex.LoadingException('No tests loaded')

    def dump_tests(self):
        """Dumps all tests, as determined by their .dump() method.

        PARAMETERS:
        tests -- dict; file -> Test. Each Test object has a .dump method
                 that takes a filename and serializes the test object.
        """
        log.info('Dumping tests')
        for test in self.test_map.values():
            try:
                test.dump()
            except ex.SerializeException as e:
                log.warning('Unable to dump {}: {}'.format(test.name, str(e)))
            else:
                log.info('Dumped {}'.format(test.name))

    def _resolve_specified_tests(self):
        """For each of the questions specified on the command line,
        find the test corresponding that question.

        Questions are preserved in the order that they are specified on the
        command line. If no questions are specified, use the entire set of
        tests.
        """
        if not self.cmd_args.question and not self.cmd_args.all \
                and self.default_tests != core.NoValue \
                and len(self.default_tests) > 0:
            log.info('Using default tests (no questions specified): '
                     '{}'.format(self.default_tests))
            self.specified_tests = [self.test_map[test]
                                    for test in self.default_tests]
            return
        elif not self.cmd_args.question:
            log.info('Using all tests (no questions specified and no default tests)')
            self.specified_tests = list(self.test_map.values())
            return
        elif not self.test_map:
            log.info('No tests loaded')
            return
        for question in self.cmd_args.question:
            if question not in self.test_map:
                print('Test "{}" not found.'.format(question))
                print('Did you mean one of the following? '
                      '(Names are case sensitive)')
                for test in self.test_map:
                    print('    {}'.format(test))
                raise ex.LoadingException('Invalid test specified: {}'.format(question))

            log.info('Adding {} to specified tests'.format(question))
            if question not in self.specified_tests:
                self.specified_tests.append(self.test_map[question])


    def _load_protocols(self):
        log.info('Loading protocols')
        for proto in self.protocols:
            try:
                module = importlib.import_module(self._PROTOCOL_PACKAGE + '.' + proto)
            except ImportError:
                raise ex.LoadingException('Invalid protocol: {}'.format(proto))

            self.protocol_map[proto] = module.protocol(self.cmd_args, self)
            log.info('Loaded protocol "{}"'.format(proto))

    def _print_header(self):
        format.print_line('=')
        print('Assignment: {}'.format(self.name))
        print('OK, version {}'.format(client.__version__))
        format.print_line('=')
        print()
예제 #10
0
 def testConstructor_homogeneousEmptyList(self):
     field = core.List(type=str)
     self.assertTrue(field.is_valid([]))
예제 #11
0
 def testConstructor_heterogeneousEmptyList(self):
     field = core.List()
     self.assertTrue(field.is_valid([]))
예제 #12
0
 def testConstructor_homogeneous(self):
     field = core.List(type=int)
     self.assertFalse(field.is_valid([1, 'hi', 6]))
     self.assertTrue(field.is_valid([1, 2, 3, 4]))
예제 #13
0
 def testConstructor_heterogeneous(self):
     field = core.List()
     self.assertTrue(field.is_valid([1, 'hi', 6]))
예제 #14
0
파일: models.py 프로젝트: xyfLily/UCB61A
class Suite(core.Serializable):
    type = core.String()
    scored = core.Boolean(default=True)
    cases = core.List()

    def __init__(self, test, verbose, interactive, timeout=None, **fields):
        super().__init__(**fields)
        self.test = test
        self.verbose = verbose
        self.interactive = interactive
        self.timeout = timeout
        self.run_only = []

    def run(self, test_name, suite_number, env=None):
        """Subclasses should override this method to run tests.

        PARAMETERS:
        test_name    -- str; name of the parent test.
        suite_number -- int; suite number, assumed to be 1-indexed.
        env          -- dict; used by programmatic API to provide a
                        custom environment to run tests with.

        RETURNS:
        dict; results of the following form:
        {
            'passed': int,
            'failed': int,
            'locked': int,
        }
        """
        raise NotImplementedError

    def enumerate_cases(self):
        enumerated = enumerate(self.cases)
        if self.run_only:
            return [x for x in enumerated if x[0] + 1 in self.run_only]
        return enumerated

    def _run_case(self, test_name, suite_number, case, case_number):
        """A wrapper for case.run().

        Prints informative output and also captures output of the test case
        and returns it as a log. The output is printed only if the case fails,
        or if self.verbose is True.
        """
        output.off()  # Delay printing until case status is determined.
        log_id = output.new_log()
        format.print_line('-')
        print('{} > Suite {} > Case {}'.format(test_name, suite_number,
                                               case_number))
        print()

        success = case.run()
        if success:
            print('-- OK! --')

        output.on()
        output_log = output.get_log(log_id)
        output.remove_log(log_id)

        if not success or self.verbose:
            print(''.join(output_log))
        if not success:
            short_name = self.test.get_short_name()
            # TODO: Change when in notebook mode
            print('Run only this test case with '
                  '"python3 ok -q {} --suite {} --case {}"'.format(
                      short_name, suite_number, case_number))
        return success
예제 #15
0
 def assertCoerce_pass(self, expect, value, **fields):
     field = core.List(**fields)
     self.assertEqual(expect, field.coerce(value))
예제 #16
0
class OkTest(models.Test):
    suites = core.List()
    description = core.String(optional=True)

    def __init__(self, file, suite_map, assign_name, verbose, interactive,
                 timeout=None, **fields):
        super().__init__(**fields)
        self.file = file
        self.suite_map = suite_map
        self.verbose = verbose
        self.interactive = interactive
        self.timeout = timeout
        self.assignment_name = assign_name

    def post_instantiation(self):
        for i, suite in enumerate(self.suites):
            if not isinstance(suite, dict):
                raise ex.SerializeException('Test cases must be dictionaries')
            elif 'type' not in suite:
                raise ex.SerializeException('Suites must have field "type"')
            elif suite['type'] not in self.suite_map:
                raise ex.SerializeException('Invalid suite type: '
                                            '{}'.format(suite['type']))
            self.suites[i] = self.suite_map[suite['type']](
                    self.verbose, self.interactive, self.timeout, **suite)

    def run(self):
        """Runs the suites associated with this OK test.

        RETURNS:
        dict; the results for this test, in the form
        {
            'passed': int,
            'failed': int,
            'locked': int,
        }
        """
        passed, failed, locked = 0, 0, 0
        for i, suite in enumerate(self.suites):
            results = suite.run(self.name, i + 1)

            passed += results['passed']
            failed += results['failed']
            locked += results['locked']

            if not self.verbose and (failed > 0 or locked > 0):
                # Stop at the first failed test
                break

        if locked > 0:
            print()
            print('There are still locked tests! '
                  'Use the -u option to unlock them')

        if type(self.description) == str and self.description:
            print()
            print(self.description)
            print()
        return {
            'passed': passed,
            'failed': failed,
            'locked': locked,
        }

    def score(self):
        """Runs test cases and computes the score for this particular test.

        Scores are determined by aggregating results from suite.run() for each
        suite. A suite is considered passed only if it results in no locked
        nor failed results.

        The points available for this test are distributed evenly across
        scoreable (i.e. unlocked and 'scored' = True) suites.
        """
        passed, total = 0, 0
        for i, suite in enumerate(self.suites):
            if not suite.scored:
                continue

            total += 1
            results = suite.run(self.name, i + 1)

            if results['locked'] == 0 and results['failed'] == 0:
                passed += 1
        if total > 0:
            score = passed * self.points / total
        else:
            score = 0.0

        format.print_progress_bar(self.name, passed, total - passed, 0)
        print()
        return score

    def unlock(self, interact):
        total_cases = len([case for suite in self.suites
                                for case in suite.cases])
        for suite_num, suite in enumerate(self.suites):
            for case_num, case in enumerate(suite.cases):
                case_id = '{} > Suite {} > Case {}'.format(
                            self.name, suite_num + 1, case_num + 1)

                format.print_line('-')
                print(case_id)
                print('(cases remaining: {})'.format(total_cases))
                print()
                total_cases -= 1

                if case.locked != True:
                    print('-- Already unlocked --')
                    print()
                    continue

                case.unlock(self.unique_id_prefix, case_id, interact)

        assert total_cases == 0, 'Number of cases is incorrect'
        format.print_line('-')
        print('OK! All cases for {} unlocked.'.format(self.name))
        print()
예제 #17
0
 def assertCoerce_errors(self, value, **fields):
     field = core.List(**fields)
     self.assertRaises(ex.SerializeException, field.coerce, value)
예제 #18
0
class Assignment(core.Serializable):
    name = core.String()
    endpoint = core.String()
    src = core.List(type=str, optional=True)
    tests = core.Dict(keys=str, values=str, ordered=True)
    default_tests = core.List(type=str, optional=True)
    protocols = core.List(type=str)

    ####################
    # Programmatic API #
    ####################

    def grade(self, question, env=None, skip_locked_cases=False):
        """Runs tests for a particular question. The setup and teardown will
        always be executed.

        question -- str; a question name (as would be entered at the command
                    line
        env      -- dict; an environment in which to execute the tests. If
                    None, uses the environment of __main__. The original
                    dictionary is never modified; each test is given a
                    duplicate of env.
        skip_locked_cases -- bool; if False, locked cases will be tested

        Returns: dict; maps question names (str) -> results (dict). The
        results dictionary contains the following fields:
        - "passed": int (number of test cases passed)
        - "failed": int (number of test cases failed)
        - "locked": int (number of test cases locked)
        """
        if env is None:
            import __main__
            env = __main__.__dict__
        messages = {}
        tests = self._resolve_specified_tests([question], all_tests=False)
        for test in tests:
            try:
                for suite in test.suites:
                    suite.skip_locked_cases = skip_locked_cases
                    suite.console.skip_locked_cases = skip_locked_cases
                    suite.console.hash_key = self.name
            except AttributeError:
                pass
        test_name = tests[0].name
        grade(tests, messages, env)
        return messages['grading'][test_name]

    ############
    # Internal #
    ############

    _TESTS_PACKAGE = 'client.sources'
    _PROTOCOL_PACKAGE = 'client.protocols'

    def __init__(self, args, **fields):
        self.cmd_args = args
        self.test_map = collections.OrderedDict()
        self.protocol_map = collections.OrderedDict()

    def post_instantiation(self):
        self._print_header()
        self._load_tests()
        self._load_protocols()
        self.specified_tests = self._resolve_specified_tests(
            self.cmd_args.question, self.cmd_args.all)

    def _load_tests(self):
        """Loads all tests specified by test_map."""
        log.info('Loading tests')
        for file_pattern, source in self.tests.items():
            # Separate filepath and parameter
            if ':' in file_pattern:
                file_pattern, parameter = file_pattern.split(':', 1)
            else:
                parameter = ''

            for file in sorted(glob.glob(file_pattern)):
                try:
                    module = importlib.import_module(self._TESTS_PACKAGE +
                                                     '.' + source)
                except ImportError:
                    raise ex.LoadingException(
                        'Invalid test source: {}'.format(source))

                test_name = file
                if parameter:
                    test_name += ':' + parameter

                self.test_map.update(module.load(file, parameter, self))

                log.info('Loaded {}'.format(test_name))

        if not self.test_map:
            raise ex.LoadingException('No tests loaded')

    def dump_tests(self):
        """Dumps all tests, as determined by their .dump() method.

        PARAMETERS:
        tests -- dict; file -> Test. Each Test object has a .dump method
                 that takes a filename and serializes the test object.
        """
        log.info('Dumping tests')
        for test in self.test_map.values():
            try:
                test.dump()
            except ex.SerializeException as e:
                log.warning('Unable to dump {}: {}'.format(test.name, str(e)))
            else:
                log.info('Dumped {}'.format(test.name))

    def _resolve_specified_tests(self, questions, all_tests=False):
        """For each of the questions specified on the command line,
        find the test corresponding that question.

        Questions are preserved in the order that they are specified on the
        command line. If no questions are specified, use the entire set of
        tests.
        """
        if not questions and not all_tests \
                and self.default_tests != core.NoValue \
                and len(self.default_tests) > 0:
            log.info('Using default tests (no questions specified): '
                     '{}'.format(self.default_tests))
            return [self.test_map[test] for test in self.default_tests]
        elif not questions:
            log.info(
                'Using all tests (no questions specified and no default tests)'
            )
            return list(self.test_map.values())
        elif not self.test_map:
            log.info('No tests loaded')
            return []

        specified_tests = []
        for question in questions:
            if question not in self.test_map:
                print('Test "{}" not found.'.format(question))
                print('Did you mean one of the following? '
                      '(Names are case sensitive)')
                for test in self.test_map:
                    print('    {}'.format(test))
                raise ex.LoadingException(
                    'Invalid test specified: {}'.format(question))

            log.info('Adding {} to specified tests'.format(question))
            if question not in specified_tests:
                specified_tests.append(self.test_map[question])
        return specified_tests

    def _load_protocols(self):
        log.info('Loading protocols')
        for proto in self.protocols:
            try:
                module = importlib.import_module(self._PROTOCOL_PACKAGE + '.' +
                                                 proto)
                self.protocol_map[proto] = module.protocol(self.cmd_args, self)
                log.info('Loaded protocol "{}"'.format(proto))
            except ImportError:
                log.debug('Skipping unknown protocol "{}"'.format(proto))

    def _print_header(self):
        format.print_line('=')
        print('Assignment: {}'.format(self.name))
        print('OK, version {}'.format(client.__version__))
        format.print_line('=')
        print()
예제 #19
0
 def testToJson_shallow(self):
     field = core.List()
     expect = [1, 'hi', True]
     self.assertEqual(expect, field.to_json(expect))