Example #1
0
    def test_multivaluedict(self):
        d = MultiValueDict({'name': ['Adrian', 'Simon'],
                            'position': ['Developer']})

        self.assertEqual(d['name'], 'Simon')
        self.assertEqual(d.get('name'), 'Simon')
        self.assertEqual(d.getlist('name'), ['Adrian', 'Simon'])
        self.assertEqual(
            sorted(d.items()),
            [('name', 'Simon'), ('position', 'Developer')]
        )

        self.assertEqual(
            sorted(d.lists()),
            [('name', ['Adrian', 'Simon']), ('position', ['Developer'])]
        )

        with self.assertRaisesMessage(MultiValueDictKeyError, 'lastname'):
            d.__getitem__('lastname')

        self.assertIsNone(d.get('lastname'))
        self.assertEqual(d.get('lastname', 'nonexistent'), 'nonexistent')
        self.assertEqual(d.getlist('lastname'), [])
        self.assertEqual(d.getlist('doesnotexist', ['Adrian', 'Simon']),
                         ['Adrian', 'Simon'])

        d.setlist('lastname', ['Holovaty', 'Willison'])
        self.assertEqual(d.getlist('lastname'), ['Holovaty', 'Willison'])
        self.assertEqual(sorted(d.values()), ['Developer', 'Simon', 'Willison'])
Example #2
0
    def test_multivaluedict(self):
        d = MultiValueDict({"name": ["Adrian", "Simon"], "position": ["Developer"]})

        self.assertEqual(d["name"], "Simon")
        self.assertEqual(d.get("name"), "Simon")
        self.assertEqual(d.getlist("name"), ["Adrian", "Simon"])
        self.assertEqual(sorted(list(six.iteritems(d))), [("name", "Simon"), ("position", "Developer")])

        self.assertEqual(sorted(list(six.iterlists(d))), [("name", ["Adrian", "Simon"]), ("position", ["Developer"])])

        # MultiValueDictKeyError: "Key 'lastname' not found in
        # <MultiValueDict: {'position': ['Developer'],
        #                   'name': ['Adrian', 'Simon']}>"
        six.assertRaisesRegex(
            self, MultiValueDictKeyError, r'"Key \'lastname\' not found in <MultiValueDict', d.__getitem__, "lastname"
        )

        self.assertEqual(d.get("lastname"), None)
        self.assertEqual(d.get("lastname", "nonexistent"), "nonexistent")
        self.assertEqual(d.getlist("lastname"), [])
        self.assertEqual(d.getlist("doesnotexist", ["Adrian", "Simon"]), ["Adrian", "Simon"])

        d.setlist("lastname", ["Holovaty", "Willison"])
        self.assertEqual(d.getlist("lastname"), ["Holovaty", "Willison"])
        self.assertEqual(sorted(list(six.itervalues(d))), ["Developer", "Simon", "Willison"])
Example #3
0
    def test_multivaluedict(self):
        d = MultiValueDict({'name': ['Adrian', 'Simon'],
                            'position': ['Developer']})

        self.assertEqual(d['name'], 'Simon')
        self.assertEqual(d.get('name'), 'Simon')
        self.assertEqual(d.getlist('name'), ['Adrian', 'Simon'])
        self.assertEqual(sorted(list(six.iteritems(d))),
                          [('name', 'Simon'), ('position', 'Developer')])

        self.assertEqual(sorted(list(six.iterlists(d))),
                          [('name', ['Adrian', 'Simon']),
                           ('position', ['Developer'])])

        six.assertRaisesRegex(self, MultiValueDictKeyError, 'lastname',
            d.__getitem__, 'lastname')

        self.assertEqual(d.get('lastname'), None)
        self.assertEqual(d.get('lastname', 'nonexistent'), 'nonexistent')
        self.assertEqual(d.getlist('lastname'), [])
        self.assertEqual(d.getlist('doesnotexist', ['Adrian', 'Simon']),
                         ['Adrian', 'Simon'])

        d.setlist('lastname', ['Holovaty', 'Willison'])
        self.assertEqual(d.getlist('lastname'), ['Holovaty', 'Willison'])
        self.assertEqual(sorted(list(six.itervalues(d))),
                         ['Developer', 'Simon', 'Willison'])
Example #4
0
    def test_multivaluedict(self):
        d = MultiValueDict({'name': ['Adrian', 'Simon'],
                            'position': ['Developer']})

        self.assertEqual(d['name'], 'Simon')
        self.assertEqual(d.get('name'), 'Simon')
        self.assertEqual(d.getlist('name'), ['Adrian', 'Simon'])
        self.assertEqual(list(d.iteritems()),
                          [('position', 'Developer'), ('name', 'Simon')])

        self.assertEqual(list(d.iterlists()),
                          [('position', ['Developer']),
                           ('name', ['Adrian', 'Simon'])])

        # MultiValueDictKeyError: "Key 'lastname' not found in
        # <MultiValueDict: {'position': ['Developer'],
        #                   'name': ['Adrian', 'Simon']}>"
        self.assertRaisesMessage(MultiValueDictKeyError,
            '"Key \'lastname\' not found in <MultiValueDict: {\'position\':'\
            ' [\'Developer\'], \'name\': [\'Adrian\', \'Simon\']}>"',
            d.__getitem__, 'lastname')

        self.assertEqual(d.get('lastname'), None)
        self.assertEqual(d.get('lastname', 'nonexistent'), 'nonexistent')
        self.assertEqual(d.getlist('lastname'), [])
        self.assertEqual(d.getlist('doesnotexist', ['Adrian', 'Simon']),
                         ['Adrian', 'Simon'])

        d.setlist('lastname', ['Holovaty', 'Willison'])
        self.assertEqual(d.getlist('lastname'), ['Holovaty', 'Willison'])
        self.assertEqual(d.values(), ['Developer', 'Simon', 'Willison'])
        self.assertEqual(list(d.itervalues()),
                          ['Developer', 'Simon', 'Willison'])
Example #5
0
    def get_form_kwargs(self):
        """
        Get the kwargs to pass to the form for this script. By default, returns the task
        arguments.

        - 'data' list entries converted into a querydict.
        - 'files' list entries converted to File objects
        """
        kwargs = self.task.arguments.copy() # don't modify self.task.arguments['data']
        if 'data' in kwargs:
            d = QueryDict('').copy()
            for k, v in kwargs['data'].items():
                if isinstance(v, list):
                    d.setlist(k, v)
                else:
                    d[k] = v
            kwargs['data'] = d

        # Convert file dictionaries (as supplied by get_temporary_file_dict) to
        # SimpleUploadedFile objects which Django understands.
        if 'files' in kwargs:
            files = MultiValueDict(kwargs['files'])
            for key in files.keys():
                filedict_list = files.getlist(key)
                for i, fdict in enumerate(filedict_list):
                    if isinstance(fdict, dict):
                        fdict = dict(fdict)
                        fdict["content"] = open(fdict["path"], "rb").read()
                        filedict_list[i] = SimpleUploadedFile.from_dict(fdict)
                files.setlist(key, filedict_list)
            kwargs['files'] = files
        return kwargs
    def render(self, context):
        url = self.url.resolve(context)
        parts = url.split("?")
        url = parts[0]
        original_query = QueryDict("&".join(parts[1:])).copy()
        query = MultiValueDict(dict((k, v.resolve(context)) for k, v in six.iteritems(self.kwargs)))

        args_to_set = []
        args_to_remove = []
        for arg in self.args:
            k = arg.resolve(context)
            if k.startswith("-"):
                args_to_remove.append(k[1:])
            else:
                args_to_set.append(k)

        kwargs_to_set = MultiValueDict()
        kwargs_to_add = MultiValueDict()
        kwargs_to_remove = MultiValueDict()
        for (k, v) in query.lists():
            # make sure val is a list of strings
            val = map(str, v if isinstance(v, list) else [v])
            if k.startswith("+"):
                kwargs_to_add.setlistdefault(k[1:]).extend(val)
            elif k.startswith("-"):
                kwargs_to_remove.setlistdefault(k[1:]).extend(val)
            else:
                kwargs_to_set.setlist(k, val)

        filter_excluded_kwargs = lambda values: [
            (k, list(set(v) - set(kwargs_to_remove.getlist(k))))
            for k, v in values.iterlists()
            if k not in args_to_remove
        ]

        args = [arg for arg in args_to_set if arg not in args_to_remove]
        kwargs = MultiValueDict()
        for k, v in filter_excluded_kwargs(original_query):
            kwargs.setlistdefault(k).extend(v if isinstance(v, list) else [v])
        for k, v in filter_excluded_kwargs(kwargs_to_set):
            kwargs.setlist(k, v)
        for k, v in filter_excluded_kwargs(kwargs_to_add):
            kwargs.setlistdefault(k).extend(v)

        output = list(set(args))
        for k, l in kwargs.lists():
            output.extend(["%s=%s" % (k, v) if v else six.text_type(k) for v in l])

        if output:
            url = "?".join([url, "&".join(set(output))])

        if self.asvar:
            context[self.asvar] = url
            return ""
        else:
            return url
Example #7
0
    def value_from_datadict(self, data, files, name):
        if not isinstance(data, MultiValueDict):
            for key, value in data.items():
                # treat value as csv string: ?foo=1,2
                if isinstance(value, str):
                    data[key] = [x.strip() for x in value.rstrip(',').split(',') if x]
            data = MultiValueDict(data)

        values_list = data.getlist(name, data.getlist('%s[]' % name)) or []

        # apparently its an array, so no need to process it's values as csv
        # ?foo=1&foo=2 -> data.getlist(foo) -> foo = [1, 2]
        # ?foo[]=1&foo[]=2 -> data.getlist(foo[]) -> foo = [1, 2]
        if len(values_list) > 0:
            ret = [x for x in values_list if x]
        else:
            ret = []

        return list(set(ret))
Example #8
0
    def test_multivaluedict(self):
        d = MultiValueDict({"name": ["Adrian", "Simon"], "position": ["Developer"]})

        self.assertEqual(d["name"], "Simon")
        self.assertEqual(d.get("name"), "Simon")
        self.assertEqual(d.getlist("name"), ["Adrian", "Simon"])
        self.assertEqual(sorted(list(six.iteritems(d))), [("name", "Simon"), ("position", "Developer")])

        self.assertEqual(sorted(list(six.iterlists(d))), [("name", ["Adrian", "Simon"]), ("position", ["Developer"])])

        six.assertRaisesRegex(self, MultiValueDictKeyError, "lastname", d.__getitem__, "lastname")

        self.assertEqual(d.get("lastname"), None)
        self.assertEqual(d.get("lastname", "nonexistent"), "nonexistent")
        self.assertEqual(d.getlist("lastname"), [])
        self.assertEqual(d.getlist("doesnotexist", ["Adrian", "Simon"]), ["Adrian", "Simon"])

        d.setlist("lastname", ["Holovaty", "Willison"])
        self.assertEqual(d.getlist("lastname"), ["Holovaty", "Willison"])
        self.assertEqual(sorted(list(six.itervalues(d))), ["Developer", "Simon", "Willison"])
Example #9
0
    def value_from_datadict(self, data, files, name):
        if not isinstance(data, MultiValueDict):
            data = MultiValueDict(data)

        values_list = data.getlist(name, data.getlist('%s[]' % name)) or []

        if isinstance(values_list, string_types):
            values_list = [values_list]

        # apparently its an array, so no need to process it's values as csv
        # ?foo=1&foo=2 -> data.getlist(foo) -> foo = [1, 2]
        # ?foo[]=1&foo[]=2 -> data.getlist(foo[]) -> foo = [1, 2]
        if len(values_list) > 1:
            ret = [x for x in values_list if x]
        elif len(values_list) == 1:
            # treat first element as csv string
            # ?foo=1,2 -> data.getlist(foo) -> foo = ['1,2']
            ret = [x.strip() for x in values_list[0].rstrip(',').split(',') if x]
        else:
            ret = []

        return list(set(ret))
 def render(self, context):
     if 'request' in context and not self.ignore:
         args = MultiValueDict(context['request'].GET.iterlists())
     else:
         args = MultiValueDict()
     
     books = []
     if settings.BOOKREADER_COMPARISON_TEMPLATE_VARIABLE in context:
         try:
             for book in context[settings.BOOKREADER_COMPARISON_TEMPLATE_VARIABLE]:
                 try:
                     books.append(book.identifier)
                 except:
                     log.exception('Error getting a book identifier')
         except:
             log.exception('Error iterating over already compared books')
     elif settings.BOOKREADER_COMPARISON_GET_ARGUMENT in args:
         try:
             for book in args.getlist(settings.BOOKREADER_COMPARISON_GET_ARGUMENT):
                 try:
                     books.append(unquote(book))
                 except:
                     log.exception('Error unquoting a book id from GET')
         except:
             log.exception('Error accessing GET variables')
     
     def resolve(book):
         if isinstance(book, basestring):
             return book
         book = book.resolve(context)
         
         if isinstance(book, basestring):
             return book
         
         if isinstance(book, Book):
             return book.identifier
         
         return None
     
     for book in filter(lambda b: b is not None, map(resolve, self.remove)):
         try:
             books.remove(book)
         except ValueError:
             pass
     
     for book in filter(lambda b: b is not None, map(resolve, self.add)):
         if book not in books:
             books.append(book)
     
     args.setlist(settings.BOOKREADER_COMPARISON_GET_ARGUMENT, books)
     return urlencode(dict(args.iterlists()),doseq=True)
Example #11
0
 def testTwoFile(self):
     """
     Tests the binding of a Form with two files attached.
     """
     file_data = MultiValueDict({'files[]': [{'filename':'face.jpg', 'content': 'www'},{'filename':'lah.jpg', 'content': 'woop'}]})
     f = self.RequiredForm({}, file_data)
     self.assertTrue(f.is_bound)
     self.assertTrue(f.is_valid())
     
     self.assertEquals(len(f.cleaned_data['files']), 2)
     for input_file in file_data.getlist('files[]'):
         found = False
         for output_file in f.cleaned_data['files']:
             found = found or (output_file.filename == input_file['filename'] and output_file.content == input_file['content'])
         self.assertTrue(found)
Example #12
0
    def formfield_for_dbfield(self, db_field, **kwargs):
        request = kwargs.pop("request", None)

        # Add a select field of available commands
        if db_field.name == 'command':
            choices_dict = MultiValueDict()
            for command, app in get_commands().items():
                choices_dict.appendlist(app, command)

            choices = []
            for key in choices_dict.keys():
                commands = choices_dict.getlist(key)
                commands.sort()
                choices.append([key, [[c, c] for c in commands]])

            choices.insert(0, ('', '--- None ---'))
            kwargs['widget'] = forms.widgets.Select(choices=choices)
            return db_field.formfield(**kwargs)

        kwargs['request'] = request
        return super(JobAdmin, self).formfield_for_dbfield(db_field, **kwargs)
    def formfield_for_dbfield(self, db_field, **kwargs):
        request = kwargs.pop("request", None)

        # Add a select field of available commands
        if db_field.name == 'command':
            choices_dict = MultiValueDict()
            for command, app in get_commands().items():
                choices_dict.appendlist(app, command)

            choices = []
            for key in choices_dict.keys():
                #if str(key).startswith('<'):
                #    key = str(key)
                commands = choices_dict.getlist(key)
                commands.sort()
                choices.append([key, [[c,c] for c in commands]])

            kwargs['widget'] = forms.widgets.Select(choices=choices)
            return db_field.formfield(**kwargs)

        kwargs['request'] = request
        return super(JobAdmin, self).formfield_for_dbfield(db_field, **kwargs)
Example #14
0
    def formfield_for_dbfield(self, db_field, **kwargs):
        request = kwargs.pop("request", None)

        # Add a select field of available commands
        if db_field.name == 'command':
            choices_dict = MultiValueDict()
            #l = get_commands().items():
            #l = [('kitsune_base_check', 'kitsune')]
            l = get_kitsune_checks()
            for command, app in l:
                choices_dict.appendlist(app, command)

            choices = []
            for key in choices_dict.keys():
                #if str(key).startswith('<'):
                #    key = str(key)
                commands = choices_dict.getlist(key)
                commands.sort()
                choices.append([key, [[c, c] for c in commands]])

            kwargs['widget'] = forms.widgets.Select(choices=choices)
            return db_field.formfield(**kwargs)
        kwargs['request'] = request
        return super(JobAdmin, self).formfield_for_dbfield(db_field, **kwargs)
Example #15
0
 def test_appendlist(self):
     d = MultiValueDict()
     d.appendlist('name', 'Adrian')
     d.appendlist('name', 'Simon')
     self.assertEqual(d.getlist('name'), ['Adrian', 'Simon'])
Example #16
0
 def test_getlist_none_empty_values(self):
     x = MultiValueDict({'a': None, 'b': []})
     self.assertIsNone(x.getlist('a'))
     self.assertEqual(x.getlist('b'), [])
Example #17
0
 def test_getlist_default(self):
     x = MultiValueDict({'a': [1]})
     MISSING = object()
     values = x.getlist('b', default=MISSING)
     self.assertIs(values, MISSING)
Example #18
0
 def setlistdefault(self, key, default_list=()):
     self._assert_mutable()
     if key not in self:
         self.setlist(key, default_list)
     return MultiValueDict.getlist(self, key)
Example #19
0
 def test_getlist_doesnt_mutate(self):
     x = MultiValueDict({'a': ['1', '2'], 'b': ['3']})
     values = x.getlist('a')
     values += x.getlist('b')
     self.assertEqual(x.getlist('a'), ['1', '2'])
Example #20
0
 def test_getlist_none_empty_values(self):
     x = MultiValueDict({'a': None, 'b': []})
     self.assertIsNone(x.getlist('a'))
     self.assertEqual(x.getlist('b'), [])
 def test_getlist_doesnt_mutate(self):
     x = MultiValueDict({"a": ["1", "2"], "b": ["3"]})
     values = x.getlist("a")
     values += x.getlist("b")
     self.assertEqual(x.getlist("a"), ["1", "2"])
Example #22
0
    def render(self, context):

        kwargs = MultiValueDict()
        for key in self.kwargs:
            key = smart_str(key, 'ascii')
            values = [value.resolve(context) for value in self.kwargs.getlist(key)]
            kwargs.setlist(key, values)

        if 'base' in kwargs:
            url = URLObject.parse(kwargs['base'])
        else:
            url = URLObject(scheme='http')

        if 'secure' in kwargs:
            if convert_to_boolean(kwargs['secure']):
                url = url.with_scheme('https')
            else:
                url = url.with_scheme('http')

        if 'query' in kwargs:
            query = kwargs['query']
            if isinstance(query, basestring):
                query = render_template_from_string_without_autoescape(query, context)
            url = url.with_query(query)

        if 'add_query' in kwargs:
            for query_to_add in kwargs.getlist('add_query'):
                if isinstance(query_to_add, basestring):
                    query_to_add = render_template_from_string_without_autoescape(query_to_add, context)
                    query_to_add = dict(decode_query(query_to_add))
                for key, value in query_to_add.items():
                    url = url.add_query_param(key, value)

        if 'scheme' in kwargs:
            url = url.with_scheme(kwargs['scheme'])

        if 'host' in kwargs:
            url = url.with_host(kwargs['host'])

        if 'path' in kwargs:
            url = url.with_path(kwargs['path'])

        if 'add_path' in kwargs:
            for path_to_add in kwargs.getlist('add_path'):
                url = url.add_path_component(path_to_add)

        if 'fragment' in kwargs:
            url = url.with_fragment(kwargs['fragment'])

        if 'port' in kwargs:
            url = url.with_port(kwargs['port'])

        # sensible default
        if not url.host:
            url = url.with_scheme('')

        # Convert the URLObject to its unicode representation
        url = unicode(url)

        # Handle escaping. By default, use the value of
        # context.autoescape. This can be overridden by
        # passing an "autoescape" keyword to the tag.
        if 'autoescape' in kwargs:
            autoescape = convert_to_boolean(kwargs['autoescape'])
        else:
            autoescape = context.autoescape

        if autoescape:
            url = escape(url)

        if self.asvar:
            context[self.asvar] = url
            return ''

        return url
Example #23
0
 def test_appendlist(self):
     d = MultiValueDict()
     d.appendlist('name', 'Adrian')
     d.appendlist('name', 'Simon')
     self.assertEqual(d.getlist('name'), ['Adrian', 'Simon'])
Example #24
0
class ModelTree(object):
    """A class to handle building and parsing a tree structure given a model.

        `root_model` - The root or "reference" model for the tree. Everything
        is relative to the root model.

        `exclude` - A list of models that are to be excluded from this tree.
        This is typically used to exclude models not intended to be exposed
        through this API.

        `routes` - Explicitly defines a join path between two models. Each
        route is made up of four components. Assuming some model hierarchy
        exists as shown below..

                                ModelA
                                /    \
                            ModelB  ModelC
                               |      |
                               \    ModelD
                                \    /
                                ModelE

        ..the traversal path from ModelA to ModelE is ambiguous. It could
        go from A -> B -> E or A -> C -> D -> E. By default, the shortest
        path is always choosen to reduce the number of joins necessary, but
        if ModelD did not exist..

                                ModelA
                                 /  \
                            ModelB  ModelC
                                 \  /
                                ModelE

        ..both paths only require two joins, thus the path that gets traversed
        first will only be the choosen one.

        To explicitly choose a path, a route can be defined. Taking the form::

            {
                'source': 'app1.model1',
                'target': 'app1.model2',
                'field': None,
                'symmetrical': None,
            }

        The `source` model defines the model where the join is being created
        from (the left side of the join). The `target` model defines the
        target model (the right side of the join). `field` is optional,
        but explicitly defines the model field that will be used for the join.
        This is useful if there are more than one foreign key relationships on
        between target and source. Finally, `symmetrical` is an optional
        boolean that ensures when the target and source models switch sides,
        the same join occurs on the same field.

        Routes are typically used for defining explicit join paths, but
        sometimes it is necessary to exclude join paths. For example if there
        are three possible paths and one should never occur.

        A modeltree config can have `required_routes` and `excluded_routes`
        entries, which are lists of routes in the above format.

        A required route is defined as follows: a join to the specified target
        model is only allowed from the specified source model.  A model can
        only be specified as a target once in the list of required routes.
        Note that the use of the `symmetrical` property of a route
        implicitly adds another route with target and source models swapped,
        so a model can be a target either directly or indirectly.  A single
        source model can participate in multiple required routes.

        An excluded route is more obvious: joining from the specified source
        model to the specified target model is not allowed.

    """                                                           # noqa: W605

    def __init__(self, model=None, **kwargs):
        if model is None and 'root_model' in kwargs:
            warnings.warn('The "root_model" key has been renamed to "model"',
                          DeprecationWarning)
            model = kwargs.get('root_model')

        if not model:
            raise TypeError('No "model" defined')

        excluded_models = kwargs.get('excluded_models', ())
        required_routes = kwargs.get('required_routes')

        if not excluded_models and 'exclude' in kwargs:
            warnings.warn(
                'The "exclude" key has been renamed to '
                '"excluded_models"', DeprecationWarning)

            excluded_models = kwargs.get('exclude', ())

        if not required_routes and 'routes' in kwargs:
            warnings.warn(
                'The "routes" key has been renamed to '
                '"required_routes"', DeprecationWarning)

            required_routes = kwargs.get('routes')

        excluded_routes = kwargs.get('excluded_routes')

        self.root_model = self.get_model(model, local=False)
        self.alias = kwargs.get('alias', None)

        # Models completely excluded from the tree
        self.excluded_models = [
            self.get_model(label, local=False) for label in excluded_models
        ]

        # Build the routes that are allowed/preferred
        self._required_joins = self._build_routes(
            required_routes, allow_redundant_targets=False)

        # Build the routes that are excluded
        self._excluded_joins = self._build_routes(excluded_routes)

        # cache each node relative their models
        self._nodes = {}

        # cache all app names relative to their model names i.e. supporting
        # multiple apps with models of the same name
        self._model_apps = MultiValueDict({})

        # cache (app, model) pairs with the respective model class
        self._models = {}

        self._build()

    def __repr__(self):
        return '<ModelTree for {0}>'.format(self.root_model.__name__)

    def _get_local_model(self, model_name, app_name=None):
        "Attempts to get a model from local cache."
        if not app_name:
            app_names = self._model_apps.getlist(model_name)
            # No apps found with this model
            if not app_names:
                return

            # Multiple apps found for this model
            if len(app_names) > 1:
                raise ModelNotUnique(
                    'The model "{0}" is not unique. '
                    'Specify the app name as well.'.format(model_name))

            app_name = app_names[0]

        try:
            return self._models[(app_name, model_name)]
        except KeyError:
            pass

    def _get_model(self, model_name, app_name=None):
        "Attempts to get a model from application cache."
        model = None

        # If an app name is supplied we can reduce it down to only models
        # within that particular app.
        if app_name:
            model = apps.get_model(app_name, model_name)
        else:
            # Attempt to find the model based on the name. Since we don't
            # have the app name, if a model of the same name exists multiple
            # times, we need to throw an error.
            for app, app_models in list(apps.app_models.items()):
                if model_name in app_models:
                    if model is not None:
                        raise ModelNotUnique(
                            'The model "{0}" is not unique. '
                            'Specify the app name as well.'.format(model_name))

                    model = app_models[model_name]

        return model

    def get_model(self, model_name=None, app_name=None, local=True):
        """A few variations are handled here for increased flexibility:

            - if a model class is given, simply echo the model back

            - if a app-model label e.g. 'library.book', is passed, the
            standard app_models cache is used

            - if `app_name` and `model_name` is provided, the standard
            app_models cache is used

            - if only `model_name` is supplied, attempt to find the model
            across all apps. if the model is found more than once, an error
            is thrown

            - if `local` is true, only models related to this `ModelTree`
            instance are searched through
        """
        model = None

        if not (app_name or model_name):
            return self.root_model

        # model class
        if inspect.isclass(model_name) and \
                issubclass(model_name, models.Model):
            # set it initially for either local and non-local
            model = model_name

            # additional check to ensure the model exists locally, reset to
            # None if it does not
            if local and model not in self._nodes:
                model = None

        # handle string-based arguments
        else:
            # handle the syntax 'library.book'
            if model_name:
                if '.' in model_name:
                    app_name, model_name = model_name.split('.', 1)
                model_name = model_name.lower()

            if local:
                model = self._get_local_model(model_name, app_name)
            else:
                model = self._get_model(model_name, app_name)

        # both mechanisms above may result in no model being found
        if model is None:
            if local:
                raise ModelNotRelated(
                    'No model found named "{0}"'.format(model_name))
            else:
                raise ModelDoesNotExist(
                    'No model found named "{0}"'.format(model_name))

        return model

    def get_field(self, name, model=None):
        if model is None:
            model = self.root_model
        return model._meta.get_field(name)

    def _build_routes(self, routes, allow_redundant_targets=True):
        """Routes provide a means of specifying JOINs between two tables.

        routes - a collection of dicts defining source->target mappings
                 with optional `field` specifier and `symmetrical` attribute.

        allow_redundant_targets - whether two routes in this collection
                 are allowed to have the same target - this should NOT
                 be allowed for required routes.
        """
        routes = routes or ()
        joins = {}
        targets_seen = set()

        for route in routes:
            if isinstance(route, dict):
                source_label = route.get('source')
                target_label = route.get('target')
                field_label = route.get('field')
                symmetrical = route.get('symmetrical')
            else:
                warnings.warn('Routes are now defined as dicts',
                              DeprecationWarning)
                source_label, target_label, field_label, symmetrical = route

            # get models
            source = self.get_model(source_label, local=False)
            target = self.get_model(target_label, local=False)

            field = None

            # get field
            if field_label:
                model_name, field_name = field_label.split('.', 1)
                model_name = model_name.lower()

                # determine which model the join field specified exists on
                if model_name == source.__name__.lower():
                    field = self.get_field(field_name, source)
                elif model_name == target.__name__.lower():
                    field = self.get_field(field_name, target)
                else:
                    raise TypeError('model for join field, "{0}", '
                                    'does not exist'.format(field_name))

                if isinstance(field, (ManyToOneRel, ManyToManyRel)):
                    field = field.field

            if not allow_redundant_targets:
                if target in targets_seen:
                    raise ValueError('Model {0} cannot be the target of '
                                     'more than one route in this list'.format(
                                         target_label))
                else:
                    targets_seen.add(target)

            # The `joins` hash defines pairs which are explicitly joined
            # via the specified field.  If no field is defined, then the
            # join field is implied or does not matter; the route is reduced
            #  to a straight lookup.
            joins[(source, target)] = field

            if symmetrical:
                if not allow_redundant_targets:
                    if source in targets_seen:
                        raise ValueError(
                            'Model {0} cannot be the target of '
                            'more than one route in this list'.format(
                                source_label))
                    else:
                        targets_seen.add(source)

                joins[(target, source)] = field

        return joins

    def _join_allowed(self, source, target, field=None):
        """Checks if the join between `source` and `target` via `field`
        is allowed.
        """
        join = (source, target)

        # No circles
        if target == source:
            return False

        # Prevent join to excluded models
        if target in self.excluded_models:
            return False

        # Never go back through the root
        if target == self.root_model:
            return False

        # Apply excluded joins if any
        if join in self._excluded_joins:
            _field = self._excluded_joins[join]
            if not _field:
                return False
            elif _field and _field == field:
                return False

        # Check if the join is allowed by a required rule
        for (_source, _target), _field in list(self._required_joins.items()):
            if _target == target:
                if _source != source:
                    return False

                # If a field is supplied, check to see if the field is allowed
                # for this join.
                if field and _field and _field != field:
                    return False

        return True

    def _add_node(self, parent, model, relation, reverse, related_name,
                  accessor_name, nullable, depth):
        """Adds a node to the tree only if a node of the same `model' does not
        already exist in the tree with smaller depth. If the node is added, the
        tree traversal continues finding the node's relations.

        Conditions in which the node will fail to be added:

            - a reverse relationship is blocked via the '+'
            - the model is excluded completely
            - the model is going back the same path it came from
            - the model is circling back to the root_model
            - the model does not come from an explicitly declared parent model


        This is running in a recursive way with _find_relations
        They keep calling each other as they go through all the models
        """
        # Reverse relationships
        if reverse and '+' in related_name:
            return

        node_hash = self._nodes.get(model, None)

        # don't add node if a path with a shorter depth exists. this is applied
        # after the correct join has been determined. generally if a route is
        # defined for relation, this will never be an issue since there would
        # only be one path available. if a route is not defined, the shorter
        # path will be found
        if not node_hash or node_hash['depth'] > depth:
            if node_hash:
                node_hash['parent'].remove_child(model)

            node = ModelTreeNode(model, parent, relation, reverse,
                                 related_name, accessor_name, nullable, depth)

            self._nodes[model] = {
                'parent': parent,
                'depth': depth,
                'node': node,
            }

            node = self._find_relations(node, depth)
            parent.children.append(node)

    def _find_relations(self, node, depth=0):
        """
        Finds all relations given a node.
        
        This runs in a recursive way with _add_node. They keep calling each
        other based on depth.
        """
        depth += 1

        model = node.model

        # NOTE: the many-to-many relations are evaluated first to prevent
        # 'through' models being bound as a ForeignKey relationship.
        fields = sorted(model._meta.get_fields(),
                        reverse=True,
                        key=lambda f: bool(f.many_to_many))

        # determine relational fields to determine paths
        # f.rel changed to f.remote_field for django2
        # f.rel.to changed to f.remote_field.model for django2
        forward_fields = [
            f for f in fields
            if (f.one_to_one or f.many_to_many or f.many_to_one) and (
                f.concrete or not f.auto_created) and f.remote_field is
            not None  # Generic foreign keys do not define rel.
            and self._join_allowed(f.model, f.remote_field.model, f)
        ]
        reverse_fields = [
            f for f in fields
            if (f.one_to_many or f.one_to_one or f.many_to_many) and (
                not f.concrete and f.auto_created)
            and self._join_allowed(f.model, f.related_model, f.field)
        ]

        def get_relation_type(f):
            if f.one_to_one:
                return 'onetone'
            elif f.many_to_many:
                return 'manytomany'
            elif f.one_to_many or f.many_to_one:
                return 'foreignkey'

        # Iterate over forward relations
        # changed f.rel.to to f.remote_field.model for django2
        for f in forward_fields:
            null = f.many_to_many or f.null
            kwargs = {
                'parent': node,
                'model': f.remote_field.model,
                'relation': get_relation_type(f),
                'reverse': False,
                'related_name': f.name,
                'accessor_name': f.name,
                'nullable': null,
                'depth': depth,
            }
            self._add_node(**kwargs)

        # Iterate over reverse relations.
        for r in reverse_fields:
            kwargs = {
                'parent': node,
                'model': r.related_model,
                'relation': get_relation_type(r),
                'reverse': True,
                'related_name': r.field.related_query_name(),
                'accessor_name': r.get_accessor_name(),
                'nullable': True,
                'depth': depth,
            }
            self._add_node(**kwargs)

        return node

    def _build(self):
        node = ModelTreeNode(self.root_model)
        self._root_node = self._find_relations(node)

        self._nodes[self.root_model] = {
            'parent': None,
            'depth': 0,
            'node': self._root_node,
        }

        # store local cache of all models in this tree by name
        for model in self._nodes:
            model_name = model._meta.object_name.lower()
            app_name = model._meta.app_label

            self._model_apps.appendlist(model_name, app_name)
            self._models[(app_name, model_name)] = model

    @property
    def root_node(self):
        "Returns the `root_node` and implicitly builds the tree."
        if not hasattr(self, '_root_node'):
            self._build()
        return self._root_node

    def _node_path_to_model(self, model, node, path=[]):
        "Returns a list representing the path of nodes to the model."
        if node.model == model:
            return path

        for child in node.children:
            mpath = self._node_path_to_model(model, child, path + [child])
            # TODO why is this condition here?
            if mpath:
                return mpath

    def _node_path(self, model):
        "Returns a list of nodes thats defines the path of traversal."
        model = self.get_model(model)
        return self._node_path_to_model(model, self.root_node)

    def get_joins(self, model):
        """Returns a list of JOIN connections that can be manually applied to a
        QuerySet object. See `.add_joins()`

        This allows for the ORM to handle setting up the JOINs which may be
        different depending on the QuerySet being altered.
        """
        node_path = self._node_path(model)

        joins = []
        for i, node in enumerate(node_path):
            # ignore each subsequent first join in the set of joins for a
            # given model
            table, path_joins = node.get_joins()
            if i == 0:
                joins.append(table)
            joins.extend(path_joins)

        return joins

    def query_string(self, model):
        nodes = self._node_path(model)
        return str('__'.join(n.related_name for n in nodes))

    def query_string_for_field(self, field, operator=None, model=None):
        """Takes a `models.Field` instance and returns a query string relative
        to the root model.
        """
        if model:
            if model._meta.proxy and \
                    model._meta.proxy_for_model is not field.model:
                raise ModelTreeError('proxied model must be the field model')

        else:
            model = field.model

        # When an explicit reverse field is used, simply use it directly
        if isinstance(field, (ManyToManyRel, ManyToOneRel)):
            toks = [field.field.related_query_name()]
        else:
            path = self.query_string(model)

            if path:
                toks = [path, field.name]
            else:
                toks = [field.name]

        if operator is not None:
            toks.append(operator)

        return str('__'.join(toks))

    def query_condition(self, field, operator, value, model=None):
        "Conveniece method for constructing a `Q` object for a given field."
        lookup = self.query_string_for_field(field,
                                             operator=operator,
                                             model=model)
        return Q(**{lookup: value})

    def add_joins(self, model, queryset=None):
        """Sets up all necessary joins up to the given model on the queryset.
        Returns the alias to the model's database table.
        """
        if queryset is None:
            clone = self.get_queryset()
        else:
            clone = queryset._clone()

        alias = None

        for i, join in enumerate(self.get_joins(model)):
            if isinstance(join, BaseTable):
                alias_map = clone.query.alias_map
                if join.table_alias in alias_map or \
                        join.table_name in alias_map:
                    continue
            alias = clone.query.join(join)

        # this implies the join is redundant and occurring on the root model's
        # table
        if alias is None:
            alias = clone.query.get_initial_alias()

        return clone, alias

    def add_select(self, *fields, **kwargs):
        "Replaces the `SELECT` columns with the ones provided."
        if 'queryset' in kwargs:
            queryset = kwargs.pop('queryset')
        else:
            queryset = self.get_queryset()

        queryset.query.default_cols = False
        include_pk = kwargs.pop('include_pk', True)

        if include_pk:
            fields = [self.root_model._meta.pk] + list(fields)

        aliases = []

        for pair in fields:
            if isinstance(pair, (list, tuple)):
                model, field = pair
            else:
                field = pair
                model = field.model

            queryset, alias = self.add_joins(model, queryset)

            aliases.append(Col(alias, field, field))

        if aliases:
            queryset.query.select = aliases

        return queryset

    def get_queryset(self):
        "Returns a QuerySet relative to the `root_model`."
        return self.root_model._default_manager.get_queryset()
 def test_appendlist(self):
     d = MultiValueDict()
     d.appendlist("name", "Adrian")
     d.appendlist("name", "Simon")
     self.assertEqual(d.getlist("name"), ["Adrian", "Simon"])
Example #26
0
class ModelTree(object):
    """A class to handle building and parsing a tree structure given a model.

        `root_model` - The root or "reference" model for the tree. Everything
        is relative to the root model.

        `exclude` - A list of models that are to be excluded from this tree.
        This is typically used to exclude models not intended to be exposed
        through this API.

        `routes` - Explicitly defines a join path between two models. Each
        route is made up of four components. Assuming some model hierarchy
        exists as shown below..

                                ModelA
                                /    \
                            ModelB  ModelC
                               |      |
                               \    ModelD
                                \    /
                                ModelE

        ..the traversal path from ModelA to ModelE is ambiguous. It could
        go from A -> B -> E or A -> C -> D -> E. By default, the shortest
        path is always choosen to reduce the number of joins necessary, but
        if ModelD did not exist..

                                ModelA
                                 /  \
                            ModelB  ModelC
                                 \  /
                                ModelE

        ..both paths only require two joins, thus the path that gets traversed
        first will only be the choosen one.

        To explicitly choose a path, a route can be defined. Taking the form::

            {
                'source': 'app1.model1',
                'target': 'app1.model2',
                'field': None,
                'symmetrical': None,
            }

        The `source` model defines the model where the join is being created
        from (the left side of the join). The `target` model defines the
        target model (the right side of the join). `field` is optional,
        but explicitly defines the model field that will be used for the join.
        This is useful if there are more than one foreign key relationships on
        between target and source. Finally, `symmetrical` is an optional
        boolean that ensures when the target and source models switch sides,
        the same join occurs on the same field.

        Routes are typically used for defining explicit join paths, but
        sometimes it is necessary to exclude join paths. For example if there
        are three possible paths and one should never occur.

        A modeltree config takes the `required_routes` and `excluded_routes`
        which is a list of routes in the above format.

    """
    def __init__(self, model=None, **kwargs):
        if model is None and 'root_model' in kwargs:
            warnings.warn('The "root_model" key has been renamed to "model"',
                          DeprecationWarning)
            model = kwargs.get('root_model')

        if not model:
            raise TypeError('No "model" defined')

        excluded_models = kwargs.get('excluded_models', ())
        required_routes = kwargs.get('required_routes')

        if not excluded_models and 'exclude' in kwargs:
            warnings.warn('The "exclude" key has been renamed to '
                          '"excluded_models"', DeprecationWarning)

            excluded_models = kwargs.get('exclude', ())

        if not required_routes and 'routes' in kwargs:
            warnings.warn('The "routes" key has been renamed to '
                          '"required_routes"', DeprecationWarning)

            required_routes = kwargs.get('routes')

        excluded_routes = kwargs.get('excluded_routes')

        self.root_model = self.get_model(model, local=False)
        self.alias = kwargs.get('alias', None)

        # Models completely excluded from the tree
        self.excluded_models = [self.get_model(label, local=False)
                                for label in excluded_models]

        # Build the routes are allowed/preferred
        self._required_joins, self._required_join_fields = \
            self._build_routes(required_routes)

        # Build the routes that are excluded
        self._excluded_joins, self._excluded_join_fields = \
            self._build_routes(excluded_routes)

        # cache each node relative their models
        self._nodes = {}

        # cache all app names relative to their model names i.e. supporting
        # multiple apps with models of the same name
        self._model_apps = MultiValueDict({})

        # cache (app, model) pairs with the respective model class
        self._models = {}

        self._build()

    def __repr__(self):
        return u'<ModelTree for {0}>'.format(self.root_model.__name__)

    def _get_local_model(self, model_name, app_name=None):
        "Attempts to get a model from local cache."
        if not app_name:
            app_names = self._model_apps.getlist(model_name)
            # No apps found with this model
            if not app_names:
                return

            # Multiple apps found for this model
            if len(app_names) > 1:
                raise ModelNotUnique('The model "{0}" is not unique. '
                                     'Specify the app name as well.'
                                     .format(model_name))

            app_name = app_names[0]

        try:
            return self._models[(app_name, model_name)]
        except KeyError:
            pass

    def _get_model(self, model_name, app_name=None):
        "Attempts to get a model from application cache."
        model = None

        # If an app name is supplied we can reduce it down to only models
        # within that particular app.
        if app_name:
            model = models.get_model(app_name, model_name)
        else:
            # Attempt to find the model based on the name. Since we don't
            # have the app name, if a model of the same name exists multiple
            # times, we need to throw an error.
            for app, app_models in loading.cache.app_models.items():
                if model_name in app_models:
                    if model is not None:
                        raise ModelNotUnique('The model "{0}" is not unique. '
                                             'Specify the app name as well.'
                                             .format(model_name))

                    model = app_models[model_name]

        return model

    def get_model(self, model_name=None, app_name=None, local=True):
        """A few variations are handled here for increased flexibility:

            - if a model class is given, simply echo the model back

            - if a app-model label e.g. 'library.book', is passed, the
            standard app_models cache is used

            - if `app_name` and `model_name` is provided, the standard
            app_models cache is used

            - if only `model_name` is supplied, attempt to find the model
            across all apps. if the model is found more than once, an error
            is thrown

            - if `local` is true, only models related to this `ModelTree`
            instance are searched through
        """
        model = None

        if not (app_name or model_name):
            return self.root_model

        # model class
        if inspect.isclass(model_name) and \
                issubclass(model_name, models.Model):
            # set it initially for either local and non-local
            model = model_name

            # additional check to ensure the model exists locally, reset to
            # None if it does not
            if local and model not in self._nodes:
                model = None

        # handle string-based arguments
        else:
            # handle the syntax 'library.book'
            if model_name:
                if '.' in model_name:
                    app_name, model_name = model_name.split('.', 1)
                model_name = model_name.lower()

            if local:
                model = self._get_local_model(model_name, app_name)
            else:
                model = self._get_model(model_name, app_name)

        # both mechanisms above may result in no model being found
        if model is None:
            if local:
                raise ModelNotRelated('No model found named "{0}"'
                                      .format(model_name))
            else:
                raise ModelDoesNotExist('No model found named "{0}"'
                                        .format(model_name))

        return model

    def get_field(self, name, model=None):
        if model is None:
            model = self.root_model
        return model._meta.get_field_by_name(name)[0]

    def _build_routes(self, routes):
        "Routes provide a means of specifying JOINs between two tables."
        routes = routes or ()
        joins = {}
        join_fields = {}

        for route in routes:
            if isinstance(route, dict):
                source_label = route.get('source')
                target_label = route.get('target')
                field_label = route.get('field')
                symmetrical = route.get('symmetrical')
            else:
                warnings.warn('Routes are now defined as dicts',
                              DeprecationWarning)
                source_label, target_label, field_label, symmetrical = route

            # get models
            source = self.get_model(source_label, local=False)
            target = self.get_model(target_label, local=False)

            field = None

            # get field
            if field_label:
                model_name, field_name = field_label.split('.', 1)
                model_name = model_name.lower()

                # determine which model the join field specified exists on
                if model_name == source.__name__.lower():
                    field = self.get_field(field_name, source)
                elif model_name == target.__name__.lower():
                    field = self.get_field(field_name, target)
                else:
                    raise TypeError('model for join field, "{0}", '
                                    'does not exist'.format(field_name))

                if isinstance(field, RelatedObject):
                    field = field.field

            # the `joins` hash defines pairs which are explicitly joined
            # via the specified field
            # if no field is defined, then the join field is implied or
            # does not matter. the route is reduced to a straight lookup
            joins[target] = source
            if symmetrical:
                joins[source] = target

            if field is not None:
                join_fields[(source, target)] = field
                if symmetrical:
                    join_fields[(target, source)] = field

        return joins, join_fields

    def _join_allowed(self, source, target, field=None):
        """Checks if the join between `source` and `target` via `field`
        is allowed.
        """
        join = (source, target)

        # No circles
        if target == source:
            return False

        # Prevent join to excluded models
        if target in self.excluded_models:
            return False

        # Never go back through the root
        if target == self.root_model:
            return False

        # Check if the join is excluded via a specific field
        if field and join in self._excluded_join_fields:
            _field = self._excluded_join_fields[join]
            if _field == field:
                return False

        # Model level..
        elif source == self._excluded_joins.get(target):
            return False

        # Check if the join is allowed
        if target in self._required_joins:
            _source = self._required_joins[target]
            if _source != source:
                return False

            # If a field is supplied, check to see if the field is allowed
            # for this join.
            if field:
                _field = self._required_join_fields.get(join)
                if _field and _field != field:
                    return False

        return True

    def _filter_one2one(self, field):
        """Tests if the field is a OneToOneField.

        If a route exists for this field's model and it's target model, ensure
        this is the field that should be used to join the the two tables.
        """
        if isinstance(field, models.OneToOneField):
            if self._join_allowed(field.model, field.rel.to, field):
                return field

    def _filter_related_one2one(self, rel):
        """Tests if this RelatedObject represents a OneToOneField.

        If a route exists for this field's model and it's target model, ensure
        this is the field that should be used to join the the two tables.
        """
        field = rel.field
        if isinstance(field, models.OneToOneField):
            if self._join_allowed(rel.parent_model, rel.model, field):
                return rel

    def _filter_fk(self, field):
        """Tests if this field is a ForeignKey.

        If a route exists for this field's model and it's target model, ensure
        this is the field that should be used to join the the two tables.
        """
        if isinstance(field, models.ForeignKey):
            if self._join_allowed(field.model, field.rel.to, field):
                return field

    def _filter_related_fk(self, rel):
        """Tests if this RelatedObject represents a ForeignKey.

        If a route exists for this field's model and it's target model, ensure
        this is the field that should be used to join the the two tables.
        """
        field = rel.field
        if isinstance(field, models.ForeignKey):
            if self._join_allowed(rel.parent_model, rel.model, field):
                return rel

    def _filter_m2m(self, field):
        """Tests if this field is a ManyToManyField.

        If a route exists for this field's model and it's target model, ensure
        this is the field that should be used to join the the two tables.
        """
        if isinstance(field, models.ManyToManyField):
            if self._join_allowed(field.model, field.rel.to, field):
                return field

    def _filter_related_m2m(self, rel):
        """Tests if this RelatedObject represents a ManyToManyField.

        If a route exists for this field's model and it's target model, ensure
        this is the field that should be used to join the the two tables.
        """
        field = rel.field
        if isinstance(field, models.ManyToManyField):
            if self._join_allowed(rel.parent_model, rel.model, field):
                return rel

    def _add_node(self, parent, model, relation, reverse, related_name,
                  accessor_name, nullable, depth):
        """Adds a node to the tree only if a node of the same `model' does not
        already exist in the tree with smaller depth. If the node is added, the
        tree traversal continues finding the node's relations.

        Conditions in which the node will fail to be added:

            - a reverse relationship is blocked via the '+'
            - the model is excluded completely
            - the model is going back the same path it came from
            - the model is circling back to the root_model
            - the model does not come from an explicitly declared parent model
        """
        # Reverse relationships
        if reverse and '+' in related_name:
            return

        node_hash = self._nodes.get(model, None)

        # don't add node if a path with a shorter depth exists. this is applied
        # after the correct join has been determined. generally if a route is
        # defined for relation, this will never be an issue since there would
        # only be one path available. if a route is not defined, the shorter
        # path will be found
        if not node_hash or node_hash['depth'] > depth:
            if node_hash:
                node_hash['parent'].remove_child(model)

            node = ModelTreeNode(model, parent, relation, reverse,
                                 related_name, accessor_name, nullable, depth)

            self._nodes[model] = {
                'parent': parent,
                'depth': depth,
                'node': node,
            }

            node = self._find_relations(node, depth)
            parent.children.append(node)

    def _find_relations(self, node, depth=0):
        """Finds all relations given a node.

        NOTE: the many-to-many relations are evaluated first to prevent
        'through' models being bound as a ForeignKey relationship.
        """
        depth += 1

        model = node.model
        opts = model._meta

        # determine relational fields to determine paths
        forward_fields = opts.fields
        reverse_fields = opts.get_all_related_objects()

        forward_o2o = filter(self._filter_one2one, forward_fields)
        reverse_o2o = filter(self._filter_related_one2one, reverse_fields)

        forward_fk = filter(self._filter_fk, forward_fields)
        reverse_fk = filter(self._filter_related_fk, reverse_fields)

        forward_m2m = filter(self._filter_m2m, opts.many_to_many)
        reverse_m2m = filter(self._filter_related_m2m,
                             opts.get_all_related_many_to_many_objects())

        # iterate m2m relations
        for f in forward_m2m:
            kwargs = {
                'parent': node,
                'model': f.rel.to,
                'relation': 'manytomany',
                'reverse': False,
                'related_name': f.name,
                'accessor_name': f.name,
                'nullable': True,
                'depth': depth,
            }
            self._add_node(**kwargs)

        # iterate over related m2m fields
        for r in reverse_m2m:
            kwargs = {
                'parent': node,
                'model': r.model,
                'relation': 'manytomany',
                'reverse': True,
                'related_name': r.field.related_query_name(),
                'accessor_name': r.get_accessor_name(),
                'nullable': True,
                'depth': depth,
            }
            self._add_node(**kwargs)

        # iterate over one2one fields
        for f in forward_o2o:
            kwargs = {
                'parent': node,
                'model': f.rel.to,
                'relation': 'onetoone',
                'reverse': False,
                'related_name': f.name,
                'accessor_name': f.name,
                'nullable': False,
                'depth': depth,
            }
            self._add_node(**kwargs)

        # iterate over related one2one fields
        for r in reverse_o2o:
            kwargs = {
                'parent': node,
                'model': r.model,
                'relation': 'onetoone',
                'reverse': True,
                'related_name': r.field.related_query_name(),
                'accessor_name': r.get_accessor_name(),
                'nullable': False,
                'depth': depth,
            }
            self._add_node(**kwargs)

        # iterate over fk fields
        for f in forward_fk:
            kwargs = {
                'parent': node,
                'model': f.rel.to,
                'relation': 'foreignkey',
                'reverse': False,
                'related_name': f.name,
                'accessor_name': f.name,
                'nullable': f.null,
                'depth': depth,
            }
            self._add_node(**kwargs)

        # iterate over related foreign keys
        for r in reverse_fk:
            kwargs = {
                'parent': node,
                'model': r.model,
                'relation': 'foreignkey',
                'reverse': True,
                'related_name': r.field.related_query_name(),
                'accessor_name': r.get_accessor_name(),
                'nullable': True,
                'depth': depth,
            }
            self._add_node(**kwargs)

        return node

    def _build(self):
        node = ModelTreeNode(self.root_model)
        self._root_node = self._find_relations(node)

        self._nodes[self.root_model] = {
            'parent': None,
            'depth': 0,
            'node': self._root_node,
        }

        # store local cache of all models in this tree by name
        for model in self._nodes:
            model_name = model._meta.object_name.lower()
            app_name = model._meta.app_label

            self._model_apps.appendlist(model_name, app_name)
            self._models[(app_name, model_name)] = model

    @property
    def root_node(self):
        "Returns the `root_node` and implicitly builds the tree."
        if not hasattr(self, '_root_node'):
            self._build()
        return self._root_node

    def _node_path_to_model(self, model, node, path=[]):
        "Returns a list representing the path of nodes to the model."
        if node.model == model:
            return path

        for child in node.children:
            mpath = self._node_path_to_model(model, child, path + [child])
            # TODO why is this condition here?
            if mpath:
                return mpath

    def _node_path(self, model):
        "Returns a list of nodes thats defines the path of traversal."
        model = self.get_model(model)
        return self._node_path_to_model(model, self.root_node)

    def get_joins(self, model, **kwargs):
        """Returns a list of JOIN connections that can be manually applied to a
        QuerySet object. See `.add_joins()`

        This allows for the ORM to handle setting up the JOINs which may be
        different depending on the QuerySet being altered.
        """
        node_path = self._node_path(model)

        joins = []
        for i, node in enumerate(node_path):
            # ignore each subsequent first join in the set of joins for a
            # given model
            if i > 0:
                joins.extend(node.get_joins(**kwargs)[1:])
            else:
                joins.extend(node.get_joins(**kwargs))

        return joins

    def query_string(self, model):
        nodes = self._node_path(model)
        return str('__'.join(n.related_name for n in nodes))

    def query_string_for_field(self, field, operator=None, model=None):
        """Takes a `models.Field` instance and returns a query string relative
        to the root model.
        """
        if model:
            if model._meta.proxy and \
                    model._meta.proxy_for_model is not field.model:
                raise ModelTreeError('proxied model must be the field model')

        else:
            model = field.model

        # When an explicit reverse field is used, simply use it directly
        if isinstance(field, RelatedObject):
            toks = [field.field.related_query_name()]
        else:
            path = self.query_string(model)

            if path:
                toks = [path, field.name]
            else:
                toks = [field.name]

        if operator is not None:
            toks.append(operator)

        return str('__'.join(toks))

    def query_condition(self, field, operator, value, model=None):
        "Conveniece method for constructing a `Q` object for a given field."
        lookup = self.query_string_for_field(field, operator=operator,
                                             model=model)
        return Q(**{lookup: value})

    def add_joins(self, model, queryset=None, **kwargs):
        """Sets up all necessary joins up to the given model on the queryset.
        Returns the alias to the model's database table.
        """
        if queryset is None:
            clone = self.get_queryset()
        else:
            clone = queryset._clone()

        alias = None

        for i, join in enumerate(self.get_joins(model, **kwargs)):
            alias = clone.query.join(**join)

        # this implies the join is redudant and occuring on the root model's
        # table
        if alias is None:
            alias = clone.query.get_initial_alias()

        return clone, alias

    def add_select(self, *fields, **kwargs):
        "Replaces the `SELECT` columns with the ones provided."
        if 'queryset' in kwargs:
            queryset = kwargs.pop('queryset')
        else:
            queryset = self.get_queryset()

        include_pk = kwargs.pop('include_pk', True)

        if include_pk:
            fields = [self.root_model._meta.pk] + list(fields)

        aliases = []

        for field in fields:
            queryset, alias = self.add_joins(field.model, queryset, **kwargs)

            col = (alias, field.column)

            if django.VERSION >= (1, 6):
                from django.db.models.sql.constants import SelectInfo

                aliases.append(SelectInfo(col, field))
            else:
                aliases.append(col)

        if aliases:
            queryset.query.select = aliases

        return queryset

    def get_queryset(self):
        "Returns a QuerySet relative to the `root_model`."
        return self.root_model._default_manager.get_query_set()
Example #27
0
 def test_getlist_doesnt_mutate(self):
     x = MultiValueDict({'a': ['1', '2'], 'b': ['3']})
     values = x.getlist('a')
     values += x.getlist('b')
     self.assertEqual(x.getlist('a'), ['1', '2'])
Example #28
0
 def test_getlist_default(self):
     x = MultiValueDict({'a': [1]})
     MISSING = object()
     values = x.getlist('b', default=MISSING)
     self.assertIs(values, MISSING)
 def test_getlist_none_empty_values(self):
     x = MultiValueDict({"a": None, "b": []})
     self.assertIsNone(x.getlist("a"))
     self.assertEqual(x.getlist("b"), [])
Example #30
0
 def setlistdefault(self, key, default_list=()):
     self._assert_mutable()
     if key not in self:
         self.setlist(key, default_list)
     return MultiValueDict.getlist(self, key)
Example #31
0
 def test_appendlist(self):
     d = MultiValueDict()
     d.appendlist("name", "Adrian")
     d.appendlist("name", "Simon")
     self.assertEqual(d.getlist("name"), ["Adrian", "Simon"])