Example #1
0
class DirsyncBaseTests(samba.tests.TestCase):
    def setUp(self):
        super(DirsyncBaseTests, self).setUp()
        self.ldb_admin = SamDB(ldapshost,
                               credentials=creds,
                               session_info=system_session(lp),
                               lp=lp)
        self.base_dn = self.ldb_admin.domain_dn()
        self.domain_sid = security.dom_sid(self.ldb_admin.get_domain_sid())
        self.user_pass = samba.generate_random_password(12, 16)
        self.configuration_dn = self.ldb_admin.get_config_basedn(
        ).get_linearized()
        self.sd_utils = sd_utils.SDUtils(self.ldb_admin)
        #used for anonymous login
        print("baseDN: %s" % self.base_dn)

    def get_user_dn(self, name):
        return "CN=%s,CN=Users,%s" % (name, self.base_dn)

    def get_ldb_connection(self, target_username, target_password):
        creds_tmp = Credentials()
        creds_tmp.set_username(target_username)
        creds_tmp.set_password(target_password)
        creds_tmp.set_domain(creds.get_domain())
        creds_tmp.set_realm(creds.get_realm())
        creds_tmp.set_workstation(creds.get_workstation())
        creds_tmp.set_gensec_features(creds_tmp.get_gensec_features()
                                      | gensec.FEATURE_SEAL)
        creds_tmp.set_kerberos_state(
            DONT_USE_KERBEROS)  # kinit is too expensive to use in a tight loop
        ldb_target = SamDB(url=ldaphost, credentials=creds_tmp, lp=lp)
        return ldb_target
Example #2
0
class DirsyncBaseTests(samba.tests.TestCase):

    def setUp(self):
        super(DirsyncBaseTests, self).setUp()
        self.ldb_admin = SamDB(ldapshost, credentials=creds, session_info=system_session(lp), lp=lp)
        self.base_dn = self.ldb_admin.domain_dn()
        self.domain_sid = security.dom_sid(self.ldb_admin.get_domain_sid())
        self.user_pass = samba.generate_random_password(12, 16)
        self.configuration_dn = self.ldb_admin.get_config_basedn().get_linearized()
        self.sd_utils = sd_utils.SDUtils(self.ldb_admin)
        #used for anonymous login
        print("baseDN: %s" % self.base_dn)

    def get_user_dn(self, name):
        return "CN=%s,CN=Users,%s" % (name, self.base_dn)

    def get_ldb_connection(self, target_username, target_password):
        creds_tmp = Credentials()
        creds_tmp.set_username(target_username)
        creds_tmp.set_password(target_password)
        creds_tmp.set_domain(creds.get_domain())
        creds_tmp.set_realm(creds.get_realm())
        creds_tmp.set_workstation(creds.get_workstation())
        creds_tmp.set_gensec_features(creds_tmp.get_gensec_features()
                                      | gensec.FEATURE_SEAL)
        creds_tmp.set_kerberos_state(DONT_USE_KERBEROS) # kinit is too expensive to use in a tight loop
        ldb_target = SamDB(url=ldaphost, credentials=creds_tmp, lp=lp)
        return ldb_target
Example #3
0
    def run(self, sambaopts=None, credopts=None, server=None, targetdir=None,
            no_secrets=False, backend_store=None):
        logger = self.get_logger()
        logger.setLevel(logging.DEBUG)

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        # Make sure we have all the required args.
        if server is None:
            raise CommandError('Server required')

        check_targetdir(logger, targetdir)

        tmpdir = tempfile.mkdtemp(dir=targetdir)

        # Run a clone join on the remote
        include_secrets = not no_secrets
        ctx = join_clone(logger=logger, creds=creds, lp=lp,
                         include_secrets=include_secrets, server=server,
                         dns_backend='SAMBA_INTERNAL', targetdir=tmpdir,
                         backend_store=backend_store)

        # get the paths used for the clone, then drop the old samdb connection
        paths = ctx.paths
        del ctx

        # Get a free RID to use as the new DC's SID (when it gets restored)
        remote_sam = SamDB(url='ldap://' + server, credentials=creds,
                           session_info=system_session(), lp=lp)
        new_sid = get_sid_for_restore(remote_sam, logger)
        realm = remote_sam.domain_dns_name()

        # Grab the remote DC's sysvol files and bundle them into a tar file
        sysvol_tar = os.path.join(tmpdir, 'sysvol.tar.gz')
        smb_conn = smb_sysvol_conn(server, lp, creds)
        backup_online(smb_conn, sysvol_tar, remote_sam.get_domain_sid())

        # remove the default sysvol files created by the clone (we want to
        # make sure we restore the sysvol.tar.gz files instead)
        shutil.rmtree(paths.sysvol)

        # Edit the downloaded sam.ldb to mark it as a backup
        samdb = SamDB(url=paths.samdb, session_info=system_session(), lp=lp)
        time_str = get_timestamp()
        add_backup_marker(samdb, "backupDate", time_str)
        add_backup_marker(samdb, "sidForRestore", new_sid)
        add_backup_marker(samdb, "backupType", "online")

        # ensure the admin user always has a password set (same as provision)
        if no_secrets:
            set_admin_password(logger, samdb)

        # Add everything in the tmpdir to the backup tar file
        backup_file = backup_filepath(targetdir, realm, time_str)
        create_log_file(tmpdir, lp, "online", server, include_secrets)
        create_backup_tar(logger, tmpdir, backup_file)

        shutil.rmtree(tmpdir)
Example #4
0
    def run(self, sambaopts=None, credopts=None, server=None, targetdir=None):
        logger = self.get_logger()
        logger.setLevel(logging.DEBUG)

        # Make sure we have all the required args.
        check_online_backup_args(logger, credopts, server, targetdir)

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        if not os.path.exists(targetdir):
            logger.info('Creating targetdir %s...' % targetdir)
            os.makedirs(targetdir)

        tmpdir = tempfile.mkdtemp(dir=targetdir)

        # Run a clone join on the remote
        ctx = join_clone(logger=logger,
                         creds=creds,
                         lp=lp,
                         include_secrets=True,
                         dns_backend='SAMBA_INTERNAL',
                         server=server,
                         targetdir=tmpdir)

        # get the paths used for the clone, then drop the old samdb connection
        paths = ctx.paths
        del ctx

        # Get a free RID to use as the new DC's SID (when it gets restored)
        remote_sam = SamDB(url='ldap://' + server,
                           credentials=creds,
                           session_info=system_session(),
                           lp=lp)
        new_sid = get_sid_for_restore(remote_sam)
        realm = remote_sam.domain_dns_name()

        # Grab the remote DC's sysvol files and bundle them into a tar file
        sysvol_tar = os.path.join(tmpdir, 'sysvol.tar.gz')
        smb_conn = smb.SMB(server, "sysvol", lp=lp, creds=creds)
        backup_online(smb_conn, sysvol_tar, remote_sam.get_domain_sid())

        # remove the default sysvol files created by the clone (we want to
        # make sure we restore the sysvol.tar.gz files instead)
        shutil.rmtree(paths.sysvol)

        # Edit the downloaded sam.ldb to mark it as a backup
        samdb = SamDB(url=paths.samdb, session_info=system_session(), lp=lp)
        time_str = get_timestamp()
        add_backup_marker(samdb, "backupDate", time_str)
        add_backup_marker(samdb, "sidForRestore", new_sid)

        # Add everything in the tmpdir to the backup tar file
        backup_file = backup_filepath(targetdir, realm, time_str)
        create_backup_tar(logger, tmpdir, backup_file)

        shutil.rmtree(tmpdir)
Example #5
0
class SitesBaseTests(samba.tests.TestCase):

    def setUp(self):
        super(SitesBaseTests, self).setUp()
        self.ldb = SamDB(ldaphost, credentials=creds,
                         session_info=system_session(lp), lp=lp)
        self.base_dn = self.ldb.domain_dn()
        self.domain_sid = security.dom_sid(self.ldb.get_domain_sid())
        self.configuration_dn = self.ldb.get_config_basedn().get_linearized()

    def get_user_dn(self, name):
        return "CN={0!s},CN=Users,{1!s}".format(name, self.base_dn)
Example #6
0
class SitesBaseTests(samba.tests.TestCase):

    def setUp(self):
        super(SitesBaseTests, self).setUp()
        self.ldb = SamDB(ldaphost, credentials=creds,
                         session_info=system_session(lp), lp=lp)
        self.base_dn = self.ldb.domain_dn()
        self.domain_sid = security.dom_sid(self.ldb.get_domain_sid())
        self.configuration_dn = self.ldb.get_config_basedn().get_linearized()

    def get_user_dn(self, name):
        return "CN=%s,CN=Users,%s" % (name, self.base_dn)
Example #7
0
class DsdbTests(TestCase):
    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

        # Create a test user
        user_name = "dsdb-user-" + str(uuid.uuid4().hex[0:6])
        user_pass = samba.generate_random_password(32, 32)
        user_description = "Test user for dsdb test"

        base_dn = self.samdb.domain_dn()

        self.account_dn = "CN=" + user_name + ",CN=Users," + base_dn
        self.samdb.newuser(username=user_name,
                           password=user_pass,
                           description=user_description)
        # Cleanup (teardown)
        self.addCleanup(delete_force, self.samdb, self.account_dn)

        # Get server reference DN
        res = self.samdb.search(base=ldb.Dn(self.samdb,
                                            self.samdb.get_serverName()),
                                scope=ldb.SCOPE_BASE,
                                attrs=["serverReference"])
        # Get server reference
        self.server_ref_dn = ldb.Dn(
            self.samdb, res[0]["serverReference"][0].decode("utf-8"))

        # Get RID Set DN
        res = self.samdb.search(base=self.server_ref_dn,
                                scope=ldb.SCOPE_BASE,
                                attrs=["rIDSetReferences"])
        rid_set_refs = res[0]
        self.assertIn("rIDSetReferences", rid_set_refs)
        rid_set_str = rid_set_refs["rIDSetReferences"][0].decode("utf-8")
        self.rid_set_dn = ldb.Dn(self.samdb, rid_set_str)

    def get_rid_set(self, rid_set_dn):
        res = self.samdb.search(base=rid_set_dn,
                                scope=ldb.SCOPE_BASE,
                                attrs=[
                                    "rIDAllocationPool",
                                    "rIDPreviousAllocationPool", "rIDUsedPool",
                                    "rIDNextRID"
                                ])
        return res[0]

    def test_ridalloc_next_free_rid(self):
        # Test RID allocation. We assume that RID
        # pools allocated to us are continguous.
        self.samdb.transaction_start()
        try:
            orig_rid_set = self.get_rid_set(self.rid_set_dn)
            self.assertIn("rIDAllocationPool", orig_rid_set)
            self.assertIn("rIDPreviousAllocationPool", orig_rid_set)
            self.assertIn("rIDUsedPool", orig_rid_set)
            self.assertIn("rIDNextRID", orig_rid_set)

            # Get rIDNextRID value from RID set.
            next_rid = int(orig_rid_set["rIDNextRID"][0])

            # Check the result of next_free_rid().
            next_free_rid = self.samdb.next_free_rid()
            self.assertEqual(next_rid + 1, next_free_rid)

            # Check calling it twice in succession gives the same result.
            next_free_rid2 = self.samdb.next_free_rid()
            self.assertEqual(next_free_rid, next_free_rid2)

            # Ensure that the RID set attributes have not changed.
            rid_set2 = self.get_rid_set(self.rid_set_dn)
            self.assertEqual(orig_rid_set, rid_set2)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_no_ridnextrid(self):
        self.samdb.transaction_start()
        try:
            # Delete the rIDNextRID attribute of the RID set,
            # and set up previous and next pools.
            prev_lo = 1000
            prev_hi = 1999
            next_lo = 3000
            next_hi = 3999
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDNextRID"] = ldb.MessageElement([], ldb.FLAG_MOD_DELETE,
                                                   "rIDNextRID")
            msg["rIDPreviousAllocationPool"] = (ldb.MessageElement(
                str((prev_hi << 32) | prev_lo), ldb.FLAG_MOD_REPLACE,
                "rIDPreviousAllocationPool"))
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                str((next_hi << 32) | next_lo), ldb.FLAG_MOD_REPLACE,
                "rIDAllocationPool"))
            self.samdb.modify(msg)

            # Ensure that next_free_rid() returns the start of the next pool.
            next_free_rid3 = self.samdb.next_free_rid()
            self.assertEqual(next_lo, next_free_rid3)

            # Check the result of allocate_rid() matches.
            rid = self.samdb.allocate_rid()
            self.assertEqual(next_free_rid3, rid)

            # Check that the result of next_free_rid() has now changed.
            next_free_rid4 = self.samdb.next_free_rid()
            self.assertEqual(rid + 1, next_free_rid4)

            # Check the range of available RIDs.
            free_lo, free_hi = self.samdb.free_rid_bounds()
            self.assertEqual(rid + 1, free_lo)
            self.assertEqual(next_hi, free_hi)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_no_free_rids(self):
        self.samdb.transaction_start()
        try:
            # Exhaust our current pool of RIDs.
            pool_lo = 2000
            pool_hi = 2999
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDPreviousAllocationPool"] = (ldb.MessageElement(
                str((pool_hi << 32) | pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDPreviousAllocationPool"))
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                str((pool_hi << 32) | pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDAllocationPool"))
            msg["rIDNextRID"] = (ldb.MessageElement(str(pool_hi),
                                                    ldb.FLAG_MOD_REPLACE,
                                                    "rIDNextRID"))
            self.samdb.modify(msg)

            # Ensure that calculating the next free RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.next_free_rid()

            self.assertEqual("RID pools out of RIDs", err.exception.args[1])

            # Ensure we can still allocate a new RID.
            self.samdb.allocate_rid()
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_new_ridset(self):
        self.samdb.transaction_start()
        try:
            # Test what happens with RID Set values set to zero (similar to
            # when a RID Set is first created, except we also set
            # rIDAllocationPool to zero).
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDPreviousAllocationPool"] = (ldb.MessageElement(
                "0", ldb.FLAG_MOD_REPLACE, "rIDPreviousAllocationPool"))
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                "0", ldb.FLAG_MOD_REPLACE, "rIDAllocationPool"))
            msg["rIDNextRID"] = (ldb.MessageElement("0", ldb.FLAG_MOD_REPLACE,
                                                    "rIDNextRID"))
            self.samdb.modify(msg)

            # Ensure that calculating the next free RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.next_free_rid()

            self.assertEqual("RID pools out of RIDs", err.exception.args[1])

            # Set values for the next pool.
            pool_lo = 2000
            pool_hi = 2999
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                str((pool_hi << 32) | pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDAllocationPool"))
            self.samdb.modify(msg)

            # Ensure the next free RID value is equal to the next pool's lower
            # bound.
            next_free_rid5 = self.samdb.next_free_rid()
            self.assertEqual(pool_lo, next_free_rid5)

            # Check the range of available RIDs.
            free_lo, free_hi = self.samdb.free_rid_bounds()
            self.assertEqual(pool_lo, free_lo)
            self.assertEqual(pool_hi, free_hi)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_move_to_new_pool(self):
        self.samdb.transaction_start()
        try:
            # Test moving to a new pool from the previous pool.
            pool_lo = 2000
            pool_hi = 2999
            new_pool_lo = 4500
            new_pool_hi = 4599
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDPreviousAllocationPool"] = (ldb.MessageElement(
                str((pool_hi << 32) | pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDPreviousAllocationPool"))
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                str((new_pool_hi << 32) | new_pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDAllocationPool"))
            msg["rIDNextRID"] = (ldb.MessageElement(str(pool_hi - 1),
                                                    ldb.FLAG_MOD_REPLACE,
                                                    "rIDNextRID"))
            self.samdb.modify(msg)

            # We should have remained in the previous pool.
            next_free_rid6 = self.samdb.next_free_rid()
            self.assertEqual(pool_hi, next_free_rid6)

            # Check the range of available RIDs.
            free_lo, free_hi = self.samdb.free_rid_bounds()
            self.assertEqual(pool_hi, free_lo)
            self.assertEqual(pool_hi, free_hi)

            # Allocate a new RID.
            rid2 = self.samdb.allocate_rid()
            self.assertEqual(next_free_rid6, rid2)

            # We should now move to the next pool.
            next_free_rid7 = self.samdb.next_free_rid()
            self.assertEqual(new_pool_lo, next_free_rid7)

            # Check the new range of available RIDs.
            free_lo2, free_hi2 = self.samdb.free_rid_bounds()
            self.assertEqual(new_pool_lo, free_lo2)
            self.assertEqual(new_pool_hi, free_hi2)

            # Ensure that allocate_rid() matches.
            rid3 = self.samdb.allocate_rid()
            self.assertEqual(next_free_rid7, rid3)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_no_ridsetreferences(self):
        self.samdb.transaction_start()
        try:
            # Delete the rIDSetReferences attribute.
            msg = ldb.Message()
            msg.dn = self.server_ref_dn
            msg["rIDSetReferences"] = (ldb.MessageElement([],
                                                          ldb.FLAG_MOD_DELETE,
                                                          "rIDSetReferences"))
            self.samdb.modify(msg)

            # Ensure calculating the next free RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.next_free_rid()

            enum, estr = err.exception.args
            self.assertEqual(ldb.ERR_NO_SUCH_ATTRIBUTE, enum)
            self.assertIn(
                "No RID Set DN - "
                "Cannot find attribute rIDSetReferences of %s "
                "to calculate reference dn" % self.server_ref_dn, estr)

            # Ensure allocating a new RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.allocate_rid()

            enum, estr = err.exception.args
            self.assertEqual(ldb.ERR_ENTRY_ALREADY_EXISTS, enum)
            self.assertIn(
                "No RID Set DN - "
                "Failed to add RID Set %s - "
                "Entry %s already exists" % (self.rid_set_dn, self.rid_set_dn),
                estr)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_no_rid_set(self):
        self.samdb.transaction_start()
        try:
            # Set the rIDSetReferences attribute to not point to a RID Set.
            fake_rid_set_str = self.account_dn
            msg = ldb.Message()
            msg.dn = self.server_ref_dn
            msg["rIDSetReferences"] = (ldb.MessageElement(
                fake_rid_set_str, ldb.FLAG_MOD_REPLACE, "rIDSetReferences"))
            self.samdb.modify(msg)

            # Ensure calculating the next free RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.next_free_rid()

            enum, estr = err.exception.args
            self.assertEqual(ldb.ERR_OPERATIONS_ERROR, enum)
            self.assertIn("Bad RID Set " + fake_rid_set_str, estr)

            # Ensure allocating a new RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.allocate_rid()

            enum, estr = err.exception.args
            self.assertEqual(ldb.ERR_OPERATIONS_ERROR, enum)
            self.assertIn("Bad RID Set " + fake_rid_set_str, estr)
        finally:
            self.samdb.transaction_cancel()

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEqual(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, [
            "local_oid:1.3.6.1.4.1.7165.4.3.14:0",
            "local_oid:1.3.6.1.4.1.7165.4.3.25:0"
        ])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val",
                                                ldb.FLAG_MOD_REPLACE,
                                                "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEqual(self.samdb.get_attribute_from_attid(13),
                         "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEqual(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEqual(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 2)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEqual(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(
            dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description",
                                                      version + 2)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "description"),
            version + 2)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:0" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:1" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            (errno, estr) = e.args
            if errno != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail(
                    "Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                    % e[1])

    # Allocate a unique RID for use in the objectSID tests.
    #
    def allocate_rid(self):
        self.samdb.transaction_start()
        try:
            rid = self.samdb.allocate_rid()
        except:
            self.samdb.transaction_cancel()
            raise
        self.samdb.transaction_commit()
        return str(rid)

    # Ensure that duplicate objectSID's are permitted for foreign security
    # principals.
    #
    def test_duplicate_objectSIDs_allowed_on_foreign_security_principals(self):

        #
        # We need to build a foreign security principal SID
        # i.e a  SID not in the current domain.
        #
        dom_sid = self.samdb.get_domain_sid()
        if str(dom_sid).endswith("0"):
            c = "9"
        else:
            c = "0"
        sid_str = str(dom_sid)[:-1] + c + "-1000"
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        dn = "CN=%s,CN=ForeignSecurityPrincipals,%s" % (sid_str, basedn)

        #
        # First without control
        #

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"
            })
            self.fail("No exception should get ERR_OBJECT_CLASS_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_OBJECT_CLASS_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_MISSING_REQUIRED_ATT
            self.assertTrue(werr in msg, msg)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal",
                "objectSid": sid
            })
            self.fail("No exception should get ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_ILLEGAL_MOD_OPERATION
            self.assertTrue(werr in msg, msg)

        #
        # We need to use the provision control
        # in order to add foreignSecurityPrincipal
        # objects
        #

        controls = ["provision:0"]
        self.samdb.add({
            "dn": dn,
            "objectClass": "foreignSecurityPrincipal"
        },
                       controls=controls)

        self.samdb.delete(dn)

        try:
            self.samdb.add(
                {
                    "dn": dn,
                    "objectClass": "foreignSecurityPrincipal"
                },
                controls=controls)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.fail("Got unexpected exception %d - %s " % (code, msg))

        # cleanup
        self.samdb.delete(dn)

    def _test_foreignSecurityPrincipal(self, obj_class, fpo_attr):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_fpo"
        dn_str = "cn=%s,cn=Users,%s" % (cn, basedn)
        dn = ldb.Dn(self.samdb, dn_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn_str)

        self.samdb.add({"dn": dn_str, "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_INVALID_GROUP_TYPE
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_NO_SUCH_OBJECT")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_NO_SUCH_OBJECT, str(e))
            werr = "%08X" % werror.WERR_NO_SUCH_MEMBER
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 1)
        self.samdb.delete(res[0].dn)
        self.samdb.delete(dn)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

    def test_foreignSecurityPrincipal_member(self):
        return self._test_foreignSecurityPrincipal("group", "member")

    def test_foreignSecurityPrincipal_MembersForAzRole(self):
        return self._test_foreignSecurityPrincipal("msDS-AzRole",
                                                   "msDS-MembersForAzRole")

    def test_foreignSecurityPrincipal_NeverRevealGroup(self):
        return self._test_foreignSecurityPrincipal("computer",
                                                   "msDS-NeverRevealGroup")

    def test_foreignSecurityPrincipal_RevealOnDemandGroup(self):
        return self._test_foreignSecurityPrincipal("computer",
                                                   "msDS-RevealOnDemandGroup")

    def _test_fail_foreignSecurityPrincipal(self,
                                            obj_class,
                                            fpo_attr,
                                            msg_exp,
                                            lerr_exp,
                                            werr_exp,
                                            allow_reference=True):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn = self.samdb.get_default_basedn()
        cn1 = "dsdb_test_fpo1"
        dn1_str = "cn=%s,cn=Users,%s" % (cn1, basedn)
        dn1 = ldb.Dn(self.samdb, dn1_str)
        cn2 = "dsdb_test_fpo2"
        dn2_str = "cn=%s,cn=Users,%s" % (cn2, basedn)
        dn2 = ldb.Dn(self.samdb, dn2_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn1_str)
        self.addCleanup(delete_force, self.samdb, dn2_str)

        self.samdb.add({"dn": dn1_str, "objectClass": obj_class})

        self.samdb.add({"dn": dn2_str, "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("%s" % dn2, ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            if not allow_reference:
                self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            if allow_reference:
                self.fail("Should have not raised an exception: %s" % e)
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(dn2)
        self.samdb.delete(dn1)

    def test_foreignSecurityPrincipal_NonMembers(self):
        return self._test_fail_foreignSecurityPrincipal(
            "group",
            "msDS-NonMembers",
            "LDB_ERR_UNWILLING_TO_PERFORM/WERR_NOT_SUPPORTED",
            ldb.ERR_UNWILLING_TO_PERFORM,
            werror.WERR_NOT_SUPPORTED,
            allow_reference=False)

    def test_foreignSecurityPrincipal_HostServiceAccount(self):
        return self._test_fail_foreignSecurityPrincipal(
            "computer", "msDS-HostServiceAccount",
            "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
            ldb.ERR_CONSTRAINT_VIOLATION,
            werror.WERR_DS_NAME_REFERENCE_INVALID)

    def test_foreignSecurityPrincipal_manager(self):
        return self._test_fail_foreignSecurityPrincipal(
            "user", "manager",
            "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
            ldb.ERR_CONSTRAINT_VIOLATION,
            werror.WERR_DS_NAME_REFERENCE_INVALID)

    #
    # Duplicate objectSID's should not be permitted for sids in the local
    # domain. The test sequence is add an object, delete it, then attempt to
    # re-add it, this should fail with a constraint violation
    #
    def test_duplicate_objectSIDs_not_allowed_on_local_objects(self):

        dom_sid = self.samdb.get_domain_sid()
        rid = self.allocate_rid()
        sid_str = str(dom_sid) + "-" + rid
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_01"
        dn = "cn=%s,cn=Users,%s" % (cn, basedn)

        self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
        self.samdb.delete(dn)

        try:
            self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            if code != ldb.ERR_CONSTRAINT_VIOLATION:
                self.fail("Got %d - %s should have got "
                          "LDB_ERR_CONSTRAINT_VIOLATION" % (code, msg))

    def test_linked_vs_non_linked_reference(self):
        basedn = self.samdb.get_default_basedn()
        kept_dn_str = "cn=reference_kept,cn=Users,%s" % (basedn)
        removed_dn_str = "cn=reference_removed,cn=Users,%s" % (basedn)
        dom_sid = self.samdb.get_domain_sid()
        none_sid_str = str(dom_sid) + "-4294967294"
        none_guid_str = "afafafaf-fafa-afaf-fafa-afafafafafaf"

        self.addCleanup(delete_force, self.samdb, kept_dn_str)
        self.addCleanup(delete_force, self.samdb, removed_dn_str)

        self.samdb.add({"dn": kept_dn_str, "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=kept_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        kept_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        kept_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        kept_dn = res[0].dn

        self.samdb.add({"dn": removed_dn_str, "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=removed_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        removed_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        removed_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        self.samdb.delete(removed_dn_str)

        #
        # First try the linked attribute 'manager'
        # by GUID and SID
        #

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                            ldb.FLAG_MOD_ADD, "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                            ldb.FLAG_MOD_ADD, "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        #
        # Try the non-linked attribute 'assistant'
        # by GUID and SID, which should work.
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_ADD, "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_DELETE, "assistant")
        self.samdb.modify(msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_ADD, "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_DELETE, "assistant")
        self.samdb.modify(msg)

        #
        # Finally ry the non-linked attribute 'assistant'
        # but with non existing GUID, SID, DN
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("CN=NoneNone,%s" % (basedn),
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % none_sid_str,
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % none_guid_str,
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(kept_dn)

    def test_normalize_dn_in_domain_full(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        full_str = str(full_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_str))

    def test_normalize_dn_in_domain_part(self):
        domain_dn = self.samdb.domain_dn()

        part_str = "CN=Users"

        full_dn = ldb.Dn(self.samdb, part_str)
        full_dn.add_base(domain_dn)

        # That is, the domain DN appended
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(part_str))

    def test_normalize_dn_in_domain_full_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_dn))

    def test_normalize_dn_in_domain_part_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        # That is, the domain DN appended
        self.assertEqual(
            ldb.Dn(self.samdb,
                   str(part_dn) + "," + str(domain_dn)),
            self.samdb.normalize_dn_in_domain(part_dn))
Example #8
0
class DynamicTokenTest(samba.tests.TestCase):
    def get_creds(self, target_username, target_password):
        creds_tmp = Credentials()
        creds_tmp.set_username(target_username)
        creds_tmp.set_password(target_password)
        creds_tmp.set_domain(creds.get_domain())
        creds_tmp.set_realm(creds.get_realm())
        creds_tmp.set_workstation(creds.get_workstation())
        creds_tmp.set_gensec_features(creds_tmp.get_gensec_features() | gensec.FEATURE_SEAL)
        return creds_tmp

    def get_ldb_connection(self, target_username, target_password):
        creds_tmp = self.get_creds(target_username, target_password)
        ldb_target = SamDB(url=url, credentials=creds_tmp, lp=lp)
        return ldb_target

    def setUp(self):
        super(DynamicTokenTest, self).setUp()
        self.admin_ldb = SamDB(url, credentials=creds, session_info=system_session(lp), lp=lp)

        self.base_dn = self.admin_ldb.domain_dn()

        self.test_user = "******"
        self.test_user_pass = "******"
        self.admin_ldb.newuser(self.test_user, self.test_user_pass)
        self.test_group0 = "tokengroups_group0"
        self.admin_ldb.newgroup(self.test_group0, grouptype=dsdb.GTYPE_SECURITY_DOMAIN_LOCAL_GROUP)
        res = self.admin_ldb.search(
            base="cn={0!s},cn=users,{1!s}".format(self.test_group0, self.base_dn),
            attrs=["objectSid"],
            scope=ldb.SCOPE_BASE,
        )
        self.test_group0_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group0, [self.test_user], add_members_operation=True)

        self.test_group1 = "tokengroups_group1"
        self.admin_ldb.newgroup(self.test_group1, grouptype=dsdb.GTYPE_SECURITY_GLOBAL_GROUP)
        res = self.admin_ldb.search(
            base="cn={0!s},cn=users,{1!s}".format(self.test_group1, self.base_dn),
            attrs=["objectSid"],
            scope=ldb.SCOPE_BASE,
        )
        self.test_group1_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group1, [self.test_user], add_members_operation=True)

        self.test_group2 = "tokengroups_group2"
        self.admin_ldb.newgroup(self.test_group2, grouptype=dsdb.GTYPE_SECURITY_UNIVERSAL_GROUP)

        res = self.admin_ldb.search(
            base="cn={0!s},cn=users,{1!s}".format(self.test_group2, self.base_dn),
            attrs=["objectSid"],
            scope=ldb.SCOPE_BASE,
        )
        self.test_group2_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group2, [self.test_user], add_members_operation=True)

        self.ldb = self.get_ldb_connection(self.test_user, self.test_user_pass)

        res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        self.user_sid_dn = "<SID={0!s}>".format(
            str(ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["tokenGroups"][0]))
        )

        res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=[])
        self.assertEquals(len(res), 1)

        self.test_user_dn = res[0].dn

        session_info_flags = (
            AUTH_SESSION_INFO_DEFAULT_GROUPS | AUTH_SESSION_INFO_AUTHENTICATED | AUTH_SESSION_INFO_SIMPLE_PRIVILEGES
        )
        session = samba.auth.user_session(
            self.ldb, lp_ctx=lp, dn=self.user_sid_dn, session_info_flags=session_info_flags
        )

        token = session.security_token
        self.user_sids = []
        for s in token.sids:
            self.user_sids.append(str(s))

    def tearDown(self):
        super(DynamicTokenTest, self).tearDown()
        delete_force(self.admin_ldb, "CN={0!s},{1!s},{2!s}".format(self.test_user, "cn=users", self.base_dn))
        delete_force(self.admin_ldb, "CN={0!s},{1!s},{2!s}".format(self.test_group0, "cn=users", self.base_dn))
        delete_force(self.admin_ldb, "CN={0!s},{1!s},{2!s}".format(self.test_group1, "cn=users", self.base_dn))
        delete_force(self.admin_ldb, "CN={0!s},{1!s},{2!s}".format(self.test_group2, "cn=users", self.base_dn))

    def test_rootDSE_tokenGroups(self):
        """Testing rootDSE tokengroups against internal calculation"""
        if not url.startswith("ldap"):
            self.fail(msg="This test is only valid on ldap")

        res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        print ("Getting tokenGroups from rootDSE")
        tokengroups = []
        for sid in res[0]["tokenGroups"]:
            tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))

        sidset1 = set(tokengroups)
        sidset2 = set(self.user_sids)
        if len(sidset1.difference(sidset2)):
            print ("token sids don't match")
            print ("tokengroups: {0!s}".format(tokengroups))
            print ("calculated : {0!s}".format(self.user_sids))
            print ("difference : {0!s}".format(sidset1.difference(sidset2)))
            self.fail(msg="calculated groups don't match against rootDSE tokenGroups")

    def test_dn_tokenGroups(self):
        print ("Getting tokenGroups from user DN")
        res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        dn_tokengroups = []
        for sid in res[0]["tokenGroups"]:
            dn_tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))

        sidset1 = set(dn_tokengroups)
        sidset2 = set(self.user_sids)
        if len(sidset1.difference(sidset2)):
            print ("token sids don't match")
            print ("difference : {0!s}".format(sidset1.difference(sidset2)))
            self.fail(msg="calculated groups don't match against user DN tokenGroups")

    def test_pac_groups(self):
        settings = {}
        settings["lp_ctx"] = lp
        settings["target_hostname"] = lp.get("netbios name")

        gensec_client = gensec.Security.start_client(settings)
        gensec_client.set_credentials(self.get_creds(self.test_user, self.test_user_pass))
        gensec_client.want_feature(gensec.FEATURE_SEAL)
        gensec_client.start_mech_by_sasl_name("GSSAPI")

        auth_context = AuthContext(lp_ctx=lp, ldb=self.ldb, methods=[])

        gensec_server = gensec.Security.start_server(settings, auth_context)
        machine_creds = Credentials()
        machine_creds.guess(lp)
        machine_creds.set_machine_account(lp)
        gensec_server.set_credentials(machine_creds)

        gensec_server.want_feature(gensec.FEATURE_SEAL)
        gensec_server.start_mech_by_sasl_name("GSSAPI")

        client_finished = False
        server_finished = False
        server_to_client = ""

        # Run the actual call loop.
        while client_finished == False and server_finished == False:
            if not client_finished:
                print "running client gensec_update"
                (client_finished, client_to_server) = gensec_client.update(server_to_client)
            if not server_finished:
                print "running server gensec_update"
                (server_finished, server_to_client) = gensec_server.update(client_to_server)

        session = gensec_server.session_info()

        token = session.security_token
        pac_sids = []
        for s in token.sids:
            pac_sids.append(str(s))

        sidset1 = set(pac_sids)
        sidset2 = set(self.user_sids)
        if len(sidset1.difference(sidset2)):
            print ("token sids don't match")
            print ("difference : {0!s}".format(sidset1.difference(sidset2)))
            self.fail(msg="calculated groups don't match against user PAC tokenGroups")

    def test_tokenGroups_manual(self):
        # Manually run the tokenGroups algorithm from MS-ADTS 3.1.1.4.5.19 and MS-DRSR 4.1.8.3
        # and compare the result
        res = self.admin_ldb.search(
            base=self.base_dn,
            scope=ldb.SCOPE_SUBTREE,
            expression="(|(objectclass=user)(objectclass=group))",
            attrs=["memberOf"],
        )
        aSet = set()
        aSetR = set()
        vSet = set()
        for obj in res:
            if "memberOf" in obj:
                for dn in obj["memberOf"]:
                    first = obj.dn.get_casefold()
                    second = ldb.Dn(self.admin_ldb, dn).get_casefold()
                    aSet.add((first, second))
                    aSetR.add((second, first))
                    vSet.add(first)
                    vSet.add(second)

        res = self.admin_ldb.search(
            base=self.base_dn, scope=ldb.SCOPE_SUBTREE, expression="(objectclass=user)", attrs=["primaryGroupID"]
        )
        for obj in res:
            if "primaryGroupID" in obj:
                sid = "{0!s}-{1:d}".format(self.admin_ldb.get_domain_sid(), int(obj["primaryGroupID"][0]))
                res2 = self.admin_ldb.search(base="<SID={0!s}>".format(sid), scope=ldb.SCOPE_BASE, attrs=[])
                first = obj.dn.get_casefold()
                second = res2[0].dn.get_casefold()

                aSet.add((first, second))
                aSetR.add((second, first))
                vSet.add(first)
                vSet.add(second)

        wSet = set()
        wSet.add(self.test_user_dn.get_casefold())
        closure(vSet, wSet, aSet)
        wSet.remove(self.test_user_dn.get_casefold())

        tokenGroupsSet = set()

        res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        dn_tokengroups = []
        for sid in res[0]["tokenGroups"]:
            sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
            res3 = self.admin_ldb.search(base="<SID={0!s}>".format(sid), scope=ldb.SCOPE_BASE, attrs=[])
            tokenGroupsSet.add(res3[0].dn.get_casefold())

        if len(wSet.difference(tokenGroupsSet)):
            self.fail(msg="additional calculated: {0!s}".format(wSet.difference(tokenGroupsSet)))

        if len(tokenGroupsSet.difference(wSet)):
            self.fail(msg="additional tokenGroups: {0!s}".format(tokenGroupsSet.difference(wSet)))

    def filtered_closure(self, wSet, filter_grouptype):
        res = self.admin_ldb.search(
            base=self.base_dn,
            scope=ldb.SCOPE_SUBTREE,
            expression="(|(objectclass=user)(objectclass=group))",
            attrs=["memberOf"],
        )
        aSet = set()
        aSetR = set()
        vSet = set()
        for obj in res:
            vSet.add(obj.dn.get_casefold())
            if "memberOf" in obj:
                for dn in obj["memberOf"]:
                    first = obj.dn.get_casefold()
                    second = ldb.Dn(self.admin_ldb, dn).get_casefold()
                    aSet.add((first, second))
                    aSetR.add((second, first))
                    vSet.add(first)
                    vSet.add(second)

        res = self.admin_ldb.search(
            base=self.base_dn, scope=ldb.SCOPE_SUBTREE, expression="(objectclass=user)", attrs=["primaryGroupID"]
        )
        for obj in res:
            if "primaryGroupID" in obj:
                sid = "{0!s}-{1:d}".format(self.admin_ldb.get_domain_sid(), int(obj["primaryGroupID"][0]))
                res2 = self.admin_ldb.search(base="<SID={0!s}>".format(sid), scope=ldb.SCOPE_BASE, attrs=[])
                first = obj.dn.get_casefold()
                second = res2[0].dn.get_casefold()

                aSet.add((first, second))
                aSetR.add((second, first))
                vSet.add(first)
                vSet.add(second)

        uSet = set()
        for v in vSet:
            res_group = self.admin_ldb.search(
                base=v, scope=ldb.SCOPE_BASE, attrs=["groupType"], expression="objectClass=group"
            )
            if len(res_group) == 1:
                if hex(int(res_group[0]["groupType"][0]) & 0x00000000FFFFFFFF) == hex(filter_grouptype):
                    uSet.add(v)
            else:
                uSet.add(v)

        closure(uSet, wSet, aSet)

    def test_tokenGroupsGlobalAndUniversal_manual(self):
        # Manually run the tokenGroups algorithm from MS-ADTS 3.1.1.4.5.19 and MS-DRSR 4.1.8.3
        # and compare the result

        # The variable names come from MS-ADTS May 15, 2014

        S = set()
        S.add(self.test_user_dn.get_casefold())

        self.filtered_closure(S, GTYPE_SECURITY_GLOBAL_GROUP)

        T = set()
        # Not really a SID, we do this on DNs...
        for sid in S:
            X = set()
            X.add(sid)
            self.filtered_closure(X, GTYPE_SECURITY_UNIVERSAL_GROUP)

            T = T.union(X)

        T.remove(self.test_user_dn.get_casefold())

        tokenGroupsSet = set()

        res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroupsGlobalAndUniversal"])
        self.assertEquals(len(res), 1)

        dn_tokengroups = []
        for sid in res[0]["tokenGroupsGlobalAndUniversal"]:
            sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
            res3 = self.admin_ldb.search(base="<SID={0!s}>".format(sid), scope=ldb.SCOPE_BASE, attrs=[])
            tokenGroupsSet.add(res3[0].dn.get_casefold())

        if len(T.difference(tokenGroupsSet)):
            self.fail(msg="additional calculated: {0!s}".format(T.difference(tokenGroupsSet)))

        if len(tokenGroupsSet.difference(T)):
            self.fail(msg="additional tokenGroupsGlobalAndUniversal: {0!s}".format(tokenGroupsSet.difference(T)))
Example #9
0
    def test_rid_set_dbcheck_after_seize(self):
        """Perform a join against the RID manager and assert we have a RID Set.
        We seize the RID master role, then using dbcheck, we assert that we can
        detect out of range users (and then bump the RID set as required)."""

        fsmo_dn = ldb.Dn(
            self.ldb_dc1,
            "CN=RID Manager$,CN=System," + self.ldb_dc1.domain_dn())
        (fsmo_owner, fsmo_not_owner) = self._determine_fSMORoleOwner(fsmo_dn)

        targetdir = self._test_join(fsmo_owner['dns_name'], "RIDALLOCTEST7")
        try:
            # Connect to the database
            ldb_url = "tdb://%s" % os.path.join(targetdir, "private/sam.ldb")
            smbconf = os.path.join(targetdir, "etc/smb.conf")

            lp = self.get_loadparm()
            new_ldb = SamDB(ldb_url,
                            credentials=self.get_credentials(),
                            session_info=system_session(lp),
                            lp=lp)

            # 1. Get server name
            res = new_ldb.search(base=ldb.Dn(new_ldb,
                                             new_ldb.get_serverName()),
                                 scope=ldb.SCOPE_BASE,
                                 attrs=["serverReference"])
            # 2. Get server reference
            server_ref_dn = ldb.Dn(new_ldb, res[0]['serverReference'][0])

            # 3. Assert we get the RID Set
            res = new_ldb.search(base=server_ref_dn,
                                 scope=ldb.SCOPE_BASE,
                                 attrs=['rIDSetReferences'])

            self.assertTrue("rIDSetReferences" in res[0])
            rid_set_dn = ldb.Dn(new_ldb, res[0]["rIDSetReferences"][0])

            # 4. Seize the RID Manager role
            (result, out, err) = self.runsubcmd("fsmo", "seize", "--role",
                                                "rid", "-H", ldb_url, "-s",
                                                smbconf, "--force")
            self.assertCmdSuccess(result, out, err)
            self.assertEquals(err, "", "Shouldn't be any error messages")

            # 5. Add a new user (triggers RID set work)
            new_ldb.newuser("ridalloctestuser", "P@ssword!")

            # 6. Now fetch the RID SET
            rid_set_res = new_ldb.search(
                base=rid_set_dn,
                scope=ldb.SCOPE_BASE,
                attrs=['rIDNextRid', 'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            last_rid = (0xFFFFFFFF00000000 & next_pool) >> 32

            # 7. Add user above the ridNextRid and at almost the end of the range.
            #
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser2,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD,
                                                  'objectClass')
            m['objectSid'] = ldb.MessageElement(
                ndr_pack(
                    security.dom_sid(
                        str(new_ldb.get_domain_sid()) + "-%d" %
                        (last_rid - 3))), ldb.FLAG_MOD_ADD, 'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            # 8. Add user above the ridNextRid and at the end of the range
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser3,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD,
                                                  'objectClass')
            m['objectSid'] = ldb.MessageElement(
                ndr_pack(
                    security.dom_sid(
                        str(new_ldb.get_domain_sid()) + "-%d" % last_rid)),
                ldb.FLAG_MOD_ADD, 'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            chk = dbcheck(new_ldb,
                          verbose=False,
                          fix=True,
                          yes=True,
                          quiet=True)

            # Should have fixed two errors (wrong ridNextRid)
            self.assertEqual(
                chk.check_database(DN=rid_set_dn, scope=ldb.SCOPE_BASE), 2)

            # 9. Assert we get didn't show any other errors
            chk = dbcheck(new_ldb, verbose=False, fix=False, quiet=True)

            # 10. Add another user (checks RID rollover)
            # We have seized the role, so we can do that.
            new_ldb.newuser("ridalloctestuser3", "P@ssword!")

            rid_set_res = new_ldb.search(
                base=rid_set_dn,
                scope=ldb.SCOPE_BASE,
                attrs=['rIDNextRid', 'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            self.assertNotEqual(last_rid,
                                (0xFFFFFFFF00000000 & next_pool) >> 32,
                                "rid pool should have changed")
        finally:
            self._test_force_demote(fsmo_owner['dns_name'], "RIDALLOCTEST7")
            shutil.rmtree(targetdir, ignore_errors=True)
Example #10
0
class UserAccountControlTests(samba.tests.TestCase):
    def add_computer_ldap(self, computername, others=None, samdb=None):
        if samdb is None:
            samdb = self.samdb
        dn = "CN=%s,OU=test_computer_ou1,%s" % (computername, self.base_dn)
        domainname = ldb.Dn(self.samdb, self.samdb.domain_dn()).canonical_str().replace("/", "")
        samaccountname = "%s$" % computername
        dnshostname = "%s.%s" % (computername, domainname)
        msg_dict = {
            "dn": dn,
            "objectclass": "computer"}
        if others is not None:
            msg_dict = dict(msg_dict.items() + others.items())

        msg = ldb.Message.from_dict(self.samdb, msg_dict )
        msg["sAMAccountName"] = samaccountname

        print "Adding computer account %s" % computername
        samdb.add(msg)

    def get_creds(self, target_username, target_password):
        creds_tmp = Credentials()
        creds_tmp.set_username(target_username)
        creds_tmp.set_password(target_password)
        creds_tmp.set_domain(creds.get_domain())
        creds_tmp.set_realm(creds.get_realm())
        creds_tmp.set_workstation(creds.get_workstation())
        creds_tmp.set_gensec_features(creds_tmp.get_gensec_features()
                                      | gensec.FEATURE_SEAL)
        creds_tmp.set_kerberos_state(DONT_USE_KERBEROS) # kinit is too expensive to use in a tight loop
        return creds_tmp

    def setUp(self):
        super(UserAccountControlTests, self).setUp()
        self.admin_creds = creds
        self.admin_samdb = SamDB(url=ldaphost,
                                 session_info=system_session(),
                                 credentials=self.admin_creds, lp=lp)

        self.unpriv_user = "******"
        self.unpriv_user_pw = "samba123@"
        self.unpriv_creds = self.get_creds(self.unpriv_user, self.unpriv_user_pw)

        self.admin_samdb.newuser(self.unpriv_user, self.unpriv_user_pw)
        res = self.admin_samdb.search("CN=%s,CN=Users,%s" % (self.unpriv_user, self.admin_samdb.domain_dn()),
                                      scope=SCOPE_BASE,
                                      attrs=["objectSid"])
        self.assertEqual(1, len(res))

        self.unpriv_user_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        self.unpriv_user_dn = res[0].dn

        self.samdb = SamDB(url=ldaphost, credentials=self.unpriv_creds, lp=lp)
        self.domain_sid = security.dom_sid(self.samdb.get_domain_sid())
        self.base_dn = self.samdb.domain_dn()

        self.samr = samr.samr("ncacn_ip_tcp:%s[sign]" % host, lp, self.unpriv_creds)
        self.samr_handle = self.samr.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
        self.samr_domain = self.samr.OpenDomain(self.samr_handle, security.SEC_FLAG_MAXIMUM_ALLOWED, self.domain_sid)

        self.sd_utils = sd_utils.SDUtils(self.admin_samdb)

        self.admin_samdb.create_ou("OU=test_computer_ou1," + self.base_dn)
        self.unpriv_user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(self.unpriv_user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," + self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        self.add_computer_ldap("testcomputer-t")

        self.sd_utils.modify_sd_on_dn("OU=test_computer_ou1," + self.base_dn, old_sd)

        self.computernames = ["testcomputer-0"]

        # Get the SD of the template account, then force it to match
        # what we expect for SeMachineAccountPrivilege accounts, so we
        # can confirm we created the accounts correctly
        self.sd_reference_cc = self.sd_utils.read_sd_on_dn("CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))

        self.sd_reference_modify = self.sd_utils.read_sd_on_dn("CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))
        for ace in self.sd_reference_modify.dacl.aces:
            if ace.type == security.SEC_ACE_TYPE_ACCESS_ALLOWED and ace.trustee == self.unpriv_user_sid:
                ace.access_mask = ace.access_mask | security.SEC_ADS_SELF_WRITE | security.SEC_ADS_WRITE_PROP

        # Now reconnect without domain admin rights
        self.samdb = SamDB(url=ldaphost, credentials=self.unpriv_creds, lp=lp)


    def tearDown(self):
        super(UserAccountControlTests, self).tearDown()
        for computername in self.computernames:
            delete_force(self.admin_samdb, "CN=%s,OU=test_computer_ou1,%s" % (computername, self.base_dn))
        delete_force(self.admin_samdb, "CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))
        delete_force(self.admin_samdb, "OU=test_computer_ou1,%s" % (self.base_dn))
        delete_force(self.admin_samdb, "CN=%s,CN=Users,%s" % (self.unpriv_user, self.base_dn))

    def test_add_computer_sd_cc(self):
        user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," + self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        computername=self.computernames[0]
        sd = ldb.MessageElement((ndr_pack(self.sd_reference_modify)),
                                ldb.FLAG_MOD_ADD,
                                "nTSecurityDescriptor")
        self.add_computer_ldap(computername,
                               others={"nTSecurityDescriptor": sd})

        res = self.admin_samdb.search("%s" % self.base_dn,
                                      expression="(&(objectClass=computer)(samAccountName=%s$))" % computername,
                                      scope=SCOPE_SUBTREE,
                                      attrs=["ntSecurityDescriptor"])

        desc = res[0]["nTSecurityDescriptor"][0]
        desc = ndr_unpack(security.descriptor, desc, allow_remaining=True)

        sddl = desc.as_sddl(self.domain_sid)
        self.assertEqual(self.sd_reference_modify.as_sddl(self.domain_sid), sddl)

        m = ldb.Message()
        m.dn = res[0].dn
        m["description"]= ldb.MessageElement(
            ("A description"), ldb.FLAG_MOD_REPLACE,
            "description")
        self.samdb.modify(m)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(str(samba.dsdb.UF_SERVER_TRUST_ACCOUNT),
                                                     ldb.FLAG_MOD_REPLACE, "userAccountControl")
        try:
            self.samdb.modify(m)
            self.fail("Unexpectedly able to set userAccountControl to be a DC on %s" % m.dn)
        except LdbError, (enum, estr):
            self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(str(samba.dsdb.UF_WORKSTATION_TRUST_ACCOUNT|samba.dsdb.UF_PARTIAL_SECRETS_ACCOUNT),
                                                     ldb.FLAG_MOD_REPLACE, "userAccountControl")
        try:
            self.samdb.modify(m)
            self.fail("Unexpectedly able to set userAccountControl to be an RODC on %s" % m.dn)
        except LdbError, (enum, estr):
            self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)
Example #11
0
    def run(self,
            new_domain_name,
            new_dns_realm,
            sambaopts=None,
            credopts=None,
            server=None,
            targetdir=None,
            keep_dns_realm=False,
            no_secrets=False):
        logger = self.get_logger()
        logger.setLevel(logging.INFO)

        # Make sure we have all the required args.
        check_online_backup_args(logger, credopts, server, targetdir)
        delete_old_dns = not keep_dns_realm

        new_dns_realm = new_dns_realm.lower()
        new_domain_name = new_domain_name.upper()

        new_base_dn = samba.dn_from_dns_name(new_dns_realm)
        logger.info("New realm for backed up domain: %s" % new_dns_realm)
        logger.info("New base DN for backed up domain: %s" % new_base_dn)
        logger.info("New domain NetBIOS name: %s" % new_domain_name)

        tmpdir = tempfile.mkdtemp(dir=targetdir)

        # setup a join-context for cloning the remote server
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)
        include_secrets = not no_secrets
        ctx = DCCloneAndRenameContext(new_base_dn,
                                      new_domain_name,
                                      new_dns_realm,
                                      logger=logger,
                                      creds=creds,
                                      lp=lp,
                                      include_secrets=include_secrets,
                                      dns_backend='SAMBA_INTERNAL',
                                      server=server,
                                      targetdir=tmpdir)

        # sanity-check we're not "renaming" the domain to the same values
        old_domain = ctx.domain_name
        if old_domain == new_domain_name:
            shutil.rmtree(tmpdir)
            raise CommandError("Cannot use the current domain NetBIOS name.")

        old_realm = ctx.realm
        if old_realm == new_dns_realm:
            shutil.rmtree(tmpdir)
            raise CommandError("Cannot use the current domain DNS realm.")

        # do the clone/rename
        ctx.do_join()

        # get the paths used for the clone, then drop the old samdb connection
        del ctx.local_samdb
        paths = ctx.paths

        # get a free RID to use as the new DC's SID (when it gets restored)
        remote_sam = SamDB(url='ldap://' + server,
                           credentials=creds,
                           session_info=system_session(),
                           lp=lp)
        new_sid = get_sid_for_restore(remote_sam)

        # Grab the remote DC's sysvol files and bundle them into a tar file.
        # Note we end up with 2 sysvol dirs - the original domain's files (that
        # use the old realm) backed here, as well as default files generated
        # for the new realm as part of the clone/join.
        sysvol_tar = os.path.join(tmpdir, 'sysvol.tar.gz')
        smb_conn = smb.SMB(server, "sysvol", lp=lp, creds=creds)
        backup_online(smb_conn, sysvol_tar, remote_sam.get_domain_sid())

        # connect to the local DB (making sure we use the new/renamed config)
        lp.load(paths.smbconf)
        samdb = SamDB(url=paths.samdb, session_info=system_session(), lp=lp)

        # Edit the cloned sam.ldb to mark it as a backup
        time_str = get_timestamp()
        add_backup_marker(samdb, "backupDate", time_str)
        add_backup_marker(samdb, "sidForRestore", new_sid)
        add_backup_marker(samdb, "backupRename", old_realm)

        # fix up the DNS objects that are using the old dnsRoot value
        self.update_dns_root(logger, samdb, old_realm, delete_old_dns)

        # update the netBIOS name and the Partition object for the domain
        self.rename_domain_partition(logger, samdb, new_domain_name)

        if delete_old_dns:
            self.delete_old_dns_zones(logger, samdb, old_realm)

        logger.info("Fixing DN attributes after rename...")
        self.fix_old_dn_attributes(samdb)

        # ensure the admin user always has a password set (same as provision)
        if no_secrets:
            set_admin_password(logger, samdb, creds.get_username())

        # Add everything in the tmpdir to the backup tar file
        backup_file = backup_filepath(targetdir, new_dns_realm, time_str)
        create_log_file(
            tmpdir, lp, "rename", server, include_secrets,
            "Original domain %s (NetBIOS), %s (DNS realm)" %
            (old_domain, old_realm))
        create_backup_tar(logger, tmpdir, backup_file)

        shutil.rmtree(tmpdir)
Example #12
0
class DsdbTests(TestCase):
    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

        # Create a test user
        user_name = "dsdb-user-" + str(uuid.uuid4().hex[0:6])
        user_pass = samba.generate_random_password(32, 32)
        user_description = "Test user for dsdb test"

        base_dn = self.samdb.domain_dn()

        self.account_dn = "cn=" + user_name + ",cn=Users," + base_dn
        self.samdb.newuser(username=user_name,
                           password=user_pass,
                           description=user_description)
        # Cleanup (teardown)
        self.addCleanup(delete_force, self.samdb, self.account_dn)

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEquals(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, [
            "local_oid:1.3.6.1.4.1.7165.4.3.14:0",
            "local_oid:1.3.6.1.4.1.7165.4.3.25:0"
        ])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val",
                                                ldb.FLAG_MOD_REPLACE,
                                                "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(13),
                          "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 2)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(
            dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description",
                                                      version + 2)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "description"),
            version + 2)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:0" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:1" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            (errno, estr) = e.args
            if errno != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail(
                    "Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                    % e[1])

    # Allocate a unique RID for use in the objectSID tests.
    #
    def allocate_rid(self):
        self.samdb.transaction_start()
        try:
            rid = self.samdb.allocate_rid()
        except:
            self.samdb.transaction_cancel()
            raise
        self.samdb.transaction_commit()
        return str(rid)

    # Ensure that duplicate objectSID's are permitted for foreign security
    # principals.
    #
    def test_duplicate_objectSIDs_allowed_on_foreign_security_principals(self):

        #
        # We need to build a foreign security principal SID
        # i.e a  SID not in the current domain.
        #
        dom_sid = self.samdb.get_domain_sid()
        if str(dom_sid).endswith("0"):
            c = "9"
        else:
            c = "0"
        sid = str(dom_sid)[:-1] + c + "-1000"
        basedn = self.samdb.get_default_basedn()
        dn = "CN=%s,CN=ForeignSecurityPrincipals,%s" % (sid, basedn)
        self.samdb.add({"dn": dn, "objectClass": "foreignSecurityPrincipal"})

        self.samdb.delete(dn)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"
            })
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.fail("Got unexpected exception %d - %s " % (code, msg))

        # cleanup
        self.samdb.delete(dn)

    #
    # Duplicate objectSID's should not be permitted for sids in the local
    # domain. The test sequence is add an object, delete it, then attempt to
    # re-add it, this should fail with a constraint violation
    #
    def test_duplicate_objectSIDs_not_allowed_on_local_objects(self):

        dom_sid = self.samdb.get_domain_sid()
        rid = self.allocate_rid()
        sid_str = str(dom_sid) + "-" + rid
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_01"
        dn = "cn=%s,cn=Users,%s" % (cn, basedn)

        self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
        self.samdb.delete(dn)

        try:
            self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            if code != ldb.ERR_CONSTRAINT_VIOLATION:
                self.fail("Got %d - %s should have got "
                          "LDB_ERR_CONSTRAINT_VIOLATION" % (code, msg))

    def test_normalize_dn_in_domain_full(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        full_str = str(full_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_str))

    def test_normalize_dn_in_domain_part(self):
        domain_dn = self.samdb.domain_dn()

        part_str = "CN=Users"

        full_dn = ldb.Dn(self.samdb, part_str)
        full_dn.add_base(domain_dn)

        # That is, the domain DN appended
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(part_str))

    def test_normalize_dn_in_domain_full_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_dn))

    def test_normalize_dn_in_domain_part_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        # That is, the domain DN appended
        self.assertEqual(
            ldb.Dn(self.samdb,
                   str(part_dn) + "," + str(domain_dn)),
            self.samdb.normalize_dn_in_domain(part_dn))
Example #13
0
class PrivAttrsTests(samba.tests.TestCase):
    def get_creds(self, target_username, target_password):
        creds_tmp = Credentials()
        creds_tmp.set_username(target_username)
        creds_tmp.set_password(target_password)
        creds_tmp.set_domain(creds.get_domain())
        creds_tmp.set_realm(creds.get_realm())
        creds_tmp.set_workstation(creds.get_workstation())
        creds_tmp.set_gensec_features(creds_tmp.get_gensec_features()
                                      | gensec.FEATURE_SEAL)
        creds_tmp.set_kerberos_state(
            DONT_USE_KERBEROS)  # kinit is too expensive to use in a tight loop
        return creds_tmp

    def assertGotLdbError(self, wanted, got):
        if not self.strict_checking:
            self.assertNotEqual(got, ldb.SUCCESS)
        else:
            self.assertEqual(got, wanted)

    def setUp(self):
        super().setUp()

        strict_checking = samba.tests.env_get_var_value('STRICT_CHECKING',
                                                        allow_missing=True)
        if strict_checking is None:
            strict_checking = '1'
        self.strict_checking = bool(int(strict_checking))

        self.admin_creds = creds
        self.admin_samdb = SamDB(url=ldaphost,
                                 credentials=self.admin_creds,
                                 lp=lp)
        self.domain_sid = security.dom_sid(self.admin_samdb.get_domain_sid())
        self.base_dn = self.admin_samdb.domain_dn()

        self.unpriv_user = "******"
        self.unpriv_user_pw = "samba123@"
        self.unpriv_creds = self.get_creds(self.unpriv_user,
                                           self.unpriv_user_pw)

        self.admin_sd_utils = sd_utils.SDUtils(self.admin_samdb)

        self.test_ou_name = "OU=test_priv_attrs"
        self.test_ou = self.test_ou_name + "," + self.base_dn

        delete_force(self.admin_samdb,
                     self.test_ou,
                     controls=["tree_delete:0"])

        self.admin_samdb.create_ou(self.test_ou)

        expected_user_dn = f"CN={self.unpriv_user},{self.test_ou_name},{self.base_dn}"

        self.admin_samdb.newuser(self.unpriv_user,
                                 self.unpriv_user_pw,
                                 userou=self.test_ou_name)
        res = self.admin_samdb.search(expected_user_dn,
                                      scope=SCOPE_BASE,
                                      attrs=["objectSid"])

        self.assertEqual(1, len(res))

        self.unpriv_user_dn = res[0].dn
        self.addCleanup(delete_force,
                        self.admin_samdb,
                        self.unpriv_user_dn,
                        controls=["tree_delete:0"])

        self.unpriv_user_sid = self.admin_sd_utils.get_object_sid(
            self.unpriv_user_dn)

        self.unpriv_samdb = SamDB(url=ldaphost,
                                  credentials=self.unpriv_creds,
                                  lp=lp)

    @classmethod
    def setUpDynamicTestCases(cls):
        for test_name in attrs.keys():
            for add_or_mod in ["add", "mod-del-add", "mod-replace"]:
                for permission in ["admin-add", "CC"]:
                    for sd in ["default", "WP"]:
                        for objectclass in ["computer", "user"]:
                            tname = f"{test_name}_{add_or_mod}_{permission}_{sd}_{objectclass}"
                            targs = (test_name, add_or_mod, permission, sd,
                                     objectclass)
                            cls.generate_dynamic_test("test_priv_attr", tname,
                                                      *targs)

    def add_computer_ldap(self, computername, others=None, samdb=None):
        dn = "CN=%s,%s" % (computername, self.test_ou)
        domainname = ldb.Dn(samdb, samdb.domain_dn()).canonical_str().replace(
            "/", "")
        samaccountname = "%s$" % computername
        dnshostname = "%s.%s" % (computername, domainname)
        msg_dict = {"dn": dn, "objectclass": "computer"}
        if others is not None:
            msg_dict = dict(list(msg_dict.items()) + list(others.items()))

        msg = ldb.Message.from_dict(samdb, msg_dict)
        msg["sAMAccountName"] = samaccountname

        print("Adding computer account %s" % computername)
        try:
            samdb.add(msg)
        except ldb.LdbError:
            print(msg)
            raise
        return msg.dn

    def add_user_ldap(self, username, others=None, samdb=None):
        dn = "CN=%s,%s" % (username, self.test_ou)
        domainname = ldb.Dn(samdb, samdb.domain_dn()).canonical_str().replace(
            "/", "")
        samaccountname = "%s$" % username
        msg_dict = {"dn": dn, "objectclass": "user"}
        if others is not None:
            msg_dict = dict(list(msg_dict.items()) + list(others.items()))

        msg = ldb.Message.from_dict(samdb, msg_dict)
        msg["sAMAccountName"] = samaccountname

        print("Adding user account %s" % username)
        try:
            samdb.add(msg)
        except ldb.LdbError:
            print(msg)
            raise
        return msg.dn

    def add_thing_ldap(self, user, others, samdb, objectclass):
        if objectclass == "user":
            dn = self.add_user_ldap(user, others, samdb=samdb)
        elif objectclass == "computer":
            dn = self.add_computer_ldap(user, others, samdb=samdb)
        return dn

    def _test_priv_attr_with_args(self, test_name, add_or_mod, permission, sd,
                                  objectclass):
        user = "******"
        if "attr" in attrs[test_name]:
            attr = attrs[test_name]["attr"]
        else:
            attr = test_name
        if add_or_mod == "add":
            others = {attr: attrs[test_name]["value"]}
        else:
            others = {}

        if permission == "CC":
            samdb = self.unpriv_samdb
            # Set CC on container to allow user add
            mod = "(OA;CI;CC;bf967aba-0de6-11d0-a285-00aa003049e2;;%s)" % str(
                self.unpriv_user_sid)
            self.admin_sd_utils.dacl_add_ace(self.test_ou, mod)
            mod = "(OA;CI;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(
                self.unpriv_user_sid)
            self.admin_sd_utils.dacl_add_ace(self.test_ou, mod)

        else:
            samdb = self.admin_samdb

        if sd == "WP":
            # Set SD to WP to the target user as part of add
            sd = "O:%sG:DUD:(OA;CIID;RPWP;;;%s)(OA;;CR;00299570-246d-11d0-a768-00aa006e0529;;%s)" % (
                self.unpriv_user_sid, self.unpriv_user_sid,
                self.unpriv_user_sid)
            tmp_desc = security.descriptor.from_sddl(sd, self.domain_sid)
            others["ntSecurityDescriptor"] = ndr_pack(tmp_desc)

        if add_or_mod == "add":

            # only-1 and only-2 are due to windows behaviour

            if "only-1" in attrs[test_name] and \
                 attrs[test_name]["only-1"] != objectclass:
                try:
                    dn = self.add_thing_ldap(user, others, samdb, objectclass)
                    self.fail(
                        f"{test_name}: Unexpectedly able to set {attr} on new {objectclass} as ADMIN (should fail LDAP_OBJECT_CLASS_VIOLATION)"
                    )
                except LdbError as e5:
                    (enum, estr) = e5.args
                    self.assertGotLdbError(ldb.ERR_OBJECT_CLASS_VIOLATION,
                                           enum)
            elif permission == "CC":
                try:
                    dn = self.add_thing_ldap(user, others, samdb, objectclass)
                    self.fail(
                        f"{test_name}: Unexpectedly able to set {attr} on new {objectclass}"
                    )
                except LdbError as e5:
                    (enum, estr) = e5.args
                    if "unpriv-add-error" in attrs[test_name]:
                        self.assertGotLdbError(attrs[test_name]["unpriv-add-error"], \
                                         enum)
                    else:
                        self.assertGotLdbError(attrs[test_name]["unpriv-error"], \
                                         enum)
            elif "only-2" in attrs[test_name] and \
                 attrs[test_name]["only-2"] != objectclass:
                try:
                    dn = self.add_thing_ldap(user, others, samdb, objectclass)
                    self.fail(
                        f"{test_name}: Unexpectedly able to set {attr} on new {objectclass} as ADMIN (should fail LDAP_OBJECT_CLASS_VIOLATION)"
                    )
                except LdbError as e5:
                    (enum, estr) = e5.args
                    self.assertGotLdbError(ldb.ERR_OBJECT_CLASS_VIOLATION,
                                           enum)
            elif "priv-error" in attrs[test_name]:
                try:
                    dn = self.add_thing_ldap(user, others, samdb, objectclass)
                    self.fail(
                        f"{test_name}: Unexpectedly able to set {attr} on new {objectclass} as ADMIN"
                    )
                except LdbError as e5:
                    (enum, estr) = e5.args
                    self.assertGotLdbError(attrs[test_name]["priv-error"],
                                           enum)
            else:
                try:
                    dn = self.add_thing_ldap(user, others, samdb, objectclass)
                except LdbError as e5:
                    (enum, estr) = e5.args
                    self.fail(
                        f"Failed to add account {user} as objectclass {objectclass}"
                    )
        else:
            try:
                dn = self.add_thing_ldap(user, others, samdb, objectclass)
            except LdbError as e5:
                (enum, estr) = e5.args
                self.fail(
                    f"Failed to add account {user} as objectclass {objectclass}"
                )

        if add_or_mod == "add":
            return

        m = ldb.Message()
        m.dn = dn

        # Do modify
        if add_or_mod == "mod-del-add":
            m["0"] = ldb.MessageElement([], ldb.FLAG_MOD_DELETE, attr)
            m["1"] = ldb.MessageElement(attrs[test_name]["value"],
                                        ldb.FLAG_MOD_ADD, attr)
        else:
            m["0"] = ldb.MessageElement(attrs[test_name]["value"],
                                        ldb.FLAG_MOD_REPLACE, attr)

        try:
            self.unpriv_samdb.modify(m)
            self.fail(
                f"{test_name}: Unexpectedly able to set {attr} on {m.dn}")
        except LdbError as e5:
            (enum, estr) = e5.args
            self.assertGotLdbError(attrs[test_name]["unpriv-error"], enum)
Example #14
0
class TrafficEmulatorPacketTests(samba.tests.TestCase):
    def setUp(self):
        super(TrafficEmulatorPacketTests, self).setUp()
        self.server = os.environ["SERVER"]
        self.domain = os.environ["DOMAIN"]
        self.host = os.environ["SERVER_IP"]
        self.lp = self.get_loadparm()
        self.session = system_session()
        self.credentials = self.get_credentials()

        self.ldb = SamDB(url="ldap://%s" % self.host,
                         session_info=self.session,
                         credentials=self.credentials,
                         lp=self.lp)
        self.domain_sid = self.ldb.get_domain_sid()

        traffic.clean_up_accounts(self.ldb, 1)
        self.tempdir = tempfile.mkdtemp(prefix="traffic_packet_test_")
        self.context = traffic.ReplayContext(server=self.server,
                                             lp=self.lp,
                                             creds=self.credentials,
                                             tempdir=self.tempdir,
                                             ou=traffic.ou_name(self.ldb, 1),
                                             domain_sid=self.domain_sid)

        self.conversation = traffic.Conversation()
        self.conversation.conversation_id = 1
        self.machinename = "STGM-1-1"
        self.machinepass = samba.generate_random_password(32, 32)
        self.username = "******"
        self.userpass = samba.generate_random_password(32, 32)
        account = traffic.ConversationAccounts(self.machinename,
                                               self.machinepass, self.username,
                                               self.userpass)

        traffic.create_ou(self.ldb, 1)
        traffic.create_machine_account(self.ldb, 1, self.machinename,
                                       self.machinepass)
        traffic.create_user_account(self.ldb, 1, self.username, self.userpass)

        self.context.generate_process_local_config(account, self.conversation)

        # grant user write permission to do things like write account SPN
        sdutils = sd_utils.SDUtils(self.ldb)
        mod = "(A;;WP;;;PS)"
        sdutils.dacl_add_ace(self.context.user_dn, mod)

    def tearDown(self):
        super(TrafficEmulatorPacketTests, self).tearDown()
        traffic.clean_up_accounts(self.ldb, 1)
        del self.ldb
        shutil.rmtree(self.tempdir)

    def test_packet_cldap_03(self):
        packet = Packet.from_line(
            "0.0\t11\t1\t2\t1\tcldap\t3\tsearchRequest\t")
        self.assertTrue(
            p.packet_cldap_3(packet, self.conversation, self.context))

    def test_packet_cldap_05(self):
        packet = Packet.from_line(
            "0.0\t11\t1\t1\t2\tcldap\t5\tsearchResDone\t")
        self.assertFalse(
            p.packet_cldap_5(packet, self.conversation, self.context))

    def test_packet_dcerpc_00(self):
        packet = Packet.from_line("0.0\t11\t1\t2\t1\tdcerpc\t0\tRequest\t")
        self.assertFalse(
            p.packet_dcerpc_0(packet, self.conversation, self.context))

    def test_packet_dcerpc_02(self):
        packet = Packet.from_line("0.0\t11\t1\t1\t2\tdcerpc\t2\tResponse\t")
        self.assertFalse(
            p.packet_dcerpc_2(packet, self.conversation, self.context))

    def test_packet_dcerpc_03(self):
        packet = Packet.from_line("0.0\t11\t1\t1\t2\tdcerpc\t3\t\t")
        self.assertFalse(
            p.packet_dcerpc_3(packet, self.conversation, self.context))

    def test_packet_dcerpc_11(self):
        packet = Packet.from_line("0.0\t11\t1\t2\t1\tdcerpc\t11\tBind\t")
        self.assertFalse(
            p.packet_dcerpc_11(packet, self.conversation, self.context))

    def test_packet_dcerpc_13(self):
        packet = Packet.from_line("0.0\t11\t1\t2\t1\tdcerpc\t13\t\t")
        self.assertFalse(
            p.packet_dcerpc_13(packet, self.conversation, self.context))

    def test_packet_dcerpc_14(self):
        packet = Packet.from_line(
            "0.0\t11\t1\t2\t1\tdcerpc\t14\tAlter_context\t")
        self.assertFalse(
            p.packet_dcerpc_14(packet, self.conversation, self.context))

    def test_packet_dcerpc_15(self):
        packet = Packet.from_line(
            "0.0\t11\t1\t1\t2\tdcerpc\t15\tAlter_context_resp\t")
        # Set user_creds MUST_USE_KERBEROS to suppress the warning message.
        self.context.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
        self.assertFalse(
            p.packet_dcerpc_15(packet, self.conversation, self.context))

    def test_packet_dcerpc_16(self):
        packet = Packet.from_line("0.0\t11\t1\t1\t2\tdcerpc\t16\tAUTH3\t")
        self.assertFalse(
            p.packet_dcerpc_16(packet, self.conversation, self.context))

    def test_packet_dns_01(self):
        packet = Packet.from_line("0.0\t11\t1\t1\t2\tdns\t1\tresponse\t")
        self.assertFalse(
            p.packet_dns_1(packet, self.conversation, self.context))

    def test_packet_drsuapi_00(self):
        packet = Packet.from_line("0.0\t06\t1\t1\t2\tdrsuapi\t0\tDsBind\t")
        self.assertTrue(
            p.packet_drsuapi_0(packet, self.conversation, self.context))

    def test_packet_drsuapi_01(self):
        packet = Packet.from_line("0.0\t06\t1\t1\t2\tdrsuapi\t1\tDsUnBind\t")
        self.assertTrue(
            p.packet_drsuapi_1(packet, self.conversation, self.context))

    def test_packet_drsuapi_02(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t1\t2\tdrsuapi\t2\tDsReplicaSync\t")
        self.assertFalse(
            p.packet_drsuapi_2(packet, self.conversation, self.context))

    def test_packet_drsuapi_03(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t1\t2\tdrsuapi\t3\tDsGetNCChanges\t")
        self.assertFalse(
            p.packet_drsuapi_3(packet, self.conversation, self.context))

    def test_packet_drsuapi_04(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t1\t2\tdrsuapi\t4\tDsReplicaUpdateRefs\t")
        self.assertFalse(
            p.packet_drsuapi_4(packet, self.conversation, self.context))

    def test_packet_drsuapi_12(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t1\t2\tdrsuapi\t12\tDsCrackNames\t")
        self.assertTrue(
            p.packet_drsuapi_12(packet, self.conversation, self.context))

    def test_packet_drsuapi_13(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t1\t2\tdrsuapi\t13\tDsWriteAccountSpn\t")
        self.assertTrue(
            p.packet_drsuapi_13(packet, self.conversation, self.context))

    def test_packet_epm_03(self):
        packet = Packet.from_line("0.0\t06\t1\t1\t2\tepm\t3\tMap\t")
        self.assertFalse(
            p.packet_epm_3(packet, self.conversation, self.context))

    def test_packet_kerberos(self):
        """Kerberos packets are not generated, but are used as a hint to
        favour kerberos.
        """
        packet = Packet.from_line("0.0\t11\t1\t1\t2\tkerberos\t\t\t")
        self.assertFalse(
            p.packet_kerberos_(packet, self.conversation, self.context))
        self.assertEqual(MUST_USE_KERBEROS,
                         self.context.user_creds.get_kerberos_state())
        self.assertEqual(MUST_USE_KERBEROS,
                         self.context.user_creds_bad.get_kerberos_state())
        self.assertEqual(MUST_USE_KERBEROS,
                         self.context.machine_creds.get_kerberos_state())
        self.assertEqual(MUST_USE_KERBEROS,
                         self.context.machine_creds_bad.get_kerberos_state())
        self.assertEqual(MUST_USE_KERBEROS,
                         self.context.creds.get_kerberos_state())

        # Need to restore kerberos creds on the admin creds otherwise
        # subsequent tests fail
        self.credentials.set_kerberos_state(DONT_USE_KERBEROS)

    def test_packet_ldap(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t1\t2\tldap\t\t*** Unknown ***\t")
        self.assertFalse(
            p.packet_ldap_(packet, self.conversation, self.context))

    def test_packet_ldap_00_sasl(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tldap\t0\tbindRequest"
                                  "\t\t\t\t\t3\tsasl\t1.3.6.1.5.5.2")
        self.assertTrue(
            p.packet_ldap_0(packet, self.conversation, self.context))

    def test_packet_ldap_00_simple(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tldap\t0\tbindRequest"
                                  "\t\t\t\t\t0\tsimple\t")
        self.assertTrue(
            p.packet_ldap_0(packet, self.conversation, self.context))

    def test_packet_ldap_01(self):
        packet = Packet.from_line("0.0\t06\t1\t1\t2\tldap\t1\tbindResponse\t")
        self.assertFalse(
            p.packet_ldap_1(packet, self.conversation, self.context))

    def test_packet_ldap_02(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tldap\t2\tunbindRequest\t")
        self.assertFalse(
            p.packet_ldap_2(packet, self.conversation, self.context))

    def test_packet_ldap_03(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tldap\t3\tsearchRequest"
                                  "\t2\tDC,DC\t\tcn\t\t\t")
        self.assertTrue(
            p.packet_ldap_3(packet, self.conversation, self.context))

    def test_packet_ldap_04(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t1\t2\tldap\t4\tsearchResEntry\t")
        self.assertFalse(
            p.packet_ldap_4(packet, self.conversation, self.context))

    def test_packet_ldap_05(self):
        packet = Packet.from_line("0.0\t06\t1\t1\t2\tldap\t5\tsearchResDone\t")
        self.assertFalse(
            p.packet_ldap_5(packet, self.conversation, self.context))

    def test_packet_ldap_06(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tldap\t6\tmodifyRequest\t"
                                  "\t\t\t\t0\tadd")
        self.assertFalse(
            p.packet_ldap_6(packet, self.conversation, self.context))

    def test_packet_ldap_07(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t1\t2\tldap\t7\tmodifyResponse\t")
        self.assertFalse(
            p.packet_ldap_7(packet, self.conversation, self.context))

    def test_packet_ldap_08(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tldap\t8\taddRequest\t")
        self.assertFalse(
            p.packet_ldap_8(packet, self.conversation, self.context))

    def test_packet_ldap_09(self):
        packet = Packet.from_line("0.0\t06\t1\t1\t2\tldap\t9\taddResponse\t")
        self.assertFalse(
            p.packet_ldap_9(packet, self.conversation, self.context))

    def test_packet_ldap_16(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tldap\t16\tabandonRequest\t")
        self.assertFalse(
            p.packet_ldap_16(packet, self.conversation, self.context))

    def test_packet_lsarpc_00(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tlsarpc\t0\tlsa_Close\t")
        self.assertFalse(
            p.packet_lsarpc_1(packet, self.conversation, self.context))

    def test_packet_lsarpc_01(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tlsarpc\t1\tlsa_Delete\t")
        self.assertFalse(
            p.packet_lsarpc_1(packet, self.conversation, self.context))

    def test_packet_lsarpc_02(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tlsarpc\t2\tlsa_EnumeratePrivileges\t")
        self.assertFalse(
            p.packet_lsarpc_2(packet, self.conversation, self.context))

    def test_packet_lsarpc_03(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tlsarpc\t3\tlsa_QuerySecurityObject\t")
        self.assertFalse(
            p.packet_lsarpc_3(packet, self.conversation, self.context))

    def test_packet_lsarpc_04(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tlsarpc\t4\tlsa_SetSecurityObject\t")
        self.assertFalse(
            p.packet_lsarpc_4(packet, self.conversation, self.context))

    def test_packet_lsarpc_05(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tlsarpc\t5\tlsa_ChangePassword\t")
        self.assertFalse(
            p.packet_lsarpc_5(packet, self.conversation, self.context))

    def test_packet_lsarpc_06(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tlsarpc\t6\tlsa_OpenPolicy\t")
        self.assertFalse(
            p.packet_lsarpc_6(packet, self.conversation, self.context))

    def test_packet_lsarpc_14(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tlsarpc\t14\tlsa_LookupNames\t")
        self.assertTrue(
            p.packet_lsarpc_14(packet, self.conversation, self.context))

    def test_packet_lsarpc_15(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tlsarpc\t15\tlsa_LookupSids\t")
        self.assertTrue(
            p.packet_lsarpc_15(packet, self.conversation, self.context))

    def test_packet_lsarpc_39(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tlsarpc\t39\tlsa_QueryTrustedDomainInfoBySid\t")
        self.assertTrue(
            p.packet_lsarpc_39(packet, self.conversation, self.context))

    def test_packet_lsarpc_40(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tlsarpc\t40\tlsa_SetTrustedDomainInfo\t")
        self.assertFalse(
            p.packet_lsarpc_40(packet, self.conversation, self.context))

    def test_packet_lsarpc_43(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tlsarpc\t43\tlsa_StorePrivateData\t")
        self.assertFalse(
            p.packet_lsarpc_43(packet, self.conversation, self.context))

    def test_packet_lsarpc_44(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tlsarpc\t44\tlsa_RetrievePrivateData\t")
        self.assertFalse(
            p.packet_lsarpc_44(packet, self.conversation, self.context))

    def test_packet_lsarpc_68(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tlsarpc\t68\tlsa_LookupNames3\t")
        self.assertFalse(
            p.packet_lsarpc_68(packet, self.conversation, self.context))

    def test_packet_lsarpc_76(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tlsarpc\t76\tlsa_LookupSids3\t")
        self.assertTrue(
            p.packet_lsarpc_76(packet, self.conversation, self.context))

    def test_packet_lsarpc_77(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tlsarpc\t77\tlsa_LookupNames4\t")
        self.assertTrue(
            p.packet_lsarpc_77(packet, self.conversation, self.context))

    def test_packet_nbns_00(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tnbns\t0\tquery\t")
        self.assertTrue(
            p.packet_nbns_0(packet, self.conversation, self.context))

    def test_packet_nbns_01(self):
        packet = Packet.from_line("0.0\t06\t1\t1\t2\tnbns\t1\tresponse\t")
        self.assertTrue(
            p.packet_nbns_0(packet, self.conversation, self.context))

    def test_packet_rpc_netlogon_00(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\trpc_netlogon\t0\tNetrLogonUasLogon\t")
        self.assertFalse(
            p.packet_rpc_netlogon_0(packet, self.conversation, self.context))

    def test_packet_rpc_netlogon_01(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\trpc_netlogon\t1\tNetrLogonUasLogoff\t")
        self.assertFalse(
            p.packet_rpc_netlogon_1(packet, self.conversation, self.context))

    def test_packet_rpc_netlogon_04(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\trpc_netlogon\t4\tNetrServerReqChallenge\t")
        self.assertFalse(
            p.packet_rpc_netlogon_4(packet, self.conversation, self.context))

    def test_packet_rpc_netlogon_14(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\trpc_netlogon\t14\tNetrLogonControl2\t")
        self.assertFalse(
            p.packet_rpc_netlogon_14(packet, self.conversation, self.context))

    def test_packet_rpc_netlogon_15(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\trpc_netlogon\t15\tNetrServerAuthenticate2\t")
        self.assertFalse(
            p.packet_rpc_netlogon_15(packet, self.conversation, self.context))

    def test_packet_rpc_netlogon_21(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\trpc_netlogon\t21\tNetrLogonDummyRoutine1\t")
        self.assertFalse(
            p.packet_rpc_netlogon_21(packet, self.conversation, self.context))

    def test_packet_rpc_netlogon_26(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\trpc_netlogon\t26\tNetrServerAuthenticate3\t")
        self.assertFalse(
            p.packet_rpc_netlogon_26(packet, self.conversation, self.context))

    def test_packet_rpc_netlogon_29(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\trpc_netlogon\t29\tNetrLogonGetDomainInfo\t")
        self.assertTrue(
            p.packet_rpc_netlogon_29(packet, self.conversation, self.context))

    def test_packet_rpc_netlogon_30(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\trpc_netlogon\t30\tNetrServerPasswordSet2\t")
        self.assertTrue(
            p.packet_rpc_netlogon_30(packet, self.conversation, self.context))

    def test_packet_rpc_netlogon_34(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\trpc_netlogon\t34\tDsrGetDcNameEx2\t")
        self.assertFalse(
            p.packet_rpc_netlogon_34(packet, self.conversation, self.context))

    def test_packet_rpc_netlogon_39(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\trpc_netlogon\t39\tNetrLogonSamLogonEx\t")
        self.assertTrue(
            p.packet_rpc_netlogon_39(packet, self.conversation, self.context))

    def test_packet_rpc_netlogon_40(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\trpc_netlogon\t40\tDsrEnumerateDomainTrusts\t")
        self.assertTrue(
            p.packet_rpc_netlogon_40(packet, self.conversation, self.context))

    def test_packet_rpc_netlogon_45(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\trpc_netlogon\t45\tNetrLogonSamLogonWithFlags\t")
        self.assertTrue(
            p.packet_rpc_netlogon_45(packet, self.conversation, self.context))

    def test_packet_samr_00(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tsamr\t0\tConnect\t")
        self.assertTrue(
            p.packet_samr_0(packet, self.conversation, self.context))

    def test_packet_samr_01(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tsamr\t1\tClose\t")
        self.assertTrue(
            p.packet_samr_1(packet, self.conversation, self.context))

    def test_packet_samr_03(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tsamr\t3\tQuerySecurity\t")
        self.assertTrue(
            p.packet_samr_3(packet, self.conversation, self.context))

    def test_packet_samr_05(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tsamr\t5\tLookupDomain\t")
        self.assertTrue(
            p.packet_samr_5(packet, self.conversation, self.context))

    def test_packet_samr_06(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tsamr\t6\tEnumDomains\t")
        self.assertTrue(
            p.packet_samr_6(packet, self.conversation, self.context))

    def test_packet_samr_07(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tsamr\t7\tOpenDomain\t")
        self.assertTrue(
            p.packet_samr_7(packet, self.conversation, self.context))

    def test_packet_samr_08(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tsamr\t8\tQueryDomainInfo'\t")
        self.assertTrue(
            p.packet_samr_8(packet, self.conversation, self.context))

    def test_packet_samr_14(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tsamr\t14\tCreateDomAlias\t")
        self.assertFalse(
            p.packet_samr_14(packet, self.conversation, self.context))

    def test_packet_samr_15(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tsamr\t15\tEnumDomainAliases\t")
        self.assertTrue(
            p.packet_samr_15(packet, self.conversation, self.context))

    def test_packet_samr_16(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tsamr\t16\tGetAliasMembership\t")
        self.assertTrue(
            p.packet_samr_16(packet, self.conversation, self.context))

    def test_packet_samr_17(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tsamr\t17\tLookupNames\t")
        self.assertTrue(
            p.packet_samr_17(packet, self.conversation, self.context))

    def test_packet_samr_18(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tsamr\t18\tLookupRids\t")
        self.assertTrue(
            p.packet_samr_18(packet, self.conversation, self.context))

    def test_packet_samr_19(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tsamr\t19\tOpenGroup\t")
        self.assertTrue(
            p.packet_samr_19(packet, self.conversation, self.context))

    def test_packet_samr_25(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tsamr\t25\tQueryGroupMember\t")
        self.assertTrue(
            p.packet_samr_25(packet, self.conversation, self.context))

    def test_packet_samr_34(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tsamr\t34\tOpenUser\t")
        self.assertTrue(
            p.packet_samr_34(packet, self.conversation, self.context))

    def test_packet_samr_36(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tsamr\t36\tQueryUserInfo\t")
        self.assertTrue(
            p.packet_samr_36(packet, self.conversation, self.context))

    def test_packet_samr_37(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tsamr\t37\tSetUserInfo\t")
        self.assertFalse(
            p.packet_samr_37(packet, self.conversation, self.context))

    def test_packet_samr_39(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tsamr\t39\tGetGroupsForUser\t")
        self.assertTrue(
            p.packet_samr_39(packet, self.conversation, self.context))

    def test_packet_samr_40(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tsamr\t40\tQueryDisplayInfo\t")
        self.assertFalse(
            p.packet_samr_40(packet, self.conversation, self.context))

    def test_packet_samr_44(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tsamr\t44\tGetUserPwInfo\t")
        self.assertFalse(
            p.packet_samr_44(packet, self.conversation, self.context))

    def test_packet_samr_57(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tsamr\t57\tConnect2\t")
        self.assertTrue(
            p.packet_samr_57(packet, self.conversation, self.context))

    def test_packet_samr_64(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tsamr\t64\tConnect5\t")
        self.assertTrue(
            p.packet_samr_64(packet, self.conversation, self.context))

    def test_packet_samr_68(self):
        packet = Packet.from_line("0.0\t06\t1\t2\t1\tsamr\t68\t\t")
        self.assertFalse(
            p.packet_samr_68(packet, self.conversation, self.context))

    def test_packet_srvsvc_16(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tsrvsvc\t16\tNetShareGetInfo\t")
        self.assertTrue(
            p.packet_srvsvc_16(packet, self.conversation, self.context))

    def test_packet_srvsvc_21(self):
        packet = Packet.from_line(
            "0.0\t06\t1\t2\t1\tsrvsvc\t21\tNetSrvGetInfo\t")
        self.assertTrue(
            p.packet_srvsvc_21(packet, self.conversation, self.context))
Example #15
0
    def run(self, computername, credopts=None, sambaopts=None, versionopts=None,
            H=None, computerou=None, description=None, prepare_oldjoin=False,
            ip_address_list=None, service_principal_name_list=None):

        if ip_address_list is None:
            ip_address_list = []

        if service_principal_name_list is None:
            service_principal_name_list = []

        # check each IP address if provided
        for ip_address in ip_address_list:
            if not _is_valid_ip(ip_address):
                raise CommandError('Invalid IP address {}'.format(ip_address))

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        try:
            samdb = SamDB(url=H, session_info=system_session(),
                          credentials=creds, lp=lp)
            samdb.newcomputer(computername, computerou=computerou,
                              description=description,
                              prepare_oldjoin=prepare_oldjoin,
                              ip_address_list=ip_address_list,
                              service_principal_name_list=service_principal_name_list,
                              )

            if ip_address_list:
                # if ip_address_list provided, then we need to create DNS
                # records for this computer.

                hostname = re.sub(r"\$$", "", computername)
                if hostname.count('$'):
                    raise CommandError('Illegal computername "%s"' % computername)

                filters = '(&(sAMAccountName={}$)(objectclass=computer))'.format(
                    ldb.binary_encode(hostname))

                recs = samdb.search(
                    base=samdb.domain_dn(),
                    scope=ldb.SCOPE_SUBTREE,
                    expression=filters,
                    attrs=['primaryGroupID', 'objectSid'])

                group = recs[0]['primaryGroupID'][0]
                owner = ndr_unpack(security.dom_sid, recs[0]["objectSid"][0])

                dns_conn = dnsserver.dnsserver(
                    "ncacn_ip_tcp:{}[sign]".format(samdb.host_dns_name()),
                    lp, creds)

                change_owner_sd = security.descriptor()
                change_owner_sd.owner_sid = owner
                change_owner_sd.group_sid = security.dom_sid(
                    "{}-{}".format(samdb.get_domain_sid(), group),
                )

                add_dns_records(
                    samdb, hostname, dns_conn,
                    change_owner_sd, samdb.host_dns_name(),
                    ip_address_list, self.get_logger())
        except Exception as e:
            raise CommandError("Failed to create computer '%s': " %
                               computername, e)

        self.outf.write("Computer '%s' created successfully\n" % computername)
Example #16
0
    def test_rid_set_dbcheck(self):
        """Perform a join against the RID manager and assert we have a RID Set.
        Using dbcheck, we assert that we can detect out of range users."""

        fsmo_dn = ldb.Dn(self.ldb_dc1, "CN=RID Manager$,CN=System," + self.ldb_dc1.domain_dn())
        (fsmo_owner, fsmo_not_owner) = self._determine_fSMORoleOwner(fsmo_dn)

        targetdir = self._test_join(fsmo_owner['dns_name'], "RIDALLOCTEST6")
        try:
            # Connect to the database
            ldb_url = "tdb://%s" % os.path.join(targetdir, "private/sam.ldb")
            smbconf = os.path.join(targetdir, "etc/smb.conf")

            lp = self.get_loadparm()
            new_ldb = SamDB(ldb_url, credentials=self.get_credentials(),
                            session_info=system_session(lp), lp=lp)

            # 1. Get server name
            res = new_ldb.search(base=ldb.Dn(new_ldb, new_ldb.get_serverName()),
                                 scope=ldb.SCOPE_BASE, attrs=["serverReference"])
            # 2. Get server reference
            server_ref_dn = ldb.Dn(new_ldb, res[0]['serverReference'][0])

            # 3. Assert we get the RID Set
            res = new_ldb.search(base=server_ref_dn,
                                 scope=ldb.SCOPE_BASE, attrs=['rIDSetReferences'])

            self.assertTrue("rIDSetReferences" in res[0])
            rid_set_dn = ldb.Dn(new_ldb, res[0]["rIDSetReferences"][0])

            # 4. Add a new user (triggers RID set work)
            new_ldb.newuser("ridalloctestuser", "P@ssword!")

            # 5. Now fetch the RID SET
            rid_set_res = new_ldb.search(base=rid_set_dn,
                                         scope=ldb.SCOPE_BASE, attrs=['rIDNextRid',
                                                                      'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            last_rid = (0xFFFFFFFF00000000 & next_pool) >> 32

            # 6. Add user above the ridNextRid and at mid-range.
            #
            # We can do this with safety because this is an offline DB that will be
            # destroyed.
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser1,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD, 'objectClass')
            m['objectSid'] = ldb.MessageElement(ndr_pack(security.dom_sid(str(new_ldb.get_domain_sid()) + "-%d" % (last_rid - 10))),
                                                ldb.FLAG_MOD_ADD,
                                                'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            # 7. Check the RID Set
            chk = dbcheck(new_ldb, verbose=False, fix=True, yes=True, quiet=True)

            # Should have one error (wrong rIDNextRID)
            self.assertEqual(chk.check_database(DN=rid_set_dn, scope=ldb.SCOPE_BASE), 1)

            # 8. Assert we get didn't show any other errors
            chk = dbcheck(new_ldb, verbose=False, fix=False, quiet=True)

            rid_set_res = new_ldb.search(base=rid_set_dn,
                                         scope=ldb.SCOPE_BASE, attrs=['rIDNextRid',
                                                                      'rIDAllocationPool'])
            last_allocated_rid = int(rid_set_res[0]["rIDNextRid"][0])
            self.assertEquals(last_allocated_rid, last_rid - 10)

            # 9. Assert that the range wasn't thrown away

            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            self.assertEqual(last_rid, (0xFFFFFFFF00000000 & next_pool) >> 32, "rid pool should have changed")
        finally:
            self._test_force_demote(fsmo_owner['dns_name'], "RIDALLOCTEST6")
            shutil.rmtree(targetdir, ignore_errors=True)
Example #17
0
    def test_rid_set_dbcheck_after_seize(self):
        """Perform a join against the RID manager and assert we have a RID Set.
        We seize the RID master role, then using dbcheck, we assert that we can
        detect out of range users (and then bump the RID set as required)."""

        fsmo_dn = ldb.Dn(self.ldb_dc1, "CN=RID Manager$,CN=System," + self.ldb_dc1.domain_dn())
        (fsmo_owner, fsmo_not_owner) = self._determine_fSMORoleOwner(fsmo_dn)

        targetdir = self._test_join(fsmo_owner['dns_name'], "RIDALLOCTEST7")
        try:
            # Connect to the database
            ldb_url = "tdb://%s" % os.path.join(targetdir, "private/sam.ldb")
            smbconf = os.path.join(targetdir, "etc/smb.conf")

            lp = self.get_loadparm()
            new_ldb = SamDB(ldb_url, credentials=self.get_credentials(),
                            session_info=system_session(lp), lp=lp)

            # 1. Get server name
            res = new_ldb.search(base=ldb.Dn(new_ldb, new_ldb.get_serverName()),
                                 scope=ldb.SCOPE_BASE, attrs=["serverReference"])
            # 2. Get server reference
            server_ref_dn = ldb.Dn(new_ldb, res[0]['serverReference'][0])

            # 3. Assert we get the RID Set
            res = new_ldb.search(base=server_ref_dn,
                                 scope=ldb.SCOPE_BASE, attrs=['rIDSetReferences'])

            self.assertTrue("rIDSetReferences" in res[0])
            rid_set_dn = ldb.Dn(new_ldb, res[0]["rIDSetReferences"][0])

            # 4. Seize the RID Manager role
            (result, out, err) = self.runsubcmd("fsmo", "seize", "--role", "rid", "-H", ldb_url, "-s", smbconf, "--force")
            self.assertCmdSuccess(result, out, err)
            self.assertEquals(err,"","Shouldn't be any error messages")

            # 5. Add a new user (triggers RID set work)
            new_ldb.newuser("ridalloctestuser", "P@ssword!")

            # 6. Now fetch the RID SET
            rid_set_res = new_ldb.search(base=rid_set_dn,
                                         scope=ldb.SCOPE_BASE, attrs=['rIDNextRid',
                                                                      'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            last_rid = (0xFFFFFFFF00000000 & next_pool) >> 32

            # 7. Add user above the ridNextRid and at almost the end of the range.
            #
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser2,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD, 'objectClass')
            m['objectSid'] = ldb.MessageElement(ndr_pack(security.dom_sid(str(new_ldb.get_domain_sid()) + "-%d" % (last_rid - 3))),
                                                ldb.FLAG_MOD_ADD,
                                                'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            # 8. Add user above the ridNextRid and at the end of the range
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser3,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD, 'objectClass')
            m['objectSid'] = ldb.MessageElement(ndr_pack(security.dom_sid(str(new_ldb.get_domain_sid()) + "-%d" % last_rid)),
                                                ldb.FLAG_MOD_ADD,
                                                'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            chk = dbcheck(new_ldb, verbose=False, fix=True, yes=True, quiet=True)

            # Should have fixed two errors (wrong ridNextRid)
            self.assertEqual(chk.check_database(DN=rid_set_dn, scope=ldb.SCOPE_BASE), 2)

            # 9. Assert we get didn't show any other errors
            chk = dbcheck(new_ldb, verbose=False, fix=False, quiet=True)

            # 10. Add another user (checks RID rollover)
            # We have seized the role, so we can do that.
            new_ldb.newuser("ridalloctestuser3", "P@ssword!")

            rid_set_res = new_ldb.search(base=rid_set_dn,
                                         scope=ldb.SCOPE_BASE, attrs=['rIDNextRid',
                                                                      'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            self.assertNotEqual(last_rid, (0xFFFFFFFF00000000 & next_pool) >> 32, "rid pool should have changed")
        finally:
            self._test_force_demote(fsmo_owner['dns_name'], "RIDALLOCTEST7")
            shutil.rmtree(targetdir, ignore_errors=True)
Example #18
0
class UserAccountControlTests(samba.tests.TestCase):
    def add_computer_ldap(self, computername, others=None, samdb=None):
        if samdb is None:
            samdb = self.samdb
        dn = "CN=%s,OU=test_computer_ou1,%s" % (computername, self.base_dn)
        domainname = ldb.Dn(self.samdb, self.samdb.domain_dn()).canonical_str().replace("/", "")
        samaccountname = "%s$" % computername
        dnshostname = "%s.%s" % (computername, domainname)
        msg_dict = {
            "dn": dn,
            "objectclass": "computer"}
        if others is not None:
            msg_dict = dict(msg_dict.items() + others.items())

        msg = ldb.Message.from_dict(self.samdb, msg_dict )
        msg["sAMAccountName"] = samaccountname

        print("Adding computer account %s" % computername)
        samdb.add(msg)

    def get_creds(self, target_username, target_password):
        creds_tmp = Credentials()
        creds_tmp.set_username(target_username)
        creds_tmp.set_password(target_password)
        creds_tmp.set_domain(creds.get_domain())
        creds_tmp.set_realm(creds.get_realm())
        creds_tmp.set_workstation(creds.get_workstation())
        creds_tmp.set_gensec_features(creds_tmp.get_gensec_features()
                                      | gensec.FEATURE_SEAL)
        creds_tmp.set_kerberos_state(DONT_USE_KERBEROS) # kinit is too expensive to use in a tight loop
        return creds_tmp

    def setUp(self):
        super(UserAccountControlTests, self).setUp()
        self.admin_creds = creds
        self.admin_samdb = SamDB(url=ldaphost,
                                 session_info=system_session(),
                                 credentials=self.admin_creds, lp=lp)
        self.domain_sid = security.dom_sid(self.admin_samdb.get_domain_sid())
        self.base_dn = self.admin_samdb.domain_dn()

        self.unpriv_user = "******"
        self.unpriv_user_pw = "samba123@"
        self.unpriv_creds = self.get_creds(self.unpriv_user, self.unpriv_user_pw)

        delete_force(self.admin_samdb, "CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))
        delete_force(self.admin_samdb, "OU=test_computer_ou1,%s" % (self.base_dn))
        delete_force(self.admin_samdb, "CN=%s,CN=Users,%s" % (self.unpriv_user, self.base_dn))

        self.admin_samdb.newuser(self.unpriv_user, self.unpriv_user_pw)
        res = self.admin_samdb.search("CN=%s,CN=Users,%s" % (self.unpriv_user, self.admin_samdb.domain_dn()),
                                      scope=SCOPE_BASE,
                                      attrs=["objectSid"])
        self.assertEqual(1, len(res))

        self.unpriv_user_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        self.unpriv_user_dn = res[0].dn

        self.samdb = SamDB(url=ldaphost, credentials=self.unpriv_creds, lp=lp)

        self.samr = samr.samr("ncacn_ip_tcp:%s[seal]" % host, lp, self.unpriv_creds)
        self.samr_handle = self.samr.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
        self.samr_domain = self.samr.OpenDomain(self.samr_handle, security.SEC_FLAG_MAXIMUM_ALLOWED, self.domain_sid)

        self.sd_utils = sd_utils.SDUtils(self.admin_samdb)

        self.admin_samdb.create_ou("OU=test_computer_ou1," + self.base_dn)
        self.unpriv_user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(self.unpriv_user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," + self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        self.add_computer_ldap("testcomputer-t")

        self.sd_utils.modify_sd_on_dn("OU=test_computer_ou1," + self.base_dn, old_sd)

        self.computernames = ["testcomputer-0"]

        # Get the SD of the template account, then force it to match
        # what we expect for SeMachineAccountPrivilege accounts, so we
        # can confirm we created the accounts correctly
        self.sd_reference_cc = self.sd_utils.read_sd_on_dn("CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))

        self.sd_reference_modify = self.sd_utils.read_sd_on_dn("CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))
        for ace in self.sd_reference_modify.dacl.aces:
            if ace.type == security.SEC_ACE_TYPE_ACCESS_ALLOWED and ace.trustee == self.unpriv_user_sid:
                ace.access_mask = ace.access_mask | security.SEC_ADS_SELF_WRITE | security.SEC_ADS_WRITE_PROP

        # Now reconnect without domain admin rights
        self.samdb = SamDB(url=ldaphost, credentials=self.unpriv_creds, lp=lp)


    def tearDown(self):
        super(UserAccountControlTests, self).tearDown()
        for computername in self.computernames:
            delete_force(self.admin_samdb, "CN=%s,OU=test_computer_ou1,%s" % (computername, self.base_dn))
        delete_force(self.admin_samdb, "CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))
        delete_force(self.admin_samdb, "OU=test_computer_ou1,%s" % (self.base_dn))
        delete_force(self.admin_samdb, "CN=%s,CN=Users,%s" % (self.unpriv_user, self.base_dn))

    def test_add_computer_sd_cc(self):
        user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," + self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        computername=self.computernames[0]
        sd = ldb.MessageElement((ndr_pack(self.sd_reference_modify)),
                                ldb.FLAG_MOD_ADD,
                                "nTSecurityDescriptor")
        self.add_computer_ldap(computername,
                               others={"nTSecurityDescriptor": sd})

        res = self.admin_samdb.search("%s" % self.base_dn,
                                      expression="(&(objectClass=computer)(samAccountName=%s$))" % computername,
                                      scope=SCOPE_SUBTREE,
                                      attrs=["ntSecurityDescriptor"])

        desc = res[0]["nTSecurityDescriptor"][0]
        desc = ndr_unpack(security.descriptor, desc, allow_remaining=True)

        sddl = desc.as_sddl(self.domain_sid)
        self.assertEqual(self.sd_reference_modify.as_sddl(self.domain_sid), sddl)

        m = ldb.Message()
        m.dn = res[0].dn
        m["description"]= ldb.MessageElement(
            ("A description"), ldb.FLAG_MOD_REPLACE,
            "description")
        self.samdb.modify(m)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(str(samba.dsdb.UF_SERVER_TRUST_ACCOUNT),
                                                     ldb.FLAG_MOD_REPLACE, "userAccountControl")
        try:
            self.samdb.modify(m)
            self.fail("Unexpectedly able to set userAccountControl to be a DC on %s" % m.dn)
        except LdbError as e5:
            (enum, estr) = e5.args
            self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(str(samba.dsdb.UF_WORKSTATION_TRUST_ACCOUNT|samba.dsdb.UF_PARTIAL_SECRETS_ACCOUNT),
                                                     ldb.FLAG_MOD_REPLACE, "userAccountControl")
        try:
            self.samdb.modify(m)
            self.fail("Unexpectedly able to set userAccountControl to be an RODC on %s" % m.dn)
        except LdbError as e6:
            (enum, estr) = e6.args
            self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(str(samba.dsdb.UF_WORKSTATION_TRUST_ACCOUNT),
                                                     ldb.FLAG_MOD_REPLACE, "userAccountControl")
        try:
            self.samdb.modify(m)
            self.fail("Unexpectedly able to set userAccountControl to be an Workstation on %s" % m.dn)
        except LdbError as e7:
            (enum, estr) = e7.args
            self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(str(samba.dsdb.UF_NORMAL_ACCOUNT),
                                                     ldb.FLAG_MOD_REPLACE, "userAccountControl")
        self.samdb.modify(m)

        m = ldb.Message()
        m.dn = res[0].dn
        m["primaryGroupID"] = ldb.MessageElement(str(security.DOMAIN_RID_ADMINS),
                                                 ldb.FLAG_MOD_REPLACE, "primaryGroupID")
        try:
            self.samdb.modify(m)
        except LdbError as e8:
            (enum, estr) = e8.args
            self.assertEqual(ldb.ERR_UNWILLING_TO_PERFORM, enum)
            return
        self.fail()

    def test_mod_computer_cc(self):
        user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," + self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        computername=self.computernames[0]
        self.add_computer_ldap(computername)

        res = self.admin_samdb.search("%s" % self.base_dn,
                                      expression="(&(objectClass=computer)(samAccountName=%s$))" % computername,
                                      scope=SCOPE_SUBTREE,
                                      attrs=[])

        m = ldb.Message()
        m.dn = res[0].dn
        m["description"]= ldb.MessageElement(
            ("A description"), ldb.FLAG_MOD_REPLACE,
            "description")
        self.samdb.modify(m)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(str(samba.dsdb.UF_WORKSTATION_TRUST_ACCOUNT|samba.dsdb.UF_PARTIAL_SECRETS_ACCOUNT),
                                                     ldb.FLAG_MOD_REPLACE, "userAccountControl")
        try:
            self.samdb.modify(m)
            self.fail("Unexpectedly able to set userAccountControl on %s" % m.dn)
        except LdbError as e9:
            (enum, estr) = e9.args
            self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(str(samba.dsdb.UF_SERVER_TRUST_ACCOUNT),
                                                     ldb.FLAG_MOD_REPLACE, "userAccountControl")
        try:
             self.samdb.modify(m)
             self.fail()
        except LdbError as e10:
             (enum, estr) = e10.args
             self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(str(samba.dsdb.UF_NORMAL_ACCOUNT),
                                                     ldb.FLAG_MOD_REPLACE, "userAccountControl")
        self.samdb.modify(m)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(str(samba.dsdb.UF_WORKSTATION_TRUST_ACCOUNT),
                                                     ldb.FLAG_MOD_REPLACE, "userAccountControl")
        try:
            self.samdb.modify(m)
            self.fail("Unexpectedly able to set userAccountControl to be an Workstation on %s" % m.dn)
        except LdbError as e11:
            (enum, estr) = e11.args
            self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)


    def test_admin_mod_uac(self):
        computername=self.computernames[0]
        self.add_computer_ldap(computername, samdb=self.admin_samdb)

        res = self.admin_samdb.search("%s" % self.base_dn,
                                      expression="(&(objectClass=computer)(samAccountName=%s$))" % computername,
                                      scope=SCOPE_SUBTREE,
                                      attrs=["userAccountControl"])

        self.assertEqual(int(res[0]["userAccountControl"][0]), UF_NORMAL_ACCOUNT|UF_ACCOUNTDISABLE|UF_PASSWD_NOTREQD)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(str(UF_WORKSTATION_TRUST_ACCOUNT|UF_PARTIAL_SECRETS_ACCOUNT|UF_TRUSTED_FOR_DELEGATION),
                                                     ldb.FLAG_MOD_REPLACE, "userAccountControl")
        try:
            self.admin_samdb.modify(m)
            self.fail("Unexpectedly able to set userAccountControl to UF_WORKSTATION_TRUST_ACCOUNT|UF_PARTIAL_SECRETS_ACCOUNT|UF_TRUSTED_FOR_DELEGATION on %s" % m.dn)
        except LdbError as e12:
            (enum, estr) = e12.args
            self.assertEqual(ldb.ERR_OTHER, enum)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(str(UF_WORKSTATION_TRUST_ACCOUNT|UF_PARTIAL_SECRETS_ACCOUNT),
                                                     ldb.FLAG_MOD_REPLACE, "userAccountControl")
        self.admin_samdb.modify(m)

        res = self.admin_samdb.search("%s" % self.base_dn,
                                      expression="(&(objectClass=computer)(samAccountName=%s$))" % computername,
                                      scope=SCOPE_SUBTREE,
                                      attrs=["userAccountControl"])

        self.assertEqual(int(res[0]["userAccountControl"][0]), UF_WORKSTATION_TRUST_ACCOUNT|UF_PARTIAL_SECRETS_ACCOUNT)
        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(str(UF_ACCOUNTDISABLE),
                                                     ldb.FLAG_MOD_REPLACE, "userAccountControl")
        self.admin_samdb.modify(m)

        res = self.admin_samdb.search("%s" % self.base_dn,
                                      expression="(&(objectClass=computer)(samAccountName=%s$))" % computername,
                                      scope=SCOPE_SUBTREE,
                                      attrs=["userAccountControl"])

        self.assertEqual(int(res[0]["userAccountControl"][0]), UF_NORMAL_ACCOUNT| UF_ACCOUNTDISABLE)


    def test_uac_bits_set(self):
        user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," + self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        computername=self.computernames[0]
        self.add_computer_ldap(computername)

        res = self.admin_samdb.search("%s" % self.base_dn,
                                      expression="(&(objectClass=computer)(samAccountName=%s$))" % computername,
                                      scope=SCOPE_SUBTREE,
                                      attrs=[])

        m = ldb.Message()
        m.dn = res[0].dn
        m["description"]= ldb.MessageElement(
            ("A description"), ldb.FLAG_MOD_REPLACE,
            "description")
        self.samdb.modify(m)

        # These bits are privileged, but authenticated users have that CAR by default, so this is a pain to test
        priv_to_auth_users_bits = set([UF_PASSWD_NOTREQD, UF_ENCRYPTED_TEXT_PASSWORD_ALLOWED,
                                       UF_DONT_EXPIRE_PASSWD])

        # These bits really are privileged, or can't be changed from UF_NORMAL as a non-admin
        priv_bits = set([UF_INTERDOMAIN_TRUST_ACCOUNT, UF_SERVER_TRUST_ACCOUNT,
                         UF_TRUSTED_FOR_DELEGATION, UF_TRUSTED_TO_AUTHENTICATE_FOR_DELEGATION,
                         UF_WORKSTATION_TRUST_ACCOUNT])

        invalid_bits = set([UF_TEMP_DUPLICATE_ACCOUNT, UF_PARTIAL_SECRETS_ACCOUNT])

        for bit in bits:
            m = ldb.Message()
            m.dn = res[0].dn
            m["userAccountControl"] = ldb.MessageElement(str(bit|UF_PASSWD_NOTREQD),
                                                         ldb.FLAG_MOD_REPLACE, "userAccountControl")
            try:
                self.samdb.modify(m)
                if (bit in priv_bits):
                    self.fail("Unexpectedly able to set userAccountControl bit 0x%08X on %s" % (bit, m.dn))
            except LdbError as e:
                (enum, estr) = e.args
                if bit in invalid_bits:
                    self.assertEqual(enum, ldb.ERR_OTHER, "was not able to set 0x%08X on %s" % (bit, m.dn))
                    # No point going on, try the next bit
                    continue
                elif (bit in priv_bits):
                    self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)
                else:
                    self.fail("Unable to set userAccountControl bit 0x%08X on %s: %s" % (bit, m.dn, estr))


    def uac_bits_unrelated_modify_helper(self, account_type):
        user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," + self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        computername=self.computernames[0]
        self.add_computer_ldap(computername, others={"userAccountControl": [str(account_type)]})

        res = self.admin_samdb.search("%s" % self.base_dn,
                                      expression="(&(objectClass=computer)(samAccountName=%s$))" % computername,
                                      scope=SCOPE_SUBTREE,
                                      attrs=["userAccountControl"])
        self.assertEqual(int(res[0]["userAccountControl"][0]), account_type)

        m = ldb.Message()
        m.dn = res[0].dn
        m["description"]= ldb.MessageElement(
            ("A description"), ldb.FLAG_MOD_REPLACE,
            "description")
        self.samdb.modify(m)

        invalid_bits = set([UF_TEMP_DUPLICATE_ACCOUNT, UF_PARTIAL_SECRETS_ACCOUNT])

        # UF_LOCKOUT isn't actually ignored, it changes other
        # attributes but does not stick here.  See MS-SAMR 2.2.1.13
        # UF_FLAG Codes clarification that UF_SCRIPT and
        # UF_PASSWD_CANT_CHANGE are simply ignored by both clients and
        # servers.  Other bits are ignored as they are undefined, or
        # are not set into the attribute (instead triggering other
        # events).
        ignored_bits = set([UF_SCRIPT, UF_00000004, UF_LOCKOUT, UF_PASSWD_CANT_CHANGE,
                            UF_00000400, UF_00004000, UF_00008000, UF_PASSWORD_EXPIRED,
                            int("0x10000000", 16), int("0x20000000", 16), int("0x40000000", 16), int("0x80000000", 16)])
        super_priv_bits = set([UF_INTERDOMAIN_TRUST_ACCOUNT])

        priv_to_remove_bits = set([UF_TRUSTED_FOR_DELEGATION, UF_TRUSTED_TO_AUTHENTICATE_FOR_DELEGATION, UF_WORKSTATION_TRUST_ACCOUNT])

        for bit in bits:
            # Reset this to the initial position, just to be sure
            m = ldb.Message()
            m.dn = res[0].dn
            m["userAccountControl"] = ldb.MessageElement(str(account_type),
                                                         ldb.FLAG_MOD_REPLACE, "userAccountControl")
            self.admin_samdb.modify(m)

            res = self.admin_samdb.search("%s" % self.base_dn,
                                          expression="(&(objectClass=computer)(samAccountName=%s$))" % computername,
                                          scope=SCOPE_SUBTREE,
                                          attrs=["userAccountControl"])

            self.assertEqual(int(res[0]["userAccountControl"][0]), account_type)

            m = ldb.Message()
            m.dn = res[0].dn
            m["userAccountControl"] = ldb.MessageElement(str(bit|UF_PASSWD_NOTREQD),
                                                         ldb.FLAG_MOD_REPLACE, "userAccountControl")
            try:
                self.admin_samdb.modify(m)
                if bit in invalid_bits:
                    self.fail("Should have been unable to set userAccountControl bit 0x%08X on %s" % (bit, m.dn))

            except LdbError as e1:
                (enum, estr) = e1.args
                if bit in invalid_bits:
                    self.assertEqual(enum, ldb.ERR_OTHER)
                    # No point going on, try the next bit
                    continue
                elif bit in super_priv_bits:
                    self.assertEqual(enum, ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS)
                    # No point going on, try the next bit
                    continue
                else:
                    self.fail("Unable to set userAccountControl bit 0x%08X on %s: %s" % (bit, m.dn, estr))

            res = self.admin_samdb.search("%s" % self.base_dn,
                                          expression="(&(objectClass=computer)(samAccountName=%s$))" % computername,
                                          scope=SCOPE_SUBTREE,
                                          attrs=["userAccountControl"])

            if bit in ignored_bits:
                self.assertEqual(int(res[0]["userAccountControl"][0]), UF_NORMAL_ACCOUNT|UF_PASSWD_NOTREQD, "Bit 0x%08x shouldn't stick" % bit)
            else:
                if bit in account_types:
                    self.assertEqual(int(res[0]["userAccountControl"][0]), bit|UF_PASSWD_NOTREQD, "Bit 0x%08x didn't stick" % bit)
                else:
                    self.assertEqual(int(res[0]["userAccountControl"][0]), bit|UF_NORMAL_ACCOUNT|UF_PASSWD_NOTREQD, "Bit 0x%08x didn't stick" % bit)

            try:
                m = ldb.Message()
                m.dn = res[0].dn
                m["userAccountControl"] = ldb.MessageElement(str(bit|UF_PASSWD_NOTREQD|UF_ACCOUNTDISABLE),
                                                             ldb.FLAG_MOD_REPLACE, "userAccountControl")
                self.samdb.modify(m)

            except LdbError as e2:
                (enum, estr) = e2.args
                self.fail("Unable to set userAccountControl bit 0x%08X on %s: %s" % (bit, m.dn, estr))

            res = self.admin_samdb.search("%s" % self.base_dn,
                                          expression="(&(objectClass=computer)(samAccountName=%s$))" % computername,
                                          scope=SCOPE_SUBTREE,
                                          attrs=["userAccountControl"])

            if bit in account_types:
                self.assertEqual(int(res[0]["userAccountControl"][0]),
                                 bit|UF_ACCOUNTDISABLE|UF_PASSWD_NOTREQD,
                                 "bit 0X%08x should have been added (0X%08x vs 0X%08x)"
                                 % (bit, int(res[0]["userAccountControl"][0]),
                                    bit|UF_ACCOUNTDISABLE|UF_PASSWD_NOTREQD))
            elif bit in ignored_bits:
                self.assertEqual(int(res[0]["userAccountControl"][0]),
                                 UF_NORMAL_ACCOUNT|UF_ACCOUNTDISABLE|UF_PASSWD_NOTREQD,
                                 "bit 0X%08x should have been added (0X%08x vs 0X%08x)"
                                 % (bit, int(res[0]["userAccountControl"][0]),
                                    UF_NORMAL_ACCOUNT|UF_ACCOUNTDISABLE|UF_PASSWD_NOTREQD))

            else:
                self.assertEqual(int(res[0]["userAccountControl"][0]),
                                 bit|UF_NORMAL_ACCOUNT|UF_ACCOUNTDISABLE|UF_PASSWD_NOTREQD,
                                 "bit 0X%08x should have been added (0X%08x vs 0X%08x)"
                                 % (bit, int(res[0]["userAccountControl"][0]),
                                    bit|UF_NORMAL_ACCOUNT|UF_ACCOUNTDISABLE|UF_PASSWD_NOTREQD))

            try:
                m = ldb.Message()
                m.dn = res[0].dn
                m["userAccountControl"] = ldb.MessageElement(str(UF_PASSWD_NOTREQD|UF_ACCOUNTDISABLE),
                                                             ldb.FLAG_MOD_REPLACE, "userAccountControl")
                self.samdb.modify(m)
                if bit in priv_to_remove_bits:
                    self.fail("Should have been unable to remove userAccountControl bit 0x%08X on %s" % (bit, m.dn))

            except LdbError as e3:
                (enum, estr) = e3.args
                if bit in priv_to_remove_bits:
                    self.assertEqual(enum, ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS)
                else:
                    self.fail("Unexpectedly unable to remove userAccountControl bit 0x%08X on %s: %s" % (bit, m.dn, estr))

            res = self.admin_samdb.search("%s" % self.base_dn,
                                          expression="(&(objectClass=computer)(samAccountName=%s$))" % computername,
                                          scope=SCOPE_SUBTREE,
                                          attrs=["userAccountControl"])

            if bit in priv_to_remove_bits:
                if bit in account_types:
                    self.assertEqual(int(res[0]["userAccountControl"][0]),
                                     bit|UF_ACCOUNTDISABLE|UF_PASSWD_NOTREQD,
                                     "bit 0X%08x should not have been removed" % bit)
                else:
                    self.assertEqual(int(res[0]["userAccountControl"][0]),
                                     bit|UF_NORMAL_ACCOUNT|UF_ACCOUNTDISABLE|UF_PASSWD_NOTREQD,
                                     "bit 0X%08x should not have been removed" % bit)
            else:
                self.assertEqual(int(res[0]["userAccountControl"][0]),
                                 UF_NORMAL_ACCOUNT|UF_ACCOUNTDISABLE|UF_PASSWD_NOTREQD,
                                 "bit 0X%08x should have been removed" % bit)

    def test_uac_bits_unrelated_modify_normal(self):
        self.uac_bits_unrelated_modify_helper(UF_NORMAL_ACCOUNT)

    def test_uac_bits_unrelated_modify_workstation(self):
        self.uac_bits_unrelated_modify_helper(UF_WORKSTATION_TRUST_ACCOUNT)

    def test_uac_bits_add(self):
        computername=self.computernames[0]

        user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," + self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        invalid_bits = set([UF_TEMP_DUPLICATE_ACCOUNT, UF_PARTIAL_SECRETS_ACCOUNT])
        # These bits are privileged, but authenticated users have that CAR by default, so this is a pain to test
        priv_to_auth_users_bits = set([UF_PASSWD_NOTREQD, UF_ENCRYPTED_TEXT_PASSWORD_ALLOWED,
                                       UF_DONT_EXPIRE_PASSWD])

        # These bits really are privileged
        priv_bits = set([UF_INTERDOMAIN_TRUST_ACCOUNT, UF_SERVER_TRUST_ACCOUNT,
                         UF_TRUSTED_FOR_DELEGATION, UF_TRUSTED_TO_AUTHENTICATE_FOR_DELEGATION])

        for bit in bits:
            try:
                self.add_computer_ldap(computername, others={"userAccountControl": [str(bit)]})
                delete_force(self.admin_samdb, "CN=%s,OU=test_computer_ou1,%s" % (computername, self.base_dn))
                if bit in priv_bits:
                    self.fail("Unexpectdly able to set userAccountControl bit 0x%08X on %s" % (bit, computername))

            except LdbError as e4:
                (enum, estr) = e4.args
                if bit in invalid_bits:
                    self.assertEqual(enum, ldb.ERR_OTHER, "Invalid bit 0x%08X was able to be set on %s" % (bit, computername))
                    # No point going on, try the next bit
                    continue
                elif bit in priv_bits:
                    self.assertEqual(enum, ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS)
                    continue
                else:
                    self.fail("Unable to set userAccountControl bit 0x%08X on %s: %s" % (bit, computername, estr))

    def test_primarygroupID_cc_add(self):
        computername=self.computernames[0]

        user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," + self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)
        try:
            # When creating a new object, you can not ever set the primaryGroupID
            self.add_computer_ldap(computername, others={"primaryGroupID": [str(security.DOMAIN_RID_ADMINS)]})
            self.fail("Unexpectedly able to set primaryGruopID to be an admin on %s" % computername)
        except LdbError as e13:
            (enum, estr) = e13.args
            self.assertEqual(enum, ldb.ERR_UNWILLING_TO_PERFORM)


    def test_primarygroupID_priv_DC_modify(self):
        computername=self.computernames[0]

        self.add_computer_ldap(computername,
                               others={"userAccountControl": [str(UF_SERVER_TRUST_ACCOUNT)]},
                               samdb=self.admin_samdb)
        res = self.admin_samdb.search("%s" % self.base_dn,
                                      expression="(&(objectClass=computer)(samAccountName=%s$))" % computername,
                                      scope=SCOPE_SUBTREE,
                                      attrs=[""])


        m = ldb.Message()
        m.dn = ldb.Dn(self.admin_samdb, "<SID=%s-%d>" % (str(self.domain_sid),
                                                         security.DOMAIN_RID_USERS))
        m["member"]= ldb.MessageElement(
            [str(res[0].dn)], ldb.FLAG_MOD_ADD,
            "member")
        self.admin_samdb.modify(m)

        m = ldb.Message()
        m.dn = res[0].dn
        m["primaryGroupID"]= ldb.MessageElement(
            [str(security.DOMAIN_RID_USERS)], ldb.FLAG_MOD_REPLACE,
            "primaryGroupID")
        try:
            self.admin_samdb.modify(m)

            # When creating a new object, you can not ever set the primaryGroupID
            self.fail("Unexpectedly able to set primaryGroupID to be other than DCS on %s" % computername)
        except LdbError as e14:
            (enum, estr) = e14.args
            self.assertEqual(enum, ldb.ERR_UNWILLING_TO_PERFORM)

    def test_primarygroupID_priv_member_modify(self):
        computername=self.computernames[0]

        self.add_computer_ldap(computername,
                               others={"userAccountControl": [str(UF_WORKSTATION_TRUST_ACCOUNT|UF_PARTIAL_SECRETS_ACCOUNT)]},
                               samdb=self.admin_samdb)
        res = self.admin_samdb.search("%s" % self.base_dn,
                                      expression="(&(objectClass=computer)(samAccountName=%s$))" % computername,
                                      scope=SCOPE_SUBTREE,
                                      attrs=[""])


        m = ldb.Message()
        m.dn = ldb.Dn(self.admin_samdb, "<SID=%s-%d>" % (str(self.domain_sid),
                                                         security.DOMAIN_RID_USERS))
        m["member"]= ldb.MessageElement(
            [str(res[0].dn)], ldb.FLAG_MOD_ADD,
            "member")
        self.admin_samdb.modify(m)

        m = ldb.Message()
        m.dn = res[0].dn
        m["primaryGroupID"]= ldb.MessageElement(
            [str(security.DOMAIN_RID_USERS)], ldb.FLAG_MOD_REPLACE,
            "primaryGroupID")
        try:
            self.admin_samdb.modify(m)

            # When creating a new object, you can not ever set the primaryGroupID
            self.fail("Unexpectedly able to set primaryGroupID to be other than DCS on %s" % computername)
        except LdbError as e15:
            (enum, estr) = e15.args
            self.assertEqual(enum, ldb.ERR_UNWILLING_TO_PERFORM)


    def test_primarygroupID_priv_user_modify(self):
        computername=self.computernames[0]

        self.add_computer_ldap(computername,
                               others={"userAccountControl": [str(UF_WORKSTATION_TRUST_ACCOUNT)]},
                               samdb=self.admin_samdb)
        res = self.admin_samdb.search("%s" % self.base_dn,
                                      expression="(&(objectClass=computer)(samAccountName=%s$))" % computername,
                                      scope=SCOPE_SUBTREE,
                                      attrs=[""])


        m = ldb.Message()
        m.dn = ldb.Dn(self.admin_samdb, "<SID=%s-%d>" % (str(self.domain_sid),
                                                         security.DOMAIN_RID_ADMINS))
        m["member"]= ldb.MessageElement(
            [str(res[0].dn)], ldb.FLAG_MOD_ADD,
            "member")
        self.admin_samdb.modify(m)

        m = ldb.Message()
        m.dn = res[0].dn
        m["primaryGroupID"]= ldb.MessageElement(
            [str(security.DOMAIN_RID_ADMINS)], ldb.FLAG_MOD_REPLACE,
            "primaryGroupID")
        self.admin_samdb.modify(m)
class UserAccountControlTests(samba.tests.TestCase):
    def add_computer_ldap(self, computername, others=None, samdb=None):
        if samdb is None:
            samdb = self.samdb
        dn = "CN=%s,OU=test_computer_ou1,%s" % (computername, self.base_dn)
        domainname = ldb.Dn(self.samdb,
                            self.samdb.domain_dn()).canonical_str().replace(
                                "/", "")
        samaccountname = "%s$" % computername
        dnshostname = "%s.%s" % (computername, domainname)
        msg_dict = {"dn": dn, "objectclass": "computer"}
        if others is not None:
            msg_dict = dict(list(msg_dict.items()) + list(others.items()))

        msg = ldb.Message.from_dict(self.samdb, msg_dict)
        msg["sAMAccountName"] = samaccountname

        print("Adding computer account %s" % computername)
        samdb.add(msg)

    def get_creds(self, target_username, target_password):
        creds_tmp = Credentials()
        creds_tmp.set_username(target_username)
        creds_tmp.set_password(target_password)
        creds_tmp.set_domain(creds.get_domain())
        creds_tmp.set_realm(creds.get_realm())
        creds_tmp.set_workstation(creds.get_workstation())
        creds_tmp.set_gensec_features(creds_tmp.get_gensec_features()
                                      | gensec.FEATURE_SEAL)
        creds_tmp.set_kerberos_state(
            DONT_USE_KERBEROS)  # kinit is too expensive to use in a tight loop
        return creds_tmp

    def setUp(self):
        super(UserAccountControlTests, self).setUp()
        self.admin_creds = creds
        self.admin_samdb = SamDB(url=ldaphost,
                                 session_info=system_session(),
                                 credentials=self.admin_creds,
                                 lp=lp)
        self.domain_sid = security.dom_sid(self.admin_samdb.get_domain_sid())
        self.base_dn = self.admin_samdb.domain_dn()

        self.unpriv_user = "******"
        self.unpriv_user_pw = "samba123@"
        self.unpriv_creds = self.get_creds(self.unpriv_user,
                                           self.unpriv_user_pw)

        delete_force(
            self.admin_samdb,
            "CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))
        delete_force(self.admin_samdb,
                     "OU=test_computer_ou1,%s" % (self.base_dn))
        delete_force(self.admin_samdb,
                     "CN=%s,CN=Users,%s" % (self.unpriv_user, self.base_dn))

        self.admin_samdb.newuser(self.unpriv_user, self.unpriv_user_pw)
        res = self.admin_samdb.search(
            "CN=%s,CN=Users,%s" %
            (self.unpriv_user, self.admin_samdb.domain_dn()),
            scope=SCOPE_BASE,
            attrs=["objectSid"])
        self.assertEqual(1, len(res))

        self.unpriv_user_sid = ndr_unpack(security.dom_sid,
                                          res[0]["objectSid"][0])
        self.unpriv_user_dn = res[0].dn

        self.samdb = SamDB(url=ldaphost, credentials=self.unpriv_creds, lp=lp)

        self.samr = samr.samr("ncacn_ip_tcp:%s[seal]" % host, lp,
                              self.unpriv_creds)
        self.samr_handle = self.samr.Connect2(
            None, security.SEC_FLAG_MAXIMUM_ALLOWED)
        self.samr_domain = self.samr.OpenDomain(
            self.samr_handle, security.SEC_FLAG_MAXIMUM_ALLOWED,
            self.domain_sid)

        self.sd_utils = sd_utils.SDUtils(self.admin_samdb)

        self.admin_samdb.create_ou("OU=test_computer_ou1," + self.base_dn)
        self.unpriv_user_sid = self.sd_utils.get_object_sid(
            self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(
            self.unpriv_user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," +
                                             self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        self.add_computer_ldap("testcomputer-t")

        self.sd_utils.modify_sd_on_dn("OU=test_computer_ou1," + self.base_dn,
                                      old_sd)

        self.computernames = ["testcomputer-0"]

        # Get the SD of the template account, then force it to match
        # what we expect for SeMachineAccountPrivilege accounts, so we
        # can confirm we created the accounts correctly
        self.sd_reference_cc = self.sd_utils.read_sd_on_dn(
            "CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))

        self.sd_reference_modify = self.sd_utils.read_sd_on_dn(
            "CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))
        for ace in self.sd_reference_modify.dacl.aces:
            if ace.type == security.SEC_ACE_TYPE_ACCESS_ALLOWED and ace.trustee == self.unpriv_user_sid:
                ace.access_mask = ace.access_mask | security.SEC_ADS_SELF_WRITE | security.SEC_ADS_WRITE_PROP

        # Now reconnect without domain admin rights
        self.samdb = SamDB(url=ldaphost, credentials=self.unpriv_creds, lp=lp)

    def tearDown(self):
        super(UserAccountControlTests, self).tearDown()
        for computername in self.computernames:
            delete_force(
                self.admin_samdb,
                "CN=%s,OU=test_computer_ou1,%s" % (computername, self.base_dn))
        delete_force(
            self.admin_samdb,
            "CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))
        delete_force(self.admin_samdb,
                     "OU=test_computer_ou1,%s" % (self.base_dn))
        delete_force(self.admin_samdb,
                     "CN=%s,CN=Users,%s" % (self.unpriv_user, self.base_dn))

    def test_add_computer_sd_cc(self):
        user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(
            user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," +
                                             self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        computername = self.computernames[0]
        sd = ldb.MessageElement((ndr_pack(self.sd_reference_modify)),
                                ldb.FLAG_MOD_ADD, "nTSecurityDescriptor")
        self.add_computer_ldap(computername,
                               others={"nTSecurityDescriptor": sd})

        res = self.admin_samdb.search(
            "%s" % self.base_dn,
            expression="(&(objectClass=computer)(samAccountName=%s$))" %
            computername,
            scope=SCOPE_SUBTREE,
            attrs=["ntSecurityDescriptor"])

        desc = res[0]["nTSecurityDescriptor"][0]
        desc = ndr_unpack(security.descriptor, desc, allow_remaining=True)

        sddl = desc.as_sddl(self.domain_sid)
        self.assertEqual(self.sd_reference_modify.as_sddl(self.domain_sid),
                         sddl)

        m = ldb.Message()
        m.dn = res[0].dn
        m["description"] = ldb.MessageElement(
            ("A description"), ldb.FLAG_MOD_REPLACE, "description")
        self.samdb.modify(m)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(
            str(samba.dsdb.UF_SERVER_TRUST_ACCOUNT), ldb.FLAG_MOD_REPLACE,
            "userAccountControl")
        try:
            self.samdb.modify(m)
            self.fail(
                "Unexpectedly able to set userAccountControl to be a DC on %s"
                % m.dn)
        except LdbError as e5:
            (enum, estr) = e5.args
            self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(
            str(samba.dsdb.UF_WORKSTATION_TRUST_ACCOUNT
                | samba.dsdb.UF_PARTIAL_SECRETS_ACCOUNT), ldb.FLAG_MOD_REPLACE,
            "userAccountControl")
        try:
            self.samdb.modify(m)
            self.fail(
                "Unexpectedly able to set userAccountControl to be an RODC on %s"
                % m.dn)
        except LdbError as e6:
            (enum, estr) = e6.args
            self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(
            str(samba.dsdb.UF_WORKSTATION_TRUST_ACCOUNT), ldb.FLAG_MOD_REPLACE,
            "userAccountControl")
        try:
            self.samdb.modify(m)
            self.fail(
                "Unexpectedly able to set userAccountControl to be an Workstation on %s"
                % m.dn)
        except LdbError as e7:
            (enum, estr) = e7.args
            self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(
            str(samba.dsdb.UF_NORMAL_ACCOUNT), ldb.FLAG_MOD_REPLACE,
            "userAccountControl")
        self.samdb.modify(m)

        m = ldb.Message()
        m.dn = res[0].dn
        m["primaryGroupID"] = ldb.MessageElement(
            str(security.DOMAIN_RID_ADMINS), ldb.FLAG_MOD_REPLACE,
            "primaryGroupID")
        try:
            self.samdb.modify(m)
        except LdbError as e8:
            (enum, estr) = e8.args
            self.assertEqual(ldb.ERR_UNWILLING_TO_PERFORM, enum)
            return
        self.fail()

    def test_mod_computer_cc(self):
        user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(
            user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," +
                                             self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        computername = self.computernames[0]
        self.add_computer_ldap(computername)

        res = self.admin_samdb.search(
            "%s" % self.base_dn,
            expression="(&(objectClass=computer)(samAccountName=%s$))" %
            computername,
            scope=SCOPE_SUBTREE,
            attrs=[])

        m = ldb.Message()
        m.dn = res[0].dn
        m["description"] = ldb.MessageElement(
            ("A description"), ldb.FLAG_MOD_REPLACE, "description")
        self.samdb.modify(m)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(
            str(samba.dsdb.UF_WORKSTATION_TRUST_ACCOUNT
                | samba.dsdb.UF_PARTIAL_SECRETS_ACCOUNT), ldb.FLAG_MOD_REPLACE,
            "userAccountControl")
        try:
            self.samdb.modify(m)
            self.fail("Unexpectedly able to set userAccountControl on %s" %
                      m.dn)
        except LdbError as e9:
            (enum, estr) = e9.args
            self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(
            str(samba.dsdb.UF_SERVER_TRUST_ACCOUNT), ldb.FLAG_MOD_REPLACE,
            "userAccountControl")
        try:
            self.samdb.modify(m)
            self.fail()
        except LdbError as e10:
            (enum, estr) = e10.args
            self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(
            str(samba.dsdb.UF_NORMAL_ACCOUNT), ldb.FLAG_MOD_REPLACE,
            "userAccountControl")
        self.samdb.modify(m)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(
            str(samba.dsdb.UF_WORKSTATION_TRUST_ACCOUNT), ldb.FLAG_MOD_REPLACE,
            "userAccountControl")
        try:
            self.samdb.modify(m)
            self.fail(
                "Unexpectedly able to set userAccountControl to be an Workstation on %s"
                % m.dn)
        except LdbError as e11:
            (enum, estr) = e11.args
            self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)

    def test_admin_mod_uac(self):
        computername = self.computernames[0]
        self.add_computer_ldap(computername, samdb=self.admin_samdb)

        res = self.admin_samdb.search(
            "%s" % self.base_dn,
            expression="(&(objectClass=computer)(samAccountName=%s$))" %
            computername,
            scope=SCOPE_SUBTREE,
            attrs=["userAccountControl"])

        self.assertEqual(
            int(res[0]["userAccountControl"][0]),
            (UF_NORMAL_ACCOUNT | UF_ACCOUNTDISABLE | UF_PASSWD_NOTREQD))

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(
            str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PARTIAL_SECRETS_ACCOUNT
                | UF_TRUSTED_FOR_DELEGATION), ldb.FLAG_MOD_REPLACE,
            "userAccountControl")
        try:
            self.admin_samdb.modify(m)
            self.fail(
                "Unexpectedly able to set userAccountControl to UF_WORKSTATION_TRUST_ACCOUNT|UF_PARTIAL_SECRETS_ACCOUNT|UF_TRUSTED_FOR_DELEGATION on %s"
                % m.dn)
        except LdbError as e12:
            (enum, estr) = e12.args
            self.assertEqual(ldb.ERR_OTHER, enum)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(
            str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PARTIAL_SECRETS_ACCOUNT),
            ldb.FLAG_MOD_REPLACE, "userAccountControl")
        self.admin_samdb.modify(m)

        res = self.admin_samdb.search(
            "%s" % self.base_dn,
            expression="(&(objectClass=computer)(samAccountName=%s$))" %
            computername,
            scope=SCOPE_SUBTREE,
            attrs=["userAccountControl"])

        self.assertEqual(
            int(res[0]["userAccountControl"][0]),
            (UF_WORKSTATION_TRUST_ACCOUNT | UF_PARTIAL_SECRETS_ACCOUNT))
        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(str(UF_ACCOUNTDISABLE),
                                                     ldb.FLAG_MOD_REPLACE,
                                                     "userAccountControl")
        self.admin_samdb.modify(m)

        res = self.admin_samdb.search(
            "%s" % self.base_dn,
            expression="(&(objectClass=computer)(samAccountName=%s$))" %
            computername,
            scope=SCOPE_SUBTREE,
            attrs=["userAccountControl"])

        self.assertEqual(int(res[0]["userAccountControl"][0]),
                         UF_NORMAL_ACCOUNT | UF_ACCOUNTDISABLE)

    def test_uac_bits_set(self):
        user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(
            user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," +
                                             self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        computername = self.computernames[0]
        self.add_computer_ldap(computername)

        res = self.admin_samdb.search(
            "%s" % self.base_dn,
            expression="(&(objectClass=computer)(samAccountName=%s$))" %
            computername,
            scope=SCOPE_SUBTREE,
            attrs=[])

        m = ldb.Message()
        m.dn = res[0].dn
        m["description"] = ldb.MessageElement(
            ("A description"), ldb.FLAG_MOD_REPLACE, "description")
        self.samdb.modify(m)

        # These bits are privileged, but authenticated users have that CAR by default, so this is a pain to test
        priv_to_auth_users_bits = set([
            UF_PASSWD_NOTREQD, UF_ENCRYPTED_TEXT_PASSWORD_ALLOWED,
            UF_DONT_EXPIRE_PASSWD
        ])

        # These bits really are privileged, or can't be changed from UF_NORMAL as a non-admin
        priv_bits = set([
            UF_INTERDOMAIN_TRUST_ACCOUNT, UF_SERVER_TRUST_ACCOUNT,
            UF_TRUSTED_FOR_DELEGATION,
            UF_TRUSTED_TO_AUTHENTICATE_FOR_DELEGATION,
            UF_WORKSTATION_TRUST_ACCOUNT
        ])

        invalid_bits = set(
            [UF_TEMP_DUPLICATE_ACCOUNT, UF_PARTIAL_SECRETS_ACCOUNT])

        for bit in bits:
            m = ldb.Message()
            m.dn = res[0].dn
            m["userAccountControl"] = ldb.MessageElement(
                str(bit | UF_PASSWD_NOTREQD), ldb.FLAG_MOD_REPLACE,
                "userAccountControl")
            try:
                self.samdb.modify(m)
                if (bit in priv_bits):
                    self.fail(
                        "Unexpectedly able to set userAccountControl bit 0x%08X on %s"
                        % (bit, m.dn))
            except LdbError as e:
                (enum, estr) = e.args
                if bit in invalid_bits:
                    self.assertEqual(
                        enum, ldb.ERR_OTHER,
                        "was not able to set 0x%08X on %s" % (bit, m.dn))
                    # No point going on, try the next bit
                    continue
                elif (bit in priv_bits):
                    self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)
                else:
                    self.fail(
                        "Unable to set userAccountControl bit 0x%08X on %s: %s"
                        % (bit, m.dn, estr))

    def uac_bits_unrelated_modify_helper(self, account_type):
        user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(
            user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," +
                                             self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        computername = self.computernames[0]
        self.add_computer_ldap(
            computername, others={"userAccountControl": [str(account_type)]})

        res = self.admin_samdb.search(
            "%s" % self.base_dn,
            expression="(&(objectClass=computer)(samAccountName=%s$))" %
            computername,
            scope=SCOPE_SUBTREE,
            attrs=["userAccountControl"])
        self.assertEqual(int(res[0]["userAccountControl"][0]), account_type)

        m = ldb.Message()
        m.dn = res[0].dn
        m["description"] = ldb.MessageElement(
            ("A description"), ldb.FLAG_MOD_REPLACE, "description")
        self.samdb.modify(m)

        invalid_bits = set(
            [UF_TEMP_DUPLICATE_ACCOUNT, UF_PARTIAL_SECRETS_ACCOUNT])

        # UF_LOCKOUT isn't actually ignored, it changes other
        # attributes but does not stick here.  See MS-SAMR 2.2.1.13
        # UF_FLAG Codes clarification that UF_SCRIPT and
        # UF_PASSWD_CANT_CHANGE are simply ignored by both clients and
        # servers.  Other bits are ignored as they are undefined, or
        # are not set into the attribute (instead triggering other
        # events).
        ignored_bits = set([
            UF_SCRIPT, UF_00000004, UF_LOCKOUT, UF_PASSWD_CANT_CHANGE,
            UF_00000400, UF_00004000, UF_00008000, UF_PASSWORD_EXPIRED,
            int("0x10000000", 16),
            int("0x20000000", 16),
            int("0x40000000", 16),
            int("0x80000000", 16)
        ])
        super_priv_bits = set([UF_INTERDOMAIN_TRUST_ACCOUNT])

        priv_to_remove_bits = set([
            UF_TRUSTED_FOR_DELEGATION,
            UF_TRUSTED_TO_AUTHENTICATE_FOR_DELEGATION,
            UF_WORKSTATION_TRUST_ACCOUNT
        ])

        for bit in bits:
            # Reset this to the initial position, just to be sure
            m = ldb.Message()
            m.dn = res[0].dn
            m["userAccountControl"] = ldb.MessageElement(
                str(account_type), ldb.FLAG_MOD_REPLACE, "userAccountControl")
            self.admin_samdb.modify(m)

            res = self.admin_samdb.search(
                "%s" % self.base_dn,
                expression="(&(objectClass=computer)(samAccountName=%s$))" %
                computername,
                scope=SCOPE_SUBTREE,
                attrs=["userAccountControl"])

            self.assertEqual(int(res[0]["userAccountControl"][0]),
                             account_type)

            m = ldb.Message()
            m.dn = res[0].dn
            m["userAccountControl"] = ldb.MessageElement(
                str(bit | UF_PASSWD_NOTREQD), ldb.FLAG_MOD_REPLACE,
                "userAccountControl")
            try:
                self.admin_samdb.modify(m)
                if bit in invalid_bits:
                    self.fail(
                        "Should have been unable to set userAccountControl bit 0x%08X on %s"
                        % (bit, m.dn))

            except LdbError as e1:
                (enum, estr) = e1.args
                if bit in invalid_bits:
                    self.assertEqual(enum, ldb.ERR_OTHER)
                    # No point going on, try the next bit
                    continue
                elif bit in super_priv_bits:
                    self.assertEqual(enum, ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS)
                    # No point going on, try the next bit
                    continue
                else:
                    self.fail(
                        "Unable to set userAccountControl bit 0x%08X on %s: %s"
                        % (bit, m.dn, estr))

            res = self.admin_samdb.search(
                "%s" % self.base_dn,
                expression="(&(objectClass=computer)(samAccountName=%s$))" %
                computername,
                scope=SCOPE_SUBTREE,
                attrs=["userAccountControl"])

            if bit in ignored_bits:
                self.assertEqual(int(res[0]["userAccountControl"][0]),
                                 UF_NORMAL_ACCOUNT | UF_PASSWD_NOTREQD,
                                 "Bit 0x%08x shouldn't stick" % bit)
            else:
                if bit in account_types:
                    self.assertEqual(int(res[0]["userAccountControl"][0]),
                                     bit | UF_PASSWD_NOTREQD,
                                     "Bit 0x%08x didn't stick" % bit)
                else:
                    self.assertEqual(
                        int(res[0]["userAccountControl"][0]),
                        bit | UF_NORMAL_ACCOUNT | UF_PASSWD_NOTREQD,
                        "Bit 0x%08x didn't stick" % bit)

            try:
                m = ldb.Message()
                m.dn = res[0].dn
                m["userAccountControl"] = ldb.MessageElement(
                    str(bit | UF_PASSWD_NOTREQD | UF_ACCOUNTDISABLE),
                    ldb.FLAG_MOD_REPLACE, "userAccountControl")
                self.samdb.modify(m)

            except LdbError as e2:
                (enum, estr) = e2.args
                self.fail(
                    "Unable to set userAccountControl bit 0x%08X on %s: %s" %
                    (bit, m.dn, estr))

            res = self.admin_samdb.search(
                "%s" % self.base_dn,
                expression="(&(objectClass=computer)(samAccountName=%s$))" %
                computername,
                scope=SCOPE_SUBTREE,
                attrs=["userAccountControl"])

            if bit in account_types:
                self.assertEqual(
                    int(res[0]["userAccountControl"][0]),
                    bit | UF_ACCOUNTDISABLE | UF_PASSWD_NOTREQD,
                    "bit 0X%08x should have been added (0X%08x vs 0X%08x)" %
                    (bit, int(res[0]["userAccountControl"][0]),
                     bit | UF_ACCOUNTDISABLE | UF_PASSWD_NOTREQD))
            elif bit in ignored_bits:
                self.assertEqual(
                    int(res[0]["userAccountControl"][0]),
                    UF_NORMAL_ACCOUNT | UF_ACCOUNTDISABLE | UF_PASSWD_NOTREQD,
                    "bit 0X%08x should have been added (0X%08x vs 0X%08x)" %
                    (bit, int(
                        res[0]["userAccountControl"][0]), UF_NORMAL_ACCOUNT
                     | UF_ACCOUNTDISABLE | UF_PASSWD_NOTREQD))

            else:
                self.assertEqual(
                    int(res[0]["userAccountControl"][0]),
                    bit | UF_NORMAL_ACCOUNT | UF_ACCOUNTDISABLE
                    | UF_PASSWD_NOTREQD,
                    "bit 0X%08x should have been added (0X%08x vs 0X%08x)" %
                    (bit, int(res[0]["userAccountControl"][0]),
                     bit | UF_NORMAL_ACCOUNT | UF_ACCOUNTDISABLE
                     | UF_PASSWD_NOTREQD))

            try:
                m = ldb.Message()
                m.dn = res[0].dn
                m["userAccountControl"] = ldb.MessageElement(
                    str(UF_PASSWD_NOTREQD | UF_ACCOUNTDISABLE),
                    ldb.FLAG_MOD_REPLACE, "userAccountControl")
                self.samdb.modify(m)
                if bit in priv_to_remove_bits:
                    self.fail(
                        "Should have been unable to remove userAccountControl bit 0x%08X on %s"
                        % (bit, m.dn))

            except LdbError as e3:
                (enum, estr) = e3.args
                if bit in priv_to_remove_bits:
                    self.assertEqual(enum, ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS)
                else:
                    self.fail(
                        "Unexpectedly unable to remove userAccountControl bit 0x%08X on %s: %s"
                        % (bit, m.dn, estr))

            res = self.admin_samdb.search(
                "%s" % self.base_dn,
                expression="(&(objectClass=computer)(samAccountName=%s$))" %
                computername,
                scope=SCOPE_SUBTREE,
                attrs=["userAccountControl"])

            if bit in priv_to_remove_bits:
                if bit in account_types:
                    self.assertEqual(
                        int(res[0]["userAccountControl"][0]),
                        bit | UF_ACCOUNTDISABLE | UF_PASSWD_NOTREQD,
                        "bit 0X%08x should not have been removed" % bit)
                else:
                    self.assertEqual(
                        int(res[0]["userAccountControl"][0]),
                        bit | UF_NORMAL_ACCOUNT | UF_ACCOUNTDISABLE
                        | UF_PASSWD_NOTREQD,
                        "bit 0X%08x should not have been removed" % bit)
            else:
                self.assertEqual(
                    int(res[0]["userAccountControl"][0]),
                    UF_NORMAL_ACCOUNT | UF_ACCOUNTDISABLE | UF_PASSWD_NOTREQD,
                    "bit 0X%08x should have been removed" % bit)

    def test_uac_bits_unrelated_modify_normal(self):
        self.uac_bits_unrelated_modify_helper(UF_NORMAL_ACCOUNT)

    def test_uac_bits_unrelated_modify_workstation(self):
        self.uac_bits_unrelated_modify_helper(UF_WORKSTATION_TRUST_ACCOUNT)

    def test_uac_bits_add(self):
        computername = self.computernames[0]

        user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(
            user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," +
                                             self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        invalid_bits = set(
            [UF_TEMP_DUPLICATE_ACCOUNT, UF_PARTIAL_SECRETS_ACCOUNT])
        # These bits are privileged, but authenticated users have that CAR by default, so this is a pain to test
        priv_to_auth_users_bits = set([
            UF_PASSWD_NOTREQD, UF_ENCRYPTED_TEXT_PASSWORD_ALLOWED,
            UF_DONT_EXPIRE_PASSWD
        ])

        # These bits really are privileged
        priv_bits = set([
            UF_INTERDOMAIN_TRUST_ACCOUNT, UF_SERVER_TRUST_ACCOUNT,
            UF_TRUSTED_FOR_DELEGATION,
            UF_TRUSTED_TO_AUTHENTICATE_FOR_DELEGATION
        ])

        for bit in bits:
            try:
                self.add_computer_ldap(
                    computername, others={"userAccountControl": [str(bit)]})
                delete_force(
                    self.admin_samdb, "CN=%s,OU=test_computer_ou1,%s" %
                    (computername, self.base_dn))
                if bit in priv_bits:
                    self.fail(
                        "Unexpectdly able to set userAccountControl bit 0x%08X on %s"
                        % (bit, computername))

            except LdbError as e4:
                (enum, estr) = e4.args
                if bit in invalid_bits:
                    self.assertEqual(
                        enum, ldb.ERR_OTHER,
                        "Invalid bit 0x%08X was able to be set on %s" %
                        (bit, computername))
                    # No point going on, try the next bit
                    continue
                elif bit in priv_bits:
                    self.assertEqual(enum, ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS)
                    continue
                else:
                    self.fail(
                        "Unable to set userAccountControl bit 0x%08X on %s: %s"
                        % (bit, computername, estr))

    def test_primarygroupID_cc_add(self):
        computername = self.computernames[0]

        user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(
            user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," +
                                             self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)
        try:
            # When creating a new object, you can not ever set the primaryGroupID
            self.add_computer_ldap(
                computername,
                others={"primaryGroupID": [str(security.DOMAIN_RID_ADMINS)]})
            self.fail(
                "Unexpectedly able to set primaryGruopID to be an admin on %s"
                % computername)
        except LdbError as e13:
            (enum, estr) = e13.args
            self.assertEqual(enum, ldb.ERR_UNWILLING_TO_PERFORM)

    def test_primarygroupID_priv_DC_modify(self):
        computername = self.computernames[0]

        self.add_computer_ldap(
            computername,
            others={"userAccountControl": [str(UF_SERVER_TRUST_ACCOUNT)]},
            samdb=self.admin_samdb)
        res = self.admin_samdb.search(
            "%s" % self.base_dn,
            expression="(&(objectClass=computer)(samAccountName=%s$))" %
            computername,
            scope=SCOPE_SUBTREE,
            attrs=[""])

        m = ldb.Message()
        m.dn = ldb.Dn(
            self.admin_samdb,
            "<SID=%s-%d>" % (str(self.domain_sid), security.DOMAIN_RID_USERS))
        m["member"] = ldb.MessageElement([str(res[0].dn)], ldb.FLAG_MOD_ADD,
                                         "member")
        self.admin_samdb.modify(m)

        m = ldb.Message()
        m.dn = res[0].dn
        m["primaryGroupID"] = ldb.MessageElement(
            [str(security.DOMAIN_RID_USERS)], ldb.FLAG_MOD_REPLACE,
            "primaryGroupID")
        try:
            self.admin_samdb.modify(m)

            # When creating a new object, you can not ever set the primaryGroupID
            self.fail(
                "Unexpectedly able to set primaryGroupID to be other than DCS on %s"
                % computername)
        except LdbError as e14:
            (enum, estr) = e14.args
            self.assertEqual(enum, ldb.ERR_UNWILLING_TO_PERFORM)

    def test_primarygroupID_priv_member_modify(self):
        computername = self.computernames[0]

        self.add_computer_ldap(computername,
                               others={
                                   "userAccountControl": [
                                       str(UF_WORKSTATION_TRUST_ACCOUNT
                                           | UF_PARTIAL_SECRETS_ACCOUNT)
                                   ]
                               },
                               samdb=self.admin_samdb)
        res = self.admin_samdb.search(
            "%s" % self.base_dn,
            expression="(&(objectClass=computer)(samAccountName=%s$))" %
            computername,
            scope=SCOPE_SUBTREE,
            attrs=[""])

        m = ldb.Message()
        m.dn = ldb.Dn(
            self.admin_samdb,
            "<SID=%s-%d>" % (str(self.domain_sid), security.DOMAIN_RID_USERS))
        m["member"] = ldb.MessageElement([str(res[0].dn)], ldb.FLAG_MOD_ADD,
                                         "member")
        self.admin_samdb.modify(m)

        m = ldb.Message()
        m.dn = res[0].dn
        m["primaryGroupID"] = ldb.MessageElement(
            [str(security.DOMAIN_RID_USERS)], ldb.FLAG_MOD_REPLACE,
            "primaryGroupID")
        try:
            self.admin_samdb.modify(m)

            # When creating a new object, you can not ever set the primaryGroupID
            self.fail(
                "Unexpectedly able to set primaryGroupID to be other than DCS on %s"
                % computername)
        except LdbError as e15:
            (enum, estr) = e15.args
            self.assertEqual(enum, ldb.ERR_UNWILLING_TO_PERFORM)

    def test_primarygroupID_priv_user_modify(self):
        computername = self.computernames[0]

        self.add_computer_ldap(
            computername,
            others={"userAccountControl": [str(UF_WORKSTATION_TRUST_ACCOUNT)]},
            samdb=self.admin_samdb)
        res = self.admin_samdb.search(
            "%s" % self.base_dn,
            expression="(&(objectClass=computer)(samAccountName=%s$))" %
            computername,
            scope=SCOPE_SUBTREE,
            attrs=[""])

        m = ldb.Message()
        m.dn = ldb.Dn(
            self.admin_samdb,
            "<SID=%s-%d>" % (str(self.domain_sid), security.DOMAIN_RID_ADMINS))
        m["member"] = ldb.MessageElement([str(res[0].dn)], ldb.FLAG_MOD_ADD,
                                         "member")
        self.admin_samdb.modify(m)

        m = ldb.Message()
        m.dn = res[0].dn
        m["primaryGroupID"] = ldb.MessageElement(
            [str(security.DOMAIN_RID_ADMINS)], ldb.FLAG_MOD_REPLACE,
            "primaryGroupID")
        self.admin_samdb.modify(m)
Example #20
0
class DynamicTokenTest(samba.tests.TestCase):
    def get_creds(self, target_username, target_password):
        creds_tmp = Credentials()
        creds_tmp.set_username(target_username)
        creds_tmp.set_password(target_password)
        creds_tmp.set_domain(creds.get_domain())
        creds_tmp.set_realm(creds.get_realm())
        creds_tmp.set_workstation(creds.get_workstation())
        creds_tmp.set_gensec_features(creds_tmp.get_gensec_features()
                                      | gensec.FEATURE_SEAL)
        return creds_tmp

    def get_ldb_connection(self, target_username, target_password):
        creds_tmp = self.get_creds(target_username, target_password)
        ldb_target = SamDB(url=url, credentials=creds_tmp, lp=lp)
        return ldb_target

    def setUp(self):
        super(DynamicTokenTest, self).setUp()
        self.admin_ldb = SamDB(url,
                               credentials=creds,
                               session_info=system_session(lp),
                               lp=lp)

        self.base_dn = self.admin_ldb.domain_dn()

        self.test_user = "******"
        self.test_user_pass = "******"
        self.admin_ldb.newuser(self.test_user, self.test_user_pass)
        self.test_group0 = "tokengroups_group0"
        self.admin_ldb.newgroup(
            self.test_group0, grouptype=dsdb.GTYPE_SECURITY_DOMAIN_LOCAL_GROUP)
        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" %
                                    (self.test_group0, self.base_dn),
                                    attrs=["objectSid"],
                                    scope=ldb.SCOPE_BASE)
        self.test_group0_sid = ndr_unpack(samba.dcerpc.security.dom_sid,
                                          res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group0,
                                                [self.test_user],
                                                add_members_operation=True)

        self.test_group1 = "tokengroups_group1"
        self.admin_ldb.newgroup(self.test_group1,
                                grouptype=dsdb.GTYPE_SECURITY_GLOBAL_GROUP)
        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" %
                                    (self.test_group1, self.base_dn),
                                    attrs=["objectSid"],
                                    scope=ldb.SCOPE_BASE)
        self.test_group1_sid = ndr_unpack(samba.dcerpc.security.dom_sid,
                                          res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group1,
                                                [self.test_user],
                                                add_members_operation=True)

        self.test_group2 = "tokengroups_group2"
        self.admin_ldb.newgroup(self.test_group2,
                                grouptype=dsdb.GTYPE_SECURITY_UNIVERSAL_GROUP)

        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" %
                                    (self.test_group2, self.base_dn),
                                    attrs=["objectSid"],
                                    scope=ldb.SCOPE_BASE)
        self.test_group2_sid = ndr_unpack(samba.dcerpc.security.dom_sid,
                                          res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group2,
                                                [self.test_user],
                                                add_members_operation=True)

        self.test_group3 = "tokengroups_group3"
        self.admin_ldb.newgroup(self.test_group3,
                                grouptype=dsdb.GTYPE_SECURITY_UNIVERSAL_GROUP)

        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" %
                                    (self.test_group3, self.base_dn),
                                    attrs=["objectSid"],
                                    scope=ldb.SCOPE_BASE)
        self.test_group3_sid = ndr_unpack(samba.dcerpc.security.dom_sid,
                                          res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group3,
                                                [self.test_group1],
                                                add_members_operation=True)

        self.test_group4 = "tokengroups_group4"
        self.admin_ldb.newgroup(self.test_group4,
                                grouptype=dsdb.GTYPE_SECURITY_UNIVERSAL_GROUP)

        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" %
                                    (self.test_group4, self.base_dn),
                                    attrs=["objectSid"],
                                    scope=ldb.SCOPE_BASE)
        self.test_group4_sid = ndr_unpack(samba.dcerpc.security.dom_sid,
                                          res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group4,
                                                [self.test_group3],
                                                add_members_operation=True)

        self.test_group5 = "tokengroups_group5"
        self.admin_ldb.newgroup(
            self.test_group5, grouptype=dsdb.GTYPE_SECURITY_DOMAIN_LOCAL_GROUP)

        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" %
                                    (self.test_group5, self.base_dn),
                                    attrs=["objectSid"],
                                    scope=ldb.SCOPE_BASE)
        self.test_group5_sid = ndr_unpack(samba.dcerpc.security.dom_sid,
                                          res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group5,
                                                [self.test_group4],
                                                add_members_operation=True)

        self.test_group6 = "tokengroups_group6"
        self.admin_ldb.newgroup(
            self.test_group6, grouptype=dsdb.GTYPE_SECURITY_DOMAIN_LOCAL_GROUP)

        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" %
                                    (self.test_group6, self.base_dn),
                                    attrs=["objectSid"],
                                    scope=ldb.SCOPE_BASE)
        self.test_group6_sid = ndr_unpack(samba.dcerpc.security.dom_sid,
                                          res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group6,
                                                [self.test_user],
                                                add_members_operation=True)

        self.ldb = self.get_ldb_connection(self.test_user, self.test_user_pass)

        res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        self.user_sid = ndr_unpack(samba.dcerpc.security.dom_sid,
                                   res[0]["tokenGroups"][0])
        self.user_sid_dn = "<SID=%s>" % str(self.user_sid)

        res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=[])
        self.assertEquals(len(res), 1)

        self.test_user_dn = res[0].dn

        session_info_flags = (AUTH_SESSION_INFO_DEFAULT_GROUPS
                              | AUTH_SESSION_INFO_AUTHENTICATED
                              | AUTH_SESSION_INFO_SIMPLE_PRIVILEGES)

        if creds.get_kerberos_state() == DONT_USE_KERBEROS:
            session_info_flags |= AUTH_SESSION_INFO_NTLM

        session = samba.auth.user_session(
            self.ldb,
            lp_ctx=lp,
            dn=self.user_sid_dn,
            session_info_flags=session_info_flags)

        token = session.security_token
        self.user_sids = []
        for s in token.sids:
            self.user_sids.append(str(s))

    def tearDown(self):
        super(DynamicTokenTest, self).tearDown()
        delete_force(
            self.admin_ldb,
            "CN=%s,%s,%s" % (self.test_user, "cn=users", self.base_dn))
        delete_force(
            self.admin_ldb,
            "CN=%s,%s,%s" % (self.test_group0, "cn=users", self.base_dn))
        delete_force(
            self.admin_ldb,
            "CN=%s,%s,%s" % (self.test_group1, "cn=users", self.base_dn))
        delete_force(
            self.admin_ldb,
            "CN=%s,%s,%s" % (self.test_group2, "cn=users", self.base_dn))
        delete_force(
            self.admin_ldb,
            "CN=%s,%s,%s" % (self.test_group3, "cn=users", self.base_dn))
        delete_force(
            self.admin_ldb,
            "CN=%s,%s,%s" % (self.test_group4, "cn=users", self.base_dn))
        delete_force(
            self.admin_ldb,
            "CN=%s,%s,%s" % (self.test_group5, "cn=users", self.base_dn))
        delete_force(
            self.admin_ldb,
            "CN=%s,%s,%s" % (self.test_group6, "cn=users", self.base_dn))

    def test_rootDSE_tokenGroups(self):
        """Testing rootDSE tokengroups against internal calculation"""
        if not url.startswith("ldap"):
            self.fail(msg="This test is only valid on ldap")

        res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        print("Getting tokenGroups from rootDSE")
        tokengroups = []
        for sid in res[0]['tokenGroups']:
            tokengroups.append(
                str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))

        sidset1 = set(tokengroups)
        sidset2 = set(self.user_sids)
        if len(sidset1.difference(sidset2)):
            print("token sids don't match")
            print("tokengroups: %s" % tokengroups)
            print("calculated : %s" % self.user_sids)
            print("difference : %s" % sidset1.difference(sidset2))
            self.fail(
                msg="calculated groups don't match against rootDSE tokenGroups"
            )

    def test_dn_tokenGroups(self):
        print("Getting tokenGroups from user DN")
        res = self.ldb.search(self.user_sid_dn,
                              scope=ldb.SCOPE_BASE,
                              attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        dn_tokengroups = []
        for sid in res[0]['tokenGroups']:
            dn_tokengroups.append(
                str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))

        sidset1 = set(dn_tokengroups)
        sidset2 = set(self.user_sids)

        # The SIDs on the DN do not include the NTLM authentication SID
        sidset2.discard(samba.dcerpc.security.SID_NT_NTLM_AUTHENTICATION)

        if len(sidset1.difference(sidset2)):
            print("token sids don't match")
            print("difference : %s" % sidset1.difference(sidset2))
            self.fail(
                msg="calculated groups don't match against user DN tokenGroups"
            )

    def test_pac_groups(self):
        settings = {}
        settings["lp_ctx"] = lp
        settings["target_hostname"] = lp.get("netbios name")

        gensec_client = gensec.Security.start_client(settings)
        gensec_client.set_credentials(
            self.get_creds(self.test_user, self.test_user_pass))
        gensec_client.want_feature(gensec.FEATURE_SEAL)
        gensec_client.start_mech_by_sasl_name("GSSAPI")

        auth_context = AuthContext(lp_ctx=lp, ldb=self.ldb, methods=[])

        gensec_server = gensec.Security.start_server(settings, auth_context)
        machine_creds = Credentials()
        machine_creds.guess(lp)
        machine_creds.set_machine_account(lp)
        gensec_server.set_credentials(machine_creds)

        gensec_server.want_feature(gensec.FEATURE_SEAL)
        gensec_server.start_mech_by_sasl_name("GSSAPI")

        client_finished = False
        server_finished = False
        server_to_client = ""

        # Run the actual call loop.
        while client_finished == False and server_finished == False:
            if not client_finished:
                print "running client gensec_update"
                (client_finished,
                 client_to_server) = gensec_client.update(server_to_client)
            if not server_finished:
                print "running server gensec_update"
                (server_finished,
                 server_to_client) = gensec_server.update(client_to_server)

        session = gensec_server.session_info()

        token = session.security_token
        pac_sids = []
        for s in token.sids:
            pac_sids.append(str(s))

        sidset1 = set(pac_sids)
        sidset2 = set(self.user_sids)
        if len(sidset1.difference(sidset2)):
            print("token sids don't match")
            print("difference : %s" % sidset1.difference(sidset2))
            self.fail(
                msg="calculated groups don't match against user PAC tokenGroups"
            )

    def test_tokenGroups_manual(self):
        # Manually run the tokenGroups algorithm from MS-ADTS 3.1.1.4.5.19 and MS-DRSR 4.1.8.3
        # and compare the result
        res = self.admin_ldb.search(
            base=self.base_dn,
            scope=ldb.SCOPE_SUBTREE,
            expression="(|(objectclass=user)(objectclass=group))",
            attrs=["memberOf"])
        aSet = set()
        aSetR = set()
        vSet = set()
        for obj in res:
            if "memberOf" in obj:
                for dn in obj["memberOf"]:
                    first = obj.dn.get_casefold()
                    second = ldb.Dn(self.admin_ldb, dn).get_casefold()
                    aSet.add((first, second))
                    aSetR.add((second, first))
                    vSet.add(first)
                    vSet.add(second)

        res = self.admin_ldb.search(base=self.base_dn,
                                    scope=ldb.SCOPE_SUBTREE,
                                    expression="(objectclass=user)",
                                    attrs=["primaryGroupID"])
        for obj in res:
            if "primaryGroupID" in obj:
                sid = "%s-%d" % (self.admin_ldb.get_domain_sid(),
                                 int(obj["primaryGroupID"][0]))
                res2 = self.admin_ldb.search(base="<SID=%s>" % sid,
                                             scope=ldb.SCOPE_BASE,
                                             attrs=[])
                first = obj.dn.get_casefold()
                second = res2[0].dn.get_casefold()

                aSet.add((first, second))
                aSetR.add((second, first))
                vSet.add(first)
                vSet.add(second)

        wSet = set()
        wSet.add(self.test_user_dn.get_casefold())
        closure(vSet, wSet, aSet)
        wSet.remove(self.test_user_dn.get_casefold())

        tokenGroupsSet = set()

        res = self.ldb.search(self.user_sid_dn,
                              scope=ldb.SCOPE_BASE,
                              attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        dn_tokengroups = []
        for sid in res[0]['tokenGroups']:
            sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
            res3 = self.admin_ldb.search(base="<SID=%s>" % sid,
                                         scope=ldb.SCOPE_BASE,
                                         attrs=[])
            tokenGroupsSet.add(res3[0].dn.get_casefold())

        if len(wSet.difference(tokenGroupsSet)):
            self.fail(msg="additional calculated: %s" %
                      wSet.difference(tokenGroupsSet))

        if len(tokenGroupsSet.difference(wSet)):
            self.fail(msg="additional tokenGroups: %s" %
                      tokenGroupsSet.difference(wSet))

    def filtered_closure(self, wSet, filter_grouptype):
        res = self.admin_ldb.search(
            base=self.base_dn,
            scope=ldb.SCOPE_SUBTREE,
            expression="(|(objectclass=user)(objectclass=group))",
            attrs=["memberOf"])
        aSet = set()
        aSetR = set()
        vSet = set()
        for obj in res:
            vSet.add(obj.dn.get_casefold())
            if "memberOf" in obj:
                for dn in obj["memberOf"]:
                    first = obj.dn.get_casefold()
                    second = ldb.Dn(self.admin_ldb, dn).get_casefold()
                    aSet.add((first, second))
                    aSetR.add((second, first))
                    vSet.add(first)
                    vSet.add(second)

        res = self.admin_ldb.search(base=self.base_dn,
                                    scope=ldb.SCOPE_SUBTREE,
                                    expression="(objectclass=user)",
                                    attrs=["primaryGroupID"])
        for obj in res:
            if "primaryGroupID" in obj:
                sid = "%s-%d" % (self.admin_ldb.get_domain_sid(),
                                 int(obj["primaryGroupID"][0]))
                res2 = self.admin_ldb.search(base="<SID=%s>" % sid,
                                             scope=ldb.SCOPE_BASE,
                                             attrs=[])
                first = obj.dn.get_casefold()
                second = res2[0].dn.get_casefold()

                aSet.add((first, second))
                aSetR.add((second, first))
                vSet.add(first)
                vSet.add(second)

        uSet = set()
        for v in vSet:
            res_group = self.admin_ldb.search(base=v,
                                              scope=ldb.SCOPE_BASE,
                                              attrs=["groupType"],
                                              expression="objectClass=group")
            if len(res_group) == 1:
                if hex(int(res_group[0]["groupType"][0])
                       & 0x00000000FFFFFFFF) == hex(filter_grouptype):
                    uSet.add(v)
            else:
                uSet.add(v)

        closure(uSet, wSet, aSet)

    def test_tokenGroupsGlobalAndUniversal_manual(self):
        # Manually run the tokenGroups algorithm from MS-ADTS 3.1.1.4.5.19 and MS-DRSR 4.1.8.3
        # and compare the result

        # The variable names come from MS-ADTS May 15, 2014

        S = set()
        S.add(self.test_user_dn.get_casefold())

        self.filtered_closure(S, GTYPE_SECURITY_GLOBAL_GROUP)

        T = set()
        # Not really a SID, we do this on DNs...
        for sid in S:
            X = set()
            X.add(sid)
            self.filtered_closure(X, GTYPE_SECURITY_UNIVERSAL_GROUP)

            T = T.union(X)

        T.remove(self.test_user_dn.get_casefold())

        tokenGroupsSet = set()

        res = self.ldb.search(self.user_sid_dn,
                              scope=ldb.SCOPE_BASE,
                              attrs=["tokenGroupsGlobalAndUniversal"])
        self.assertEquals(len(res), 1)

        dn_tokengroups = []
        for sid in res[0]['tokenGroupsGlobalAndUniversal']:
            sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
            res3 = self.admin_ldb.search(base="<SID=%s>" % sid,
                                         scope=ldb.SCOPE_BASE,
                                         attrs=[])
            tokenGroupsSet.add(res3[0].dn.get_casefold())

        if len(T.difference(tokenGroupsSet)):
            self.fail(msg="additional calculated: %s" %
                      T.difference(tokenGroupsSet))

        if len(tokenGroupsSet.difference(T)):
            self.fail(msg="additional tokenGroupsGlobalAndUniversal: %s" %
                      tokenGroupsSet.difference(T))

    def test_samr_GetGroupsForUser(self):
        # Confirm that we get the correct results against SAMR also
        if not url.startswith("ldap://"):
            self.fail(
                msg=
                "This test is only valid on ldap (so we an find the hostname and use SAMR)"
            )
        host = url.split("://")[1]
        (domain_sid, user_rid) = self.user_sid.split()
        samr_conn = samba.dcerpc.samr.samr("ncacn_ip_tcp:%s[seal]" % host, lp,
                                           creds)
        samr_handle = samr_conn.Connect2(None,
                                         security.SEC_FLAG_MAXIMUM_ALLOWED)
        samr_domain = samr_conn.OpenDomain(samr_handle,
                                           security.SEC_FLAG_MAXIMUM_ALLOWED,
                                           domain_sid)
        user_handle = samr_conn.OpenUser(samr_domain,
                                         security.SEC_FLAG_MAXIMUM_ALLOWED,
                                         user_rid)
        rids = samr_conn.GetGroupsForUser(user_handle)
        samr_dns = set()
        for rid in rids.rids:
            self.assertEqual(
                rid.attributes, samr.SE_GROUP_MANDATORY
                | samr.SE_GROUP_ENABLED_BY_DEFAULT | samr.SE_GROUP_ENABLED)
            sid = "%s-%d" % (domain_sid, rid.rid)
            res = self.admin_ldb.search(base="<SID=%s>" % sid,
                                        scope=ldb.SCOPE_BASE,
                                        attrs=[])
            samr_dns.add(res[0].dn.get_casefold())

        user_info = samr_conn.QueryUserInfo(user_handle, 1)
        self.assertEqual(rids.rids[0].rid, user_info.primary_gid)

        tokenGroupsSet = set()
        res = self.ldb.search(self.user_sid_dn,
                              scope=ldb.SCOPE_BASE,
                              attrs=["tokenGroupsGlobalAndUniversal"])
        for sid in res[0]['tokenGroupsGlobalAndUniversal']:
            sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
            res3 = self.admin_ldb.search(
                base="<SID=%s>" % sid,
                scope=ldb.SCOPE_BASE,
                attrs=[],
                expression=
                "(&(|(grouptype=%d)(grouptype=%d))(objectclass=group))" %
                (GTYPE_SECURITY_GLOBAL_GROUP, GTYPE_SECURITY_UNIVERSAL_GROUP))
            if len(res) == 1:
                tokenGroupsSet.add(res3[0].dn.get_casefold())

        if len(samr_dns.difference(tokenGroupsSet)):
            self.fail(
                msg="additional samr_GetUserGroups over tokenGroups: %s" %
                samr_dns.difference(tokenGroupsSet))

        memberOf = set()
        # Add the primary group
        primary_group_sid = "%s-%d" % (domain_sid, user_info.primary_gid)
        res2 = self.admin_ldb.search(base="<SID=%s>" % sid,
                                     scope=ldb.SCOPE_BASE,
                                     attrs=[])

        memberOf.add(res2[0].dn.get_casefold())
        res = self.ldb.search(self.user_sid_dn,
                              scope=ldb.SCOPE_BASE,
                              attrs=["memberOf"])
        for dn in res[0]['memberOf']:
            res3 = self.admin_ldb.search(
                base=dn,
                scope=ldb.SCOPE_BASE,
                attrs=[],
                expression=
                "(&(|(grouptype=%d)(grouptype=%d))(objectclass=group))" %
                (GTYPE_SECURITY_GLOBAL_GROUP, GTYPE_SECURITY_UNIVERSAL_GROUP))
            if len(res3) == 1:
                memberOf.add(res3[0].dn.get_casefold())

        if len(memberOf.difference(samr_dns)):
            self.fail(msg="additional memberOf over samr_GetUserGroups: %s" %
                      memberOf.difference(samr_dns))

        if len(samr_dns.difference(memberOf)):
            self.fail(msg="additional samr_GetUserGroups over memberOf: %s" %
                      samr_dns.difference(memberOf))

        S = set()
        S.add(self.test_user_dn.get_casefold())

        self.filtered_closure(S, GTYPE_SECURITY_GLOBAL_GROUP)
        self.filtered_closure(S, GTYPE_SECURITY_UNIVERSAL_GROUP)

        # Now remove the user DN and primary group
        S.remove(self.test_user_dn.get_casefold())

        if len(samr_dns.difference(S)):
            self.fail(
                msg="additional samr_GetUserGroups over filtered_closure: %s" %
                samr_dns.difference(S))

    def test_samr_GetGroupsForUser_nomember(self):
        # Confirm that we get the correct results against SAMR also
        if not url.startswith("ldap://"):
            self.fail(
                msg=
                "This test is only valid on ldap (so we an find the hostname and use SAMR)"
            )
        host = url.split("://")[1]

        test_user = "******"
        self.admin_ldb.newuser(test_user, self.test_user_pass)
        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" %
                                    (test_user, self.base_dn),
                                    attrs=["objectSid"],
                                    scope=ldb.SCOPE_BASE)
        user_sid = ndr_unpack(samba.dcerpc.security.dom_sid,
                              res[0]["objectSid"][0])

        (domain_sid, user_rid) = user_sid.split()
        samr_conn = samba.dcerpc.samr.samr("ncacn_ip_tcp:%s[seal]" % host, lp,
                                           creds)
        samr_handle = samr_conn.Connect2(None,
                                         security.SEC_FLAG_MAXIMUM_ALLOWED)
        samr_domain = samr_conn.OpenDomain(samr_handle,
                                           security.SEC_FLAG_MAXIMUM_ALLOWED,
                                           domain_sid)
        user_handle = samr_conn.OpenUser(samr_domain,
                                         security.SEC_FLAG_MAXIMUM_ALLOWED,
                                         user_rid)
        rids = samr_conn.GetGroupsForUser(user_handle)
        user_info = samr_conn.QueryUserInfo(user_handle, 1)
        delete_force(self.admin_ldb,
                     "CN=%s,%s,%s" % (test_user, "cn=users", self.base_dn))
        self.assertEqual(len(rids.rids), 1)
        self.assertEqual(rids.rids[0].rid, user_info.primary_gid)
Example #21
0
class DsdbTests(TestCase):
    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

        # Create a test user
        user_name = "dsdb-user-" + str(uuid.uuid4().hex[0:6])
        user_pass = samba.generate_random_password(32, 32)
        user_description = "Test user for dsdb test"

        base_dn = self.samdb.domain_dn()

        self.account_dn = "cn=" + user_name + ",cn=Users," + base_dn
        self.samdb.newuser(username=user_name,
                           password=user_pass,
                           description=user_description)
        # Cleanup (teardown)
        self.addCleanup(delete_force, self.samdb, self.account_dn)

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEquals(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, [
            "local_oid:1.3.6.1.4.1.7165.4.3.14:0",
            "local_oid:1.3.6.1.4.1.7165.4.3.25:0"
        ])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val",
                                                ldb.FLAG_MOD_REPLACE,
                                                "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(13),
                          "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 2)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(
            dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description",
                                                      version + 2)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "description"),
            version + 2)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:0" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:1" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            (errno, estr) = e.args
            if errno != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail(
                    "Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                    % e[1])

    # Allocate a unique RID for use in the objectSID tests.
    #
    def allocate_rid(self):
        self.samdb.transaction_start()
        try:
            rid = self.samdb.allocate_rid()
        except:
            self.samdb.transaction_cancel()
            raise
        self.samdb.transaction_commit()
        return str(rid)

    # Ensure that duplicate objectSID's are permitted for foreign security
    # principals.
    #
    def test_duplicate_objectSIDs_allowed_on_foreign_security_principals(self):

        #
        # We need to build a foreign security principal SID
        # i.e a  SID not in the current domain.
        #
        dom_sid = self.samdb.get_domain_sid()
        if str(dom_sid).endswith("0"):
            c = "9"
        else:
            c = "0"
        sid_str = str(dom_sid)[:-1] + c + "-1000"
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        dn = "CN=%s,CN=ForeignSecurityPrincipals,%s" % (sid_str, basedn)

        #
        # First without control
        #

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"
            })
            self.fail("No exception should get ERR_OBJECT_CLASS_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_OBJECT_CLASS_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_MISSING_REQUIRED_ATT
            self.assertTrue(werr in msg, msg)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal",
                "objectSid": sid
            })
            self.fail("No exception should get ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_ILLEGAL_MOD_OPERATION
            self.assertTrue(werr in msg, msg)

        #
        # We need to use the provision control
        # in order to add foreignSecurityPrincipal
        # objects
        #

        controls = ["provision:0"]
        self.samdb.add({
            "dn": dn,
            "objectClass": "foreignSecurityPrincipal"
        },
                       controls=controls)

        self.samdb.delete(dn)

        try:
            self.samdb.add(
                {
                    "dn": dn,
                    "objectClass": "foreignSecurityPrincipal"
                },
                controls=controls)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.fail("Got unexpected exception %d - %s " % (code, msg))

        # cleanup
        self.samdb.delete(dn)

    def _test_foreignSecurityPrincipal(self, obj_class, fpo_attr):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_fpo"
        dn_str = "cn=%s,cn=Users,%s" % (cn, basedn)
        dn = ldb.Dn(self.samdb, dn_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn_str)

        self.samdb.add({"dn": dn_str, "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_INVALID_GROUP_TYPE
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_NO_SUCH_OBJECT")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_NO_SUCH_OBJECT, str(e))
            werr = "%08X" % werror.WERR_NO_SUCH_MEMBER
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 1)
        self.samdb.delete(res[0].dn)
        self.samdb.delete(dn)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

    def test_foreignSecurityPrincipal_member(self):
        return self._test_foreignSecurityPrincipal("group", "member")

    def test_foreignSecurityPrincipal_MembersForAzRole(self):
        return self._test_foreignSecurityPrincipal("msDS-AzRole",
                                                   "msDS-MembersForAzRole")

    def test_foreignSecurityPrincipal_NeverRevealGroup(self):
        return self._test_foreignSecurityPrincipal("computer",
                                                   "msDS-NeverRevealGroup")

    def test_foreignSecurityPrincipal_RevealOnDemandGroup(self):
        return self._test_foreignSecurityPrincipal("computer",
                                                   "msDS-RevealOnDemandGroup")

    def _test_fail_foreignSecurityPrincipal(self,
                                            obj_class,
                                            fpo_attr,
                                            msg_exp,
                                            lerr_exp,
                                            werr_exp,
                                            allow_reference=True):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn = self.samdb.get_default_basedn()
        cn1 = "dsdb_test_fpo1"
        dn1_str = "cn=%s,cn=Users,%s" % (cn1, basedn)
        dn1 = ldb.Dn(self.samdb, dn1_str)
        cn2 = "dsdb_test_fpo2"
        dn2_str = "cn=%s,cn=Users,%s" % (cn2, basedn)
        dn2 = ldb.Dn(self.samdb, dn2_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn1_str)
        self.addCleanup(delete_force, self.samdb, dn2_str)

        self.samdb.add({"dn": dn1_str, "objectClass": obj_class})

        self.samdb.add({"dn": dn2_str, "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("%s" % dn2, ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            if not allow_reference:
                self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            if allow_reference:
                self.fail("Should have not raised an exception: %s" % e)
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(dn2)
        self.samdb.delete(dn1)

    def test_foreignSecurityPrincipal_NonMembers(self):
        return self._test_fail_foreignSecurityPrincipal(
            "group",
            "msDS-NonMembers",
            "LDB_ERR_UNWILLING_TO_PERFORM/WERR_NOT_SUPPORTED",
            ldb.ERR_UNWILLING_TO_PERFORM,
            werror.WERR_NOT_SUPPORTED,
            allow_reference=False)

    def test_foreignSecurityPrincipal_HostServiceAccount(self):
        return self._test_fail_foreignSecurityPrincipal(
            "computer", "msDS-HostServiceAccount",
            "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
            ldb.ERR_CONSTRAINT_VIOLATION,
            werror.WERR_DS_NAME_REFERENCE_INVALID)

    def test_foreignSecurityPrincipal_manager(self):
        return self._test_fail_foreignSecurityPrincipal(
            "user", "manager",
            "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
            ldb.ERR_CONSTRAINT_VIOLATION,
            werror.WERR_DS_NAME_REFERENCE_INVALID)

    #
    # Duplicate objectSID's should not be permitted for sids in the local
    # domain. The test sequence is add an object, delete it, then attempt to
    # re-add it, this should fail with a constraint violation
    #
    def test_duplicate_objectSIDs_not_allowed_on_local_objects(self):

        dom_sid = self.samdb.get_domain_sid()
        rid = self.allocate_rid()
        sid_str = str(dom_sid) + "-" + rid
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_01"
        dn = "cn=%s,cn=Users,%s" % (cn, basedn)

        self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
        self.samdb.delete(dn)

        try:
            self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            if code != ldb.ERR_CONSTRAINT_VIOLATION:
                self.fail("Got %d - %s should have got "
                          "LDB_ERR_CONSTRAINT_VIOLATION" % (code, msg))

    def test_linked_vs_non_linked_reference(self):
        basedn = self.samdb.get_default_basedn()
        kept_dn_str = "cn=reference_kept,cn=Users,%s" % (basedn)
        removed_dn_str = "cn=reference_removed,cn=Users,%s" % (basedn)
        dom_sid = self.samdb.get_domain_sid()
        none_sid_str = str(dom_sid) + "-4294967294"
        none_guid_str = "afafafaf-fafa-afaf-fafa-afafafafafaf"

        self.addCleanup(delete_force, self.samdb, kept_dn_str)
        self.addCleanup(delete_force, self.samdb, removed_dn_str)

        self.samdb.add({"dn": kept_dn_str, "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=kept_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        kept_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        kept_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        kept_dn = res[0].dn

        self.samdb.add({"dn": removed_dn_str, "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=removed_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        removed_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        removed_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        self.samdb.delete(removed_dn_str)

        #
        # First try the linked attribute 'manager'
        # by GUID and SID
        #

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                            ldb.FLAG_MOD_ADD, "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                            ldb.FLAG_MOD_ADD, "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        #
        # Try the non-linked attribute 'assistant'
        # by GUID and SID, which should work.
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_ADD, "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_DELETE, "assistant")
        self.samdb.modify(msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_ADD, "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_DELETE, "assistant")
        self.samdb.modify(msg)

        #
        # Finally ry the non-linked attribute 'assistant'
        # but with non existing GUID, SID, DN
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("CN=NoneNone,%s" % (basedn),
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % none_sid_str,
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % none_guid_str,
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(kept_dn)

    def test_normalize_dn_in_domain_full(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        full_str = str(full_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_str))

    def test_normalize_dn_in_domain_part(self):
        domain_dn = self.samdb.domain_dn()

        part_str = "CN=Users"

        full_dn = ldb.Dn(self.samdb, part_str)
        full_dn.add_base(domain_dn)

        # That is, the domain DN appended
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(part_str))

    def test_normalize_dn_in_domain_full_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_dn))

    def test_normalize_dn_in_domain_part_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        # That is, the domain DN appended
        self.assertEqual(
            ldb.Dn(self.samdb,
                   str(part_dn) + "," + str(domain_dn)),
            self.samdb.normalize_dn_in_domain(part_dn))
class UserAccountControlTests(samba.tests.TestCase):
    def add_computer_ldap(self, computername, others=None, samdb=None):
        if samdb is None:
            samdb = self.samdb
        dn = "CN=%s,OU=test_computer_ou1,%s" % (computername, self.base_dn)
        domainname = ldb.Dn(self.samdb, self.samdb.domain_dn()).canonical_str().replace("/", "")
        samaccountname = "%s$" % computername
        dnshostname = "%s.%s" % (computername, domainname)
        msg_dict = {
            "dn": dn,
            "objectclass": "computer"}
        if others is not None:
            msg_dict = dict(msg_dict.items() + others.items())

        msg = ldb.Message.from_dict(self.samdb, msg_dict )
        msg["sAMAccountName"] = samaccountname

        print "Adding computer account %s" % computername
        samdb.add(msg)

    def get_creds(self, target_username, target_password):
        creds_tmp = Credentials()
        creds_tmp.set_username(target_username)
        creds_tmp.set_password(target_password)
        creds_tmp.set_domain(creds.get_domain())
        creds_tmp.set_realm(creds.get_realm())
        creds_tmp.set_workstation(creds.get_workstation())
        creds_tmp.set_gensec_features(creds_tmp.get_gensec_features()
                                      | gensec.FEATURE_SEAL)
        creds_tmp.set_kerberos_state(DONT_USE_KERBEROS) # kinit is too expensive to use in a tight loop
        return creds_tmp

    def setUp(self):
        super(UserAccountControlTests, self).setUp()
        self.admin_creds = creds
        self.admin_samdb = SamDB(url=ldaphost,
                                 session_info=system_session(),
                                 credentials=self.admin_creds, lp=lp)

        self.unpriv_user = "******"
        self.unpriv_user_pw = "samba123@"
        self.unpriv_creds = self.get_creds(self.unpriv_user, self.unpriv_user_pw)

        self.admin_samdb.newuser(self.unpriv_user, self.unpriv_user_pw)
        res = self.admin_samdb.search("CN=%s,CN=Users,%s" % (self.unpriv_user, self.admin_samdb.domain_dn()),
                                      scope=SCOPE_BASE,
                                      attrs=["objectSid"])
        self.assertEqual(1, len(res))

        self.unpriv_user_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        self.unpriv_user_dn = res[0].dn

        self.samdb = SamDB(url=ldaphost, credentials=self.unpriv_creds, lp=lp)
        self.domain_sid = security.dom_sid(self.samdb.get_domain_sid())
        self.base_dn = self.samdb.domain_dn()

        self.samr = samr.samr("ncacn_ip_tcp:%s[sign]" % host, lp, self.unpriv_creds)
        self.samr_handle = self.samr.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
        self.samr_domain = self.samr.OpenDomain(self.samr_handle, security.SEC_FLAG_MAXIMUM_ALLOWED, self.domain_sid)

        self.sd_utils = sd_utils.SDUtils(self.admin_samdb)

        self.admin_samdb.create_ou("OU=test_computer_ou1," + self.base_dn)
        self.unpriv_user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(self.unpriv_user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," + self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        self.add_computer_ldap("testcomputer-t")

        self.sd_utils.modify_sd_on_dn("OU=test_computer_ou1," + self.base_dn, old_sd)

        self.computernames = ["testcomputer-0"]

        # Get the SD of the template account, then force it to match
        # what we expect for SeMachineAccountPrivilege accounts, so we
        # can confirm we created the accounts correctly
        self.sd_reference_cc = self.sd_utils.read_sd_on_dn("CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))

        self.sd_reference_modify = self.sd_utils.read_sd_on_dn("CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))
        for ace in self.sd_reference_modify.dacl.aces:
            if ace.type == security.SEC_ACE_TYPE_ACCESS_ALLOWED and ace.trustee == self.unpriv_user_sid:
                ace.access_mask = ace.access_mask | security.SEC_ADS_SELF_WRITE | security.SEC_ADS_WRITE_PROP

        # Now reconnect without domain admin rights
        self.samdb = SamDB(url=ldaphost, credentials=self.unpriv_creds, lp=lp)


    def tearDown(self):
        super(UserAccountControlTests, self).tearDown()
        for computername in self.computernames:
            delete_force(self.admin_samdb, "CN=%s,OU=test_computer_ou1,%s" % (computername, self.base_dn))
        delete_force(self.admin_samdb, "CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))
        delete_force(self.admin_samdb, "OU=test_computer_ou1,%s" % (self.base_dn))
        delete_force(self.admin_samdb, "CN=%s,CN=Users,%s" % (self.unpriv_user, self.base_dn))

    def test_add_computer_sd_cc(self):
        user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," + self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        computername=self.computernames[0]
        sd = ldb.MessageElement((ndr_pack(self.sd_reference_modify)),
                                ldb.FLAG_MOD_ADD,
                                "nTSecurityDescriptor")
        self.add_computer_ldap(computername,
                               others={"nTSecurityDescriptor": sd})

        res = self.admin_samdb.search("%s" % self.base_dn,
                                      expression="(&(objectClass=computer)(samAccountName=%s$))" % computername,
                                      scope=SCOPE_SUBTREE,
                                      attrs=["ntSecurityDescriptor"])

        desc = res[0]["nTSecurityDescriptor"][0]
        desc = ndr_unpack(security.descriptor, desc, allow_remaining=True)

        sddl = desc.as_sddl(self.domain_sid)
        self.assertEqual(self.sd_reference_modify.as_sddl(self.domain_sid), sddl)

        m = ldb.Message()
        m.dn = res[0].dn
        m["description"]= ldb.MessageElement(
            ("A description"), ldb.FLAG_MOD_REPLACE,
            "description")
        self.samdb.modify(m)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(str(samba.dsdb.UF_SERVER_TRUST_ACCOUNT),
                                                     ldb.FLAG_MOD_REPLACE, "userAccountControl")
        try:
            self.samdb.modify(m)
            self.fail("Unexpectedly able to set userAccountControl to be a DC on %s" % m.dn)
        except LdbError, (enum, estr):
            self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)

        m = ldb.Message()
        m.dn = res[0].dn
        m["userAccountControl"] = ldb.MessageElement(str(samba.dsdb.UF_WORKSTATION_TRUST_ACCOUNT|samba.dsdb.UF_PARTIAL_SECRETS_ACCOUNT),
                                                     ldb.FLAG_MOD_REPLACE, "userAccountControl")
        try:
            self.samdb.modify(m)
            self.fail("Unexpectedly able to set userAccountControl to be an RODC on %s" % m.dn)
        except LdbError, (enum, estr):
            self.assertEqual(ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS, enum)
Example #23
0
    def test_rid_set_dbcheck(self):
        """Perform a join against the RID manager and assert we have a RID Set.
        Using dbcheck, we assert that we can detect out of range users."""

        fsmo_dn = ldb.Dn(
            self.ldb_dc1,
            "CN=RID Manager$,CN=System," + self.ldb_dc1.domain_dn())
        (fsmo_owner, fsmo_not_owner) = self._determine_fSMORoleOwner(fsmo_dn)

        targetdir = self._test_join(fsmo_owner['dns_name'], "RIDALLOCTEST6")
        try:
            # Connect to the database
            ldb_url = "tdb://%s" % os.path.join(targetdir, "private/sam.ldb")
            smbconf = os.path.join(targetdir, "etc/smb.conf")

            lp = self.get_loadparm()
            new_ldb = SamDB(ldb_url,
                            credentials=self.get_credentials(),
                            session_info=system_session(lp),
                            lp=lp)

            # 1. Get server name
            res = new_ldb.search(base=ldb.Dn(new_ldb,
                                             new_ldb.get_serverName()),
                                 scope=ldb.SCOPE_BASE,
                                 attrs=["serverReference"])
            # 2. Get server reference
            server_ref_dn = ldb.Dn(new_ldb, res[0]['serverReference'][0])

            # 3. Assert we get the RID Set
            res = new_ldb.search(base=server_ref_dn,
                                 scope=ldb.SCOPE_BASE,
                                 attrs=['rIDSetReferences'])

            self.assertTrue("rIDSetReferences" in res[0])
            rid_set_dn = ldb.Dn(new_ldb, res[0]["rIDSetReferences"][0])

            # 4. Add a new user (triggers RID set work)
            new_ldb.newuser("ridalloctestuser", "P@ssword!")

            # 5. Now fetch the RID SET
            rid_set_res = new_ldb.search(
                base=rid_set_dn,
                scope=ldb.SCOPE_BASE,
                attrs=['rIDNextRid', 'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            last_rid = (0xFFFFFFFF00000000 & next_pool) >> 32

            # 6. Add user above the ridNextRid and at mid-range.
            #
            # We can do this with safety because this is an offline DB that will be
            # destroyed.
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser1,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD,
                                                  'objectClass')
            m['objectSid'] = ldb.MessageElement(
                ndr_pack(
                    security.dom_sid(
                        str(new_ldb.get_domain_sid()) + "-%d" %
                        (last_rid - 10))), ldb.FLAG_MOD_ADD, 'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            # 7. Check the RID Set
            chk = dbcheck(new_ldb,
                          verbose=False,
                          fix=True,
                          yes=True,
                          quiet=True)

            # Should have one error (wrong rIDNextRID)
            self.assertEqual(
                chk.check_database(DN=rid_set_dn, scope=ldb.SCOPE_BASE), 1)

            # 8. Assert we get didn't show any other errors
            chk = dbcheck(new_ldb, verbose=False, fix=False, quiet=True)

            rid_set_res = new_ldb.search(
                base=rid_set_dn,
                scope=ldb.SCOPE_BASE,
                attrs=['rIDNextRid', 'rIDAllocationPool'])
            last_allocated_rid = int(rid_set_res[0]["rIDNextRid"][0])
            self.assertEquals(last_allocated_rid, last_rid - 10)

            # 9. Assert that the range wasn't thrown away

            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            self.assertEqual(last_rid, (0xFFFFFFFF00000000 & next_pool) >> 32,
                             "rid pool should have changed")
        finally:
            self._test_force_demote(fsmo_owner['dns_name'], "RIDALLOCTEST6")
            shutil.rmtree(targetdir, ignore_errors=True)
Example #24
0
class DynamicTokenTest(samba.tests.TestCase):
    def get_creds(self, target_username, target_password):
        creds_tmp = Credentials()
        creds_tmp.set_username(target_username)
        creds_tmp.set_password(target_password)
        creds_tmp.set_domain(creds.get_domain())
        creds_tmp.set_realm(creds.get_realm())
        creds_tmp.set_workstation(creds.get_workstation())
        creds_tmp.set_gensec_features(creds_tmp.get_gensec_features()
                                      | gensec.FEATURE_SEAL)
        return creds_tmp

    def get_ldb_connection(self, target_username, target_password):
        creds_tmp = self.get_creds(target_username, target_password)
        ldb_target = SamDB(url=url, credentials=creds_tmp, lp=lp)
        return ldb_target

    def setUp(self):
        super(DynamicTokenTest, self).setUp()
        self.admin_ldb = SamDB(url,
                               credentials=creds,
                               session_info=system_session(lp),
                               lp=lp)

        self.base_dn = self.admin_ldb.domain_dn()

        self.test_user = "******"
        self.test_user_pass = "******"
        self.admin_ldb.newuser(self.test_user, self.test_user_pass)
        self.test_group0 = "tokengroups_group0"
        self.admin_ldb.newgroup(
            self.test_group0, grouptype=dsdb.GTYPE_SECURITY_DOMAIN_LOCAL_GROUP)
        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" %
                                    (self.test_group0, self.base_dn),
                                    attrs=["objectSid"],
                                    scope=ldb.SCOPE_BASE)
        self.test_group0_sid = ndr_unpack(samba.dcerpc.security.dom_sid,
                                          res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group0,
                                                [self.test_user],
                                                add_members_operation=True)

        self.test_group1 = "tokengroups_group1"
        self.admin_ldb.newgroup(self.test_group1,
                                grouptype=dsdb.GTYPE_SECURITY_GLOBAL_GROUP)
        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" %
                                    (self.test_group1, self.base_dn),
                                    attrs=["objectSid"],
                                    scope=ldb.SCOPE_BASE)
        self.test_group1_sid = ndr_unpack(samba.dcerpc.security.dom_sid,
                                          res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group1,
                                                [self.test_user],
                                                add_members_operation=True)

        self.test_group2 = "tokengroups_group2"
        self.admin_ldb.newgroup(self.test_group2,
                                grouptype=dsdb.GTYPE_SECURITY_UNIVERSAL_GROUP)

        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" %
                                    (self.test_group2, self.base_dn),
                                    attrs=["objectSid"],
                                    scope=ldb.SCOPE_BASE)
        self.test_group2_sid = ndr_unpack(samba.dcerpc.security.dom_sid,
                                          res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group2,
                                                [self.test_user],
                                                add_members_operation=True)

        self.ldb = self.get_ldb_connection(self.test_user, self.test_user_pass)

        res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        self.user_sid_dn = "<SID=%s>" % str(
            ndr_unpack(samba.dcerpc.security.dom_sid,
                       res[0]["tokenGroups"][0]))

        res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=[])
        self.assertEquals(len(res), 1)

        self.test_user_dn = res[0].dn

        session_info_flags = (AUTH_SESSION_INFO_DEFAULT_GROUPS
                              | AUTH_SESSION_INFO_AUTHENTICATED
                              | AUTH_SESSION_INFO_SIMPLE_PRIVILEGES)
        session = samba.auth.user_session(
            self.ldb,
            lp_ctx=lp,
            dn=self.user_sid_dn,
            session_info_flags=session_info_flags)

        token = session.security_token
        self.user_sids = []
        for s in token.sids:
            self.user_sids.append(str(s))

    def tearDown(self):
        super(DynamicTokenTest, self).tearDown()
        delete_force(
            self.admin_ldb,
            "CN=%s,%s,%s" % (self.test_user, "cn=users", self.base_dn))
        delete_force(
            self.admin_ldb,
            "CN=%s,%s,%s" % (self.test_group0, "cn=users", self.base_dn))
        delete_force(
            self.admin_ldb,
            "CN=%s,%s,%s" % (self.test_group1, "cn=users", self.base_dn))
        delete_force(
            self.admin_ldb,
            "CN=%s,%s,%s" % (self.test_group2, "cn=users", self.base_dn))

    def test_rootDSE_tokenGroups(self):
        """Testing rootDSE tokengroups against internal calculation"""
        if not url.startswith("ldap"):
            self.fail(msg="This test is only valid on ldap")

        res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        print("Getting tokenGroups from rootDSE")
        tokengroups = []
        for sid in res[0]['tokenGroups']:
            tokengroups.append(
                str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))

        sidset1 = set(tokengroups)
        sidset2 = set(self.user_sids)
        if len(sidset1.difference(sidset2)):
            print("token sids don't match")
            print("tokengroups: %s" % tokengroups)
            print("calculated : %s" % self.user_sids)
            print("difference : %s" % sidset1.difference(sidset2))
            self.fail(
                msg="calculated groups don't match against rootDSE tokenGroups"
            )

    def test_dn_tokenGroups(self):
        print("Getting tokenGroups from user DN")
        res = self.ldb.search(self.user_sid_dn,
                              scope=ldb.SCOPE_BASE,
                              attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        dn_tokengroups = []
        for sid in res[0]['tokenGroups']:
            dn_tokengroups.append(
                str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))

        sidset1 = set(dn_tokengroups)
        sidset2 = set(self.user_sids)
        if len(sidset1.difference(sidset2)):
            print("token sids don't match")
            print("difference : %s" % sidset1.difference(sidset2))
            self.fail(
                msg="calculated groups don't match against user DN tokenGroups"
            )

    def test_pac_groups(self):
        settings = {}
        settings["lp_ctx"] = lp
        settings["target_hostname"] = lp.get("netbios name")

        gensec_client = gensec.Security.start_client(settings)
        gensec_client.set_credentials(
            self.get_creds(self.test_user, self.test_user_pass))
        gensec_client.want_feature(gensec.FEATURE_SEAL)
        gensec_client.start_mech_by_sasl_name("GSSAPI")

        auth_context = AuthContext(lp_ctx=lp, ldb=self.ldb, methods=[])

        gensec_server = gensec.Security.start_server(settings, auth_context)
        machine_creds = Credentials()
        machine_creds.guess(lp)
        machine_creds.set_machine_account(lp)
        gensec_server.set_credentials(machine_creds)

        gensec_server.want_feature(gensec.FEATURE_SEAL)
        gensec_server.start_mech_by_sasl_name("GSSAPI")

        client_finished = False
        server_finished = False
        server_to_client = ""

        # Run the actual call loop.
        while client_finished == False and server_finished == False:
            if not client_finished:
                print "running client gensec_update"
                (client_finished,
                 client_to_server) = gensec_client.update(server_to_client)
            if not server_finished:
                print "running server gensec_update"
                (server_finished,
                 server_to_client) = gensec_server.update(client_to_server)

        session = gensec_server.session_info()

        token = session.security_token
        pac_sids = []
        for s in token.sids:
            pac_sids.append(str(s))

        sidset1 = set(pac_sids)
        sidset2 = set(self.user_sids)
        if len(sidset1.difference(sidset2)):
            print("token sids don't match")
            print("difference : %s" % sidset1.difference(sidset2))
            self.fail(
                msg="calculated groups don't match against user PAC tokenGroups"
            )

    def test_tokenGroups_manual(self):
        # Manually run the tokenGroups algorithm from MS-ADTS 3.1.1.4.5.19 and MS-DRSR 4.1.8.3
        # and compare the result
        res = self.admin_ldb.search(
            base=self.base_dn,
            scope=ldb.SCOPE_SUBTREE,
            expression="(|(objectclass=user)(objectclass=group))",
            attrs=["memberOf"])
        aSet = set()
        aSetR = set()
        vSet = set()
        for obj in res:
            if "memberOf" in obj:
                for dn in obj["memberOf"]:
                    first = obj.dn.get_casefold()
                    second = ldb.Dn(self.admin_ldb, dn).get_casefold()
                    aSet.add((first, second))
                    aSetR.add((second, first))
                    vSet.add(first)
                    vSet.add(second)

        res = self.admin_ldb.search(base=self.base_dn,
                                    scope=ldb.SCOPE_SUBTREE,
                                    expression="(objectclass=user)",
                                    attrs=["primaryGroupID"])
        for obj in res:
            if "primaryGroupID" in obj:
                sid = "%s-%d" % (self.admin_ldb.get_domain_sid(),
                                 int(obj["primaryGroupID"][0]))
                res2 = self.admin_ldb.search(base="<SID=%s>" % sid,
                                             scope=ldb.SCOPE_BASE,
                                             attrs=[])
                first = obj.dn.get_casefold()
                second = res2[0].dn.get_casefold()

                aSet.add((first, second))
                aSetR.add((second, first))
                vSet.add(first)
                vSet.add(second)

        wSet = set()
        wSet.add(self.test_user_dn.get_casefold())
        closure(vSet, wSet, aSet)
        wSet.remove(self.test_user_dn.get_casefold())

        tokenGroupsSet = set()

        res = self.ldb.search(self.user_sid_dn,
                              scope=ldb.SCOPE_BASE,
                              attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        dn_tokengroups = []
        for sid in res[0]['tokenGroups']:
            sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
            res3 = self.admin_ldb.search(base="<SID=%s>" % sid,
                                         scope=ldb.SCOPE_BASE,
                                         attrs=[])
            tokenGroupsSet.add(res3[0].dn.get_casefold())

        if len(wSet.difference(tokenGroupsSet)):
            self.fail(msg="additional calculated: %s" %
                      wSet.difference(tokenGroupsSet))

        if len(tokenGroupsSet.difference(wSet)):
            self.fail(msg="additional tokenGroups: %s" %
                      tokenGroupsSet.difference(wSet))

    def filtered_closure(self, wSet, filter_grouptype):
        res = self.admin_ldb.search(
            base=self.base_dn,
            scope=ldb.SCOPE_SUBTREE,
            expression="(|(objectclass=user)(objectclass=group))",
            attrs=["memberOf"])
        aSet = set()
        aSetR = set()
        vSet = set()
        for obj in res:
            vSet.add(obj.dn.get_casefold())
            if "memberOf" in obj:
                for dn in obj["memberOf"]:
                    first = obj.dn.get_casefold()
                    second = ldb.Dn(self.admin_ldb, dn).get_casefold()
                    aSet.add((first, second))
                    aSetR.add((second, first))
                    vSet.add(first)
                    vSet.add(second)

        res = self.admin_ldb.search(base=self.base_dn,
                                    scope=ldb.SCOPE_SUBTREE,
                                    expression="(objectclass=user)",
                                    attrs=["primaryGroupID"])
        for obj in res:
            if "primaryGroupID" in obj:
                sid = "%s-%d" % (self.admin_ldb.get_domain_sid(),
                                 int(obj["primaryGroupID"][0]))
                res2 = self.admin_ldb.search(base="<SID=%s>" % sid,
                                             scope=ldb.SCOPE_BASE,
                                             attrs=[])
                first = obj.dn.get_casefold()
                second = res2[0].dn.get_casefold()

                aSet.add((first, second))
                aSetR.add((second, first))
                vSet.add(first)
                vSet.add(second)

        uSet = set()
        for v in vSet:
            res_group = self.admin_ldb.search(base=v,
                                              scope=ldb.SCOPE_BASE,
                                              attrs=["groupType"],
                                              expression="objectClass=group")
            if len(res_group) == 1:
                if hex(int(res_group[0]["groupType"][0])
                       & 0x00000000FFFFFFFF) == hex(filter_grouptype):
                    uSet.add(v)
            else:
                uSet.add(v)

        closure(uSet, wSet, aSet)

    def test_tokenGroupsGlobalAndUniversal_manual(self):
        # Manually run the tokenGroups algorithm from MS-ADTS 3.1.1.4.5.19 and MS-DRSR 4.1.8.3
        # and compare the result

        # The variable names come from MS-ADTS May 15, 2014

        S = set()
        S.add(self.test_user_dn.get_casefold())

        self.filtered_closure(S, GTYPE_SECURITY_GLOBAL_GROUP)

        T = set()
        # Not really a SID, we do this on DNs...
        for sid in S:
            X = set()
            X.add(sid)
            self.filtered_closure(X, GTYPE_SECURITY_UNIVERSAL_GROUP)

            T = T.union(X)

        T.remove(self.test_user_dn.get_casefold())

        tokenGroupsSet = set()

        res = self.ldb.search(self.user_sid_dn,
                              scope=ldb.SCOPE_BASE,
                              attrs=["tokenGroupsGlobalAndUniversal"])
        self.assertEquals(len(res), 1)

        dn_tokengroups = []
        for sid in res[0]['tokenGroupsGlobalAndUniversal']:
            sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
            res3 = self.admin_ldb.search(base="<SID=%s>" % sid,
                                         scope=ldb.SCOPE_BASE,
                                         attrs=[])
            tokenGroupsSet.add(res3[0].dn.get_casefold())

        if len(T.difference(tokenGroupsSet)):
            self.fail(msg="additional calculated: %s" %
                      T.difference(tokenGroupsSet))

        if len(tokenGroupsSet.difference(T)):
            self.fail(msg="additional tokenGroupsGlobalAndUniversal: %s" %
                      tokenGroupsSet.difference(T))
Example #25
0
class DynamicTokenTest(samba.tests.TestCase):

    def get_creds(self, target_username, target_password):
        creds_tmp = Credentials()
        creds_tmp.set_username(target_username)
        creds_tmp.set_password(target_password)
        creds_tmp.set_domain(creds.get_domain())
        creds_tmp.set_realm(creds.get_realm())
        creds_tmp.set_workstation(creds.get_workstation())
        creds_tmp.set_gensec_features(creds_tmp.get_gensec_features()
                                      | gensec.FEATURE_SEAL)
        return creds_tmp

    def get_ldb_connection(self, target_username, target_password):
        creds_tmp = self.get_creds(target_username, target_password)
        ldb_target = SamDB(url=url, credentials=creds_tmp, lp=lp)
        return ldb_target

    def setUp(self):
        super(DynamicTokenTest, self).setUp()
        self.admin_ldb = SamDB(url, credentials=creds, session_info=system_session(lp), lp=lp)

        self.base_dn = self.admin_ldb.domain_dn()

        self.test_user = "******"
        self.test_user_pass = "******"
        self.admin_ldb.newuser(self.test_user, self.test_user_pass)
        self.test_group0 = "tokengroups_group0"
        self.admin_ldb.newgroup(self.test_group0, grouptype=dsdb.GTYPE_SECURITY_DOMAIN_LOCAL_GROUP)
        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group0, self.base_dn),
                                    attrs=["objectSid"], scope=ldb.SCOPE_BASE)
        self.test_group0_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group0, [self.test_user],
                                       add_members_operation=True)

        self.test_group1 = "tokengroups_group1"
        self.admin_ldb.newgroup(self.test_group1, grouptype=dsdb.GTYPE_SECURITY_GLOBAL_GROUP)
        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group1, self.base_dn),
                                    attrs=["objectSid"], scope=ldb.SCOPE_BASE)
        self.test_group1_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group1, [self.test_user],
                                       add_members_operation=True)

        self.test_group2 = "tokengroups_group2"
        self.admin_ldb.newgroup(self.test_group2, grouptype=dsdb.GTYPE_SECURITY_UNIVERSAL_GROUP)

        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group2, self.base_dn),
                                    attrs=["objectSid"], scope=ldb.SCOPE_BASE)
        self.test_group2_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group2, [self.test_user],
                                       add_members_operation=True)

        self.test_group3 = "tokengroups_group3"
        self.admin_ldb.newgroup(self.test_group3, grouptype=dsdb.GTYPE_SECURITY_UNIVERSAL_GROUP)

        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group3, self.base_dn),
                                    attrs=["objectSid"], scope=ldb.SCOPE_BASE)
        self.test_group3_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group3, [self.test_group1],
                                       add_members_operation=True)

        self.test_group4 = "tokengroups_group4"
        self.admin_ldb.newgroup(self.test_group4, grouptype=dsdb.GTYPE_SECURITY_UNIVERSAL_GROUP)

        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group4, self.base_dn),
                                    attrs=["objectSid"], scope=ldb.SCOPE_BASE)
        self.test_group4_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group4, [self.test_group3],
                                       add_members_operation=True)

        self.test_group5 = "tokengroups_group5"
        self.admin_ldb.newgroup(self.test_group5, grouptype=dsdb.GTYPE_SECURITY_DOMAIN_LOCAL_GROUP)

        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group5, self.base_dn),
                                    attrs=["objectSid"], scope=ldb.SCOPE_BASE)
        self.test_group5_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group5, [self.test_group4],
                                                add_members_operation=True)

        self.test_group6 = "tokengroups_group6"
        self.admin_ldb.newgroup(self.test_group6, grouptype=dsdb.GTYPE_SECURITY_DOMAIN_LOCAL_GROUP)

        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group6, self.base_dn),
                                    attrs=["objectSid"], scope=ldb.SCOPE_BASE)
        self.test_group6_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])

        self.admin_ldb.add_remove_group_members(self.test_group6, [self.test_user],
                                                add_members_operation=True)

        self.ldb = self.get_ldb_connection(self.test_user, self.test_user_pass)

        res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        self.user_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["tokenGroups"][0])
        self.user_sid_dn = "<SID=%s>" % str(self.user_sid)

        res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=[])
        self.assertEquals(len(res), 1)

        self.test_user_dn = res[0].dn

        session_info_flags = ( AUTH_SESSION_INFO_DEFAULT_GROUPS |
                               AUTH_SESSION_INFO_AUTHENTICATED |
                               AUTH_SESSION_INFO_SIMPLE_PRIVILEGES)
        session = samba.auth.user_session(self.ldb, lp_ctx=lp, dn=self.user_sid_dn,
                                          session_info_flags=session_info_flags)

        token = session.security_token
        self.user_sids = []
        for s in token.sids:
            self.user_sids.append(str(s))

    def tearDown(self):
        super(DynamicTokenTest, self).tearDown()
        delete_force(self.admin_ldb, "CN=%s,%s,%s" %
                          (self.test_user, "cn=users", self.base_dn))
        delete_force(self.admin_ldb, "CN=%s,%s,%s" %
                          (self.test_group0, "cn=users", self.base_dn))
        delete_force(self.admin_ldb, "CN=%s,%s,%s" %
                          (self.test_group1, "cn=users", self.base_dn))
        delete_force(self.admin_ldb, "CN=%s,%s,%s" %
                          (self.test_group2, "cn=users", self.base_dn))
        delete_force(self.admin_ldb, "CN=%s,%s,%s" %
                          (self.test_group3, "cn=users", self.base_dn))
        delete_force(self.admin_ldb, "CN=%s,%s,%s" %
                          (self.test_group4, "cn=users", self.base_dn))
        delete_force(self.admin_ldb, "CN=%s,%s,%s" %
                          (self.test_group5, "cn=users", self.base_dn))
        delete_force(self.admin_ldb, "CN=%s,%s,%s" %
                          (self.test_group6, "cn=users", self.base_dn))

    def test_rootDSE_tokenGroups(self):
        """Testing rootDSE tokengroups against internal calculation"""
        if not url.startswith("ldap"):
            self.fail(msg="This test is only valid on ldap")

        res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        print("Getting tokenGroups from rootDSE")
        tokengroups = []
        for sid in res[0]['tokenGroups']:
            tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))

        sidset1 = set(tokengroups)
        sidset2 = set(self.user_sids)
        if len(sidset1.difference(sidset2)):
            print("token sids don't match")
            print("tokengroups: %s" % tokengroups)
            print("calculated : %s" % self.user_sids)
            print("difference : %s" % sidset1.difference(sidset2))
            self.fail(msg="calculated groups don't match against rootDSE tokenGroups")

    def test_dn_tokenGroups(self):
        print("Getting tokenGroups from user DN")
        res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        dn_tokengroups = []
        for sid in res[0]['tokenGroups']:
            dn_tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))

        sidset1 = set(dn_tokengroups)
        sidset2 = set(self.user_sids)
        if len(sidset1.difference(sidset2)):
            print("token sids don't match")
            print("difference : %s" % sidset1.difference(sidset2))
            self.fail(msg="calculated groups don't match against user DN tokenGroups")

    def test_pac_groups(self):
        settings = {}
        settings["lp_ctx"] = lp
        settings["target_hostname"] = lp.get("netbios name")

        gensec_client = gensec.Security.start_client(settings)
        gensec_client.set_credentials(self.get_creds(self.test_user, self.test_user_pass))
        gensec_client.want_feature(gensec.FEATURE_SEAL)
        gensec_client.start_mech_by_sasl_name("GSSAPI")

        auth_context = AuthContext(lp_ctx=lp, ldb=self.ldb, methods=[])

        gensec_server = gensec.Security.start_server(settings, auth_context)
        machine_creds = Credentials()
        machine_creds.guess(lp)
        machine_creds.set_machine_account(lp)
        gensec_server.set_credentials(machine_creds)

        gensec_server.want_feature(gensec.FEATURE_SEAL)
        gensec_server.start_mech_by_sasl_name("GSSAPI")

        client_finished = False
        server_finished = False
        server_to_client = ""

        # Run the actual call loop.
        while client_finished == False and server_finished == False:
            if not client_finished:
                print "running client gensec_update"
                (client_finished, client_to_server) = gensec_client.update(server_to_client)
            if not server_finished:
                print "running server gensec_update"
                (server_finished, server_to_client) = gensec_server.update(client_to_server)

        session = gensec_server.session_info()

        token = session.security_token
        pac_sids = []
        for s in token.sids:
            pac_sids.append(str(s))

        sidset1 = set(pac_sids)
        sidset2 = set(self.user_sids)
        if len(sidset1.difference(sidset2)):
            print("token sids don't match")
            print("difference : %s" % sidset1.difference(sidset2))
            self.fail(msg="calculated groups don't match against user PAC tokenGroups")


    def test_tokenGroups_manual(self):
        # Manually run the tokenGroups algorithm from MS-ADTS 3.1.1.4.5.19 and MS-DRSR 4.1.8.3
        # and compare the result
        res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
                                    expression="(|(objectclass=user)(objectclass=group))",
                                    attrs=["memberOf"])
        aSet = set()
        aSetR = set()
        vSet = set()
        for obj in res:
            if "memberOf" in obj:
                for dn in obj["memberOf"]:
                    first = obj.dn.get_casefold()
                    second = ldb.Dn(self.admin_ldb, dn).get_casefold()
                    aSet.add((first, second))
                    aSetR.add((second, first))
                    vSet.add(first)
                    vSet.add(second)

        res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
                                    expression="(objectclass=user)",
                                    attrs=["primaryGroupID"])
        for obj in res:
            if "primaryGroupID" in obj:
                sid = "%s-%d" % (self.admin_ldb.get_domain_sid(), int(obj["primaryGroupID"][0]))
                res2 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
                                             attrs=[])
                first = obj.dn.get_casefold()
                second = res2[0].dn.get_casefold()

                aSet.add((first, second))
                aSetR.add((second, first))
                vSet.add(first)
                vSet.add(second)

        wSet = set()
        wSet.add(self.test_user_dn.get_casefold())
        closure(vSet, wSet, aSet)
        wSet.remove(self.test_user_dn.get_casefold())

        tokenGroupsSet = set()

        res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        dn_tokengroups = []
        for sid in res[0]['tokenGroups']:
            sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
            res3 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
                                         attrs=[])
            tokenGroupsSet.add(res3[0].dn.get_casefold())

        if len(wSet.difference(tokenGroupsSet)):
            self.fail(msg="additional calculated: %s" % wSet.difference(tokenGroupsSet))

        if len(tokenGroupsSet.difference(wSet)):
            self.fail(msg="additional tokenGroups: %s" % tokenGroupsSet.difference(wSet))


    def filtered_closure(self, wSet, filter_grouptype):
        res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
                                    expression="(|(objectclass=user)(objectclass=group))",
                                    attrs=["memberOf"])
        aSet = set()
        aSetR = set()
        vSet = set()
        for obj in res:
            vSet.add(obj.dn.get_casefold())
            if "memberOf" in obj:
                for dn in obj["memberOf"]:
                    first = obj.dn.get_casefold()
                    second = ldb.Dn(self.admin_ldb, dn).get_casefold()
                    aSet.add((first, second))
                    aSetR.add((second, first))
                    vSet.add(first)
                    vSet.add(second)

        res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
                                    expression="(objectclass=user)",
                                    attrs=["primaryGroupID"])
        for obj in res:
            if "primaryGroupID" in obj:
                sid = "%s-%d" % (self.admin_ldb.get_domain_sid(), int(obj["primaryGroupID"][0]))
                res2 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
                                             attrs=[])
                first = obj.dn.get_casefold()
                second = res2[0].dn.get_casefold()

                aSet.add((first, second))
                aSetR.add((second, first))
                vSet.add(first)
                vSet.add(second)

        uSet = set()
        for v in vSet:
            res_group = self.admin_ldb.search(base=v, scope=ldb.SCOPE_BASE,
                                              attrs=["groupType"],
                                              expression="objectClass=group")
            if len(res_group) == 1:
                if hex(int(res_group[0]["groupType"][0]) & 0x00000000FFFFFFFF) == hex(filter_grouptype):
                    uSet.add(v)
            else:
                uSet.add(v)

        closure(uSet, wSet, aSet)


    def test_tokenGroupsGlobalAndUniversal_manual(self):
        # Manually run the tokenGroups algorithm from MS-ADTS 3.1.1.4.5.19 and MS-DRSR 4.1.8.3
        # and compare the result

        # The variable names come from MS-ADTS May 15, 2014

        S = set()
        S.add(self.test_user_dn.get_casefold())

        self.filtered_closure(S, GTYPE_SECURITY_GLOBAL_GROUP)

        T = set()
        # Not really a SID, we do this on DNs...
        for sid in S:
            X = set()
            X.add(sid)
            self.filtered_closure(X, GTYPE_SECURITY_UNIVERSAL_GROUP)

            T = T.union(X)

        T.remove(self.test_user_dn.get_casefold())

        tokenGroupsSet = set()

        res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroupsGlobalAndUniversal"])
        self.assertEquals(len(res), 1)

        dn_tokengroups = []
        for sid in res[0]['tokenGroupsGlobalAndUniversal']:
            sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
            res3 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
                                         attrs=[])
            tokenGroupsSet.add(res3[0].dn.get_casefold())

        if len(T.difference(tokenGroupsSet)):
            self.fail(msg="additional calculated: %s" % T.difference(tokenGroupsSet))

        if len(tokenGroupsSet.difference(T)):
            self.fail(msg="additional tokenGroupsGlobalAndUniversal: %s" % tokenGroupsSet.difference(T))

    def test_samr_GetGroupsForUser(self):
        # Confirm that we get the correct results against SAMR also
        if not url.startswith("ldap://"):
            self.fail(msg="This test is only valid on ldap (so we an find the hostname and use SAMR)")
        host = url.split("://")[1]
        (domain_sid, user_rid) = self.user_sid.split()
        samr_conn = samba.dcerpc.samr.samr("ncacn_ip_tcp:%s[seal]" % host, lp, creds)
        samr_handle = samr_conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
        samr_domain = samr_conn.OpenDomain(samr_handle, security.SEC_FLAG_MAXIMUM_ALLOWED,
                                      domain_sid)
        user_handle = samr_conn.OpenUser(samr_domain, security.SEC_FLAG_MAXIMUM_ALLOWED, user_rid)
        rids = samr_conn.GetGroupsForUser(user_handle)
        samr_dns = set()
        for rid in rids.rids:
            self.assertEqual(rid.attributes, samr.SE_GROUP_MANDATORY | samr.SE_GROUP_ENABLED_BY_DEFAULT| samr.SE_GROUP_ENABLED)
            sid = "%s-%d" % (domain_sid, rid.rid)
            res = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
                                  attrs=[])
            samr_dns.add(res[0].dn.get_casefold())

        user_info = samr_conn.QueryUserInfo(user_handle, 1)
        self.assertEqual(rids.rids[0].rid, user_info.primary_gid)

        tokenGroupsSet = set()
        res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroupsGlobalAndUniversal"])
        for sid in res[0]['tokenGroupsGlobalAndUniversal']:
            sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
            res3 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
                                         attrs=[],
                                         expression="(&(|(grouptype=%d)(grouptype=%d))(objectclass=group))"
                                         % (GTYPE_SECURITY_GLOBAL_GROUP, GTYPE_SECURITY_UNIVERSAL_GROUP))
            if len(res) == 1:
                tokenGroupsSet.add(res3[0].dn.get_casefold())

        if len(samr_dns.difference(tokenGroupsSet)):
            self.fail(msg="additional samr_GetUserGroups over tokenGroups: %s" % samr_dns.difference(tokenGroupsSet))

        memberOf = set()
        # Add the primary group
        primary_group_sid = "%s-%d" % (domain_sid, user_info.primary_gid)
        res2 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
                                     attrs=[])

        memberOf.add(res2[0].dn.get_casefold())
        res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["memberOf"])
        for dn in res[0]['memberOf']:
            res3 = self.admin_ldb.search(base=dn, scope=ldb.SCOPE_BASE,
                                         attrs=[],
                                         expression="(&(|(grouptype=%d)(grouptype=%d))(objectclass=group))"
                                         % (GTYPE_SECURITY_GLOBAL_GROUP, GTYPE_SECURITY_UNIVERSAL_GROUP))
            if len(res3) == 1:
                memberOf.add(res3[0].dn.get_casefold())

        if len(memberOf.difference(samr_dns)):
            self.fail(msg="additional memberOf over samr_GetUserGroups: %s" % memberOf.difference(samr_dns))

        if len(samr_dns.difference(memberOf)):
            self.fail(msg="additional samr_GetUserGroups over memberOf: %s" % samr_dns.difference(memberOf))

        S = set()
        S.add(self.test_user_dn.get_casefold())

        self.filtered_closure(S, GTYPE_SECURITY_GLOBAL_GROUP)
        self.filtered_closure(S, GTYPE_SECURITY_UNIVERSAL_GROUP)

        # Now remove the user DN and primary group
        S.remove(self.test_user_dn.get_casefold())

        if len(samr_dns.difference(S)):
            self.fail(msg="additional samr_GetUserGroups over filtered_closure: %s" % samr_dns.difference(S))

    def test_samr_GetGroupsForUser_nomember(self):
        # Confirm that we get the correct results against SAMR also
        if not url.startswith("ldap://"):
            self.fail(msg="This test is only valid on ldap (so we an find the hostname and use SAMR)")
        host = url.split("://")[1]

        test_user = "******"
        self.admin_ldb.newuser(test_user, self.test_user_pass)
        res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (test_user, self.base_dn),
                                    attrs=["objectSid"], scope=ldb.SCOPE_BASE)
        user_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])

        (domain_sid, user_rid) = user_sid.split()
        samr_conn = samba.dcerpc.samr.samr("ncacn_ip_tcp:%s[seal]" % host, lp, creds)
        samr_handle = samr_conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
        samr_domain = samr_conn.OpenDomain(samr_handle, security.SEC_FLAG_MAXIMUM_ALLOWED,
                                           domain_sid)
        user_handle = samr_conn.OpenUser(samr_domain, security.SEC_FLAG_MAXIMUM_ALLOWED, user_rid)
        rids = samr_conn.GetGroupsForUser(user_handle)
        user_info = samr_conn.QueryUserInfo(user_handle, 1)
        delete_force(self.admin_ldb, "CN=%s,%s,%s" %
                          (test_user, "cn=users", self.base_dn))
        self.assertEqual(len(rids.rids), 1)
        self.assertEqual(rids.rids[0].rid, user_info.primary_gid)
Example #26
0
    def run(self,
            computername,
            credopts=None,
            sambaopts=None,
            versionopts=None,
            H=None,
            computerou=None,
            description=None,
            prepare_oldjoin=False,
            ip_address_list=None,
            service_principal_name_list=None):

        if ip_address_list is None:
            ip_address_list = []

        if service_principal_name_list is None:
            service_principal_name_list = []

        # check each IP address if provided
        for ip_address in ip_address_list:
            if not _is_valid_ip(ip_address):
                raise CommandError('Invalid IP address {}'.format(ip_address))

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        try:
            samdb = SamDB(url=H,
                          session_info=system_session(),
                          credentials=creds,
                          lp=lp)
            samdb.newcomputer(
                computername,
                computerou=computerou,
                description=description,
                prepare_oldjoin=prepare_oldjoin,
                ip_address_list=ip_address_list,
                service_principal_name_list=service_principal_name_list,
            )

            if ip_address_list:
                # if ip_address_list provided, then we need to create DNS
                # records for this computer.

                hostname = re.sub(r"\$$", "", computername)
                if hostname.count('$'):
                    raise CommandError('Illegal computername "%s"' %
                                       computername)

                filters = '(&(sAMAccountName={}$)(objectclass=computer))'.format(
                    ldb.binary_encode(hostname))

                recs = samdb.search(base=samdb.domain_dn(),
                                    scope=ldb.SCOPE_SUBTREE,
                                    expression=filters,
                                    attrs=['primaryGroupID', 'objectSid'])

                group = recs[0]['primaryGroupID'][0]
                owner = ndr_unpack(security.dom_sid, recs[0]["objectSid"][0])

                dns_conn = dnsserver.dnsserver(
                    "ncacn_ip_tcp:{}[sign]".format(samdb.host_dns_name()), lp,
                    creds)

                change_owner_sd = security.descriptor()
                change_owner_sd.owner_sid = owner
                change_owner_sd.group_sid = security.dom_sid(
                    "{}-{}".format(samdb.get_domain_sid(), group), )

                add_dns_records(samdb, hostname, dns_conn, change_owner_sd,
                                samdb.host_dns_name(), ip_address_list,
                                self.get_logger())
        except Exception as e:
            raise CommandError(
                "Failed to create computer '%s': " % computername, e)

        self.outf.write("Computer '%s' created successfully\n" % computername)
Example #27
0
class SamrTests(RpcInterfaceTestCase):

    def setUp(self):
        super(SamrTests, self).setUp()
        self.conn = samr.samr("ncalrpc:", self.get_loadparm())
        self.open_samdb()
        self.open_domain_handle()

    #
    # Open the samba database
    #
    def open_samdb(self):
        self.lp = env_loadparm()
        self.domain = os.environ["DOMAIN"]
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(
            session_info=self.session, credentials=self.creds, lp=self.lp)

    #
    # Open a SAMR Domain handle
    def open_domain_handle(self):
        self.handle = self.conn.Connect2(
            None, security.SEC_FLAG_MAXIMUM_ALLOWED)

        self.domain_sid = self.conn.LookupDomain(
            self.handle, lsa.String(self.domain))

        self.domain_handle = self.conn.OpenDomain(
            self.handle, security.SEC_FLAG_MAXIMUM_ALLOWED, self.domain_sid)

    # Filter a list of records, removing those that are not part of the
    # current domain.
    #
    def filter_domain(self, unfiltered):
        def sid(msg):
            sid = ndr_unpack(security.dom_sid, msg["objectSID"][0])
            (x, _) = sid.split()
            return x

        dom_sid = security.dom_sid(self.samdb.get_domain_sid())
        return [x for x in unfiltered if sid(x) == dom_sid]

    def test_connect5(self):
        (level, info, handle) =\
            self.conn.Connect5(None, 0, 1, samr.ConnectInfo1())

    def test_connect2(self):
        handle = self.conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
        self.assertTrue(handle is not None)

    def test_EnumDomains(self):
        handle = self.conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
        toArray(*self.conn.EnumDomains(handle, 0, 4294967295))
        self.conn.Close(handle)

    # Create groups based on the id list supplied, the id is used to
    # form a unique name and description.
    #
    # returns a list of the created dn's, which can be passed to delete_dns
    # to clean up after the test has run.
    def create_groups(self, ids):
        dns = []
        for i in ids:
            name = "SAMR_GRP%d" % i
            dn = "cn=%s,cn=Users,%s" % (name, self.samdb.domain_dn())
            delete_force(self.samdb, dn)

            self.samdb.newgroup(name)
            dns.append(dn)
        return dns

    # Create user accounts based on the id list supplied, the id is used to
    # form a unique name and description.
    #
    # returns a list of the created dn's, which can be passed to delete_dns
    # to clean up after the test has run.
    def create_users(self, ids):
        dns = []
        for i in ids:
            name = "SAMR_USER%d" % i
            dn = "cn=%s,CN=USERS,%s" % (name, self.samdb.domain_dn())
            delete_force(self.samdb, dn)

            # We only need the user to exist, we don't need a password
            self.samdb.newuser(
                name,
                password=None,
                setpassword=False,
                description="Description for " + name,
                givenname="given%dname" % i,
                surname="surname%d" % i)
            dns.append(dn)
        return dns

    # Create computer accounts based on the id list supplied, the id is used to
    # form a unique name and description.
    #
    # returns a list of the created dn's, which can be passed to delete_dns
    # to clean up after the test has run.
    def create_computers(self, ids):
        dns = []
        for i in ids:
            name = "SAMR_CMP%d" % i
            dn = "cn=%s,cn=COMPUTERS,%s" % (name, self.samdb.domain_dn())
            delete_force(self.samdb, dn)

            self.samdb.newcomputer(name, description="Description of " + name)
            dns.append(dn)
        return dns

    # Delete the specified dn's.
    #
    # Used to clean up entries created by individual tests.
    #
    def delete_dns(self, dns):
        for dn in dns:
            delete_force(self.samdb, dn)

    # Common tests for QueryDisplayInfo
    #
    def _test_QueryDisplayInfo(
            self, level, check_results, select, attributes, add_elements):
        #
        # Get the expected results by querying the samdb database directly.
        # We do this rather than use a list of expected results as this runs
        # with other tests so we do not have a known fixed list of elements
        expected = self.samdb.search(expression=select, attrs=attributes)
        self.assertTrue(len(expected) > 0)

        #
        # Perform QueryDisplayInfo with max results greater than the expected
        # number of results.
        (ts, rs, actual) = self.conn.QueryDisplayInfo(
            self.domain_handle, level, 0, 1024, 4294967295)

        self.assertEquals(len(expected), ts)
        self.assertEquals(len(expected), rs)
        check_results(expected, actual.entries)

        #
        # Perform QueryDisplayInfo with max results set to the number of
        # results returned from the first query, should return the same results
        (ts1, rs1, actual1) = self.conn.QueryDisplayInfo(
            self.domain_handle, level, 0, rs, 4294967295)
        self.assertEquals(ts, ts1)
        self.assertEquals(rs, rs1)
        check_results(expected, actual1.entries)

        #
        # Perform QueryDisplayInfo and get the last two results.
        # Note: We are assuming there are at least three entries
        self.assertTrue(ts > 2)
        (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
            self.domain_handle, level, (ts - 2), 2, 4294967295)
        self.assertEquals(ts, ts2)
        self.assertEquals(2, rs2)
        check_results(list(expected)[-2:], actual2.entries)

        #
        # Perform QueryDisplayInfo and get the first two results.
        # Note: We are assuming there are at least three entries
        self.assertTrue(ts > 2)
        (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
            self.domain_handle, level, 0, 2, 4294967295)
        self.assertEquals(ts, ts2)
        self.assertEquals(2, rs2)
        check_results(list(expected)[:2], actual2.entries)

        #
        # Perform QueryDisplayInfo and get two results in the middle of the
        # list i.e. not the first or the last entry.
        # Note: We are assuming there are at least four entries
        self.assertTrue(ts > 3)
        (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
            self.domain_handle, level, 1, 2, 4294967295)
        self.assertEquals(ts, ts2)
        self.assertEquals(2, rs2)
        check_results(list(expected)[1:2], actual2.entries)

        #
        # To check that cached values are being returned rather than the
        # results being re-read from disk we add elements, and request all
        # but the first result.
        #
        dns = add_elements([1000, 1002, 1003, 1004])

        #
        # Perform QueryDisplayInfo and get all but the first result.
        # We should be using the cached results so the entries we just added
        # should not be present
        (ts3, rs3, actual3) = self.conn.QueryDisplayInfo(
            self.domain_handle, level, 1, 1024, 4294967295)
        self.assertEquals(ts, ts3)
        self.assertEquals(len(expected) - 1, rs3)
        check_results(list(expected)[1:], actual3.entries)

        #
        # Perform QueryDisplayInfo and get all the results.
        # As the start index is zero we should reread the data from disk and
        # the added entries should be there
        new = self.samdb.search(expression=select, attrs=attributes)
        (ts4, rs4, actual4) = self.conn.QueryDisplayInfo(
            self.domain_handle, level, 0, 1024, 4294967295)
        self.assertEquals(len(expected) + len(dns), ts4)
        self.assertEquals(len(expected) + len(dns), rs4)
        check_results(new, actual4.entries)

        # Delete the added DN's and query all but the first entry.
        # This should ensure the cached results are used and that the
        # missing entry code is triggered.
        self.delete_dns(dns)
        (ts5, rs5, actual5) = self.conn.QueryDisplayInfo(
            self.domain_handle, level, 1, 1024, 4294967295)
        self.assertEquals(len(expected) + len(dns), ts5)
        # The deleted results will be filtered from the result set so should
        # be missing from the returned results.
        # Note: depending on the GUID order, the first result in the cache may
        #       be a deleted entry, in which case the results will contain all
        #       the expected elements, otherwise the first expected result will
        #       be missing.
        if rs5 == len(expected):
            check_results(expected, actual5.entries)
        elif rs5 == (len(expected) - 1):
            check_results(list(expected)[1:], actual5.entries)
        else:
            self.fail("Incorrect number of entries {0}".format(rs5))

        #
        # Perform QueryDisplayInfo specifying an index past the end of the
        # available data.
        # Should return no data.
        (ts6, rs6, actual6) = self.conn.QueryDisplayInfo(
            self.domain_handle, level, ts5, 1, 4294967295)
        self.assertEquals(ts5, ts6)
        self.assertEquals(0, rs6)

        self.conn.Close(self.handle)

    # Test for QueryDisplayInfo, Level 1
    # Returns the sAMAccountName, displayName and description for all
    # the user accounts.
    #
    def test_QueryDisplayInfo_level_1(self):
        def check_results(expected, actual):
            # Assume the QueryDisplayInfo and ldb.search return their results
            # in the same order
            for (e, a) in zip(expected, actual):
                self.assertTrue(isinstance(a, samr.DispEntryGeneral))
                self.assertEquals(str(e["sAMAccountName"]),
                                  str(a.account_name))

                # The displayName and description are optional.
                # In the expected results they will be missing, in
                # samr.DispEntryGeneral the corresponding attribute will have a
                # length of zero.
                #
                if a.full_name.length == 0:
                    self.assertFalse("displayName" in e)
                else:
                    self.assertEquals(str(e["displayName"]), str(a.full_name))

                if a.description.length == 0:
                    self.assertFalse("description" in e)
                else:
                    self.assertEquals(str(e["description"]),
                                      str(a.description))
        # Create four user accounts
        # to ensure that we have the minimum needed for the tests.
        dns = self.create_users([1, 2, 3, 4])

        select = "(&(objectclass=user)(sAMAccountType={0}))".format(
            ATYPE_NORMAL_ACCOUNT)
        attributes = ["sAMAccountName", "displayName", "description"]
        self._test_QueryDisplayInfo(
            1, check_results, select, attributes, self.create_users)

        self.delete_dns(dns)

    # Test for QueryDisplayInfo, Level 2
    # Returns the sAMAccountName and description for all
    # the computer accounts.
    #
    def test_QueryDisplayInfo_level_2(self):
        def check_results(expected, actual):
            # Assume the QueryDisplayInfo and ldb.search return their results
            # in the same order
            for (e, a) in zip(expected, actual):
                self.assertTrue(isinstance(a, samr.DispEntryFull))
                self.assertEquals(str(e["sAMAccountName"]),
                                  str(a.account_name))

                # The description is optional.
                # In the expected results they will be missing, in
                # samr.DispEntryGeneral the corresponding attribute will have a
                # length of zero.
                #
                if a.description.length == 0:
                    self.assertFalse("description" in e)
                else:
                    self.assertEquals(str(e["description"]),
                                      str(a.description))

        # Create four computer accounts
        # to ensure that we have the minimum needed for the tests.
        dns = self.create_computers([1, 2, 3, 4])

        select = "(&(objectclass=user)(sAMAccountType={0}))".format(
            ATYPE_WORKSTATION_TRUST)
        attributes = ["sAMAccountName", "description"]
        self._test_QueryDisplayInfo(
            2, check_results, select, attributes, self.create_computers)

        self.delete_dns(dns)

    # Test for QueryDisplayInfo, Level 3
    # Returns the sAMAccountName and description for all
    # the groups.
    #
    def test_QueryDisplayInfo_level_3(self):
        def check_results(expected, actual):
            # Assume the QueryDisplayInfo and ldb.search return their results
            # in the same order
            for (e, a) in zip(expected, actual):
                self.assertTrue(isinstance(a, samr.DispEntryFullGroup))
                self.assertEquals(str(e["sAMAccountName"]),
                                  str(a.account_name))

                # The description is optional.
                # In the expected results they will be missing, in
                # samr.DispEntryGeneral the corresponding attribute will have a
                # length of zero.
                #
                if a.description.length == 0:
                    self.assertFalse("description" in e)
                else:
                    self.assertEquals(str(e["description"]),
                                      str(a.description))

        # Create four groups
        # to ensure that we have the minimum needed for the tests.
        dns = self.create_groups([1, 2, 3, 4])

        select = "(&(|(groupType=%d)(groupType=%d))(objectClass=group))" % (
            GTYPE_SECURITY_UNIVERSAL_GROUP,
            GTYPE_SECURITY_GLOBAL_GROUP)
        attributes = ["sAMAccountName", "description"]
        self._test_QueryDisplayInfo(
            3, check_results, select, attributes, self.create_groups)

        self.delete_dns(dns)

    # Test for QueryDisplayInfo, Level 4
    # Returns the sAMAccountName (as an ASCII string)
    # for all the user accounts.
    #
    def test_QueryDisplayInfo_level_4(self):
        def check_results(expected, actual):
            # Assume the QueryDisplayInfo and ldb.search return their results
            # in the same order
            for (e, a) in zip(expected, actual):
                self.assertTrue(isinstance(a, samr.DispEntryAscii))
                self.assertTrue(
                    isinstance(a.account_name, lsa.AsciiStringLarge))
                self.assertEquals(
                    str(e["sAMAccountName"]), str(a.account_name.string))

        # Create four user accounts
        # to ensure that we have the minimum needed for the tests.
        dns = self.create_users([1, 2, 3, 4])

        select = "(&(objectclass=user)(sAMAccountType={0}))".format(
            ATYPE_NORMAL_ACCOUNT)
        attributes = ["sAMAccountName", "displayName", "description"]
        self._test_QueryDisplayInfo(
            4, check_results, select, attributes, self.create_users)

        self.delete_dns(dns)

    # Test for QueryDisplayInfo, Level 5
    # Returns the sAMAccountName (as an ASCII string)
    # for all the groups.
    #
    def test_QueryDisplayInfo_level_5(self):
        def check_results(expected, actual):
            # Assume the QueryDisplayInfo and ldb.search return their results
            # in the same order
            for (e, a) in zip(expected, actual):
                self.assertTrue(isinstance(a, samr.DispEntryAscii))
                self.assertTrue(
                    isinstance(a.account_name, lsa.AsciiStringLarge))
                self.assertEquals(
                    str(e["sAMAccountName"]), str(a.account_name.string))

        # Create four groups
        # to ensure that we have the minimum needed for the tests.
        dns = self.create_groups([1, 2, 3, 4])

        select = "(&(|(groupType=%d)(groupType=%d))(objectClass=group))" % (
            GTYPE_SECURITY_UNIVERSAL_GROUP,
            GTYPE_SECURITY_GLOBAL_GROUP)
        attributes = ["sAMAccountName", "description"]
        self._test_QueryDisplayInfo(
            5, check_results, select, attributes, self.create_groups)

        self.delete_dns(dns)

    def test_EnumDomainGroups(self):
        def check_results(expected, actual):
            for (e, a) in zip(expected, actual):
                self.assertTrue(isinstance(a, samr.SamEntry))
                self.assertEquals(
                    str(e["sAMAccountName"]), str(a.name.string))

        # Create four groups
        # to ensure that we have the minimum needed for the tests.
        dns = self.create_groups([1, 2, 3, 4])

        #
        # Get the expected results by querying the samdb database directly.
        # We do this rather than use a list of expected results as this runs
        # with other tests so we do not have a known fixed list of elements
        select = "(&(|(groupType=%d)(groupType=%d))(objectClass=group))" % (
            GTYPE_SECURITY_UNIVERSAL_GROUP,
            GTYPE_SECURITY_GLOBAL_GROUP)
        attributes = ["sAMAccountName", "objectSID"]
        unfiltered = self.samdb.search(expression=select, attrs=attributes)
        filtered = self.filter_domain(unfiltered)
        self.assertTrue(len(filtered) > 4)

        # Sort the expected results by rid
        expected = sorted(list(filtered), key=rid)

        #
        # Perform EnumDomainGroups with max size greater than the expected
        # number of results. Allow for an extra 10 entries
        #
        max_size = calc_max_size(len(expected) + 10)
        (resume_handle, actual, num_entries) = self.conn.EnumDomainGroups(
            self.domain_handle, 0, max_size)
        self.assertEquals(len(expected), num_entries)
        check_results(expected, actual.entries)

        #
        # Perform EnumDomainGroups with size set to so that it contains
        # 4 entries.
        #
        max_size = calc_max_size(4)
        (resume_handle, actual, num_entries) = self.conn.EnumDomainGroups(
            self.domain_handle, 0, max_size)
        self.assertEquals(4, num_entries)
        check_results(expected[:4], actual.entries)

        #
        # Try calling with resume_handle greater than number of entries
        # Should return no results and a resume handle of 0
        max_size = calc_max_size(1)
        rh = len(expected)
        self.conn.Close(self.handle)
        (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
            self.domain_handle, rh, max_size)

        self.assertEquals(0, num_entries)
        self.assertEquals(0, resume_handle)

        #
        # Enumerate through the domain groups one element at a time.
        #
        max_size = calc_max_size(1)
        actual = []
        (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
            self.domain_handle, 0, max_size)
        while resume_handle:
            self.assertEquals(1, num_entries)
            actual.append(a.entries[0])
            (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
                self.domain_handle, resume_handle, max_size)
        if num_entries:
            actual.append(a.entries[0])

        #
        # Check that the cached results are being returned.
        # Obtain a new resume_handle and insert new entries into the
        # into the DB
        #
        actual = []
        max_size = calc_max_size(1)
        (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
            self.domain_handle, 0, max_size)
        extra_dns = self.create_groups([1000, 1002, 1003, 1004])
        while resume_handle:
            self.assertEquals(1, num_entries)
            actual.append(a.entries[0])
            (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
                self.domain_handle, resume_handle, max_size)
        if num_entries:
            actual.append(a.entries[0])

        self.assertEquals(len(expected), len(actual))
        check_results(expected, actual)

        #
        # Perform EnumDomainGroups, we should read the newly added domains
        #
        max_size = calc_max_size(len(expected) + len(extra_dns) + 10)
        (resume_handle, actual, num_entries) = self.conn.EnumDomainGroups(
            self.domain_handle, 0, max_size)
        self.assertEquals(len(expected) + len(extra_dns), num_entries)

        #
        # Get a new expected result set by querying the database directly
        unfiltered01 = self.samdb.search(expression=select, attrs=attributes)
        filtered01 = self.filter_domain(unfiltered01)
        self.assertTrue(len(filtered01) > len(expected))

        # Sort the expected results by rid
        expected01 = sorted(list(filtered01), key=rid)

        #
        # Now check that we read the new entries.
        #
        check_results(expected01, actual.entries)

        #
        # Check that deleted results are handled correctly.
        # Obtain a new resume_handle and delete entries from the DB.
        #
        actual = []
        max_size = calc_max_size(1)
        (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
            self.domain_handle, 0, max_size)
        self.delete_dns(extra_dns)
        while resume_handle and num_entries:
            self.assertEquals(1, num_entries)
            actual.append(a.entries[0])
            (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
                self.domain_handle, resume_handle, max_size)
        if num_entries:
            actual.append(a.entries[0])

        self.assertEquals(len(expected), len(actual))
        check_results(expected, actual)

        self.delete_dns(dns)

    def test_EnumDomainUsers(self):
        def check_results(expected, actual):
            for (e, a) in zip(expected, actual):
                self.assertTrue(isinstance(a, samr.SamEntry))
                self.assertEquals(
                    str(e["sAMAccountName"]), str(a.name.string))

        # Create four users
        # to ensure that we have the minimum needed for the tests.
        dns = self.create_users([1, 2, 3, 4])

        #
        # Get the expected results by querying the samdb database directly.
        # We do this rather than use a list of expected results as this runs
        # with other tests so we do not have a known fixed list of elements
        select = "(objectClass=user)"
        attributes = ["sAMAccountName", "objectSID", "userAccountConrol"]
        unfiltered = self.samdb.search(expression=select, attrs=attributes)
        filtered = self.filter_domain(unfiltered)
        self.assertTrue(len(filtered) > 4)

        # Sort the expected results by rid
        expected = sorted(list(filtered), key=rid)

        #
        # Perform EnumDomainUsers with max_size greater than required for the
        # expected number of results. We should get all the results.
        #
        max_size = calc_max_size(len(expected) + 10)
        (resume_handle, actual, num_entries) = self.conn.EnumDomainUsers(
            self.domain_handle, 0, 0, max_size)
        self.assertEquals(len(expected), num_entries)
        check_results(expected, actual.entries)

        #
        # Perform EnumDomainUsers with size set to so that it contains
        # 4 entries.
        max_size = calc_max_size(4)
        (resume_handle, actual, num_entries) = self.conn.EnumDomainUsers(
            self.domain_handle, 0, 0, max_size)
        self.assertEquals(4, num_entries)
        check_results(expected[:4], actual.entries)

        #
        # Try calling with resume_handle greater than number of entries
        # Should return no results and a resume handle of 0
        rh = len(expected)
        max_size = calc_max_size(1)
        self.conn.Close(self.handle)
        (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
            self.domain_handle, rh, 0, max_size)

        self.assertEquals(0, num_entries)
        self.assertEquals(0, resume_handle)

        #
        # Enumerate through the domain users one element at a time.
        # We should get all the results.
        #
        actual = []
        max_size = calc_max_size(1)
        (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
            self.domain_handle, 0, 0, max_size)
        while resume_handle:
            self.assertEquals(1, num_entries)
            actual.append(a.entries[0])
            (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
                self.domain_handle, resume_handle, 0, max_size)
        if num_entries:
            actual.append(a.entries[0])

        self.assertEquals(len(expected), len(actual))
        check_results(expected, actual)

        #
        # Check that the cached results are being returned.
        # Obtain a new resume_handle and insert new entries into the
        # into the DB. As the entries were added after the results were cached
        # they should not show up in the returned results.
        #
        actual = []
        max_size = calc_max_size(1)
        (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
            self.domain_handle, 0, 0, max_size)
        extra_dns = self.create_users([1000, 1002, 1003, 1004])
        while resume_handle:
            self.assertEquals(1, num_entries)
            actual.append(a.entries[0])
            (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
                self.domain_handle, resume_handle, 0, max_size)
        if num_entries:
            actual.append(a.entries[0])

        self.assertEquals(len(expected), len(actual))
        check_results(expected, actual)

        #
        # Perform EnumDomainUsers, we should read the newly added groups
        # As resume_handle is zero, the results will be read from disk.
        #
        max_size = calc_max_size(len(expected) + len(extra_dns) + 10)
        (resume_handle, actual, num_entries) = self.conn.EnumDomainUsers(
            self.domain_handle, 0, 0, max_size)
        self.assertEquals(len(expected) + len(extra_dns), num_entries)

        #
        # Get a new expected result set by querying the database directly
        unfiltered01 = self.samdb.search(expression=select, attrs=attributes)
        filtered01 = self.filter_domain(unfiltered01)
        self.assertTrue(len(filtered01) > len(expected))

        # Sort the expected results by rid
        expected01 = sorted(list(filtered01), key=rid)

        #
        # Now check that we read the new entries.
        #
        self.assertEquals(len(expected01), num_entries)
        check_results(expected01, actual.entries)

        #
        # Check that deleted results are handled correctly.
        # Obtain a new resume_handle and delete entries from the DB.
        # We will not see the deleted entries in the result set, as details
        # need to be read from disk. Only the object GUID's are cached.
        #
        actual = []
        max_size = calc_max_size(1)
        (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
            self.domain_handle, 0, 0, max_size)
        self.delete_dns(extra_dns)
        while resume_handle and num_entries:
            self.assertEquals(1, num_entries)
            actual.append(a.entries[0])
            (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
                self.domain_handle, resume_handle, 0, max_size)
        if num_entries:
            actual.append(a.entries[0])

        self.assertEquals(len(expected), len(actual))
        check_results(expected, actual)

        self.delete_dns(dns)

    def test_DomGeneralInformation_num_users(self):
        info = self.conn.QueryDomainInfo(
            self.domain_handle, DomainGeneralInformation)
        #
        # Enumerate through all the domain users and compare the number
        # returned against QueryDomainInfo they should be the same
        max_size = calc_max_size(1)
        (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
            self.domain_handle, 0, 0, max_size)
        count = num_entries
        while resume_handle:
            self.assertEquals(1, num_entries)
            (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
                self.domain_handle, resume_handle, 0, max_size)
            count += num_entries

        self.assertEquals(count, info.num_users)

    def test_DomGeneralInformation_num_groups(self):
        info = self.conn.QueryDomainInfo(
            self.domain_handle, DomainGeneralInformation)
        #
        # Enumerate through all the domain groups and compare the number
        # returned against QueryDomainInfo they should be the same
        max_size = calc_max_size(1)
        (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
            self.domain_handle, 0, max_size)
        count = num_entries
        while resume_handle:
            self.assertEquals(1, num_entries)
            (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
                self.domain_handle, resume_handle, max_size)
            count += num_entries

        self.assertEquals(count, info.num_groups)

    def test_DomGeneralInformation_num_aliases(self):
        info = self.conn.QueryDomainInfo(
            self.domain_handle, DomainGeneralInformation)
        #
        # Enumerate through all the domain aliases and compare the number
        # returned against QueryDomainInfo they should be the same
        max_size = calc_max_size(1)
        (resume_handle, a, num_entries) = self.conn.EnumDomainAliases(
            self.domain_handle, 0, max_size)
        count = num_entries
        while resume_handle:
            self.assertEquals(1, num_entries)
            (resume_handle, a, num_entries) = self.conn.EnumDomainAliases(
                self.domain_handle, resume_handle, max_size)
            count += num_entries

        self.assertEquals(count, info.num_aliases)
Example #28
0
class DsdbTests(TestCase):

    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

        # Create a test user
        user_name = "dsdb-user-" + str(uuid.uuid4().hex[0:6])
        user_pass = samba.generate_random_password(32, 32)
        user_description = "Test user for dsdb test"

        base_dn = self.samdb.domain_dn()

        self.account_dn = "cn=" + user_name + ",cn=Users," + base_dn
        self.samdb.newuser(username=user_name,
                           password=user_pass,
                           description=user_description)
        # Cleanup (teardown)
        self.addCleanup(delete_force, self.samdb, self.account_dn)

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEquals(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0", "local_oid:1.3.6.1.4.1.7165.4.3.25:0"])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val", ldb.FLAG_MOD_REPLACE, "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(13), "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 2)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description", version + 2)
        self.assertEqual(self.samdb.get_attribute_replmetadata_version(dn, "description"), version + 2)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                    base=self.account_dn,
                                    attrs=["replPropertyMetaData"],
                                    controls=["local_oid:%s:0"
                                              % dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                    base=self.account_dn,
                                    attrs=["replPropertyMetaData"],
                                    controls=["local_oid:%s:1"
                                              % dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED])
        except ldb.LdbError as e:
            (errno, estr) = e.args
            if errno != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail("Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                          % e[1])

    # Allocate a unique RID for use in the objectSID tests.
    #
    def allocate_rid(self):
        self.samdb.transaction_start()
        try:
            rid = self.samdb.allocate_rid()
        except:
            self.samdb.transaction_cancel()
            raise
        self.samdb.transaction_commit()
        return str(rid)

    # Ensure that duplicate objectSID's are permitted for foreign security
    # principals.
    #
    def test_duplicate_objectSIDs_allowed_on_foreign_security_principals(self):

        #
        # We need to build a foreign security principal SID
        # i.e a  SID not in the current domain.
        #
        dom_sid = self.samdb.get_domain_sid()
        if str(dom_sid).endswith("0"):
            c = "9"
        else:
            c = "0"
        sid_str = str(dom_sid)[:-1] + c + "-1000"
        sid     = ndr_pack(security.dom_sid(sid_str))
        basedn  = self.samdb.get_default_basedn()
        dn      = "CN=%s,CN=ForeignSecurityPrincipals,%s" % (sid_str, basedn)

        #
        # First without control
        #

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"})
            self.fail("No exception should get ERR_OBJECT_CLASS_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_OBJECT_CLASS_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_MISSING_REQUIRED_ATT
            self.assertTrue(werr in msg, msg)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal",
                "objectSid": sid})
            self.fail("No exception should get ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_ILLEGAL_MOD_OPERATION
            self.assertTrue(werr in msg, msg)

        #
        # We need to use the provision control
        # in order to add foreignSecurityPrincipal
        # objects
        #

        controls = ["provision:0"]
        self.samdb.add({
            "dn": dn,
            "objectClass": "foreignSecurityPrincipal"},
            controls=controls)

        self.samdb.delete(dn)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"},
                controls=controls)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.fail("Got unexpected exception %d - %s "
                      % (code, msg))

        # cleanup
        self.samdb.delete(dn)

    def _test_foreignSecurityPrincipal(self, obj_class, fpo_attr):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn   = self.samdb.get_default_basedn()
        cn       = "dsdb_test_fpo"
        dn_str   = "cn=%s,cn=Users,%s" % (cn, basedn)
        dn = ldb.Dn(self.samdb, dn_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn_str)

        self.samdb.add({
            "dn": dn_str,
            "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_INVALID_GROUP_TYPE
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_NO_SUCH_OBJECT")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_NO_SUCH_OBJECT, str(e))
            werr = "%08X" % werror.WERR_NO_SUCH_MEMBER
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 1)
        self.samdb.delete(res[0].dn)
        self.samdb.delete(dn)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

    def test_foreignSecurityPrincipal_member(self):
        return self._test_foreignSecurityPrincipal(
                "group", "member")

    def test_foreignSecurityPrincipal_MembersForAzRole(self):
        return self._test_foreignSecurityPrincipal(
                "msDS-AzRole", "msDS-MembersForAzRole")

    def test_foreignSecurityPrincipal_NeverRevealGroup(self):
        return self._test_foreignSecurityPrincipal(
                "computer", "msDS-NeverRevealGroup")

    def test_foreignSecurityPrincipal_RevealOnDemandGroup(self):
        return self._test_foreignSecurityPrincipal(
                "computer", "msDS-RevealOnDemandGroup")

    def _test_fail_foreignSecurityPrincipal(self, obj_class, fpo_attr,
                                            msg_exp, lerr_exp, werr_exp,
                                            allow_reference=True):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn   = self.samdb.get_default_basedn()
        cn1       = "dsdb_test_fpo1"
        dn1_str   = "cn=%s,cn=Users,%s" % (cn1, basedn)
        dn1 = ldb.Dn(self.samdb, dn1_str)
        cn2       = "dsdb_test_fpo2"
        dn2_str   = "cn=%s,cn=Users,%s" % (cn2, basedn)
        dn2 = ldb.Dn(self.samdb, dn2_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn1_str)
        self.addCleanup(delete_force, self.samdb, dn2_str)

        self.samdb.add({
            "dn": dn1_str,
            "objectClass": obj_class})

        self.samdb.add({
            "dn": dn2_str,
            "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("%s" % dn2,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            if not allow_reference:
                sel.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            if allow_reference:
                self.fail("Should have not raised an exception: %s" % e)
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(dn2)
        self.samdb.delete(dn1)

    def test_foreignSecurityPrincipal_NonMembers(self):
        return self._test_fail_foreignSecurityPrincipal(
                "group", "msDS-NonMembers",
                "LDB_ERR_UNWILLING_TO_PERFORM/WERR_NOT_SUPPORTED",
                ldb.ERR_UNWILLING_TO_PERFORM, werror.WERR_NOT_SUPPORTED,
                allow_reference=False)

    def test_foreignSecurityPrincipal_HostServiceAccount(self):
        return self._test_fail_foreignSecurityPrincipal(
                "computer", "msDS-HostServiceAccount",
                "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
                ldb.ERR_CONSTRAINT_VIOLATION,
                werror.WERR_DS_NAME_REFERENCE_INVALID)

    def test_foreignSecurityPrincipal_manager(self):
        return self._test_fail_foreignSecurityPrincipal(
                "user", "manager",
                "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
                ldb.ERR_CONSTRAINT_VIOLATION,
                werror.WERR_DS_NAME_REFERENCE_INVALID)

    #
    # Duplicate objectSID's should not be permitted for sids in the local
    # domain. The test sequence is add an object, delete it, then attempt to
    # re-add it, this should fail with a constraint violation
    #
    def test_duplicate_objectSIDs_not_allowed_on_local_objects(self):

        dom_sid = self.samdb.get_domain_sid()
        rid     = self.allocate_rid()
        sid_str = str(dom_sid) + "-" + rid
        sid     = ndr_pack(security.dom_sid(sid_str))
        basedn  = self.samdb.get_default_basedn()
        cn       = "dsdb_test_01"
        dn      = "cn=%s,cn=Users,%s" % (cn, basedn)

        self.samdb.add({
            "dn": dn,
            "objectClass": "user",
            "objectSID": sid})
        self.samdb.delete(dn)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "user",
                "objectSID": sid})
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            if code != ldb.ERR_CONSTRAINT_VIOLATION:
                self.fail("Got %d - %s should have got "
                          "LDB_ERR_CONSTRAINT_VIOLATION"
                          % (code, msg))

    def test_linked_vs_non_linked_reference(self):
        basedn   = self.samdb.get_default_basedn()
        kept_dn_str   = "cn=reference_kept,cn=Users,%s" % (basedn)
        removed_dn_str   = "cn=reference_removed,cn=Users,%s" % (basedn)
        dom_sid = self.samdb.get_domain_sid()
        none_sid_str = str(dom_sid) + "-4294967294"
        none_guid_str = "afafafaf-fafa-afaf-fafa-afafafafafaf"

        self.addCleanup(delete_force, self.samdb, kept_dn_str)
        self.addCleanup(delete_force, self.samdb, removed_dn_str)

        self.samdb.add({
            "dn": kept_dn_str,
            "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=kept_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        kept_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        kept_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        kept_dn = res[0].dn

        self.samdb.add({
            "dn": removed_dn_str,
            "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=removed_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        removed_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        removed_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        self.samdb.delete(removed_dn_str)

        #
        # First try the linked attribute 'manager'
        # by GUID and SID
        #

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                           ldb.FLAG_MOD_ADD,
                                           "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                           ldb.FLAG_MOD_ADD,
                                           "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        #
        # Try the non-linked attribute 'assistant'
        # by GUID and SID, which should work.
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_DELETE,
                                              "assistant")
        self.samdb.modify(msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_DELETE,
                                              "assistant")
        self.samdb.modify(msg)

        #
        # Finally ry the non-linked attribute 'assistant'
        # but with non existing GUID, SID, DN
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("CN=NoneNone,%s" % (basedn),
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % none_sid_str,
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % none_guid_str,
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(kept_dn)

    def test_normalize_dn_in_domain_full(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        full_str = str(full_dn)

        # That is, no change
        self.assertEqual(full_dn,
                         self.samdb.normalize_dn_in_domain(full_str))

    def test_normalize_dn_in_domain_part(self):
        domain_dn = self.samdb.domain_dn()

        part_str = "CN=Users"

        full_dn = ldb.Dn(self.samdb, part_str)
        full_dn.add_base(domain_dn)

        # That is, the domain DN appended
        self.assertEqual(full_dn,
                         self.samdb.normalize_dn_in_domain(part_str))

    def test_normalize_dn_in_domain_full_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        # That is, no change
        self.assertEqual(full_dn,
                         self.samdb.normalize_dn_in_domain(full_dn))

    def test_normalize_dn_in_domain_part_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        # That is, the domain DN appended
        self.assertEqual(ldb.Dn(self.samdb,
                                str(part_dn) + "," + str(domain_dn)),
                         self.samdb.normalize_dn_in_domain(part_dn))