def test_create_node_class_adds_to_cache(self): klass1 = get_node_class_for_model(Book) @six.add_metaclass(ModelNodeMeta) class ModelNode(ModelNodeMixin, StructuredNode): class Meta: model = Book klass2 = get_node_class_for_model(Book) self.assertEqual(klass1, klass2)
def test_m2m_changed_post_clear(self): m2m_changed.disconnect(m2m_changed_handler, dispatch_uid='chemtrails.signals.handlers.m2m_changed_handler') m2m_changed.connect(m2m_changed_handler, dispatch_uid='m2m_changed_handler.test') try: book = BookFixture(Book, generate_m2m={'authors': (1, 1)}).create_one() self.assertEqual(1, len(get_node_class_for_model(Book).nodes.has(authors=True))) book.authors.clear() self.assertEqual(0, len(get_node_class_for_model(Book).nodes.has(authors=True))) self.assertEqual(0, len(get_node_class_for_model(Author).nodes.has(book_set=True))) finally: m2m_changed.connect(m2m_changed_handler, dispatch_uid='chemtrails.signals.handlers.m2m_changed_handler') m2m_changed.disconnect(m2m_changed_handler, dispatch_uid='m2m_changed_handler.test')
def test_m2m_changed_post_add_reverse(self): m2m_changed.disconnect(m2m_changed_handler, dispatch_uid='chemtrails.signals.handlers.m2m_changed_handler') m2m_changed.connect(m2m_changed_handler, dispatch_uid='m2m_changed_handler.test') try: author = AuthorFixture(Author).create_one() self.assertEqual(0, len(get_node_class_for_model(Author).nodes.has(book_set=True))) book = BookFixture(Book, follow_m2m=False, field_values={'authors': []}).create_one() author.book_set.add(book) self.assertEqual(1, len(get_node_class_for_model(Author).nodes.has(book_set=True))) finally: m2m_changed.connect(m2m_changed_handler, dispatch_uid='chemtrails.signals.handlers.m2m_changed_handler') m2m_changed.disconnect(m2m_changed_handler, dispatch_uid='m2m_changed_handler.test')
def create_nodes(): book = BookFixture(Book).create_one() node = get_node_for_object(book) assert isinstance(node, StructuredNode) klass = get_node_class_for_model(Book) assert len(klass.nodes.all()) == 1
def test_get_nodeset_for_queryset(self): queryset = Store.objects.filter( pk__in=map(lambda n: n.pk, StoreFixture(Store).create(count=2, commit=True))) nodeset = get_nodeset_for_queryset(queryset) self.assertIsInstance(nodeset, NodeSet) for node in nodeset: self.assertIsInstance(node, get_node_class_for_model(queryset.model))
def pre_delete_handler(sender, instance, **kwargs): """ Delete the node from the graph before it is removed from the database. """ if settings.ENABLED: klass = get_node_class_for_model(instance._meta.model) node = klass.nodes.get_or_none(**{'pk': instance.pk}) if node: node.delete()
def test_null_foreignkey_is_disconnected(self): post_save.disconnect(post_save_handler, dispatch_uid='chemtrails.signals.handlers.post_save_handler') post_save.connect(post_save_handler, dispatch_uid='post_save_handler.test') try: store = StoreFixture(Store, generate_m2m=False).create_one() klass = get_node_class_for_model(Store) self.assertEqual(store.bestseller.pk, get_node_class_for_model(Book).nodes.get(pk=store.bestseller.pk).pk) self.assertEqual(1, len(klass.nodes.has(bestseller=True))) store.bestseller = None store.save() self.assertEqual(0, len(klass.nodes.has(bestseller=True))) finally: post_save.connect(post_save_handler, dispatch_uid='chemtrails.signals.handlers.post_save_handler') post_save.disconnect(post_save_handler, dispatch_uid='post_save_handler.test')
def test_ignored_model_sync(self): group = Group.objects.create(name='group') get_node_for_object(group).sync() klass = get_node_class_for_model(Group) try: klass.nodes.get(pk=group.pk) self.fail( 'Did not fail when trying to look up an ignored node class.') except klass.DoesNotExist as e: self.assertEqual(str(e), "{'pk': %d}" % group.pk)
def test_inflate_node_class_raises_inflate_error(self): # Make sure the InflateError is raised if trying # to inflate nodes with wrong model author = AuthorFixture(Author).create_one() deflated, _ = list( flatten( db.cypher_query('MATCH (u: UserNode {pk: %d}) RETURN u;' % author.user.pk))) node_class = get_node_class_for_model(Author) self.assertRaises(InflateError, node_class.inflate, deflated)
def test_create_new_object_is_synced(self): post_save.disconnect(post_save_handler, dispatch_uid='chemtrails.signals.handlers.post_save_handler') post_save.connect(post_save_handler, dispatch_uid='post_save_handler.test') try: book = BookFixture(Book).create_one() klass = get_node_class_for_model(Book) self.assertEqual(book.pk, klass.nodes.get(pk=book.pk).pk) self.assertEqual(1, len(klass.nodes.has(authors=True, publisher=True))) finally: post_save.connect(post_save_handler, dispatch_uid='chemtrails.signals.handlers.post_save_handler') post_save.disconnect(post_save_handler, dispatch_uid='post_save_handler.test')
def test_flush_nodes_context_manager(self): with flush_nodes(): book = BookFixture(Book).create_one() node = get_node_for_object(book) self.assertIsInstance(node, StructuredNode) klass = get_node_class_for_model(Book) self.assertEqual(len(klass.nodes.all()), 1) self.assertEqual(len(klass.nodes.all()), 0)
def test_related_ignore_model_sync(self): user = User.objects.create_user(username='******', password='******') group = Group.objects.create(name='group') user.groups.add(group) get_node_for_object(user).sync() klass = get_node_class_for_model(Group) try: klass.nodes.get(pk=group.pk) self.fail( 'Did not fail when trying to look up an ignored node class.') except klass.DoesNotExist as e: self.assertEqual(str(e), "{'pk': %d}" % group.pk)
def test_flush_nodes_context_decorator(self): @flush_nodes() def create_nodes(): book = BookFixture(Book).create_one() node = get_node_for_object(book) assert isinstance(node, StructuredNode) klass = get_node_class_for_model(Book) assert len(klass.nodes.all()) == 1 create_nodes() klass = get_node_class_for_model(Book) self.assertEqual(len(klass.nodes.all()), 0)
def post_migrate_handler(sender, **kwargs): """ Create the meta graph after migrations has been installed. """ if settings.ENABLED: for model in sender.models.values(): install_labels(get_node_class_for_model(model), quiet=False, stdout=sys.stdout) install_labels(get_meta_node_class_for_model(model), quiet=False, stdout=sys.stdout) get_meta_node_for_model(model).sync( max_depth=settings.MAX_CONNECTION_DEPTH, update_existing=True)
def test_delete_object_is_deleted(self): pre_delete.disconnect(pre_delete_handler, dispatch_uid='chemtrails.signals.handlers.pre_delete_handler') pre_delete.connect(pre_delete_handler, dispatch_uid='pre_delete_handler.test') try: book = BookFixture(Book).create_one() klass = get_node_class_for_model(Book) pk = book.pk try: book.delete() klass.nodes.get(pk=pk) self.fail('Did not raise when trying to get non-existent book node.') except klass.DoesNotExist as e: self.assertEqual(str(e), "{'pk': %d}" % pk) finally: pre_delete.connect(pre_delete_handler, dispatch_uid='chemtrails.signals.handlers.pre_delete_handler') pre_delete.disconnect(pre_delete_handler, dispatch_uid='pre_delete_handler.test')
def get_related_node_property_for_field(cls, field, meta_node=False): """ Get the relationship definition for the related node based on field. :param field: Field to inspect :param meta_node: If True, return the meta node for the related model, else return the model node. :returns: A ``RelationshipDefinition`` instance. """ from chemtrails.neoutils import get_node_class_for_model, get_meta_node_class_for_model reverse_field = True if isinstance(field, (models.ManyToManyRel, models.ManyToOneRel, models.OneToOneRel)) else False class DynamicRelation(StructuredRel): type = StringProperty(default=field.__class__.__name__) is_meta = BooleanProperty(default=meta_node) remote_field = StringProperty(default=str('{model}.{field}'.format( model=get_model_string(field.model), field=( field.related_name or '%s_set' % field.name if not isinstance(field, models.OneToOneRel) else field.name )) if reverse_field else field.remote_field.field).lower()) target_field = StringProperty( default=str(field.target_field).lower()) prop = cls.get_property_class_for_field(field.__class__) relationship_type = cls.get_relationship_type(field) if meta_node: klass = (__meta_cache__[field.related_model] if field.related_model in __meta_cache__ else get_meta_node_class_for_model(field.related_model)) return prop(cls_name=klass, rel_type=relationship_type, model=DynamicRelation) else: klass = (__node_cache__[field.related_model] if reverse_field and field.related_model in __node_cache__ else get_node_class_for_model(field.related_model)) return prop(cls_name=klass, rel_type=relationship_type, model=DynamicRelation)
def is_authorized(self, perm, obj): """ Checks if user/group is authorized to access given object. """ target_node = get_node_class_for_model(obj).nodes.get_or_none( **{'pk': obj.pk}) if not target_node: return False if self.user: queryset = get_objects_for_user( self.user, perm, klass=obj._meta.default_manager.filter(pk=obj.pk)) return obj in queryset elif self.group: # TODO: Implement `get_objects_for_group`! queryset = get_objects_for_group( self.group, perm, klass=obj._meta.default_manager.filter(pk=obj.pk)) return obj in queryset
def handle(self, *args, **options): target_file = tempfile.NamedTemporaryFile(delete=True) BUFFER_COUNT = 100 try: for app in apps.all_models: self.stdout.write(self.style.SUCCESS('Looking at {}'.format(app))) model_count = 0 cntr = 0 for n, model in enumerate(apps.get_app_config(app_label=app).get_models()): model_count += 1 cls = get_node_class_for_model(model) if cls._is_ignored: continue for item in model.objects.all(): node = cls(instance=item, bind=False) node.to_csv(cntr=cntr, target_file=target_file) cntr += 1 with open(target_file.name, newline='') as csvfile: spamreader = csv.reader(csvfile, delimiter=';', quotechar='|') self.stdout.write(self.style.SUCCESS('Crucifixion party! Wait for it...')) node_count = 0 skip_count = 0 cypher = '' for row in spamreader: if row[0] == 'n': try: cypher += row[1] node_count += 1 except exception.UniqueProperty: skip_count += 1 if cypher and node_count % BUFFER_COUNT == 0: db.cypher_query(cypher) cypher = '' self.stdout.write('{} nodes processed'.format(node_count), ending='\r') self.stdout.flush() if cypher: db.cypher_query(cypher) self.stdout.write(self.style.SUCCESS( '{} nodes processed and delivered to Neo... ' 'Why oh why, didn\'t I take the blue pill??'.format(node_count))) if skip_count > 0: self.stdout.write(self.style.SUCCESS( '{} nodes already existed, and was not updated.... Were you listening to me Neo? Or were you ' 'looking at the woman in the red dress?'.format(skip_count))) with open(target_file.name, newline='') as csvfile: spamreader = csv.reader(csvfile, delimiter=';', quotechar='|') relation_count = 0 for row in spamreader: if row[0] == 'r': db.cypher_query(row[1]) relation_count += 1 if relation_count % BUFFER_COUNT == 0: self.stdout.write('{} relations processed'.format(relation_count), ending='\r') self.stdout.flush() self.stdout.write(self.style.SUCCESS( '{} relations processed and delivered to Neo... ' 'Why oh why, didn\'t I take the blue pill??'.format(relation_count))) self.stdout.write(self.style.SUCCESS('Crucifixion party! By the left! Forward!')) finally: target_file.close()
def get_users_with_perms(obj, permissions, with_superusers=False, with_group_users=True): """ Returns a queryset of all ``User`` objects which there can be calculated a path from the given ``obj``. :param obj: model instance. :param permissions: Single permission string, or sequence of permissions strings that user requires to have. :param with_superusers: Default: ``False``. If set to ``True`` result would include all superusers. :param with_group_users: Default: ``True``. If set to ``False`` result would **not** include users which has only group permissions for given ``obj``. :raises MixedContentTypeError: If computed content type for ``permissions`` and/or ``obj`` clashes. :returns: Queryset containing ``User`` objects which has ``permissions`` for ``obj``. """ ctype, codenames = check_permissions_app_label(permissions) if ctype is None: ctype = get_content_type(obj) if codenames: # Make sure permissions are valid. _codenames = set( ctype.permission_set.filter( codename__in=codenames).values_list('codename', flat=True)) if not codenames == _codenames: message = ngettext_lazy( 'Calculated content type from permission "%s" does not match %r.' % (next(iter(codenames)), ctype), 'One or more permissions "%s" from calculated content type does not match %r.' % (', '.join(sorted(codenames)), ctype), len(codenames)) raise MixedContentTypeError(message) elif not ctype == get_content_type(obj): raise MixedContentTypeError( 'Calculated content type %r does not match %r.' % (ctype, get_content_type(obj))) queryset = _get_queryset(User) # If there is no node in the graph for ``obj``, return empty queryset. target_node = get_node_class_for_model(obj).nodes.get_or_none( **{'pk': obj.pk}) if not codenames or not target_node: if with_superusers is True: return queryset.filter(is_superuser=True) return queryset.none() ctype_source = get_content_type(User) # We need a fake source content type model to use as origin. fake_model = ctype_source.model_class()() source_node = get_node_for_object(fake_model, bind=False) queries = [] for access_rule in get_access_rules(ctype_source, ctype, codenames): manager = source_node.paths if access_rule.direction is not None: manager.direction = access_rule.direction for n, rule_definition in enumerate(access_rule.relation_types_obj): relation_type, target_props = zip(*rule_definition.items()) relation_type, target_props = relation_type[0], target_props[0] source_props = {} target_props = target_props or {} if n == 0 and access_rule.requires_staff: source_props.update({'is_staff': True}) # Make sure the last object in the query is matched to ``obj``. if n == len(access_rule.relation_types_obj) - 1: target_props['pk'] = target_node.pk # FIXME: Workaround for https://github.com/inonit/django-chemtrails/issues/46 # If using "{source}.<attr>" filters, ignore them! target_props = { key: value for key, value in target_props.items() if isinstance(value, str) and not value.startswith('{source}.') } manager = manager.add(relation_type, source_props=source_props, target_props=target_props) if manager.statement: queries.append(manager.get_path()) q_values = Q() if with_superusers is True: q_values |= Q(is_superuser=True) start_node_class = get_node_class_for_model(queryset.model) end_node_class = get_node_class_for_model(obj) for query in queries: # FIXME: https://github.com/inonit/libcypher-parser-python/issues/1 # validate_cypher(query, raise_exception=True) result, _ = db.cypher_query(query) if result: values = set() for item in flatten(result): if not isinstance(item, Path): # pragma: no cover continue elif (start_node_class.__label__ not in item.start.labels or end_node_class.__label__ not in item.end.labels): continue try: start, end = (start_node_class.inflate(item.start), end_node_class.inflate(item.end)) if isinstance(start, start_node_class) and end == target_node: # Make sure the user object has correct permissions instance = start.get_object() global_perms = set( get_perms(instance, obj) if with_group_users else get_user_perms(instance, obj)) if all((code in global_perms for code in codenames)): values.add(item.start.properties['pk']) except (KeyError, InflateError, ObjectDoesNotExist): continue q_values |= Q(pk__in=values) if not q_values: return queryset.none() return queryset.filter(q_values)
def get_objects_for_user(user, permissions, klass=None, use_groups=True, extra_perms=None, any_perm=False, with_superuser=True): """ Returns a queryset of objects for which there can be calculated a path between the ``user`` using one or more access rules with *all* permissions present at ``permissions``. :param user: ``User`` instance for which objects should be returned. :param permissions: Single permission string, or sequence of permission strings that should be checked. If ``klass`` parameter is not given, those should be full permission strings rather than only codenames (ie. ``auth.change_user``). If more than one permission is present in the sequence, their content type **must** be the same or ``MixedContentTypeError`` would be raised. :param klass: May be a ``Model``, ``Manager`` or ``QuerySet`` object. If not given, this will be calculated based on passed ``permissions`` strings. :param use_groups: If ``True``, include users groups permissions. Defaults to ``True``. :param extra_perms: Single permission string, or sequence of permission strings that should be used as ``global_perms`` base. These permissions will be treated as if the user possesses them. :param any_perm: If ``True``, any permission in sequence is accepted. Defaults to ``False``. :param with_superuser: If ``True`` and ``user.is_superuser`` is set, returns the entire queryset. Otherwise will only return the objects the user has explicit permissions to. Defaults to ``True``. :raises MixedContentTypeError: If computed content type for ``permissions`` and/or ``klass`` clashes. :raises ValueError: If unable to compute content type for ``permissions``. :returns: QuerySet containing objects ``user`` has ``permissions`` to. """ # Make sure all permissions checks out! ctype, codenames = check_permissions_app_label(permissions) if extra_perms: extra_ctype, extra_perms = check_permissions_app_label(extra_perms) if extra_ctype != ctype: raise MixedContentTypeError( 'Calculated content type from keyword argument `extra_perms` ' '%s does not match %r.' % (extra_ctype, ctype)) extra_perms = extra_perms or set() if ctype is None and klass is not None: queryset = _get_queryset(klass) ctype = get_content_type(queryset.model) elif ctype is not None and klass is None: queryset = _get_queryset(ctype.model_class()) elif klass is None: raise ValueError('Could not determine the content type.') else: queryset = _get_queryset(klass) if ctype.model_class() != queryset.model: raise MixedContentTypeError( 'ContentType for given permissions and klass differs.') # Superusers have access to all objects. if with_superuser and user.is_superuser: return queryset # We don't support anonymous users. if user.is_anonymous: return queryset.none() # If there is no node in the graph for the user object, return empty queryset. source_node = get_node_class_for_model(user).nodes.get_or_none( **{'pk': user.pk}) if not source_node: return queryset.none() # Next, get all permissions the user has, either directly set through user permissions # or if ``use_groups`` are set, derived from a group membership. global_perms = extra_perms | set( get_perms(user, queryset.model ) if use_groups else get_user_perms(user, queryset.model)) # Check if we requires the user to have *all* permissions or if it is # sufficient with any provided. if not any_perm and not all((code in global_perms for code in codenames)): return queryset.none() elif any_perm: for code in codenames.copy(): if code not in global_perms: codenames.remove(code) # Calculate a PATH query for each rule queries = [] for access_rule in get_access_rules(get_content_type(user), ctype, codenames): manager = source_node.paths for n, rule_definition in enumerate(access_rule.relation_types_obj): relation_type, target_props = zip(*rule_definition.items()) relation_type, target_props = relation_type[0], target_props[0] source_props = {} if n == 0 and access_rule.requires_staff: source_props.update({'is_staff': True}) manager = manager.add(relation_type, source_props=source_props, target_props=target_props) if manager.statement: queries.append(manager.get_path()) q_values = Q() start_node_class = get_node_class_for_model(user) end_node_class = get_node_class_for_model(queryset.model) for query in queries: # FIXME: https://github.com/inonit/libcypher-parser-python/issues/1 # validate_cypher(query, raise_exception=True) result, _ = db.cypher_query(query) if result: values = set() for item in flatten(result): if not isinstance(item, Path): # pragma: no cover continue elif (start_node_class.__label__ not in item.start.labels or end_node_class.__label__ not in item.end.labels): continue try: start, end = (start_node_class(user).inflate(item.start), end_node_class.inflate(item.end)) if start == source_node and isinstance( end, end_node_class): values.add(item.end.properties['pk']) except (KeyError, InflateError): # pragma: no cover continue q_values |= Q(pk__in=values) # If no values in the Q filter, it means we couldn't get a path from the # user node to given object in queryset by any evaluated rule. # Return an empty queryset. if not q_values: return queryset.none() return queryset.filter(q_values)
def test_get_node_for_object(self): store = StoreFixture(Store).create_one(commit=True) store_node = get_node_for_object(store) self.assertIsInstance(store_node, get_node_class_for_model(Store))
def test_recursive_connect(self): post_save.disconnect( post_save_handler, dispatch_uid='chemtrails.signals.handlers.post_save_handler') m2m_changed.disconnect( m2m_changed_handler, dispatch_uid='chemtrails.signals.handlers.m2m_changed_handler') try: book = BookFixture(Book, generate_m2m={ 'authors': (1, 1) }).create_one() for depth in range(3): db.cypher_query( 'MATCH (n)-[r]-() WHERE n.type = "ModelNode" DELETE r' ) # Delete all relationships book_node = get_node_for_object(book).save() book_node.recursive_connect(depth) if depth == 0: # Max depth 0 means that no recursion should occur, and no connections # can be made, because the connected objects might not exist. for prop in book_node.defined_properties( aliases=False, properties=False).keys(): relation = getattr(book_node, prop) try: self.assertEqual(len(relation.all()), 0) except CardinalityViolation: # Will raise CardinalityViolation for nodes which has a single # required relationship continue elif depth == 1: self.assertEqual( 0, len( get_node_class_for_model(Book).nodes.has( store_set=True))) self.assertEqual( 0, len( get_node_class_for_model(Store).nodes.has( books=True))) self.assertEqual( 0, len( get_node_class_for_model(Book).nodes.has( bestseller_stores=True))) self.assertEqual( 0, len( get_node_class_for_model(Store).nodes.has( bestseller=True))) self.assertEqual( 1, len( get_node_class_for_model(Book).nodes.has( publisher=True))) self.assertEqual( 1, len( get_node_class_for_model(Publisher).nodes.has( book_set=True))) self.assertEqual( 1, len( get_node_class_for_model(Book).nodes.has( authors=True))) self.assertEqual( 1, len( get_node_class_for_model(Author).nodes.has( book_set=True))) self.assertEqual( 0, len( get_node_class_for_model(Author).nodes.has( user=True))) self.assertEqual( 0, len( get_node_class_for_model(User).nodes.has( author=True))) self.assertEqual( 1, len( get_node_class_for_model(Book).nodes.has( tags=True))) self.assertEqual( 0, len( get_node_class_for_model(Tag).nodes.has( content_type=True))) elif depth == 2: self.assertEqual( 1, len( get_node_class_for_model(Author).nodes.has( user=True))) self.assertEqual( 1, len( get_node_class_for_model(User).nodes.has( author=True))) self.assertEqual( 1, len( get_node_class_for_model(Tag).nodes.has( content_type=True))) self.assertEqual( 1, len( get_node_class_for_model(ContentType).nodes.has( content_type_set_for_tag=True))) finally: post_save.connect( post_save_handler, dispatch_uid='chemtrails.signals.handlers.post_save_handler') m2m_changed.connect( m2m_changed_handler, dispatch_uid='chemtrails.signals.handlers.m2m_changed_handler')
def test_ignore_models_case_insensitive(self): klass = get_node_class_for_model(Group) self.assertTrue(klass._is_ignored)
def test_deflate_empty_node_raise_required_property_error(self): klass = get_node_class_for_model(User) self.assertRaisesMessage(RequiredProperty, "property 'pk' on objects of class UserNode", klass.deflate, klass().__properties__)
def test_ignored_models_app_label_model_name(self): klass = get_node_class_for_model(Group) self.assertTrue(klass._is_ignored) klass = get_node_class_for_model(Permission) self.assertFalse(klass._is_ignored)
def test_ignored_models_app_label_wildcard(self): klass = get_node_class_for_model(Group) self.assertTrue(klass._is_ignored) klass = get_node_class_for_model(Book) self.assertFalse(klass._is_ignored)
def test_get_node_class_for_model(self): book = BookFixture(Book).create_one() klass = get_node_class_for_model(book, for_concrete_model=True) self.assertTrue(issubclass(klass, StructuredNode))
def test_get_node_class_for_model(self): klass = get_node_class_for_model(Book) self.assertTrue(issubclass(klass, StructuredNode))
def test_get_property_class_for_field(self): from django.contrib.contenttypes.fields import GenericRelation from django.contrib.postgres.fields import ( ArrayField, HStoreField, JSONField, IntegerRangeField, BigIntegerRangeField, FloatRangeField, DateTimeRangeField, DateRangeField) from django.db import models class CustomField(object): pass klass = get_node_class_for_model(Book) self.assertEqual(klass.get_property_class_for_field(models.ForeignKey), RelationshipTo) self.assertEqual(klass.get_property_class_for_field(models.ForeignKey), RelationshipTo) self.assertEqual( klass.get_property_class_for_field(models.OneToOneField), RelationshipTo) self.assertEqual( klass.get_property_class_for_field(models.ManyToManyField), RelationshipTo) self.assertEqual( klass.get_property_class_for_field(models.ManyToOneRel), RelationshipTo) self.assertEqual( klass.get_property_class_for_field(models.OneToOneRel), RelationshipTo) self.assertEqual( klass.get_property_class_for_field(models.ManyToManyRel), RelationshipTo) self.assertEqual(klass.get_property_class_for_field(GenericRelation), RelationshipTo) self.assertEqual(klass.get_property_class_for_field(models.AutoField), IntegerProperty) self.assertEqual( klass.get_property_class_for_field(models.BigAutoField), IntegerProperty) self.assertEqual( klass.get_property_class_for_field(models.BooleanField), BooleanProperty) self.assertEqual(klass.get_property_class_for_field(models.CharField), StringProperty) self.assertEqual( klass.get_property_class_for_field( models.CommaSeparatedIntegerField), ArrayProperty) self.assertEqual(klass.get_property_class_for_field(models.DateField), DateProperty) self.assertEqual( klass.get_property_class_for_field(models.DateTimeField), DateTimeProperty) self.assertEqual( klass.get_property_class_for_field(models.DecimalField), FloatProperty) self.assertEqual( klass.get_property_class_for_field(models.DurationField), StringProperty) self.assertEqual(klass.get_property_class_for_field(models.EmailField), StringProperty) self.assertEqual( klass.get_property_class_for_field(models.FilePathField), StringProperty) self.assertEqual(klass.get_property_class_for_field(models.FileField), StringProperty) self.assertEqual(klass.get_property_class_for_field(models.FloatField), FloatProperty) self.assertEqual( klass.get_property_class_for_field(models.GenericIPAddressField), StringProperty) self.assertEqual( klass.get_property_class_for_field(models.IntegerField), IntegerProperty) self.assertEqual( klass.get_property_class_for_field(models.IPAddressField), StringProperty) self.assertEqual( klass.get_property_class_for_field(models.NullBooleanField), BooleanProperty) self.assertEqual( klass.get_property_class_for_field(models.PositiveIntegerField), IntegerProperty) self.assertEqual( klass.get_property_class_for_field( models.PositiveSmallIntegerField), IntegerProperty) self.assertEqual(klass.get_property_class_for_field(models.SlugField), StringProperty) self.assertEqual( klass.get_property_class_for_field(models.SmallIntegerField), IntegerProperty) self.assertEqual(klass.get_property_class_for_field(models.TextField), StringProperty) self.assertEqual(klass.get_property_class_for_field(models.TimeField), IntegerProperty) self.assertEqual(klass.get_property_class_for_field(models.URLField), StringProperty) self.assertEqual(klass.get_property_class_for_field(models.UUIDField), StringProperty) # Test special fields # self.assertEqual(klass.get_property_class_for_field(ArrayField), ArrayProperty) self.assertEqual(klass.get_property_class_for_field(HStoreField), JSONProperty) self.assertEqual(klass.get_property_class_for_field(JSONField), JSONProperty) # Test undefined fields by inspecting their base classes. self.assertEqual( klass.get_property_class_for_field(models.BigIntegerField), IntegerProperty) self.assertEqual(klass.get_property_class_for_field(IntegerRangeField), IntegerProperty) self.assertEqual( klass.get_property_class_for_field(BigIntegerRangeField), IntegerProperty) self.assertEqual(klass.get_property_class_for_field(FloatRangeField), FloatProperty) self.assertEqual( klass.get_property_class_for_field(DateTimeRangeField), DateTimeProperty) self.assertEqual(klass.get_property_class_for_field(DateRangeField), DateProperty) # Test unsupported field self.assertRaisesMessage( NotImplementedError, 'Unsupported field. Field CustomField is currently not supported.', klass.get_property_class_for_field, CustomField)