def test_allow_relation(self):
        router = ShardedPerTenantRouter()
        obj1 = TShardedModel(partner_id=PID)
        obj2 = TShardedModel(partner_id=obj1.partner_id)
        result = router.allow_relation(obj1, obj2)

        self.assertTrue(result)
        self.assertTrue(router._write_mode)
    def test_get_shard_for_instance_read_mod(self):
        instance = TShardedModel(partner_id=PID)
        router = ShardedPerTenantRouter()
        router._write_mode = False
        db_alias = router._get_shard_for_instance(instance)

        self.assertEqual(db_alias, 'test2__{}'.format(instance.partner_id))
    def test_db_for_read(self):
        router = ShardedPerTenantRouter()
        obj = TShardedModel(partner_id=PID)
        db_alias = router.db_for_read(TShardedModel, instance=obj)

        self.assertEqual(db_alias, 'test2__{}'.format(obj.partner_id))
        self.assertFalse(router._write_mode)
    def test_get_shard_with_instance_read_mode(self):
        router = ShardedPerTenantRouter()
        router._write_mode = False
        instance = TShardedModel(partner_id=PID)
        hints = {'instance': instance}
        db_alias = router._get_shard(TShardedModel, **hints)

        self.assertEqual(db_alias, 'test2__{}'.format(instance.partner_id))
    def test__extract_shared_value_with_instance_should_logging_with_empty_sharded_value(
            self):
        obj = TShardedModel()
        hints = {'instance': obj}
        self.router._stack = mock.Mock(return_value='test_log_string')
        self.router.logger.info = mock.Mock()

        val = self.router._extract_shared_value(TShardedModel, **hints)

        self.assertIsNone(val)

        self.router.logger.info.assert_called_once_with('test_log_string')
        self.router._stack.assert_called_once_with(
            'Instance of {} should have {} ({})\n'.format(
                TShardedModel.__name__, TShardedModel.sharded_field,
                model_to_dict(obj)))
    def test__extract_shared_value_with_instance(self):
        obj = TShardedModel(partner_id=self.PID)
        hints = {'instance': obj}

        val = self.router._extract_shared_value(TShardedModel, **hints)
        self.assertEqual(val, self.PID)