def test_save_tenant_set_different_than_object(self): unset_current_tenant() from .models import Project account = self.account_fr account_2 = self.account_us project = Project(account=account, name='test save fr') project.save() project2 = Project(account=account_2, name='test save us') project2.save() self.assertEqual(Project.objects.count(), 2) set_current_tenant(account_2) project.name = 'test update name' project.save() current_tenant = get_current_tenant() self.assertEqual(current_tenant, account_2) unset_current_tenant() project = Project.objects.filter(account=account).first() self.assertEqual(project.name, 'test update name')
def test_subquery(self): # we want all the projects with the name of their first task import django if django.VERSION[0] == 1 and django.VERSION[1] < 11: # subqueries where only introduced in django 1.11 return from django.db.models import OuterRef, Subquery from .models import Project, Task projects = self.projects account = self.account_fr tasks = self.tasks self.assertEqual(Project.objects.count(), 30) set_current_tenant(account) with self.assertNumQueries(1) as captured_queries: task_qs = Task.objects.filter(project=OuterRef("pk")).order_by('-name') projects = Project.objects.all().annotate( first_task_name=Subquery( task_qs.values('name')[:1] ) ) for p in projects: self.assertTrue(p.first_task_name is not None) # check that tenant in subquery for query in captured_queries.captured_queries: self.assertTrue('U0."account_id" = %d' % account.id in query['sql']) self.assertTrue('WHERE "tests_project"."account_id" = %d' % account.id in query['sql']) unset_current_tenant()
def test_save_tenant_set(self): unset_current_tenant() from .models import Project account = self.account_fr account_2 = self.account_us project = Project(account=account, name="test save fr") project.save() project2 = Project(account=account_2, name="test save us") project2.save() self.assertEqual(Project.objects.count(), 2) set_current_tenant(account) project.name = "test update name" project.save() current_tenant = get_current_tenant() self.assertEqual(current_tenant, account) unset_current_tenant() project = Project.objects.filter(account=account).first() self.assertEqual(project.name, "test update name")
def test_delete_cascade_reference_to_distributed(self): from .models import Country, Account unset_current_tenant() country = self.france account1 = Account.objects.create(name="Account FR", country=country, subdomain="fr.", domain="citusdata.com") account2 = Account.objects.create(name="Account FR 2", country=country, subdomain="fr.", domain="msft.com") self.assertEqual(Account.objects.count(), 2) with self.assertNumQueries(16) as captured_queries: country.delete() self.assertEqual(Account.objects.count(), 0) self.assertEqual(Country.objects.count(), 0) self.assertTrue( "SET LOCAL citus.multi_shard_modify_mode TO 'sequential';" in [query["sql"] for query in captured_queries.captured_queries]) self.assertTrue( "SET LOCAL citus.multi_shard_modify_mode TO 'parallel';" in [query["sql"] for query in captured_queries.captured_queries])
def test_prefetch_related(self): from .models import Project, Account unset_current_tenant() project_managers = self.project_managers project_id = project_managers[0].project_id accounts = [project_managers[0].account, Account.objects.last()] set_current_tenant(accounts) with self.assertNumQueries(2) as captured_queries: project = (Project.objects.filter( pk=project_id).prefetch_related("managers").first()) self.assertTrue( 'WHERE ("tests_manager"."account_id" IN (%s)' % ", ".join([str(account.id) for account in accounts]) in captured_queries.captured_queries[1]["sql"]) pattern = 'AND \\("tests_projectmanager"."account_id" = .?"tests_manager"."account_id"\\)' self.assertTrue( bool( re.search(pattern, captured_queries.captured_queries[1]["sql"]))) unset_current_tenant()
def test_delete_cascade_distributed(self): from .models import Task, Project, SubTask subtasks = self.subtasks account = self.account_fr unset_current_tenant() self.assertEqual(Project.objects.count(), 30) self.assertEqual(Task.objects.count(), 150) self.assertEqual(SubTask.objects.count(), 750) set_current_tenant(account) self.assertEqual(Project.objects.count(), 10) self.assertEqual(Task.objects.count(), 50) self.assertEqual(SubTask.objects.count(), 250) project = Project.objects.first() with self.assertNumQueries(12) as captured_queries: project.delete() self.assertEqual(Project.objects.count(), 9) self.assertEqual(Task.objects.count(), 45) self.assertEqual(SubTask.objects.count(), 225) self.assertFalse( "SET LOCAL citus.multi_shard_modify_mode TO 'sequential';" in [query["sql"] for query in captured_queries.captured_queries]) self.assertFalse( "SET LOCAL citus.multi_shard_modify_mode TO 'parallel';" in [query["sql"] for query in captured_queries.captured_queries])
def test_subquery_joins(self): # we want all the projects with the name of their first task import django from django.db.models import OuterRef, Subquery from .models import Project, SubTask projects = self.projects account = self.account_fr subtasks = self.subtasks self.assertEqual(Project.objects.count(), 30) set_current_tenant(account) with self.assertNumQueries(1) as captured_queries: subtask_qs = SubTask.objects.filter( project=OuterRef("pk"), task__opened=True).order_by("-name") projects = Project.objects.all().annotate( first_subtask_name=Subquery(subtask_qs.values("name")[:1])) for p in projects: self.assertTrue(p.first_subtask_name is not None) # check that tenant in subquery for query in captured_queries.captured_queries: self.assertTrue('U0."account_id" = %d' % account.id in query["sql"]) pattern = '\\(U0."task_id" = U\\d."id" AND \\(U0."account_id" = .?U\\d."account_id"\\)' self.assertTrue(bool(re.search(pattern, query["sql"]))) self.assertTrue('WHERE "tests_project"."account_id" = %d' % account.id in query["sql"]) unset_current_tenant()
def test_delete_cascade_distributed_to_reference(self): from .models import Account, Employee, ModelConfig, Project unset_current_tenant() account = self.account_fr employee = Employee.objects.create(account=account, name='Louise') modelconfig = ModelConfig.objects.create(account=account, employee=employee, name='test') projects = self.projects for project in projects: if project.account == account: project.employee = employee project.save(update_fields=['employee']) account.employee = employee account.save() self.assertEqual(Account.objects.count(), 3) self.assertEqual(Employee.objects.count(), 1) self.assertEqual(ModelConfig.objects.count(), 1) self.assertEqual(Project.objects.count(), 30) set_current_tenant(account) account.delete() # Once deleted, we don't have a current tenant self.assertEqual(Account.objects.count(), 2) self.assertEqual(Employee.objects.count(), 0) self.assertEqual(ModelConfig.objects.count(), 0) self.assertEqual(Project.objects.count(), 20) unset_current_tenant()
def test_set_current_tenant(self): from .models import Project projects = self.projects account = projects[0].account set_current_tenant(account) self.assertEqual(get_current_tenant(), account) unset_current_tenant()
def test_filter_without_joins_on_tenant_id_not_pk(self): from .models import TenantNotIdModel, SomeRelatedModel tenants = self.tenant_not_id self.assertEqual(SomeRelatedModel.objects.count(), 30) set_current_tenant(tenants[0]) self.assertEqual(SomeRelatedModel.objects.count(), 10) unset_current_tenant()
def delete(self, request, pk): """Tenant token -- from header""" tenant_token = request.headers.get("Authorization").split(' ')[-1] tenant = UserTenant.objects.get(token=tenant_token) set_current_tenant(tenant) item = Item.objects.get(id=pk) item.delete() unset_current_tenant() return Response({"result": True})
def test_aggregate(self): from .models import ProjectManager projects = self.projects managers = self.project_managers unset_current_tenant() projects_per_manager = ProjectManager.objects.annotate( Count("project_id")) list(projects_per_manager)
def get(self, request): """Tenant token -- from header""" tenant_token = request.headers.get("Authorization").split(' ')[-1] tenant = UserTenant.objects.get(token=tenant_token) set_current_tenant(tenant) items = Item.objects.all() serializer = ItemSerializer(items, many=True) unset_current_tenant() return Response(serializer.data)
def test_filter_without_joins(self): from .models import Project projects = self.projects account = self.account_fr self.assertEqual(Project.objects.count(), len(projects)) set_current_tenant(account) self.assertEqual(Project.objects.count(), account.projects.count()) unset_current_tenant()
def test_tenant_filters_single_tenant(self): from .models import Project, Account projects = self.projects account = projects[0].account set_current_tenant(account) self.assertEqual(get_tenant_filters(Project), {'account_id': account.pk}) unset_current_tenant()
def test_current_tenant_value_single(self): from .models import Project, Account projects = self.projects account = projects[0].account set_current_tenant(account) self.assertEqual(get_current_tenant_value(), account.id) unset_current_tenant()
def test_filter_without_joins(self): from .models import Project, Account unset_current_tenant() projects = self.projects accounts = Account.objects.all().order_by('id')[1:] self.assertEqual(Project.objects.count(), 30) set_current_tenant(accounts) self.assertEqual(Project.objects.count(), 20) unset_current_tenant()
def test_tenant_filters_multi_tenant(self): from .models import Project, Account projects = self.projects accounts = [projects[0].account, projects[1].account] set_current_tenant(accounts) self.assertEqual(get_tenant_filters(Project), {'account_id__in': [accounts[0].id, accounts[1].id]}) unset_current_tenant()
def test_delete_tenant_set(self): from .models import Project projects = self.projects account = self.account_fr self.assertEqual(Project.objects.count(), 30) set_current_tenant(account) Project.objects.all().delete() unset_current_tenant() self.assertEqual(Project.objects.count(), 20)
def test_str_model_tenant_set(self): from .models import Task projects = self.projects account = self.account_fr tasks = self.tasks set_current_tenant(account) print(Task.objects.first()) unset_current_tenant()
def test_current_tenant_value_queryset(self): from .models import Project, Account projects = self.projects accounts = Account.objects.all().order_by('id') set_current_tenant(accounts) value = get_current_tenant_value() self.assertEqual(get_tenant_filters(Project), {'account_id__in': list(accounts.values_list('id', flat=True))}) unset_current_tenant()
def test_current_tenant_value_list(self): from .models import Project, Account projects = self.projects accounts = [projects[0].account, projects[1].account] set_current_tenant(accounts) value = get_current_tenant_value() self.assertTrue(isinstance(value, list)) self.assertEqual(value, [accounts[0].id, accounts[1].id]) unset_current_tenant()
def test_current_tenant_value_queryset(self): from .models import Project, Account projects = self.projects accounts = Account.objects.all().order_by("id") set_current_tenant(accounts) value = get_current_tenant_value() self.assertTrue(isinstance(value, list)) self.assertEqual(value, accounts.values_list("id", flat=True)) unset_current_tenant()
def test_exclude_tenant_set(self): from .models import Task projects = self.projects account_fr = self.account_fr tasks = self.tasks unset_current_tenant() set_current_tenant(account_fr) tasks = Task.objects.exclude(project__isnull=True) self.assertEqual(tasks.count(), 50) unset_current_tenant()
def test_select_tenant_foreign_key_different_tenant_id(self): from .models import Revenue, Account self.revenues revenue = Revenue.objects.first() set_current_tenant(revenue.acc) # Selecting revenue.project, project.account is tenant (revenue.acc is tenant) # To push down, account_id should be in query (not acc_id) with self.assertNumQueries(1) as captured_queries: project = revenue.project self.assertTrue('AND "tests_project"."account_id" = %d' % revenue.acc_id \ in captured_queries.captured_queries[0]['sql']) unset_current_tenant()
def test_select_tenant_foreign_key(self): from .models import Task self.tasks task = Task.objects.first() set_current_tenant(task.account) # Selecting task.project, account is tenant # To push down, account_id should be in query with self.assertNumQueries(1) as captured_queries: project = task.project self.assertTrue('AND "tests_project"."account_id" = %d' % task.account_id \ in captured_queries.captured_queries[0]['sql']) unset_current_tenant()
def test_bulk_create_tenant_set(self): from .models import Project account = self.account_fr set_current_tenant(account) projects = [] for i in range(10): projects.append(Project(name='project %d' % i)) Project.objects.bulk_create(projects) unset_current_tenant() self.assertEqual(Project.objects.count(), 10) for project in Project.objects.all(): self.assertEqual(project.account_id, account.id)
def test_create_project_tenant_set(self): # Using save() from .models import Project account = self.account_fr set_current_tenant(account) project = Project() project.name = 'test save()' project.save() self.assertEqual(Project.objects.count(), 1) Project.objects.create(name='test create') self.assertEqual(Project.objects.count(), 2) unset_current_tenant()
def test_delete_tenant_set(self): from .models import Project, Account unset_current_tenant() projects = self.projects accounts = Account.objects.all().order_by('id')[1:] self.assertEqual(Project.objects.count(), 30) set_current_tenant(accounts) with self.assertNumQueries(7) as captured_queries: Project.objects.all().delete() for query in captured_queries.captured_queries: self.assertTrue('"account_id" IN (%s)' % ', '.join([str(account.id) for account in accounts])) unset_current_tenant() self.assertEqual(Project.objects.count(), 10)
def test_delete_tenant_set(self): from .models import Project projects = self.projects account = self.account_fr self.assertEqual(Project.objects.count(), 30) set_current_tenant(account) with self.assertNumQueries(6) as captured_queries: Project.objects.all().delete() unset_current_tenant() for query in captured_queries.captured_queries: self.assertTrue('"account_id" = %d' % account.id in query['sql']) self.assertEqual(Project.objects.count(), 20)