def check_restored_database(self, bkp_lp, expect_secrets=True):
        paths = provision.provision_paths_from_lp(bkp_lp, bkp_lp.get("realm"))

        bkp_pd = get_prim_dom(paths.secrets, bkp_lp)
        self.assertEqual(len(bkp_pd), 1)
        acn = bkp_pd[0].get('samAccountName')
        self.assertIsNotNone(acn)
        self.assertEqual(str(acn[0]), self.new_server + '$')
        self.assertIsNotNone(bkp_pd[0].get('secret'))

        samdb = SamDB(url=paths.samdb,
                      session_info=system_session(),
                      lp=bkp_lp,
                      credentials=self.get_credentials())

        # check that the backup markers have been removed from the restored DB
        res = samdb.search(base=ldb.Dn(samdb, "@SAMBA_DSDB"),
                           scope=ldb.SCOPE_BASE,
                           attrs=self.backup_markers)
        self.assertEqual(len(res), 1)
        for marker in self.backup_markers:
            self.assertIsNone(res[0].get(marker),
                              "%s backup-marker left behind" % marker)

        # check that the repsFrom and repsTo values have been removed
        # from the restored DB
        res = samdb.search(base=samdb.get_default_basedn(),
                           scope=ldb.SCOPE_BASE,
                           attrs=['repsFrom', 'repsTo'])
        self.assertEqual(len(res), 1)
        self.assertIsNone(res[0].get('repsFrom'))
        self.assertIsNone(res[0].get('repsTo'))

        res = samdb.search(base=samdb.get_config_basedn(),
                           scope=ldb.SCOPE_BASE,
                           attrs=['repsFrom', 'repsTo'])
        self.assertEqual(len(res), 1)
        self.assertIsNone(res[0].get('repsFrom'))
        self.assertIsNone(res[0].get('repsTo'))

        # check the DB is using the backend we supplied
        if self.backend:
            res = samdb.search(base="@PARTITION",
                               scope=ldb.SCOPE_BASE,
                               attrs=["backendStore"])
            backend = str(res[0].get("backendStore"))
            self.assertEqual(backend, self.backend)

        # check the restored DB has the expected partitions/DC/FSMO roles
        self.assert_partitions_present(samdb)
        self.assert_dcs_present(samdb, self.new_server, expected_count=1)
        self.assert_fsmo_roles(samdb, self.new_server, self.server)
        self.assert_secrets(samdb, expect_secrets=expect_secrets)

        # check we still have an uptodateness vector for the original DC
        self.assert_repl_uptodate_vector(samdb)
        return samdb
    def check_restored_database(self, bkp_lp, expect_secrets=True):
        paths = provision.provision_paths_from_lp(bkp_lp, bkp_lp.get("realm"))

        bkp_pd = get_prim_dom(paths.secrets, bkp_lp)
        self.assertEqual(len(bkp_pd), 1)
        acn = bkp_pd[0].get('samAccountName')
        self.assertIsNotNone(acn)
        self.assertEqual(acn[0].replace('$', ''), self.new_server)
        self.assertIsNotNone(bkp_pd[0].get('secret'))

        samdb = SamDB(url=paths.samdb,
                      session_info=system_session(),
                      lp=bkp_lp,
                      credentials=self.get_credentials())

        # check that the backup markers have been removed from the restored DB
        res = samdb.search(base=ldb.Dn(samdb, "@SAMBA_DSDB"),
                           scope=ldb.SCOPE_BASE,
                           attrs=self.backup_markers)
        self.assertEqual(len(res), 1)
        for marker in self.backup_markers:
            self.assertIsNone(res[0].get(marker),
                              "%s backup-marker left behind" % marker)

        # check that the repsFrom and repsTo values have been removed
        # from the restored DB
        res = samdb.search(base=samdb.get_default_basedn(),
                           scope=ldb.SCOPE_BASE,
                           attrs=['repsFrom', 'repsTo'])
        self.assertEqual(len(res), 1)
        self.assertIsNone(res[0].get('repsFrom'))
        self.assertIsNone(res[0].get('repsTo'))

        res = samdb.search(base=samdb.get_config_basedn(),
                           scope=ldb.SCOPE_BASE,
                           attrs=['repsFrom', 'repsTo'])
        self.assertEqual(len(res), 1)
        self.assertIsNone(res[0].get('repsFrom'))
        self.assertIsNone(res[0].get('repsTo'))

        # check the restored DB has the expected partitions/DC/FSMO roles
        self.assert_partitions_present(samdb)
        self.assert_dcs_present(samdb, self.new_server, expected_count=1)
        self.assert_fsmo_roles(samdb, self.new_server, self.server)
        self.assert_secrets(samdb, expect_secrets=expect_secrets)
        return samdb
Exemple #3
0
class DsdbTests(TestCase):
    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEquals(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          str(res[0]["replPropertyMetaData"]))
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          str(res[0]["replPropertyMetaData"]))
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          str(res[0]["replPropertyMetaData"]))
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, [
            "local_oid:1.3.6.1.4.1.7165.4.3.14:0",
            "local_oid:1.3.6.1.4.1.7165.4.3.25:0"
        ])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          str(res[0]["replPropertyMetaData"]))
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = long(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val",
                                                ldb.FLAG_MOD_REPLACE,
                                                "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          str(res[0]["replPropertyMetaData"]))
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = long(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = long(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(13),
                          "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 1)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(
            dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description",
                                                      version + 2)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "description"),
            version + 2)

    def test_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               credentials=self.creds,
                               lp=self.lp)

            self.samdb.transaction_start()

            dn = "cn=test_db_lock_user,cn=users," + str(basedn)
            self.samdb.add({
                "dn": dn,
                "objectclass": "user",
            })
            self.samdb.delete(dn)

            # Obtain a write lock
            self.samdb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            self.samdb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        # Release the locks
        for l in res:
            pass

        end = time.time()
        self.assertGreater(end - start, 1.9)

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               credentials=self.creds,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        dn = "cn=test_db_lock_user,cn=users," + str(basedn)
        self.samdb.add({
            "dn": dn,
            "objectclass": "user",
        })
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the parent releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            self.samdb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock3(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               credentials=self.creds,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")

        # This will end up in the top level db
        dn = "@DSDB_LOCK_TEST"
        self.samdb.add({"dn": dn})
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        self.assertGreater(end - start, 1.9)
        os.write(w1, b"prepared")

        # Drop the write lock
        self.samdb.transaction_cancel()

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)
        self.assertEqual(got_pid, pid)

    def _test_full_db_lock1(self, backend_path):
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del (self.samdb)
            gc.collect()

            backenddb = ldb.Ldb(backend_path)

            backenddb.transaction_start()

            backenddb.add({"dn": "@DSDB_LOCK_TEST"})
            backenddb.delete("@DSDB_LOCK_TEST")

            # Obtain a write lock
            backenddb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            backenddb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        end = time.time()
        self.assertGreater(end - start, 1.9)

        # Release the locks
        for l in res:
            pass

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)

    def test_full_db_lock1_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)

    def _test_full_db_lock2(self, backend_path):
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:

            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               credentials=self.creds,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)
            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # In the parent, close the main DB, re-open just one DB
        del (self.samdb)
        gc.collect()
        backenddb = ldb.Ldb(backend_path)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        backenddb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        backenddb.add({"dn": "@DSDB_LOCK_TEST"})
        backenddb.delete("@DSDB_LOCK_TEST")
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        backenddb.transaction_prepare_commit()
        end = time.time()

        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            backenddb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)

    def test_full_db_lock2_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(
                expression="cn=Administrator",
                scope=ldb.SCOPE_SUBTREE,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:0" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(
                expression="cn=Administrator",
                scope=ldb.SCOPE_SUBTREE,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:1" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            if e[0] != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail(
                    "Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                    % e[1])
Exemple #4
0
class PassWordHashTests(TestCase):
    def setUp(self):
        self.lp = samba.tests.env_loadparm()
        super(PassWordHashTests, self).setUp()

    def set_store_cleartext(self, cleartext):
        # get the current pwdProperties
        pwdProperties = self.ldb.get_pwdProperties()
        # update the clear-text properties flag
        props = int(pwdProperties)
        if cleartext:
            props |= DOMAIN_PASSWORD_STORE_CLEARTEXT
        else:
            props &= ~DOMAIN_PASSWORD_STORE_CLEARTEXT
        self.ldb.set_pwdProperties(str(props))

    # Add a user to ldb, this will exercise the password_hash code
    # and calculate the appropriate supplemental credentials
    def add_user(self, options=None, clear_text=False, ldb=None):
        # set any needed options
        if options is not None:
            for (option, value) in options:
                self.lp.set(option, value)

        if ldb is None:
            self.creds = Credentials()
            self.session = system_session()
            self.creds.guess(self.lp)
            self.session = system_session()
            self.ldb = SamDB(session_info=self.session,
                             credentials=self.creds,
                             lp=self.lp)
        else:
            self.ldb = ldb

        res = self.ldb.search(base=self.ldb.get_config_basedn(),
                              expression="ncName=%s" %
                              self.ldb.get_default_basedn(),
                              attrs=["nETBIOSName"])
        self.netbios_domain = str(res[0]["nETBIOSName"][0])
        self.dns_domain = self.ldb.domain_dns_name()

        # Gets back the basedn
        base_dn = self.ldb.domain_dn()

        # Gets back the configuration basedn
        configuration_dn = self.ldb.get_config_basedn().get_linearized()

        # permit password changes during this test
        PasswordCommon.allow_password_changes(self, self.ldb)

        self.base_dn = self.ldb.domain_dn()

        account_control = 0
        if clear_text:
            # Restore the current domain setting on exit.
            pwdProperties = self.ldb.get_pwdProperties()
            self.addCleanup(self.ldb.set_pwdProperties, pwdProperties)
            # Update the domain setting
            self.set_store_cleartext(clear_text)
            account_control |= UF_ENCRYPTED_TEXT_PASSWORD_ALLOWED

        # (Re)adds the test user USER_NAME with password USER_PASS
        # and userPrincipalName UPN
        delete_force(self.ldb, "cn=" + USER_NAME + ",cn=users," + self.base_dn)
        self.ldb.add({
            "dn": "cn=" + USER_NAME + ",cn=users," + self.base_dn,
            "objectclass": "user",
            "sAMAccountName": USER_NAME,
            "userPassword": USER_PASS,
            "userPrincipalName": UPN,
            "userAccountControl": str(account_control)
        })

    # Get the supplemental credentials for the user under test
    def get_supplemental_creds(self):
        base = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        res = self.ldb.search(scope=ldb.SCOPE_BASE,
                              base=base,
                              attrs=["supplementalCredentials"])
        self.assertIs(True, len(res) > 0)
        obj = res[0]
        sc_blob = obj["supplementalCredentials"][0]
        sc = ndr_unpack(drsblobs.supplementalCredentialsBlob, sc_blob)
        return sc

    # Calculate and validate a Wdigest value
    def check_digest(self, user, realm, password, digest):
        expected = calc_digest(user, realm, password)
        actual = binascii.hexlify(bytearray(digest)).decode('utf8')
        error = "Digest expected[%s], actual[%s], " \
                "user[%s], realm[%s], pass[%s]" % \
                (expected, actual, user, realm, password)
        self.assertEquals(expected, actual, error)

    # Check all of the 29 expected WDigest values
    #
    def check_wdigests(self, digests):

        self.assertEquals(29, digests.num_hashes)

        # Using the n-1 pattern in the array indexes to make it easier
        # to check the tests against the spec and the samba-tool user tests.
        self.check_digest(USER_NAME, self.netbios_domain, USER_PASS,
                          digests.hashes[1 - 1].hash)
        self.check_digest(USER_NAME.lower(), self.netbios_domain.lower(),
                          USER_PASS, digests.hashes[2 - 1].hash)
        self.check_digest(USER_NAME.upper(), self.netbios_domain.upper(),
                          USER_PASS, digests.hashes[3 - 1].hash)
        self.check_digest(USER_NAME, self.netbios_domain.upper(), USER_PASS,
                          digests.hashes[4 - 1].hash)
        self.check_digest(USER_NAME, self.netbios_domain.lower(), USER_PASS,
                          digests.hashes[5 - 1].hash)
        self.check_digest(USER_NAME.upper(), self.netbios_domain.lower(),
                          USER_PASS, digests.hashes[6 - 1].hash)
        self.check_digest(USER_NAME.lower(), self.netbios_domain.upper(),
                          USER_PASS, digests.hashes[7 - 1].hash)
        self.check_digest(USER_NAME, self.dns_domain, USER_PASS,
                          digests.hashes[8 - 1].hash)
        self.check_digest(USER_NAME.lower(), self.dns_domain.lower(),
                          USER_PASS, digests.hashes[9 - 1].hash)
        self.check_digest(USER_NAME.upper(), self.dns_domain.upper(),
                          USER_PASS, digests.hashes[10 - 1].hash)
        self.check_digest(USER_NAME, self.dns_domain.upper(), USER_PASS,
                          digests.hashes[11 - 1].hash)
        self.check_digest(USER_NAME, self.dns_domain.lower(), USER_PASS,
                          digests.hashes[12 - 1].hash)
        self.check_digest(USER_NAME.upper(), self.dns_domain.lower(),
                          USER_PASS, digests.hashes[13 - 1].hash)
        self.check_digest(USER_NAME.lower(), self.dns_domain.upper(),
                          USER_PASS, digests.hashes[14 - 1].hash)
        self.check_digest(UPN, "", USER_PASS, digests.hashes[15 - 1].hash)
        self.check_digest(UPN.lower(), "", USER_PASS,
                          digests.hashes[16 - 1].hash)
        self.check_digest(UPN.upper(), "", USER_PASS,
                          digests.hashes[17 - 1].hash)

        name = "%s\\%s" % (self.netbios_domain, USER_NAME)
        self.check_digest(name, "", USER_PASS, digests.hashes[18 - 1].hash)

        name = "%s\\%s" % (self.netbios_domain.lower(), USER_NAME.lower())
        self.check_digest(name, "", USER_PASS, digests.hashes[19 - 1].hash)

        name = "%s\\%s" % (self.netbios_domain.upper(), USER_NAME.upper())
        self.check_digest(name, "", USER_PASS, digests.hashes[20 - 1].hash)
        self.check_digest(USER_NAME, "Digest", USER_PASS,
                          digests.hashes[21 - 1].hash)
        self.check_digest(USER_NAME.lower(), "Digest", USER_PASS,
                          digests.hashes[22 - 1].hash)
        self.check_digest(USER_NAME.upper(), "Digest", USER_PASS,
                          digests.hashes[23 - 1].hash)
        self.check_digest(UPN, "Digest", USER_PASS,
                          digests.hashes[24 - 1].hash)
        self.check_digest(UPN.lower(), "Digest", USER_PASS,
                          digests.hashes[25 - 1].hash)
        self.check_digest(UPN.upper(), "Digest", USER_PASS,
                          digests.hashes[26 - 1].hash)
        name = "%s\\%s" % (self.netbios_domain, USER_NAME)
        self.check_digest(name, "Digest", USER_PASS,
                          digests.hashes[27 - 1].hash)

        name = "%s\\%s" % (self.netbios_domain.lower(), USER_NAME.lower())
        self.check_digest(name, "Digest", USER_PASS,
                          digests.hashes[28 - 1].hash)

        name = "%s\\%s" % (self.netbios_domain.upper(), USER_NAME.upper())
        self.check_digest(name, "Digest", USER_PASS,
                          digests.hashes[29 - 1].hash)

    def checkUserPassword(self, up, expected):

        # Check we've received the correct number of hashes
        self.assertEquals(len(expected), up.num_hashes)

        i = 0
        for (tag, alg, rounds) in expected:
            self.assertEquals(tag, up.hashes[i].scheme)

            data = up.hashes[i].value.decode('utf8').split("$")
            # Check we got the expected crypt algorithm
            self.assertEquals(alg, data[1])

            if rounds is None:
                cmd = "$%s$%s" % (alg, data[2])
            else:
                cmd = "$%s$rounds=%d$%s" % (alg, rounds, data[3])

            # Calculate the expected hash value
            expected = crypt.crypt(USER_PASS, cmd)
            self.assertEquals(expected, up.hashes[i].value.decode('utf8'))
            i += 1

    # Check that the correct nt_hash was stored for userPassword
    def checkNtHash(self, password, nt_hash):
        creds = Credentials()
        creds.set_anonymous()
        creds.set_password(password)
        expected = creds.get_nt_hash()
        actual = bytearray(nt_hash)
        self.assertEquals(expected, actual)
Exemple #5
0
    def test_rid_set_dbcheck_after_seize(self):
        """Perform a join against the RID manager and assert we have a RID Set.
        We seize the RID master role, then using dbcheck, we assert that we can
        detect out of range users (and then bump the RID set as required)."""

        fsmo_dn = ldb.Dn(
            self.ldb_dc1,
            "CN=RID Manager$,CN=System," + self.ldb_dc1.domain_dn())
        (fsmo_owner, fsmo_not_owner) = self._determine_fSMORoleOwner(fsmo_dn)

        targetdir = self._test_join(fsmo_owner['dns_name'], "RIDALLOCTEST7")
        try:
            # Connect to the database
            ldb_url = "tdb://%s" % os.path.join(targetdir, "private/sam.ldb")
            smbconf = os.path.join(targetdir, "etc/smb.conf")

            lp = self.get_loadparm()
            new_ldb = SamDB(ldb_url,
                            credentials=self.get_credentials(),
                            session_info=system_session(lp),
                            lp=lp)

            # 1. Get server name
            res = new_ldb.search(base=ldb.Dn(new_ldb,
                                             new_ldb.get_serverName()),
                                 scope=ldb.SCOPE_BASE,
                                 attrs=["serverReference"])
            # 2. Get server reference
            server_ref_dn = ldb.Dn(new_ldb, res[0]['serverReference'][0])

            # 3. Assert we get the RID Set
            res = new_ldb.search(base=server_ref_dn,
                                 scope=ldb.SCOPE_BASE,
                                 attrs=['rIDSetReferences'])

            self.assertTrue("rIDSetReferences" in res[0])
            rid_set_dn = ldb.Dn(new_ldb, res[0]["rIDSetReferences"][0])

            # 4. Seize the RID Manager role
            (result, out, err) = self.runsubcmd("fsmo", "seize", "--role",
                                                "rid", "-H", ldb_url, "-s",
                                                smbconf, "--force")
            self.assertCmdSuccess(result, out, err)
            self.assertEquals(err, "", "Shouldn't be any error messages")

            # 5. Add a new user (triggers RID set work)
            new_ldb.newuser("ridalloctestuser", "P@ssword!")

            # 6. Now fetch the RID SET
            rid_set_res = new_ldb.search(
                base=rid_set_dn,
                scope=ldb.SCOPE_BASE,
                attrs=['rIDNextRid', 'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            last_rid = (0xFFFFFFFF00000000 & next_pool) >> 32

            # 7. Add user above the ridNextRid and at almost the end of the range.
            #
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser2,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD,
                                                  'objectClass')
            m['objectSid'] = ldb.MessageElement(
                ndr_pack(
                    security.dom_sid(
                        str(new_ldb.get_domain_sid()) + "-%d" %
                        (last_rid - 3))), ldb.FLAG_MOD_ADD, 'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            # 8. Add user above the ridNextRid and at the end of the range
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser3,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD,
                                                  'objectClass')
            m['objectSid'] = ldb.MessageElement(
                ndr_pack(
                    security.dom_sid(
                        str(new_ldb.get_domain_sid()) + "-%d" % last_rid)),
                ldb.FLAG_MOD_ADD, 'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            chk = dbcheck(new_ldb,
                          verbose=False,
                          fix=True,
                          yes=True,
                          quiet=True)

            # Should have fixed two errors (wrong ridNextRid)
            self.assertEqual(
                chk.check_database(DN=rid_set_dn, scope=ldb.SCOPE_BASE), 2)

            # 9. Assert we get didn't show any other errors
            chk = dbcheck(new_ldb, verbose=False, fix=False, quiet=True)

            # 10. Add another user (checks RID rollover)
            # We have seized the role, so we can do that.
            new_ldb.newuser("ridalloctestuser3", "P@ssword!")

            rid_set_res = new_ldb.search(
                base=rid_set_dn,
                scope=ldb.SCOPE_BASE,
                attrs=['rIDNextRid', 'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            self.assertNotEqual(last_rid,
                                (0xFFFFFFFF00000000 & next_pool) >> 32,
                                "rid pool should have changed")
        finally:
            self._test_force_demote(fsmo_owner['dns_name'], "RIDALLOCTEST7")
            shutil.rmtree(targetdir, ignore_errors=True)
Exemple #6
0
    def test_rid_set_dbcheck(self):
        """Perform a join against the RID manager and assert we have a RID Set.
        Using dbcheck, we assert that we can detect out of range users."""

        fsmo_dn = ldb.Dn(
            self.ldb_dc1,
            "CN=RID Manager$,CN=System," + self.ldb_dc1.domain_dn())
        (fsmo_owner, fsmo_not_owner) = self._determine_fSMORoleOwner(fsmo_dn)

        targetdir = self._test_join(fsmo_owner['dns_name'], "RIDALLOCTEST6")
        try:
            # Connect to the database
            ldb_url = "tdb://%s" % os.path.join(targetdir, "private/sam.ldb")
            smbconf = os.path.join(targetdir, "etc/smb.conf")

            lp = self.get_loadparm()
            new_ldb = SamDB(ldb_url,
                            credentials=self.get_credentials(),
                            session_info=system_session(lp),
                            lp=lp)

            # 1. Get server name
            res = new_ldb.search(base=ldb.Dn(new_ldb,
                                             new_ldb.get_serverName()),
                                 scope=ldb.SCOPE_BASE,
                                 attrs=["serverReference"])
            # 2. Get server reference
            server_ref_dn = ldb.Dn(new_ldb, res[0]['serverReference'][0])

            # 3. Assert we get the RID Set
            res = new_ldb.search(base=server_ref_dn,
                                 scope=ldb.SCOPE_BASE,
                                 attrs=['rIDSetReferences'])

            self.assertTrue("rIDSetReferences" in res[0])
            rid_set_dn = ldb.Dn(new_ldb, res[0]["rIDSetReferences"][0])

            # 4. Add a new user (triggers RID set work)
            new_ldb.newuser("ridalloctestuser", "P@ssword!")

            # 5. Now fetch the RID SET
            rid_set_res = new_ldb.search(
                base=rid_set_dn,
                scope=ldb.SCOPE_BASE,
                attrs=['rIDNextRid', 'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            last_rid = (0xFFFFFFFF00000000 & next_pool) >> 32

            # 6. Add user above the ridNextRid and at mid-range.
            #
            # We can do this with safety because this is an offline DB that will be
            # destroyed.
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser1,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD,
                                                  'objectClass')
            m['objectSid'] = ldb.MessageElement(
                ndr_pack(
                    security.dom_sid(
                        str(new_ldb.get_domain_sid()) + "-%d" %
                        (last_rid - 10))), ldb.FLAG_MOD_ADD, 'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            # 7. Check the RID Set
            chk = dbcheck(new_ldb,
                          verbose=False,
                          fix=True,
                          yes=True,
                          quiet=True)

            # Should have one error (wrong rIDNextRID)
            self.assertEqual(
                chk.check_database(DN=rid_set_dn, scope=ldb.SCOPE_BASE), 1)

            # 8. Assert we get didn't show any other errors
            chk = dbcheck(new_ldb, verbose=False, fix=False, quiet=True)

            rid_set_res = new_ldb.search(
                base=rid_set_dn,
                scope=ldb.SCOPE_BASE,
                attrs=['rIDNextRid', 'rIDAllocationPool'])
            last_allocated_rid = int(rid_set_res[0]["rIDNextRid"][0])
            self.assertEquals(last_allocated_rid, last_rid - 10)

            # 9. Assert that the range wasn't thrown away

            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            self.assertEqual(last_rid, (0xFFFFFFFF00000000 & next_pool) >> 32,
                             "rid pool should have changed")
        finally:
            self._test_force_demote(fsmo_owner['dns_name'], "RIDALLOCTEST6")
            shutil.rmtree(targetdir, ignore_errors=True)
Exemple #7
0
class DsdbTests(TestCase):
    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

        # Create a test user
        user_name = "dsdb-user-" + str(uuid.uuid4().hex[0:6])
        user_pass = samba.generate_random_password(32, 32)
        user_description = "Test user for dsdb test"

        base_dn = self.samdb.domain_dn()

        self.account_dn = "cn=" + user_name + ",cn=Users," + base_dn
        self.samdb.newuser(username=user_name,
                           password=user_pass,
                           description=user_description)
        # Cleanup (teardown)
        self.addCleanup(delete_force, self.samdb, self.account_dn)

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEquals(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, [
            "local_oid:1.3.6.1.4.1.7165.4.3.14:0",
            "local_oid:1.3.6.1.4.1.7165.4.3.25:0"
        ])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val",
                                                ldb.FLAG_MOD_REPLACE,
                                                "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(13),
                          "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 2)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(
            dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description",
                                                      version + 2)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "description"),
            version + 2)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:0" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:1" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            (errno, estr) = e.args
            if errno != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail(
                    "Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                    % e[1])

    # Allocate a unique RID for use in the objectSID tests.
    #
    def allocate_rid(self):
        self.samdb.transaction_start()
        try:
            rid = self.samdb.allocate_rid()
        except:
            self.samdb.transaction_cancel()
            raise
        self.samdb.transaction_commit()
        return str(rid)

    # Ensure that duplicate objectSID's are permitted for foreign security
    # principals.
    #
    def test_duplicate_objectSIDs_allowed_on_foreign_security_principals(self):

        #
        # We need to build a foreign security principal SID
        # i.e a  SID not in the current domain.
        #
        dom_sid = self.samdb.get_domain_sid()
        if str(dom_sid).endswith("0"):
            c = "9"
        else:
            c = "0"
        sid_str = str(dom_sid)[:-1] + c + "-1000"
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        dn = "CN=%s,CN=ForeignSecurityPrincipals,%s" % (sid_str, basedn)

        #
        # First without control
        #

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"
            })
            self.fail("No exception should get ERR_OBJECT_CLASS_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_OBJECT_CLASS_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_MISSING_REQUIRED_ATT
            self.assertTrue(werr in msg, msg)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal",
                "objectSid": sid
            })
            self.fail("No exception should get ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_ILLEGAL_MOD_OPERATION
            self.assertTrue(werr in msg, msg)

        #
        # We need to use the provision control
        # in order to add foreignSecurityPrincipal
        # objects
        #

        controls = ["provision:0"]
        self.samdb.add({
            "dn": dn,
            "objectClass": "foreignSecurityPrincipal"
        },
                       controls=controls)

        self.samdb.delete(dn)

        try:
            self.samdb.add(
                {
                    "dn": dn,
                    "objectClass": "foreignSecurityPrincipal"
                },
                controls=controls)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.fail("Got unexpected exception %d - %s " % (code, msg))

        # cleanup
        self.samdb.delete(dn)

    def _test_foreignSecurityPrincipal(self, obj_class, fpo_attr):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_fpo"
        dn_str = "cn=%s,cn=Users,%s" % (cn, basedn)
        dn = ldb.Dn(self.samdb, dn_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn_str)

        self.samdb.add({"dn": dn_str, "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_INVALID_GROUP_TYPE
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_NO_SUCH_OBJECT")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_NO_SUCH_OBJECT, str(e))
            werr = "%08X" % werror.WERR_NO_SUCH_MEMBER
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 1)
        self.samdb.delete(res[0].dn)
        self.samdb.delete(dn)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

    def test_foreignSecurityPrincipal_member(self):
        return self._test_foreignSecurityPrincipal("group", "member")

    def test_foreignSecurityPrincipal_MembersForAzRole(self):
        return self._test_foreignSecurityPrincipal("msDS-AzRole",
                                                   "msDS-MembersForAzRole")

    def test_foreignSecurityPrincipal_NeverRevealGroup(self):
        return self._test_foreignSecurityPrincipal("computer",
                                                   "msDS-NeverRevealGroup")

    def test_foreignSecurityPrincipal_RevealOnDemandGroup(self):
        return self._test_foreignSecurityPrincipal("computer",
                                                   "msDS-RevealOnDemandGroup")

    def _test_fail_foreignSecurityPrincipal(self,
                                            obj_class,
                                            fpo_attr,
                                            msg_exp,
                                            lerr_exp,
                                            werr_exp,
                                            allow_reference=True):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn = self.samdb.get_default_basedn()
        cn1 = "dsdb_test_fpo1"
        dn1_str = "cn=%s,cn=Users,%s" % (cn1, basedn)
        dn1 = ldb.Dn(self.samdb, dn1_str)
        cn2 = "dsdb_test_fpo2"
        dn2_str = "cn=%s,cn=Users,%s" % (cn2, basedn)
        dn2 = ldb.Dn(self.samdb, dn2_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn1_str)
        self.addCleanup(delete_force, self.samdb, dn2_str)

        self.samdb.add({"dn": dn1_str, "objectClass": obj_class})

        self.samdb.add({"dn": dn2_str, "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("%s" % dn2, ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            if not allow_reference:
                self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            if allow_reference:
                self.fail("Should have not raised an exception: %s" % e)
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(dn2)
        self.samdb.delete(dn1)

    def test_foreignSecurityPrincipal_NonMembers(self):
        return self._test_fail_foreignSecurityPrincipal(
            "group",
            "msDS-NonMembers",
            "LDB_ERR_UNWILLING_TO_PERFORM/WERR_NOT_SUPPORTED",
            ldb.ERR_UNWILLING_TO_PERFORM,
            werror.WERR_NOT_SUPPORTED,
            allow_reference=False)

    def test_foreignSecurityPrincipal_HostServiceAccount(self):
        return self._test_fail_foreignSecurityPrincipal(
            "computer", "msDS-HostServiceAccount",
            "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
            ldb.ERR_CONSTRAINT_VIOLATION,
            werror.WERR_DS_NAME_REFERENCE_INVALID)

    def test_foreignSecurityPrincipal_manager(self):
        return self._test_fail_foreignSecurityPrincipal(
            "user", "manager",
            "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
            ldb.ERR_CONSTRAINT_VIOLATION,
            werror.WERR_DS_NAME_REFERENCE_INVALID)

    #
    # Duplicate objectSID's should not be permitted for sids in the local
    # domain. The test sequence is add an object, delete it, then attempt to
    # re-add it, this should fail with a constraint violation
    #
    def test_duplicate_objectSIDs_not_allowed_on_local_objects(self):

        dom_sid = self.samdb.get_domain_sid()
        rid = self.allocate_rid()
        sid_str = str(dom_sid) + "-" + rid
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_01"
        dn = "cn=%s,cn=Users,%s" % (cn, basedn)

        self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
        self.samdb.delete(dn)

        try:
            self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            if code != ldb.ERR_CONSTRAINT_VIOLATION:
                self.fail("Got %d - %s should have got "
                          "LDB_ERR_CONSTRAINT_VIOLATION" % (code, msg))

    def test_linked_vs_non_linked_reference(self):
        basedn = self.samdb.get_default_basedn()
        kept_dn_str = "cn=reference_kept,cn=Users,%s" % (basedn)
        removed_dn_str = "cn=reference_removed,cn=Users,%s" % (basedn)
        dom_sid = self.samdb.get_domain_sid()
        none_sid_str = str(dom_sid) + "-4294967294"
        none_guid_str = "afafafaf-fafa-afaf-fafa-afafafafafaf"

        self.addCleanup(delete_force, self.samdb, kept_dn_str)
        self.addCleanup(delete_force, self.samdb, removed_dn_str)

        self.samdb.add({"dn": kept_dn_str, "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=kept_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        kept_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        kept_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        kept_dn = res[0].dn

        self.samdb.add({"dn": removed_dn_str, "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=removed_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        removed_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        removed_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        self.samdb.delete(removed_dn_str)

        #
        # First try the linked attribute 'manager'
        # by GUID and SID
        #

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                            ldb.FLAG_MOD_ADD, "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                            ldb.FLAG_MOD_ADD, "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        #
        # Try the non-linked attribute 'assistant'
        # by GUID and SID, which should work.
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_ADD, "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_DELETE, "assistant")
        self.samdb.modify(msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_ADD, "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_DELETE, "assistant")
        self.samdb.modify(msg)

        #
        # Finally ry the non-linked attribute 'assistant'
        # but with non existing GUID, SID, DN
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("CN=NoneNone,%s" % (basedn),
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % none_sid_str,
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % none_guid_str,
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(kept_dn)

    def test_normalize_dn_in_domain_full(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        full_str = str(full_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_str))

    def test_normalize_dn_in_domain_part(self):
        domain_dn = self.samdb.domain_dn()

        part_str = "CN=Users"

        full_dn = ldb.Dn(self.samdb, part_str)
        full_dn.add_base(domain_dn)

        # That is, the domain DN appended
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(part_str))

    def test_normalize_dn_in_domain_full_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_dn))

    def test_normalize_dn_in_domain_part_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        # That is, the domain DN appended
        self.assertEqual(
            ldb.Dn(self.samdb,
                   str(part_dn) + "," + str(domain_dn)),
            self.samdb.normalize_dn_in_domain(part_dn))
Exemple #8
0
    def test_rid_set_dbcheck_after_seize(self):
        """Perform a join against the RID manager and assert we have a RID Set.
        We seize the RID master role, then using dbcheck, we assert that we can
        detect out of range users (and then bump the RID set as required)."""

        fsmo_dn = ldb.Dn(self.ldb_dc1, "CN=RID Manager$,CN=System," + self.ldb_dc1.domain_dn())
        (fsmo_owner, fsmo_not_owner) = self._determine_fSMORoleOwner(fsmo_dn)

        targetdir = self._test_join(fsmo_owner['dns_name'], "RIDALLOCTEST7")
        try:
            # Connect to the database
            ldb_url = "tdb://%s" % os.path.join(targetdir, "private/sam.ldb")
            smbconf = os.path.join(targetdir, "etc/smb.conf")

            lp = self.get_loadparm()
            new_ldb = SamDB(ldb_url, credentials=self.get_credentials(),
                            session_info=system_session(lp), lp=lp)

            # 1. Get server name
            res = new_ldb.search(base=ldb.Dn(new_ldb, new_ldb.get_serverName()),
                                 scope=ldb.SCOPE_BASE, attrs=["serverReference"])
            # 2. Get server reference
            server_ref_dn = ldb.Dn(new_ldb, res[0]['serverReference'][0])

            # 3. Assert we get the RID Set
            res = new_ldb.search(base=server_ref_dn,
                                 scope=ldb.SCOPE_BASE, attrs=['rIDSetReferences'])

            self.assertTrue("rIDSetReferences" in res[0])
            rid_set_dn = ldb.Dn(new_ldb, res[0]["rIDSetReferences"][0])

            # 4. Seize the RID Manager role
            (result, out, err) = self.runsubcmd("fsmo", "seize", "--role", "rid", "-H", ldb_url, "-s", smbconf, "--force")
            self.assertCmdSuccess(result, out, err)
            self.assertEquals(err,"","Shouldn't be any error messages")

            # 5. Add a new user (triggers RID set work)
            new_ldb.newuser("ridalloctestuser", "P@ssword!")

            # 6. Now fetch the RID SET
            rid_set_res = new_ldb.search(base=rid_set_dn,
                                         scope=ldb.SCOPE_BASE, attrs=['rIDNextRid',
                                                                      'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            last_rid = (0xFFFFFFFF00000000 & next_pool) >> 32

            # 7. Add user above the ridNextRid and at almost the end of the range.
            #
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser2,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD, 'objectClass')
            m['objectSid'] = ldb.MessageElement(ndr_pack(security.dom_sid(str(new_ldb.get_domain_sid()) + "-%d" % (last_rid - 3))),
                                                ldb.FLAG_MOD_ADD,
                                                'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            # 8. Add user above the ridNextRid and at the end of the range
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser3,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD, 'objectClass')
            m['objectSid'] = ldb.MessageElement(ndr_pack(security.dom_sid(str(new_ldb.get_domain_sid()) + "-%d" % last_rid)),
                                                ldb.FLAG_MOD_ADD,
                                                'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            chk = dbcheck(new_ldb, verbose=False, fix=True, yes=True, quiet=True)

            # Should have fixed two errors (wrong ridNextRid)
            self.assertEqual(chk.check_database(DN=rid_set_dn, scope=ldb.SCOPE_BASE), 2)

            # 9. Assert we get didn't show any other errors
            chk = dbcheck(new_ldb, verbose=False, fix=False, quiet=True)

            # 10. Add another user (checks RID rollover)
            # We have seized the role, so we can do that.
            new_ldb.newuser("ridalloctestuser3", "P@ssword!")

            rid_set_res = new_ldb.search(base=rid_set_dn,
                                         scope=ldb.SCOPE_BASE, attrs=['rIDNextRid',
                                                                      'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            self.assertNotEqual(last_rid, (0xFFFFFFFF00000000 & next_pool) >> 32, "rid pool should have changed")
        finally:
            self._test_force_demote(fsmo_owner['dns_name'], "RIDALLOCTEST7")
            shutil.rmtree(targetdir, ignore_errors=True)
Exemple #9
0
    def test_rid_set_dbcheck(self):
        """Perform a join against the RID manager and assert we have a RID Set.
        Using dbcheck, we assert that we can detect out of range users."""

        fsmo_dn = ldb.Dn(self.ldb_dc1, "CN=RID Manager$,CN=System," + self.ldb_dc1.domain_dn())
        (fsmo_owner, fsmo_not_owner) = self._determine_fSMORoleOwner(fsmo_dn)

        targetdir = self._test_join(fsmo_owner['dns_name'], "RIDALLOCTEST6")
        try:
            # Connect to the database
            ldb_url = "tdb://%s" % os.path.join(targetdir, "private/sam.ldb")
            smbconf = os.path.join(targetdir, "etc/smb.conf")

            lp = self.get_loadparm()
            new_ldb = SamDB(ldb_url, credentials=self.get_credentials(),
                            session_info=system_session(lp), lp=lp)

            # 1. Get server name
            res = new_ldb.search(base=ldb.Dn(new_ldb, new_ldb.get_serverName()),
                                 scope=ldb.SCOPE_BASE, attrs=["serverReference"])
            # 2. Get server reference
            server_ref_dn = ldb.Dn(new_ldb, res[0]['serverReference'][0])

            # 3. Assert we get the RID Set
            res = new_ldb.search(base=server_ref_dn,
                                 scope=ldb.SCOPE_BASE, attrs=['rIDSetReferences'])

            self.assertTrue("rIDSetReferences" in res[0])
            rid_set_dn = ldb.Dn(new_ldb, res[0]["rIDSetReferences"][0])

            # 4. Add a new user (triggers RID set work)
            new_ldb.newuser("ridalloctestuser", "P@ssword!")

            # 5. Now fetch the RID SET
            rid_set_res = new_ldb.search(base=rid_set_dn,
                                         scope=ldb.SCOPE_BASE, attrs=['rIDNextRid',
                                                                      'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            last_rid = (0xFFFFFFFF00000000 & next_pool) >> 32

            # 6. Add user above the ridNextRid and at mid-range.
            #
            # We can do this with safety because this is an offline DB that will be
            # destroyed.
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser1,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD, 'objectClass')
            m['objectSid'] = ldb.MessageElement(ndr_pack(security.dom_sid(str(new_ldb.get_domain_sid()) + "-%d" % (last_rid - 10))),
                                                ldb.FLAG_MOD_ADD,
                                                'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            # 7. Check the RID Set
            chk = dbcheck(new_ldb, verbose=False, fix=True, yes=True, quiet=True)

            # Should have one error (wrong rIDNextRID)
            self.assertEqual(chk.check_database(DN=rid_set_dn, scope=ldb.SCOPE_BASE), 1)

            # 8. Assert we get didn't show any other errors
            chk = dbcheck(new_ldb, verbose=False, fix=False, quiet=True)

            rid_set_res = new_ldb.search(base=rid_set_dn,
                                         scope=ldb.SCOPE_BASE, attrs=['rIDNextRid',
                                                                      'rIDAllocationPool'])
            last_allocated_rid = int(rid_set_res[0]["rIDNextRid"][0])
            self.assertEquals(last_allocated_rid, last_rid - 10)

            # 9. Assert that the range wasn't thrown away

            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            self.assertEqual(last_rid, (0xFFFFFFFF00000000 & next_pool) >> 32, "rid pool should have changed")
        finally:
            self._test_force_demote(fsmo_owner['dns_name'], "RIDALLOCTEST6")
            shutil.rmtree(targetdir, ignore_errors=True)
Exemple #10
0
class DsdbTests(TestCase):
    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

        # Create a test user
        user_name = "dsdb-user-" + str(uuid.uuid4().hex[0:6])
        user_pass = samba.generate_random_password(32, 32)
        user_description = "Test user for dsdb test"

        base_dn = self.samdb.domain_dn()

        self.account_dn = "CN=" + user_name + ",CN=Users," + base_dn
        self.samdb.newuser(username=user_name,
                           password=user_pass,
                           description=user_description)
        # Cleanup (teardown)
        self.addCleanup(delete_force, self.samdb, self.account_dn)

        # Get server reference DN
        res = self.samdb.search(base=ldb.Dn(self.samdb,
                                            self.samdb.get_serverName()),
                                scope=ldb.SCOPE_BASE,
                                attrs=["serverReference"])
        # Get server reference
        self.server_ref_dn = ldb.Dn(
            self.samdb, res[0]["serverReference"][0].decode("utf-8"))

        # Get RID Set DN
        res = self.samdb.search(base=self.server_ref_dn,
                                scope=ldb.SCOPE_BASE,
                                attrs=["rIDSetReferences"])
        rid_set_refs = res[0]
        self.assertIn("rIDSetReferences", rid_set_refs)
        rid_set_str = rid_set_refs["rIDSetReferences"][0].decode("utf-8")
        self.rid_set_dn = ldb.Dn(self.samdb, rid_set_str)

    def get_rid_set(self, rid_set_dn):
        res = self.samdb.search(base=rid_set_dn,
                                scope=ldb.SCOPE_BASE,
                                attrs=[
                                    "rIDAllocationPool",
                                    "rIDPreviousAllocationPool", "rIDUsedPool",
                                    "rIDNextRID"
                                ])
        return res[0]

    def test_ridalloc_next_free_rid(self):
        # Test RID allocation. We assume that RID
        # pools allocated to us are continguous.
        self.samdb.transaction_start()
        try:
            orig_rid_set = self.get_rid_set(self.rid_set_dn)
            self.assertIn("rIDAllocationPool", orig_rid_set)
            self.assertIn("rIDPreviousAllocationPool", orig_rid_set)
            self.assertIn("rIDUsedPool", orig_rid_set)
            self.assertIn("rIDNextRID", orig_rid_set)

            # Get rIDNextRID value from RID set.
            next_rid = int(orig_rid_set["rIDNextRID"][0])

            # Check the result of next_free_rid().
            next_free_rid = self.samdb.next_free_rid()
            self.assertEqual(next_rid + 1, next_free_rid)

            # Check calling it twice in succession gives the same result.
            next_free_rid2 = self.samdb.next_free_rid()
            self.assertEqual(next_free_rid, next_free_rid2)

            # Ensure that the RID set attributes have not changed.
            rid_set2 = self.get_rid_set(self.rid_set_dn)
            self.assertEqual(orig_rid_set, rid_set2)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_no_ridnextrid(self):
        self.samdb.transaction_start()
        try:
            # Delete the rIDNextRID attribute of the RID set,
            # and set up previous and next pools.
            prev_lo = 1000
            prev_hi = 1999
            next_lo = 3000
            next_hi = 3999
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDNextRID"] = ldb.MessageElement([], ldb.FLAG_MOD_DELETE,
                                                   "rIDNextRID")
            msg["rIDPreviousAllocationPool"] = (ldb.MessageElement(
                str((prev_hi << 32) | prev_lo), ldb.FLAG_MOD_REPLACE,
                "rIDPreviousAllocationPool"))
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                str((next_hi << 32) | next_lo), ldb.FLAG_MOD_REPLACE,
                "rIDAllocationPool"))
            self.samdb.modify(msg)

            # Ensure that next_free_rid() returns the start of the next pool.
            next_free_rid3 = self.samdb.next_free_rid()
            self.assertEqual(next_lo, next_free_rid3)

            # Check the result of allocate_rid() matches.
            rid = self.samdb.allocate_rid()
            self.assertEqual(next_free_rid3, rid)

            # Check that the result of next_free_rid() has now changed.
            next_free_rid4 = self.samdb.next_free_rid()
            self.assertEqual(rid + 1, next_free_rid4)

            # Check the range of available RIDs.
            free_lo, free_hi = self.samdb.free_rid_bounds()
            self.assertEqual(rid + 1, free_lo)
            self.assertEqual(next_hi, free_hi)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_no_free_rids(self):
        self.samdb.transaction_start()
        try:
            # Exhaust our current pool of RIDs.
            pool_lo = 2000
            pool_hi = 2999
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDPreviousAllocationPool"] = (ldb.MessageElement(
                str((pool_hi << 32) | pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDPreviousAllocationPool"))
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                str((pool_hi << 32) | pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDAllocationPool"))
            msg["rIDNextRID"] = (ldb.MessageElement(str(pool_hi),
                                                    ldb.FLAG_MOD_REPLACE,
                                                    "rIDNextRID"))
            self.samdb.modify(msg)

            # Ensure that calculating the next free RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.next_free_rid()

            self.assertEqual("RID pools out of RIDs", err.exception.args[1])

            # Ensure we can still allocate a new RID.
            self.samdb.allocate_rid()
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_new_ridset(self):
        self.samdb.transaction_start()
        try:
            # Test what happens with RID Set values set to zero (similar to
            # when a RID Set is first created, except we also set
            # rIDAllocationPool to zero).
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDPreviousAllocationPool"] = (ldb.MessageElement(
                "0", ldb.FLAG_MOD_REPLACE, "rIDPreviousAllocationPool"))
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                "0", ldb.FLAG_MOD_REPLACE, "rIDAllocationPool"))
            msg["rIDNextRID"] = (ldb.MessageElement("0", ldb.FLAG_MOD_REPLACE,
                                                    "rIDNextRID"))
            self.samdb.modify(msg)

            # Ensure that calculating the next free RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.next_free_rid()

            self.assertEqual("RID pools out of RIDs", err.exception.args[1])

            # Set values for the next pool.
            pool_lo = 2000
            pool_hi = 2999
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                str((pool_hi << 32) | pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDAllocationPool"))
            self.samdb.modify(msg)

            # Ensure the next free RID value is equal to the next pool's lower
            # bound.
            next_free_rid5 = self.samdb.next_free_rid()
            self.assertEqual(pool_lo, next_free_rid5)

            # Check the range of available RIDs.
            free_lo, free_hi = self.samdb.free_rid_bounds()
            self.assertEqual(pool_lo, free_lo)
            self.assertEqual(pool_hi, free_hi)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_move_to_new_pool(self):
        self.samdb.transaction_start()
        try:
            # Test moving to a new pool from the previous pool.
            pool_lo = 2000
            pool_hi = 2999
            new_pool_lo = 4500
            new_pool_hi = 4599
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDPreviousAllocationPool"] = (ldb.MessageElement(
                str((pool_hi << 32) | pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDPreviousAllocationPool"))
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                str((new_pool_hi << 32) | new_pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDAllocationPool"))
            msg["rIDNextRID"] = (ldb.MessageElement(str(pool_hi - 1),
                                                    ldb.FLAG_MOD_REPLACE,
                                                    "rIDNextRID"))
            self.samdb.modify(msg)

            # We should have remained in the previous pool.
            next_free_rid6 = self.samdb.next_free_rid()
            self.assertEqual(pool_hi, next_free_rid6)

            # Check the range of available RIDs.
            free_lo, free_hi = self.samdb.free_rid_bounds()
            self.assertEqual(pool_hi, free_lo)
            self.assertEqual(pool_hi, free_hi)

            # Allocate a new RID.
            rid2 = self.samdb.allocate_rid()
            self.assertEqual(next_free_rid6, rid2)

            # We should now move to the next pool.
            next_free_rid7 = self.samdb.next_free_rid()
            self.assertEqual(new_pool_lo, next_free_rid7)

            # Check the new range of available RIDs.
            free_lo2, free_hi2 = self.samdb.free_rid_bounds()
            self.assertEqual(new_pool_lo, free_lo2)
            self.assertEqual(new_pool_hi, free_hi2)

            # Ensure that allocate_rid() matches.
            rid3 = self.samdb.allocate_rid()
            self.assertEqual(next_free_rid7, rid3)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_no_ridsetreferences(self):
        self.samdb.transaction_start()
        try:
            # Delete the rIDSetReferences attribute.
            msg = ldb.Message()
            msg.dn = self.server_ref_dn
            msg["rIDSetReferences"] = (ldb.MessageElement([],
                                                          ldb.FLAG_MOD_DELETE,
                                                          "rIDSetReferences"))
            self.samdb.modify(msg)

            # Ensure calculating the next free RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.next_free_rid()

            enum, estr = err.exception.args
            self.assertEqual(ldb.ERR_NO_SUCH_ATTRIBUTE, enum)
            self.assertIn(
                "No RID Set DN - "
                "Cannot find attribute rIDSetReferences of %s "
                "to calculate reference dn" % self.server_ref_dn, estr)

            # Ensure allocating a new RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.allocate_rid()

            enum, estr = err.exception.args
            self.assertEqual(ldb.ERR_ENTRY_ALREADY_EXISTS, enum)
            self.assertIn(
                "No RID Set DN - "
                "Failed to add RID Set %s - "
                "Entry %s already exists" % (self.rid_set_dn, self.rid_set_dn),
                estr)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_no_rid_set(self):
        self.samdb.transaction_start()
        try:
            # Set the rIDSetReferences attribute to not point to a RID Set.
            fake_rid_set_str = self.account_dn
            msg = ldb.Message()
            msg.dn = self.server_ref_dn
            msg["rIDSetReferences"] = (ldb.MessageElement(
                fake_rid_set_str, ldb.FLAG_MOD_REPLACE, "rIDSetReferences"))
            self.samdb.modify(msg)

            # Ensure calculating the next free RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.next_free_rid()

            enum, estr = err.exception.args
            self.assertEqual(ldb.ERR_OPERATIONS_ERROR, enum)
            self.assertIn("Bad RID Set " + fake_rid_set_str, estr)

            # Ensure allocating a new RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.allocate_rid()

            enum, estr = err.exception.args
            self.assertEqual(ldb.ERR_OPERATIONS_ERROR, enum)
            self.assertIn("Bad RID Set " + fake_rid_set_str, estr)
        finally:
            self.samdb.transaction_cancel()

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEqual(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, [
            "local_oid:1.3.6.1.4.1.7165.4.3.14:0",
            "local_oid:1.3.6.1.4.1.7165.4.3.25:0"
        ])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val",
                                                ldb.FLAG_MOD_REPLACE,
                                                "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEqual(self.samdb.get_attribute_from_attid(13),
                         "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEqual(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEqual(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 2)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEqual(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(
            dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description",
                                                      version + 2)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "description"),
            version + 2)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:0" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:1" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            (errno, estr) = e.args
            if errno != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail(
                    "Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                    % e[1])

    # Allocate a unique RID for use in the objectSID tests.
    #
    def allocate_rid(self):
        self.samdb.transaction_start()
        try:
            rid = self.samdb.allocate_rid()
        except:
            self.samdb.transaction_cancel()
            raise
        self.samdb.transaction_commit()
        return str(rid)

    # Ensure that duplicate objectSID's are permitted for foreign security
    # principals.
    #
    def test_duplicate_objectSIDs_allowed_on_foreign_security_principals(self):

        #
        # We need to build a foreign security principal SID
        # i.e a  SID not in the current domain.
        #
        dom_sid = self.samdb.get_domain_sid()
        if str(dom_sid).endswith("0"):
            c = "9"
        else:
            c = "0"
        sid_str = str(dom_sid)[:-1] + c + "-1000"
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        dn = "CN=%s,CN=ForeignSecurityPrincipals,%s" % (sid_str, basedn)

        #
        # First without control
        #

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"
            })
            self.fail("No exception should get ERR_OBJECT_CLASS_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_OBJECT_CLASS_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_MISSING_REQUIRED_ATT
            self.assertTrue(werr in msg, msg)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal",
                "objectSid": sid
            })
            self.fail("No exception should get ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_ILLEGAL_MOD_OPERATION
            self.assertTrue(werr in msg, msg)

        #
        # We need to use the provision control
        # in order to add foreignSecurityPrincipal
        # objects
        #

        controls = ["provision:0"]
        self.samdb.add({
            "dn": dn,
            "objectClass": "foreignSecurityPrincipal"
        },
                       controls=controls)

        self.samdb.delete(dn)

        try:
            self.samdb.add(
                {
                    "dn": dn,
                    "objectClass": "foreignSecurityPrincipal"
                },
                controls=controls)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.fail("Got unexpected exception %d - %s " % (code, msg))

        # cleanup
        self.samdb.delete(dn)

    def _test_foreignSecurityPrincipal(self, obj_class, fpo_attr):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_fpo"
        dn_str = "cn=%s,cn=Users,%s" % (cn, basedn)
        dn = ldb.Dn(self.samdb, dn_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn_str)

        self.samdb.add({"dn": dn_str, "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_INVALID_GROUP_TYPE
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_NO_SUCH_OBJECT")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_NO_SUCH_OBJECT, str(e))
            werr = "%08X" % werror.WERR_NO_SUCH_MEMBER
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 1)
        self.samdb.delete(res[0].dn)
        self.samdb.delete(dn)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

    def test_foreignSecurityPrincipal_member(self):
        return self._test_foreignSecurityPrincipal("group", "member")

    def test_foreignSecurityPrincipal_MembersForAzRole(self):
        return self._test_foreignSecurityPrincipal("msDS-AzRole",
                                                   "msDS-MembersForAzRole")

    def test_foreignSecurityPrincipal_NeverRevealGroup(self):
        return self._test_foreignSecurityPrincipal("computer",
                                                   "msDS-NeverRevealGroup")

    def test_foreignSecurityPrincipal_RevealOnDemandGroup(self):
        return self._test_foreignSecurityPrincipal("computer",
                                                   "msDS-RevealOnDemandGroup")

    def _test_fail_foreignSecurityPrincipal(self,
                                            obj_class,
                                            fpo_attr,
                                            msg_exp,
                                            lerr_exp,
                                            werr_exp,
                                            allow_reference=True):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn = self.samdb.get_default_basedn()
        cn1 = "dsdb_test_fpo1"
        dn1_str = "cn=%s,cn=Users,%s" % (cn1, basedn)
        dn1 = ldb.Dn(self.samdb, dn1_str)
        cn2 = "dsdb_test_fpo2"
        dn2_str = "cn=%s,cn=Users,%s" % (cn2, basedn)
        dn2 = ldb.Dn(self.samdb, dn2_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn1_str)
        self.addCleanup(delete_force, self.samdb, dn2_str)

        self.samdb.add({"dn": dn1_str, "objectClass": obj_class})

        self.samdb.add({"dn": dn2_str, "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("%s" % dn2, ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            if not allow_reference:
                self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            if allow_reference:
                self.fail("Should have not raised an exception: %s" % e)
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(dn2)
        self.samdb.delete(dn1)

    def test_foreignSecurityPrincipal_NonMembers(self):
        return self._test_fail_foreignSecurityPrincipal(
            "group",
            "msDS-NonMembers",
            "LDB_ERR_UNWILLING_TO_PERFORM/WERR_NOT_SUPPORTED",
            ldb.ERR_UNWILLING_TO_PERFORM,
            werror.WERR_NOT_SUPPORTED,
            allow_reference=False)

    def test_foreignSecurityPrincipal_HostServiceAccount(self):
        return self._test_fail_foreignSecurityPrincipal(
            "computer", "msDS-HostServiceAccount",
            "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
            ldb.ERR_CONSTRAINT_VIOLATION,
            werror.WERR_DS_NAME_REFERENCE_INVALID)

    def test_foreignSecurityPrincipal_manager(self):
        return self._test_fail_foreignSecurityPrincipal(
            "user", "manager",
            "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
            ldb.ERR_CONSTRAINT_VIOLATION,
            werror.WERR_DS_NAME_REFERENCE_INVALID)

    #
    # Duplicate objectSID's should not be permitted for sids in the local
    # domain. The test sequence is add an object, delete it, then attempt to
    # re-add it, this should fail with a constraint violation
    #
    def test_duplicate_objectSIDs_not_allowed_on_local_objects(self):

        dom_sid = self.samdb.get_domain_sid()
        rid = self.allocate_rid()
        sid_str = str(dom_sid) + "-" + rid
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_01"
        dn = "cn=%s,cn=Users,%s" % (cn, basedn)

        self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
        self.samdb.delete(dn)

        try:
            self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            if code != ldb.ERR_CONSTRAINT_VIOLATION:
                self.fail("Got %d - %s should have got "
                          "LDB_ERR_CONSTRAINT_VIOLATION" % (code, msg))

    def test_linked_vs_non_linked_reference(self):
        basedn = self.samdb.get_default_basedn()
        kept_dn_str = "cn=reference_kept,cn=Users,%s" % (basedn)
        removed_dn_str = "cn=reference_removed,cn=Users,%s" % (basedn)
        dom_sid = self.samdb.get_domain_sid()
        none_sid_str = str(dom_sid) + "-4294967294"
        none_guid_str = "afafafaf-fafa-afaf-fafa-afafafafafaf"

        self.addCleanup(delete_force, self.samdb, kept_dn_str)
        self.addCleanup(delete_force, self.samdb, removed_dn_str)

        self.samdb.add({"dn": kept_dn_str, "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=kept_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        kept_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        kept_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        kept_dn = res[0].dn

        self.samdb.add({"dn": removed_dn_str, "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=removed_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        removed_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        removed_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        self.samdb.delete(removed_dn_str)

        #
        # First try the linked attribute 'manager'
        # by GUID and SID
        #

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                            ldb.FLAG_MOD_ADD, "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                            ldb.FLAG_MOD_ADD, "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        #
        # Try the non-linked attribute 'assistant'
        # by GUID and SID, which should work.
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_ADD, "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_DELETE, "assistant")
        self.samdb.modify(msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_ADD, "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_DELETE, "assistant")
        self.samdb.modify(msg)

        #
        # Finally ry the non-linked attribute 'assistant'
        # but with non existing GUID, SID, DN
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("CN=NoneNone,%s" % (basedn),
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % none_sid_str,
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % none_guid_str,
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(kept_dn)

    def test_normalize_dn_in_domain_full(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        full_str = str(full_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_str))

    def test_normalize_dn_in_domain_part(self):
        domain_dn = self.samdb.domain_dn()

        part_str = "CN=Users"

        full_dn = ldb.Dn(self.samdb, part_str)
        full_dn.add_base(domain_dn)

        # That is, the domain DN appended
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(part_str))

    def test_normalize_dn_in_domain_full_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_dn))

    def test_normalize_dn_in_domain_part_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        # That is, the domain DN appended
        self.assertEqual(
            ldb.Dn(self.samdb,
                   str(part_dn) + "," + str(domain_dn)),
            self.samdb.normalize_dn_in_domain(part_dn))
Exemple #11
0
class DsdbLockTestCase(SamDBTestCase):
    def test_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session, lp=self.lp)

            self.samdb.transaction_start()

            dn = "cn=test_db_lock_user,cn=users," + str(basedn)
            self.samdb.add({
                "dn": dn,
                "objectclass": "user",
            })
            self.samdb.delete(dn)

            # Obtain a write lock
            self.samdb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            self.samdb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        # Release the locks
        for l in res:
            pass

        end = time.time()
        self.assertGreater(end - start, 1.9)

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session, lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        dn = "cn=test_db_lock_user,cn=users," + str(basedn)
        self.samdb.add({
            "dn": dn,
            "objectclass": "user",
        })
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the parent releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            self.samdb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock3(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session, lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")

        # This will end up in the top level db
        dn = "@DSDB_LOCK_TEST"
        self.samdb.add({"dn": dn})
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        self.assertGreater(end - start, 1.9)
        os.write(w1, b"prepared")

        # Drop the write lock
        self.samdb.transaction_cancel()

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)
        self.assertEqual(got_pid, pid)

    def _test_full_db_lock1(self, backend_path):
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del (self.samdb)
            gc.collect()

            backenddb = ldb.Ldb(backend_path)

            backenddb.transaction_start()

            backenddb.add({"dn": "@DSDB_LOCK_TEST"})
            backenddb.delete("@DSDB_LOCK_TEST")

            # Obtain a write lock
            backenddb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            backenddb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        end = time.time()
        self.assertGreater(end - start, 1.9)

        # Release the locks
        for l in res:
            pass

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)

    def test_full_db_lock1_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)

    def _test_full_db_lock2(self, backend_path):
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:

            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session, lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)
            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # In the parent, close the main DB, re-open just one DB
        del (self.samdb)
        gc.collect()
        backenddb = ldb.Ldb(backend_path)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        backenddb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        backenddb.add({"dn": "@DSDB_LOCK_TEST"})
        backenddb.delete("@DSDB_LOCK_TEST")
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        backenddb.transaction_prepare_commit()
        end = time.time()

        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            backenddb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)

    def test_full_db_lock2_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)
Exemple #12
0
class DsdbTests(TestCase):
    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

        # Create a test user
        user_name = "dsdb-user-" + str(uuid.uuid4().hex[0:6])
        user_pass = samba.generate_random_password(32, 32)
        user_description = "Test user for dsdb test"

        base_dn = self.samdb.domain_dn()

        self.account_dn = "cn=" + user_name + ",cn=Users," + base_dn
        self.samdb.newuser(username=user_name,
                           password=user_pass,
                           description=user_description)
        # Cleanup (teardown)
        self.addCleanup(delete_force, self.samdb, self.account_dn)

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEquals(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, [
            "local_oid:1.3.6.1.4.1.7165.4.3.14:0",
            "local_oid:1.3.6.1.4.1.7165.4.3.25:0"
        ])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val",
                                                ldb.FLAG_MOD_REPLACE,
                                                "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(13),
                          "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 2)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(
            dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description",
                                                      version + 2)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "description"),
            version + 2)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:0" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:1" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            (errno, estr) = e.args
            if errno != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail(
                    "Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                    % e[1])

    # Allocate a unique RID for use in the objectSID tests.
    #
    def allocate_rid(self):
        self.samdb.transaction_start()
        try:
            rid = self.samdb.allocate_rid()
        except:
            self.samdb.transaction_cancel()
            raise
        self.samdb.transaction_commit()
        return str(rid)

    # Ensure that duplicate objectSID's are permitted for foreign security
    # principals.
    #
    def test_duplicate_objectSIDs_allowed_on_foreign_security_principals(self):

        #
        # We need to build a foreign security principal SID
        # i.e a  SID not in the current domain.
        #
        dom_sid = self.samdb.get_domain_sid()
        if str(dom_sid).endswith("0"):
            c = "9"
        else:
            c = "0"
        sid = str(dom_sid)[:-1] + c + "-1000"
        basedn = self.samdb.get_default_basedn()
        dn = "CN=%s,CN=ForeignSecurityPrincipals,%s" % (sid, basedn)
        self.samdb.add({"dn": dn, "objectClass": "foreignSecurityPrincipal"})

        self.samdb.delete(dn)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"
            })
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.fail("Got unexpected exception %d - %s " % (code, msg))

        # cleanup
        self.samdb.delete(dn)

    #
    # Duplicate objectSID's should not be permitted for sids in the local
    # domain. The test sequence is add an object, delete it, then attempt to
    # re-add it, this should fail with a constraint violation
    #
    def test_duplicate_objectSIDs_not_allowed_on_local_objects(self):

        dom_sid = self.samdb.get_domain_sid()
        rid = self.allocate_rid()
        sid_str = str(dom_sid) + "-" + rid
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_01"
        dn = "cn=%s,cn=Users,%s" % (cn, basedn)

        self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
        self.samdb.delete(dn)

        try:
            self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            if code != ldb.ERR_CONSTRAINT_VIOLATION:
                self.fail("Got %d - %s should have got "
                          "LDB_ERR_CONSTRAINT_VIOLATION" % (code, msg))

    def test_normalize_dn_in_domain_full(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        full_str = str(full_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_str))

    def test_normalize_dn_in_domain_part(self):
        domain_dn = self.samdb.domain_dn()

        part_str = "CN=Users"

        full_dn = ldb.Dn(self.samdb, part_str)
        full_dn.add_base(domain_dn)

        # That is, the domain DN appended
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(part_str))

    def test_normalize_dn_in_domain_full_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_dn))

    def test_normalize_dn_in_domain_part_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        # That is, the domain DN appended
        self.assertEqual(
            ldb.Dn(self.samdb,
                   str(part_dn) + "," + str(domain_dn)),
            self.samdb.normalize_dn_in_domain(part_dn))
Exemple #13
0
class PassWordHashTests(TestCase):

    def setUp(self):
        self.lp = samba.tests.env_loadparm()
        super(PassWordHashTests, self).setUp()

    def set_store_cleartext(self, cleartext):
        # get the current pwdProperties
        pwdProperties = self.ldb.get_pwdProperties()
        # update the clear-text properties flag
        props = int(pwdProperties)
        if cleartext:
            props |= DOMAIN_PASSWORD_STORE_CLEARTEXT
        else:
            props &= ~DOMAIN_PASSWORD_STORE_CLEARTEXT
        self.ldb.set_pwdProperties(str(props))

    # Add a user to ldb, this will exercise the password_hash code
    # and calculate the appropriate supplemental credentials
    def add_user(self, options=None, clear_text=False, ldb=None):
        # set any needed options
        if options is not None:
            for (option, value) in options:
                self.lp.set(option, value)

        if ldb is None:
            self.creds = Credentials()
            self.session = system_session()
            self.creds.guess(self.lp)
            self.session = system_session()
            self.ldb = SamDB(session_info=self.session,
                             credentials=self.creds,
                             lp=self.lp)
        else:
            self.ldb = ldb

        res = self.ldb.search(base=self.ldb.get_config_basedn(),
                              expression="ncName=%s" % self.ldb.get_default_basedn(),
                              attrs=["nETBIOSName"])
        self.netbios_domain = res[0]["nETBIOSName"][0]
        self.dns_domain = self.ldb.domain_dns_name()


        # Gets back the basedn
        base_dn = self.ldb.domain_dn()

        # Gets back the configuration basedn
        configuration_dn = self.ldb.get_config_basedn().get_linearized()

        # permit password changes during this test
        PasswordCommon.allow_password_changes(self, self.ldb)

        self.base_dn = self.ldb.domain_dn()

        account_control = 0
        if clear_text:
            # Restore the current domain setting on exit.
            pwdProperties = self.ldb.get_pwdProperties()
            self.addCleanup(self.ldb.set_pwdProperties, pwdProperties)
            # Update the domain setting
            self.set_store_cleartext(clear_text)
            account_control |= UF_ENCRYPTED_TEXT_PASSWORD_ALLOWED

        # (Re)adds the test user USER_NAME with password USER_PASS
        # and userPrincipalName UPN
        delete_force(self.ldb, "cn=" + USER_NAME + ",cn=users," + self.base_dn)
        self.ldb.add({
             "dn": "cn=" + USER_NAME + ",cn=users," + self.base_dn,
             "objectclass": "user",
             "sAMAccountName": USER_NAME,
             "userPassword": USER_PASS,
             "userPrincipalName": UPN,
             "userAccountControl": str(account_control)
        })

    # Get the supplemental credentials for the user under test
    def get_supplemental_creds(self):
        base = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        res = self.ldb.search(scope=ldb.SCOPE_BASE,
                              base=base,
                              attrs=["supplementalCredentials"])
        self.assertIs(True, len(res) > 0)
        obj = res[0]
        sc_blob = obj["supplementalCredentials"][0]
        sc = ndr_unpack(drsblobs.supplementalCredentialsBlob, sc_blob)
        return sc

    # Calculate and validate a Wdigest value
    def check_digest(self, user, realm, password,  digest):
        expected = calc_digest(user, realm, password)
        actual = binascii.hexlify(bytearray(digest))
        error = "Digest expected[%s], actual[%s], " \
                "user[%s], realm[%s], pass[%s]" % \
                (expected, actual, user, realm, password)
        self.assertEquals(expected, actual, error)

    # Check all of the 29 expected WDigest values
    #
    def check_wdigests(self, digests):

        self.assertEquals(29, digests.num_hashes)

        # Using the n-1 pattern in the array indexes to make it easier
        # to check the tests against the spec and the samba-tool user tests.
        self.check_digest(USER_NAME,
                          self.netbios_domain,
                          USER_PASS,
                          digests.hashes[1-1].hash)
        self.check_digest(USER_NAME.lower(),
                          self.netbios_domain.lower(),
                          USER_PASS,
                          digests.hashes[2-1].hash)
        self.check_digest(USER_NAME.upper(),
                          self.netbios_domain.upper(),
                          USER_PASS,
                          digests.hashes[3-1].hash)
        self.check_digest(USER_NAME,
                          self.netbios_domain.upper(),
                          USER_PASS,
                          digests.hashes[4-1].hash)
        self.check_digest(USER_NAME,
                          self.netbios_domain.lower(),
                          USER_PASS,
                          digests.hashes[5-1].hash)
        self.check_digest(USER_NAME.upper(),
                          self.netbios_domain.lower(),
                          USER_PASS,
                          digests.hashes[6-1].hash)
        self.check_digest(USER_NAME.lower(),
                          self.netbios_domain.upper(),
                          USER_PASS,
                          digests.hashes[7-1].hash)
        self.check_digest(USER_NAME,
                          self.dns_domain,
                          USER_PASS,
                          digests.hashes[8-1].hash)
        self.check_digest(USER_NAME.lower(),
                          self.dns_domain.lower(),
                          USER_PASS,
                          digests.hashes[9-1].hash)
        self.check_digest(USER_NAME.upper(),
                          self.dns_domain.upper(),
                          USER_PASS,
                          digests.hashes[10-1].hash)
        self.check_digest(USER_NAME,
                          self.dns_domain.upper(),
                          USER_PASS,
                          digests.hashes[11-1].hash)
        self.check_digest(USER_NAME,
                          self.dns_domain.lower(),
                          USER_PASS,
                          digests.hashes[12-1].hash)
        self.check_digest(USER_NAME.upper(),
                          self.dns_domain.lower(),
                          USER_PASS,
                          digests.hashes[13-1].hash)
        self.check_digest(USER_NAME.lower(),
                          self.dns_domain.upper(),
                          USER_PASS,
                          digests.hashes[14-1].hash)
        self.check_digest(UPN,
                          "",
                          USER_PASS,
                          digests.hashes[15-1].hash)
        self.check_digest(UPN.lower(),
                          "",
                          USER_PASS,
                          digests.hashes[16-1].hash)
        self.check_digest(UPN.upper(),
                          "",
                          USER_PASS,
                          digests.hashes[17-1].hash)

        name = "%s\\%s" % (self.netbios_domain, USER_NAME)
        self.check_digest(name,
                          "",
                          USER_PASS,
                          digests.hashes[18-1].hash)

        name = "%s\\%s" % (self.netbios_domain.lower(), USER_NAME.lower())
        self.check_digest(name,
                          "",
                          USER_PASS,
                          digests.hashes[19-1].hash)

        name = "%s\\%s" % (self.netbios_domain.upper(), USER_NAME.upper())
        self.check_digest(name,
                          "",
                          USER_PASS,
                          digests.hashes[20-1].hash)
        self.check_digest(USER_NAME,
                          "Digest",
                          USER_PASS,
                          digests.hashes[21-1].hash)
        self.check_digest(USER_NAME.lower(),
                          "Digest",
                          USER_PASS,
                          digests.hashes[22-1].hash)
        self.check_digest(USER_NAME.upper(),
                          "Digest",
                          USER_PASS,
                          digests.hashes[23-1].hash)
        self.check_digest(UPN,
                          "Digest",
                          USER_PASS,
                          digests.hashes[24-1].hash)
        self.check_digest(UPN.lower(),
                          "Digest",
                          USER_PASS,
                          digests.hashes[25-1].hash)
        self.check_digest(UPN.upper(),
                          "Digest",
                          USER_PASS,
                          digests.hashes[26-1].hash)
        name = "%s\\%s" % (self.netbios_domain, USER_NAME)
        self.check_digest(name,
                          "Digest",
                          USER_PASS,
                          digests.hashes[27-1].hash)

        name = "%s\\%s" % (self.netbios_domain.lower(), USER_NAME.lower())
        self.check_digest(name,
                          "Digest",
                          USER_PASS,
                          digests.hashes[28-1].hash)

        name = "%s\\%s" % (self.netbios_domain.upper(), USER_NAME.upper())
        self.check_digest(name,
                          "Digest",
                          USER_PASS,
                          digests.hashes[29-1].hash)

    def checkUserPassword(self, up, expected):

        # Check we've received the correct number of hashes
        self.assertEquals(len(expected), up.num_hashes)

        i = 0
        for (tag, alg, rounds) in expected:
            self.assertEquals(tag, up.hashes[i].scheme)

            data = up.hashes[i].value.split("$")
            # Check we got the expected crypt algorithm
            self.assertEquals(alg, data[1])

            if rounds is None:
                cmd = "$%s$%s" % (alg, data[2])
            else:
                cmd = "$%s$rounds=%d$%s" % (alg, rounds, data[3])

            # Calculate the expected hash value
            expected = crypt.crypt(USER_PASS, cmd)
            self.assertEquals(expected, up.hashes[i].value)
            i += 1

    # Check that the correct nt_hash was stored for userPassword
    def checkNtHash(self, password, nt_hash):
        creds = Credentials()
        creds.set_anonymous()
        creds.set_password(password)
        expected = creds.get_nt_hash()
        actual = bytearray(nt_hash)
        self.assertEquals(expected, actual)
Exemple #14
0
class DsdbLockTestCase(SamDBTestCase):
    def test_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del(self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               lp=self.lp)

            self.samdb.transaction_start()

            dn = "cn=test_db_lock_user,cn=users," + str(basedn)
            self.samdb.add({
                 "dn": dn,
                 "objectclass": "user",
            })
            self.samdb.delete(dn)

            # Obtain a write lock
            self.samdb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            self.samdb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        # Release the locks
        for l in res:
            pass

        end = time.time()
        self.assertGreater(end - start, 1.9)

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del(self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        dn = "cn=test_db_lock_user,cn=users," + str(basedn)
        self.samdb.add({
             "dn": dn,
             "objectclass": "user",
        })
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the parent releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            self.samdb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock3(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del(self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")

        # This will end up in the top level db
        dn = "@DSDB_LOCK_TEST"
        self.samdb.add({
             "dn": dn})
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        self.assertGreater(end - start, 1.9)
        os.write(w1, b"prepared")

        # Drop the write lock
        self.samdb.transaction_cancel()

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)
        self.assertEqual(got_pid, pid)


    def _test_full_db_lock1(self, backend_path):
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del(self.samdb)
            gc.collect()

            backenddb = ldb.Ldb(backend_path)


            backenddb.transaction_start()

            backenddb.add({"dn":"@DSDB_LOCK_TEST"})
            backenddb.delete("@DSDB_LOCK_TEST")

            # Obtain a write lock
            backenddb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            backenddb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        end = time.time()
        self.assertGreater(end - start, 1.9)

        # Release the locks
        for l in res:
            pass

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d",
                                       backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)


    def test_full_db_lock1_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d",
                                       backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)


    def _test_full_db_lock2(self, backend_path):
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:

            # In the child, close the main DB, re-open
            del(self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)
            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # In the parent, close the main DB, re-open just one DB
        del(self.samdb)
        gc.collect()
        backenddb = ldb.Ldb(backend_path)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        backenddb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        backenddb.add({"dn":"@DSDB_LOCK_TEST"})
        backenddb.delete("@DSDB_LOCK_TEST")
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        backenddb.transaction_prepare_commit()
        end = time.time()

        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            backenddb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d",
                                       backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)

    def test_full_db_lock2_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d",
                                       backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)
Exemple #15
0
class DsdbTests(TestCase):

    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

        # Create a test user
        user_name = "dsdb-user-" + str(uuid.uuid4().hex[0:6])
        user_pass = samba.generate_random_password(32, 32)
        user_description = "Test user for dsdb test"

        base_dn = self.samdb.domain_dn()

        self.account_dn = "cn=" + user_name + ",cn=Users," + base_dn
        self.samdb.newuser(username=user_name,
                           password=user_pass,
                           description=user_description)
        # Cleanup (teardown)
        self.addCleanup(delete_force, self.samdb, self.account_dn)

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEquals(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0", "local_oid:1.3.6.1.4.1.7165.4.3.25:0"])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val", ldb.FLAG_MOD_REPLACE, "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(13), "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 2)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description", version + 2)
        self.assertEqual(self.samdb.get_attribute_replmetadata_version(dn, "description"), version + 2)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                    base=self.account_dn,
                                    attrs=["replPropertyMetaData"],
                                    controls=["local_oid:%s:0"
                                              % dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                    base=self.account_dn,
                                    attrs=["replPropertyMetaData"],
                                    controls=["local_oid:%s:1"
                                              % dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED])
        except ldb.LdbError as e:
            (errno, estr) = e.args
            if errno != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail("Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                          % e[1])

    # Allocate a unique RID for use in the objectSID tests.
    #
    def allocate_rid(self):
        self.samdb.transaction_start()
        try:
            rid = self.samdb.allocate_rid()
        except:
            self.samdb.transaction_cancel()
            raise
        self.samdb.transaction_commit()
        return str(rid)

    # Ensure that duplicate objectSID's are permitted for foreign security
    # principals.
    #
    def test_duplicate_objectSIDs_allowed_on_foreign_security_principals(self):

        #
        # We need to build a foreign security principal SID
        # i.e a  SID not in the current domain.
        #
        dom_sid = self.samdb.get_domain_sid()
        if str(dom_sid).endswith("0"):
            c = "9"
        else:
            c = "0"
        sid_str = str(dom_sid)[:-1] + c + "-1000"
        sid     = ndr_pack(security.dom_sid(sid_str))
        basedn  = self.samdb.get_default_basedn()
        dn      = "CN=%s,CN=ForeignSecurityPrincipals,%s" % (sid_str, basedn)

        #
        # First without control
        #

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"})
            self.fail("No exception should get ERR_OBJECT_CLASS_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_OBJECT_CLASS_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_MISSING_REQUIRED_ATT
            self.assertTrue(werr in msg, msg)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal",
                "objectSid": sid})
            self.fail("No exception should get ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_ILLEGAL_MOD_OPERATION
            self.assertTrue(werr in msg, msg)

        #
        # We need to use the provision control
        # in order to add foreignSecurityPrincipal
        # objects
        #

        controls = ["provision:0"]
        self.samdb.add({
            "dn": dn,
            "objectClass": "foreignSecurityPrincipal"},
            controls=controls)

        self.samdb.delete(dn)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"},
                controls=controls)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.fail("Got unexpected exception %d - %s "
                      % (code, msg))

        # cleanup
        self.samdb.delete(dn)

    def _test_foreignSecurityPrincipal(self, obj_class, fpo_attr):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn   = self.samdb.get_default_basedn()
        cn       = "dsdb_test_fpo"
        dn_str   = "cn=%s,cn=Users,%s" % (cn, basedn)
        dn = ldb.Dn(self.samdb, dn_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn_str)

        self.samdb.add({
            "dn": dn_str,
            "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_INVALID_GROUP_TYPE
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_NO_SUCH_OBJECT")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_NO_SUCH_OBJECT, str(e))
            werr = "%08X" % werror.WERR_NO_SUCH_MEMBER
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 1)
        self.samdb.delete(res[0].dn)
        self.samdb.delete(dn)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

    def test_foreignSecurityPrincipal_member(self):
        return self._test_foreignSecurityPrincipal(
                "group", "member")

    def test_foreignSecurityPrincipal_MembersForAzRole(self):
        return self._test_foreignSecurityPrincipal(
                "msDS-AzRole", "msDS-MembersForAzRole")

    def test_foreignSecurityPrincipal_NeverRevealGroup(self):
        return self._test_foreignSecurityPrincipal(
                "computer", "msDS-NeverRevealGroup")

    def test_foreignSecurityPrincipal_RevealOnDemandGroup(self):
        return self._test_foreignSecurityPrincipal(
                "computer", "msDS-RevealOnDemandGroup")

    def _test_fail_foreignSecurityPrincipal(self, obj_class, fpo_attr,
                                            msg_exp, lerr_exp, werr_exp,
                                            allow_reference=True):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn   = self.samdb.get_default_basedn()
        cn1       = "dsdb_test_fpo1"
        dn1_str   = "cn=%s,cn=Users,%s" % (cn1, basedn)
        dn1 = ldb.Dn(self.samdb, dn1_str)
        cn2       = "dsdb_test_fpo2"
        dn2_str   = "cn=%s,cn=Users,%s" % (cn2, basedn)
        dn2 = ldb.Dn(self.samdb, dn2_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn1_str)
        self.addCleanup(delete_force, self.samdb, dn2_str)

        self.samdb.add({
            "dn": dn1_str,
            "objectClass": obj_class})

        self.samdb.add({
            "dn": dn2_str,
            "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("%s" % dn2,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            if not allow_reference:
                sel.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            if allow_reference:
                self.fail("Should have not raised an exception: %s" % e)
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(dn2)
        self.samdb.delete(dn1)

    def test_foreignSecurityPrincipal_NonMembers(self):
        return self._test_fail_foreignSecurityPrincipal(
                "group", "msDS-NonMembers",
                "LDB_ERR_UNWILLING_TO_PERFORM/WERR_NOT_SUPPORTED",
                ldb.ERR_UNWILLING_TO_PERFORM, werror.WERR_NOT_SUPPORTED,
                allow_reference=False)

    def test_foreignSecurityPrincipal_HostServiceAccount(self):
        return self._test_fail_foreignSecurityPrincipal(
                "computer", "msDS-HostServiceAccount",
                "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
                ldb.ERR_CONSTRAINT_VIOLATION,
                werror.WERR_DS_NAME_REFERENCE_INVALID)

    def test_foreignSecurityPrincipal_manager(self):
        return self._test_fail_foreignSecurityPrincipal(
                "user", "manager",
                "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
                ldb.ERR_CONSTRAINT_VIOLATION,
                werror.WERR_DS_NAME_REFERENCE_INVALID)

    #
    # Duplicate objectSID's should not be permitted for sids in the local
    # domain. The test sequence is add an object, delete it, then attempt to
    # re-add it, this should fail with a constraint violation
    #
    def test_duplicate_objectSIDs_not_allowed_on_local_objects(self):

        dom_sid = self.samdb.get_domain_sid()
        rid     = self.allocate_rid()
        sid_str = str(dom_sid) + "-" + rid
        sid     = ndr_pack(security.dom_sid(sid_str))
        basedn  = self.samdb.get_default_basedn()
        cn       = "dsdb_test_01"
        dn      = "cn=%s,cn=Users,%s" % (cn, basedn)

        self.samdb.add({
            "dn": dn,
            "objectClass": "user",
            "objectSID": sid})
        self.samdb.delete(dn)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "user",
                "objectSID": sid})
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            if code != ldb.ERR_CONSTRAINT_VIOLATION:
                self.fail("Got %d - %s should have got "
                          "LDB_ERR_CONSTRAINT_VIOLATION"
                          % (code, msg))

    def test_linked_vs_non_linked_reference(self):
        basedn   = self.samdb.get_default_basedn()
        kept_dn_str   = "cn=reference_kept,cn=Users,%s" % (basedn)
        removed_dn_str   = "cn=reference_removed,cn=Users,%s" % (basedn)
        dom_sid = self.samdb.get_domain_sid()
        none_sid_str = str(dom_sid) + "-4294967294"
        none_guid_str = "afafafaf-fafa-afaf-fafa-afafafafafaf"

        self.addCleanup(delete_force, self.samdb, kept_dn_str)
        self.addCleanup(delete_force, self.samdb, removed_dn_str)

        self.samdb.add({
            "dn": kept_dn_str,
            "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=kept_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        kept_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        kept_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        kept_dn = res[0].dn

        self.samdb.add({
            "dn": removed_dn_str,
            "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=removed_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        removed_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        removed_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        self.samdb.delete(removed_dn_str)

        #
        # First try the linked attribute 'manager'
        # by GUID and SID
        #

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                           ldb.FLAG_MOD_ADD,
                                           "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                           ldb.FLAG_MOD_ADD,
                                           "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        #
        # Try the non-linked attribute 'assistant'
        # by GUID and SID, which should work.
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_DELETE,
                                              "assistant")
        self.samdb.modify(msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_DELETE,
                                              "assistant")
        self.samdb.modify(msg)

        #
        # Finally ry the non-linked attribute 'assistant'
        # but with non existing GUID, SID, DN
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("CN=NoneNone,%s" % (basedn),
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % none_sid_str,
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % none_guid_str,
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(kept_dn)

    def test_normalize_dn_in_domain_full(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        full_str = str(full_dn)

        # That is, no change
        self.assertEqual(full_dn,
                         self.samdb.normalize_dn_in_domain(full_str))

    def test_normalize_dn_in_domain_part(self):
        domain_dn = self.samdb.domain_dn()

        part_str = "CN=Users"

        full_dn = ldb.Dn(self.samdb, part_str)
        full_dn.add_base(domain_dn)

        # That is, the domain DN appended
        self.assertEqual(full_dn,
                         self.samdb.normalize_dn_in_domain(part_str))

    def test_normalize_dn_in_domain_full_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        # That is, no change
        self.assertEqual(full_dn,
                         self.samdb.normalize_dn_in_domain(full_dn))

    def test_normalize_dn_in_domain_part_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        # That is, the domain DN appended
        self.assertEqual(ldb.Dn(self.samdb,
                                str(part_dn) + "," + str(domain_dn)),
                         self.samdb.normalize_dn_in_domain(part_dn))