Exemple #1
0
class ManyLDAPTest(samba.tests.TestCase):
    def setUp(self):
        super(ManyLDAPTest, self).setUp()
        self.ldb = SamDB(url,
                         credentials=creds,
                         session_info=system_session(lp),
                         lp=lp)
        self.base_dn = self.ldb.domain_dn()
        self.OU_NAME_MANY = "many_ou" + format(random.randint(0, 99999), "05")
        self.ou_dn = ldb.Dn(
            self.ldb, "ou=" + self.OU_NAME_MANY + "," + str(self.base_dn))

        samba.tests.delete_force(self.ldb,
                                 self.ou_dn,
                                 controls=['tree_delete:1'])

        self.ldb.add({
            "dn": self.ou_dn,
            "objectclass": "organizationalUnit",
            "ou": self.OU_NAME_MANY
        })

        for x in range(2000):
            ou_name = self.OU_NAME_MANY + str(x)
            self.ldb.add({
                "dn": "ou=" + ou_name + "," + str(self.ou_dn),
                "objectclass": "organizationalUnit",
                "ou": ou_name
            })

    def tearDown(self):
        samba.tests.delete_force(self.ldb,
                                 self.ou_dn,
                                 controls=['tree_delete:1'])

    def test_unindexed_iterator_search(self):
        """Testing a search for all the OUs.

        Needed to test that more that IOV_MAX responses can be returned
        """
        if not url.startswith("ldap"):
            self.fail(msg="This test is only valid on ldap")

        count = 0
        msg1 = None
        search1 = self.ldb.search_iterator(
            base=self.ou_dn,
            expression="(ou=" + self.OU_NAME_MANY + "*)",
            scope=ldb.SCOPE_SUBTREE,
            attrs=["objectGUID", "samAccountName"])

        for reply in search1:
            self.assertIsInstance(reply, ldb.Message)
            count += 1
        res1 = search1.result()

        # Check we got everything
        self.assertEqual(count, 2001)
Exemple #2
0
    def test_dont_create_db_existing_tdb_file(self):
        existing_name = self.tempdir + "/existing.db"
        initial = SamDB(url="tdb://" + existing_name, flags=0)
        dn = "dn=,cn=test_dont_create_db_existing_tdb_file"
        initial.add({"dn": dn, "cn": "test_dont_create_db_existing_tdb_file"})

        cn = initial.searchone("cn", dn)
        self.assertEquals(b"test_dont_create_db_existing_tdb_file", cn)

        second = SamDB(url="tdb://" + existing_name)
        cn = second.searchone("cn", dn)
        self.assertEquals(b"test_dont_create_db_existing_tdb_file", cn)
Exemple #3
0
    def test_add_replicated_objects(self):
        for o in (
            {
                'dn': "ou=%s1,%s" % (self.tag, self.base_dn),
                "objectclass": "organizationalUnit"
            },
            {
                'dn': "cn=%s2,%s" % (self.tag, self.base_dn),
                "objectclass": "user"
            },
            {
                'dn': "cn=%s3,%s" % (self.tag, self.base_dn),
                "objectclass": "group"
            },
            {
                'dn': "cn=%s4,%s" % (self.tag, self.service),
                "objectclass": "NTDSConnection",
                'enabledConnection': 'TRUE',
                'fromServer': self.base_dn,
                'options': '0'
            },
        ):
            try:
                self.samdb.add(o)
                self.fail("Failed to fail to add %s" % o['dn'])
            except ldb.LdbError as e:
                (ecode, emsg) = e.args
                if ecode != ldb.ERR_REFERRAL:
                    print(emsg)
                    self.fail("Adding %s: ldb error: %s %s, wanted referral" %
                              (o['dn'], ecode, emsg))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    try:
                        tmpdb = SamDB(address,
                                      credentials=CREDS,
                                      session_info=system_session(LP),
                                      lp=LP)
                        tmpdb.add(o)
                        tmpdb.delete(o['dn'])
                    except ldb.LdbError as e:
                        self.fail("couldn't modify referred location %s" %
                                  address)

                    if address.lower().startswith(
                            self.samdb.domain_dns_name()):
                        self.fail(
                            "referral address did not give a specific DC")
Exemple #4
0
    def test_dont_create_db_existing_tdb_file(self):
        existing_name = self.tempdir + "/existing.db"
        initial = SamDB(url="tdb://" + existing_name, flags=0)
        dn = "dn=,cn=test_dont_create_db_existing_tdb_file"
        initial.add({
            "dn": dn,
            "cn": "test_dont_create_db_existing_tdb_file"
        })

        cn = initial.searchone("cn", dn)
        self.assertEquals(b"test_dont_create_db_existing_tdb_file", cn)

        second = SamDB(url="tdb://" + existing_name)
        cn = second.searchone("cn", dn)
        self.assertEquals(b"test_dont_create_db_existing_tdb_file", cn)
Exemple #5
0
    def test_add_replicated_objects(self):
        for o in (
                {
                    'dn': "ou=%s1,%s" % (self.tag, self.base_dn),
                    "objectclass": "organizationalUnit"
                },
                {
                    'dn': "cn=%s2,%s" % (self.tag, self.base_dn),
                    "objectclass": "user"
                },
                {
                    'dn': "cn=%s3,%s" % (self.tag, self.base_dn),
                    "objectclass": "group"
                },
                {
                    'dn': "cn=%s4,%s" % (self.tag, self.service),
                    "objectclass": "NTDSConnection",
                    'enabledConnection': 'TRUE',
                    'fromServer': self.base_dn,
                    'options': '0'
                },
        ):
            try:
                self.samdb.add(o)
                self.fail("Failed to fail to add %s" % o['dn'])
            except ldb.LdbError as e:
                (ecode, emsg) = e.args
                if ecode != ldb.ERR_REFERRAL:
                    print(emsg)
                    self.fail("Adding %s: ldb error: %s %s, wanted referral" %
                              (o['dn'], ecode, emsg))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    try:
                        tmpdb = SamDB(address, credentials=CREDS,
                                      session_info=system_session(LP), lp=LP)
                        tmpdb.add(o)
                        tmpdb.delete(o['dn'])
                    except ldb.LdbError as e:
                        self.fail("couldn't modify referred location %s" %
                                  address)

                    if address.lower().startswith(self.samdb.domain_dns_name()):
                        self.fail("referral address did not give a specific DC")
Exemple #6
0
    def netlogon(self):
        server = os.environ["SERVER"]
        host = os.environ["SERVER_IP"]
        lp = self.get_loadparm()

        credentials = self.get_credentials()

        session = system_session()
        ldb = SamDB(url="ldap://%s" % host,
                    session_info=session,
                    credentials=credentials,
                    lp=lp)
        machine_pass = samba.generate_random_password(32, 32)
        machine_name = MACHINE_NAME
        machine_dn = "cn=%s,%s" % (machine_name, ldb.domain_dn())

        delete_force(ldb, machine_dn)

        utf16pw = ('"%s"' % get_string(machine_pass)).encode('utf-16-le')
        ldb.add({
            "dn":
            machine_dn,
            "objectclass":
            "computer",
            "sAMAccountName":
            "%s$" % machine_name,
            "userAccountControl":
            str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
            "unicodePwd":
            utf16pw
        })

        machine_creds = Credentials()
        machine_creds.guess(lp)
        machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
        machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
        machine_creds.set_password(machine_pass)
        machine_creds.set_username(machine_name + "$")
        machine_creds.set_workstation(machine_name)

        netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" % server, lp,
                          machine_creds)

        delete_force(ldb, machine_dn)
Exemple #7
0
class AuthLogTestsSamLogon(samba.tests.auth_log_base.AuthLogTestBase):

    def setUp(self):
        super(AuthLogTestsSamLogon, self).setUp()
        self.lp      = samba.tests.env_loadparm()
        self.creds   = Credentials()

        self.session = system_session()
        self.ldb = SamDB(
            session_info=self.session,
            credentials=self.creds,
            lp=self.lp)

        self.domain        = os.environ["DOMAIN"]
        self.netbios_name  = "SamLogonTest"
        self.machinepass   = "******"
        self.remoteAddress = AS_SYSTEM_MAGIC_PATH_TOKEN
        self.base_dn       = self.ldb.domain_dn()
        self.samlogon_dn   = ("cn=%s,cn=users,%s" %
                              (self.netbios_name, self.base_dn))

    def tearDown(self):
        super(AuthLogTestsSamLogon, self).tearDown()
        delete_force(self.ldb, self.samlogon_dn)

    def _test_samlogon(self, binding, creds, checkFunction):

        def isLastExpectedMessage(msg):
            return (
                msg["type"] == "Authentication" and
                msg["Authentication"]["serviceDescription"]  == "SamLogon" and
                msg["Authentication"]["authDescription"]     == "network" and
                msg["Authentication"]["passwordType"]        == "NTLMv2" and
                (msg["Authentication"]["eventId"] ==
                    EVT_ID_SUCCESSFUL_LOGON) and
                (msg["Authentication"]["logonType"] == EVT_LOGON_NETWORK))

        if binding:
            binding = "[schannel,%s]" % binding
        else:
            binding = "[schannel]"

        utf16pw = text_type('"' + self.machinepass + '"').encode('utf-16-le')
        self.ldb.add({
            "dn": self.samlogon_dn,
            "objectclass": "computer",
            "sAMAccountName": "%s$" % self.netbios_name,
            "userAccountControl":
                str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
            "unicodePwd": utf16pw})

        machine_creds = Credentials()
        machine_creds.guess(self.get_loadparm())
        machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
        machine_creds.set_password(self.machinepass)
        machine_creds.set_username(self.netbios_name + "$")

        netlogon_conn = netlogon.netlogon("ncalrpc:%s" % binding,
                                          self.get_loadparm(),
                                          machine_creds)
        challenge = b"abcdefgh"

        target_info = ntlmssp.AV_PAIR_LIST()
        target_info.count = 3

        domainname = ntlmssp.AV_PAIR()
        domainname.AvId = ntlmssp.MsvAvNbDomainName
        domainname.Value = self.domain

        computername = ntlmssp.AV_PAIR()
        computername.AvId = ntlmssp.MsvAvNbComputerName
        computername.Value = self.netbios_name

        eol = ntlmssp.AV_PAIR()
        eol.AvId = ntlmssp.MsvAvEOL
        target_info.pair = [domainname, computername, eol]

        target_info_blob = ndr_pack(target_info)

        response = creds.get_ntlm_response(flags=CLI_CRED_NTLMv2_AUTH,
                                           challenge=challenge,
                                           target_info=target_info_blob)

        netr_flags = 0

        logon_level = netlogon.NetlogonNetworkTransitiveInformation
        logon = samba.dcerpc.netlogon.netr_NetworkInfo()

        logon.challenge = [x if isinstance(x,int) else ord(x) for x in challenge]
        logon.nt = netlogon.netr_ChallengeResponse()
        logon.nt.length = len(response["nt_response"])
        logon.nt.data = [x if isinstance(x,int) else ord(x) for x in response["nt_response"]]
        logon.identity_info = samba.dcerpc.netlogon.netr_IdentityInfo()
        (username, domain) = creds.get_ntlm_username_domain()

        logon.identity_info.domain_name.string = domain
        logon.identity_info.account_name.string = username
        logon.identity_info.workstation.string = creds.get_workstation()

        validation_level = samba.dcerpc.netlogon.NetlogonValidationSamInfo4

        result = netlogon_conn.netr_LogonSamLogonEx(
            os.environ["SERVER"],
            machine_creds.get_workstation(),
            logon_level, logon,
            validation_level, netr_flags)

        (validation, authoritative, netr_flags_out) = result

        messages = self.waitForMessages(isLastExpectedMessage, netlogon_conn)
        checkFunction(messages)

    def samlogon_check(self, messages):

        messages = self.remove_netlogon_messages(messages)
        expected_messages = 5
        self.assertEquals(expected_messages,
                          len(messages),
                          "Did not receive the expected number of messages")

        # Check the first message it should be an Authorization
        msg = messages[0]
        self.assertEquals("Authorization", msg["type"])
        self.assertEquals("DCE/RPC",
                          msg["Authorization"]["serviceDescription"])
        self.assertEquals("ncalrpc", msg["Authorization"]["authType"])
        self.assertEquals("NONE", msg["Authorization"]["transportProtection"])
        self.assertTrue(self.is_guid(msg["Authorization"]["sessionId"]))

    def test_ncalrpc_samlogon(self):

        creds = self.insta_creds(template=self.get_credentials(),
                                 kerberos_state=DONT_USE_KERBEROS)
        try:
            self._test_samlogon("SEAL", creds, self.samlogon_check)
        except Exception as e:
            self.fail("Unexpected exception: " + str(e))
Exemple #8
0
class PassWordHashTests(TestCase):

    def setUp(self):
        super(PassWordHashTests, self).setUp()

    # Add a user to ldb, this will exercise the password_hash code
    # and calculate the appropriate supplemental credentials
    def add_user(self, options=None, clear_text=False):
        self.lp = samba.tests.env_loadparm()
        # set any needed options
        if options is not None:
            for (option,value) in options:
                self.lp.set(option, value)

        self.creds = Credentials()
        self.session = system_session()
        self.ldb = SamDB(
            session_info=self.session,
            credentials=self.creds,
            lp=self.lp)

        # Gets back the basedn
        base_dn = self.ldb.domain_dn()

        # Gets back the configuration basedn
        configuration_dn = self.ldb.get_config_basedn().get_linearized()

        # Get the old "dSHeuristics" if it was set
        dsheuristics = self.ldb.get_dsheuristics()

        # Set the "dSHeuristics" to activate the correct "userPassword"
        # behaviour
        self.ldb.set_dsheuristics("000000001")

        # Reset the "dSHeuristics" as they were before
        self.addCleanup(self.ldb.set_dsheuristics, dsheuristics)

        # Get the old "minPwdAge"
        minPwdAge = self.ldb.get_minPwdAge()

        # Set it temporarily to "0"
        self.ldb.set_minPwdAge("0")
        self.base_dn = self.ldb.domain_dn()

        # Reset the "minPwdAge" as it was before
        self.addCleanup(self.ldb.set_minPwdAge, minPwdAge)

        account_control = 0
        if clear_text:
            # get the current pwdProperties
            pwdProperties = self.ldb.get_pwdProperties()
            # enable clear text properties
            props = int(pwdProperties)
            props |= DOMAIN_PASSWORD_STORE_CLEARTEXT
            self.ldb.set_pwdProperties(str(props))
            # Restore the value on exit.
            self.addCleanup(self.ldb.set_pwdProperties, pwdProperties)
            account_control |= UF_ENCRYPTED_TEXT_PASSWORD_ALLOWED

        # (Re)adds the test user USER_NAME with password USER_PASS
        # and userPrincipalName UPN
        delete_force(self.ldb, "cn=" + USER_NAME + ",cn=users," + self.base_dn)
        self.ldb.add({
             "dn": "cn=" + USER_NAME + ",cn=users," + self.base_dn,
             "objectclass": "user",
             "sAMAccountName": USER_NAME,
             "userPassword": USER_PASS,
             "userPrincipalName": UPN,
             "userAccountControl": str(account_control)
        })

    # Get the supplemental credentials for the user under test
    def get_supplemental_creds(self):
        base = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        res = self.ldb.search(scope=ldb.SCOPE_BASE,
                              base=base,
                              attrs=["supplementalCredentials"])
        self.assertIs( True, len(res) > 0)
        obj = res[0]
        sc_blob = obj["supplementalCredentials"][0]
        sc = ndr_unpack(drsblobs.supplementalCredentialsBlob, sc_blob)
        return sc

    # Calculate and validate a Wdigest value
    def check_digest(self, user, realm, password,  digest):
        expected = calc_digest( user, realm, password)
        actual = binascii.hexlify(bytearray(digest))
        error = "Digest expected[%s], actual[%s], " \
                "user[%s], realm[%s], pass[%s]" % \
                (expected, actual, user, realm, password)
        self.assertEquals(expected, actual, error)

    # Check all of the 29 expected WDigest values
    #
    def check_wdigests(self, digests):

        self.assertEquals(29, digests.num_hashes)

        self.check_digest(USER_NAME,
                          self.lp.get("workgroup"),
                          USER_PASS,
                          digests.hashes[0].hash)
        self.check_digest(USER_NAME.lower(),
                          self.lp.get("workgroup").lower(),
                          USER_PASS,
                          digests.hashes[1].hash)
        self.check_digest(USER_NAME.upper(),
                          self.lp.get("workgroup").upper(),
                          USER_PASS,
                          digests.hashes[2].hash)
        self.check_digest(USER_NAME,
                          self.lp.get("workgroup").upper(),
                          USER_PASS,
                          digests.hashes[3].hash)
        self.check_digest(USER_NAME,
                          self.lp.get("workgroup").lower(),
                          USER_PASS,
                          digests.hashes[4].hash)
        self.check_digest(USER_NAME.upper(),
                          self.lp.get("workgroup").lower(),
                          USER_PASS,
                          digests.hashes[5].hash)
        self.check_digest(USER_NAME.lower(),
                          self.lp.get("workgroup").upper(),
                          USER_PASS,
                          digests.hashes[6].hash)
        self.check_digest(USER_NAME,
                          self.lp.get("realm").lower(),
                          USER_PASS,
                          digests.hashes[7].hash)
        self.check_digest(USER_NAME.lower(),
                          self.lp.get("realm").lower(),
                          USER_PASS,
                          digests.hashes[8].hash)
        self.check_digest(USER_NAME.upper(),
                          self.lp.get("realm"),
                          USER_PASS,
                          digests.hashes[9].hash)
        self.check_digest(USER_NAME,
                          self.lp.get("realm"),
                          USER_PASS,
                          digests.hashes[10].hash)
        self.check_digest(USER_NAME,
                          self.lp.get("realm").lower(),
                          USER_PASS,
                          digests.hashes[11].hash)
        self.check_digest(USER_NAME.upper(),
                          self.lp.get("realm").lower(),
                          USER_PASS,
                          digests.hashes[12].hash)
        self.check_digest(USER_NAME.lower(),
                          self.lp.get("realm"),
                          USER_PASS,
                          digests.hashes[13].hash)
        self.check_digest(UPN,
                          "",
                          USER_PASS,
                          digests.hashes[14].hash)
        self.check_digest(UPN.lower(),
                          "",
                          USER_PASS,
                          digests.hashes[15].hash)
        self.check_digest(UPN.upper(),
                          "",
                          USER_PASS,
                          digests.hashes[16].hash)

        name = "%s\\%s" % (self.lp.get("workgroup"), USER_NAME)
        self.check_digest(name,
                          "",
                          USER_PASS,
                          digests.hashes[17].hash)

        name = "%s\\%s" % (self.lp.get("workgroup").lower(), USER_NAME.lower())
        self.check_digest(name,
                          "",
                          USER_PASS,
                          digests.hashes[18].hash)

        name = "%s\\%s" % (self.lp.get("workgroup").upper(), USER_NAME.upper())
        self.check_digest(name,
                          "",
                          USER_PASS,
                          digests.hashes[19].hash)
        self.check_digest(USER_NAME,
                          "Digest",
                          USER_PASS,
                          digests.hashes[20].hash)
        self.check_digest(USER_NAME.lower(),
                          "Digest",
                          USER_PASS,
                          digests.hashes[21].hash)
        self.check_digest(USER_NAME.upper(),
                          "Digest",
                          USER_PASS,
                          digests.hashes[22].hash)
        self.check_digest(UPN,
                          "Digest",
                          USER_PASS,
                          digests.hashes[23].hash)
        self.check_digest(UPN.lower(),
                          "Digest",
                          USER_PASS,
                          digests.hashes[24].hash)
        self.check_digest(UPN.upper(),
                          "Digest",
                          USER_PASS,
                          digests.hashes[25].hash)
        name = "%s\\%s" % (self.lp.get("workgroup"), USER_NAME)
        self.check_digest(name,
                          "Digest",
                          USER_PASS,
                          digests.hashes[26].hash)

        name = "%s\\%s" % (self.lp.get("workgroup").lower(), USER_NAME.lower())
        self.check_digest(name,
                          "Digest",
                          USER_PASS,
                          digests.hashes[27].hash)

        name = "%s\\%s" % (self.lp.get("workgroup").upper(), USER_NAME.upper())
        self.check_digest(name,
                          "Digest",
                          USER_PASS,
                          digests.hashes[28].hash)
Exemple #9
0
class SiteCoverageTests(samba.tests.TestCase):
    def setUp(self):
        self.prefix = "kcc_"
        self.lp = samba.tests.env_loadparm()

        self.sites = {}
        self.site_links = {}

        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()

        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

    def tearDown(self):
        self.samdb.transaction_start()

        for site in self.sites:
            delete_force(self.samdb, site, controls=['tree_delete:1'])

        for site_link in self.site_links:
            delete_force(self.samdb, site_link)

        self.samdb.transaction_commit()

    def _add_server(self, name, site):
        dn = "CN={},CN=Servers,{}".format(name, site)
        self.samdb.add({
            "dn": dn,
            "objectClass": "server",
            "serverReference": self.samdb.domain_dn()
        })
        return dn

    def _add_site(self, name):
        dn = "CN={},CN=Sites,{}".format(name, self.samdb.get_config_basedn())
        self.samdb.add({"dn": dn, "objectClass": "site"})
        self.samdb.add({
            "dn": "CN=Servers," + dn,
            "objectClass": ["serversContainer"]
        })

        self.sites[dn] = name
        return dn, name.lower()

    def _add_site_link(self, name, links=[], cost=100):
        dn = "CN={},CN=IP,CN=Inter-Site Transports,CN=Sites,{}".format(
            name, self.samdb.get_config_basedn())
        self.samdb.add({
            "dn": dn,
            "objectClass": "siteLink",
            "cost": str(cost),
            "siteList": links
        })
        self.site_links[dn] = name
        return dn

    def test_single_site_link_same_dc_count(self):
        self.samdb.transaction_start()
        site1, name1 = self._add_site(self.prefix + "ABCD")
        site2, name2 = self._add_site(self.prefix + "BCDE")

        uncovered_dn, uncovered = self._add_site(self.prefix + "uncovered")

        self._add_server(self.prefix + "ABCD" + '1', site1)
        self._add_server(self.prefix + "BCDE" + '1', site2)

        self._add_site_link(self.prefix + "link", [site1, site2, uncovered_dn])
        self.samdb.transaction_commit()

        to_cover = uncovered_sites_to_cover(self.samdb, name1)
        to_cover.sort()

        self.assertEqual([uncovered], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name2)
        to_cover.sort()

        self.assertEqual([], to_cover)

    def test_single_site_link_different_dc_count(self):
        self.samdb.transaction_start()
        site1, name1 = self._add_site(self.prefix + "ABCD")
        site2, name2 = self._add_site(self.prefix + "BCDE")

        uncovered_dn, uncovered = self._add_site(self.prefix + "uncovered")

        self._add_server(self.prefix + "ABCD" + '1', site1)
        self._add_server(self.prefix + "ABCD" + '2', site1)
        self._add_server(self.prefix + "BCDE" + '1', site2)
        self._add_server(self.prefix + "BCDE" + '2', site2)
        self._add_server(self.prefix + "BCDE" + '3', site2)

        self._add_site_link(self.prefix + "link", [site1, site2, uncovered_dn])
        self.samdb.transaction_commit()

        to_cover = uncovered_sites_to_cover(self.samdb, name1)
        to_cover.sort()

        self.assertEqual([], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name2)
        to_cover.sort()

        self.assertEqual([uncovered], to_cover)

    def test_two_site_links_same_cost(self):
        self.samdb.transaction_start()
        site1, name1 = self._add_site(self.prefix + "ABCD")
        site2, name2 = self._add_site(self.prefix + "BCDE")

        uncovered_dn, uncovered = self._add_site(self.prefix + "uncovered")

        self._add_server(self.prefix + "ABCD" + '1', site1)
        self._add_server(self.prefix + "ABCD" + '2', site1)
        self._add_server(self.prefix + "BCDE" + '1', site2)
        self._add_server(self.prefix + "BCDE" + '2', site2)
        self._add_server(self.prefix + "BCDE" + '3', site2)

        self._add_site_link(self.prefix + "link1", [site1, uncovered_dn])
        self._add_site_link(self.prefix + "link2", [site2, uncovered_dn])
        self.samdb.transaction_commit()

        to_cover = uncovered_sites_to_cover(self.samdb, name1)
        to_cover.sort()

        self.assertEqual([uncovered], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name2)
        to_cover.sort()

        self.assertEqual([uncovered], to_cover)

    def test_two_site_links_different_costs(self):
        self.samdb.transaction_start()
        site1, name1 = self._add_site(self.prefix + "ABCD")
        site2, name2 = self._add_site(self.prefix + "BCDE")

        uncovered_dn, uncovered = self._add_site(self.prefix + "uncovered")

        self._add_server(self.prefix + "ABCD" + '1', site1)
        self._add_server(self.prefix + "BCDE" + '1', site2)
        self._add_server(self.prefix + "BCDE" + '2', site2)

        self._add_site_link(self.prefix + "link1", [site1, uncovered_dn],
                            cost=50)
        self._add_site_link(self.prefix + "link2", [site2, uncovered_dn],
                            cost=75)
        self.samdb.transaction_commit()

        to_cover = uncovered_sites_to_cover(self.samdb, name1)
        to_cover.sort()

        self.assertEqual([uncovered], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name2)
        to_cover.sort()

        self.assertEqual([], to_cover)

    def test_three_site_links_different_costs(self):
        self.samdb.transaction_start()
        site1, name1 = self._add_site(self.prefix + "ABCD")
        site2, name2 = self._add_site(self.prefix + "BCDE")
        site3, name3 = self._add_site(self.prefix + "CDEF")

        uncovered_dn, uncovered = self._add_site(self.prefix + "uncovered")

        self._add_server(self.prefix + "ABCD" + '1', site1)
        self._add_server(self.prefix + "BCDE" + '1', site2)
        self._add_server(self.prefix + "CDEF" + '1', site3)
        self._add_server(self.prefix + "CDEF" + '2', site3)

        self._add_site_link(self.prefix + "link1", [site1, uncovered_dn],
                            cost=50)
        self._add_site_link(self.prefix + "link2", [site2, uncovered_dn],
                            cost=75)
        self._add_site_link(self.prefix + "link3", [site3, uncovered_dn],
                            cost=60)
        self.samdb.transaction_commit()

        to_cover = uncovered_sites_to_cover(self.samdb, name1)
        to_cover.sort()

        self.assertEqual([uncovered], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name2)
        to_cover.sort()

        self.assertEqual([], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name3)
        to_cover.sort()

        self.assertEqual([], to_cover)

    def test_three_site_links_different_costs(self):
        self.samdb.transaction_start()
        site1, name1 = self._add_site(self.prefix + "ABCD")
        site2, name2 = self._add_site(self.prefix + "BCDE")
        site3, name3 = self._add_site(self.prefix + "CDEF")

        uncovered_dn, uncovered = self._add_site(self.prefix + "uncovered")

        self._add_server(self.prefix + "ABCD" + '1', site1)
        self._add_server(self.prefix + "BCDE" + '1', site2)
        self._add_server(self.prefix + "CDEF" + '1', site3)
        self._add_server(self.prefix + "CDEF" + '2', site3)

        self._add_site_link(self.prefix + "link1", [site1, uncovered_dn],
                            cost=50)
        self._add_site_link(self.prefix + "link2", [site2, uncovered_dn],
                            cost=75)
        self._add_site_link(self.prefix + "link3", [site3, uncovered_dn],
                            cost=50)
        self.samdb.transaction_commit()

        to_cover = uncovered_sites_to_cover(self.samdb, name1)
        to_cover.sort()

        self.assertEqual([uncovered], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name2)
        to_cover.sort()

        self.assertEqual([], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name3)
        to_cover.sort()

        self.assertEqual([uncovered], to_cover)

    def test_complex_setup_with_multiple_uncovered_sites(self):
        self.samdb.transaction_start()
        site1, name1 = self._add_site(self.prefix + "ABCD")
        site2, name2 = self._add_site(self.prefix + "BCDE")
        site3, name3 = self._add_site(self.prefix + "CDEF")

        site4, name4 = self._add_site(self.prefix + "1234")
        site5, name5 = self._add_site(self.prefix + "2345")
        site6, name6 = self._add_site(self.prefix + "3456")

        uncovered_dn1, uncovered1 = self._add_site(self.prefix + "uncovered1")
        uncovered_dn2, uncovered2 = self._add_site(self.prefix + "uncovered2")
        uncovered_dn3, uncovered3 = self._add_site(self.prefix + "uncovered3")

        # Site Link Cluster 1 - Server List
        self._add_server(self.prefix + "ABCD" + '1', site1)

        self._add_server(self.prefix + "BCDE" + '1', site2)
        self._add_server(self.prefix + "BCDE" + '2', site2)

        self._add_server(self.prefix + "CDEF" + '1', site3)
        self._add_server(self.prefix + "CDEF" + '2', site3)
        self._add_server(self.prefix + "CDEF" + '3', site3)

        # Site Link Cluster 2 - Server List
        self._add_server(self.prefix + "1234" + '1', site4)
        self._add_server(self.prefix + "1234" + '2', site4)

        self._add_server(self.prefix + "2345" + '1', site5)
        self._add_server(self.prefix + "2345" + '2', site5)

        self._add_server(self.prefix + "3456" + '1', site6)

        # Join to Uncovered1 (preference to site link cluster 1)
        self._add_site_link(self.prefix + "link1A",
                            [site1, site2, site3, uncovered_dn1],
                            cost=49)
        self._add_site_link(self.prefix + "link2A",
                            [site4, site5, site6, uncovered_dn1],
                            cost=50)

        # Join to Uncovered2 (no preferene on site links)
        self._add_site_link(self.prefix + "link1B",
                            [site1, site2, site3, uncovered_dn2],
                            cost=50)
        self._add_site_link(self.prefix + "link2B",
                            [site4, site5, site6, uncovered_dn2],
                            cost=50)

        # Join to Uncovered3 (preference to site link cluster 2)
        self._add_site_link(self.prefix + "link1C",
                            [site1, site2, site3, uncovered_dn3],
                            cost=50)
        self._add_site_link(self.prefix + "link2C",
                            [site4, site5, site6, uncovered_dn3],
                            cost=49)

        self.samdb.transaction_commit()

        to_cover = uncovered_sites_to_cover(self.samdb, name1)
        to_cover.sort()

        self.assertEqual([], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name2)
        to_cover.sort()

        self.assertEqual([], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name3)
        to_cover.sort()

        self.assertEqual([uncovered1, uncovered2], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name4)
        to_cover.sort()

        self.assertEqual([uncovered2, uncovered3], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name5)
        to_cover.sort()

        self.assertEqual([], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name6)
        to_cover.sort()

        self.assertEqual([], to_cover)

        for to_check in [uncovered1, uncovered2, uncovered3]:
            to_cover = uncovered_sites_to_cover(self.samdb, to_check)
            to_cover.sort()

            self.assertEqual([], to_cover)
Exemple #10
0
    def test_rid_set_dbcheck(self):
        """Perform a join against the RID manager and assert we have a RID Set.
        Using dbcheck, we assert that we can detect out of range users."""

        fsmo_dn = ldb.Dn(self.ldb_dc1, "CN=RID Manager$,CN=System," + self.ldb_dc1.domain_dn())
        (fsmo_owner, fsmo_not_owner) = self._determine_fSMORoleOwner(fsmo_dn)

        targetdir = self._test_join(fsmo_owner['dns_name'], "RIDALLOCTEST6")
        try:
            # Connect to the database
            ldb_url = "tdb://%s" % os.path.join(targetdir, "private/sam.ldb")
            smbconf = os.path.join(targetdir, "etc/smb.conf")

            lp = self.get_loadparm()
            new_ldb = SamDB(ldb_url, credentials=self.get_credentials(),
                            session_info=system_session(lp), lp=lp)

            # 1. Get server name
            res = new_ldb.search(base=ldb.Dn(new_ldb, new_ldb.get_serverName()),
                                 scope=ldb.SCOPE_BASE, attrs=["serverReference"])
            # 2. Get server reference
            server_ref_dn = ldb.Dn(new_ldb, res[0]['serverReference'][0])

            # 3. Assert we get the RID Set
            res = new_ldb.search(base=server_ref_dn,
                                 scope=ldb.SCOPE_BASE, attrs=['rIDSetReferences'])

            self.assertTrue("rIDSetReferences" in res[0])
            rid_set_dn = ldb.Dn(new_ldb, res[0]["rIDSetReferences"][0])

            # 4. Add a new user (triggers RID set work)
            new_ldb.newuser("ridalloctestuser", "P@ssword!")

            # 5. Now fetch the RID SET
            rid_set_res = new_ldb.search(base=rid_set_dn,
                                         scope=ldb.SCOPE_BASE, attrs=['rIDNextRid',
                                                                      'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            last_rid = (0xFFFFFFFF00000000 & next_pool) >> 32

            # 6. Add user above the ridNextRid and at mid-range.
            #
            # We can do this with safety because this is an offline DB that will be
            # destroyed.
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser1,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD, 'objectClass')
            m['objectSid'] = ldb.MessageElement(ndr_pack(security.dom_sid(str(new_ldb.get_domain_sid()) + "-%d" % (last_rid - 10))),
                                                ldb.FLAG_MOD_ADD,
                                                'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            # 7. Check the RID Set
            chk = dbcheck(new_ldb, verbose=False, fix=True, yes=True, quiet=True)

            # Should have one error (wrong rIDNextRID)
            self.assertEqual(chk.check_database(DN=rid_set_dn, scope=ldb.SCOPE_BASE), 1)

            # 8. Assert we get didn't show any other errors
            chk = dbcheck(new_ldb, verbose=False, fix=False, quiet=True)

            rid_set_res = new_ldb.search(base=rid_set_dn,
                                         scope=ldb.SCOPE_BASE, attrs=['rIDNextRid',
                                                                      'rIDAllocationPool'])
            last_allocated_rid = int(rid_set_res[0]["rIDNextRid"][0])
            self.assertEquals(last_allocated_rid, last_rid - 10)

            # 9. Assert that the range wasn't thrown away

            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            self.assertEqual(last_rid, (0xFFFFFFFF00000000 & next_pool) >> 32, "rid pool should have changed")
        finally:
            self._test_force_demote(fsmo_owner['dns_name'], "RIDALLOCTEST6")
            shutil.rmtree(targetdir, ignore_errors=True)
Exemple #11
0
class GroupAuditTests(AuditLogTestBase):

    def setUp(self):
        self.message_type = MSG_GROUP_LOG
        self.event_type   = DSDB_GROUP_EVENT_NAME
        super(GroupAuditTests, self).setUp()

        self.remoteAddress = os.environ["CLIENT_IP"]
        self.server_ip = os.environ["SERVER_IP"]

        host = "ldap://%s" % os.environ["SERVER"]
        self.ldb = SamDB(url=host,
                         session_info=system_session(),
                         credentials=self.get_credentials(),
                         lp=self.get_loadparm())
        self.server = os.environ["SERVER"]

        # Gets back the basedn
        self.base_dn = self.ldb.domain_dn()

        # Get the old "dSHeuristics" if it was set
        dsheuristics = self.ldb.get_dsheuristics()

        # Set the "dSHeuristics" to activate the correct "userPassword"
        # behaviour
        self.ldb.set_dsheuristics("000000001")

        # Reset the "dSHeuristics" as they were before
        self.addCleanup(self.ldb.set_dsheuristics, dsheuristics)

        # Get the old "minPwdAge"
        minPwdAge = self.ldb.get_minPwdAge()

        # Set it temporarily to "0"
        self.ldb.set_minPwdAge("0")
        self.base_dn = self.ldb.domain_dn()

        # Reset the "minPwdAge" as it was before
        self.addCleanup(self.ldb.set_minPwdAge, minPwdAge)

        # (Re)adds the test user USER_NAME with password USER_PASS
        self.ldb.add({
            "dn": "cn=" + USER_NAME + ",cn=users," + self.base_dn,
            "objectclass": "user",
            "sAMAccountName": USER_NAME,
            "userPassword": USER_PASS
        })
        self.ldb.newgroup(GROUP_NAME_01)
        self.ldb.newgroup(GROUP_NAME_02)

    def tearDown(self):
        super(GroupAuditTests, self).tearDown()
        delete_force(self.ldb, "cn=" + USER_NAME + ",cn=users," + self.base_dn)
        self.ldb.deletegroup(GROUP_NAME_01)
        self.ldb.deletegroup(GROUP_NAME_02)

    def test_add_and_remove_users_from_group(self):

        #
        # Wait for the primary group change for the created user.
        #
        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("PrimaryGroup", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=domain users,cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        #
        # Add the user to a group
        #
        self.discardMessages()

        self.ldb.add_remove_group_members(GROUP_NAME_01, [USER_NAME])
        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        #
        # Add the user to another group
        #
        self.discardMessages()
        self.ldb.add_remove_group_members(GROUP_NAME_02, [USER_NAME])

        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_02 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        #
        # Remove the user from a group
        #
        self.discardMessages()
        self.ldb.add_remove_group_members(
            GROUP_NAME_01,
            [USER_NAME],
            add_members_operation=False)
        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("Removed", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        #
        # Re-add the user to a group
        #
        self.discardMessages()
        self.ldb.add_remove_group_members(GROUP_NAME_01, [USER_NAME])

        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

    def test_change_primary_group(self):

        #
        # Wait for the primary group change for the created user.
        #
        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("PrimaryGroup", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=domain users,cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        #
        # Add the user to a group, the user needs to be a member of a group
        # before there primary group can be set to that group.
        #
        self.discardMessages()

        self.ldb.add_remove_group_members(GROUP_NAME_01, [USER_NAME])
        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        #
        # Change the primary group of a user
        #
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        # get the primaryGroupToken of the group
        res = self.ldb.search(base=group_dn, attrs=["primaryGroupToken"],
                              scope=ldb.SCOPE_BASE)
        group_id = res[0]["primaryGroupToken"]

        # set primaryGroupID attribute of the user to that group
        m = ldb.Message()
        m.dn = ldb.Dn(self.ldb, user_dn)
        m["primaryGroupID"] = ldb.MessageElement(
            group_id,
            FLAG_MOD_REPLACE,
            "primaryGroupID")
        self.discardMessages()
        self.ldb.modify(m)

        #
        # Wait for the primary group change.
        # Will see the user removed from the new group
        #          the user added to their old primary group
        #          and a new primary group event.
        #
        messages = self.waitForMessages(3)
        print("Received %d messages" % len(messages))
        self.assertEquals(3,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["groupChange"]
        self.assertEqual("Removed", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        audit = messages[1]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=domain users,cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        audit = messages[2]["groupChange"]

        self.assertEqual("PrimaryGroup", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")
Exemple #12
0
class RodcTests(samba.tests.TestCase):

    def setUp(self):
        super(RodcTests, self).setUp()
        self.samdb = SamDB(HOST, credentials=CREDS,
                           session_info=system_session(LP), lp=LP)

        self.base_dn = self.samdb.domain_dn()

        root = self.samdb.search(base='', scope=ldb.SCOPE_BASE,
                                 attrs=['dsServiceName'])
        self.service = root[0]['dsServiceName'][0]
        self.tag = uuid.uuid4().hex

    def test_add_replicated_objects(self):
        for o in (
                {
                    'dn': "ou=%s1,%s" % (self.tag, self.base_dn),
                    "objectclass": "organizationalUnit"
                },
                {
                    'dn': "cn=%s2,%s" % (self.tag, self.base_dn),
                    "objectclass": "user"
                },
                {
                    'dn': "cn=%s3,%s" % (self.tag, self.base_dn),
                    "objectclass": "group"
                },
                {
                    'dn': "cn=%s4,%s" % (self.tag, self.service),
                    "objectclass": "NTDSConnection",
                    'enabledConnection': 'TRUE',
                    'fromServer': self.base_dn,
                    'options': '0'
                },
        ):
            try:
                self.samdb.add(o)
                self.fail("Failed to fail to add %s" % o['dn'])
            except ldb.LdbError as e:
                (ecode, emsg) = e.args
                if ecode != ldb.ERR_REFERRAL:
                    print(emsg)
                    self.fail("Adding %s: ldb error: %s %s, wanted referral" %
                              (o['dn'], ecode, emsg))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    try:
                        tmpdb = SamDB(address, credentials=CREDS,
                                      session_info=system_session(LP), lp=LP)
                        tmpdb.add(o)
                        tmpdb.delete(o['dn'])
                    except ldb.LdbError as e:
                        self.fail("couldn't modify referred location %s" %
                                  address)

                    if address.lower().startswith(self.samdb.domain_dns_name()):
                        self.fail("referral address did not give a specific DC")

    def test_modify_replicated_attributes(self):
        # some timestamp ones
        dn = 'CN=Guest,CN=Users,' + self.base_dn
        value = 'hallooo'
        for attr in ['carLicense', 'middleName']:
            msg = ldb.Message()
            msg.dn = ldb.Dn(self.samdb, dn)
            msg[attr] = ldb.MessageElement(value,
                                           ldb.FLAG_MOD_REPLACE,
                                           attr)
            try:
                self.samdb.modify(msg)
                self.fail("Failed to fail to modify %s %s" % (dn, attr))
            except ldb.LdbError as e1:
                (ecode, emsg) = e1.args
                if ecode != ldb.ERR_REFERRAL:
                    self.fail("Failed to REFER when trying to modify %s %s" %
                              (dn, attr))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    try:
                        tmpdb = SamDB(address, credentials=CREDS,
                                      session_info=system_session(LP), lp=LP)
                        tmpdb.modify(msg)
                    except ldb.LdbError as e:
                        self.fail("couldn't modify referred location %s" %
                                  address)

                    if address.lower().startswith(self.samdb.domain_dns_name()):
                        self.fail("referral address did not give a specific DC")

    def test_modify_nonreplicated_attributes(self):
        # some timestamp ones
        dn = 'CN=Guest,CN=Users,' + self.base_dn
        value = '123456789'
        for attr in ['badPwdCount', 'lastLogon', 'lastLogoff']:
            m = ldb.Message()
            m.dn = ldb.Dn(self.samdb, dn)
            m[attr] = ldb.MessageElement(value,
                                         ldb.FLAG_MOD_REPLACE,
                                         attr)
            # Windows refers these ones even though they are non-replicated
            try:
                self.samdb.modify(m)
                self.fail("Failed to fail to modify %s %s" % (dn, attr))
            except ldb.LdbError as e2:
                (ecode, emsg) = e2.args
                if ecode != ldb.ERR_REFERRAL:
                    self.fail("Failed to REFER when trying to modify %s %s" %
                              (dn, attr))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    if address.lower().startswith(self.samdb.domain_dns_name()):
                        self.fail("referral address did not give a specific DC")

    def test_modify_nonreplicated_reps_attributes(self):
        # some timestamp ones
        dn = self.base_dn

        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, dn)
        attr = 'repsFrom'

        res = self.samdb.search(dn, scope=ldb.SCOPE_BASE,
                                attrs=['repsFrom'])
        rep = ndr_unpack(drsblobs.repsFromToBlob, res[0]['repsFrom'][0],
                         allow_remaining=True)
        rep.ctr.result_last_attempt = -1
        value = ndr_pack(rep)

        m[attr] = ldb.MessageElement(value,
                                     ldb.FLAG_MOD_REPLACE,
                                     attr)
        try:
            self.samdb.modify(m)
            self.fail("Failed to fail to modify %s %s" % (dn, attr))
        except ldb.LdbError as e3:
            (ecode, emsg) = e3.args
            if ecode != ldb.ERR_REFERRAL:
                self.fail("Failed to REFER when trying to modify %s %s" %
                          (dn, attr))
            else:
                m = re.search(r'(ldap://[^>]+)>', emsg)
                if m is None:
                    self.fail("referral seems not to refer to anything")
                address = m.group(1)

                if address.lower().startswith(self.samdb.domain_dns_name()):
                    self.fail("referral address did not give a specific DC")

    def test_delete_special_objects(self):
        dn = 'CN=Guest,CN=Users,' + self.base_dn
        try:
            self.samdb.delete(dn)
            self.fail("Failed to fail to delete %s" % (dn))
        except ldb.LdbError as e4:
            (ecode, emsg) = e4.args
            if ecode != ldb.ERR_REFERRAL:
                print(ecode, emsg)
                self.fail("Failed to REFER when trying to delete %s" % dn)
            else:
                m = re.search(r'(ldap://[^>]+)>', emsg)
                if m is None:
                    self.fail("referral seems not to refer to anything")
                address = m.group(1)

                if address.lower().startswith(self.samdb.domain_dns_name()):
                    self.fail("referral address did not give a specific DC")

    def test_no_delete_nonexistent_objects(self):
        dn = 'CN=does-not-exist-%s,CN=Users,%s' % (self.tag, self.base_dn)
        try:
            self.samdb.delete(dn)
            self.fail("Failed to fail to delete %s" % (dn))
        except ldb.LdbError as e5:
            (ecode, emsg) = e5.args
            if ecode != ldb.ERR_NO_SUCH_OBJECT:
                print(ecode, emsg)
                self.fail("Failed to NO_SUCH_OBJECT when trying to delete "
                          "%s (which does not exist)" % dn)
Exemple #13
0
class PyCredentialsTests(TestCase):
    def setUp(self):
        super(PyCredentialsTests, self).setUp()

        self.server = os.environ["SERVER"]
        self.domain = os.environ["DOMAIN"]
        self.host = os.environ["SERVER_IP"]
        self.lp = self.get_loadparm()

        self.credentials = self.get_credentials()

        self.session = system_session()
        self.ldb = SamDB(url="ldap://%s" % self.host,
                         session_info=self.session,
                         credentials=self.credentials,
                         lp=self.lp)

        self.create_machine_account()
        self.create_user_account()

    def tearDown(self):
        super(PyCredentialsTests, self).tearDown()
        delete_force(self.ldb, self.machine_dn)
        delete_force(self.ldb, self.user_dn)

    # Until a successful netlogon connection has been established there will
    # not be a valid authenticator associated with the credentials
    # and new_client_authenticator should throw a ValueError
    def test_no_netlogon_connection(self):
        self.assertRaises(ValueError,
                          self.machine_creds.new_client_authenticator)

    # Once a netlogon connection has been established,
    # new_client_authenticator should return a value
    #
    def test_have_netlogon_connection(self):
        c = self.get_netlogon_connection()
        a = self.machine_creds.new_client_authenticator()
        self.assertIsNotNone(a)

    # Get an authenticator and use it on a sequence of operations requiring
    # an authenticator
    def test_client_authenticator(self):
        c = self.get_netlogon_connection()
        (authenticator, subsequent) = self.get_authenticator(c)
        self.do_NetrLogonSamLogonWithFlags(c, authenticator, subsequent)
        (authenticator, subsequent) = self.get_authenticator(c)
        self.do_NetrLogonGetDomainInfo(c, authenticator, subsequent)
        (authenticator, subsequent) = self.get_authenticator(c)
        self.do_NetrLogonGetDomainInfo(c, authenticator, subsequent)
        (authenticator, subsequent) = self.get_authenticator(c)
        self.do_NetrLogonGetDomainInfo(c, authenticator, subsequent)

    # Test Credentials.encrypt_netr_crypt_password
    # By performing a NetrServerPasswordSet2
    # And the logging on using the new password.
    def test_encrypt_netr_password(self):
        # Change the password
        self.do_Netr_ServerPasswordSet2()
        # Now use the new password to perform an operation
        self.do_DsrEnumerateDomainTrusts()

# Change the current machine account pazssword with a
# netr_ServerPasswordSet2 call.

    def do_Netr_ServerPasswordSet2(self):
        c = self.get_netlogon_connection()
        (authenticator, subsequent) = self.get_authenticator(c)
        PWD_LEN = 32
        DATA_LEN = 512
        newpass = samba.generate_random_password(PWD_LEN, PWD_LEN)
        filler = [ord(x) for x in os.urandom(DATA_LEN - PWD_LEN)]
        pwd = netlogon.netr_CryptPassword()
        pwd.length = PWD_LEN
        pwd.data = filler + [ord(x) for x in newpass]
        self.machine_creds.encrypt_netr_crypt_password(pwd)
        c.netr_ServerPasswordSet2(self.server,
                                  self.machine_creds.get_workstation(),
                                  SEC_CHAN_WKSTA, self.machine_name,
                                  authenticator, pwd)

        self.machine_pass = newpass
        self.machine_creds.set_password(newpass)

    # Perform a DsrEnumerateDomainTrusts, this provides confirmation that
    # a netlogon connection has been correctly established
    def do_DsrEnumerateDomainTrusts(self):
        c = self.get_netlogon_connection()
        trusts = c.netr_DsrEnumerateDomainTrusts(
            self.server, netlogon.NETR_TRUST_FLAG_IN_FOREST
            | netlogon.NETR_TRUST_FLAG_OUTBOUND
            | netlogon.NETR_TRUST_FLAG_INBOUND)

    # Establish sealed schannel netlogon connection over TCP/IP
    #
    def get_netlogon_connection(self):
        return netlogon.netlogon(
            "ncacn_ip_tcp:%s[schannel,seal]" % self.server, self.lp,
            self.machine_creds)

    #
    # Create the machine account
    def create_machine_account(self):
        self.machine_pass = samba.generate_random_password(32, 32)
        self.machine_name = MACHINE_NAME
        self.machine_dn = "cn=%s,%s" % (self.machine_name,
                                        self.ldb.domain_dn())

        # remove the account if it exists, this will happen if a previous test
        # run failed
        delete_force(self.ldb, self.machine_dn)

        utf16pw = unicode('"' + self.machine_pass.encode('utf-8') + '"',
                          'utf-8').encode('utf-16-le')
        self.ldb.add({
            "dn":
            self.machine_dn,
            "objectclass":
            "computer",
            "sAMAccountName":
            "%s$" % self.machine_name,
            "userAccountControl":
            str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
            "unicodePwd":
            utf16pw
        })

        self.machine_creds = Credentials()
        self.machine_creds.guess(self.get_loadparm())
        self.machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
        self.machine_creds.set_password(self.machine_pass)
        self.machine_creds.set_username(self.machine_name + "$")
        self.machine_creds.set_workstation(self.machine_name)

    #
    # Create a test user account
    def create_user_account(self):
        self.user_pass = samba.generate_random_password(32, 32)
        self.user_name = USER_NAME
        self.user_dn = "cn=%s,%s" % (self.user_name, self.ldb.domain_dn())

        # remove the account if it exists, this will happen if a previous test
        # run failed
        delete_force(self.ldb, self.user_dn)

        utf16pw = unicode('"' + self.user_pass.encode('utf-8') + '"',
                          'utf-8').encode('utf-16-le')
        self.ldb.add({
            "dn": self.user_dn,
            "objectclass": "user",
            "sAMAccountName": "%s" % self.user_name,
            "userAccountControl": str(UF_NORMAL_ACCOUNT),
            "unicodePwd": utf16pw
        })

        self.user_creds = Credentials()
        self.user_creds.guess(self.get_loadparm())
        self.user_creds.set_password(self.user_pass)
        self.user_creds.set_username(self.user_name)
        self.user_creds.set_workstation(self.machine_name)
        pass

    #
    # Get the authenticator from the machine creds.
    def get_authenticator(self, c):
        auth = self.machine_creds.new_client_authenticator()
        current = netr_Authenticator()
        current.cred.data = [ord(x) for x in auth["credential"]]
        current.timestamp = auth["timestamp"]

        subsequent = netr_Authenticator()
        return (current, subsequent)

    def do_NetrLogonSamLogonWithFlags(self, c, current, subsequent):
        logon = samlogon_logon_info(self.domain, self.machine_name,
                                    self.user_creds)

        logon_level = netlogon.NetlogonNetworkTransitiveInformation
        validation_level = netlogon.NetlogonValidationSamInfo4
        netr_flags = 0
        c.netr_LogonSamLogonWithFlags(self.server,
                                      self.user_creds.get_workstation(),
                                      current, subsequent, logon_level, logon,
                                      validation_level, netr_flags)

    def do_NetrLogonGetDomainInfo(self, c, current, subsequent):
        query = netr_WorkstationInformation()

        c.netr_LogonGetDomainInfo(self.server,
                                  self.user_creds.get_workstation(), current,
                                  subsequent, 2, query)
Exemple #14
0
class SiteCoverageTests(samba.tests.TestCase):

    def setUp(self):
        self.prefix = "kcc_"
        self.lp = samba.tests.env_loadparm()

        self.sites = {}
        self.site_links = {}

        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()

        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

    def tearDown(self):
        self.samdb.transaction_start()

        for site in self.sites:
            delete_force(self.samdb, site, controls=['tree_delete:1'])

        for site_link in self.site_links:
            delete_force(self.samdb, site_link)

        self.samdb.transaction_commit()

    def _add_server(self, name, site):
        dn = "CN={},CN=Servers,{}".format(name, site)
        self.samdb.add({
            "dn": dn,
            "objectClass": "server",
            "serverReference": self.samdb.domain_dn()
        })
        return dn

    def _add_site(self, name):
        dn = "CN={},CN=Sites,{}".format(
            name, self.samdb.get_config_basedn()
        )
        self.samdb.add({
            "dn": dn,
            "objectClass": "site"
        })
        self.samdb.add({
            "dn": "CN=Servers," + dn,
            "objectClass": ["serversContainer"]
        })

        self.sites[dn] = name
        return dn, name.lower()

    def _add_site_link(self, name, links=[], cost=100):
        dn = "CN={},CN=IP,CN=Inter-Site Transports,CN=Sites,{}".format(
            name, self.samdb.get_config_basedn()
        )
        self.samdb.add({
            "dn": dn,
            "objectClass": "siteLink",
            "cost": str(cost),
            "siteList": links
        })
        self.site_links[dn] = name
        return dn

    def test_single_site_link_same_dc_count(self):
        self.samdb.transaction_start()
        site1, name1 = self._add_site(self.prefix + "ABCD")
        site2, name2 = self._add_site(self.prefix + "BCDE")

        uncovered_dn, uncovered = self._add_site(self.prefix + "uncovered")

        self._add_server(self.prefix + "ABCD" + '1', site1)
        self._add_server(self.prefix + "BCDE" + '1', site2)

        self._add_site_link(self.prefix + "link",
                            [site1, site2, uncovered_dn])
        self.samdb.transaction_commit()

        to_cover = uncovered_sites_to_cover(self.samdb, name1)
        to_cover.sort()

        self.assertEqual([uncovered], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name2)
        to_cover.sort()

        self.assertEqual([], to_cover)

    def test_single_site_link_different_dc_count(self):
        self.samdb.transaction_start()
        site1, name1 = self._add_site(self.prefix + "ABCD")
        site2, name2 = self._add_site(self.prefix + "BCDE")

        uncovered_dn, uncovered = self._add_site(self.prefix + "uncovered")

        self._add_server(self.prefix + "ABCD" + '1', site1)
        self._add_server(self.prefix + "ABCD" + '2', site1)
        self._add_server(self.prefix + "BCDE" + '1', site2)
        self._add_server(self.prefix + "BCDE" + '2', site2)
        self._add_server(self.prefix + "BCDE" + '3', site2)

        self._add_site_link(self.prefix + "link",
                            [site1, site2, uncovered_dn])
        self.samdb.transaction_commit()

        to_cover = uncovered_sites_to_cover(self.samdb, name1)
        to_cover.sort()

        self.assertEqual([], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name2)
        to_cover.sort()

        self.assertEqual([uncovered], to_cover)

    def test_two_site_links_same_cost(self):
        self.samdb.transaction_start()
        site1, name1 = self._add_site(self.prefix + "ABCD")
        site2, name2 = self._add_site(self.prefix + "BCDE")

        uncovered_dn, uncovered = self._add_site(self.prefix + "uncovered")

        self._add_server(self.prefix + "ABCD" + '1', site1)
        self._add_server(self.prefix + "ABCD" + '2', site1)
        self._add_server(self.prefix + "BCDE" + '1', site2)
        self._add_server(self.prefix + "BCDE" + '2', site2)
        self._add_server(self.prefix + "BCDE" + '3', site2)

        self._add_site_link(self.prefix + "link1",
                            [site1, uncovered_dn])
        self._add_site_link(self.prefix + "link2",
                            [site2, uncovered_dn])
        self.samdb.transaction_commit()

        to_cover = uncovered_sites_to_cover(self.samdb, name1)
        to_cover.sort()

        self.assertEqual([uncovered], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name2)
        to_cover.sort()

        self.assertEqual([uncovered], to_cover)

    def test_two_site_links_different_costs(self):
        self.samdb.transaction_start()
        site1, name1 = self._add_site(self.prefix + "ABCD")
        site2, name2 = self._add_site(self.prefix + "BCDE")

        uncovered_dn, uncovered = self._add_site(self.prefix + "uncovered")

        self._add_server(self.prefix + "ABCD" + '1', site1)
        self._add_server(self.prefix + "BCDE" + '1', site2)
        self._add_server(self.prefix + "BCDE" + '2', site2)

        self._add_site_link(self.prefix + "link1",
                            [site1, uncovered_dn],
                            cost=50)
        self._add_site_link(self.prefix + "link2",
                            [site2, uncovered_dn],
                            cost=75)
        self.samdb.transaction_commit()

        to_cover = uncovered_sites_to_cover(self.samdb, name1)
        to_cover.sort()

        self.assertEqual([uncovered], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name2)
        to_cover.sort()

        self.assertEqual([], to_cover)

    def test_three_site_links_different_costs(self):
        self.samdb.transaction_start()
        site1, name1 = self._add_site(self.prefix + "ABCD")
        site2, name2 = self._add_site(self.prefix + "BCDE")
        site3, name3 = self._add_site(self.prefix + "CDEF")

        uncovered_dn, uncovered = self._add_site(self.prefix + "uncovered")

        self._add_server(self.prefix + "ABCD" + '1', site1)
        self._add_server(self.prefix + "BCDE" + '1', site2)
        self._add_server(self.prefix + "CDEF" + '1', site3)
        self._add_server(self.prefix + "CDEF" + '2', site3)

        self._add_site_link(self.prefix + "link1",
                            [site1, uncovered_dn],
                            cost=50)
        self._add_site_link(self.prefix + "link2",
                            [site2, uncovered_dn],
                            cost=75)
        self._add_site_link(self.prefix + "link3",
                            [site3, uncovered_dn],
                            cost=60)
        self.samdb.transaction_commit()

        to_cover = uncovered_sites_to_cover(self.samdb, name1)
        to_cover.sort()

        self.assertEqual([uncovered], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name2)
        to_cover.sort()

        self.assertEqual([], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name3)
        to_cover.sort()

        self.assertEqual([], to_cover)

    def test_three_site_links_different_costs(self):
        self.samdb.transaction_start()
        site1, name1 = self._add_site(self.prefix + "ABCD")
        site2, name2 = self._add_site(self.prefix + "BCDE")
        site3, name3 = self._add_site(self.prefix + "CDEF")

        uncovered_dn, uncovered = self._add_site(self.prefix + "uncovered")

        self._add_server(self.prefix + "ABCD" + '1', site1)
        self._add_server(self.prefix + "BCDE" + '1', site2)
        self._add_server(self.prefix + "CDEF" + '1', site3)
        self._add_server(self.prefix + "CDEF" + '2', site3)

        self._add_site_link(self.prefix + "link1",
                            [site1, uncovered_dn],
                            cost=50)
        self._add_site_link(self.prefix + "link2",
                            [site2, uncovered_dn],
                            cost=75)
        self._add_site_link(self.prefix + "link3",
                            [site3, uncovered_dn],
                            cost=50)
        self.samdb.transaction_commit()

        to_cover = uncovered_sites_to_cover(self.samdb, name1)
        to_cover.sort()

        self.assertEqual([uncovered], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name2)
        to_cover.sort()

        self.assertEqual([], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name3)
        to_cover.sort()

        self.assertEqual([uncovered], to_cover)

    def test_complex_setup_with_multiple_uncovered_sites(self):
        self.samdb.transaction_start()
        site1, name1 = self._add_site(self.prefix + "ABCD")
        site2, name2 = self._add_site(self.prefix + "BCDE")
        site3, name3 = self._add_site(self.prefix + "CDEF")

        site4, name4 = self._add_site(self.prefix + "1234")
        site5, name5 = self._add_site(self.prefix + "2345")
        site6, name6 = self._add_site(self.prefix + "3456")

        uncovered_dn1, uncovered1 = self._add_site(self.prefix + "uncovered1")
        uncovered_dn2, uncovered2 = self._add_site(self.prefix + "uncovered2")
        uncovered_dn3, uncovered3 = self._add_site(self.prefix + "uncovered3")

        # Site Link Cluster 1 - Server List
        self._add_server(self.prefix + "ABCD" + '1', site1)

        self._add_server(self.prefix + "BCDE" + '1', site2)
        self._add_server(self.prefix + "BCDE" + '2', site2)

        self._add_server(self.prefix + "CDEF" + '1', site3)
        self._add_server(self.prefix + "CDEF" + '2', site3)
        self._add_server(self.prefix + "CDEF" + '3', site3)

        # Site Link Cluster 2 - Server List
        self._add_server(self.prefix + "1234" + '1', site4)
        self._add_server(self.prefix + "1234" + '2', site4)

        self._add_server(self.prefix + "2345" + '1', site5)
        self._add_server(self.prefix + "2345" + '2', site5)

        self._add_server(self.prefix + "3456" + '1', site6)

        # Join to Uncovered1 (preference to site link cluster 1)
        self._add_site_link(self.prefix + "link1A",
                            [site1, site2, site3, uncovered_dn1],
                            cost=49)
        self._add_site_link(self.prefix + "link2A",
                            [site4, site5, site6, uncovered_dn1],
                            cost=50)

        # Join to Uncovered2 (no preferene on site links)
        self._add_site_link(self.prefix + "link1B",
                            [site1, site2, site3, uncovered_dn2],
                            cost=50)
        self._add_site_link(self.prefix + "link2B",
                            [site4, site5, site6, uncovered_dn2],
                            cost=50)

        # Join to Uncovered3 (preference to site link cluster 2)
        self._add_site_link(self.prefix + "link1C",
                            [site1, site2, site3, uncovered_dn3],
                            cost=50)
        self._add_site_link(self.prefix + "link2C",
                            [site4, site5, site6, uncovered_dn3],
                            cost=49)

        self.samdb.transaction_commit()

        to_cover = uncovered_sites_to_cover(self.samdb, name1)
        to_cover.sort()

        self.assertEqual([], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name2)
        to_cover.sort()

        self.assertEqual([], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name3)
        to_cover.sort()

        self.assertEqual([uncovered1, uncovered2], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name4)
        to_cover.sort()

        self.assertEqual([uncovered2, uncovered3], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name5)
        to_cover.sort()

        self.assertEqual([], to_cover)

        to_cover = uncovered_sites_to_cover(self.samdb, name6)
        to_cover.sort()

        self.assertEqual([], to_cover)

        for to_check in [uncovered1, uncovered2, uncovered3]:
            to_cover = uncovered_sites_to_cover(self.samdb, to_check)
            to_cover.sort()

            self.assertEqual([], to_cover)
Exemple #15
0
class PasswordTests(PasswordTestCase):

    def setUp(self):
        super(PasswordTests, self).setUp()
        self.ldb = SamDB(url=host, session_info=system_session(lp), credentials=creds, lp=lp)

        # Gets back the basedn
        base_dn = self.ldb.domain_dn()

        # Gets back the configuration basedn
        configuration_dn = self.ldb.get_config_basedn().get_linearized()

        # permit password changes during this test
        self.allow_password_changes()

        self.base_dn = self.ldb.domain_dn()

        # (Re)adds the test user "testuser" with no password atm
        delete_force(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        self.ldb.add({
             "dn": "cn=testuser,cn=users," + self.base_dn,
             "objectclass": "user",
             "sAMAccountName": "testuser"})

        # Tests a password change when we don't have any password yet with a
        # wrong old password
        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: noPassword
add: userPassword
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e:
            (num, msg) = e.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)
            # Windows (2008 at least) seems to have some small bug here: it
            # returns "0000056A" on longer (always wrong) previous passwords.
            self.assertTrue('00000056' in msg)

        # Sets the initial user password with a "special" password change
        # I think that this internally is a password set operation and it can
        # only be performed by someone which has password set privileges on the
        # account (at least in s4 we do handle it like that).
        self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
add: userPassword
userPassword: thatsAcomplPASS1
""")

        # But in the other way around this special syntax doesn't work
        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
""")
            self.fail()
        except LdbError as e1:
            (num, _) = e1.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        # Enables the user account
        self.ldb.enable_account("(sAMAccountName=testuser)")

        # Open a second LDB connection with the user credentials. Use the
        # command line credentials for informations like the domain, the realm
        # and the workstation.
        creds2 = Credentials()
        creds2.set_username("testuser")
        creds2.set_password("thatsAcomplPASS1")
        creds2.set_domain(creds.get_domain())
        creds2.set_realm(creds.get_realm())
        creds2.set_workstation(creds.get_workstation())
        creds2.set_gensec_features(creds2.get_gensec_features()
                                                          | gensec.FEATURE_SEAL)
        self.ldb2 = SamDB(url=host, credentials=creds2, lp=lp)

    def test_unicodePwd_hash_set(self):
        """Performs a password hash set operation on 'unicodePwd' which should be prevented"""
        # Notice: Direct hash password sets should never work

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["unicodePwd"] = MessageElement("XXXXXXXXXXXXXXXX", FLAG_MOD_REPLACE,
          "unicodePwd")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e2:
            (num, _) = e2.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

    def test_unicodePwd_hash_change(self):
        """Performs a password hash change operation on 'unicodePwd' which should be prevented"""
        # Notice: Direct hash password changes should never work

        # Hash password changes should never work
        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: unicodePwd
unicodePwd: XXXXXXXXXXXXXXXX
add: unicodePwd
unicodePwd: YYYYYYYYYYYYYYYY
""")
            self.fail()
        except LdbError as e3:
            (num, _) = e3.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

    def test_unicodePwd_clear_set(self):
        """Performs a password cleartext set operation on 'unicodePwd'"""

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["unicodePwd"] = MessageElement("\"thatsAcomplPASS2\"".encode('utf-16-le'),
          FLAG_MOD_REPLACE, "unicodePwd")
        self.ldb.modify(m)

    def test_unicodePwd_clear_change(self):
        """Performs a password cleartext change operation on 'unicodePwd'"""

        self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: unicodePwd
unicodePwd:: """ + base64.b64encode("\"thatsAcomplPASS1\"".encode('utf-16-le')).decode('utf8') + """
add: unicodePwd
unicodePwd:: """ + base64.b64encode("\"thatsAcomplPASS2\"".encode('utf-16-le')).decode('utf8') + """
""")

        # Wrong old password
        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: unicodePwd
unicodePwd:: """ + base64.b64encode("\"thatsAcomplPASS3\"".encode('utf-16-le')).decode('utf8') + """
add: unicodePwd
unicodePwd:: """ + base64.b64encode("\"thatsAcomplPASS4\"".encode('utf-16-le')).decode('utf8') + """
""")
            self.fail()
        except LdbError as e4:
            (num, msg) = e4.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)
            self.assertTrue('00000056' in msg)

        # A change to the same password again will not work (password history)
        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: unicodePwd
unicodePwd:: """ + base64.b64encode("\"thatsAcomplPASS2\"".encode('utf-16-le')).decode('utf8') + """
add: unicodePwd
unicodePwd:: """ + base64.b64encode("\"thatsAcomplPASS2\"".encode('utf-16-le')).decode('utf8') + """
""")
            self.fail()
        except LdbError as e5:
            (num, msg) = e5.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)
            self.assertTrue('0000052D' in msg)

    def test_dBCSPwd_hash_set(self):
        """Performs a password hash set operation on 'dBCSPwd' which should be prevented"""
        # Notice: Direct hash password sets should never work

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["dBCSPwd"] = MessageElement("XXXXXXXXXXXXXXXX", FLAG_MOD_REPLACE,
          "dBCSPwd")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e6:
            (num, _) = e6.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

    def test_dBCSPwd_hash_change(self):
        """Performs a password hash change operation on 'dBCSPwd' which should be prevented"""
        # Notice: Direct hash password changes should never work

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: dBCSPwd
dBCSPwd: XXXXXXXXXXXXXXXX
add: dBCSPwd
dBCSPwd: YYYYYYYYYYYYYYYY
""")
            self.fail()
        except LdbError as e7:
            (num, _) = e7.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

    def test_userPassword_clear_set(self):
        """Performs a password cleartext set operation on 'userPassword'"""
        # Notice: This works only against Windows if "dSHeuristics" has been set
        # properly

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement("thatsAcomplPASS2", FLAG_MOD_REPLACE,
          "userPassword")
        self.ldb.modify(m)

    def test_userPassword_clear_change(self):
        """Performs a password cleartext change operation on 'userPassword'"""
        # Notice: This works only against Windows if "dSHeuristics" has been set
        # properly

        self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
""")

        # Wrong old password
        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS3
add: userPassword
userPassword: thatsAcomplPASS4
""")
            self.fail()
        except LdbError as e8:
            (num, msg) = e8.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)
            self.assertTrue('00000056' in msg)

        # A change to the same password again will not work (password history)
        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS2
add: userPassword
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e9:
            (num, msg) = e9.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)
            self.assertTrue('0000052D' in msg)

    def test_clearTextPassword_clear_set(self):
        """Performs a password cleartext set operation on 'clearTextPassword'"""
        # Notice: This never works against Windows - only supported by us

        try:
            m = Message()
            m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
            m["clearTextPassword"] = MessageElement("thatsAcomplPASS2".encode('utf-16-le'),
              FLAG_MOD_REPLACE, "clearTextPassword")
            self.ldb.modify(m)
            # this passes against s4
        except LdbError as e10:
            (num, msg) = e10.args
            # "NO_SUCH_ATTRIBUTE" is returned by Windows -> ignore it
            if num != ERR_NO_SUCH_ATTRIBUTE:
                raise LdbError(num, msg)

    def test_clearTextPassword_clear_change(self):
        """Performs a password cleartext change operation on 'clearTextPassword'"""
        # Notice: This never works against Windows - only supported by us

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: clearTextPassword
clearTextPassword:: """ + base64.b64encode("thatsAcomplPASS1".encode('utf-16-le')).decode('utf8') + """
add: clearTextPassword
clearTextPassword:: """ + base64.b64encode("thatsAcomplPASS2".encode('utf-16-le')).decode('utf8') + """
""")
            # this passes against s4
        except LdbError as e11:
            (num, msg) = e11.args
            # "NO_SUCH_ATTRIBUTE" is returned by Windows -> ignore it
            if num != ERR_NO_SUCH_ATTRIBUTE:
                raise LdbError(num, msg)

        # Wrong old password
        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: clearTextPassword
clearTextPassword:: """ + base64.b64encode("thatsAcomplPASS3".encode('utf-16-le')).decode('utf8') + """
add: clearTextPassword
clearTextPassword:: """ + base64.b64encode("thatsAcomplPASS4".encode('utf-16-le')).decode('utf8') + """
""")
            self.fail()
        except LdbError as e12:
            (num, msg) = e12.args
            # "NO_SUCH_ATTRIBUTE" is returned by Windows -> ignore it
            if num != ERR_NO_SUCH_ATTRIBUTE:
                self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)
                self.assertTrue('00000056' in msg)

        # A change to the same password again will not work (password history)
        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: clearTextPassword
clearTextPassword:: """ + base64.b64encode("thatsAcomplPASS2".encode('utf-16-le')).decode('utf8') + """
add: clearTextPassword
clearTextPassword:: """ + base64.b64encode("thatsAcomplPASS2".encode('utf-16-le')).decode('utf8') + """
""")
            self.fail()
        except LdbError as e13:
            (num, msg) = e13.args
            # "NO_SUCH_ATTRIBUTE" is returned by Windows -> ignore it
            if num != ERR_NO_SUCH_ATTRIBUTE:
                self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)
                self.assertTrue('0000052D' in msg)

    def test_failures(self):
        """Performs some failure testing"""

        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
""")
            self.fail()
        except LdbError as e14:
            (num, _) = e14.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
""")
            self.fail()
        except LdbError as e15:
            (num, _) = e15.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
""")
            self.fail()
        except LdbError as e16:
            (num, _) = e16.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
""")
            self.fail()
        except LdbError as e17:
            (num, _) = e17.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
add: userPassword
userPassword: thatsAcomplPASS1
""")
            self.fail()
        except LdbError as e18:
            (num, _) = e18.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
add: userPassword
userPassword: thatsAcomplPASS1
""")
            self.fail()
        except LdbError as e19:
            (num, _) = e19.args
            self.assertEquals(num, ERR_INSUFFICIENT_ACCESS_RIGHTS)

        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e20:
            (num, _) = e20.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e21:
            (num, _) = e21.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e22:
            (num, _) = e22.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e23:
            (num, _) = e23.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
add: userPassword
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e24:
            (num, _) = e24.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
add: userPassword
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e25:
            (num, _) = e25.args
            self.assertEquals(num, ERR_INSUFFICIENT_ACCESS_RIGHTS)

        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e26:
            (num, _) = e26.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e27:
            (num, _) = e27.args
            self.assertEquals(num, ERR_INSUFFICIENT_ACCESS_RIGHTS)

        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
replace: userPassword
userPassword: thatsAcomplPASS3
""")
            self.fail()
        except LdbError as e28:
            (num, _) = e28.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
replace: userPassword
userPassword: thatsAcomplPASS3
""")
            self.fail()
        except LdbError as e29:
            (num, _) = e29.args
            self.assertEquals(num, ERR_INSUFFICIENT_ACCESS_RIGHTS)

        # Reverse order does work
        self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
add: userPassword
userPassword: thatsAcomplPASS2
delete: userPassword
userPassword: thatsAcomplPASS1
""")

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS2
add: unicodePwd
unicodePwd:: """ + base64.b64encode("\"thatsAcomplPASS3\"".encode('utf-16-le')).decode('utf8') + """
""")
             # this passes against s4
        except LdbError as e30:
            (num, _) = e30.args
            self.assertEquals(num, ERR_ATTRIBUTE_OR_VALUE_EXISTS)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: unicodePwd
unicodePwd:: """ + base64.b64encode("\"thatsAcomplPASS3\"".encode('utf-16-le')).decode('utf8') + """
add: userPassword
userPassword: thatsAcomplPASS4
""")
             # this passes against s4
        except LdbError as e31:
            (num, _) = e31.args
            self.assertEquals(num, ERR_NO_SUCH_ATTRIBUTE)

        # Several password changes at once are allowed
        self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
replace: userPassword
userPassword: thatsAcomplPASS1
userPassword: thatsAcomplPASS2
""")

        # Several password changes at once are allowed
        self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
replace: userPassword
userPassword: thatsAcomplPASS1
userPassword: thatsAcomplPASS2
replace: userPassword
userPassword: thatsAcomplPASS3
replace: userPassword
userPassword: thatsAcomplPASS4
""")

        # This surprisingly should work
        delete_force(self.ldb, "cn=testuser2,cn=users," + self.base_dn)
        self.ldb.add({
             "dn": "cn=testuser2,cn=users," + self.base_dn,
             "objectclass": "user",
             "userPassword": ["thatsAcomplPASS1", "thatsAcomplPASS2"] })

        # This surprisingly should work
        delete_force(self.ldb, "cn=testuser2,cn=users," + self.base_dn)
        self.ldb.add({
             "dn": "cn=testuser2,cn=users," + self.base_dn,
             "objectclass": "user",
             "userPassword": ["thatsAcomplPASS1", "thatsAcomplPASS1"] })

    def test_empty_passwords(self):
        print("Performs some empty passwords testing")

        try:
            self.ldb.add({
                 "dn": "cn=testuser2,cn=users," + self.base_dn,
                 "objectclass": "user",
                 "unicodePwd": [] })
            self.fail()
        except LdbError as e32:
            (num, _) = e32.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb.add({
                 "dn": "cn=testuser2,cn=users," + self.base_dn,
                 "objectclass": "user",
                 "dBCSPwd": [] })
            self.fail()
        except LdbError as e33:
            (num, _) = e33.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb.add({
                 "dn": "cn=testuser2,cn=users," + self.base_dn,
                 "objectclass": "user",
                 "userPassword": [] })
            self.fail()
        except LdbError as e34:
            (num, _) = e34.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb.add({
                 "dn": "cn=testuser2,cn=users," + self.base_dn,
                 "objectclass": "user",
                 "clearTextPassword": [] })
            self.fail()
        except LdbError as e35:
            (num, _) = e35.args
            self.assertTrue(num == ERR_CONSTRAINT_VIOLATION or
                            num == ERR_NO_SUCH_ATTRIBUTE) # for Windows

        delete_force(self.ldb, "cn=testuser2,cn=users," + self.base_dn)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["unicodePwd"] = MessageElement([], FLAG_MOD_ADD, "unicodePwd")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e36:
            (num, _) = e36.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["dBCSPwd"] = MessageElement([], FLAG_MOD_ADD, "dBCSPwd")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e37:
            (num, _) = e37.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement([], FLAG_MOD_ADD, "userPassword")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e38:
            (num, _) = e38.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["clearTextPassword"] = MessageElement([], FLAG_MOD_ADD, "clearTextPassword")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e39:
            (num, _) = e39.args
            self.assertTrue(num == ERR_CONSTRAINT_VIOLATION or
                            num == ERR_NO_SUCH_ATTRIBUTE) # for Windows

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["unicodePwd"] = MessageElement([], FLAG_MOD_REPLACE, "unicodePwd")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e40:
            (num, _) = e40.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["dBCSPwd"] = MessageElement([], FLAG_MOD_REPLACE, "dBCSPwd")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e41:
            (num, _) = e41.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement([], FLAG_MOD_REPLACE, "userPassword")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e42:
            (num, _) = e42.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["clearTextPassword"] = MessageElement([], FLAG_MOD_REPLACE, "clearTextPassword")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e43:
            (num, _) = e43.args
            self.assertTrue(num == ERR_UNWILLING_TO_PERFORM or
                            num == ERR_NO_SUCH_ATTRIBUTE) # for Windows

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["unicodePwd"] = MessageElement([], FLAG_MOD_DELETE, "unicodePwd")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e44:
            (num, _) = e44.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["dBCSPwd"] = MessageElement([], FLAG_MOD_DELETE, "dBCSPwd")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e45:
            (num, _) = e45.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement([], FLAG_MOD_DELETE, "userPassword")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e46:
            (num, _) = e46.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["clearTextPassword"] = MessageElement([], FLAG_MOD_DELETE, "clearTextPassword")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e47:
            (num, _) = e47.args
            self.assertTrue(num == ERR_CONSTRAINT_VIOLATION or
                            num == ERR_NO_SUCH_ATTRIBUTE) # for Windows

    def test_plain_userPassword(self):
        print("Performs testing about the standard 'userPassword' behaviour")

        # Delete the "dSHeuristics"
        self.ldb.set_dsheuristics(None)

        time.sleep(1) # This switching time is strictly needed!

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement("myPassword", FLAG_MOD_ADD,
          "userPassword")
        self.ldb.modify(m)

        res = self.ldb.search("cn=testuser,cn=users," + self.base_dn,
                         scope=SCOPE_BASE, attrs=["userPassword"])
        self.assertTrue(len(res) == 1)
        self.assertTrue("userPassword" in res[0])
        self.assertEquals(res[0]["userPassword"][0], "myPassword")

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement("myPassword2", FLAG_MOD_REPLACE,
          "userPassword")
        self.ldb.modify(m)

        res = self.ldb.search("cn=testuser,cn=users," + self.base_dn,
                         scope=SCOPE_BASE, attrs=["userPassword"])
        self.assertTrue(len(res) == 1)
        self.assertTrue("userPassword" in res[0])
        self.assertEquals(res[0]["userPassword"][0], "myPassword2")

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement([], FLAG_MOD_DELETE,
          "userPassword")
        self.ldb.modify(m)

        res = self.ldb.search("cn=testuser,cn=users," + self.base_dn,
                         scope=SCOPE_BASE, attrs=["userPassword"])
        self.assertTrue(len(res) == 1)
        self.assertFalse("userPassword" in res[0])

        # Set the test "dSHeuristics" to deactivate "userPassword" pwd changes
        self.ldb.set_dsheuristics("000000000")

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement("myPassword3", FLAG_MOD_REPLACE,
          "userPassword")
        self.ldb.modify(m)

        res = self.ldb.search("cn=testuser,cn=users," + self.base_dn,
                         scope=SCOPE_BASE, attrs=["userPassword"])
        self.assertTrue(len(res) == 1)
        self.assertTrue("userPassword" in res[0])
        self.assertEquals(res[0]["userPassword"][0], "myPassword3")

        # Set the test "dSHeuristics" to deactivate "userPassword" pwd changes
        self.ldb.set_dsheuristics("000000002")

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement("myPassword4", FLAG_MOD_REPLACE,
          "userPassword")
        self.ldb.modify(m)

        res = self.ldb.search("cn=testuser,cn=users," + self.base_dn,
                         scope=SCOPE_BASE, attrs=["userPassword"])
        self.assertTrue(len(res) == 1)
        self.assertTrue("userPassword" in res[0])
        self.assertEquals(res[0]["userPassword"][0], "myPassword4")

        # Reset the test "dSHeuristics" (reactivate "userPassword" pwd changes)
        self.ldb.set_dsheuristics("000000001")

    def test_modify_dsheuristics_userPassword(self):
        print("Performs testing about reading userPassword between dsHeuristic modifies")

        # Make sure userPassword cannot be read
        self.ldb.set_dsheuristics("000000000")

        # Open a new connection (with dsHeuristic=000000000)
        ldb1 = SamDB(url=host, session_info=system_session(lp),
                     credentials=creds, lp=lp)

        # Set userPassword to be read
        # This setting only affects newer connections (ldb2)
        ldb1.set_dsheuristics("000000001")
        time.sleep(1)

        m = Message()
        m.dn = Dn(ldb1, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement("thatsAcomplPASS1", FLAG_MOD_REPLACE,
          "userPassword")
        ldb1.modify(m)

        res = ldb1.search("cn=testuser,cn=users," + self.base_dn,
                          scope=SCOPE_BASE, attrs=["userPassword"])

        # userPassword cannot be read, it wasn't set, instead the
        # password was
        self.assertTrue(len(res) == 1)
        self.assertFalse("userPassword" in res[0])

        # Open another new connection (with dsHeuristic=000000001)
        ldb2 = SamDB(url=host, session_info=system_session(lp),
                     credentials=creds, lp=lp)

        res = ldb2.search("cn=testuser,cn=users," + self.base_dn,
                          scope=SCOPE_BASE, attrs=["userPassword"])

        # Check on the new connection that userPassword was not stored
        # from ldb1 or is not readable
        self.assertTrue(len(res) == 1)
        self.assertFalse("userPassword" in res[0])

        # Set userPassword to be readable
        # This setting does not affect this connection
        ldb2.set_dsheuristics("000000000")
        time.sleep(1)

        res = ldb2.search("cn=testuser,cn=users," + self.base_dn,
                          scope=SCOPE_BASE, attrs=["userPassword"])

        # Check that userPassword was not stored from ldb1
        self.assertTrue(len(res) == 1)
        self.assertFalse("userPassword" in res[0])

        m = Message()
        m.dn = Dn(ldb2, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement("thatsAcomplPASS2", FLAG_MOD_REPLACE,
          "userPassword")
        ldb2.modify(m)

        res = ldb2.search("cn=testuser,cn=users," + self.base_dn,
                          scope=SCOPE_BASE, attrs=["userPassword"])

        # Check despite setting it with userPassword support disabled
        # on this connection it should still not be readable
        self.assertTrue(len(res) == 1)
        self.assertFalse("userPassword" in res[0])

        # Only password from ldb1 is the user's password
        creds2 = Credentials()
        creds2.set_username("testuser")
        creds2.set_password("thatsAcomplPASS1")
        creds2.set_domain(creds.get_domain())
        creds2.set_realm(creds.get_realm())
        creds2.set_workstation(creds.get_workstation())
        creds2.set_gensec_features(creds2.get_gensec_features()
                                   | gensec.FEATURE_SEAL)

        try:
            SamDB(url=host, credentials=creds2, lp=lp)
        except:
            self.fail("testuser used the wrong password")

        ldb3 = SamDB(url=host, session_info=system_session(lp),
                     credentials=creds, lp=lp)

        # Check that userPassword was stored from ldb2
        res = ldb3.search("cn=testuser,cn=users," + self.base_dn,
                          scope=SCOPE_BASE, attrs=["userPassword"])

        # userPassword can be read
        self.assertTrue(len(res) == 1)
        self.assertTrue("userPassword" in res[0])
        self.assertEquals(res[0]["userPassword"][0], "thatsAcomplPASS2")

        # Reset the test "dSHeuristics" (reactivate "userPassword" pwd changes)
        self.ldb.set_dsheuristics("000000001")

        ldb4 = SamDB(url=host, session_info=system_session(lp),
                     credentials=creds, lp=lp)

        # Check that userPassword that was stored from ldb2
        res = ldb4.search("cn=testuser,cn=users," + self.base_dn,
                          scope=SCOPE_BASE, attrs=["userPassword"])

        # userPassword can be not be read
        self.assertTrue(len(res) == 1)
        self.assertFalse("userPassword" in res[0])

    def test_zero_length(self):
        # Get the old "minPwdLength"
        minPwdLength = self.ldb.get_minPwdLength()
        # Set it temporarely to "0"
        self.ldb.set_minPwdLength("0")

        # Get the old "pwdProperties"
        pwdProperties = self.ldb.get_pwdProperties()
        # Set them temporarely to "0" (to deactivate eventually the complexity)
        self.ldb.set_pwdProperties("0")

        self.ldb.setpassword("(sAMAccountName=testuser)", "")

        # Reset the "pwdProperties" as they were before
        self.ldb.set_pwdProperties(pwdProperties)

        # Reset the "minPwdLength" as it was before
        self.ldb.set_minPwdLength(minPwdLength)

    def test_pw_change_delete_no_value_userPassword(self):
        """Test password change with userPassword where the delete attribute doesn't have a value"""

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
add: userPassword
userPassword: thatsAcomplPASS1
""")
        except LdbError, (num, msg):
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)
        else:
Exemple #16
0
    def run(self, psoname, precedence, H=None, min_pwd_age=None,
            max_pwd_age=None, complexity=None, store_plaintext=None,
            history_length=None, min_pwd_length=None,
            account_lockout_duration=None, account_lockout_threshold=None,
            reset_account_lockout_after=None, credopts=None, sambaopts=None,
            versionopts=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        samdb = SamDB(url=H, session_info=system_session(),
            credentials=creds, lp=lp)

        try:
            precedence = int(precedence)
        except ValueError:
            raise CommandError("The PSO's precedence should be a numerical value. Try --help")

        # sanity-check that the PSO doesn't already exist
        pso_dn = "CN=%s,%s" % (psoname, pso_container(samdb))
        try:
            res = samdb.search(pso_dn, scope=ldb.SCOPE_BASE)
        except Exception as e:
            pass
        else:
            raise CommandError("PSO '%s' already exists" % psoname)

        # we expect the user to specify at least one password-policy setting,
        # otherwise there's no point in creating a PSO
        num_pwd_args = num_options_in_args(pwd_settings_options, self.raw_argv)
        if num_pwd_args == 0:
            raise CommandError("Please specify at least one password policy setting. Try --help")

        # it's unlikely that the user will specify all 9 password policy
        # settings on the CLI - current domain password-settings as the default
        # values for unspecified arguments
        if num_pwd_args < len(pwd_settings_options):
            self.message("Not all password policy options have been specified.")
            self.message("For unspecified options, the current domain password settings will be used as the default values.")

        # lookup the current domain password-settings
        res = samdb.search(samdb.domain_dn(), scope=ldb.SCOPE_BASE,
            attrs=["pwdProperties", "pwdHistoryLength", "minPwdLength",
                "minPwdAge", "maxPwdAge", "lockoutDuration",
                "lockoutThreshold", "lockOutObservationWindow"])
        assert(len(res) == 1)

        # use the domain settings for any missing arguments
        pwd_props = int(res[0]["pwdProperties"][0])
        if complexity is None:
            prop_flag = DOMAIN_PASSWORD_COMPLEX
            complexity = "on" if pwd_props & prop_flag else "off"

        if store_plaintext is None:
            prop_flag = DOMAIN_PASSWORD_STORE_CLEARTEXT
            store_plaintext = "on" if pwd_props & prop_flag else "off"

        if history_length is None:
            history_length = int(res[0]["pwdHistoryLength"][0])

        if min_pwd_length is None:
            min_pwd_length = int(res[0]["minPwdLength"][0])

        if min_pwd_age is None:
            min_pwd_age = timestamp_to_days(res[0]["minPwdAge"][0])

        if max_pwd_age is None:
            max_pwd_age = timestamp_to_days(res[0]["maxPwdAge"][0])

        if account_lockout_duration is None:
            account_lockout_duration = \
                timestamp_to_mins(res[0]["lockoutDuration"][0])

        if account_lockout_threshold is None:
            account_lockout_threshold = int(res[0]["lockoutThreshold"][0])

        if reset_account_lockout_after is None:
            reset_account_lockout_after = \
                timestamp_to_mins(res[0]["lockOutObservationWindow"][0])

        check_pso_constraints(max_pwd_age=max_pwd_age, min_pwd_age=min_pwd_age,
                              history_length=history_length,
                              min_pwd_length=min_pwd_length)

        # pack the settings into an LDB message
        m = make_pso_ldb_msg(self.outf, samdb, pso_dn, create=True,
                             complexity=complexity, precedence=precedence,
                             store_plaintext=store_plaintext,
                             history_length=history_length,
                             min_pwd_length=min_pwd_length,
                             min_pwd_age=min_pwd_age, max_pwd_age=max_pwd_age,
                             lockout_duration=account_lockout_duration,
                             lockout_threshold=account_lockout_threshold,
                             reset_account_lockout_after=reset_account_lockout_after)

        # create the new PSO
        try:
            samdb.add(m)
            self.message("PSO successfully created: %s" % pso_dn)
            # display the new PSO's settings
            show_pso_by_dn(self.outf, samdb, pso_dn, show_applies_to=False)
        except ldb.LdbError as e:
            (num, msg) = e.args
            if num == ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS:
                raise CommandError("Administrator permissions are needed to create a PSO.")
            else:
                raise CommandError("Failed to create PSO '%s': %s" %(pso_dn, msg))
Exemple #17
0
class PassWordHashTests(TestCase):

    def setUp(self):
        self.lp = samba.tests.env_loadparm()
        super(PassWordHashTests, self).setUp()

    def set_store_cleartext(self, cleartext):
        # get the current pwdProperties
        pwdProperties = self.ldb.get_pwdProperties()
        # update the clear-text properties flag
        props = int(pwdProperties)
        if cleartext:
            props |= DOMAIN_PASSWORD_STORE_CLEARTEXT
        else:
            props &= ~DOMAIN_PASSWORD_STORE_CLEARTEXT
        self.ldb.set_pwdProperties(str(props))

    # Add a user to ldb, this will exercise the password_hash code
    # and calculate the appropriate supplemental credentials
    def add_user(self, options=None, clear_text=False, ldb=None):
        # set any needed options
        if options is not None:
            for (option, value) in options:
                self.lp.set(option, value)

        if ldb is None:
            self.creds = Credentials()
            self.session = system_session()
            self.creds.guess(self.lp)
            self.session = system_session()
            self.ldb = SamDB(session_info=self.session,
                             credentials=self.creds,
                             lp=self.lp)
        else:
            self.ldb = ldb

        res = self.ldb.search(base=self.ldb.get_config_basedn(),
                              expression="ncName=%s" % self.ldb.get_default_basedn(),
                              attrs=["nETBIOSName"])
        self.netbios_domain = res[0]["nETBIOSName"][0]
        self.dns_domain = self.ldb.domain_dns_name()


        # Gets back the basedn
        base_dn = self.ldb.domain_dn()

        # Gets back the configuration basedn
        configuration_dn = self.ldb.get_config_basedn().get_linearized()

        # permit password changes during this test
        PasswordCommon.allow_password_changes(self, self.ldb)

        self.base_dn = self.ldb.domain_dn()

        account_control = 0
        if clear_text:
            # Restore the current domain setting on exit.
            pwdProperties = self.ldb.get_pwdProperties()
            self.addCleanup(self.ldb.set_pwdProperties, pwdProperties)
            # Update the domain setting
            self.set_store_cleartext(clear_text)
            account_control |= UF_ENCRYPTED_TEXT_PASSWORD_ALLOWED

        # (Re)adds the test user USER_NAME with password USER_PASS
        # and userPrincipalName UPN
        delete_force(self.ldb, "cn=" + USER_NAME + ",cn=users," + self.base_dn)
        self.ldb.add({
             "dn": "cn=" + USER_NAME + ",cn=users," + self.base_dn,
             "objectclass": "user",
             "sAMAccountName": USER_NAME,
             "userPassword": USER_PASS,
             "userPrincipalName": UPN,
             "userAccountControl": str(account_control)
        })

    # Get the supplemental credentials for the user under test
    def get_supplemental_creds(self):
        base = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        res = self.ldb.search(scope=ldb.SCOPE_BASE,
                              base=base,
                              attrs=["supplementalCredentials"])
        self.assertIs(True, len(res) > 0)
        obj = res[0]
        sc_blob = obj["supplementalCredentials"][0]
        sc = ndr_unpack(drsblobs.supplementalCredentialsBlob, sc_blob)
        return sc

    # Calculate and validate a Wdigest value
    def check_digest(self, user, realm, password,  digest):
        expected = calc_digest(user, realm, password)
        actual = binascii.hexlify(bytearray(digest))
        error = "Digest expected[%s], actual[%s], " \
                "user[%s], realm[%s], pass[%s]" % \
                (expected, actual, user, realm, password)
        self.assertEquals(expected, actual, error)

    # Check all of the 29 expected WDigest values
    #
    def check_wdigests(self, digests):

        self.assertEquals(29, digests.num_hashes)

        # Using the n-1 pattern in the array indexes to make it easier
        # to check the tests against the spec and the samba-tool user tests.
        self.check_digest(USER_NAME,
                          self.netbios_domain,
                          USER_PASS,
                          digests.hashes[1-1].hash)
        self.check_digest(USER_NAME.lower(),
                          self.netbios_domain.lower(),
                          USER_PASS,
                          digests.hashes[2-1].hash)
        self.check_digest(USER_NAME.upper(),
                          self.netbios_domain.upper(),
                          USER_PASS,
                          digests.hashes[3-1].hash)
        self.check_digest(USER_NAME,
                          self.netbios_domain.upper(),
                          USER_PASS,
                          digests.hashes[4-1].hash)
        self.check_digest(USER_NAME,
                          self.netbios_domain.lower(),
                          USER_PASS,
                          digests.hashes[5-1].hash)
        self.check_digest(USER_NAME.upper(),
                          self.netbios_domain.lower(),
                          USER_PASS,
                          digests.hashes[6-1].hash)
        self.check_digest(USER_NAME.lower(),
                          self.netbios_domain.upper(),
                          USER_PASS,
                          digests.hashes[7-1].hash)
        self.check_digest(USER_NAME,
                          self.dns_domain,
                          USER_PASS,
                          digests.hashes[8-1].hash)
        self.check_digest(USER_NAME.lower(),
                          self.dns_domain.lower(),
                          USER_PASS,
                          digests.hashes[9-1].hash)
        self.check_digest(USER_NAME.upper(),
                          self.dns_domain.upper(),
                          USER_PASS,
                          digests.hashes[10-1].hash)
        self.check_digest(USER_NAME,
                          self.dns_domain.upper(),
                          USER_PASS,
                          digests.hashes[11-1].hash)
        self.check_digest(USER_NAME,
                          self.dns_domain.lower(),
                          USER_PASS,
                          digests.hashes[12-1].hash)
        self.check_digest(USER_NAME.upper(),
                          self.dns_domain.lower(),
                          USER_PASS,
                          digests.hashes[13-1].hash)
        self.check_digest(USER_NAME.lower(),
                          self.dns_domain.upper(),
                          USER_PASS,
                          digests.hashes[14-1].hash)
        self.check_digest(UPN,
                          "",
                          USER_PASS,
                          digests.hashes[15-1].hash)
        self.check_digest(UPN.lower(),
                          "",
                          USER_PASS,
                          digests.hashes[16-1].hash)
        self.check_digest(UPN.upper(),
                          "",
                          USER_PASS,
                          digests.hashes[17-1].hash)

        name = "%s\\%s" % (self.netbios_domain, USER_NAME)
        self.check_digest(name,
                          "",
                          USER_PASS,
                          digests.hashes[18-1].hash)

        name = "%s\\%s" % (self.netbios_domain.lower(), USER_NAME.lower())
        self.check_digest(name,
                          "",
                          USER_PASS,
                          digests.hashes[19-1].hash)

        name = "%s\\%s" % (self.netbios_domain.upper(), USER_NAME.upper())
        self.check_digest(name,
                          "",
                          USER_PASS,
                          digests.hashes[20-1].hash)
        self.check_digest(USER_NAME,
                          "Digest",
                          USER_PASS,
                          digests.hashes[21-1].hash)
        self.check_digest(USER_NAME.lower(),
                          "Digest",
                          USER_PASS,
                          digests.hashes[22-1].hash)
        self.check_digest(USER_NAME.upper(),
                          "Digest",
                          USER_PASS,
                          digests.hashes[23-1].hash)
        self.check_digest(UPN,
                          "Digest",
                          USER_PASS,
                          digests.hashes[24-1].hash)
        self.check_digest(UPN.lower(),
                          "Digest",
                          USER_PASS,
                          digests.hashes[25-1].hash)
        self.check_digest(UPN.upper(),
                          "Digest",
                          USER_PASS,
                          digests.hashes[26-1].hash)
        name = "%s\\%s" % (self.netbios_domain, USER_NAME)
        self.check_digest(name,
                          "Digest",
                          USER_PASS,
                          digests.hashes[27-1].hash)

        name = "%s\\%s" % (self.netbios_domain.lower(), USER_NAME.lower())
        self.check_digest(name,
                          "Digest",
                          USER_PASS,
                          digests.hashes[28-1].hash)

        name = "%s\\%s" % (self.netbios_domain.upper(), USER_NAME.upper())
        self.check_digest(name,
                          "Digest",
                          USER_PASS,
                          digests.hashes[29-1].hash)

    def checkUserPassword(self, up, expected):

        # Check we've received the correct number of hashes
        self.assertEquals(len(expected), up.num_hashes)

        i = 0
        for (tag, alg, rounds) in expected:
            self.assertEquals(tag, up.hashes[i].scheme)

            data = up.hashes[i].value.split("$")
            # Check we got the expected crypt algorithm
            self.assertEquals(alg, data[1])

            if rounds is None:
                cmd = "$%s$%s" % (alg, data[2])
            else:
                cmd = "$%s$rounds=%d$%s" % (alg, rounds, data[3])

            # Calculate the expected hash value
            expected = crypt.crypt(USER_PASS, cmd)
            self.assertEquals(expected, up.hashes[i].value)
            i += 1

    # Check that the correct nt_hash was stored for userPassword
    def checkNtHash(self, password, nt_hash):
        creds = Credentials()
        creds.set_anonymous()
        creds.set_password(password)
        expected = creds.get_nt_hash()
        actual = bytearray(nt_hash)
        self.assertEquals(expected, actual)
Exemple #18
0
class DsdbLockTestCase(SamDBTestCase):
    def test_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del(self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               lp=self.lp)

            self.samdb.transaction_start()

            dn = "cn=test_db_lock_user,cn=users," + str(basedn)
            self.samdb.add({
                 "dn": dn,
                 "objectclass": "user",
            })
            self.samdb.delete(dn)

            # Obtain a write lock
            self.samdb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            self.samdb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        # Release the locks
        for l in res:
            pass

        end = time.time()
        self.assertGreater(end - start, 1.9)

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del(self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        dn = "cn=test_db_lock_user,cn=users," + str(basedn)
        self.samdb.add({
             "dn": dn,
             "objectclass": "user",
        })
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the parent releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            self.samdb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock3(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del(self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")

        # This will end up in the top level db
        dn = "@DSDB_LOCK_TEST"
        self.samdb.add({
             "dn": dn})
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        self.assertGreater(end - start, 1.9)
        os.write(w1, b"prepared")

        # Drop the write lock
        self.samdb.transaction_cancel()

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)
        self.assertEqual(got_pid, pid)


    def _test_full_db_lock1(self, backend_path):
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del(self.samdb)
            gc.collect()

            backenddb = ldb.Ldb(backend_path)


            backenddb.transaction_start()

            backenddb.add({"dn":"@DSDB_LOCK_TEST"})
            backenddb.delete("@DSDB_LOCK_TEST")

            # Obtain a write lock
            backenddb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            backenddb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        end = time.time()
        self.assertGreater(end - start, 1.9)

        # Release the locks
        for l in res:
            pass

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d",
                                       backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)


    def test_full_db_lock1_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d",
                                       backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)


    def _test_full_db_lock2(self, backend_path):
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:

            # In the child, close the main DB, re-open
            del(self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)
            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # In the parent, close the main DB, re-open just one DB
        del(self.samdb)
        gc.collect()
        backenddb = ldb.Ldb(backend_path)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        backenddb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        backenddb.add({"dn":"@DSDB_LOCK_TEST"})
        backenddb.delete("@DSDB_LOCK_TEST")
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        backenddb.transaction_prepare_commit()
        end = time.time()

        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            backenddb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d",
                                       backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)

    def test_full_db_lock2_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d",
                                       backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)
Exemple #19
0
class DsdbTests(TestCase):

    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

        # Create a test user
        user_name = "dsdb-user-" + str(uuid.uuid4().hex[0:6])
        user_pass = samba.generate_random_password(32, 32)
        user_description = "Test user for dsdb test"

        base_dn = self.samdb.domain_dn()

        self.account_dn = "cn=" + user_name + ",cn=Users," + base_dn
        self.samdb.newuser(username=user_name,
                           password=user_pass,
                           description=user_description)
        # Cleanup (teardown)
        self.addCleanup(delete_force, self.samdb, self.account_dn)

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEquals(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0", "local_oid:1.3.6.1.4.1.7165.4.3.25:0"])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val", ldb.FLAG_MOD_REPLACE, "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(13), "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 2)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description", version + 2)
        self.assertEqual(self.samdb.get_attribute_replmetadata_version(dn, "description"), version + 2)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                    base=self.account_dn,
                                    attrs=["replPropertyMetaData"],
                                    controls=["local_oid:%s:0"
                                              % dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                    base=self.account_dn,
                                    attrs=["replPropertyMetaData"],
                                    controls=["local_oid:%s:1"
                                              % dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED])
        except ldb.LdbError as e:
            (errno, estr) = e.args
            if errno != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail("Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                          % e[1])

    # Allocate a unique RID for use in the objectSID tests.
    #
    def allocate_rid(self):
        self.samdb.transaction_start()
        try:
            rid = self.samdb.allocate_rid()
        except:
            self.samdb.transaction_cancel()
            raise
        self.samdb.transaction_commit()
        return str(rid)

    # Ensure that duplicate objectSID's are permitted for foreign security
    # principals.
    #
    def test_duplicate_objectSIDs_allowed_on_foreign_security_principals(self):

        #
        # We need to build a foreign security principal SID
        # i.e a  SID not in the current domain.
        #
        dom_sid = self.samdb.get_domain_sid()
        if str(dom_sid).endswith("0"):
            c = "9"
        else:
            c = "0"
        sid_str = str(dom_sid)[:-1] + c + "-1000"
        sid     = ndr_pack(security.dom_sid(sid_str))
        basedn  = self.samdb.get_default_basedn()
        dn      = "CN=%s,CN=ForeignSecurityPrincipals,%s" % (sid_str, basedn)

        #
        # First without control
        #

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"})
            self.fail("No exception should get ERR_OBJECT_CLASS_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_OBJECT_CLASS_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_MISSING_REQUIRED_ATT
            self.assertTrue(werr in msg, msg)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal",
                "objectSid": sid})
            self.fail("No exception should get ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_ILLEGAL_MOD_OPERATION
            self.assertTrue(werr in msg, msg)

        #
        # We need to use the provision control
        # in order to add foreignSecurityPrincipal
        # objects
        #

        controls = ["provision:0"]
        self.samdb.add({
            "dn": dn,
            "objectClass": "foreignSecurityPrincipal"},
            controls=controls)

        self.samdb.delete(dn)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"},
                controls=controls)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.fail("Got unexpected exception %d - %s "
                      % (code, msg))

        # cleanup
        self.samdb.delete(dn)

    def _test_foreignSecurityPrincipal(self, obj_class, fpo_attr):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn   = self.samdb.get_default_basedn()
        cn       = "dsdb_test_fpo"
        dn_str   = "cn=%s,cn=Users,%s" % (cn, basedn)
        dn = ldb.Dn(self.samdb, dn_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn_str)

        self.samdb.add({
            "dn": dn_str,
            "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_INVALID_GROUP_TYPE
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_NO_SUCH_OBJECT")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_NO_SUCH_OBJECT, str(e))
            werr = "%08X" % werror.WERR_NO_SUCH_MEMBER
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 1)
        self.samdb.delete(res[0].dn)
        self.samdb.delete(dn)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

    def test_foreignSecurityPrincipal_member(self):
        return self._test_foreignSecurityPrincipal(
                "group", "member")

    def test_foreignSecurityPrincipal_MembersForAzRole(self):
        return self._test_foreignSecurityPrincipal(
                "msDS-AzRole", "msDS-MembersForAzRole")

    def test_foreignSecurityPrincipal_NeverRevealGroup(self):
        return self._test_foreignSecurityPrincipal(
                "computer", "msDS-NeverRevealGroup")

    def test_foreignSecurityPrincipal_RevealOnDemandGroup(self):
        return self._test_foreignSecurityPrincipal(
                "computer", "msDS-RevealOnDemandGroup")

    def _test_fail_foreignSecurityPrincipal(self, obj_class, fpo_attr,
                                            msg_exp, lerr_exp, werr_exp,
                                            allow_reference=True):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn   = self.samdb.get_default_basedn()
        cn1       = "dsdb_test_fpo1"
        dn1_str   = "cn=%s,cn=Users,%s" % (cn1, basedn)
        dn1 = ldb.Dn(self.samdb, dn1_str)
        cn2       = "dsdb_test_fpo2"
        dn2_str   = "cn=%s,cn=Users,%s" % (cn2, basedn)
        dn2 = ldb.Dn(self.samdb, dn2_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn1_str)
        self.addCleanup(delete_force, self.samdb, dn2_str)

        self.samdb.add({
            "dn": dn1_str,
            "objectClass": obj_class})

        self.samdb.add({
            "dn": dn2_str,
            "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("%s" % dn2,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            if not allow_reference:
                sel.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            if allow_reference:
                self.fail("Should have not raised an exception: %s" % e)
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(dn2)
        self.samdb.delete(dn1)

    def test_foreignSecurityPrincipal_NonMembers(self):
        return self._test_fail_foreignSecurityPrincipal(
                "group", "msDS-NonMembers",
                "LDB_ERR_UNWILLING_TO_PERFORM/WERR_NOT_SUPPORTED",
                ldb.ERR_UNWILLING_TO_PERFORM, werror.WERR_NOT_SUPPORTED,
                allow_reference=False)

    def test_foreignSecurityPrincipal_HostServiceAccount(self):
        return self._test_fail_foreignSecurityPrincipal(
                "computer", "msDS-HostServiceAccount",
                "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
                ldb.ERR_CONSTRAINT_VIOLATION,
                werror.WERR_DS_NAME_REFERENCE_INVALID)

    def test_foreignSecurityPrincipal_manager(self):
        return self._test_fail_foreignSecurityPrincipal(
                "user", "manager",
                "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
                ldb.ERR_CONSTRAINT_VIOLATION,
                werror.WERR_DS_NAME_REFERENCE_INVALID)

    #
    # Duplicate objectSID's should not be permitted for sids in the local
    # domain. The test sequence is add an object, delete it, then attempt to
    # re-add it, this should fail with a constraint violation
    #
    def test_duplicate_objectSIDs_not_allowed_on_local_objects(self):

        dom_sid = self.samdb.get_domain_sid()
        rid     = self.allocate_rid()
        sid_str = str(dom_sid) + "-" + rid
        sid     = ndr_pack(security.dom_sid(sid_str))
        basedn  = self.samdb.get_default_basedn()
        cn       = "dsdb_test_01"
        dn      = "cn=%s,cn=Users,%s" % (cn, basedn)

        self.samdb.add({
            "dn": dn,
            "objectClass": "user",
            "objectSID": sid})
        self.samdb.delete(dn)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "user",
                "objectSID": sid})
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            if code != ldb.ERR_CONSTRAINT_VIOLATION:
                self.fail("Got %d - %s should have got "
                          "LDB_ERR_CONSTRAINT_VIOLATION"
                          % (code, msg))

    def test_linked_vs_non_linked_reference(self):
        basedn   = self.samdb.get_default_basedn()
        kept_dn_str   = "cn=reference_kept,cn=Users,%s" % (basedn)
        removed_dn_str   = "cn=reference_removed,cn=Users,%s" % (basedn)
        dom_sid = self.samdb.get_domain_sid()
        none_sid_str = str(dom_sid) + "-4294967294"
        none_guid_str = "afafafaf-fafa-afaf-fafa-afafafafafaf"

        self.addCleanup(delete_force, self.samdb, kept_dn_str)
        self.addCleanup(delete_force, self.samdb, removed_dn_str)

        self.samdb.add({
            "dn": kept_dn_str,
            "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=kept_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        kept_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        kept_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        kept_dn = res[0].dn

        self.samdb.add({
            "dn": removed_dn_str,
            "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=removed_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        removed_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        removed_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        self.samdb.delete(removed_dn_str)

        #
        # First try the linked attribute 'manager'
        # by GUID and SID
        #

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                           ldb.FLAG_MOD_ADD,
                                           "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                           ldb.FLAG_MOD_ADD,
                                           "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        #
        # Try the non-linked attribute 'assistant'
        # by GUID and SID, which should work.
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_DELETE,
                                              "assistant")
        self.samdb.modify(msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_DELETE,
                                              "assistant")
        self.samdb.modify(msg)

        #
        # Finally ry the non-linked attribute 'assistant'
        # but with non existing GUID, SID, DN
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("CN=NoneNone,%s" % (basedn),
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % none_sid_str,
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % none_guid_str,
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(kept_dn)

    def test_normalize_dn_in_domain_full(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        full_str = str(full_dn)

        # That is, no change
        self.assertEqual(full_dn,
                         self.samdb.normalize_dn_in_domain(full_str))

    def test_normalize_dn_in_domain_part(self):
        domain_dn = self.samdb.domain_dn()

        part_str = "CN=Users"

        full_dn = ldb.Dn(self.samdb, part_str)
        full_dn.add_base(domain_dn)

        # That is, the domain DN appended
        self.assertEqual(full_dn,
                         self.samdb.normalize_dn_in_domain(part_str))

    def test_normalize_dn_in_domain_full_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        # That is, no change
        self.assertEqual(full_dn,
                         self.samdb.normalize_dn_in_domain(full_dn))

    def test_normalize_dn_in_domain_part_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        # That is, the domain DN appended
        self.assertEqual(ldb.Dn(self.samdb,
                                str(part_dn) + "," + str(domain_dn)),
                         self.samdb.normalize_dn_in_domain(part_dn))
class UserTests(samba.tests.TestCase):

    def add_if_possible(self, *args, **kwargs):
        """In these tests sometimes things are left in the database
        deliberately, so we don't worry if we fail to add them a second
        time."""
        try:
            self.ldb.add(*args, **kwargs)
        except LdbError:
            pass

    def setUp(self):
        super(UserTests, self).setUp()
        self.state = GlobalState  # the class itself, not an instance
        self.lp = lp
        self.ldb = SamDB(host, credentials=creds,
                         session_info=system_session(lp), lp=lp)
        self.base_dn = self.ldb.domain_dn()
        self.ou = "OU=pid%s,%s" % (os.getpid(), self.base_dn)
        self.ou_users = "OU=users,%s" % self.ou
        self.ou_groups = "OU=groups,%s" % self.ou
        self.ou_computers = "OU=computers,%s" % self.ou

        for dn in (self.ou, self.ou_users, self.ou_groups,
                   self.ou_computers):
            self.add_if_possible({
                "dn": dn,
                "objectclass": "organizationalUnit"})

    def tearDown(self):
        super(UserTests, self).tearDown()

    def test_00_00_do_nothing(self):
        # this gives us an idea of the overhead
        pass

    def _prepare_n_groups(self, n):
        self.state.n_groups = n
        for i in range(n):
            self.add_if_possible({
                "dn": "cn=g%d,%s" % (i, self.ou_groups),
                "objectclass": "group"})

    def _add_users(self, start, end):
        for i in range(start, end):
            self.ldb.add({
                "dn": "cn=u%d,%s" % (i, self.ou_users),
                "objectclass": "user"})

    def _test_join(self):
        tmpdir = tempfile.mkdtemp()
        if '://' in host:
            server = host.split('://', 1)[1]
        else:
            server = host
        cmd = cmd_sambatool.subcommands['domain'].subcommands['join']
        result = cmd._run("samba-tool domain join",
                          creds.get_realm(),
                          "dc", "-U%s%%%s" % (creds.get_username(),
                                              creds.get_password()),
                          '--targetdir=%s' % tmpdir,
                          '--server=%s' % server)

        shutil.rmtree(tmpdir)

    def _test_unindexed_search(self):
        expressions = [
            ('(&(objectclass=user)(description='
             'Built-in account for adminstering the computer/domain))'),
            '(description=Built-in account for adminstering the computer/domain)',
            '(objectCategory=*)',
            '(samaccountname=Administrator*)'
        ]
        for expression in expressions:
            t = time.time()
            for i in range(10):
                self.ldb.search(self.ou,
                                expression=expression,
                                scope=SCOPE_SUBTREE,
                                attrs=['cn'])
            print('%d %s took %s' % (i, expression,
                                                    time.time() - t), file=sys.stderr)

    def _test_indexed_search(self):
        expressions = ['(objectclass=group)',
                       '(samaccountname=Administrator)'
        ]
        for expression in expressions:
            t = time.time()
            for i in range(100):
                self.ldb.search(self.ou,
                                expression=expression,
                                scope=SCOPE_SUBTREE,
                                attrs=['cn'])
            print('%d runs %s took %s' % (i, expression,
                                                         time.time() - t), file=sys.stderr)

    def _test_add_many_users(self, n=BATCH_SIZE):
        s = self.state.next_user_id
        e = s + n
        self._add_users(s, e)
        self.state.next_user_id = e

    test_00_00_join_empty_dc = _test_join

    test_00_01_adding_users_1000 = _test_add_many_users
    test_00_02_adding_users_2000 = _test_add_many_users
    test_00_03_adding_users_3000 = _test_add_many_users

    test_00_10_join_unlinked_dc = _test_join
    test_00_11_unindexed_search_3k_users = _test_unindexed_search
    test_00_12_indexed_search_3k_users = _test_indexed_search

    def _link_user_and_group(self, u, g):
        m = Message()
        m.dn = Dn(self.ldb, "CN=g%d,%s" % (g, self.ou_groups))
        m["member"] = MessageElement("cn=u%d,%s" % (u, self.ou_users),
                                     FLAG_MOD_ADD, "member")
        self.ldb.modify(m)

    def _unlink_user_and_group(self, u, g):
        user = "******" % (u, self.ou_users)
        group = "CN=g%d,%s" % (g, self.ou_groups)
        m = Message()
        m.dn = Dn(self.ldb, group)
        m["member"] = MessageElement(user, FLAG_MOD_DELETE, "member")
        self.ldb.modify(m)

    def _test_link_many_users(self, n=BATCH_SIZE):
        self._prepare_n_groups(N_GROUPS)
        s = self.state.next_linked_user
        e = s + n
        for i in range(s, e):
            g = i % N_GROUPS
            self._link_user_and_group(i, g)
        self.state.next_linked_user = e

    test_01_01_link_users_1000 = _test_link_many_users
    test_01_02_link_users_2000 = _test_link_many_users
    test_01_03_link_users_3000 = _test_link_many_users

    def _test_link_many_users_offset_1(self, n=BATCH_SIZE):
        s = self.state.next_relinked_user
        e = s + n
        for i in range(s, e):
            g = (i + 1) % N_GROUPS
            self._link_user_and_group(i, g)
        self.state.next_relinked_user = e

    test_02_01_link_users_again_1000 = _test_link_many_users_offset_1
    test_02_02_link_users_again_2000 = _test_link_many_users_offset_1
    test_02_03_link_users_again_3000 = _test_link_many_users_offset_1

    test_02_10_join_partially_linked_dc = _test_join
    test_02_11_unindexed_search_partially_linked_dc = _test_unindexed_search
    test_02_12_indexed_search_partially_linked_dc = _test_indexed_search

    def _test_link_many_users_3_groups(self, n=BATCH_SIZE, groups=3):
        s = self.state.next_linked_user_3
        e = s + n
        self.state.next_linked_user_3 = e
        for i in range(s, e):
            g = (i + 2) % groups
            if g not in (i % N_GROUPS, (i + 1) % N_GROUPS):
                self._link_user_and_group(i, g)

    test_03_01_link_users_again_1000_few_groups = _test_link_many_users_3_groups
    test_03_02_link_users_again_2000_few_groups = _test_link_many_users_3_groups
    test_03_03_link_users_again_3000_few_groups = _test_link_many_users_3_groups

    def _test_remove_links_0(self, n=BATCH_SIZE):
        s = self.state.next_removed_link_0
        e = s + n
        self.state.next_removed_link_0 = e
        for i in range(s, e):
            g = i % N_GROUPS
            self._unlink_user_and_group(i, g)

    test_04_01_remove_some_links_1000 = _test_remove_links_0
    test_04_02_remove_some_links_2000 = _test_remove_links_0
    test_04_03_remove_some_links_3000 = _test_remove_links_0

    # back to using _test_add_many_users
    test_05_01_adding_users_after_links_4000 = _test_add_many_users

    # reset the link count, to replace the original links
    def test_06_01_relink_users_1000(self):
        self.state.next_linked_user = 0
        self._test_link_many_users()

    test_06_02_link_users_2000 = _test_link_many_users
    test_06_03_link_users_3000 = _test_link_many_users
    test_06_04_link_users_4000 = _test_link_many_users
    test_06_05_link_users_again_4000 = _test_link_many_users_offset_1
    test_06_06_link_users_again_4000_few_groups = _test_link_many_users_3_groups

    test_07_01_adding_users_after_links_5000 = _test_add_many_users

    def _test_link_random_users_and_groups(self, n=BATCH_SIZE, groups=100):
        self._prepare_n_groups(groups)
        for i in range(n):
            u = random.randrange(self.state.next_user_id)
            g = random.randrange(groups)
            try:
                self._link_user_and_group(u, g)
            except LdbError:
                pass

    test_08_01_link_random_users_100_groups = _test_link_random_users_and_groups
    test_08_02_link_random_users_100_groups = _test_link_random_users_and_groups

    test_10_01_unindexed_search_full_dc = _test_unindexed_search
    test_10_02_indexed_search_full_dc = _test_indexed_search
    test_11_02_join_full_dc = _test_join

    def test_20_01_delete_50_groups(self):
        for i in range(self.state.n_groups - 50, self.state.n_groups):
            self.ldb.delete("cn=g%d,%s" % (i, self.ou_groups))
        self.state.n_groups -= 50

    def _test_delete_many_users(self, n=BATCH_SIZE):
        e = self.state.next_user_id
        s = max(0, e - n)
        self.state.next_user_id = s
        for i in range(s, e):
            self.ldb.delete("cn=u%d,%s" % (i, self.ou_users))

    test_21_01_delete_users_5000_lightly_linked = _test_delete_many_users
    test_21_02_delete_users_4000_lightly_linked = _test_delete_many_users
    test_21_03_delete_users_3000 = _test_delete_many_users

    def test_22_01_delete_all_groups(self):
        for i in range(self.state.n_groups):
            self.ldb.delete("cn=g%d,%s" % (i, self.ou_groups))
        self.state.n_groups = 0

    test_23_01_delete_users_after_groups_2000 = _test_delete_many_users
    test_23_00_delete_users_after_groups_1000 = _test_delete_many_users

    test_24_02_join_after_cleanup = _test_join
Exemple #21
0
class VLVTests(samba.tests.TestCase):

    def create_user(self, i, n, prefix='vlvtest', suffix='', attrs=None):
        name = "%s%d%s" % (prefix, i, suffix)
        user = {
            'cn': name,
            "objectclass": "user",
            'givenName': "abcdefghijklmnopqrstuvwxyz"[i % 26],
            "roomNumber": "%sbc" % (n - i),
            "carLicense": "后来经",
            "employeeNumber": "%s%sx" % (abs(i * (99 - i)), '\n' * (i & 255)),
            "accountExpires": "%s" % (10 ** 9 + 1000000 * i),
            "msTSExpireDate4": "19%02d0101010000.0Z" % (i % 100),
            "flags": str(i * (n - i)),
            "serialNumber": "abc %s%s%s" % ('AaBb |-/'[i & 7],
                                            ' 3z}'[i & 3],
                                            '"@'[i & 1],),
        }

        # _user_broken_attrs tests are broken due to problems outside
        # of VLV.
        _user_broken_attrs = {
            # Sort doesn't look past a NUL byte.
            "photo": "\x00%d" % (n - i),
            "audio": "%sn octet string %s%s ♫♬\x00lalala" % ('Aa'[i & 1],
                                                             chr(i & 255), i),
            "displayNamePrintable": "%d\x00%c" % (i, i & 255),
            "adminDisplayName": "%d\x00b" % (n-i),
            "title": "%d%sb" % (n - i, '\x00' * i),
            "comment": "Favourite colour is %d" % (n % (i + 1)),

            # Names that vary only in case. Windows returns
            # equivalent addresses in the order they were put
            # in ('a st', 'A st',...).
            "street": "%s st" % (chr(65 | (i & 14) | ((i & 1) * 32))),
        }

        if attrs is not None:
            user.update(attrs)

        user['dn'] = "cn=%s,%s" % (user['cn'], self.ou)

        if opts.skip_attr_regex:
            match = re.compile(opts.skip_attr_regex).search
            for k in user.keys():
                if match(k):
                    del user[k]

        self.users.append(user)
        self.ldb.add(user)
        return user

    def setUp(self):
        super(VLVTests, self).setUp()
        self.ldb = SamDB(host, credentials=creds,
                         session_info=system_session(lp), lp=lp)

        self.base_dn = self.ldb.domain_dn()
        self.ou = "ou=vlv,%s" % self.base_dn
        if opts.delete_in_setup:
            try:
                self.ldb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError as e:
                print("tried deleting %s, got error %s" % (self.ou, e))
        self.ldb.add({
            "dn": self.ou,
            "objectclass": "organizationalUnit"})

        self.users = []
        for i in range(N_ELEMENTS):
            self.create_user(i, N_ELEMENTS)

        attrs = self.users[0].keys()
        self.binary_sorted_keys = ['audio',
                                   'photo',
                                   "msTSExpireDate4",
                                   'serialNumber',
                                   "displayNamePrintable"]

        self.numeric_sorted_keys = ['flags',
                                    'accountExpires']

        self.timestamp_keys = ['msTSExpireDate4']

        self.int64_keys = set(['accountExpires'])

        self.locale_sorted_keys = [x for x in attrs if
                                   x not in (self.binary_sorted_keys +
                                             self.numeric_sorted_keys)]

        # don't try spaces, etc in cn
        self.delicate_keys = ['cn']

    def tearDown(self):
        super(VLVTests, self).tearDown()
        if not opts.delete_in_setup:
            self.ldb.delete(self.ou, ['tree_delete:1'])

    def get_full_list(self, attr, include_cn=False):
        """Fetch the whole list sorted on the attribute, using the VLV.
        This way you get a VLV cookie."""
        n_users = len(self.users)
        sort_control = "server_sort:1:0:%s" % attr
        half_n = n_users // 2
        vlv_search = "vlv:1:%d:%d:%d:0" % (half_n, half_n, half_n + 1)
        attrs = [attr]
        if include_cn:
            attrs.append('cn')
        res = self.ldb.search(self.ou,
                              scope=ldb.SCOPE_ONELEVEL,
                              attrs=attrs,
                              controls=[sort_control,
                                        vlv_search])
        if include_cn:
            full_results = [(x[attr][0], x['cn'][0]) for x in res]
        else:
            full_results = [x[attr][0].lower() for x in res]
        controls = res.controls
        return full_results, controls, sort_control

    def get_expected_order(self, attr, expression=None):
        """Fetch the whole list sorted on the attribute, using sort only."""
        sort_control = "server_sort:1:0:%s" % attr
        res = self.ldb.search(self.ou,
                              scope=ldb.SCOPE_ONELEVEL,
                              expression=expression,
                              attrs=[attr],
                              controls=[sort_control])
        results = [x[attr][0] for x in res]
        return results

    def delete_user(self, user):
        self.ldb.delete(user['dn'])
        del self.users[self.users.index(user)]

    def get_gte_tests_and_order(self, attr, expression=None):
        expected_order = self.get_expected_order(attr, expression=expression)
        gte_users = []
        if attr in self.delicate_keys:
            gte_keys = [
                '3',
                'abc',
                '¹',
                'ŋđ¼³ŧ“«đð',
                '桑巴',
            ]
        elif attr in self.timestamp_keys:
            gte_keys = [
                '18560101010000.0Z',
                '19140103010000.0Z',
                '19560101010010.0Z',
                '19700101000000.0Z',
                '19991231211234.3Z',
                '20061111211234.0Z',
                '20390901041234.0Z',
                '25560101010000.0Z',
            ]
        elif attr not in self.numeric_sorted_keys:
            gte_keys = [
                '3',
                'abc',
                ' ',
                '!@#!@#!',
                'kōkako',
                '¹',
                'ŋđ¼³ŧ“«đð',
                '\n\t\t',
                '桑巴',
                'zzzz',
            ]
            if expected_order:
                gte_keys.append(expected_order[len(expected_order) // 2] + ' tail')

        else:
            # "numeric" means positive integers
            # doesn't work with -1, 3.14, ' 3', '9' * 20
            gte_keys = ['3',
                        '1' * 10,
                        '1',
                        '9' * 7,
                        '0']

            if attr in self.int64_keys:
                gte_keys += ['3' * 12, '71' * 8]

        for i, x in enumerate(gte_keys):
            user = self.create_user(i, N_ELEMENTS,
                                    prefix='gte',
                                    attrs={attr: x})
            gte_users.append(user)

        gte_order = self.get_expected_order(attr)
        for user in gte_users:
            self.delete_user(user)

        # for sanity's sake
        expected_order_2 = self.get_expected_order(attr, expression=expression)
        self.assertEqual(expected_order, expected_order_2)

        # Map gte tests to indexes in expected order. This will break
        # if gte_order and expected_order are differently ordered (as
        # it should).
        gte_map = {}

        # index to the first one with each value
        index_map = {}
        for i, k in enumerate(expected_order):
            if k not in index_map:
                index_map[k] = i

        keys = []
        for k in gte_order:
            if k in index_map:
                i = index_map[k]
                gte_map[k] = i
                for k in keys:
                    gte_map[k] = i
                keys = []
            else:
                keys.append(k)

        for k in keys:
            gte_map[k] = len(expected_order)

        if False:
            print("gte_map:")
            for k in gte_order:
                print("   %10s => %10s" % (k, gte_map[k]))

        return gte_order, expected_order, gte_map

    def assertCorrectResults(self, results, expected_order,
                             offset, before, after):
        """A helper to calculate offsets correctly and say as much as possible
        when something goes wrong."""

        start = max(offset - before - 1, 0)
        end = offset + after
        expected_results = expected_order[start: end]

        # if it is a tuple with the cn, drop the cn
        if expected_results and isinstance(expected_results[0], tuple):
            expected_results = [x[0] for x in expected_results]

        if expected_results == results:
            return

        if expected_order is not None:
            print("expected order: %s" % expected_order[:20])
            if len(expected_order) > 20:
                print("... and %d more not shown" % (len(expected_order) - 20))

        print("offset %d before %d after %d" % (offset, before, after))
        print("start %d end %d" % (start, end))
        print("expected: %s" % expected_results)
        print("got     : %s" % results)
        self.assertEquals(expected_results, results)

    def test_server_vlv_with_cookie(self):
        attrs = [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass')]
        for attr in attrs:
            expected_order = self.get_expected_order(attr)
            sort_control = "server_sort:1:0:%s" % attr
            res = None
            n = len(self.users)
            for before in [10, 0, 3, 1, 4, 5, 2]:
                for after in [0, 3, 1, 4, 5, 2, 7]:
                    for offset in range(max(1, before - 2),
                                        min(n - after + 2, n)):
                        if res is None:
                            vlv_search = "vlv:1:%d:%d:%d:0" % (before, after,
                                                               offset)
                        else:
                            cookie = get_cookie(res.controls, n)
                            vlv_search = ("vlv:1:%d:%d:%d:%s:%s" %
                                          (before, after, offset, n,
                                           cookie))

                        res = self.ldb.search(self.ou,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        vlv_search])

                        results = [x[attr][0] for x in res]

                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def run_index_tests_with_expressions(self, expressions):
        # Here we don't test every before/after combination.
        attrs = [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass')]
        for attr in attrs:
            for expression in expressions:
                expected_order = self.get_expected_order(attr, expression)
                sort_control = "server_sort:1:0:%s" % attr
                res = None
                n = len(expected_order)
                for before in range(0, 11):
                    after = before
                    for offset in range(max(1, before - 2),
                                        min(n - after + 2, n)):
                        if res is None:
                            vlv_search = "vlv:1:%d:%d:%d:0" % (before, after,
                                                               offset)
                        else:
                            cookie = get_cookie(res.controls)
                            vlv_search = ("vlv:1:%d:%d:%d:%s:%s" %
                                          (before, after, offset, n,
                                           cookie))

                        res = self.ldb.search(self.ou,
                                              expression=expression,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        vlv_search])

                        results = [x[attr][0] for x in res]

                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def test_server_vlv_with_expression(self):
        """What happens when we run the VLV with an expression?"""
        expressions = ["(objectClass=*)",
                       "(cn=%s)" % self.users[-1]['cn'],
                       "(roomNumber=%s)" % self.users[0]['roomNumber'],
                       ]
        self.run_index_tests_with_expressions(expressions)

    def test_server_vlv_with_failing_expression(self):
        """What happens when we run the VLV on an expression that matches
        nothing?"""
        expressions = ["(samaccountname=testferf)",
                       "(cn=hefalump)",
                       ]
        self.run_index_tests_with_expressions(expressions)

    def run_gte_tests_with_expressions(self, expressions):
        # Here we don't test every before/after combination.
        attrs = [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass')]
        for expression in expressions:
            for attr in attrs:
                gte_order, expected_order, gte_map = \
                    self.get_gte_tests_and_order(attr, expression)
                # In case there is some order dependency, disorder tests
                gte_tests = gte_order[:]
                random.seed(2)
                random.shuffle(gte_tests)
                res = None
                sort_control = "server_sort:1:0:%s" % attr

                expected_order = self.get_expected_order(attr, expression)
                sort_control = "server_sort:1:0:%s" % attr
                res = None
                for before in range(0, 11):
                    after = before
                    for gte in gte_tests:
                        if res is not None:
                            cookie = get_cookie(res.controls)
                        else:
                            cookie = None
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        gte=gte,
                                                        cookie=cookie)

                        res = self.ldb.search(self.ou,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              expression=expression,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        vlv_search])

                        results = [x[attr][0] for x in res]
                        offset = gte_map.get(gte, len(expected_order))

                        # here offset is 0-based
                        start = max(offset - before, 0)
                        end = offset + 1 + after

                        expected_results = expected_order[start: end]

                        self.assertEquals(expected_results, results)

    def test_vlv_gte_with_expression(self):
        """What happens when we run the VLV with an expression?"""
        expressions = ["(objectClass=*)",
                       "(cn=%s)" % self.users[-1]['cn'],
                       "(roomNumber=%s)" % self.users[0]['roomNumber'],
                       ]
        self.run_gte_tests_with_expressions(expressions)

    def test_vlv_gte_with_failing_expression(self):
        """What happens when we run the VLV on an expression that matches
        nothing?"""
        expressions = ["(samaccountname=testferf)",
                       "(cn=hefalump)",
                       ]
        self.run_gte_tests_with_expressions(expressions)

    def test_server_vlv_with_cookie_while_adding_and_deleting(self):
        """What happens if we add or remove items in the middle of the VLV?

        Nothing. The search and the sort is not repeated, and we only
        deal with the objects originally found.
        """
        attrs = ['cn'] + [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass')]
        user_number = 0
        iteration = 0
        for attr in attrs:
            full_results, controls, sort_control = \
                            self.get_full_list(attr, True)
            original_n = len(self.users)

            expected_order = full_results
            random.seed(1)

            for before in range(0, 3) + [6, 11, 19]:
                for after in range(0, 3) + [6, 11, 19]:
                    start = max(before - 1, 1)
                    end = max(start + 4, original_n - after + 2)
                    for offset in range(start, end):
                        #if iteration > 2076:
                        #    return
                        cookie = get_cookie(controls, original_n)
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        offset=offset,
                                                        n=original_n,
                                                        cookie=cookie)

                        iteration += 1
                        res = self.ldb.search(self.ou,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        vlv_search])

                        controls = res.controls
                        results = [x[attr][0] for x in res]
                        real_offset = max(1, min(offset, len(expected_order)))

                        expected_results = []
                        skipped = 0
                        begin_offset = max(real_offset - before - 1, 0)
                        real_before = min(before, real_offset - 1)
                        real_after = min(after,
                                         len(expected_order) - real_offset)

                        for x in expected_order[begin_offset:]:
                            if x is not None:
                                expected_results.append(x[0])
                                if (len(expected_results) ==
                                    real_before + real_after + 1):
                                    break
                            else:
                                skipped += 1

                        if expected_results != results:
                            print ("attr %s before %d after %d offset %d" %
                                   (attr, before, after, offset))
                        self.assertEquals(expected_results, results)

                        n = len(self.users)
                        if random.random() < 0.1 + (n < 5) * 0.05:
                            if n == 0:
                                i = 0
                            else:
                                i = random.randrange(n)
                            user = self.create_user(i, n, suffix='-%s' %
                                                    user_number)
                            user_number += 1
                        if random.random() < 0.1  + (n > 50) * 0.02 and n:
                            index = random.randrange(n)
                            user = self.users.pop(index)

                            self.ldb.delete(user['dn'])

                            replaced = (user[attr], user['cn'])
                            if replaced in expected_order:
                                i = expected_order.index(replaced)
                                expected_order[i] = None

    def test_server_vlv_with_cookie_while_changing(self):
        """What happens if we modify items in the middle of the VLV?

        The expected behaviour (as found on Windows) is the sort is
        not repeated, but the changes in attributes are reflected.
        """
        attrs = [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass', 'cn')]
        for attr in attrs:
            n_users = len(self.users)
            expected_order = [x.upper() for x in self.get_expected_order(attr)]
            sort_control = "server_sort:1:0:%s" % attr
            res = None
            i = 0

            # First we'll fetch the whole list so we know the original
            # sort order. This is necessary because we don't know how
            # the server will order equivalent items. We are using the
            # dn as a key.
            half_n = n_users // 2
            vlv_search = "vlv:1:%d:%d:%d:0" % (half_n, half_n, half_n + 1)
            res = self.ldb.search(self.ou,
                                  scope=ldb.SCOPE_ONELEVEL,
                                  attrs=['dn', attr],
                                  controls=[sort_control, vlv_search])

            results = [x[attr][0].upper() for x in res]
            #self.assertEquals(expected_order, results)

            dn_order = [str(x['dn']) for x in res]
            values = results[:]

            for before in range(0, 3):
                for after in range(0, 3):
                    for offset in range(1 + before, n_users - after):
                        cookie = get_cookie(res.controls, len(self.users))
                        vlv_search = ("vlv:1:%d:%d:%d:%s:%s" %
                                      (before, after, offset, len(self.users),
                                       cookie))

                        res = self.ldb.search(self.ou,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=['dn', attr],
                                              controls=[sort_control,
                                                        vlv_search])

                        dn_results = [str(x['dn']) for x in res]
                        dn_expected = dn_order[offset - before - 1:
                                               offset + after]

                        self.assertEquals(dn_expected, dn_results)

                        results = [x[attr][0].upper() for x in res]

                        self.assertCorrectResults(results, values,
                                                  offset, before, after)

                        i += 1
                        if i % 3 == 2:
                            if (attr in self.locale_sorted_keys or
                                attr in self.binary_sorted_keys):
                                i1 = i % n_users
                                i2 = (i ^ 255) % n_users
                                dn1 = dn_order[i1]
                                dn2 = dn_order[i2]
                                v2 = values[i2]

                                if v2 in self.locale_sorted_keys:
                                    v2 += '-%d' % i
                                cn1 = dn1.split(',', 1)[0][3:]
                                cn2 = dn2.split(',', 1)[0][3:]

                                values[i1] = v2

                                m = ldb.Message()
                                m.dn = ldb.Dn(self.ldb, dn1)
                                m[attr] = ldb.MessageElement(v2,
                                                             ldb.FLAG_MOD_REPLACE,
                                                             attr)

                                self.ldb.modify(m)

    def test_server_vlv_fractions_with_cookie(self):
        """What happens when the count is set to an arbitrary number?

        In that case the offset and the count form a fraction, and the
        VLV should be centred at a point offset/count of the way
        through. For example, if offset is 3 and count is 6, the VLV
        should be looking around halfway. The actual algorithm is a
        bit fiddlier than that, because of the one-basedness of VLV.
        """
        attrs = [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass')]

        n_users = len(self.users)

        random.seed(4)

        for attr in attrs:
            full_results, controls, sort_control = self.get_full_list(attr)
            self.assertEqual(len(full_results), n_users)
            for before in range(0, 2):
                for after in range(0, 2):
                    for denominator in range(1, 20):
                        for offset in range(1, denominator + 3):
                            cookie = get_cookie(controls, len(self.users))
                            vlv_search = ("vlv:1:%d:%d:%d:%s:%s" %
                                          (before, after, offset,
                                           denominator,
                                           cookie))
                            try:
                                res = self.ldb.search(self.ou,
                                                      scope=ldb.SCOPE_ONELEVEL,
                                                      attrs=[attr],
                                                      controls=[sort_control,
                                                                vlv_search])
                            except ldb.LdbError as e:
                                if offset != 0:
                                    raise
                                print ("offset %d denominator %d raised error "
                                       "expected error %s\n"
                                       "(offset zero is illegal unless "
                                       "content count is zero)" %
                                       (offset, denominator, e))
                                continue

                            results = [x[attr][0].lower() for x in res]

                            if denominator == 0:
                                denominator = n_users
                                if offset == 0:
                                    offset = denominator
                            elif denominator == 1:
                                # the offset can only be 1, but the 1/1 case
                                # means something special
                                if offset == 1:
                                    real_offset = n_users
                                else:
                                    real_offset = 1
                            else:
                                if offset > denominator:
                                    offset = denominator
                                real_offset = (1 +
                                               int(round((n_users - 1) *
                                                         (offset - 1) /
                                                         (denominator - 1.0)))
                                )

                            self.assertCorrectResults(results, full_results,
                                                      real_offset, before,
                                                      after)

                            controls = res.controls
                            if False:
                                for c in list(controls):
                                    cstr = str(c)
                                    if cstr.startswith('vlv_resp'):
                                        bits = cstr.rsplit(':')
                                        print ("the answer is %s; we said %d" %
                                               (bits[2], real_offset))
                                        break

    def test_server_vlv_no_cookie(self):
        attrs = [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass')]

        for attr in attrs:
            expected_order = self.get_expected_order(attr)
            sort_control = "server_sort:1:0:%s" % attr
            for before in range(0, 5):
                for after in range(0, 7):
                    for offset in range(1 + before, len(self.users) - after):
                        res = self.ldb.search(self.ou,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        "vlv:1:%d:%d:%d:0" %
                                                        (before, after,
                                                         offset)])
                        results = [x[attr][0] for x in res]
                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def get_expected_order_showing_deleted(self, attr,
                                           expression="(|(cn=vlvtest*)(cn=vlv-deleted*))",
                                           base=None,
                                           scope=ldb.SCOPE_SUBTREE
                                           ):
        """Fetch the whole list sorted on the attribute, using sort only,
        searching in the entire tree, not just our OU. This is the
        way to find deleted objects.
        """
        if base is None:
            base = self.base_dn
        sort_control = "server_sort:1:0:%s" % attr
        controls = [sort_control, "show_deleted:1"]

        res = self.ldb.search(base,
                              scope=scope,
                              expression=expression,
                              attrs=[attr],
                              controls=controls)
        results = [x[attr][0] for x in res]
        return results

    def add_deleted_users(self, n):
        deleted_users = [self.create_user(i, n, prefix='vlv-deleted')
                         for i in range(n)]

        for user in deleted_users:
            self.delete_user(user)

    def test_server_vlv_no_cookie_show_deleted(self):
        """What do we see with the show_deleted control?"""
        attrs = ['objectGUID',
                 'cn',
                 'sAMAccountName',
                 'objectSid',
                 'name',
                 'whenChanged',
                 'usnChanged'
        ]

        # add some deleted users first, just in case there are none
        self.add_deleted_users(6)
        random.seed(22)
        expression = "(|(cn=vlvtest*)(cn=vlv-deleted*))"

        for attr in attrs:
            show_deleted_control = "show_deleted:1"
            expected_order = self.get_expected_order_showing_deleted(attr,
                                                                     expression)
            n = len(expected_order)
            sort_control = "server_sort:1:0:%s" % attr
            for before in [3, 1, 0]:
                for after in [0, 2]:
                    # don't test every position, because there could be hundreds.
                    # jump back and forth instead
                    for i in range(20):
                        offset = random.randrange(max(1, before - 2),
                                                  min(n - after + 2, n))
                        res = self.ldb.search(self.base_dn,
                                              expression=expression,
                                              scope=ldb.SCOPE_SUBTREE,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        show_deleted_control,
                                                        "vlv:1:%d:%d:%d:0" %
                                                        (before, after,
                                                         offset)
                                              ]
                        )
                        results = [x[attr][0] for x in res]
                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def test_server_vlv_no_cookie_show_deleted_only(self):
        """What do we see with the show_deleted control when we're not looking
        at any non-deleted things"""
        attrs = ['objectGUID',
                 'cn',
                 'sAMAccountName',
                 'objectSid',
                 'whenChanged',
        ]

        # add some deleted users first, just in case there are none
        self.add_deleted_users(4)
        base = 'CN=Deleted Objects,%s' % self.base_dn
        expression = "(cn=vlv-deleted*)"
        for attr in attrs:
            show_deleted_control = "show_deleted:1"
            expected_order = self.get_expected_order_showing_deleted(attr,
                                                    expression=expression,
                                                    base=base,
                                                    scope=ldb.SCOPE_ONELEVEL)
            print ("searching for attr %s amongst %d deleted objects" %
                   (attr, len(expected_order)))
            sort_control = "server_sort:1:0:%s" % attr
            step = max(len(expected_order) // 10, 1)
            for before in [3, 0]:
                for after in [0, 2]:
                    for offset in range(1 + before,
                                        len(expected_order) - after,
                                        step):
                        res = self.ldb.search(base,
                                              expression=expression,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        show_deleted_control,
                                                        "vlv:1:%d:%d:%d:0" %
                                                        (before, after,
                                                         offset)])
                        results = [x[attr][0] for x in res]
                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)



    def test_server_vlv_with_cookie_show_deleted(self):
        """What do we see with the show_deleted control?"""
        attrs = ['objectGUID',
                 'cn',
                 'sAMAccountName',
                 'objectSid',
                 'name',
                 'whenChanged',
                 'usnChanged'
        ]
        self.add_deleted_users(6)
        random.seed(23)
        for attr in attrs:
            expected_order = self.get_expected_order(attr)
            sort_control = "server_sort:1:0:%s" % attr
            res = None
            show_deleted_control = "show_deleted:1"
            expected_order = self.get_expected_order_showing_deleted(attr)
            n = len(expected_order)
            expression = "(|(cn=vlvtest*)(cn=vlv-deleted*))"
            for before in [3, 2, 1, 0]:
                after = before
                for i in range(20):
                    offset = random.randrange(max(1, before - 2),
                                              min(n - after + 2, n))
                    if res is None:
                        vlv_search = "vlv:1:%d:%d:%d:0" % (before, after,
                                                           offset)
                    else:
                        cookie = get_cookie(res.controls, n)
                        vlv_search = ("vlv:1:%d:%d:%d:%s:%s" %
                                      (before, after, offset, n,
                                       cookie))

                    res = self.ldb.search(self.base_dn,
                                          expression=expression,
                                          scope=ldb.SCOPE_SUBTREE,
                                          attrs=[attr],
                                          controls=[sort_control,
                                                    vlv_search,
                                                    show_deleted_control])

                    results = [x[attr][0] for x in res]

                    self.assertCorrectResults(results, expected_order,
                                              offset, before, after)


    def test_server_vlv_gte_with_cookie(self):
        attrs = [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass')]
        for attr in attrs:
            gte_order, expected_order, gte_map = \
                                        self.get_gte_tests_and_order(attr)
            # In case there is some order dependency, disorder tests
            gte_tests = gte_order[:]
            random.seed(1)
            random.shuffle(gte_tests)
            res = None
            sort_control = "server_sort:1:0:%s" % attr
            for before in [0, 1, 2, 4]:
                for after in [0, 1, 3, 6]:
                    for gte in gte_tests:
                        if res is not None:
                            cookie = get_cookie(res.controls, len(self.users))
                        else:
                            cookie = None
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        gte=gte,
                                                        cookie=cookie)

                        res = self.ldb.search(self.ou,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        vlv_search])

                        results = [x[attr][0] for x in res]
                        offset = gte_map.get(gte, len(expected_order))

                        # here offset is 0-based
                        start = max(offset - before, 0)
                        end = offset + 1 + after

                        expected_results = expected_order[start: end]

                        self.assertEquals(expected_results, results)

    def test_server_vlv_gte_no_cookie(self):
        attrs = [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass')]
        iteration = 0
        for attr in attrs:
            gte_order, expected_order, gte_map = \
                                        self.get_gte_tests_and_order(attr)
            # In case there is some order dependency, disorder tests
            gte_tests = gte_order[:]
            random.seed(1)
            random.shuffle(gte_tests)

            sort_control = "server_sort:1:0:%s" % attr
            for before in [0, 1, 3]:
                for after in [0, 4]:
                    for gte in gte_tests:
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        gte=gte)

                        res = self.ldb.search(self.ou,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        vlv_search])
                        results = [x[attr][0] for x in res]

                        # here offset is 0-based
                        offset = gte_map.get(gte, len(expected_order))
                        start = max(offset - before, 0)
                        end = offset + after + 1
                        expected_results = expected_order[start: end]
                        iteration += 1
                        if expected_results != results:
                            middle = expected_order[len(expected_order) // 2]
                            print(expected_results, results)
                            print(middle)
                            print(expected_order)
                            print()
                            print ("\nattr %s offset %d before %d "
                                   "after %d gte %s" %
                                   (attr, offset, before, after, gte))
                        self.assertEquals(expected_results, results)

    def test_multiple_searches(self):
        """The maximum number of concurrent vlv searches per connection is
        currently set at 3. That means if you open 4 VLV searches the
        cookie on the first one should fail.
        """
        # Windows has a limit of 10 VLVs where there are low numbers
        # of objects in each search.
        attrs = ([x for x in self.users[0].keys() if x not in
                  ('dn', 'objectclass')] * 2)[:12]

        vlv_cookies = []
        for attr in attrs:
            sort_control = "server_sort:1:0:%s" % attr

            res = self.ldb.search(self.ou,
                                  scope=ldb.SCOPE_ONELEVEL,
                                  attrs=[attr],
                                  controls=[sort_control,
                                            "vlv:1:1:1:1:0"])

            cookie = get_cookie(res.controls, len(self.users))
            vlv_cookies.append(cookie)
            time.sleep(0.2)

        # now this one should fail
        self.assertRaises(ldb.LdbError,
                          self.ldb.search,
                          self.ou,
                          scope=ldb.SCOPE_ONELEVEL,
                          attrs=[attr],
                          controls=[sort_control,
                                    "vlv:1:1:1:1:0:%s" % vlv_cookies[0]])

        # and this one should succeed
        res = self.ldb.search(self.ou,
                              scope=ldb.SCOPE_ONELEVEL,
                              attrs=[attr],
                              controls=[sort_control,
                                        "vlv:1:1:1:1:0:%s" % vlv_cookies[-1]])

        # this one should fail because it is a new connection and
        # doesn't share cookies
        new_ldb = SamDB(host, credentials=creds,
                        session_info=system_session(lp), lp=lp)

        self.assertRaises(ldb.LdbError,
                          new_ldb.search, self.ou,
                          scope=ldb.SCOPE_ONELEVEL,
                          attrs=[attr],
                          controls=[sort_control,
                                    "vlv:1:1:1:1:0:%s" % vlv_cookies[-1]])

        # but now without the critical flag it just does no VLV.
        new_ldb.search(self.ou,
                       scope=ldb.SCOPE_ONELEVEL,
                       attrs=[attr],
                       controls=[sort_control,
                                 "vlv:0:1:1:1:0:%s" % vlv_cookies[-1]])
class AuthLogPassChangeTests(samba.tests.auth_log_base.AuthLogTestBase):

    def setUp(self):
        super(AuthLogPassChangeTests, self).setUp()

        self.remoteAddress = os.environ["CLIENT_IP"]
        self.server_ip = os.environ["SERVER_IP"]

        host = "ldap://%s" % os.environ["SERVER"]
        self.ldb = SamDB(url=host,
                         session_info=system_session(),
                         credentials=self.get_credentials(),
                         lp=self.get_loadparm())

        print("ldb %s" % type(self.ldb))
        # Gets back the basedn
        base_dn = self.ldb.domain_dn()
        print("base_dn %s" % base_dn)

        # permit password changes during this test
        PasswordCommon.allow_password_changes(self, self.ldb)

        self.base_dn = self.ldb.domain_dn()

        # (Re)adds the test user USER_NAME with password USER_PASS
        delete_force(self.ldb, "cn=" + USER_NAME + ",cn=users," + self.base_dn)
        self.ldb.add({
            "dn": "cn=" + USER_NAME + ",cn=users," + self.base_dn,
            "objectclass": "user",
            "sAMAccountName": USER_NAME,
            "userPassword": USER_PASS
        })

        # discard any auth log messages for the password setup
        self.discardMessages()

    def tearDown(self):
        super(AuthLogPassChangeTests, self).tearDown()

    def test_admin_change_password(self):
        def isLastExpectedMessage(msg):
            return ((msg["type"] == "Authentication") and
                    (msg["Authentication"]["status"] == "NT_STATUS_OK") and
                    (msg["Authentication"]["serviceDescription"] ==
                        "SAMR Password Change") and
                    (msg["Authentication"]["authDescription"] ==
                        "samr_ChangePasswordUser3"))

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server_ip)
        password = "******"

        net.change_password(newpassword=password.encode('utf-8'),
                            username=USER_NAME,
                            oldpassword=USER_PASS)

        messages = self.waitForMessages(isLastExpectedMessage)
        print("Received %d messages" % len(messages))
        self.assertEquals(8,
                          len(messages),
                          "Did not receive the expected number of messages")

    def test_admin_change_password_new_password_fails_restriction(self):
        def isLastExpectedMessage(msg):
            return ((msg["type"] == "Authentication") and
                    (msg["Authentication"]["status"] ==
                        "NT_STATUS_PASSWORD_RESTRICTION") and
                    (msg["Authentication"]["serviceDescription"] ==
                        "SAMR Password Change") and
                    (msg["Authentication"]["authDescription"] ==
                        "samr_ChangePasswordUser3"))

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server_ip)
        password = "******"

        exception_thrown = False
        try:
            net.change_password(newpassword=password.encode('utf-8'),
                                oldpassword=USER_PASS,
                                username=USER_NAME)
        except Exception:
            exception_thrown = True
        self.assertEquals(True, exception_thrown,
                          "Expected exception not thrown")

        messages = self.waitForMessages(isLastExpectedMessage)
        self.assertEquals(8,
                          len(messages),
                          "Did not receive the expected number of messages")

    def test_admin_change_password_unknown_user(self):
        def isLastExpectedMessage(msg):
            return ((msg["type"] == "Authentication") and
                    (msg["Authentication"]["status"] ==
                        "NT_STATUS_NO_SUCH_USER") and
                    (msg["Authentication"]["serviceDescription"] ==
                        "SAMR Password Change") and
                    (msg["Authentication"]["authDescription"] ==
                        "samr_ChangePasswordUser3"))

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server_ip)
        password = "******"

        exception_thrown = False
        try:
            net.change_password(newpassword=password.encode('utf-8'),
                                oldpassword=USER_PASS,
                                username="******")
        except Exception:
            exception_thrown = True
        self.assertEquals(True, exception_thrown,
                          "Expected exception not thrown")

        messages = self.waitForMessages(isLastExpectedMessage)
        self.assertEquals(8,
                          len(messages),
                          "Did not receive the expected number of messages")

    def test_admin_change_password_bad_original_password(self):
        def isLastExpectedMessage(msg):
            return ((msg["type"] == "Authentication") and
                    (msg["Authentication"]["status"] ==
                        "NT_STATUS_WRONG_PASSWORD") and
                    (msg["Authentication"]["serviceDescription"] ==
                        "SAMR Password Change") and
                    (msg["Authentication"]["authDescription"] ==
                        "samr_ChangePasswordUser3"))

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server_ip)
        password = "******"

        exception_thrown = False
        try:
            net.change_password(newpassword=password.encode('utf-8'),
                                oldpassword="******",
                                username=USER_NAME)
        except Exception:
            exception_thrown = True
        self.assertEquals(True, exception_thrown,
                          "Expected exception not thrown")

        messages = self.waitForMessages(isLastExpectedMessage)
        self.assertEquals(8,
                          len(messages),
                          "Did not receive the expected number of messages")

    # net rap password changes are broken, but they trigger enough of the
    # server side behaviour to exercise the code paths of interest.
    # if we used the real password it would be too long and does not hash
    # correctly, so we just check it triggers the wrong password path.
    def test_rap_change_password(self):
        def isLastExpectedMessage(msg):
            return ((msg["type"] == "Authentication") and
                    (msg["Authentication"]["serviceDescription"] ==
                        "SAMR Password Change") and
                    (msg["Authentication"]["status"] ==
                        "NT_STATUS_WRONG_PASSWORD") and
                    (msg["Authentication"]["authDescription"] ==
                        "OemChangePasswordUser2"))

        username = os.environ["USERNAME"]
        server = os.environ["SERVER"]
        password = os.environ["PASSWORD"]
        server_param = "--server=%s" % server
        creds = "-U%s%%%s" % (username, password)
        call(["bin/net", "rap", server_param,
              "password", USER_NAME, "notMyPassword", "notGoingToBeMyPassword",
              server, creds, "--option=client ipc max protocol=nt1"])

        messages = self.waitForMessages(isLastExpectedMessage)
        self.assertEquals(7,
                          len(messages),
                          "Did not receive the expected number of messages")

    def test_ldap_change_password(self):
        def isLastExpectedMessage(msg):
            return ((msg["type"] == "Authentication") and
                    (msg["Authentication"]["status"] == "NT_STATUS_OK") and
                    (msg["Authentication"]["serviceDescription"] ==
                        "LDAP Password Change") and
                    (msg["Authentication"]["authDescription"] ==
                        "LDAP Modify"))

        new_password = samba.generate_random_password(32, 32)
        self.ldb.modify_ldif(
            "dn: cn=" + USER_NAME + ",cn=users," + self.base_dn + "\n" +
            "changetype: modify\n" +
            "delete: userPassword\n" +
            "userPassword: "******"\n" +
            "add: userPassword\n" +
            "userPassword: "******"\n")

        messages = self.waitForMessages(isLastExpectedMessage)
        print("Received %d messages" % len(messages))
        self.assertEquals(4,
                          len(messages),
                          "Did not receive the expected number of messages")

    #
    # Currently this does not get logged, so we expect to only see the log
    # entries for the underlying ldap bind.
    #
    def test_ldap_change_password_bad_user(self):
        def isLastExpectedMessage(msg):
            return (msg["type"] == "Authorization" and
                    msg["Authorization"]["serviceDescription"] == "LDAP" and
                    msg["Authorization"]["authType"] == "krb5")

        new_password = samba.generate_random_password(32, 32)
        try:
            self.ldb.modify_ldif(
                "dn: cn=" + "badUser" + ",cn=users," + self.base_dn + "\n" +
                "changetype: modify\n" +
                "delete: userPassword\n" +
                "userPassword: "******"\n" +
                "add: userPassword\n" +
                "userPassword: "******"\n")
            self.fail()
        except LdbError as e:
            (num, msg) = e.args
            pass

        messages = self.waitForMessages(isLastExpectedMessage)
        print("Received %d messages" % len(messages))
        self.assertEquals(3,
                          len(messages),
                          "Did not receive the expected number of messages")

    def test_ldap_change_password_bad_original_password(self):
        def isLastExpectedMessage(msg):
            return ((msg["type"] == "Authentication") and
                    (msg["Authentication"]["status"] ==
                        "NT_STATUS_WRONG_PASSWORD") and
                    (msg["Authentication"]["serviceDescription"] ==
                        "LDAP Password Change") and
                    (msg["Authentication"]["authDescription"] ==
                        "LDAP Modify"))

        new_password = samba.generate_random_password(32, 32)
        try:
            self.ldb.modify_ldif(
                "dn: cn=" + USER_NAME + ",cn=users," + self.base_dn + "\n" +
                "changetype: modify\n" +
                "delete: userPassword\n" +
                "userPassword: "******"badPassword" + "\n" +
                "add: userPassword\n" +
                "userPassword: "******"\n")
            self.fail()
        except LdbError as e1:
            (num, msg) = e1.args
            pass

        messages = self.waitForMessages(isLastExpectedMessage)
        print("Received %d messages" % len(messages))
        self.assertEquals(4,
                          len(messages),
                          "Did not receive the expected number of messages")
class UserTests(samba.tests.TestCase):

    def add_if_possible(self, *args, **kwargs):
        """In these tests sometimes things are left in the database
        deliberately, so we don't worry if we fail to add them a second
        time."""
        try:
            self.ldb.add(*args, **kwargs)
        except LdbError:
            pass

    def setUp(self):
        super(UserTests, self).setUp()
        self.state = GlobalState  # the class itself, not an instance
        self.lp = lp
        self.ldb = SamDB(host, credentials=creds,
                         session_info=system_session(lp), lp=lp)
        self.base_dn = self.ldb.domain_dn()
        self.ou = "OU=pid%s,%s" % (os.getpid(), self.base_dn)
        self.ou_users = "OU=users,%s" % self.ou
        self.ou_groups = "OU=groups,%s" % self.ou
        self.ou_computers = "OU=computers,%s" % self.ou

        for dn in (self.ou, self.ou_users, self.ou_groups,
                   self.ou_computers):
            self.add_if_possible({
                "dn": dn,
                "objectclass": "organizationalUnit"})

    def tearDown(self):
        super(UserTests, self).tearDown()

    def test_00_00_do_nothing(self):
        # this gives us an idea of the overhead
        pass

    def _prepare_n_groups(self, n):
        self.state.n_groups = n
        for i in range(n):
            self.add_if_possible({
                "dn": "cn=g%d,%s" % (i, self.ou_groups),
                "objectclass": "group"})

    def _add_users(self, start, end):
        for i in range(start, end):
            self.ldb.add({
                "dn": "cn=u%d,%s" % (i, self.ou_users),
                "objectclass": "user"})

    def _add_users_ldif(self, start, end):
        lines = []
        for i in range(start, end):
            lines.append("dn: cn=u%d,%s" % (i, self.ou_users))
            lines.append("objectclass: user")
            lines.append("")
        self.ldb.add_ldif('\n'.join(lines))

    def _test_unindexed_search(self):
        expressions = [
            ('(&(objectclass=user)(description='
             'Built-in account for adminstering the computer/domain))'),
            '(description=Built-in account for adminstering the computer/domain)',
            '(objectCategory=*)',
            '(samaccountname=Administrator*)'
        ]
        for expression in expressions:
            t = time.time()
            for i in range(50):
                self.ldb.search(self.ou,
                                expression=expression,
                                scope=SCOPE_SUBTREE,
                                attrs=['cn'])
            print >> sys.stderr, '%d %s took %s' % (i, expression,
                                                    time.time() - t)

    def _test_indexed_search(self):
        expressions = ['(objectclass=group)',
                       '(samaccountname=Administrator)'
        ]
        for expression in expressions:
            t = time.time()
            for i in range(10000):
                self.ldb.search(self.ou,
                                expression=expression,
                                scope=SCOPE_SUBTREE,
                                attrs=['cn'])
            print >> sys.stderr, '%d runs %s took %s' % (i, expression,
                                                         time.time() - t)

    def _test_complex_search(self):
        classes = ['samaccountname', 'objectCategory', 'dn', 'member']
        values = ['*', '*t*', 'g*', 'user']
        comparators = ['=', '<=', '>='] # '~=' causes error
        maybe_not = ['!(', '']
        joiners = ['&', '|']

        # The number of permuations is 18432, which is not huge but
        # would take hours to search. So we take a sample.
        all_permutations = list(itertools.product(joiners,
                                                  classes, classes,
                                                  values, values,
                                                  comparators, comparators,
                                                  maybe_not, maybe_not))
        random.seed(1)



        for (j, c1, c2, v1, v2,
             o1, o2, n1, n2) in random.sample(all_permutations, 100):
            expression = ''.join(['(', j,
                                  '(', n1, c1, o1, v1,
                                  '))' if n1 else ')',
                                  '(', n2, c2, o2, v2,
                                  '))' if n2 else ')',
                                  ')'])
            print expression
            self.ldb.search(self.ou,
                            expression=expression,
                            scope=SCOPE_SUBTREE,
                            attrs=['cn'])

    def _test_member_search(self, rounds=10):
        expressions = []
        for d in range(50):
            expressions.append('(member=cn=u%d,%s)' % (d + 500, self.ou_users))
            expressions.append('(member=u%d*)' % (d + 700,))
        for i in range(N_GROUPS):
            expressions.append('(memberOf=cn=g%d,%s)' % (i, self.ou_groups))
            expressions.append('(memberOf=cn=g%d*)' % (i,))
            expressions.append('(memberOf=cn=*%s*)' % self.ou_groups)

        for expression in expressions:
            t = time.time()
            for i in range(rounds):
                self.ldb.search(self.ou,
                                expression=expression,
                                scope=SCOPE_SUBTREE,
                                attrs=['cn'])
            print >> sys.stderr, '%d runs %s took %s' % (i, expression,
                                                         time.time() - t)

    def _test_add_many_users(self, n=BATCH_SIZE):
        s = self.state.next_user_id
        e = s + n
        self._add_users(s, e)
        self.state.next_user_id = e

    def _test_add_many_users_ldif(self, n=BATCH_SIZE):
        s = self.state.next_user_id
        e = s + n
        self._add_users_ldif(s, e)
        self.state.next_user_id = e

    def _link_user_and_group(self, u, g):
        m = Message()
        m.dn = Dn(self.ldb, "CN=g%d,%s" % (g, self.ou_groups))
        m["member"] = MessageElement("cn=u%d,%s" % (u, self.ou_users),
                                     FLAG_MOD_ADD, "member")
        self.ldb.modify(m)

    def _test_link_many_users(self, n=BATCH_SIZE):
        self._prepare_n_groups(N_GROUPS)
        s = self.state.next_linked_user
        e = s + n
        for i in range(s, e):
            # put everyone in group 0, and one other group
            g = i % (N_GROUPS - 1) + 1
            self._link_user_and_group(i, g)
            self._link_user_and_group(i, 0)
        self.state.next_linked_user = e


    test_00_01_adding_users_1000 = _test_add_many_users

    test_00_10_complex_search_1k_users = _test_complex_search
    test_00_11_unindexed_search_1k_users = _test_unindexed_search
    test_00_12_indexed_search_1k_users = _test_indexed_search
    test_00_13_member_search_1k_users = _test_member_search

    test_01_02_adding_users_2000_ldif = _test_add_many_users_ldif
    test_01_03_adding_users_3000 = _test_add_many_users

    test_01_10_complex_search_3k_users = _test_complex_search
    test_01_11_unindexed_search_3k_users = _test_unindexed_search
    test_01_12_indexed_search_3k_users = _test_indexed_search
    def test_01_13_member_search_3k_users(self):
        self._test_member_search(rounds=5)

    test_02_01_link_users_1000 = _test_link_many_users
    test_02_02_link_users_2000 = _test_link_many_users
    test_02_03_link_users_3000 = _test_link_many_users

    test_03_10_complex_search_linked_users = _test_complex_search
    test_03_11_unindexed_search_linked_users = _test_unindexed_search
    test_03_12_indexed_search_linked_users = _test_indexed_search
    def test_03_13_member_search_linked_users(self):
        self._test_member_search(rounds=2)
Exemple #24
0
class PyKrb5CredentialsTests(TestCase):

    def setUp(self):
        super(PyKrb5CredentialsTests, self).setUp()

        self.server      = os.environ["SERVER"]
        self.domain      = os.environ["DOMAIN"]
        self.host        = os.environ["SERVER_IP"]
        self.lp          = self.get_loadparm()

        self.credentials = self.get_credentials()

        self.session     = system_session()
        self.ldb = SamDB(url="ldap://%s" % self.host,
                         session_info=self.session,
                         credentials=self.credentials,
                         lp=self.lp)

        self.create_machine_account()


    def tearDown(self):
        super(PyKrb5CredentialsTests, self).tearDown()
        delete_force(self.ldb, self.machine_dn)

    def test_get_named_ccache(self):
        name = "MEMORY:py_creds_machine"
        ccache = self.machine_creds.get_named_ccache(self.lp,
                                                     name)
        self.assertEqual(ccache.get_name(), name)

    def test_get_unnamed_ccache(self):
        ccache = self.machine_creds.get_named_ccache(self.lp)
        self.assertIsNotNone(ccache.get_name())

    def test_set_named_ccache(self):
        ccache = self.machine_creds.get_named_ccache(self.lp)

        creds = Credentials()
        creds.set_named_ccache(ccache.get_name())

        ccache2 = creds.get_named_ccache(self.lp)
        self.assertEqual(ccache.get_name(), ccache2.get_name())

    #
    # Create the machine account
    def create_machine_account(self):
        self.machine_pass = samba.generate_random_password(32, 32)
        self.machine_name = MACHINE_NAME
        self.machine_dn = "cn=%s,%s" % (self.machine_name, self.ldb.domain_dn())

        # remove the account if it exists, this will happen if a previous test
        # run failed
        delete_force(self.ldb, self.machine_dn)
        # get unicode str for both py2 and py3
        pass_unicode = self.machine_pass.encode('utf-8').decode('utf-8')
        utf16pw = u'"{}"'.format(pass_unicode).encode('utf-16-le')
        self.ldb.add({
            "dn": self.machine_dn,
            "objectclass": "computer",
            "sAMAccountName": "%s$" % self.machine_name,
            "userAccountControl":
                str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
            "unicodePwd": utf16pw})

        self.machine_creds = Credentials()
        self.machine_creds.guess(self.get_loadparm())
        self.machine_creds.set_password(self.machine_pass)
        self.machine_creds.set_username(self.machine_name + "$")
        self.machine_creds.set_workstation(self.machine_name)
Exemple #25
0
class GroupAuditTests(AuditLogTestBase):
    def setUp(self):
        self.message_type = MSG_GROUP_LOG
        self.event_type = DSDB_GROUP_EVENT_NAME
        super(GroupAuditTests, self).setUp()

        self.server_ip = os.environ["SERVER_IP"]

        host = "ldap://%s" % os.environ["SERVER"]
        self.ldb = SamDB(url=host,
                         session_info=system_session(),
                         credentials=self.get_credentials(),
                         lp=self.get_loadparm())
        self.server = os.environ["SERVER"]

        # Gets back the basedn
        self.base_dn = self.ldb.domain_dn()

        # Get the old "dSHeuristics" if it was set
        dsheuristics = self.ldb.get_dsheuristics()

        # Set the "dSHeuristics" to activate the correct "userPassword"
        # behaviour
        self.ldb.set_dsheuristics("000000001")

        # Reset the "dSHeuristics" as they were before
        self.addCleanup(self.ldb.set_dsheuristics, dsheuristics)

        # Get the old "minPwdAge"
        minPwdAge = self.ldb.get_minPwdAge()

        # Set it temporarily to "0"
        self.ldb.set_minPwdAge("0")
        self.base_dn = self.ldb.domain_dn()

        # Reset the "minPwdAge" as it was before
        self.addCleanup(self.ldb.set_minPwdAge, minPwdAge)

        # (Re)adds the test user USER_NAME with password USER_PASS
        self.ldb.add({
            "dn": "cn=" + USER_NAME + ",cn=users," + self.base_dn,
            "objectclass": "user",
            "sAMAccountName": USER_NAME,
            "userPassword": USER_PASS
        })
        self.ldb.newgroup(GROUP_NAME_01)
        self.ldb.newgroup(GROUP_NAME_02)

    def tearDown(self):
        super(GroupAuditTests, self).tearDown()
        delete_force(self.ldb, "cn=" + USER_NAME + ",cn=users," + self.base_dn)
        self.ldb.deletegroup(GROUP_NAME_01)
        self.ldb.deletegroup(GROUP_NAME_02)

    def test_add_and_remove_users_from_group(self):

        #
        # Wait for the primary group change for the created user.
        #
        messages = self.waitForMessages(2)
        print("Received %d messages" % len(messages))
        self.assertEquals(2, len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("PrimaryGroup", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=domain users,cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        # Check the Add message for the new users primary group
        audit = messages[1]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=domain users,cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        self.assertEquals(EVT_ID_USER_ADDED_TO_GLOBAL_SEC_GROUP,
                          audit["eventId"])
        #
        # Add the user to a group
        #
        self.discardMessages()

        self.ldb.add_remove_group_members(GROUP_NAME_01, [USER_NAME])
        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        #
        # Add the user to another group
        #
        self.discardMessages()
        self.ldb.add_remove_group_members(GROUP_NAME_02, [USER_NAME])

        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_02 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        #
        # Remove the user from a group
        #
        self.discardMessages()
        self.ldb.add_remove_group_members(GROUP_NAME_01, [USER_NAME],
                                          add_members_operation=False)
        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("Removed", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        #
        # Re-add the user to a group
        #
        self.discardMessages()
        self.ldb.add_remove_group_members(GROUP_NAME_01, [USER_NAME])

        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

    def test_change_primary_group(self):

        #
        # Wait for the primary group change for the created user.
        #
        messages = self.waitForMessages(2)
        print("Received %d messages" % len(messages))
        self.assertEquals(2, len(messages),
                          "Did not receive the expected number of messages")

        # Check the PrimaryGroup message
        audit = messages[0]["groupChange"]

        self.assertEqual("PrimaryGroup", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=domain users,cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        # Check the Add message for the new users primary group
        audit = messages[1]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=domain users,cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        self.assertEquals(EVT_ID_USER_ADDED_TO_GLOBAL_SEC_GROUP,
                          audit["eventId"])

        #
        # Add the user to a group, the user needs to be a member of a group
        # before there primary group can be set to that group.
        #
        self.discardMessages()

        self.ldb.add_remove_group_members(GROUP_NAME_01, [USER_NAME])
        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")
        self.assertEquals(EVT_ID_USER_ADDED_TO_GLOBAL_SEC_GROUP,
                          audit["eventId"])

        #
        # Change the primary group of a user
        #
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        # get the primaryGroupToken of the group
        res = self.ldb.search(base=group_dn,
                              attrs=["primaryGroupToken"],
                              scope=ldb.SCOPE_BASE)
        group_id = res[0]["primaryGroupToken"]

        # set primaryGroupID attribute of the user to that group
        m = ldb.Message()
        m.dn = ldb.Dn(self.ldb, user_dn)
        m["primaryGroupID"] = ldb.MessageElement(group_id, FLAG_MOD_REPLACE,
                                                 "primaryGroupID")
        self.discardMessages()
        self.ldb.modify(m)

        #
        # Wait for the primary group change.
        # Will see the user removed from the new group
        #          the user added to their old primary group
        #          and a new primary group event.
        #
        messages = self.waitForMessages(3)
        print("Received %d messages" % len(messages))
        self.assertEquals(3, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["groupChange"]
        self.assertEqual("Removed", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")
        self.assertEquals(EVT_ID_USER_REMOVED_FROM_GLOBAL_SEC_GROUP,
                          audit["eventId"])

        audit = messages[1]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=domain users,cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")
        self.assertEquals(EVT_ID_USER_ADDED_TO_GLOBAL_SEC_GROUP,
                          audit["eventId"])

        audit = messages[2]["groupChange"]

        self.assertEqual("PrimaryGroup", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")
class UserTests(samba.tests.TestCase):

    def add_if_possible(self, *args, **kwargs):
        """In these tests sometimes things are left in the database
        deliberately, so we don't worry if we fail to add them a second
        time."""
        try:
            self.ldb.add(*args, **kwargs)
        except LdbError:
            pass

    def setUp(self):
        super(UserTests, self).setUp()
        self.state = GlobalState  # the class itself, not an instance
        self.lp = lp
        self.ldb = SamDB(host, credentials=creds,
                         session_info=system_session(lp), lp=lp)
        self.base_dn = self.ldb.domain_dn()
        self.ou = "OU=pid%s,%s" % (os.getpid(), self.base_dn)
        self.ou_users = "OU=users,%s" % self.ou
        self.ou_groups = "OU=groups,%s" % self.ou
        self.ou_computers = "OU=computers,%s" % self.ou

        self.state.test_number += 1
        random.seed(self.state.test_number)

    def tearDown(self):
        super(UserTests, self).tearDown()

    def test_00_00_do_nothing(self):
        # this gives us an idea of the overhead
        pass

    def test_00_01_do_nothing_relevant(self):
        # takes around 1 second on i7-4770
        j = 0
        for i in range(30000000):
            j += i

    def test_00_02_do_nothing_sleepily(self):
        time.sleep(1)

    def test_00_03_add_ous_and_groups(self):
        # initialise the database
        for dn in (self.ou,
                   self.ou_users,
                   self.ou_groups,
                   self.ou_computers):
            self.ldb.add({
                "dn": dn,
                "objectclass": "organizationalUnit"
            })

        for i in range(N_GROUPS):
            self.ldb.add({
                "dn": "cn=g%d,%s" % (i, self.ou_groups),
                "objectclass": "group"
            })

        self.state.n_groups = N_GROUPS

    def _add_users(self, start, end):
        for i in range(start, end):
            self.ldb.add({
                "dn": "cn=u%d,%s" % (i, self.ou_users),
                "objectclass": "user"
            })

    def _add_users_ldif(self, start, end):
        lines = []
        for i in range(start, end):
            lines.append("dn: cn=u%d,%s" % (i, self.ou_users))
            lines.append("objectclass: user")
            lines.append("")
        self.ldb.add_ldif('\n'.join(lines))

    def _test_join(self):
        tmpdir = tempfile.mkdtemp()
        if '://' in host:
            server = host.split('://', 1)[1]
        else:
            server = host
        cmd = cmd_sambatool.subcommands['domain'].subcommands['join']
        result = cmd._run("samba-tool domain join",
                          creds.get_realm(),
                          "dc", "-U%s%%%s" % (creds.get_username(),
                                              creds.get_password()),
                          '--targetdir=%s' % tmpdir,
                          '--server=%s' % server)

        shutil.rmtree(tmpdir)

    def _test_unindexed_search(self):
        expressions = [
            ('(&(objectclass=user)(description='
             'Built-in account for adminstering the computer/domain))'),
            '(description=Built-in account for adminstering the computer/domain)',
            '(objectCategory=*)',
            '(samaccountname=Administrator*)'
        ]
        for expression in expressions:
            t = time.time()
            for i in range(25):
                self.ldb.search(self.ou,
                                expression=expression,
                                scope=SCOPE_SUBTREE,
                                attrs=['cn'])
            print('%d %s took %s' % (i, expression,
                                     time.time() - t),
                  file=sys.stderr)

    def _test_indexed_search(self):
        expressions = ['(objectclass=group)',
                       '(samaccountname=Administrator)'
        ]
        for expression in expressions:
            t = time.time()
            for i in range(4000):
                self.ldb.search(self.ou,
                                expression=expression,
                                scope=SCOPE_SUBTREE,
                                attrs=['cn'])
            print('%d runs %s took %s' % (i, expression,
                                          time.time() - t),
                  file=sys.stderr)

    def _test_base_search(self):
        for dn in [self.base_dn, self.ou, self.ou_users,
                   self.ou_groups, self.ou_computers]:
            for i in range(4000):
                try:
                    self.ldb.search(dn,
                                    scope=SCOPE_BASE,
                                    attrs=['cn'])
                except LdbError as e:
                    (num, msg) = e.args
                    if num != ERR_NO_SUCH_OBJECT:
                        raise

    def _test_base_search_failing(self):
        pattern = 'missing%d' + self.ou
        for i in range(4000):
            try:
                self.ldb.search(pattern % i,
                                scope=SCOPE_BASE,
                                attrs=['cn'])
            except LdbError as (num, msg):
                if num != ERR_NO_SUCH_OBJECT:
                    raise
Exemple #27
0
class AuditLogDsdbTests(AuditLogTestBase):

    def setUp(self):
        self.message_type = MSG_DSDB_LOG
        self.event_type   = DSDB_EVENT_NAME
        super(AuditLogDsdbTests, self).setUp()

        self.remoteAddress = os.environ["CLIENT_IP"]
        self.server_ip = os.environ["SERVER_IP"]

        host = "ldap://%s" % os.environ["SERVER"]
        self.ldb = SamDB(url=host,
                         session_info=system_session(),
                         credentials=self.get_credentials(),
                         lp=self.get_loadparm())
        self.server = os.environ["SERVER"]

        # Gets back the basedn
        self.base_dn = self.ldb.domain_dn()

        # Get the old "dSHeuristics" if it was set
        dsheuristics = self.ldb.get_dsheuristics()

        # Set the "dSHeuristics" to activate the correct "userPassword"
        # behaviour
        self.ldb.set_dsheuristics("000000001")

        # Reset the "dSHeuristics" as they were before
        self.addCleanup(self.ldb.set_dsheuristics, dsheuristics)

        # Get the old "minPwdAge"
        minPwdAge = self.ldb.get_minPwdAge()

        # Set it temporarily to "0"
        self.ldb.set_minPwdAge("0")
        self.base_dn = self.ldb.domain_dn()

        # Reset the "minPwdAge" as it was before
        self.addCleanup(self.ldb.set_minPwdAge, minPwdAge)

        # (Re)adds the test user USER_NAME with password USER_PASS
        delete_force(self.ldb, "cn=" + USER_NAME + ",cn=users," + self.base_dn)
        self.ldb.add({
            "dn": "cn=" + USER_NAME + ",cn=users," + self.base_dn,
            "objectclass": "user",
            "sAMAccountName": USER_NAME,
            "userPassword": USER_PASS
        })

    #
    # Discard the messages from the setup code
    #
    def discardSetupMessages(self, dn):
        self.waitForMessages(2, dn=dn)
        self.discardMessages()

    def tearDown(self):
        self.discardMessages()
        super(AuditLogDsdbTests, self).tearDown()

    def haveExpectedTxn(self, expected):
        if self.context["txnMessage"] is not None:
            txn = self.context["txnMessage"]["dsdbTransaction"]
            if txn["transactionId"] == expected:
                return True
        return False

    def waitForTransaction(self, expected, connection=None):
        """Wait for a transaction message to arrive
        The connection is passed through to keep the connection alive
        until all the logging messages have been received.
        """

        self.connection = connection

        start_time = time.time()
        while not self.haveExpectedTxn(expected):
            self.msg_ctx.loop_once(0.1)
            if time.time() - start_time > 1:
                self.connection = None
                return ""

        self.connection = None
        return self.context["txnMessage"]

    def test_net_change_password(self):

        dn = "CN=" + USER_NAME + ",CN=Users," + self.base_dn
        self.discardSetupMessages(dn)

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server)
        password = "******"

        net.change_password(newpassword=password.encode('utf-8'),
                            username=USER_NAME,
                            oldpassword=USER_PASS)

        messages = self.waitForMessages(1, net, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        # We skip the check for self.get_service_description() as this
        # is subject to a race between smbd and the s4 rpc_server code
        # as to which will set the description as it is DCE/RPC over SMB

        self.assertTrue(self.is_guid(audit["transactionId"]))

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["clearTextPassword"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertTrue(actions[0]["redacted"])
        self.assertEquals("replace", actions[0]["action"])

    def test_net_set_password(self):

        dn = "CN=" + USER_NAME + ",CN=Users," + self.base_dn
        self.discardSetupMessages(dn)

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server)
        password = "******"
        domain = lp.get("workgroup")

        net.set_password(newpassword=password.encode('utf-8'),
                         account_name=USER_NAME,
                         domain_name=domain)
        messages = self.waitForMessages(1, net, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertEquals(dn, audit["dn"])
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        # We skip the check for self.get_service_description() as this
        # is subject to a race between smbd and the s4 rpc_server code
        # as to which will set the description as it is DCE/RPC over SMB

        self.assertTrue(self.is_guid(audit["transactionId"]))

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["clearTextPassword"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertTrue(actions[0]["redacted"])
        self.assertEquals("replace", actions[0]["action"])

    def test_ldap_change_password(self):

        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        new_password = samba.generate_random_password(32, 32)
        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.ldb.modify_ldif(
            "dn: " + dn + "\n" +
            "changetype: modify\n" +
            "delete: userPassword\n" +
            "userPassword: "******"\n" +
            "add: userPassword\n" +
            "userPassword: "******"\n")

        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertEquals(dn, audit["dn"])
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["userPassword"]["actions"]
        self.assertEquals(2, len(actions))
        self.assertTrue(actions[0]["redacted"])
        self.assertEquals("delete", actions[0]["action"])
        self.assertTrue(actions[1]["redacted"])
        self.assertEquals("add", actions[1]["action"])

    def test_ldap_replace_password(self):

        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        new_password = samba.generate_random_password(32, 32)
        self.ldb.modify_ldif(
            "dn: " + dn + "\n" +
            "changetype: modify\n" +
            "replace: userPassword\n" +
            "userPassword: "******"\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")
        self.assertTrue(self.is_guid(audit["transactionId"]))

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["userPassword"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertTrue(actions[0]["redacted"])
        self.assertEquals("replace", actions[0]["action"])

    def test_ldap_add_user(self):

        # The setup code adds a user, so we check for the dsdb events
        # generated by it.
        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        messages = self.waitForMessages(2, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(2,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[1]["dsdbChange"]
        self.assertEquals("Add", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertEquals(dn, audit["dn"])
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")
        self.assertTrue(self.is_guid(audit["sessionId"]))
        self.assertTrue(self.is_guid(audit["transactionId"]))

        attributes = audit["attributes"]
        self.assertEquals(3, len(attributes))

        actions = attributes["objectclass"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        self.assertEquals(1, len(actions[0]["values"]))
        self.assertEquals("user", actions[0]["values"][0]["value"])

        actions = attributes["sAMAccountName"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        self.assertEquals(1, len(actions[0]["values"]))
        self.assertEquals(USER_NAME, actions[0]["values"][0]["value"])

        actions = attributes["userPassword"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        self.assertTrue(actions[0]["redacted"])

    def test_samdb_delete_user(self):

        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        self.ldb.deleteuser(USER_NAME)

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Delete", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        self.assertEquals(0, audit["statusCode"])
        self.assertEquals("Success", audit["status"])
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        transactionId = audit["transactionId"]
        message = self.waitForTransaction(transactionId)
        audit = message["dsdbTransaction"]
        self.assertEquals("commit", audit["action"])
        self.assertTrue(self.is_guid(audit["transactionId"]))
        self.assertTrue(audit["duration"] > 0)

    def test_samdb_delete_non_existent_dn(self):

        DOES_NOT_EXIST = "doesNotExist"
        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        dn = "cn=" + DOES_NOT_EXIST + ",cn=users," + self.base_dn
        try:
            self.ldb.delete(dn)
            self.fail("Exception not thrown")
        except Exception:
            pass

        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Delete", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertEquals(ERR_NO_SUCH_OBJECT, audit["statusCode"])
        self.assertEquals("No such object", audit["status"])
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        transactionId = audit["transactionId"]
        message = self.waitForTransaction(transactionId)
        audit = message["dsdbTransaction"]
        self.assertEquals("rollback", audit["action"])
        self.assertTrue(self.is_guid(audit["transactionId"]))
        self.assertTrue(audit["duration"] > 0)

    def test_create_and_delete_secret_over_lsa(self):

        dn = "cn=Test Secret,CN=System," + self.base_dn
        self.discardSetupMessages(dn)

        creds = self.insta_creds(template=self.get_credentials())
        lsa_conn = lsa.lsarpc(
            "ncacn_np:%s" % self.server,
            self.get_loadparm(),
            creds)
        lsa_handle = lsa_conn.OpenPolicy2(
            system_name="\\",
            attr=lsa.ObjectAttribute(),
            access_mask=security.SEC_FLAG_MAXIMUM_ALLOWED)
        secret_name = lsa.String()
        secret_name.string = "G$Test"
        lsa_conn.CreateSecret(
            handle=lsa_handle,
            name=secret_name,
            access_mask=security.SEC_FLAG_MAXIMUM_ALLOWED)

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Add", audit["operation"])
        self.assertTrue(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])

        # We skip the check for self.get_service_description() as this
        # is subject to a race between smbd and the s4 rpc_server code
        # as to which will set the description as it is DCE/RPC over SMB

        attributes = audit["attributes"]
        self.assertEquals(2, len(attributes))

        object_class = attributes["objectClass"]
        self.assertEquals(1, len(object_class["actions"]))
        action = object_class["actions"][0]
        self.assertEquals("add", action["action"])
        values = action["values"]
        self.assertEquals(1, len(values))
        self.assertEquals("secret", values[0]["value"])

        cn = attributes["cn"]
        self.assertEquals(1, len(cn["actions"]))
        action = cn["actions"][0]
        self.assertEquals("add", action["action"])
        values = action["values"]
        self.assertEquals(1, len(values))
        self.assertEquals("Test Secret", values[0]["value"])

        #
        # Now delete the secret.
        self.discardMessages()
        h = lsa_conn.OpenSecret(
            handle=lsa_handle,
            name=secret_name,
            access_mask=security.SEC_FLAG_MAXIMUM_ALLOWED)

        lsa_conn.DeleteObject(h)
        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")

        dn = "cn=Test Secret,CN=System," + self.base_dn
        audit = messages[0]["dsdbChange"]
        self.assertEquals("Delete", audit["operation"])
        self.assertTrue(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])

        # We skip the check for self.get_service_description() as this
        # is subject to a race between smbd and the s4 rpc_server code
        # as to which will set the description as it is DCE/RPC over SMB


    def test_modify(self):

        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        #
        # Add an attribute value
        #
        self.ldb.modify_ldif(
            "dn: " + dn + "\n" +
            "changetype: modify\n" +
            "add: carLicense\n" +
            "carLicense: license-01\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertEquals(dn, audit["dn"])
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(1, len(values))
        self.assertEquals("license-01", values[0]["value"])

        #
        # Add an another value to the attribute
        #
        self.discardMessages()
        self.ldb.modify_ldif(
            "dn: " + dn + "\n" +
            "changetype: modify\n" +
            "add: carLicense\n" +
            "carLicense: license-02\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        attributes = messages[0]["dsdbChange"]["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(1, len(values))
        self.assertEquals("license-02", values[0]["value"])

        #
        # Add an another two values to the attribute
        #
        self.discardMessages()
        self.ldb.modify_ldif(
            "dn: " + dn + "\n" +
            "changetype: modify\n" +
            "add: carLicense\n" +
            "carLicense: license-03\n" +
            "carLicense: license-04\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        attributes = messages[0]["dsdbChange"]["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(2, len(values))
        self.assertEquals("license-03", values[0]["value"])
        self.assertEquals("license-04", values[1]["value"])

        #
        # delete two values to the attribute
        #
        self.discardMessages()
        self.ldb.modify_ldif(
            "dn: " + dn + "\n" +
            "changetype: delete\n" +
            "delete: carLicense\n" +
            "carLicense: license-03\n" +
            "carLicense: license-04\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        attributes = messages[0]["dsdbChange"]["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("delete", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(2, len(values))
        self.assertEquals("license-03", values[0]["value"])
        self.assertEquals("license-04", values[1]["value"])

        #
        # replace two values to the attribute
        #
        self.discardMessages()
        self.ldb.modify_ldif(
            "dn: " + dn + "\n" +
            "changetype: delete\n" +
            "replace: carLicense\n" +
            "carLicense: license-05\n" +
            "carLicense: license-06\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        attributes = messages[0]["dsdbChange"]["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("replace", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(2, len(values))
        self.assertEquals("license-05", values[0]["value"])
        self.assertEquals("license-06", values[1]["value"])
class LATests(samba.tests.TestCase):

    def setUp(self):
        super(LATests, self).setUp()
        self.samdb = SamDB(host, credentials=creds,
                           session_info=system_session(lp), lp=lp)

        self.base_dn = self.samdb.domain_dn()
        self.ou = "OU=la,%s" % self.base_dn
        if opts.delete_in_setup:
            try:
                self.samdb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError as e:
                print("tried deleting %s, got error %s" % (self.ou, e))
        self.samdb.add({'objectclass': 'organizationalUnit',
                        'dn': self.ou})

    def tearDown(self):
        super(LATests, self).tearDown()
        if not opts.no_cleanup:
            self.samdb.delete(self.ou, ['tree_delete:1'])

    def add_object(self, cn, objectclass, more_attrs={}):
        dn = "CN=%s,%s" % (cn, self.ou)
        attrs = {'cn': cn,
                 'objectclass': objectclass,
                 'dn': dn}
        attrs.update(more_attrs)
        self.samdb.add(attrs)

        return dn

    def add_objects(self, n, objectclass, prefix=None, more_attrs={}):
        if prefix is None:
            prefix = objectclass
        dns = []
        for i in range(n):
            dns.append(self.add_object("%s%d" % (prefix, i + 1),
                                       objectclass,
                                       more_attrs=more_attrs))
        return dns

    def add_linked_attribute(self, src, dest, attr='member',
                             controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_ADD, attr)
        self.samdb.modify(m, controls=controls)

    def remove_linked_attribute(self, src, dest, attr='member',
                                controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_DELETE, attr)
        self.samdb.modify(m, controls=controls)

    def replace_linked_attribute(self, src, dest, attr='member',
                                 controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_REPLACE, attr)
        self.samdb.modify(m, controls=controls)

    def attr_search(self, obj, attr, scope=ldb.SCOPE_BASE, **controls):
        if opts.no_reveal_internals:
            if 'reveal_internals' in controls:
                del controls['reveal_internals']

        controls = ['%s:%d' % (k, int(v)) for k, v in controls.items()]

        res = self.samdb.search(obj,
                                scope=scope,
                                attrs=[attr],
                                controls=controls)
        return res

    def assert_links(self, obj, expected, attr, msg='', **kwargs):
        res = self.attr_search(obj, attr, **kwargs)

        if len(expected) == 0:
            if attr in res[0]:
                self.fail("found attr '%s' in %s" % (attr, res[0]))
            return

        try:
            results = list([x[attr] for x in res][0])
        except KeyError:
            self.fail("missing attr '%s' on %s" % (attr, obj))

        expected = sorted(expected)
        results = sorted(results)

        if expected != results:
            print(msg)
            print("expected %s" % expected)
            print("received %s" % results)

        self.assertEqual(results, expected)

    def assert_back_links(self, obj, expected, attr='memberOf', **kwargs):
        self.assert_links(obj, expected, attr=attr,
                          msg='back links do not match', **kwargs)

    def assert_forward_links(self, obj, expected, attr='member', **kwargs):
        self.assert_links(obj, expected, attr=attr,
                          msg='forward links do not match', **kwargs)

    def get_object_guid(self, dn):
        res = self.samdb.search(dn,
                                scope=ldb.SCOPE_BASE,
                                attrs=['objectGUID'])
        return str(misc.GUID(res[0]['objectGUID'][0]))

    def assertRaisesLdbError(self, errcode, msg, f, *args, **kwargs):
        """Assert a function raises a particular LdbError."""
        try:
            f(*args, **kwargs)
        except ldb.LdbError as e:
            (num, msg) = e.args
            if num != errcode:
                lut = {v: k for k, v in vars(ldb).items()
                       if k.startswith('ERR_') and isinstance(v, int)}
                self.fail("%s, expected "
                          "LdbError %s, (%d) "
                          "got %s (%d)" % (msg,
                                           lut.get(errcode), errcode,
                                           lut.get(num), num))
        else:
            lut = {v: k for k, v in vars(ldb).items()
                   if k.startswith('ERR_') and isinstance(v, int)}
            self.fail("%s, expected "
                      "LdbError %s, (%d) "
                      "but we got success" % (msg, lut.get(errcode), errcode))

    def _test_la_backlinks(self, reveal=False):
        tag = 'backlinks'
        kwargs = {}
        if reveal:
            tag += '_reveal'
            kwargs = {'reveal_internals': 0}

        u1, u2 = self.add_objects(2, 'user', 'u_%s' % tag)
        g1, g2 = self.add_objects(2, 'group', 'g_%s' % tag)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.assert_back_links(u1, [g1, g2], **kwargs)
        self.assert_back_links(u2, [g2], **kwargs)

    def test_la_backlinks(self):
        self._test_la_backlinks()

    def test_la_backlinks_reveal(self):
        if opts.no_reveal_internals:
            print('skipping because --no-reveal-internals')
            return
        self._test_la_backlinks(True)

    def _test_la_backlinks_delete_group(self, reveal=False):
        tag = 'del_group'
        kwargs = {}
        if reveal:
            tag += '_reveal'
            kwargs = {'reveal_internals': 0}

        u1, u2 = self.add_objects(2, 'user', 'u_' + tag)
        g1, g2 = self.add_objects(2, 'group', 'g_' + tag)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.samdb.delete(g2, ['tree_delete:1'])

        self.assert_back_links(u1, [g1], **kwargs)
        self.assert_back_links(u2, set(), **kwargs)

    def test_la_backlinks_delete_group(self):
        self._test_la_backlinks_delete_group()

    def test_la_backlinks_delete_group_reveal(self):
        if opts.no_reveal_internals:
            print('skipping because --no-reveal-internals')
            return
        self._test_la_backlinks_delete_group(True)

    def test_links_all_delete_group(self):
        u1, u2 = self.add_objects(2, 'user', 'u_all_del_group')
        g1, g2 = self.add_objects(2, 'group', 'g_all_del_group')
        g2guid = self.get_object_guid(g2)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.samdb.delete(g2)
        self.assert_back_links(u1, [g1], show_deleted=1, show_recycled=1,
                               show_deactivated_link=0)
        self.assert_back_links(u2, set(), show_deleted=1, show_recycled=1,
                               show_deactivated_link=0)
        self.assert_forward_links(g1, [u1], show_deleted=1, show_recycled=1,
                                  show_deactivated_link=0)
        self.assert_forward_links('<GUID=%s>' % g2guid,
                                  [], show_deleted=1, show_recycled=1,
                                  show_deactivated_link=0)

    def test_links_all_delete_group_reveal(self):
        u1, u2 = self.add_objects(2, 'user', 'u_all_del_group_reveal')
        g1, g2 = self.add_objects(2, 'group', 'g_all_del_group_reveal')
        g2guid = self.get_object_guid(g2)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.samdb.delete(g2)
        self.assert_back_links(u1, [g1], show_deleted=1, show_recycled=1,
                               show_deactivated_link=0,
                               reveal_internals=0)
        self.assert_back_links(u2, set(), show_deleted=1, show_recycled=1,
                               show_deactivated_link=0,
                               reveal_internals=0)
        self.assert_forward_links(g1, [u1], show_deleted=1, show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)
        self.assert_forward_links('<GUID=%s>' % g2guid,
                                  [], show_deleted=1, show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)

    def test_la_links_delete_link(self):
        u1, u2 = self.add_objects(2, 'user', 'u_del_link')
        g1, g2 = self.add_objects(2, 'group', 'g_del_link')

        res = self.samdb.search(g1, scope=ldb.SCOPE_BASE,
                                attrs=['uSNChanged'])
        old_usn1 = int(res[0]['uSNChanged'][0])

        self.add_linked_attribute(g1, u1)

        res = self.samdb.search(g1, scope=ldb.SCOPE_BASE,
                                attrs=['uSNChanged'])
        new_usn1 = int(res[0]['uSNChanged'][0])

        self.assertNotEqual(old_usn1, new_usn1, "USN should have incremented")

        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        res = self.samdb.search(g2, scope=ldb.SCOPE_BASE,
                                attrs=['uSNChanged'])
        old_usn2 = int(res[0]['uSNChanged'][0])

        self.remove_linked_attribute(g2, u1)

        res = self.samdb.search(g2, scope=ldb.SCOPE_BASE,
                                attrs=['uSNChanged'])
        new_usn2 = int(res[0]['uSNChanged'][0])

        self.assertNotEqual(old_usn2, new_usn2, "USN should have incremented")

        self.assert_forward_links(g1, [u1])
        self.assert_forward_links(g2, [u2])

        self.add_linked_attribute(g2, u1)
        self.assert_forward_links(g2, [u1, u2])
        self.remove_linked_attribute(g2, u2)
        self.assert_forward_links(g2, [u1])
        self.remove_linked_attribute(g2, u1)
        self.assert_forward_links(g2, [])
        self.remove_linked_attribute(g1, [])
        self.assert_forward_links(g1, [])

        # removing a duplicate link in the same message should fail
        self.add_linked_attribute(g2, [u1, u2])
        self.assertRaises(ldb.LdbError,
                          self.remove_linked_attribute,g2, [u1, u1])

    def _test_la_links_delete_link_reveal(self):
        u1, u2 = self.add_objects(2, 'user', 'u_del_link_reveal')
        g1, g2 = self.add_objects(2, 'group', 'g_del_link_reveal')

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.remove_linked_attribute(g2, u1)

        self.assert_forward_links(g2, [u1, u2], show_deleted=1,
                                  show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0
        )

    def test_la_links_delete_link_reveal(self):
        if opts.no_reveal_internals:
            print('skipping because --no-reveal-internals')
            return
        self._test_la_links_delete_link_reveal()

    def test_la_links_delete_user(self):
        u1, u2 = self.add_objects(2, 'user', 'u_del_user')
        g1, g2 = self.add_objects(2, 'group', 'g_del_user')

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        res = self.samdb.search(g1, scope=ldb.SCOPE_BASE,
                                attrs=['uSNChanged'])
        old_usn1 = int(res[0]['uSNChanged'][0])

        res = self.samdb.search(g2, scope=ldb.SCOPE_BASE,
                                attrs=['uSNChanged'])
        old_usn2 = int(res[0]['uSNChanged'][0])

        self.samdb.delete(u1)

        self.assert_forward_links(g1, [])
        self.assert_forward_links(g2, [u2])

        res = self.samdb.search(g1, scope=ldb.SCOPE_BASE,
                                attrs=['uSNChanged'])
        new_usn1 = int(res[0]['uSNChanged'][0])

        res = self.samdb.search(g2, scope=ldb.SCOPE_BASE,
                                attrs=['uSNChanged'])
        new_usn2 = int(res[0]['uSNChanged'][0])

        # Assert the USN on the alternate object is unchanged
        self.assertEqual(old_usn1, new_usn1)
        self.assertEqual(old_usn2, new_usn2)

    def test_la_links_delete_user_reveal(self):
        u1, u2 = self.add_objects(2, 'user', 'u_del_user_reveal')
        g1, g2 = self.add_objects(2, 'group', 'g_del_user_reveal')

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.samdb.delete(u1)

        self.assert_forward_links(g2, [u2],
                                  show_deleted=1, show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)
        self.assert_forward_links(g1, [],
                                  show_deleted=1, show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)

    def test_multiple_links(self):
        u1, u2, u3, u4 = self.add_objects(4, 'user', 'u_multiple_links')
        g1, g2, g3, g4 = self.add_objects(4, 'group', 'g_multiple_links')

        self.add_linked_attribute(g1, [u1, u2, u3, u4])
        self.add_linked_attribute(g2, [u3, u1])
        self.add_linked_attribute(g3, u2)

        self.assertRaisesLdbError(ldb.ERR_ENTRY_ALREADY_EXISTS,
                                  "adding duplicate values",
                                  self.add_linked_attribute, g2,
                                  [u1, u2, u3, u2])

        self.assert_forward_links(g1, [u1, u2, u3, u4])
        self.assert_forward_links(g2, [u3, u1])
        self.assert_forward_links(g3, [u2])
        self.assert_back_links(u1, [g2, g1])
        self.assert_back_links(u2, [g3, g1])
        self.assert_back_links(u3, [g2, g1])
        self.assert_back_links(u4, [g1])

        self.remove_linked_attribute(g2, [u1, u3])
        self.remove_linked_attribute(g1, [u1, u3])

        self.assert_forward_links(g1, [u2, u4])
        self.assert_forward_links(g2, [])
        self.assert_forward_links(g3, [u2])
        self.assert_back_links(u1, [])
        self.assert_back_links(u2, [g3, g1])
        self.assert_back_links(u3, [])
        self.assert_back_links(u4, [g1])

        self.add_linked_attribute(g1, [u1, u3])
        self.add_linked_attribute(g2, [u3, u1])
        self.add_linked_attribute(g3, [u1, u3])

        self.assert_forward_links(g1, [u1, u2, u3, u4])
        self.assert_forward_links(g2, [u1, u3])
        self.assert_forward_links(g3, [u1, u2, u3])
        self.assert_back_links(u1, [g1, g2, g3])
        self.assert_back_links(u2, [g3, g1])
        self.assert_back_links(u3, [g3, g2, g1])
        self.assert_back_links(u4, [g1])

    def test_la_links_replace(self):
        u1, u2, u3, u4 = self.add_objects(4, 'user', 'u_replace')
        g1, g2, g3, g4 = self.add_objects(4, 'group', 'g_replace')

        self.add_linked_attribute(g1, [u1, u2])
        self.add_linked_attribute(g2, [u1, u3])
        self.add_linked_attribute(g3, u1)

        self.replace_linked_attribute(g1, [u2])
        self.replace_linked_attribute(g2, [u2, u3])
        self.replace_linked_attribute(g3, [u1, u3])
        self.replace_linked_attribute(g4, [u4])

        self.assert_forward_links(g1, [u2])
        self.assert_forward_links(g2, [u3, u2])
        self.assert_forward_links(g3, [u3, u1])
        self.assert_forward_links(g4, [u4])
        self.assert_back_links(u1, [g3])
        self.assert_back_links(u2, [g1, g2])
        self.assert_back_links(u3, [g2, g3])
        self.assert_back_links(u4, [g4])

        self.replace_linked_attribute(g1, [u1, u2, u3])
        self.replace_linked_attribute(g2, [u1])
        self.replace_linked_attribute(g3, [u2])
        self.replace_linked_attribute(g4, [])

        self.assert_forward_links(g1, [u1, u2, u3])
        self.assert_forward_links(g2, [u1])
        self.assert_forward_links(g3, [u2])
        self.assert_forward_links(g4, [])
        self.assert_back_links(u1, [g1, g2])
        self.assert_back_links(u2, [g1, g3])
        self.assert_back_links(u3, [g1])
        self.assert_back_links(u4, [])

        self.assertRaisesLdbError(ldb.ERR_ENTRY_ALREADY_EXISTS,
                                  "replacing duplicate values",
                                  self.replace_linked_attribute, g2,
                                  [u1, u2, u3, u2])


    def test_la_links_replace2(self):
        users = self.add_objects(12, 'user', 'u_replace2')
        g1, = self.add_objects(1, 'group', 'g_replace2')

        self.add_linked_attribute(g1, users[:6])
        self.assert_forward_links(g1, users[:6])
        self.replace_linked_attribute(g1, users)
        self.assert_forward_links(g1, users)
        self.replace_linked_attribute(g1, users[6:])
        self.assert_forward_links(g1, users[6:])
        self.remove_linked_attribute(g1, users[6:9])
        self.assert_forward_links(g1, users[9:])
        self.remove_linked_attribute(g1, users[9:])
        self.assert_forward_links(g1, [])

    def test_la_links_permutations(self):
        """Make sure the order in which we add links doesn't matter."""
        users = self.add_objects(3, 'user', 'u_permutations')
        groups = self.add_objects(6, 'group', 'g_permutations')

        for g, p in zip(groups, itertools.permutations(users)):
            self.add_linked_attribute(g, p)

        # everyone should be in every group
        for g in groups:
            self.assert_forward_links(g, users)

        for u in users:
            self.assert_back_links(u, groups)

        for g, p in zip(groups[::-1], itertools.permutations(users)):
            self.replace_linked_attribute(g, p)

        for g in groups:
            self.assert_forward_links(g, users)

        for u in users:
            self.assert_back_links(u, groups)

        for g, p in zip(groups, itertools.permutations(users)):
            self.remove_linked_attribute(g, p)

        for g in groups:
            self.assert_forward_links(g, [])

        for u in users:
            self.assert_back_links(u, [])

    def test_la_links_relaxed(self):
        """Check that the relax control doesn't mess with linked attributes."""
        relax_control = ['relax:0']

        users = self.add_objects(10, 'user', 'u_relax')
        groups = self.add_objects(3, 'group', 'g_relax',
                                  more_attrs={'member': users[:2]})
        g_relax1, g_relax2, g_uptight = groups

        # g_relax1 has all users added at once
        # g_relax2 gets them one at a time in reverse order
        # g_uptight never relaxes

        self.add_linked_attribute(g_relax1, users[2:5], controls=relax_control)

        for u in reversed(users[2:5]):
            self.add_linked_attribute(g_relax2, u, controls=relax_control)
            self.add_linked_attribute(g_uptight, u)

        for g in groups:
            self.assert_forward_links(g, users[:5])

            self.add_linked_attribute(g, users[5:7])
            self.assert_forward_links(g, users[:7])

            for u in users[7:]:
                self.add_linked_attribute(g, u)

            self.assert_forward_links(g, users)

        for u in users:
            self.assert_back_links(u, groups)

        # try some replacement permutations
        import random
        random.seed(1)
        users2 = users[:]
        for i in range(5):
            random.shuffle(users2)
            self.replace_linked_attribute(g_relax1, users2,
                                          controls=relax_control)

            self.assert_forward_links(g_relax1, users)

        for i in range(5):
            random.shuffle(users2)
            self.remove_linked_attribute(g_relax2, users2,
                                         controls=relax_control)
            self.remove_linked_attribute(g_uptight, users2)

            self.replace_linked_attribute(g_relax1, [], controls=relax_control)

            random.shuffle(users2)
            self.add_linked_attribute(g_relax2, users2,
                                      controls=relax_control)
            self.add_linked_attribute(g_uptight, users2)
            self.replace_linked_attribute(g_relax1, users2,
                                          controls=relax_control)

            self.assert_forward_links(g_relax1, users)
            self.assert_forward_links(g_relax2, users)
            self.assert_forward_links(g_uptight, users)

        for u in users:
            self.assert_back_links(u, groups)

    def test_add_all_at_once(self):
        """All these other tests are creating linked attributes after the
        objects are there. We want to test creating them all at once
        using LDIF.
        """
        users = self.add_objects(7, 'user', 'u_all_at_once')
        g1, g3 = self.add_objects(2, 'group', 'g_all_at_once',
                                  more_attrs={'member': users})
        (g2,) = self.add_objects(1, 'group', 'g_all_at_once2',
                                 more_attrs={'member': users[:5]})

        self.assertRaisesLdbError(ldb.ERR_ENTRY_ALREADY_EXISTS,
                                  "adding multiple duplicate values",
                                  self.add_objects, 1, 'group',
                                  'g_with_duplicate_links',
                                  more_attrs={'member': users[:5] + users[1:2]})

        self.assert_forward_links(g1, users)
        self.assert_forward_links(g2, users[:5])
        self.assert_forward_links(g3, users)
        for u in users[:5]:
            self.assert_back_links(u, [g1, g2, g3])
        for u in users[5:]:
            self.assert_back_links(u, [g1, g3])

        self.remove_linked_attribute(g2, users[0])
        self.remove_linked_attribute(g2, users[1])
        self.add_linked_attribute(g2, users[1])
        self.add_linked_attribute(g2, users[5])
        self.add_linked_attribute(g2, users[6])

        self.assert_forward_links(g1, users)
        self.assert_forward_links(g2, users[1:])

        for u in users[1:]:
            self.remove_linked_attribute(g2, u)
        self.remove_linked_attribute(g1, users)

        for u in users:
            self.samdb.delete(u)

        self.assert_forward_links(g1, [])
        self.assert_forward_links(g2, [])
        self.assert_forward_links(g3, [])

    def test_one_way_attributes(self):
        e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',
                                  'e_one_way')
        guid = self.get_object_guid(e2)

        self.add_linked_attribute(e1, e2, attr="addressBookRoots")
        self.assert_forward_links(e1, [e2], attr='addressBookRoots')

        self.samdb.delete(e2)

        res = self.samdb.search("<GUID=%s>" % guid,
                                scope=ldb.SCOPE_BASE,
                                controls=['show_deleted:1',
                                          'show_recycled:1'])

        new_dn = str(res[0].dn)
        self.assert_forward_links(e1, [new_dn], attr='addressBookRoots')
        self.assert_forward_links(e1, [new_dn],
                                  attr='addressBookRoots',
                                  show_deactivated_link=0)

    def test_one_way_attributes_delete_link(self):
        e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',
                                  'e_one_way')
        guid = self.get_object_guid(e2)

        self.add_linked_attribute(e1, e2, attr="addressBookRoots")
        self.assert_forward_links(e1, [e2], attr='addressBookRoots')

        self.remove_linked_attribute(e1, e2, attr="addressBookRoots")

        self.assert_forward_links(e1, [], attr='addressBookRoots')
        self.assert_forward_links(e1, [], attr='addressBookRoots',
                                  show_deactivated_link=0)

    def test_pretend_one_way_attributes(self):
        e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',
                                  'e_one_way')
        guid = self.get_object_guid(e2)

        self.add_linked_attribute(e1, e2, attr="addressBookRoots2")
        self.assert_forward_links(e1, [e2], attr='addressBookRoots2')

        self.samdb.delete(e2)
        res = self.samdb.search("<GUID=%s>" % guid,
                                scope=ldb.SCOPE_BASE,
                                controls=['show_deleted:1',
                                          'show_recycled:1'])

        new_dn = str(res[0].dn)

        self.assert_forward_links(e1, [], attr='addressBookRoots2')
        self.assert_forward_links(e1, [], attr='addressBookRoots2',
                                  show_deactivated_link=0)

    def test_pretend_one_way_attributes_delete_link(self):
        e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',
                                  'e_one_way')
        guid = self.get_object_guid(e2)

        self.add_linked_attribute(e1, e2, attr="addressBookRoots2")
        self.assert_forward_links(e1, [e2], attr='addressBookRoots2')

        self.remove_linked_attribute(e1, e2, attr="addressBookRoots2")

        self.assert_forward_links(e1, [], attr='addressBookRoots2')
        self.assert_forward_links(e1, [], attr='addressBookRoots2',
                                  show_deactivated_link=0)
Exemple #29
0
    def test_rid_set_dbcheck_after_seize(self):
        """Perform a join against the RID manager and assert we have a RID Set.
        We seize the RID master role, then using dbcheck, we assert that we can
        detect out of range users (and then bump the RID set as required)."""

        fsmo_dn = ldb.Dn(self.ldb_dc1, "CN=RID Manager$,CN=System," + self.ldb_dc1.domain_dn())
        (fsmo_owner, fsmo_not_owner) = self._determine_fSMORoleOwner(fsmo_dn)

        targetdir = self._test_join(fsmo_owner['dns_name'], "RIDALLOCTEST7")
        try:
            # Connect to the database
            ldb_url = "tdb://%s" % os.path.join(targetdir, "private/sam.ldb")
            smbconf = os.path.join(targetdir, "etc/smb.conf")

            lp = self.get_loadparm()
            new_ldb = SamDB(ldb_url, credentials=self.get_credentials(),
                            session_info=system_session(lp), lp=lp)

            # 1. Get server name
            res = new_ldb.search(base=ldb.Dn(new_ldb, new_ldb.get_serverName()),
                                 scope=ldb.SCOPE_BASE, attrs=["serverReference"])
            # 2. Get server reference
            server_ref_dn = ldb.Dn(new_ldb, res[0]['serverReference'][0])

            # 3. Assert we get the RID Set
            res = new_ldb.search(base=server_ref_dn,
                                 scope=ldb.SCOPE_BASE, attrs=['rIDSetReferences'])

            self.assertTrue("rIDSetReferences" in res[0])
            rid_set_dn = ldb.Dn(new_ldb, res[0]["rIDSetReferences"][0])

            # 4. Seize the RID Manager role
            (result, out, err) = self.runsubcmd("fsmo", "seize", "--role", "rid", "-H", ldb_url, "-s", smbconf, "--force")
            self.assertCmdSuccess(result, out, err)
            self.assertEquals(err,"","Shouldn't be any error messages")

            # 5. Add a new user (triggers RID set work)
            new_ldb.newuser("ridalloctestuser", "P@ssword!")

            # 6. Now fetch the RID SET
            rid_set_res = new_ldb.search(base=rid_set_dn,
                                         scope=ldb.SCOPE_BASE, attrs=['rIDNextRid',
                                                                      'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            last_rid = (0xFFFFFFFF00000000 & next_pool) >> 32

            # 7. Add user above the ridNextRid and at almost the end of the range.
            #
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser2,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD, 'objectClass')
            m['objectSid'] = ldb.MessageElement(ndr_pack(security.dom_sid(str(new_ldb.get_domain_sid()) + "-%d" % (last_rid - 3))),
                                                ldb.FLAG_MOD_ADD,
                                                'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            # 8. Add user above the ridNextRid and at the end of the range
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser3,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD, 'objectClass')
            m['objectSid'] = ldb.MessageElement(ndr_pack(security.dom_sid(str(new_ldb.get_domain_sid()) + "-%d" % last_rid)),
                                                ldb.FLAG_MOD_ADD,
                                                'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            chk = dbcheck(new_ldb, verbose=False, fix=True, yes=True, quiet=True)

            # Should have fixed two errors (wrong ridNextRid)
            self.assertEqual(chk.check_database(DN=rid_set_dn, scope=ldb.SCOPE_BASE), 2)

            # 9. Assert we get didn't show any other errors
            chk = dbcheck(new_ldb, verbose=False, fix=False, quiet=True)

            # 10. Add another user (checks RID rollover)
            # We have seized the role, so we can do that.
            new_ldb.newuser("ridalloctestuser3", "P@ssword!")

            rid_set_res = new_ldb.search(base=rid_set_dn,
                                         scope=ldb.SCOPE_BASE, attrs=['rIDNextRid',
                                                                      'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            self.assertNotEqual(last_rid, (0xFFFFFFFF00000000 & next_pool) >> 32, "rid pool should have changed")
        finally:
            self._test_force_demote(fsmo_owner['dns_name'], "RIDALLOCTEST7")
            shutil.rmtree(targetdir, ignore_errors=True)
Exemple #30
0
class PyCredentialsTests(TestCase):

    def setUp(self):
        super(PyCredentialsTests, self).setUp()

        self.server      = os.environ["SERVER"]
        self.domain      = os.environ["DOMAIN"]
        self.host        = os.environ["SERVER_IP"]
        self.lp          = self.get_loadparm()

        self.credentials = self.get_credentials()

        self.session     = system_session()
        self.ldb = SamDB(url="ldap://%s" % self.host,
                         session_info=self.session,
                         credentials=self.credentials,
                         lp=self.lp)

        self.create_machine_account()
        self.create_user_account()


    def tearDown(self):
        super(PyCredentialsTests, self).tearDown()
        delete_force(self.ldb, self.machine_dn)
        delete_force(self.ldb, self.user_dn)

    # Until a successful netlogon connection has been established there will
    # not be a valid authenticator associated with the credentials
    # and new_client_authenticator should throw a ValueError
    def test_no_netlogon_connection(self):
        self.assertRaises(ValueError,
                          self.machine_creds.new_client_authenticator)

    # Once a netlogon connection has been established,
    # new_client_authenticator should return a value
    #
    def test_have_netlogon_connection(self):
        c = self.get_netlogon_connection()
        a = self.machine_creds.new_client_authenticator()
        self.assertIsNotNone(a)

    # Get an authenticator and use it on a sequence of operations requiring
    # an authenticator
    def test_client_authenticator(self):
        c = self.get_netlogon_connection()
        (authenticator, subsequent) = self.get_authenticator(c)
        self.do_NetrLogonSamLogonWithFlags(c, authenticator, subsequent)
        (authenticator, subsequent) = self.get_authenticator(c)
        self.do_NetrLogonGetDomainInfo(c, authenticator, subsequent)
        (authenticator, subsequent) = self.get_authenticator(c)
        self.do_NetrLogonGetDomainInfo(c, authenticator, subsequent)
        (authenticator, subsequent) = self.get_authenticator(c)
        self.do_NetrLogonGetDomainInfo(c, authenticator, subsequent)


    def test_SamLogonEx(self):
        c = self.get_netlogon_connection()

        logon = samlogon_logon_info(self.domain,
                                    self.machine_name,
                                    self.user_creds)

        logon_level = netlogon.NetlogonNetworkTransitiveInformation
        validation_level = netlogon.NetlogonValidationSamInfo4
        netr_flags = 0

        try:
            c.netr_LogonSamLogonEx(self.server,
                                   self.user_creds.get_workstation(),
                                   logon_level,
                                   logon,
                                   validation_level,
                                   netr_flags)
        except NTSTATUSError as e:
            enum = ctypes.c_uint32(e.args[0]).value
            if enum == ntstatus.NT_STATUS_WRONG_PASSWORD:
                self.fail("got wrong password error")
            else:
                raise

    def test_SamLogonEx_no_domain(self):
        c = self.get_netlogon_connection()

        self.user_creds.set_domain('')

        logon = samlogon_logon_info(self.domain,
                                    self.machine_name,
                                    self.user_creds)

        logon_level = netlogon.NetlogonNetworkTransitiveInformation
        validation_level = netlogon.NetlogonValidationSamInfo4
        netr_flags = 0

        try:
            c.netr_LogonSamLogonEx(self.server,
                                   self.user_creds.get_workstation(),
                                   logon_level,
                                   logon,
                                   validation_level,
                                   netr_flags)
        except NTSTATUSError as e:
            enum = ctypes.c_uint32(e.args[0]).value
            if enum == ntstatus.NT_STATUS_WRONG_PASSWORD:
                self.fail("got wrong password error")
            else:
                self.fail("got unexpected error" + str(e))

    def test_SamLogonExNTLM(self):
        c = self.get_netlogon_connection()

        logon = samlogon_logon_info(self.domain,
                                    self.machine_name,
                                    self.user_creds,
                                    flags=CLI_CRED_NTLM_AUTH)

        logon_level = netlogon.NetlogonNetworkTransitiveInformation
        validation_level = netlogon.NetlogonValidationSamInfo4
        netr_flags = 0

        try:
            c.netr_LogonSamLogonEx(self.server,
                                   self.user_creds.get_workstation(),
                                   logon_level,
                                   logon,
                                   validation_level,
                                   netr_flags)
        except NTSTATUSError as e:
            enum = ctypes.c_uint32(e.args[0]).value
            if enum == ntstatus.NT_STATUS_WRONG_PASSWORD:
                self.fail("got wrong password error")
            else:
                raise

    def test_SamLogonExMSCHAPv2(self):
        c = self.get_netlogon_connection()

        logon = samlogon_logon_info(self.domain,
                                    self.machine_name,
                                    self.user_creds,
                                    flags=CLI_CRED_NTLM_AUTH)

        logon.identity_info.parameter_control = MSV1_0_ALLOW_MSVCHAPV2

        logon_level = netlogon.NetlogonNetworkTransitiveInformation
        validation_level = netlogon.NetlogonValidationSamInfo4
        netr_flags = 0

        try:
            c.netr_LogonSamLogonEx(self.server,
                                   self.user_creds.get_workstation(),
                                   logon_level,
                                   logon,
                                   validation_level,
                                   netr_flags)
        except NTSTATUSError as e:
            enum = ctypes.c_uint32(e.args[0]).value
            if enum == ntstatus.NT_STATUS_WRONG_PASSWORD:
                self.fail("got wrong password error")
            else:
                raise


    # Test Credentials.encrypt_netr_crypt_password
    # By performing a NetrServerPasswordSet2
    # And the logging on using the new password.
    def test_encrypt_netr_password(self):
        # Change the password
        self.do_Netr_ServerPasswordSet2()
        # Now use the new password to perform an operation
        srvsvc.srvsvc("ncacn_np:%s" % (self.server),
                      self.lp,
                      self.machine_creds)


   # Change the current machine account password with a
   # netr_ServerPasswordSet2 call.

    def do_Netr_ServerPasswordSet2(self):
        c = self.get_netlogon_connection()
        (authenticator, subsequent) = self.get_authenticator(c)
        PWD_LEN  = 32
        DATA_LEN = 512
        newpass = samba.generate_random_password(PWD_LEN, PWD_LEN)
        encoded = newpass.encode('utf-16-le')
        pwd_len = len(encoded)
        filler  = [ord(x) for x in os.urandom(DATA_LEN-pwd_len)]
        pwd = netlogon.netr_CryptPassword()
        pwd.length = pwd_len
        pwd.data = filler + [ord(x) for x in encoded]
        self.machine_creds.encrypt_netr_crypt_password(pwd)
        c.netr_ServerPasswordSet2(self.server,
                                  self.machine_creds.get_workstation(),
                                  SEC_CHAN_WKSTA,
                                  self.machine_name,
                                  authenticator,
                                  pwd)

        self.machine_pass = newpass
        self.machine_creds.set_password(newpass)

    # Establish sealed schannel netlogon connection over TCP/IP
    #
    def get_netlogon_connection(self):
        return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" % self.server,
                                 self.lp,
                                 self.machine_creds)

    #
    # Create the machine account
    def create_machine_account(self):
        self.machine_pass = samba.generate_random_password(32, 32)
        self.machine_name = MACHINE_NAME
        self.machine_dn = "cn=%s,%s" % (self.machine_name, self.ldb.domain_dn())

        # remove the account if it exists, this will happen if a previous test
        # run failed
        delete_force(self.ldb, self.machine_dn)

        utf16pw = unicode(
            '"' + self.machine_pass.encode('utf-8') + '"', 'utf-8'
        ).encode('utf-16-le')
        self.ldb.add({
            "dn": self.machine_dn,
            "objectclass": "computer",
            "sAMAccountName": "%s$" % self.machine_name,
            "userAccountControl":
                str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
            "unicodePwd": utf16pw})

        self.machine_creds = Credentials()
        self.machine_creds.guess(self.get_loadparm())
        self.machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
        self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
        self.machine_creds.set_password(self.machine_pass)
        self.machine_creds.set_username(self.machine_name + "$")
        self.machine_creds.set_workstation(self.machine_name)

    #
    # Create a test user account
    def create_user_account(self):
        self.user_pass = samba.generate_random_password(32, 32)
        self.user_name = USER_NAME
        self.user_dn = "cn=%s,%s" % (self.user_name, self.ldb.domain_dn())

        # remove the account if it exists, this will happen if a previous test
        # run failed
        delete_force(self.ldb, self.user_dn)

        utf16pw = unicode(
            '"' + self.user_pass.encode('utf-8') + '"', 'utf-8'
        ).encode('utf-16-le')
        self.ldb.add({
            "dn": self.user_dn,
            "objectclass": "user",
            "sAMAccountName": "%s" % self.user_name,
            "userAccountControl": str(UF_NORMAL_ACCOUNT),
            "unicodePwd": utf16pw})

        self.user_creds = Credentials()
        self.user_creds.guess(self.get_loadparm())
        self.user_creds.set_password(self.user_pass)
        self.user_creds.set_username(self.user_name)
        self.user_creds.set_workstation(self.machine_name)
        pass

    #
    # Get the authenticator from the machine creds.
    def get_authenticator(self, c):
        auth = self.machine_creds.new_client_authenticator();
        current  = netr_Authenticator()
        current.cred.data = [ord(x) for x in auth["credential"]]
        current.timestamp = auth["timestamp"]

        subsequent = netr_Authenticator()
        return (current, subsequent)

    def do_NetrLogonSamLogonWithFlags(self, c, current, subsequent):
        logon = samlogon_logon_info(self.domain,
                                    self.machine_name,
                                    self.user_creds)

        logon_level = netlogon.NetlogonNetworkTransitiveInformation
        validation_level = netlogon.NetlogonValidationSamInfo4
        netr_flags = 0
        c.netr_LogonSamLogonWithFlags(self.server,
                                      self.user_creds.get_workstation(),
                                      current,
                                      subsequent,
                                      logon_level,
                                      logon,
                                      validation_level,
                                      netr_flags)

    def do_NetrLogonGetDomainInfo(self, c, current, subsequent):
        query = netr_WorkstationInformation()

        c.netr_LogonGetDomainInfo(self.server,
                                  self.user_creds.get_workstation(),
                                  current,
                                  subsequent,
                                  2,
                                  query)
Exemple #31
0
    def run(self,
            psoname,
            precedence,
            H=None,
            min_pwd_age=None,
            max_pwd_age=None,
            complexity=None,
            store_plaintext=None,
            history_length=None,
            min_pwd_length=None,
            account_lockout_duration=None,
            account_lockout_threshold=None,
            reset_account_lockout_after=None,
            credopts=None,
            sambaopts=None,
            versionopts=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        samdb = SamDB(url=H,
                      session_info=system_session(),
                      credentials=creds,
                      lp=lp)

        try:
            precedence = int(precedence)
        except ValueError:
            raise CommandError("The PSO's precedence should be "
                               "a numerical value. Try --help")

        # sanity-check that the PSO doesn't already exist
        pso_dn = "CN=%s,%s" % (psoname, pso_container(samdb))
        try:
            res = samdb.search(pso_dn, scope=ldb.SCOPE_BASE)
        except ldb.LdbError as e:
            if e.args[0] == ldb.ERR_NO_SUCH_OBJECT:
                pass
            else:
                raise
        else:
            raise CommandError("PSO '%s' already exists" % psoname)

        # we expect the user to specify at least one password-policy setting,
        # otherwise there's no point in creating a PSO
        num_pwd_args = num_options_in_args(pwd_settings_options, self.raw_argv)
        if num_pwd_args == 0:
            raise CommandError("Please specify at least one password policy "
                               "setting. Try --help")

        # it's unlikely that the user will specify all 9 password policy
        # settings on the CLI - current domain password-settings as the default
        # values for unspecified arguments
        if num_pwd_args < len(pwd_settings_options):
            self.message("Not all password policy options "
                         "have been specified.")
            self.message("For unspecified options, the current domain password"
                         " settings will be used as the default values.")

        # lookup the current domain password-settings
        res = samdb.search(samdb.domain_dn(),
                           scope=ldb.SCOPE_BASE,
                           attrs=[
                               "pwdProperties", "pwdHistoryLength",
                               "minPwdLength", "minPwdAge", "maxPwdAge",
                               "lockoutDuration", "lockoutThreshold",
                               "lockOutObservationWindow"
                           ])
        assert (len(res) == 1)

        # use the domain settings for any missing arguments
        pwd_props = int(res[0]["pwdProperties"][0])
        if complexity is None:
            prop_flag = DOMAIN_PASSWORD_COMPLEX
            complexity = "on" if pwd_props & prop_flag else "off"

        if store_plaintext is None:
            prop_flag = DOMAIN_PASSWORD_STORE_CLEARTEXT
            store_plaintext = "on" if pwd_props & prop_flag else "off"

        if history_length is None:
            history_length = int(res[0]["pwdHistoryLength"][0])

        if min_pwd_length is None:
            min_pwd_length = int(res[0]["minPwdLength"][0])

        if min_pwd_age is None:
            min_pwd_age = timestamp_to_days(res[0]["minPwdAge"][0])

        if max_pwd_age is None:
            max_pwd_age = timestamp_to_days(res[0]["maxPwdAge"][0])

        if account_lockout_duration is None:
            account_lockout_duration = \
                timestamp_to_mins(res[0]["lockoutDuration"][0])

        if account_lockout_threshold is None:
            account_lockout_threshold = int(res[0]["lockoutThreshold"][0])

        if reset_account_lockout_after is None:
            reset_account_lockout_after = \
                timestamp_to_mins(res[0]["lockOutObservationWindow"][0])

        check_pso_constraints(max_pwd_age=max_pwd_age,
                              min_pwd_age=min_pwd_age,
                              history_length=history_length,
                              min_pwd_length=min_pwd_length)

        # pack the settings into an LDB message
        m = make_pso_ldb_msg(self.outf,
                             samdb,
                             pso_dn,
                             create=True,
                             complexity=complexity,
                             precedence=precedence,
                             store_plaintext=store_plaintext,
                             history_length=history_length,
                             min_pwd_length=min_pwd_length,
                             min_pwd_age=min_pwd_age,
                             max_pwd_age=max_pwd_age,
                             lockout_duration=account_lockout_duration,
                             lockout_threshold=account_lockout_threshold,
                             reset_lockout_after=reset_account_lockout_after)

        # create the new PSO
        try:
            samdb.add(m)
            self.message("PSO successfully created: %s" % pso_dn)
            # display the new PSO's settings
            show_pso_by_dn(self.outf, samdb, pso_dn, show_applies_to=False)
        except ldb.LdbError as e:
            (num, msg) = e.args
            if num == ldb.ERR_INSUFFICIENT_ACCESS_RIGHTS:
                raise CommandError("Administrator permissions are needed "
                                   "to create a PSO.")
            else:
                raise CommandError("Failed to create PSO '%s': %s" %
                                   (pso_dn, msg))
Exemple #32
0
    def test_rid_set_dbcheck_after_seize(self):
        """Perform a join against the RID manager and assert we have a RID Set.
        We seize the RID master role, then using dbcheck, we assert that we can
        detect out of range users (and then bump the RID set as required)."""

        fsmo_dn = ldb.Dn(
            self.ldb_dc1,
            "CN=RID Manager$,CN=System," + self.ldb_dc1.domain_dn())
        (fsmo_owner, fsmo_not_owner) = self._determine_fSMORoleOwner(fsmo_dn)

        targetdir = self._test_join(fsmo_owner['dns_name'], "RIDALLOCTEST7")
        try:
            # Connect to the database
            ldb_url = "tdb://%s" % os.path.join(targetdir, "private/sam.ldb")
            smbconf = os.path.join(targetdir, "etc/smb.conf")

            lp = self.get_loadparm()
            new_ldb = SamDB(ldb_url,
                            credentials=self.get_credentials(),
                            session_info=system_session(lp),
                            lp=lp)

            # 1. Get server name
            res = new_ldb.search(base=ldb.Dn(new_ldb,
                                             new_ldb.get_serverName()),
                                 scope=ldb.SCOPE_BASE,
                                 attrs=["serverReference"])
            # 2. Get server reference
            server_ref_dn = ldb.Dn(new_ldb, res[0]['serverReference'][0])

            # 3. Assert we get the RID Set
            res = new_ldb.search(base=server_ref_dn,
                                 scope=ldb.SCOPE_BASE,
                                 attrs=['rIDSetReferences'])

            self.assertTrue("rIDSetReferences" in res[0])
            rid_set_dn = ldb.Dn(new_ldb, res[0]["rIDSetReferences"][0])

            # 4. Seize the RID Manager role
            (result, out, err) = self.runsubcmd("fsmo", "seize", "--role",
                                                "rid", "-H", ldb_url, "-s",
                                                smbconf, "--force")
            self.assertCmdSuccess(result, out, err)
            self.assertEquals(err, "", "Shouldn't be any error messages")

            # 5. Add a new user (triggers RID set work)
            new_ldb.newuser("ridalloctestuser", "P@ssword!")

            # 6. Now fetch the RID SET
            rid_set_res = new_ldb.search(
                base=rid_set_dn,
                scope=ldb.SCOPE_BASE,
                attrs=['rIDNextRid', 'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            last_rid = (0xFFFFFFFF00000000 & next_pool) >> 32

            # 7. Add user above the ridNextRid and at almost the end of the range.
            #
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser2,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD,
                                                  'objectClass')
            m['objectSid'] = ldb.MessageElement(
                ndr_pack(
                    security.dom_sid(
                        str(new_ldb.get_domain_sid()) + "-%d" %
                        (last_rid - 3))), ldb.FLAG_MOD_ADD, 'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            # 8. Add user above the ridNextRid and at the end of the range
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser3,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD,
                                                  'objectClass')
            m['objectSid'] = ldb.MessageElement(
                ndr_pack(
                    security.dom_sid(
                        str(new_ldb.get_domain_sid()) + "-%d" % last_rid)),
                ldb.FLAG_MOD_ADD, 'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            chk = dbcheck(new_ldb,
                          verbose=False,
                          fix=True,
                          yes=True,
                          quiet=True)

            # Should have fixed two errors (wrong ridNextRid)
            self.assertEqual(
                chk.check_database(DN=rid_set_dn, scope=ldb.SCOPE_BASE), 2)

            # 9. Assert we get didn't show any other errors
            chk = dbcheck(new_ldb, verbose=False, fix=False, quiet=True)

            # 10. Add another user (checks RID rollover)
            # We have seized the role, so we can do that.
            new_ldb.newuser("ridalloctestuser3", "P@ssword!")

            rid_set_res = new_ldb.search(
                base=rid_set_dn,
                scope=ldb.SCOPE_BASE,
                attrs=['rIDNextRid', 'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            self.assertNotEqual(last_rid,
                                (0xFFFFFFFF00000000 & next_pool) >> 32,
                                "rid pool should have changed")
        finally:
            self._test_force_demote(fsmo_owner['dns_name'], "RIDALLOCTEST7")
            shutil.rmtree(targetdir, ignore_errors=True)
Exemple #33
0
class DsdbTests(TestCase):
    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

        # Create a test user
        user_name = "dsdb-user-" + str(uuid.uuid4().hex[0:6])
        user_pass = samba.generate_random_password(32, 32)
        user_description = "Test user for dsdb test"

        base_dn = self.samdb.domain_dn()

        self.account_dn = "cn=" + user_name + ",cn=Users," + base_dn
        self.samdb.newuser(username=user_name,
                           password=user_pass,
                           description=user_description)
        # Cleanup (teardown)
        self.addCleanup(delete_force, self.samdb, self.account_dn)

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEquals(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, [
            "local_oid:1.3.6.1.4.1.7165.4.3.14:0",
            "local_oid:1.3.6.1.4.1.7165.4.3.25:0"
        ])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val",
                                                ldb.FLAG_MOD_REPLACE,
                                                "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(13),
                          "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 2)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(
            dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description",
                                                      version + 2)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "description"),
            version + 2)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:0" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:1" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            (errno, estr) = e.args
            if errno != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail(
                    "Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                    % e[1])

    # Allocate a unique RID for use in the objectSID tests.
    #
    def allocate_rid(self):
        self.samdb.transaction_start()
        try:
            rid = self.samdb.allocate_rid()
        except:
            self.samdb.transaction_cancel()
            raise
        self.samdb.transaction_commit()
        return str(rid)

    # Ensure that duplicate objectSID's are permitted for foreign security
    # principals.
    #
    def test_duplicate_objectSIDs_allowed_on_foreign_security_principals(self):

        #
        # We need to build a foreign security principal SID
        # i.e a  SID not in the current domain.
        #
        dom_sid = self.samdb.get_domain_sid()
        if str(dom_sid).endswith("0"):
            c = "9"
        else:
            c = "0"
        sid_str = str(dom_sid)[:-1] + c + "-1000"
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        dn = "CN=%s,CN=ForeignSecurityPrincipals,%s" % (sid_str, basedn)

        #
        # First without control
        #

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"
            })
            self.fail("No exception should get ERR_OBJECT_CLASS_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_OBJECT_CLASS_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_MISSING_REQUIRED_ATT
            self.assertTrue(werr in msg, msg)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal",
                "objectSid": sid
            })
            self.fail("No exception should get ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_ILLEGAL_MOD_OPERATION
            self.assertTrue(werr in msg, msg)

        #
        # We need to use the provision control
        # in order to add foreignSecurityPrincipal
        # objects
        #

        controls = ["provision:0"]
        self.samdb.add({
            "dn": dn,
            "objectClass": "foreignSecurityPrincipal"
        },
                       controls=controls)

        self.samdb.delete(dn)

        try:
            self.samdb.add(
                {
                    "dn": dn,
                    "objectClass": "foreignSecurityPrincipal"
                },
                controls=controls)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.fail("Got unexpected exception %d - %s " % (code, msg))

        # cleanup
        self.samdb.delete(dn)

    def _test_foreignSecurityPrincipal(self, obj_class, fpo_attr):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_fpo"
        dn_str = "cn=%s,cn=Users,%s" % (cn, basedn)
        dn = ldb.Dn(self.samdb, dn_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn_str)

        self.samdb.add({"dn": dn_str, "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_INVALID_GROUP_TYPE
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_NO_SUCH_OBJECT")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_NO_SUCH_OBJECT, str(e))
            werr = "%08X" % werror.WERR_NO_SUCH_MEMBER
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 1)
        self.samdb.delete(res[0].dn)
        self.samdb.delete(dn)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

    def test_foreignSecurityPrincipal_member(self):
        return self._test_foreignSecurityPrincipal("group", "member")

    def test_foreignSecurityPrincipal_MembersForAzRole(self):
        return self._test_foreignSecurityPrincipal("msDS-AzRole",
                                                   "msDS-MembersForAzRole")

    def test_foreignSecurityPrincipal_NeverRevealGroup(self):
        return self._test_foreignSecurityPrincipal("computer",
                                                   "msDS-NeverRevealGroup")

    def test_foreignSecurityPrincipal_RevealOnDemandGroup(self):
        return self._test_foreignSecurityPrincipal("computer",
                                                   "msDS-RevealOnDemandGroup")

    def _test_fail_foreignSecurityPrincipal(self,
                                            obj_class,
                                            fpo_attr,
                                            msg_exp,
                                            lerr_exp,
                                            werr_exp,
                                            allow_reference=True):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn = self.samdb.get_default_basedn()
        cn1 = "dsdb_test_fpo1"
        dn1_str = "cn=%s,cn=Users,%s" % (cn1, basedn)
        dn1 = ldb.Dn(self.samdb, dn1_str)
        cn2 = "dsdb_test_fpo2"
        dn2_str = "cn=%s,cn=Users,%s" % (cn2, basedn)
        dn2 = ldb.Dn(self.samdb, dn2_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn1_str)
        self.addCleanup(delete_force, self.samdb, dn2_str)

        self.samdb.add({"dn": dn1_str, "objectClass": obj_class})

        self.samdb.add({"dn": dn2_str, "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("%s" % dn2, ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            if not allow_reference:
                self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            if allow_reference:
                self.fail("Should have not raised an exception: %s" % e)
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(dn2)
        self.samdb.delete(dn1)

    def test_foreignSecurityPrincipal_NonMembers(self):
        return self._test_fail_foreignSecurityPrincipal(
            "group",
            "msDS-NonMembers",
            "LDB_ERR_UNWILLING_TO_PERFORM/WERR_NOT_SUPPORTED",
            ldb.ERR_UNWILLING_TO_PERFORM,
            werror.WERR_NOT_SUPPORTED,
            allow_reference=False)

    def test_foreignSecurityPrincipal_HostServiceAccount(self):
        return self._test_fail_foreignSecurityPrincipal(
            "computer", "msDS-HostServiceAccount",
            "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
            ldb.ERR_CONSTRAINT_VIOLATION,
            werror.WERR_DS_NAME_REFERENCE_INVALID)

    def test_foreignSecurityPrincipal_manager(self):
        return self._test_fail_foreignSecurityPrincipal(
            "user", "manager",
            "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
            ldb.ERR_CONSTRAINT_VIOLATION,
            werror.WERR_DS_NAME_REFERENCE_INVALID)

    #
    # Duplicate objectSID's should not be permitted for sids in the local
    # domain. The test sequence is add an object, delete it, then attempt to
    # re-add it, this should fail with a constraint violation
    #
    def test_duplicate_objectSIDs_not_allowed_on_local_objects(self):

        dom_sid = self.samdb.get_domain_sid()
        rid = self.allocate_rid()
        sid_str = str(dom_sid) + "-" + rid
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_01"
        dn = "cn=%s,cn=Users,%s" % (cn, basedn)

        self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
        self.samdb.delete(dn)

        try:
            self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            if code != ldb.ERR_CONSTRAINT_VIOLATION:
                self.fail("Got %d - %s should have got "
                          "LDB_ERR_CONSTRAINT_VIOLATION" % (code, msg))

    def test_linked_vs_non_linked_reference(self):
        basedn = self.samdb.get_default_basedn()
        kept_dn_str = "cn=reference_kept,cn=Users,%s" % (basedn)
        removed_dn_str = "cn=reference_removed,cn=Users,%s" % (basedn)
        dom_sid = self.samdb.get_domain_sid()
        none_sid_str = str(dom_sid) + "-4294967294"
        none_guid_str = "afafafaf-fafa-afaf-fafa-afafafafafaf"

        self.addCleanup(delete_force, self.samdb, kept_dn_str)
        self.addCleanup(delete_force, self.samdb, removed_dn_str)

        self.samdb.add({"dn": kept_dn_str, "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=kept_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        kept_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        kept_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        kept_dn = res[0].dn

        self.samdb.add({"dn": removed_dn_str, "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=removed_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        removed_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        removed_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        self.samdb.delete(removed_dn_str)

        #
        # First try the linked attribute 'manager'
        # by GUID and SID
        #

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                            ldb.FLAG_MOD_ADD, "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                            ldb.FLAG_MOD_ADD, "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        #
        # Try the non-linked attribute 'assistant'
        # by GUID and SID, which should work.
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_ADD, "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_DELETE, "assistant")
        self.samdb.modify(msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_ADD, "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_DELETE, "assistant")
        self.samdb.modify(msg)

        #
        # Finally ry the non-linked attribute 'assistant'
        # but with non existing GUID, SID, DN
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("CN=NoneNone,%s" % (basedn),
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % none_sid_str,
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % none_guid_str,
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(kept_dn)

    def test_normalize_dn_in_domain_full(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        full_str = str(full_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_str))

    def test_normalize_dn_in_domain_part(self):
        domain_dn = self.samdb.domain_dn()

        part_str = "CN=Users"

        full_dn = ldb.Dn(self.samdb, part_str)
        full_dn.add_base(domain_dn)

        # That is, the domain DN appended
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(part_str))

    def test_normalize_dn_in_domain_full_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_dn))

    def test_normalize_dn_in_domain_part_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        # That is, the domain DN appended
        self.assertEqual(
            ldb.Dn(self.samdb,
                   str(part_dn) + "," + str(domain_dn)),
            self.samdb.normalize_dn_in_domain(part_dn))
Exemple #34
0
class MatchRulesTests(samba.tests.TestCase):
    def setUp(self):
        super(MatchRulesTests, self).setUp()
        self.lp = lp
        self.ldb = SamDB(host, credentials=creds, session_info=system_session(lp), lp=lp)
        self.base_dn = self.ldb.domain_dn()
        self.ou = "ou=matchrulestest,%s" % self.base_dn
        self.ou_users = "ou=users,%s" % self.ou
        self.ou_groups = "ou=groups,%s" % self.ou
        self.ou_computers = "ou=computers,%s" % self.ou

        # Add a organizational unit to create objects
        self.ldb.add({
            "dn": self.ou,
            "objectclass": "organizationalUnit"})

	# Add the following OU hierarchy and set otherWellKnownObjects,
	# which has BinaryDN syntax:
	#
	# o1
	# |--> o2
	# |    |--> o3
	# |    |    |-->o4

	self.ldb.add({
	    "dn": "OU=o1,%s" % self.ou,
            "objectclass": "organizationalUnit"})
	self.ldb.add({
	    "dn": "OU=o2,OU=o1,%s" % self.ou,
            "objectclass": "organizationalUnit"})
	self.ldb.add({
	    "dn": "OU=o3,OU=o2,OU=o1,%s" % self.ou,
            "objectclass": "organizationalUnit"})
	self.ldb.add({
	    "dn": "OU=o4,OU=o3,OU=o2,OU=o1,%s" % self.ou,
            "objectclass": "organizationalUnit"})

        m = Message()
        m.dn = Dn(self.ldb, self.ou)
        m["otherWellKnownObjects"] = MessageElement("B:32:00000000000000000000000000000001:OU=o1,%s" % self.ou,
                                     FLAG_MOD_ADD, "otherWellKnownObjects")
        self.ldb.modify(m)

        m = Message()
        m.dn = Dn(self.ldb, "OU=o1,%s" % self.ou)
        m["otherWellKnownObjects"] = MessageElement("B:32:00000000000000000000000000000002:OU=o2,OU=o1,%s" % self.ou,
                                     FLAG_MOD_ADD, "otherWellKnownObjects")
        self.ldb.modify(m)

        m = Message()
        m.dn = Dn(self.ldb, "OU=o2,OU=o1,%s" % self.ou)
        m["otherWellKnownObjects"] = MessageElement("B:32:00000000000000000000000000000003:OU=o3,OU=o2,OU=o1,%s" % self.ou,
                                     FLAG_MOD_ADD, "otherWellKnownObjects")
        self.ldb.modify(m)

        m = Message()
        m.dn = Dn(self.ldb, "OU=o3,OU=o2,OU=o1,%s" % self.ou)
        m["otherWellKnownObjects"] = MessageElement("B:32:00000000000000000000000000000004:OU=o4,OU=o3,OU=o2,OU=o1,%s" % self.ou,
                                     FLAG_MOD_ADD, "otherWellKnownObjects")
        self.ldb.modify(m)

	# Create OU for users and groups
        self.ldb.add({
            "dn": self.ou_users,
            "objectclass": "organizationalUnit"})
        self.ldb.add({
            "dn": self.ou_groups,
            "objectclass": "organizationalUnit"})
        self.ldb.add({
            "dn": self.ou_computers,
            "objectclass": "organizationalUnit"})

        # Add four groups
        self.ldb.add({
            "dn": "cn=g1,%s" % self.ou_groups,
            "objectclass": "group" })
        self.ldb.add({
            "dn": "cn=g2,%s" % self.ou_groups,
            "objectclass": "group" })
        self.ldb.add({
            "dn": "cn=g3,%s" % self.ou_groups,
            "objectclass": "group" })
        self.ldb.add({
            "dn": "cn=g4,%s" % self.ou_groups,
            "objectclass": "group" })

        # Add four users
        self.ldb.add({
            "dn": "cn=u1,%s" % self.ou_users,
            "objectclass": "user"})
        self.ldb.add({
            "dn": "cn=u2,%s" % self.ou_users,
            "objectclass": "user"})
        self.ldb.add({
            "dn": "cn=u3,%s" % self.ou_users,
            "objectclass": "user"})
        self.ldb.add({
            "dn": "cn=u4,%s" % self.ou_users,
            "objectclass": "user"})

        # Add computers to test Object(DN-Binary) syntax
        self.ldb.add({
            "dn": "cn=c1,%s" % self.ou_computers,
            "objectclass": "computer",
            "dNSHostName": "c1.%s" % self.lp.get("realm").lower(),
            "servicePrincipalName": ["HOST/c1"],
            "sAMAccountName": "c1$",
            "userAccountControl": "83890178"})

        self.ldb.add({
            "dn": "cn=c2,%s" % self.ou_computers,
            "objectclass": "computer",
            "dNSHostName": "c2.%s" % self.lp.get("realm").lower(),
            "servicePrincipalName": ["HOST/c2"],
            "sAMAccountName": "c2$",
            "userAccountControl": "83890178"})

        self.ldb.add({
            "dn": "cn=c3,%s" % self.ou_computers,
            "objectclass": "computer",
            "dNSHostName": "c3.%s" % self.lp.get("realm").lower(),
            "servicePrincipalName": ["HOST/c3"],
            "sAMAccountName": "c3$",
            "userAccountControl": "83890178"})

        # Create the following hierarchy:
        # g4
        # |--> u4
        # |--> g3
        # |    |--> u3
        # |    |--> g2
        # |    |    |--> u2
        # |    |    |--> g1
        # |    |    |    |--> u1

        # u1 member of g1
        m = Message()
        m.dn = Dn(self.ldb, "cn=g1,%s" % self.ou_groups)
        m["member"] = MessageElement("cn=u1,%s" % self.ou_users,
                                     FLAG_MOD_ADD, "member")
        self.ldb.modify(m)

        # u2 member of g2
        m = Message()
        m.dn = Dn(self.ldb, "cn=g2,%s" % self.ou_groups)
        m["member"] = MessageElement("cn=u2,%s" % self.ou_users,
                                     FLAG_MOD_ADD, "member")
        self.ldb.modify(m)

        # u3 member of g3
        m = Message()
        m.dn = Dn(self.ldb, "cn=g3,%s" % self.ou_groups)
        m["member"] = MessageElement("cn=u3,%s" % self.ou_users,
                                     FLAG_MOD_ADD, "member")
        self.ldb.modify(m)

        # u4 member of g4
        m = Message()
        m.dn = Dn(self.ldb, "cn=g4,%s" % self.ou_groups)
        m["member"] = MessageElement("cn=u4,%s" % self.ou_users,
                                     FLAG_MOD_ADD, "member")
        self.ldb.modify(m)

        # g3 member of g4
        m = Message()
        m.dn = Dn(self.ldb, "cn=g4,%s" % self.ou_groups)
        m["member"] = MessageElement("cn=g3,%s" % self.ou_groups,
                                     FLAG_MOD_ADD, "member")
        self.ldb.modify(m)

        # g2 member of g3
        m = Message()
        m.dn = Dn(self.ldb, "cn=g3,%s" % self.ou_groups)
        m["member"] = MessageElement("cn=g2,%s" % self.ou_groups,
                                     FLAG_MOD_ADD, "member")
        self.ldb.modify(m)

        # g1 member of g2
        m = Message()
        m.dn = Dn(self.ldb, "cn=g2,%s" % self.ou_groups)
        m["member"] = MessageElement("cn=g1,%s" % self.ou_groups,
                                     FLAG_MOD_ADD, "member")
        self.ldb.modify(m)

        # The msDS-RevealedUsers is owned by system and cannot be modified
        # directly. Set the schemaUpgradeInProgress flag as workaround
        # and create this hierarchy:
        # ou=computers
        # |-> c1
        # |   |->c2
        # |   |  |->u1

        #
        # While appropriate for this test, this is NOT a good practice
        # in general.  This is only done here because the alternative
        # is to make a schema modification.
        #
        # IF/WHEN Samba protects this attribute better, this
        # particular part of the test can be removed, as the same code
        # is covered by the addressBookRoots2 case well enough.
        #
        m = Message()
        m.dn = Dn(self.ldb, "")
        m["e1"] = MessageElement("1", FLAG_MOD_REPLACE, "schemaUpgradeInProgress")
        self.ldb.modify(m)

        m = Message()
        m.dn = Dn(self.ldb, "cn=c2,%s" % self.ou_computers)
        m["e1"] = MessageElement("B:8:01010101:cn=c3,%s" % self.ou_computers,
                                 FLAG_MOD_ADD, "msDS-RevealedUsers")
        self.ldb.modify(m)

        m = Message()
        m.dn = Dn(self.ldb, "cn=c1,%s" % self.ou_computers)
        m["e1"] = MessageElement("B:8:01010101:cn=c2,%s" % self.ou_computers,
                                 FLAG_MOD_ADD, "msDS-RevealedUsers")
        self.ldb.modify(m)

        m = Message()
        m.dn = Dn(self.ldb, "")
        m["e1"] = MessageElement("0", FLAG_MOD_REPLACE, "schemaUpgradeInProgress")
        self.ldb.modify(m)

        # Add a couple of ms-Exch-Configuration-Container to test forward-link
        # attributes without backward link (addressBookRoots2)
        # e1
        # |--> e2
        # |    |--> c1
        self.ldb.add({
            "dn": "cn=e1,%s" % self.ou,
            "objectclass": "msExchConfigurationContainer"})
        self.ldb.add({
            "dn": "cn=e2,%s" % self.ou,
            "objectclass": "msExchConfigurationContainer"})

        m = Message()
        m.dn = Dn(self.ldb, "cn=e2,%s" % self.ou)
        m["e1"] = MessageElement("cn=c1,%s" % self.ou_computers,
                                 FLAG_MOD_ADD, "addressBookRoots2")
        self.ldb.modify(m)

        m = Message()
        m.dn = Dn(self.ldb, "cn=e1,%s" % self.ou)
        m["e1"] = MessageElement("cn=e2,%s" % self.ou,
                                 FLAG_MOD_ADD, "addressBookRoots2")
        self.ldb.modify(m)

    def tearDown(self):
        super(MatchRulesTests, self).tearDown()
        delete_force(self.ldb, "cn=u4,%s" % self.ou_users)
        delete_force(self.ldb, "cn=u3,%s" % self.ou_users)
        delete_force(self.ldb, "cn=u2,%s" % self.ou_users)
        delete_force(self.ldb, "cn=u1,%s" % self.ou_users)
        delete_force(self.ldb, "cn=g4,%s" % self.ou_groups)
        delete_force(self.ldb, "cn=g3,%s" % self.ou_groups)
        delete_force(self.ldb, "cn=g2,%s" % self.ou_groups)
        delete_force(self.ldb, "cn=g1,%s" % self.ou_groups)
        delete_force(self.ldb, "cn=c1,%s" % self.ou_computers)
        delete_force(self.ldb, "cn=c2,%s" % self.ou_computers)
        delete_force(self.ldb, "cn=c3,%s" % self.ou_computers)
        delete_force(self.ldb, self.ou_users)
        delete_force(self.ldb, self.ou_groups)
        delete_force(self.ldb, self.ou_computers)
        delete_force(self.ldb, "OU=o4,OU=o3,OU=o2,OU=o1,%s" % self.ou)
        delete_force(self.ldb, "OU=o3,OU=o2,OU=o1,%s" % self.ou)
        delete_force(self.ldb, "OU=o2,OU=o1,%s" % self.ou)
        delete_force(self.ldb, "OU=o1,%s" % self.ou)
        delete_force(self.ldb, "CN=e2,%s" % self.ou)
        delete_force(self.ldb, "CN=e1,%s" % self.ou)
        delete_force(self.ldb, self.ou)

    def test_u1_member_of_g4(self):
        # Search without transitive match must return 0 results
        res1 = self.ldb.search("cn=g4,%s" % self.ou_groups,
                        scope=SCOPE_BASE,
                        expression="member=cn=u1,%s" % self.ou_users)
        self.assertTrue(len(res1) == 0)

        res1 = self.ldb.search("cn=u1,%s" % self.ou_users,
                        scope=SCOPE_BASE,
                        expression="memberOf=cn=g4,%s" % self.ou_groups)
        self.assertTrue(len(res1) == 0)

        # Search with transitive match must return 1 results
        res1 = self.ldb.search("cn=g4,%s" % self.ou_groups,
                        scope=SCOPE_BASE,
                        expression="member:1.2.840.113556.1.4.1941:=cn=u1,%s" % self.ou_users)
        self.assertTrue(len(res1) == 1)

        res1 = self.ldb.search("cn=u1,%s" % self.ou_users,
                        scope=SCOPE_BASE,
                        expression="memberOf:1.2.840.113556.1.4.1941:=cn=g4,%s" % self.ou_groups)
        self.assertTrue(len(res1) == 1)

    def test_g1_member_of_g4(self):
        # Search without transitive match must return 0 results
        res1 = self.ldb.search("cn=g4,%s" % self.ou_groups,
                        scope=SCOPE_BASE,
                        expression="member=cn=g1,%s" % self.ou_groups)
        self.assertTrue(len(res1) == 0)

        res1 = self.ldb.search("cn=g1,%s" % self.ou_groups,
                        scope=SCOPE_BASE,
                        expression="memberOf=cn=g4,%s" % self.ou_groups)
        self.assertTrue(len(res1) == 0)

        # Search with transitive match must return 1 results
        res1 = self.ldb.search("cn=g4,%s" % self.ou_groups,
                        scope=SCOPE_BASE,
                        expression="member:1.2.840.113556.1.4.1941:=cn=g1,%s" % self.ou_groups)
        self.assertTrue(len(res1) == 1)

        res1 = self.ldb.search("cn=g1,%s" % self.ou_groups,
                        scope=SCOPE_BASE,
                        expression="memberOf:1.2.840.113556.1.4.1941:=cn=g4,%s" % self.ou_groups)
        self.assertTrue(len(res1) == 1)

    def test_u1_groups(self):
        res1 = self.ldb.search(self.ou_groups,
                        scope=SCOPE_SUBTREE,
                        expression="member=cn=u1,%s" % self.ou_users)
        self.assertTrue(len(res1) == 1)

        res1 = self.ldb.search(self.ou_groups,
                        scope=SCOPE_SUBTREE,
                        expression="member:1.2.840.113556.1.4.1941:=cn=u1,%s" % self.ou_users)
        self.assertTrue(len(res1) == 4)

    def test_u2_groups(self):
        res1 = self.ldb.search(self.ou_groups,
                        scope=SCOPE_SUBTREE,
                        expression="member=cn=u2,%s" % self.ou_users)
        self.assertTrue(len(res1) == 1)

        res1 = self.ldb.search(self.ou_groups,
                        scope=SCOPE_SUBTREE,
                        expression="member:1.2.840.113556.1.4.1941:=cn=u2,%s" % self.ou_users)
        self.assertTrue(len(res1) == 3)

    def test_u3_groups(self):
        res1 = self.ldb.search(self.ou_groups,
                        scope=SCOPE_SUBTREE,
                        expression="member=cn=u3,%s" % self.ou_users)
        self.assertTrue(len(res1) == 1)

        res1 = self.ldb.search(self.ou_groups,
                        scope=SCOPE_SUBTREE,
                        expression="member:1.2.840.113556.1.4.1941:=cn=u3,%s" % self.ou_users)
        self.assertTrue(len(res1) == 2)

    def test_u4_groups(self):
        res1 = self.ldb.search(self.ou_groups,
                        scope=SCOPE_SUBTREE,
                        expression="member=cn=u4,%s" % self.ou_users)
        self.assertTrue(len(res1) == 1)

        res1 = self.ldb.search(self.ou_groups,
                        scope=SCOPE_SUBTREE,
                        expression="member:1.2.840.113556.1.4.1941:=cn=u4,%s" % self.ou_users)
        self.assertTrue(len(res1) == 1)

    def test_extended_dn(self):
        res1 = self.ldb.search("cn=u1,%s" % self.ou_users,
                        scope=SCOPE_BASE,
                        expression="objectClass=*",
                        attrs=['objectSid', 'objectGUID'])
        self.assertTrue(len(res1) == 1)

        sid = self.ldb.schema_format_value("objectSid", res1[0]["objectSid"][0])
        guid = self.ldb.schema_format_value("objectGUID", res1[0]['objectGUID'][0])

        res1 = self.ldb.search(self.ou_groups,
                        scope=SCOPE_SUBTREE,
                        expression="member=<SID=%s>" % sid)
        self.assertTrue(len(res1) == 1)

        res1 = self.ldb.search(self.ou_groups,
                        scope=SCOPE_SUBTREE,
                        expression="member=<GUID=%s>" % guid)
        self.assertTrue(len(res1) == 1)

        res1 = self.ldb.search(self.ou_groups,
                        scope=SCOPE_SUBTREE,
                        expression="member:1.2.840.113556.1.4.1941:=<SID=%s>" % sid)
        self.assertTrue(len(res1) == 4)

        res1 = self.ldb.search(self.ou_groups,
                        scope=SCOPE_SUBTREE,
                        expression="member:1.2.840.113556.1.4.1941:=<GUID=%s>" % guid)
        self.assertTrue(len(res1) == 4)

    def test_object_dn_binary(self):
        res1 = self.ldb.search(self.ou_computers,
                        scope=SCOPE_SUBTREE,
                        expression="msDS-RevealedUsers=B:8:01010101:cn=c3,%s" % self.ou_computers)
        self.assertTrue(len(res1) == 1)

        res1 = self.ldb.search(self.ou_computers,
                        scope=SCOPE_SUBTREE,
                        expression="msDS-RevealedUsers:1.2.840.113556.1.4.1941:=B:8:01010101:cn=c3,%s" % self.ou_computers)
        self.assertTrue(len(res1) == 2)

    def test_one_way_links(self):
        res1 = self.ldb.search(self.ou,
                        scope=SCOPE_SUBTREE,
                        expression="addressBookRoots2=cn=c1,%s" % self.ou_computers)
        self.assertTrue(len(res1) == 1)

        res1 = self.ldb.search(self.ou,
                        scope=SCOPE_SUBTREE,
                        expression="addressBookRoots2:1.2.840.113556.1.4.1941:=cn=c1,%s" % self.ou_computers)
        self.assertTrue(len(res1) == 2)

    def test_not_linked_attrs(self):
        res1 = self.ldb.search(self.base_dn,
                        scope=SCOPE_BASE,
                        expression="wellKnownObjects=B:32:aa312825768811d1aded00c04fd8d5cd:CN=computers,%s" % self.base_dn)
        self.assertTrue(len(res1) == 1)

        res1 = self.ldb.search(self.base_dn,
                        scope=SCOPE_BASE,
                        expression="wellKnownObjects:1.2.840.113556.1.4.1941:=B:32:aa312825768811d1aded00c04fd8d5cd:CN=computers,%s" % self.base_dn)
        self.assertTrue(len(res1) == 0)


	res1 = self.ldb.search(self.ou,
			scope=SCOPE_SUBTREE,
			expression="otherWellKnownObjects=B:32:00000000000000000000000000000004:OU=o4,OU=o3,OU=o2,OU=o1,%s" % self.ou)
	self.assertTrue(len(res1) == 1)

	res1 = self.ldb.search(self.ou,
			scope=SCOPE_SUBTREE,
			expression="otherWellKnownObjects:1.2.840.113556.1.4.1941:=B:32:00000000000000000000000000000004:OU=o4,OU=o3,OU=o2,OU=o1,%s" % self.ou)
	self.assertTrue(len(res1) == 0)
class AuthLogPassChangeTests(samba.tests.auth_log_base.AuthLogTestBase):
    def setUp(self):
        super(AuthLogPassChangeTests, self).setUp()

        self.remoteAddress = os.environ["CLIENT_IP"]
        self.server_ip = os.environ["SERVER_IP"]

        host = "ldap://%s" % os.environ["SERVER"]
        self.ldb = SamDB(url=host,
                         session_info=system_session(),
                         credentials=self.get_credentials(),
                         lp=self.get_loadparm())

        print("ldb %s" % type(self.ldb))
        # Gets back the basedn
        base_dn = self.ldb.domain_dn()
        print("base_dn %s" % base_dn)

        # permit password changes during this test
        PasswordCommon.allow_password_changes(self, self.ldb)

        self.base_dn = self.ldb.domain_dn()

        # (Re)adds the test user USER_NAME with password USER_PASS
        delete_force(self.ldb, "cn=" + USER_NAME + ",cn=users," + self.base_dn)
        self.ldb.add({
            "dn": "cn=" + USER_NAME + ",cn=users," + self.base_dn,
            "objectclass": "user",
            "sAMAccountName": USER_NAME,
            "userPassword": USER_PASS
        })

        # discard any auth log messages for the password setup
        self.discardMessages()

    def tearDown(self):
        super(AuthLogPassChangeTests, self).tearDown()

    def test_admin_change_password(self):
        def isLastExpectedMessage(msg):
            return ((msg["type"] == "Authentication")
                    and (msg["Authentication"]["status"] == "NT_STATUS_OK")
                    and (msg["Authentication"]["serviceDescription"]
                         == "SAMR Password Change")
                    and (msg["Authentication"]["authDescription"]
                         == "samr_ChangePasswordUser3")
                    and (msg["Authentication"]["eventId"]
                         == EVT_ID_SUCCESSFUL_LOGON) and
                    (msg["Authentication"]["logonType"] == EVT_LOGON_NETWORK))

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server_ip)
        password = "******"

        net.change_password(newpassword=password,
                            username=USER_NAME,
                            oldpassword=USER_PASS)

        messages = self.waitForMessages(isLastExpectedMessage)
        print("Received %d messages" % len(messages))
        self.assertEquals(8, len(messages),
                          "Did not receive the expected number of messages")

    def test_admin_change_password_new_password_fails_restriction(self):
        def isLastExpectedMessage(msg):
            return ((msg["type"] == "Authentication")
                    and (msg["Authentication"]["status"]
                         == "NT_STATUS_PASSWORD_RESTRICTION")
                    and (msg["Authentication"]["serviceDescription"]
                         == "SAMR Password Change")
                    and (msg["Authentication"]["authDescription"]
                         == "samr_ChangePasswordUser3")
                    and (msg["Authentication"]["eventId"]
                         == EVT_ID_UNSUCCESSFUL_LOGON) and
                    (msg["Authentication"]["logonType"] == EVT_LOGON_NETWORK))

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server_ip)
        password = "******"

        exception_thrown = False
        try:
            net.change_password(newpassword=password,
                                oldpassword=USER_PASS,
                                username=USER_NAME)
        except Exception:
            exception_thrown = True
        self.assertEquals(True, exception_thrown,
                          "Expected exception not thrown")

        messages = self.waitForMessages(isLastExpectedMessage)
        self.assertEquals(8, len(messages),
                          "Did not receive the expected number of messages")

    def test_admin_change_password_unknown_user(self):
        def isLastExpectedMessage(msg):
            return (
                (msg["type"] == "Authentication") and
                (msg["Authentication"]["status"] == "NT_STATUS_NO_SUCH_USER")
                and (msg["Authentication"]["serviceDescription"]
                     == "SAMR Password Change")
                and (msg["Authentication"]["authDescription"]
                     == "samr_ChangePasswordUser3") and
                (msg["Authentication"]["eventId"] == EVT_ID_UNSUCCESSFUL_LOGON)
                and (msg["Authentication"]["logonType"] == EVT_LOGON_NETWORK))

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server_ip)
        password = "******"

        exception_thrown = False
        try:
            net.change_password(newpassword=password,
                                oldpassword=USER_PASS,
                                username="******")
        except Exception:
            exception_thrown = True
        self.assertEquals(True, exception_thrown,
                          "Expected exception not thrown")

        messages = self.waitForMessages(isLastExpectedMessage)
        self.assertEquals(8, len(messages),
                          "Did not receive the expected number of messages")

    def test_admin_change_password_bad_original_password(self):
        def isLastExpectedMessage(msg):
            return (
                (msg["type"] == "Authentication") and
                (msg["Authentication"]["status"] == "NT_STATUS_WRONG_PASSWORD")
                and (msg["Authentication"]["serviceDescription"]
                     == "SAMR Password Change")
                and (msg["Authentication"]["authDescription"]
                     == "samr_ChangePasswordUser3") and
                (msg["Authentication"]["eventId"] == EVT_ID_UNSUCCESSFUL_LOGON)
                and (msg["Authentication"]["logonType"] == EVT_LOGON_NETWORK))

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server_ip)
        password = "******"

        exception_thrown = False
        try:
            net.change_password(newpassword=password,
                                oldpassword="******",
                                username=USER_NAME)
        except Exception:
            exception_thrown = True
        self.assertEquals(True, exception_thrown,
                          "Expected exception not thrown")

        messages = self.waitForMessages(isLastExpectedMessage)
        self.assertEquals(8, len(messages),
                          "Did not receive the expected number of messages")

    # net rap password changes are broken, but they trigger enough of the
    # server side behaviour to exercise the code paths of interest.
    # if we used the real password it would be too long and does not hash
    # correctly, so we just check it triggers the wrong password path.
    def test_rap_change_password(self):
        def isLastExpectedMessage(msg):
            return (
                (msg["type"] == "Authentication")
                and (msg["Authentication"]["serviceDescription"]
                     == "SAMR Password Change") and
                (msg["Authentication"]["status"] == "NT_STATUS_WRONG_PASSWORD")
                and (msg["Authentication"]["authDescription"]
                     == "OemChangePasswordUser2") and
                (msg["Authentication"]["eventId"] == EVT_ID_UNSUCCESSFUL_LOGON)
                and (msg["Authentication"]["logonType"] == EVT_LOGON_NETWORK))

        username = os.environ["USERNAME"]
        server = os.environ["SERVER"]
        password = os.environ["PASSWORD"]
        server_param = "--server=%s" % server
        creds = "-U%s%%%s" % (username, password)
        call([
            "bin/net", "rap", server_param, "password", USER_NAME,
            "notMyPassword", "notGoingToBeMyPassword", server, creds,
            "--option=client ipc max protocol=nt1"
        ])

        messages = self.waitForMessages(isLastExpectedMessage)
        self.assertEquals(7, len(messages),
                          "Did not receive the expected number of messages")

    def test_ldap_change_password(self):
        def isLastExpectedMessage(msg):
            return ((msg["type"] == "Authentication")
                    and (msg["Authentication"]["status"] == "NT_STATUS_OK")
                    and (msg["Authentication"]["serviceDescription"]
                         == "LDAP Password Change") and
                    (msg["Authentication"]["authDescription"] == "LDAP Modify")
                    and (msg["Authentication"]["eventId"]
                         == EVT_ID_SUCCESSFUL_LOGON) and
                    (msg["Authentication"]["logonType"] == EVT_LOGON_NETWORK))

        new_password = samba.generate_random_password(32, 32)
        self.ldb.modify_ldif("dn: cn=" + USER_NAME + ",cn=users," +
                             self.base_dn + "\n" + "changetype: modify\n" +
                             "delete: userPassword\n" + "userPassword: "******"\n" + "add: userPassword\n" +
                             "userPassword: "******"\n")

        messages = self.waitForMessages(isLastExpectedMessage)
        print("Received %d messages" % len(messages))
        self.assertEquals(4, len(messages),
                          "Did not receive the expected number of messages")

    #
    # Currently this does not get logged, so we expect to only see the log
    # entries for the underlying ldap bind.
    #
    def test_ldap_change_password_bad_user(self):
        def isLastExpectedMessage(msg):
            return (msg["type"] == "Authorization"
                    and msg["Authorization"]["serviceDescription"] == "LDAP"
                    and msg["Authorization"]["authType"] == "krb5")

        new_password = samba.generate_random_password(32, 32)
        try:
            self.ldb.modify_ldif("dn: cn=" + "badUser" + ",cn=users," +
                                 self.base_dn + "\n" + "changetype: modify\n" +
                                 "delete: userPassword\n" + "userPassword: "******"\n" + "add: userPassword\n" +
                                 "userPassword: "******"\n")
            self.fail()
        except LdbError as e:
            (num, msg) = e.args
            pass

        messages = self.waitForMessages(isLastExpectedMessage)
        print("Received %d messages" % len(messages))
        self.assertEquals(3, len(messages),
                          "Did not receive the expected number of messages")

    def test_ldap_change_password_bad_original_password(self):
        def isLastExpectedMessage(msg):
            return (
                (msg["type"] == "Authentication") and
                (msg["Authentication"]["status"] == "NT_STATUS_WRONG_PASSWORD")
                and (msg["Authentication"]["serviceDescription"]
                     == "LDAP Password Change")
                and (msg["Authentication"]["authDescription"] == "LDAP Modify")
                and
                (msg["Authentication"]["eventId"] == EVT_ID_UNSUCCESSFUL_LOGON)
                and (msg["Authentication"]["logonType"] == EVT_LOGON_NETWORK))

        new_password = samba.generate_random_password(32, 32)
        try:
            self.ldb.modify_ldif("dn: cn=" + USER_NAME + ",cn=users," +
                                 self.base_dn + "\n" + "changetype: modify\n" +
                                 "delete: userPassword\n" + "userPassword: "******"badPassword" + "\n" + "add: userPassword\n" +
                                 "userPassword: "******"\n")
            self.fail()
        except LdbError as e1:
            (num, msg) = e1.args
            pass

        messages = self.waitForMessages(isLastExpectedMessage)
        print("Received %d messages" % len(messages))
        self.assertEquals(4, len(messages),
                          "Did not receive the expected number of messages")
Exemple #36
0
class PasswordTests(samba.tests.TestCase):

    def setUp(self):
        super(PasswordTests, self).setUp()
        self.ldb = SamDB(url=host, session_info=system_session(lp), credentials=creds, lp=lp)

        # Gets back the basedn
        base_dn = self.ldb.domain_dn()

        # Gets back the configuration basedn
        configuration_dn = self.ldb.get_config_basedn().get_linearized()

        # Get the old "dSHeuristics" if it was set
        dsheuristics = self.ldb.get_dsheuristics()

        # Set the "dSHeuristics" to activate the correct "userPassword" behaviour
        self.ldb.set_dsheuristics("000000001")

        # Reset the "dSHeuristics" as they were before
        self.addCleanup(self.ldb.set_dsheuristics, dsheuristics)

        # Get the old "minPwdAge"
        minPwdAge = self.ldb.get_minPwdAge()

        # Set it temporarily to "0"
        self.ldb.set_minPwdAge("0")
        self.base_dn = self.ldb.domain_dn()

        # Reset the "minPwdAge" as it was before
        self.addCleanup(self.ldb.set_minPwdAge, minPwdAge)

        # (Re)adds the test user "testuser" with no password atm
        delete_force(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        self.ldb.add({
             "dn": "cn=testuser,cn=users," + self.base_dn,
             "objectclass": "user",
             "sAMAccountName": "testuser"})

        # Tests a password change when we don't have any password yet with a
        # wrong old password
        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: noPassword
add: userPassword
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e:
            (num, msg) = e.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)
            # Windows (2008 at least) seems to have some small bug here: it
            # returns "0000056A" on longer (always wrong) previous passwords.
            self.assertTrue('00000056' in msg)

        # Sets the initial user password with a "special" password change
        # I think that this internally is a password set operation and it can
        # only be performed by someone which has password set privileges on the
        # account (at least in s4 we do handle it like that).
        self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
add: userPassword
userPassword: thatsAcomplPASS1
""")

        # But in the other way around this special syntax doesn't work
        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
""")
            self.fail()
        except LdbError as e1:
            (num, _) = e1.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        # Enables the user account
        self.ldb.enable_account("(sAMAccountName=testuser)")

        # Open a second LDB connection with the user credentials. Use the
        # command line credentials for informations like the domain, the realm
        # and the workstation.
        creds2 = Credentials()
        creds2.set_username("testuser")
        creds2.set_password("thatsAcomplPASS1")
        creds2.set_domain(creds.get_domain())
        creds2.set_realm(creds.get_realm())
        creds2.set_workstation(creds.get_workstation())
        creds2.set_gensec_features(creds2.get_gensec_features()
                                                          | gensec.FEATURE_SEAL)
        self.ldb2 = SamDB(url=host, credentials=creds2, lp=lp)

    def test_unicodePwd_hash_set(self):
        """Performs a password hash set operation on 'unicodePwd' which should be prevented"""
        # Notice: Direct hash password sets should never work

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["unicodePwd"] = MessageElement("XXXXXXXXXXXXXXXX", FLAG_MOD_REPLACE,
          "unicodePwd")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e2:
            (num, _) = e2.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

    def test_unicodePwd_hash_change(self):
        """Performs a password hash change operation on 'unicodePwd' which should be prevented"""
        # Notice: Direct hash password changes should never work

        # Hash password changes should never work
        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: unicodePwd
unicodePwd: XXXXXXXXXXXXXXXX
add: unicodePwd
unicodePwd: YYYYYYYYYYYYYYYY
""")
            self.fail()
        except LdbError as e3:
            (num, _) = e3.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

    def test_unicodePwd_clear_set(self):
        """Performs a password cleartext set operation on 'unicodePwd'"""

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["unicodePwd"] = MessageElement("\"thatsAcomplPASS2\"".encode('utf-16-le'),
          FLAG_MOD_REPLACE, "unicodePwd")
        self.ldb.modify(m)

    def test_unicodePwd_clear_change(self):
        """Performs a password cleartext change operation on 'unicodePwd'"""

        self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: unicodePwd
unicodePwd:: """ + base64.b64encode("\"thatsAcomplPASS1\"".encode('utf-16-le')) + """
add: unicodePwd
unicodePwd:: """ + base64.b64encode("\"thatsAcomplPASS2\"".encode('utf-16-le')) + """
""")

        # Wrong old password
        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: unicodePwd
unicodePwd:: """ + base64.b64encode("\"thatsAcomplPASS3\"".encode('utf-16-le')) + """
add: unicodePwd
unicodePwd:: """ + base64.b64encode("\"thatsAcomplPASS4\"".encode('utf-16-le')) + """
""")
            self.fail()
        except LdbError as e4:
            (num, msg) = e4.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)
            self.assertTrue('00000056' in msg)

        # A change to the same password again will not work (password history)
        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: unicodePwd
unicodePwd:: """ + base64.b64encode("\"thatsAcomplPASS2\"".encode('utf-16-le')) + """
add: unicodePwd
unicodePwd:: """ + base64.b64encode("\"thatsAcomplPASS2\"".encode('utf-16-le')) + """
""")
            self.fail()
        except LdbError as e5:
            (num, msg) = e5.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)
            self.assertTrue('0000052D' in msg)

    def test_dBCSPwd_hash_set(self):
        """Performs a password hash set operation on 'dBCSPwd' which should be prevented"""
        # Notice: Direct hash password sets should never work

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["dBCSPwd"] = MessageElement("XXXXXXXXXXXXXXXX", FLAG_MOD_REPLACE,
          "dBCSPwd")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e6:
            (num, _) = e6.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

    def test_dBCSPwd_hash_change(self):
        """Performs a password hash change operation on 'dBCSPwd' which should be prevented"""
        # Notice: Direct hash password changes should never work

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: dBCSPwd
dBCSPwd: XXXXXXXXXXXXXXXX
add: dBCSPwd
dBCSPwd: YYYYYYYYYYYYYYYY
""")
            self.fail()
        except LdbError as e7:
            (num, _) = e7.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

    def test_userPassword_clear_set(self):
        """Performs a password cleartext set operation on 'userPassword'"""
        # Notice: This works only against Windows if "dSHeuristics" has been set
        # properly

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement("thatsAcomplPASS2", FLAG_MOD_REPLACE,
          "userPassword")
        self.ldb.modify(m)

    def test_userPassword_clear_change(self):
        """Performs a password cleartext change operation on 'userPassword'"""
        # Notice: This works only against Windows if "dSHeuristics" has been set
        # properly

        self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
""")

        # Wrong old password
        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS3
add: userPassword
userPassword: thatsAcomplPASS4
""")
            self.fail()
        except LdbError as e8:
            (num, msg) = e8.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)
            self.assertTrue('00000056' in msg)

        # A change to the same password again will not work (password history)
        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS2
add: userPassword
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e9:
            (num, msg) = e9.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)
            self.assertTrue('0000052D' in msg)

    def test_clearTextPassword_clear_set(self):
        """Performs a password cleartext set operation on 'clearTextPassword'"""
        # Notice: This never works against Windows - only supported by us

        try:
            m = Message()
            m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
            m["clearTextPassword"] = MessageElement("thatsAcomplPASS2".encode('utf-16-le'),
              FLAG_MOD_REPLACE, "clearTextPassword")
            self.ldb.modify(m)
            # this passes against s4
        except LdbError as e10:
            (num, msg) = e10.args
            # "NO_SUCH_ATTRIBUTE" is returned by Windows -> ignore it
            if num != ERR_NO_SUCH_ATTRIBUTE:
                raise LdbError(num, msg)

    def test_clearTextPassword_clear_change(self):
        """Performs a password cleartext change operation on 'clearTextPassword'"""
        # Notice: This never works against Windows - only supported by us

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: clearTextPassword
clearTextPassword:: """ + base64.b64encode("thatsAcomplPASS1".encode('utf-16-le')) + """
add: clearTextPassword
clearTextPassword:: """ + base64.b64encode("thatsAcomplPASS2".encode('utf-16-le')) + """
""")
            # this passes against s4
        except LdbError as e11:
            (num, msg) = e11.args
            # "NO_SUCH_ATTRIBUTE" is returned by Windows -> ignore it
            if num != ERR_NO_SUCH_ATTRIBUTE:
                raise LdbError(num, msg)

        # Wrong old password
        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: clearTextPassword
clearTextPassword:: """ + base64.b64encode("thatsAcomplPASS3".encode('utf-16-le')) + """
add: clearTextPassword
clearTextPassword:: """ + base64.b64encode("thatsAcomplPASS4".encode('utf-16-le')) + """
""")
            self.fail()
        except LdbError as e12:
            (num, msg) = e12.args
            # "NO_SUCH_ATTRIBUTE" is returned by Windows -> ignore it
            if num != ERR_NO_SUCH_ATTRIBUTE:
                self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)
                self.assertTrue('00000056' in msg)

        # A change to the same password again will not work (password history)
        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: clearTextPassword
clearTextPassword:: """ + base64.b64encode("thatsAcomplPASS2".encode('utf-16-le')) + """
add: clearTextPassword
clearTextPassword:: """ + base64.b64encode("thatsAcomplPASS2".encode('utf-16-le')) + """
""")
            self.fail()
        except LdbError as e13:
            (num, msg) = e13.args
            # "NO_SUCH_ATTRIBUTE" is returned by Windows -> ignore it
            if num != ERR_NO_SUCH_ATTRIBUTE:
                self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)
                self.assertTrue('0000052D' in msg)

    def test_failures(self):
        """Performs some failure testing"""

        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
""")
            self.fail()
        except LdbError as e14:
            (num, _) = e14.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
""")
            self.fail()
        except LdbError as e15:
            (num, _) = e15.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
""")
            self.fail()
        except LdbError as e16:
            (num, _) = e16.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
""")
            self.fail()
        except LdbError as e17:
            (num, _) = e17.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
add: userPassword
userPassword: thatsAcomplPASS1
""")
            self.fail()
        except LdbError as e18:
            (num, _) = e18.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
add: userPassword
userPassword: thatsAcomplPASS1
""")
            self.fail()
        except LdbError as e19:
            (num, _) = e19.args
            self.assertEquals(num, ERR_INSUFFICIENT_ACCESS_RIGHTS)

        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e20:
            (num, _) = e20.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e21:
            (num, _) = e21.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e22:
            (num, _) = e22.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e23:
            (num, _) = e23.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
add: userPassword
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e24:
            (num, _) = e24.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
add: userPassword
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e25:
            (num, _) = e25.args
            self.assertEquals(num, ERR_INSUFFICIENT_ACCESS_RIGHTS)

        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e26:
            (num, _) = e26.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
""")
            self.fail()
        except LdbError as e27:
            (num, _) = e27.args
            self.assertEquals(num, ERR_INSUFFICIENT_ACCESS_RIGHTS)

        try:
            self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
replace: userPassword
userPassword: thatsAcomplPASS3
""")
            self.fail()
        except LdbError as e28:
            (num, _) = e28.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS1
add: userPassword
userPassword: thatsAcomplPASS2
replace: userPassword
userPassword: thatsAcomplPASS3
""")
            self.fail()
        except LdbError as e29:
            (num, _) = e29.args
            self.assertEquals(num, ERR_INSUFFICIENT_ACCESS_RIGHTS)

        # Reverse order does work
        self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
add: userPassword
userPassword: thatsAcomplPASS2
delete: userPassword
userPassword: thatsAcomplPASS1
""")

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
userPassword: thatsAcomplPASS2
add: unicodePwd
unicodePwd:: """ + base64.b64encode("\"thatsAcomplPASS3\"".encode('utf-16-le')) + """
""")
             # this passes against s4
        except LdbError as e30:
            (num, _) = e30.args
            self.assertEquals(num, ERR_ATTRIBUTE_OR_VALUE_EXISTS)

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: unicodePwd
unicodePwd:: """ + base64.b64encode("\"thatsAcomplPASS3\"".encode('utf-16-le')) + """
add: userPassword
userPassword: thatsAcomplPASS4
""")
             # this passes against s4
        except LdbError as e31:
            (num, _) = e31.args
            self.assertEquals(num, ERR_NO_SUCH_ATTRIBUTE)

        # Several password changes at once are allowed
        self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
replace: userPassword
userPassword: thatsAcomplPASS1
userPassword: thatsAcomplPASS2
""")

        # Several password changes at once are allowed
        self.ldb.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
replace: userPassword
userPassword: thatsAcomplPASS1
userPassword: thatsAcomplPASS2
replace: userPassword
userPassword: thatsAcomplPASS3
replace: userPassword
userPassword: thatsAcomplPASS4
""")

        # This surprisingly should work
        delete_force(self.ldb, "cn=testuser2,cn=users," + self.base_dn)
        self.ldb.add({
             "dn": "cn=testuser2,cn=users," + self.base_dn,
             "objectclass": "user",
             "userPassword": ["thatsAcomplPASS1", "thatsAcomplPASS2"] })

        # This surprisingly should work
        delete_force(self.ldb, "cn=testuser2,cn=users," + self.base_dn)
        self.ldb.add({
             "dn": "cn=testuser2,cn=users," + self.base_dn,
             "objectclass": "user",
             "userPassword": ["thatsAcomplPASS1", "thatsAcomplPASS1"] })

    def test_empty_passwords(self):
        print("Performs some empty passwords testing")

        try:
            self.ldb.add({
                 "dn": "cn=testuser2,cn=users," + self.base_dn,
                 "objectclass": "user",
                 "unicodePwd": [] })
            self.fail()
        except LdbError as e32:
            (num, _) = e32.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb.add({
                 "dn": "cn=testuser2,cn=users," + self.base_dn,
                 "objectclass": "user",
                 "dBCSPwd": [] })
            self.fail()
        except LdbError as e33:
            (num, _) = e33.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb.add({
                 "dn": "cn=testuser2,cn=users," + self.base_dn,
                 "objectclass": "user",
                 "userPassword": [] })
            self.fail()
        except LdbError as e34:
            (num, _) = e34.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        try:
            self.ldb.add({
                 "dn": "cn=testuser2,cn=users," + self.base_dn,
                 "objectclass": "user",
                 "clearTextPassword": [] })
            self.fail()
        except LdbError as e35:
            (num, _) = e35.args
            self.assertTrue(num == ERR_CONSTRAINT_VIOLATION or
                            num == ERR_NO_SUCH_ATTRIBUTE) # for Windows

        delete_force(self.ldb, "cn=testuser2,cn=users," + self.base_dn)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["unicodePwd"] = MessageElement([], FLAG_MOD_ADD, "unicodePwd")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e36:
            (num, _) = e36.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["dBCSPwd"] = MessageElement([], FLAG_MOD_ADD, "dBCSPwd")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e37:
            (num, _) = e37.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement([], FLAG_MOD_ADD, "userPassword")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e38:
            (num, _) = e38.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["clearTextPassword"] = MessageElement([], FLAG_MOD_ADD, "clearTextPassword")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e39:
            (num, _) = e39.args
            self.assertTrue(num == ERR_CONSTRAINT_VIOLATION or
                            num == ERR_NO_SUCH_ATTRIBUTE) # for Windows

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["unicodePwd"] = MessageElement([], FLAG_MOD_REPLACE, "unicodePwd")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e40:
            (num, _) = e40.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["dBCSPwd"] = MessageElement([], FLAG_MOD_REPLACE, "dBCSPwd")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e41:
            (num, _) = e41.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement([], FLAG_MOD_REPLACE, "userPassword")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e42:
            (num, _) = e42.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["clearTextPassword"] = MessageElement([], FLAG_MOD_REPLACE, "clearTextPassword")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e43:
            (num, _) = e43.args
            self.assertTrue(num == ERR_UNWILLING_TO_PERFORM or
                            num == ERR_NO_SUCH_ATTRIBUTE) # for Windows

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["unicodePwd"] = MessageElement([], FLAG_MOD_DELETE, "unicodePwd")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e44:
            (num, _) = e44.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["dBCSPwd"] = MessageElement([], FLAG_MOD_DELETE, "dBCSPwd")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e45:
            (num, _) = e45.args
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement([], FLAG_MOD_DELETE, "userPassword")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e46:
            (num, _) = e46.args
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["clearTextPassword"] = MessageElement([], FLAG_MOD_DELETE, "clearTextPassword")
        try:
            self.ldb.modify(m)
            self.fail()
        except LdbError as e47:
            (num, _) = e47.args
            self.assertTrue(num == ERR_CONSTRAINT_VIOLATION or
                            num == ERR_NO_SUCH_ATTRIBUTE) # for Windows

    def test_plain_userPassword(self):
        print("Performs testing about the standard 'userPassword' behaviour")

        # Delete the "dSHeuristics"
        self.ldb.set_dsheuristics(None)

        time.sleep(1) # This switching time is strictly needed!

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement("myPassword", FLAG_MOD_ADD,
          "userPassword")
        self.ldb.modify(m)

        res = self.ldb.search("cn=testuser,cn=users," + self.base_dn,
                         scope=SCOPE_BASE, attrs=["userPassword"])
        self.assertTrue(len(res) == 1)
        self.assertTrue("userPassword" in res[0])
        self.assertEquals(res[0]["userPassword"][0], "myPassword")

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement("myPassword2", FLAG_MOD_REPLACE,
          "userPassword")
        self.ldb.modify(m)

        res = self.ldb.search("cn=testuser,cn=users," + self.base_dn,
                         scope=SCOPE_BASE, attrs=["userPassword"])
        self.assertTrue(len(res) == 1)
        self.assertTrue("userPassword" in res[0])
        self.assertEquals(res[0]["userPassword"][0], "myPassword2")

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement([], FLAG_MOD_DELETE,
          "userPassword")
        self.ldb.modify(m)

        res = self.ldb.search("cn=testuser,cn=users," + self.base_dn,
                         scope=SCOPE_BASE, attrs=["userPassword"])
        self.assertTrue(len(res) == 1)
        self.assertFalse("userPassword" in res[0])

        # Set the test "dSHeuristics" to deactivate "userPassword" pwd changes
        self.ldb.set_dsheuristics("000000000")

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement("myPassword3", FLAG_MOD_REPLACE,
          "userPassword")
        self.ldb.modify(m)

        res = self.ldb.search("cn=testuser,cn=users," + self.base_dn,
                         scope=SCOPE_BASE, attrs=["userPassword"])
        self.assertTrue(len(res) == 1)
        self.assertTrue("userPassword" in res[0])
        self.assertEquals(res[0]["userPassword"][0], "myPassword3")

        # Set the test "dSHeuristics" to deactivate "userPassword" pwd changes
        self.ldb.set_dsheuristics("000000002")

        m = Message()
        m.dn = Dn(self.ldb, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement("myPassword4", FLAG_MOD_REPLACE,
          "userPassword")
        self.ldb.modify(m)

        res = self.ldb.search("cn=testuser,cn=users," + self.base_dn,
                         scope=SCOPE_BASE, attrs=["userPassword"])
        self.assertTrue(len(res) == 1)
        self.assertTrue("userPassword" in res[0])
        self.assertEquals(res[0]["userPassword"][0], "myPassword4")

        # Reset the test "dSHeuristics" (reactivate "userPassword" pwd changes)
        self.ldb.set_dsheuristics("000000001")

    def test_modify_dsheuristics_userPassword(self):
        print("Performs testing about reading userPassword between dsHeuristic modifies")

        # Make sure userPassword cannot be read
        self.ldb.set_dsheuristics("000000000")

        # Open a new connection (with dsHeuristic=000000000)
        ldb1 = SamDB(url=host, session_info=system_session(lp),
                     credentials=creds, lp=lp)

        # Set userPassword to be read
        # This setting only affects newer connections (ldb2)
        ldb1.set_dsheuristics("000000001")
        time.sleep(1)

        m = Message()
        m.dn = Dn(ldb1, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement("thatsAcomplPASS1", FLAG_MOD_REPLACE,
          "userPassword")
        ldb1.modify(m)

        res = ldb1.search("cn=testuser,cn=users," + self.base_dn,
                          scope=SCOPE_BASE, attrs=["userPassword"])

        # userPassword cannot be read, despite the dsHeuristic setting
        self.assertTrue(len(res) == 1)
        self.assertFalse("userPassword" in res[0])

        # Open another new connection (with dsHeuristic=000000001)
        ldb2 = SamDB(url=host, session_info=system_session(lp),
                     credentials=creds, lp=lp)

        # Set userPassword to be unreadable
        # This setting does not affect this connection
        ldb2.set_dsheuristics("000000000")
        time.sleep(1)

        res = ldb2.search("cn=testuser,cn=users," + self.base_dn,
                          scope=SCOPE_BASE, attrs=["userPassword"])

        # Check that userPassword was not stored from ldb1
        self.assertTrue(len(res) == 1)
        self.assertFalse("userPassword" in res[0])

        m = Message()
        m.dn = Dn(ldb2, "cn=testuser,cn=users," + self.base_dn)
        m["userPassword"] = MessageElement("thatsAcomplPASS2", FLAG_MOD_REPLACE,
          "userPassword")
        ldb2.modify(m)

        res = ldb2.search("cn=testuser,cn=users," + self.base_dn,
                          scope=SCOPE_BASE, attrs=["userPassword"])

        # userPassword can be read in this connection
        # This is regardless of the current dsHeuristics setting
        self.assertTrue(len(res) == 1)
        self.assertTrue("userPassword" in res[0])
        self.assertEquals(res[0]["userPassword"][0], "thatsAcomplPASS2")

        # Only password from ldb1 is the user's password
        creds2 = Credentials()
        creds2.set_username("testuser")
        creds2.set_password("thatsAcomplPASS1")
        creds2.set_domain(creds.get_domain())
        creds2.set_realm(creds.get_realm())
        creds2.set_workstation(creds.get_workstation())
        creds2.set_gensec_features(creds2.get_gensec_features()
                                   | gensec.FEATURE_SEAL)

        try:
            SamDB(url=host, credentials=creds2, lp=lp)
        except:
            self.fail("testuser used the wrong password")

        ldb3 = SamDB(url=host, session_info=system_session(lp),
                     credentials=creds, lp=lp)

        # Check that userPassword was stored from ldb2
        res = ldb3.search("cn=testuser,cn=users," + self.base_dn,
                          scope=SCOPE_BASE, attrs=["userPassword"])

        # userPassword can be read
        self.assertTrue(len(res) == 1)
        self.assertTrue("userPassword" in res[0])
        self.assertEquals(res[0]["userPassword"][0], "thatsAcomplPASS2")

        # Reset the test "dSHeuristics" (reactivate "userPassword" pwd changes)
        self.ldb.set_dsheuristics("000000001")

    def test_zero_length(self):
        # Get the old "minPwdLength"
        minPwdLength = self.ldb.get_minPwdLength()
        # Set it temporarely to "0"
        self.ldb.set_minPwdLength("0")

        # Get the old "pwdProperties"
        pwdProperties = self.ldb.get_pwdProperties()
        # Set them temporarely to "0" (to deactivate eventually the complexity)
        self.ldb.set_pwdProperties("0")

        self.ldb.setpassword("(sAMAccountName=testuser)", "")

        # Reset the "pwdProperties" as they were before
        self.ldb.set_pwdProperties(pwdProperties)

        # Reset the "minPwdLength" as it was before
        self.ldb.set_minPwdLength(minPwdLength)

    def test_pw_change_delete_no_value_userPassword(self):
        """Test password change with userPassword where the delete attribute doesn't have a value"""

        try:
            self.ldb2.modify_ldif("""
dn: cn=testuser,cn=users,""" + self.base_dn + """
changetype: modify
delete: userPassword
add: userPassword
userPassword: thatsAcomplPASS1
""")
        except LdbError, (num, msg):
            self.assertEquals(num, ERR_CONSTRAINT_VIOLATION)
        else:
Exemple #37
0
class AuthLogTestsNetLogonBadCreds(samba.tests.auth_log_base.AuthLogTestBase):

    def setUp(self):
        super(AuthLogTestsNetLogonBadCreds, self).setUp()
        self.lp      = samba.tests.env_loadparm()
        self.creds   = Credentials()

        self.session = system_session()
        self.ldb = SamDB(
            session_info=self.session,
            credentials=self.creds,
            lp=self.lp)

        self.domain        = os.environ["DOMAIN"]
        self.netbios_name  = "NetLogonBad"
        self.machinepass   = "******"
        self.remoteAddress = AS_SYSTEM_MAGIC_PATH_TOKEN
        self.base_dn       = self.ldb.domain_dn()
        self.dn            = ("cn=%s,cn=users,%s" %
                              (self.netbios_name, self.base_dn))

        utf16pw = unicode(
            '"' + self.machinepass.encode('utf-8') + '"', 'utf-8'
        ).encode('utf-16-le')
        self.ldb.add({
            "dn": self.dn,
            "objectclass": "computer",
            "sAMAccountName": "%s$" % self.netbios_name,
            "userAccountControl":
                str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
            "unicodePwd": utf16pw})

    def tearDown(self):
        super(AuthLogTestsNetLogonBadCreds, self).tearDown()
        delete_force(self.ldb, self.dn)

    def _test_netlogon(self, name, pwd, status, checkFunction):

        def isLastExpectedMessage(msg):
            return (
                msg["type"] == "Authentication" and
                msg["Authentication"]["serviceDescription"] == "NETLOGON" and
                msg["Authentication"]["authDescription"] ==
                "ServerAuthenticate" and
                msg["Authentication"]["status"] == status)

        machine_creds = Credentials()
        machine_creds.guess(self.get_loadparm())
        machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
        machine_creds.set_password(pwd)
        machine_creds.set_username(name + "$")

        try:
            netlogon.netlogon("ncalrpc:[schannel]",
                              self.get_loadparm(),
                              machine_creds)
            self.fail("NTSTATUSError not raised")
        except NTSTATUSError:
            pass

        messages = self.waitForMessages(isLastExpectedMessage)
        checkFunction(messages)

    def netlogon_check(self, messages):

        expected_messages = 4
        self.assertEquals(expected_messages,
                          len(messages),
                          "Did not receive the expected number of messages")

        # Check the first message it should be an Authorization
        msg = messages[0]
        self.assertEquals("Authorization", msg["type"])
        self.assertEquals("DCE/RPC",
                          msg["Authorization"]["serviceDescription"])
        self.assertEquals("ncalrpc", msg["Authorization"]["authType"])
        self.assertEquals("NONE", msg["Authorization"]["transportProtection"])
        self.assertTrue(self.is_guid(msg["Authorization"]["sessionId"]))

    def test_netlogon_bad_machine_name(self):
        self._test_netlogon("bad_name",
                            self.machinepass,
                            "NT_STATUS_NO_TRUST_SAM_ACCOUNT",
                            self.netlogon_check)

    def test_netlogon_bad_password(self):
        self._test_netlogon(self.netbios_name,
                            "badpass",
                            "NT_STATUS_ACCESS_DENIED",
                            self.netlogon_check)

    def test_netlogon_password_DES(self):
        """Logon failure that exercises the "DES" passwordType path.
        """
        def isLastExpectedMessage(msg):
            return (
                msg["type"] == "Authentication" and
                msg["Authentication"]["serviceDescription"] == "NETLOGON" and
                msg["Authentication"]["authDescription"] ==
                "ServerAuthenticate" and
                msg["Authentication"]["passwordType"] == "DES")

        c = netlogon.netlogon("ncalrpc:[schannel]", self.get_loadparm())
        creds = netlogon.netr_Credential()
        c.netr_ServerReqChallenge(self.server, self.netbios_name, creds)
        try:
            c.netr_ServerAuthenticate3(self.server,
                                       self.netbios_name,
                                       SEC_CHAN_WKSTA,
                                       self.netbios_name,
                                       creds,
                                       0)
        except NTSTATUSError:
            pass
        self.waitForMessages(isLastExpectedMessage)

    def test_netlogon_password_HMAC_MD5(self):
        """Logon failure that exercises the "HMAC-MD5" passwordType path.
        """
        def isLastExpectedMessage(msg):
            return (
                msg["type"] == "Authentication" and
                msg["Authentication"]["serviceDescription"] == "NETLOGON" and
                msg["Authentication"]["authDescription"] ==
                "ServerAuthenticate" and
                msg["Authentication"]["passwordType"] == "HMAC-MD5")
        c = netlogon.netlogon("ncalrpc:[schannel]", self.get_loadparm())
        creds = netlogon.netr_Credential()
        c.netr_ServerReqChallenge(self.server, self.netbios_name, creds)
        try:
            c.netr_ServerAuthenticate3(self.server,
                                       self.netbios_name,
                                       SEC_CHAN_WKSTA,
                                       self.netbios_name,
                                       creds,
                                       NETLOGON_NEG_STRONG_KEYS)
        except NTSTATUSError:
            pass
        self.waitForMessages(isLastExpectedMessage)
Exemple #38
0
class AuthLogTestsWinbind(AuthLogTestBase, BlackboxTestCase):

    #
    # Helper function to watch for authentication messages on the
    # Domain Controller.
    #
    def dc_watcher(self):

        (r1, w1) = os.pipe()
        pid = os.fork()
        if pid != 0:
            # Parent process return the result socket to the caller.
            return r1

        # Load the lp context for the Domain Controller, rather than the
        # member server.
        config_file = os.environ["DC_SERVERCONFFILE"]
        lp_ctx = LoadParm()
        lp_ctx.load(config_file)

        #
        # Is the message a SamLogon authentication?
        def is_sam_logon(m):
            if m is None:
                return False
            msg = json.loads(m)
            return (msg["type"] == "Authentication" and
                    msg["Authentication"]["serviceDescription"] == "SamLogon")

        #
        # Handler function for received authentication messages.
        def message_handler(context, msgType, src, message):
            # Print the message to help debugging the tests.
            # as it's a JSON message it does not look like a sub-unit message.
            print(message)
            self.dc_msgs.append(message)

        # Set up a messaging context to listen for authentication events on
        # the domain controller.
        msg_ctx = Messaging((1, ), lp_ctx=lp_ctx)
        msg_ctx.irpc_add_name(AUTH_EVENT_NAME)
        msg_handler_and_context = (message_handler, None)
        msg_ctx.register(msg_handler_and_context, msg_type=MSG_AUTH_LOG)

        # Wait for the SamLogon message.
        # As there could be other SamLogon's in progress we need to collect
        # all the SamLogons and let the caller match them to the session.
        self.dc_msgs = []
        start_time = time.time()
        while (time.time() - start_time < 1):
            msg_ctx.loop_once(0.1)

        # Only interested in SamLogon messages, filter out the rest
        msgs = list(filter(is_sam_logon, self.dc_msgs))
        if msgs:
            for m in msgs:
                m += "\n"
                os.write(w1, get_bytes(m))
        else:
            os.write(w1, get_bytes("None\n"))
        os.close(w1)

        msg_ctx.deregister(msg_handler_and_context, msg_type=MSG_AUTH_LOG)
        msg_ctx.irpc_remove_name(AUTH_EVENT_NAME)

        os._exit(0)

    # Remove any DCE/RPC ncacn_np messages
    # these only get triggered once per session, and stripping them out
    # avoids ordering dependencies in the tests
    #
    def filter_messages(self, messages):
        def keep(msg):
            if (msg["type"] == "Authorization"
                    and msg["Authorization"]["serviceDescription"] == "DCE/RPC"
                    and msg["Authorization"]["authType"] == "ncacn_np"):
                return False
            else:
                return True

        return list(filter(keep, messages))

    def setUp(self):
        super(AuthLogTestsWinbind, self).setUp()
        self.domain = os.environ["DOMAIN"]
        self.host = os.environ["SERVER"]
        self.dc = os.environ["DC_SERVER"]
        self.lp = self.get_loadparm()
        self.credentials = self.get_credentials()
        self.session = system_session()

        self.ldb = SamDB(url="ldap://{0}".format(self.dc),
                         session_info=self.session,
                         credentials=self.credentials,
                         lp=self.lp)
        self.create_user_account()

    def tearDown(self):
        super(AuthLogTestsWinbind, self).tearDown()
        delete_force(self.ldb, self.user_dn)

    #
    # Create a test user account
    def create_user_account(self):
        self.user_pass = self.random_password()
        self.user_name = USER_NAME
        self.user_dn = "cn=%s,%s" % (self.user_name, self.ldb.domain_dn())

        # remove the account if it exists, this will happen if a previous test
        # run failed
        delete_force(self.ldb, self.user_dn)

        utf16pw = ('"%s"' % get_string(self.user_pass)).encode('utf-16-le')
        self.ldb.add({
            "dn": self.user_dn,
            "objectclass": "user",
            "sAMAccountName": "%s" % self.user_name,
            "userAccountControl": str(UF_NORMAL_ACCOUNT),
            "unicodePwd": utf16pw
        })

        self.user_creds = Credentials()
        self.user_creds.guess(self.get_loadparm())
        self.user_creds.set_password(self.user_pass)
        self.user_creds.set_username(self.user_name)
        self.user_creds.set_workstation(self.server)

    #
    # Check that the domain server received a SamLogon request for the
    # current logon.
    #
    def check_domain_server_authentication(self, pipe, logon_id, description):

        messages = os.read(pipe, 8192)
        messages = get_string(messages)
        if len(messages) == 0 or messages == "None":
            self.fail("No Domain server authentication message")

        #
        # Look for the SamLogon request matching logon_id
        msg = None
        for message in messages.split("\n"):
            msg = json.loads(get_string(message))
            if logon_id == msg["Authentication"]["logonId"]:
                break
            msg = None

        if msg is None:
            self.fail("No Domain server authentication message")

        #
        # Validate that message contains the expected data
        #
        self.assertEquals("Authentication", msg["type"])
        self.assertEquals(logon_id, msg["Authentication"]["logonId"])
        self.assertEquals("SamLogon",
                          msg["Authentication"]["serviceDescription"])
        self.assertEquals(description,
                          msg["Authentication"]["authDescription"])

    def test_ntlm_auth(self):
        def isLastExpectedMessage(msg):
            DESC = "PAM_AUTH, ntlm_auth"
            return (
                msg["type"] == "Authentication"
                and msg["Authentication"]["serviceDescription"] == "winbind"
                and msg["Authentication"]["authDescription"] is not None
                and msg["Authentication"]["authDescription"].startswith(DESC))

        pipe = self.dc_watcher()
        COMMAND = "bin/ntlm_auth"
        self.check_run("{0} --username={1} --password={2}".format(
            COMMAND, self.credentials.get_username(),
            self.credentials.get_password()),
                       msg="ntlm_auth failed")

        messages = self.waitForMessages(isLastExpectedMessage)
        messages = self.filter_messages(messages)
        expected_messages = 1
        self.assertEquals(expected_messages, len(messages),
                          "Did not receive the expected number of messages")

        # Check the first message it should be an Authentication
        msg = messages[0]
        self.assertEquals("Authentication", msg["type"])
        self.assertTrue(msg["Authentication"]["authDescription"].startswith(
            "PAM_AUTH, ntlm_auth,"))
        self.assertEquals("winbind",
                          msg["Authentication"]["serviceDescription"])
        self.assertEquals("Plaintext", msg["Authentication"]["passwordType"])
        # Logon type should be NetworkCleartext
        self.assertEquals(8, msg["Authentication"]["logonType"])
        # Event code should be Successful logon
        self.assertEquals(4624, msg["Authentication"]["eventId"])
        self.assertEquals("unix:", msg["Authentication"]["remoteAddress"])
        self.assertEquals("unix:", msg["Authentication"]["localAddress"])
        self.assertEquals(self.domain, msg["Authentication"]["clientDomain"])
        self.assertEquals("NT_STATUS_OK", msg["Authentication"]["status"])
        self.assertEquals(self.credentials.get_username(),
                          msg["Authentication"]["clientAccount"])
        self.assertEquals(self.credentials.get_domain(),
                          msg["Authentication"]["clientDomain"])
        self.assertTrue(msg["Authentication"]["workstation"] is None)

        logon_id = msg["Authentication"]["logonId"]

        #
        # Now check the Domain server authentication message
        #
        self.check_domain_server_authentication(pipe, logon_id, "interactive")

    def test_wbinfo(self):
        def isLastExpectedMessage(msg):
            DESC = "NTLM_AUTH, wbinfo"
            return (
                msg["type"] == "Authentication"
                and msg["Authentication"]["serviceDescription"] == "winbind"
                and msg["Authentication"]["authDescription"] is not None
                and msg["Authentication"]["authDescription"].startswith(DESC))

        pipe = self.dc_watcher()
        COMMAND = "bin/wbinfo"
        try:
            self.check_run("{0} -a {1}%{2}".format(
                COMMAND, self.credentials.get_username(),
                self.credentials.get_password()),
                           msg="ntlm_auth failed")
        except BlackboxProcessError:
            pass

        messages = self.waitForMessages(isLastExpectedMessage)
        messages = self.filter_messages(messages)
        expected_messages = 3
        self.assertEquals(expected_messages, len(messages),
                          "Did not receive the expected number of messages")

        # The 1st message should be an Authentication against the local
        # password database
        msg = messages[0]
        self.assertEquals("Authentication", msg["type"])
        self.assertTrue(msg["Authentication"]["authDescription"].startswith(
            "PASSDB, wbinfo,"))
        self.assertEquals("winbind",
                          msg["Authentication"]["serviceDescription"])
        # Logon type should be Interactive
        self.assertEquals(2, msg["Authentication"]["logonType"])
        # Event code should be Unsuccessful logon
        self.assertEquals(4625, msg["Authentication"]["eventId"])
        self.assertEquals("unix:", msg["Authentication"]["remoteAddress"])
        self.assertEquals("unix:", msg["Authentication"]["localAddress"])
        self.assertEquals('', msg["Authentication"]["clientDomain"])
        # This is what the existing winbind implementation returns.
        self.assertEquals("NT_STATUS_NO_SUCH_USER",
                          msg["Authentication"]["status"])
        self.assertEquals("NTLMv2", msg["Authentication"]["passwordType"])
        self.assertEquals(self.credentials.get_username(),
                          msg["Authentication"]["clientAccount"])
        self.assertEquals("", msg["Authentication"]["clientDomain"])

        logon_id = msg["Authentication"]["logonId"]

        # The 2nd message should be a PAM_AUTH with the same logon id as the
        # 1st message
        msg = messages[1]
        self.assertEquals("Authentication", msg["type"])
        self.assertTrue(
            msg["Authentication"]["authDescription"].startswith("PAM_AUTH"))
        self.assertEquals("winbind",
                          msg["Authentication"]["serviceDescription"])
        self.assertEquals(logon_id, msg["Authentication"]["logonId"])
        # Logon type should be NetworkCleartext
        self.assertEquals(8, msg["Authentication"]["logonType"])
        # Event code should be Unsuccessful logon
        self.assertEquals(4625, msg["Authentication"]["eventId"])
        self.assertEquals("unix:", msg["Authentication"]["remoteAddress"])
        self.assertEquals("unix:", msg["Authentication"]["localAddress"])
        self.assertEquals('', msg["Authentication"]["clientDomain"])
        # This is what the existing winbind implementation returns.
        self.assertEquals("NT_STATUS_INVALID_HANDLE",
                          msg["Authentication"]["status"])
        self.assertEquals(self.credentials.get_username(),
                          msg["Authentication"]["clientAccount"])
        self.assertEquals("", msg["Authentication"]["clientDomain"])

        # The 3rd message should be an NTLM_AUTH
        msg = messages[2]
        self.assertEquals("Authentication", msg["type"])
        self.assertTrue(msg["Authentication"]["authDescription"].startswith(
            "NTLM_AUTH, wbinfo,"))
        self.assertEquals("winbind",
                          msg["Authentication"]["serviceDescription"])
        # Logon type should be Network
        self.assertEquals(3, msg["Authentication"]["logonType"])
        self.assertEquals("NT_STATUS_OK", msg["Authentication"]["status"])
        # Event code should be successful logon
        self.assertEquals(4624, msg["Authentication"]["eventId"])
        self.assertEquals("NTLMv2", msg["Authentication"]["passwordType"])
        self.assertEquals("unix:", msg["Authentication"]["remoteAddress"])
        self.assertEquals("unix:", msg["Authentication"]["localAddress"])
        self.assertEquals(self.credentials.get_username(),
                          msg["Authentication"]["clientAccount"])
        self.assertEquals(self.credentials.get_domain(),
                          msg["Authentication"]["clientDomain"])

        logon_id = msg["Authentication"]["logonId"]

        #
        # Now check the Domain server authentication message
        #
        self.check_domain_server_authentication(pipe, logon_id, "network")

    def test_wbinfo_ntlmv1(self):
        def isLastExpectedMessage(msg):
            DESC = "NTLM_AUTH, wbinfo"
            return (
                msg["type"] == "Authentication"
                and msg["Authentication"]["serviceDescription"] == "winbind"
                and msg["Authentication"]["authDescription"] is not None
                and msg["Authentication"]["authDescription"].startswith(DESC))

        pipe = self.dc_watcher()
        COMMAND = "bin/wbinfo"
        try:
            self.check_run("{0} --ntlmv1 -a {1}%{2}".format(
                COMMAND, self.credentials.get_username(),
                self.credentials.get_password()),
                           msg="ntlm_auth failed")
        except BlackboxProcessError:
            pass

        messages = self.waitForMessages(isLastExpectedMessage)
        messages = self.filter_messages(messages)
        expected_messages = 3
        self.assertEquals(expected_messages, len(messages),
                          "Did not receive the expected number of messages")

        # The 1st message should be an Authentication against the local
        # password database
        msg = messages[0]
        self.assertEquals("Authentication", msg["type"])
        self.assertTrue(msg["Authentication"]["authDescription"].startswith(
            "PASSDB, wbinfo,"))
        self.assertEquals("winbind",
                          msg["Authentication"]["serviceDescription"])
        # Logon type should be Interactive
        self.assertEquals(2, msg["Authentication"]["logonType"])
        # Event code should be Unsuccessful logon
        self.assertEquals(4625, msg["Authentication"]["eventId"])
        self.assertEquals("unix:", msg["Authentication"]["remoteAddress"])
        self.assertEquals("unix:", msg["Authentication"]["localAddress"])
        self.assertEquals('', msg["Authentication"]["clientDomain"])
        # This is what the existing winbind implementation returns.
        self.assertEquals("NT_STATUS_NO_SUCH_USER",
                          msg["Authentication"]["status"])
        self.assertEquals("NTLMv2", msg["Authentication"]["passwordType"])
        self.assertEquals(self.credentials.get_username(),
                          msg["Authentication"]["clientAccount"])
        self.assertEquals("", msg["Authentication"]["clientDomain"])

        logon_id = msg["Authentication"]["logonId"]

        # The 2nd message should be a PAM_AUTH with the same logon id as the
        # 1st message
        msg = messages[1]
        self.assertEquals("Authentication", msg["type"])
        self.assertTrue(
            msg["Authentication"]["authDescription"].startswith("PAM_AUTH"))
        self.assertEquals("winbind",
                          msg["Authentication"]["serviceDescription"])
        self.assertEquals(logon_id, msg["Authentication"]["logonId"])
        self.assertEquals("Plaintext", msg["Authentication"]["passwordType"])
        # Logon type should be NetworkCleartext
        self.assertEquals(8, msg["Authentication"]["logonType"])
        # Event code should be Unsuccessful logon
        self.assertEquals(4625, msg["Authentication"]["eventId"])
        self.assertEquals("unix:", msg["Authentication"]["remoteAddress"])
        self.assertEquals("unix:", msg["Authentication"]["localAddress"])
        self.assertEquals('', msg["Authentication"]["clientDomain"])
        # This is what the existing winbind implementation returns.
        self.assertEquals("NT_STATUS_INVALID_HANDLE",
                          msg["Authentication"]["status"])
        self.assertEquals(self.credentials.get_username(),
                          msg["Authentication"]["clientAccount"])
        self.assertEquals("", msg["Authentication"]["clientDomain"])

        # The 3rd message should be an NTLM_AUTH
        msg = messages[2]
        self.assertEquals("Authentication", msg["type"])
        self.assertTrue(msg["Authentication"]["authDescription"].startswith(
            "NTLM_AUTH, wbinfo,"))
        self.assertEquals("winbind",
                          msg["Authentication"]["serviceDescription"])
        self.assertEquals("NTLMv1", msg["Authentication"]["passwordType"])
        # Logon type should be Network
        self.assertEquals(3, msg["Authentication"]["logonType"])
        self.assertEquals("NT_STATUS_OK", msg["Authentication"]["status"])
        # Event code should be successful logon
        self.assertEquals(4624, msg["Authentication"]["eventId"])
        self.assertEquals("unix:", msg["Authentication"]["remoteAddress"])
        self.assertEquals("unix:", msg["Authentication"]["localAddress"])
        self.assertEquals(self.credentials.get_username(),
                          msg["Authentication"]["clientAccount"])
        self.assertEquals(self.credentials.get_domain(),
                          msg["Authentication"]["clientDomain"])

        logon_id = msg["Authentication"]["logonId"]
        #
        # Now check the Domain server authentication message
        #
        self.check_domain_server_authentication(pipe, logon_id, "network")
Exemple #39
0
    def test_rid_set_dbcheck(self):
        """Perform a join against the RID manager and assert we have a RID Set.
        Using dbcheck, we assert that we can detect out of range users."""

        fsmo_dn = ldb.Dn(
            self.ldb_dc1,
            "CN=RID Manager$,CN=System," + self.ldb_dc1.domain_dn())
        (fsmo_owner, fsmo_not_owner) = self._determine_fSMORoleOwner(fsmo_dn)

        targetdir = self._test_join(fsmo_owner['dns_name'], "RIDALLOCTEST6")
        try:
            # Connect to the database
            ldb_url = "tdb://%s" % os.path.join(targetdir, "private/sam.ldb")
            smbconf = os.path.join(targetdir, "etc/smb.conf")

            lp = self.get_loadparm()
            new_ldb = SamDB(ldb_url,
                            credentials=self.get_credentials(),
                            session_info=system_session(lp),
                            lp=lp)

            # 1. Get server name
            res = new_ldb.search(base=ldb.Dn(new_ldb,
                                             new_ldb.get_serverName()),
                                 scope=ldb.SCOPE_BASE,
                                 attrs=["serverReference"])
            # 2. Get server reference
            server_ref_dn = ldb.Dn(new_ldb, res[0]['serverReference'][0])

            # 3. Assert we get the RID Set
            res = new_ldb.search(base=server_ref_dn,
                                 scope=ldb.SCOPE_BASE,
                                 attrs=['rIDSetReferences'])

            self.assertTrue("rIDSetReferences" in res[0])
            rid_set_dn = ldb.Dn(new_ldb, res[0]["rIDSetReferences"][0])

            # 4. Add a new user (triggers RID set work)
            new_ldb.newuser("ridalloctestuser", "P@ssword!")

            # 5. Now fetch the RID SET
            rid_set_res = new_ldb.search(
                base=rid_set_dn,
                scope=ldb.SCOPE_BASE,
                attrs=['rIDNextRid', 'rIDAllocationPool'])
            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            last_rid = (0xFFFFFFFF00000000 & next_pool) >> 32

            # 6. Add user above the ridNextRid and at mid-range.
            #
            # We can do this with safety because this is an offline DB that will be
            # destroyed.
            m = ldb.Message()
            m.dn = ldb.Dn(new_ldb, "CN=ridsettestuser1,CN=Users")
            m.dn.add_base(new_ldb.get_default_basedn())
            m['objectClass'] = ldb.MessageElement('user', ldb.FLAG_MOD_ADD,
                                                  'objectClass')
            m['objectSid'] = ldb.MessageElement(
                ndr_pack(
                    security.dom_sid(
                        str(new_ldb.get_domain_sid()) + "-%d" %
                        (last_rid - 10))), ldb.FLAG_MOD_ADD, 'objectSid')
            new_ldb.add(m, controls=["relax:0"])

            # 7. Check the RID Set
            chk = dbcheck(new_ldb,
                          verbose=False,
                          fix=True,
                          yes=True,
                          quiet=True)

            # Should have one error (wrong rIDNextRID)
            self.assertEqual(
                chk.check_database(DN=rid_set_dn, scope=ldb.SCOPE_BASE), 1)

            # 8. Assert we get didn't show any other errors
            chk = dbcheck(new_ldb, verbose=False, fix=False, quiet=True)

            rid_set_res = new_ldb.search(
                base=rid_set_dn,
                scope=ldb.SCOPE_BASE,
                attrs=['rIDNextRid', 'rIDAllocationPool'])
            last_allocated_rid = int(rid_set_res[0]["rIDNextRid"][0])
            self.assertEquals(last_allocated_rid, last_rid - 10)

            # 9. Assert that the range wasn't thrown away

            next_pool = int(rid_set_res[0]["rIDAllocationPool"][0])
            self.assertEqual(last_rid, (0xFFFFFFFF00000000 & next_pool) >> 32,
                             "rid pool should have changed")
        finally:
            self._test_force_demote(fsmo_owner['dns_name'], "RIDALLOCTEST6")
            shutil.rmtree(targetdir, ignore_errors=True)
class AuthLogTestsNetLogon(samba.tests.auth_log_base.AuthLogTestBase):

    def setUp(self):
        super(AuthLogTestsNetLogon, self).setUp()
        self.lp      = samba.tests.env_loadparm()
        self.creds   = Credentials()

        self.session = system_session()
        self.ldb = SamDB(
            session_info=self.session,
            credentials=self.creds,
            lp=self.lp)

        self.domain        = os.environ["DOMAIN"]
        self.netbios_name  = "NetLogonGood"
        self.machinepass   = "******"
        self.remoteAddress = AS_SYSTEM_MAGIC_PATH_TOKEN
        self.base_dn       = self.ldb.domain_dn()
        self.dn            = ("cn=%s,cn=users,%s" %
                              (self.netbios_name, self.base_dn))

        utf16pw = unicode(
            '"' + self.machinepass.encode('utf-8') + '"', 'utf-8'
        ).encode('utf-16-le')
        self.ldb.add({
            "dn": self.dn,
            "objectclass": "computer",
            "sAMAccountName": "%s$" % self.netbios_name,
            "userAccountControl":
                str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
            "unicodePwd": utf16pw})

    def tearDown(self):
        super(AuthLogTestsNetLogon, self).tearDown()
        delete_force(self.ldb, self.dn)

    def _test_netlogon(self, binding, checkFunction):

        def isLastExpectedMessage(msg):
            return (
                msg["type"] == "Authorization" and
                msg["Authorization"]["serviceDescription"]  == "DCE/RPC" and
                msg["Authorization"]["authType"]            == "schannel" and
                msg["Authorization"]["transportProtection"] == "SEAL")

        if binding:
            binding = "[schannel,%s]" % binding
        else:
            binding = "[schannel]"

        machine_creds = Credentials()
        machine_creds.guess(self.get_loadparm())
        machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
        machine_creds.set_password(self.machinepass)
        machine_creds.set_username(self.netbios_name + "$")

        netlogon_conn = netlogon.netlogon("ncalrpc:%s" % binding,
                                          self.get_loadparm(),
                                          machine_creds)

        messages = self.waitForMessages(isLastExpectedMessage, netlogon_conn)
        checkFunction(messages)

    def netlogon_check(self, messages):

        expected_messages = 5
        self.assertEquals(expected_messages,
                          len(messages),
                          "Did not receive the expected number of messages")

        # Check the first message it should be an Authorization
        msg = messages[0]
        self.assertEquals("Authorization", msg["type"])
        self.assertEquals("DCE/RPC",
                          msg["Authorization"]["serviceDescription"])
        self.assertEquals("ncalrpc", msg["Authorization"]["authType"])
        self.assertEquals("NONE", msg["Authorization"]["transportProtection"])
        self.assertTrue(self.is_guid(msg["Authorization"]["sessionId"]))

        # Check the fourth message it should be a NETLOGON Authentication
        msg = messages[3]
        self.assertEquals("Authentication", msg["type"])
        self.assertEquals("NETLOGON",
                          msg["Authentication"]["serviceDescription"])
        self.assertEquals("ServerAuthenticate",
                          msg["Authentication"]["authDescription"])
        self.assertEquals("NT_STATUS_OK",
                          msg["Authentication"]["status"])
        self.assertEquals("HMAC-SHA256",
                          msg["Authentication"]["passwordType"])

    def test_netlogon(self):
        self._test_netlogon("SEAL", self.netlogon_check)
Exemple #41
0
class BaseSortTests(samba.tests.TestCase):
    avoid_tricky_sort = False
    maxDiff = 2000

    def create_user(self,
                    i,
                    n,
                    prefix='sorttest',
                    suffix='',
                    attrs=None,
                    tricky=False):
        name = "%s%d%s" % (prefix, i, suffix)
        user = {
            'cn':
            name,
            "objectclass":
            "user",
            'givenName':
            "abcdefghijklmnopqrstuvwxyz"[i % 26],
            "roomNumber":
            "%sb\x00c" % (n - i),
            # with python3 re.sub(r'[^\w,.]', repl, string) doesn't
            # work as expected with unicode as value for carLicense
            "carLicense":
            "XXXXXXXXX" if self.avoid_tricky_sort else "后来经",
            "employeeNumber":
            "%s%sx" % (abs(i * (99 - i)), '\n' * (i & 255)),
            "accountExpires":
            "%s" % (10**9 + 1000000 * i),
            "msTSExpireDate4":
            "19%02d0101010000.0Z" % (i % 100),
            "flags":
            str(i * (n - i)),
            "serialNumber":
            "abc %s%s%s" % (
                'AaBb |-/'[i & 7],
                ' 3z}'[i & 3],
                '"@'[i & 1],
            ),
            "comment":
            "Favourite colour is %d" % (n % (i + 1)),
        }

        if self.avoid_tricky_sort:
            # We are not even going to try passing tests that assume
            # some kind of Unicode awareness.
            for k, v in user.items():
                user[k] = re.sub(r'[^\w,.]', 'X', v)
        else:
            # Add some even trickier ones!
            fiendish_index = i % len(FIENDISH_TESTS)
            user.update({
                # Sort doesn't look past a NUL byte.
                "photo":
                "\x00%d" % (n - i),
                "audio":
                "%sn octet string %s%s ♫♬\x00lalala" %
                ('Aa'[i & 1], chr(i & 255), i),
                "displayNamePrintable":
                "%d\x00%c" % (i, i & 255),
                "adminDisplayName":
                "%d\x00b" % (n - i),
                "title":
                "%d%sb" % (n - i, '\x00' * i),

                # Names that vary only in case. Windows returns
                # equivalent addresses in the order they were put
                # in ('a st', 'A st',...). We don't check that.
                "street":
                "%s st" % (chr(65 | (i & 14) | ((i & 1) * 32))),
                "streetAddress":
                FIENDISH_TESTS[fiendish_index],
                "postalAddress":
                FIENDISH_TESTS[-fiendish_index],
            })

        if attrs is not None:
            user.update(attrs)

        user['dn'] = "cn=%s,%s" % (user['cn'], self.ou)

        self.users.append(user)
        self.ldb.add(user)
        return user

    def setUp(self):
        super(BaseSortTests, self).setUp()
        self.ldb = SamDB(host,
                         credentials=creds,
                         session_info=system_session(lp),
                         lp=lp)

        self.base_dn = self.ldb.domain_dn()
        self.ou = "ou=sort,%s" % self.base_dn
        if False:
            try:
                self.ldb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError as e:
                print("tried deleting %s, got error %s" % (self.ou, e))

        self.ldb.add({"dn": self.ou, "objectclass": "organizationalUnit"})
        self.users = []
        n = opts.elements
        for i in range(n):
            self.create_user(i, n)

        attrs = set(self.users[0].keys()) - set(['objectclass', 'dn'])
        self.binary_sorted_keys = attrs.intersection([
            'audio', 'photo', "msTSExpireDate4", 'serialNumber',
            "displayNamePrintable"
        ])

        self.numeric_sorted_keys = attrs.intersection(
            ['flags', 'accountExpires'])

        self.timestamp_keys = attrs.intersection(['msTSExpireDate4'])

        self.int64_keys = set(['accountExpires'])

        self.locale_sorted_keys = [
            x for x in attrs
            if x not in (self.binary_sorted_keys | self.numeric_sorted_keys)
        ]

        self.expected_results = {}
        self.expected_results_binary = {}

        for k in self.binary_sorted_keys:
            forward = sorted((x[k] for x in self.users))
            reverse = list(reversed(forward))
            self.expected_results_binary[k] = (forward, reverse)

        # FYI: Expected result data was generated from the old
        # code that was manually sorting (while executing with
        # python2)
        # The resulting data was injected into the data file with
        # code similar to:
        #
        # for k in self.expected_results:
        #     f.write("%s = %s\n" % (k,  repr(self.expected_results[k][0])))

        f = open(self.results_file, "r")
        for line in f:
            if len(line.split('=', 1)) == 2:
                key = line.split('=', 1)[0].strip()
                value = line.split('=', 1)[1].strip()
                if value.startswith('['):
                    import ast
                    fwd_list = ast.literal_eval(value)
                    rev_list = list(reversed(fwd_list))
                    self.expected_results[key] = (fwd_list, rev_list)
        f.close()

    def tearDown(self):
        super(BaseSortTests, self).tearDown()
        self.ldb.delete(self.ou, ['tree_delete:1'])

    def _test_server_sort_default(self):
        attrs = self.locale_sorted_keys

        for attr in attrs:
            for rev in (0, 1):
                res = self.ldb.search(
                    self.ou,
                    scope=ldb.SCOPE_ONELEVEL,
                    attrs=[attr],
                    controls=["server_sort:1:%d:%s" % (rev, attr)])
                self.assertEqual(len(res), len(self.users))

                expected_order = self.expected_results[attr][rev]
                received_order = [norm(x[attr][0]) for x in res]
                if expected_order != received_order:
                    print(attr, ['forward', 'reverse'][rev])
                    print("expected", expected_order)
                    print("received", received_order)
                    print("unnormalised:", [x[attr][0] for x in res])
                    print("unnormalised: «%s»" %
                          '»  «'.join(str(x[attr][0]) for x in res))
                self.assertEquals(expected_order, received_order)

    def _test_server_sort_binary(self):
        for attr in self.binary_sorted_keys:
            for rev in (0, 1):
                res = self.ldb.search(
                    self.ou,
                    scope=ldb.SCOPE_ONELEVEL,
                    attrs=[attr],
                    controls=["server_sort:1:%d:%s" % (rev, attr)])

                self.assertEqual(len(res), len(self.users))
                expected_order = self.expected_results_binary[attr][rev]
                received_order = [str(x[attr][0]) for x in res]
                if expected_order != received_order:
                    print(attr)
                    print(expected_order)
                    print(received_order)
                self.assertEquals(expected_order, received_order)

    def _test_server_sort_us_english(self):
        # Windows doesn't support many matching rules, but does allow
        # the locale specific sorts -- if it has the locale installed.
        # The most reliable locale is the default US English, which
        # won't change the sort order.

        for lang, oid in [
            ('en_US', '1.2.840.113556.1.4.1499'),
        ]:

            for attr in self.locale_sorted_keys:
                for rev in (0, 1):
                    res = self.ldb.search(
                        self.ou,
                        scope=ldb.SCOPE_ONELEVEL,
                        attrs=[attr],
                        controls=["server_sort:1:%d:%s:%s" % (rev, attr, oid)])

                    self.assertTrue(len(res) == len(self.users))
                    expected_order = self.expected_results[attr][rev]
                    received_order = [norm(x[attr][0]) for x in res]
                    if expected_order != received_order:
                        print(attr, lang)
                        print(['forward', 'reverse'][rev])
                        print("expected: ", expected_order)
                        print("received: ", received_order)
                        print("unnormalised:", [x[attr][0] for x in res])
                        print("unnormalised: «%s»" %
                              '»  «'.join(str(x[attr][0]) for x in res))

                    self.assertEquals(expected_order, received_order)

    def _test_server_sort_different_attr(self):
        def cmp_locale(a, b):
            return locale.strcoll(a[0], b[0])

        def cmp_binary(a, b):
            return cmp_fn(a[0], b[0])

        def cmp_numeric(a, b):
            return cmp_fn(int(a[0]), int(b[0]))

        # For testing simplicity, the attributes in here need to be
        # unique for each user. Otherwise there are multiple possible
        # valid answers.
        sort_functions = {
            'cn': cmp_binary,
            "employeeNumber": cmp_locale,
            "accountExpires": cmp_numeric,
            "msTSExpireDate4": cmp_binary
        }
        attrs = list(sort_functions.keys())
        attr_pairs = zip(attrs, attrs[1:] + attrs[:1])

        for sort_attr, result_attr in attr_pairs:
            forward = sorted(((norm(x[sort_attr]), norm(x[result_attr]))
                              for x in self.users),
                             key=cmp_to_key_fn(sort_functions[sort_attr]))
            reverse = list(reversed(forward))

            for rev in (0, 1):
                res = self.ldb.search(
                    self.ou,
                    scope=ldb.SCOPE_ONELEVEL,
                    attrs=[result_attr],
                    controls=["server_sort:1:%d:%s" % (rev, sort_attr)])
                self.assertEqual(len(res), len(self.users))
                pairs = (forward, reverse)[rev]

                expected_order = [x[1] for x in pairs]
                received_order = [norm(x[result_attr][0]) for x in res]

                if expected_order != received_order:
                    print(sort_attr, result_attr, ['forward', 'reverse'][rev])
                    print("expected", expected_order)
                    print("received", received_order)
                    print("unnormalised:", [x[result_attr][0] for x in res])
                    print("unnormalised: «%s»" %
                          '»  «'.join(str(x[result_attr][0]) for x in res))
                    print("pairs:", pairs)
                    # There are bugs in Windows that we don't want (or
                    # know how) to replicate regarding timestamp sorting.
                    # Let's remind ourselves.
                    if result_attr == "msTSExpireDate4":
                        print('-' * 72)
                        print("This test fails against Windows with the "
                              "default number of elements (33).")
                        print("Try with --elements=27 (or similar).")
                        print('-' * 72)

                self.assertEquals(expected_order, received_order)
                for x in res:
                    if sort_attr in x:
                        self.fail('the search for %s should not return %s' %
                                  (result_attr, sort_attr))
class UserTests(samba.tests.TestCase):

    def add_if_possible(self, *args, **kwargs):
        """In these tests sometimes things are left in the database
        deliberately, so we don't worry if we fail to add them a second
        time."""
        try:
            self.ldb.add(*args, **kwargs)
        except LdbError:
            pass

    def setUp(self):
        super(UserTests, self).setUp()
        self.state = GlobalState  # the class itself, not an instance
        self.lp = lp
        self.ldb = SamDB(host, credentials=creds,
                         session_info=system_session(lp), lp=lp)
        self.base_dn = self.ldb.domain_dn()
        self.ou = "OU=pid%s,%s" % (os.getpid(), self.base_dn)
        self.ou_users = "OU=users,%s" % self.ou
        self.ou_groups = "OU=groups,%s" % self.ou
        self.ou_computers = "OU=computers,%s" % self.ou

        for dn in (self.ou, self.ou_users, self.ou_groups,
                   self.ou_computers):
            self.add_if_possible({
                "dn": dn,
                "objectclass": "organizationalUnit"})

    def tearDown(self):
        super(UserTests, self).tearDown()

    def test_00_00_do_nothing(self):
        # this gives us an idea of the overhead
        pass

    def _prepare_n_groups(self, n):
        self.state.n_groups = n
        for i in range(n):
            self.add_if_possible({
                "dn": "cn=g%d,%s" % (i, self.ou_groups),
                "objectclass": "group"})

    def _add_users(self, start, end):
        for i in range(start, end):
            self.ldb.add({
                "dn": "cn=u%d,%s" % (i, self.ou_users),
                "objectclass": "user"})

    def _test_join(self):
        tmpdir = tempfile.mkdtemp()
        if '://' in host:
            server = host.split('://', 1)[1]
        else:
            server = host
        cmd = cmd_sambatool.subcommands['domain'].subcommands['join']
        result = cmd._run("samba-tool domain join",
                          creds.get_realm(),
                          "dc", "-U%s%%%s" % (creds.get_username(),
                                              creds.get_password()),
                          '--targetdir=%s' % tmpdir,
                          '--server=%s' % server)

        shutil.rmtree(tmpdir)

    def _test_unindexed_search(self):
        expressions = [
            ('(&(objectclass=user)(description='
             'Built-in account for adminstering the computer/domain))'),
            '(description=Built-in account for adminstering the computer/domain)',
            '(objectCategory=*)',
            '(samaccountname=Administrator*)'
        ]
        for expression in expressions:
            t = time.time()
            for i in range(10):
                self.ldb.search(self.ou,
                                expression=expression,
                                scope=SCOPE_SUBTREE,
                                attrs=['cn'])
            print('%d %s took %s' % (i, expression,
                                     time.time() - t), file=sys.stderr)

    def _test_indexed_search(self):
        expressions = ['(objectclass=group)',
                       '(samaccountname=Administrator)'
                       ]
        for expression in expressions:
            t = time.time()
            for i in range(100):
                self.ldb.search(self.ou,
                                expression=expression,
                                scope=SCOPE_SUBTREE,
                                attrs=['cn'])
            print('%d runs %s took %s' % (i, expression,
                                          time.time() - t), file=sys.stderr)

    def _test_add_many_users(self, n=BATCH_SIZE):
        s = self.state.next_user_id
        e = s + n
        self._add_users(s, e)
        self.state.next_user_id = e

    test_00_00_join_empty_dc = _test_join

    test_00_01_adding_users_1000 = _test_add_many_users
    test_00_02_adding_users_2000 = _test_add_many_users
    test_00_03_adding_users_3000 = _test_add_many_users

    test_00_10_join_unlinked_dc = _test_join
    test_00_11_unindexed_search_3k_users = _test_unindexed_search
    test_00_12_indexed_search_3k_users = _test_indexed_search

    def _link_user_and_group(self, u, g):
        m = Message()
        m.dn = Dn(self.ldb, "CN=g%d,%s" % (g, self.ou_groups))
        m["member"] = MessageElement("cn=u%d,%s" % (u, self.ou_users),
                                     FLAG_MOD_ADD, "member")
        self.ldb.modify(m)

    def _unlink_user_and_group(self, u, g):
        user = "******" % (u, self.ou_users)
        group = "CN=g%d,%s" % (g, self.ou_groups)
        m = Message()
        m.dn = Dn(self.ldb, group)
        m["member"] = MessageElement(user, FLAG_MOD_DELETE, "member")
        self.ldb.modify(m)

    def _test_link_many_users(self, n=BATCH_SIZE):
        self._prepare_n_groups(N_GROUPS)
        s = self.state.next_linked_user
        e = s + n
        for i in range(s, e):
            g = i % N_GROUPS
            self._link_user_and_group(i, g)
        self.state.next_linked_user = e

    test_01_01_link_users_1000 = _test_link_many_users
    test_01_02_link_users_2000 = _test_link_many_users
    test_01_03_link_users_3000 = _test_link_many_users

    def _test_link_many_users_offset_1(self, n=BATCH_SIZE):
        s = self.state.next_relinked_user
        e = s + n
        for i in range(s, e):
            g = (i + 1) % N_GROUPS
            self._link_user_and_group(i, g)
        self.state.next_relinked_user = e

    test_02_01_link_users_again_1000 = _test_link_many_users_offset_1
    test_02_02_link_users_again_2000 = _test_link_many_users_offset_1
    test_02_03_link_users_again_3000 = _test_link_many_users_offset_1

    test_02_10_join_partially_linked_dc = _test_join
    test_02_11_unindexed_search_partially_linked_dc = _test_unindexed_search
    test_02_12_indexed_search_partially_linked_dc = _test_indexed_search

    def _test_link_many_users_3_groups(self, n=BATCH_SIZE, groups=3):
        s = self.state.next_linked_user_3
        e = s + n
        self.state.next_linked_user_3 = e
        for i in range(s, e):
            g = (i + 2) % groups
            if g not in (i % N_GROUPS, (i + 1) % N_GROUPS):
                self._link_user_and_group(i, g)

    test_03_01_link_users_again_1000_few_groups = _test_link_many_users_3_groups
    test_03_02_link_users_again_2000_few_groups = _test_link_many_users_3_groups
    test_03_03_link_users_again_3000_few_groups = _test_link_many_users_3_groups

    def _test_remove_links_0(self, n=BATCH_SIZE):
        s = self.state.next_removed_link_0
        e = s + n
        self.state.next_removed_link_0 = e
        for i in range(s, e):
            g = i % N_GROUPS
            self._unlink_user_and_group(i, g)

    test_04_01_remove_some_links_1000 = _test_remove_links_0
    test_04_02_remove_some_links_2000 = _test_remove_links_0
    test_04_03_remove_some_links_3000 = _test_remove_links_0

    # back to using _test_add_many_users
    test_05_01_adding_users_after_links_4000 = _test_add_many_users

    # reset the link count, to replace the original links
    def test_06_01_relink_users_1000(self):
        self.state.next_linked_user = 0
        self._test_link_many_users()

    test_06_02_link_users_2000 = _test_link_many_users
    test_06_03_link_users_3000 = _test_link_many_users
    test_06_04_link_users_4000 = _test_link_many_users
    test_06_05_link_users_again_4000 = _test_link_many_users_offset_1
    test_06_06_link_users_again_4000_few_groups = _test_link_many_users_3_groups

    test_07_01_adding_users_after_links_5000 = _test_add_many_users

    def _test_link_random_users_and_groups(self, n=BATCH_SIZE, groups=100):
        self._prepare_n_groups(groups)
        for i in range(n):
            u = random.randrange(self.state.next_user_id)
            g = random.randrange(groups)
            try:
                self._link_user_and_group(u, g)
            except LdbError:
                pass

    test_08_01_link_random_users_100_groups = _test_link_random_users_and_groups
    test_08_02_link_random_users_100_groups = _test_link_random_users_and_groups

    test_10_01_unindexed_search_full_dc = _test_unindexed_search
    test_10_02_indexed_search_full_dc = _test_indexed_search
    test_11_02_join_full_dc = _test_join

    def test_20_01_delete_50_groups(self):
        for i in range(self.state.n_groups - 50, self.state.n_groups):
            self.ldb.delete("cn=g%d,%s" % (i, self.ou_groups))
        self.state.n_groups -= 50

    def _test_delete_many_users(self, n=BATCH_SIZE):
        e = self.state.next_user_id
        s = max(0, e - n)
        self.state.next_user_id = s
        for i in range(s, e):
            self.ldb.delete("cn=u%d,%s" % (i, self.ou_users))

    test_21_01_delete_users_5000_lightly_linked = _test_delete_many_users
    test_21_02_delete_users_4000_lightly_linked = _test_delete_many_users
    test_21_03_delete_users_3000 = _test_delete_many_users

    def test_22_01_delete_all_groups(self):
        for i in range(self.state.n_groups):
            self.ldb.delete("cn=g%d,%s" % (i, self.ou_groups))
        self.state.n_groups = 0

    test_23_01_delete_users_after_groups_2000 = _test_delete_many_users
    test_23_00_delete_users_after_groups_1000 = _test_delete_many_users

    test_24_02_join_after_cleanup = _test_join
class PyKrb5CredentialsTests(TestCase):
    def setUp(self):
        super(PyKrb5CredentialsTests, self).setUp()

        self.server = os.environ["SERVER"]
        self.domain = os.environ["DOMAIN"]
        self.host = os.environ["SERVER_IP"]
        self.lp = self.get_loadparm()

        self.credentials = self.get_credentials()

        self.session = system_session()
        self.ldb = SamDB(url="ldap://%s" % self.host,
                         session_info=self.session,
                         credentials=self.credentials,
                         lp=self.lp)

        self.create_machine_account()

    def tearDown(self):
        super(PyKrb5CredentialsTests, self).tearDown()
        delete_force(self.ldb, self.machine_dn)

    def test_get_named_ccache(self):
        name = "MEMORY:py_creds_machine"
        ccache = self.machine_creds.get_named_ccache(self.lp, name)
        self.assertEqual(ccache.get_name(), name)

    def test_get_unnamed_ccache(self):
        ccache = self.machine_creds.get_named_ccache(self.lp)
        self.assertIsNotNone(ccache.get_name())

    def test_set_named_ccache(self):
        ccache = self.machine_creds.get_named_ccache(self.lp)

        creds = Credentials()
        creds.set_named_ccache(ccache.get_name())

        ccache2 = creds.get_named_ccache(self.lp)
        self.assertEqual(ccache.get_name(), ccache2.get_name())

    #
    # Create the machine account
    def create_machine_account(self):
        self.machine_pass = samba.generate_random_password(32, 32)
        self.machine_name = MACHINE_NAME
        self.machine_dn = "cn=%s,%s" % (self.machine_name,
                                        self.ldb.domain_dn())

        # remove the account if it exists, this will happen if a previous test
        # run failed
        delete_force(self.ldb, self.machine_dn)
        # get unicode str for both py2 and py3
        pass_unicode = self.machine_pass.encode('utf-8').decode('utf-8')
        utf16pw = u'"{0}"'.format(pass_unicode).encode('utf-16-le')
        self.ldb.add({
            "dn":
            self.machine_dn,
            "objectclass":
            "computer",
            "sAMAccountName":
            "%s$" % self.machine_name,
            "userAccountControl":
            str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
            "unicodePwd":
            utf16pw
        })

        self.machine_creds = Credentials()
        self.machine_creds.guess(self.get_loadparm())
        self.machine_creds.set_password(self.machine_pass)
        self.machine_creds.set_username(self.machine_name + "$")
        self.machine_creds.set_workstation(self.machine_name)
Exemple #44
0
class PyCredentialsTests(TestCase):

    def setUp(self):
        super(PyCredentialsTests, self).setUp()

        self.server      = os.environ["SERVER"]
        self.domain      = os.environ["DOMAIN"]
        self.host        = os.environ["SERVER_IP"]
        self.lp          = self.get_loadparm()

        self.credentials = self.get_credentials()

        self.session     = system_session()
        self.ldb = SamDB(url="ldap://%s" % self.host,
                         session_info=self.session,
                         credentials=self.credentials,
                         lp=self.lp)

        self.create_machine_account()
        self.create_user_account()

    def tearDown(self):
        super(PyCredentialsTests, self).tearDown()
        delete_force(self.ldb, self.machine_dn)
        delete_force(self.ldb, self.user_dn)

    # Until a successful netlogon connection has been established there will
    # not be a valid authenticator associated with the credentials
    # and new_client_authenticator should throw a ValueError
    def test_no_netlogon_connection(self):
        self.assertRaises(ValueError,
                          self.machine_creds.new_client_authenticator)

    # Once a netlogon connection has been established,
    # new_client_authenticator should return a value
    #
    def test_have_netlogon_connection(self):
        c = self.get_netlogon_connection()
        a = self.machine_creds.new_client_authenticator()
        self.assertIsNotNone(a)

    # Get an authenticator and use it on a sequence of operations requiring
    # an authenticator
    def test_client_authenticator(self):
        c = self.get_netlogon_connection()
        (authenticator, subsequent) = self.get_authenticator(c)
        self.do_NetrLogonSamLogonWithFlags(c, authenticator, subsequent)
        (authenticator, subsequent) = self.get_authenticator(c)
        self.do_NetrLogonGetDomainInfo(c, authenticator, subsequent)
        (authenticator, subsequent) = self.get_authenticator(c)
        self.do_NetrLogonGetDomainInfo(c, authenticator, subsequent)
        (authenticator, subsequent) = self.get_authenticator(c)
        self.do_NetrLogonGetDomainInfo(c, authenticator, subsequent)

    def test_SamLogonEx(self):
        c = self.get_netlogon_connection()

        logon = samlogon_logon_info(self.domain,
                                    self.machine_name,
                                    self.user_creds)

        logon_level = netlogon.NetlogonNetworkTransitiveInformation
        validation_level = netlogon.NetlogonValidationSamInfo4
        netr_flags = 0

        try:
            c.netr_LogonSamLogonEx(self.server,
                                   self.user_creds.get_workstation(),
                                   logon_level,
                                   logon,
                                   validation_level,
                                   netr_flags)
        except NTSTATUSError as e:
            enum = ctypes.c_uint32(e.args[0]).value
            if enum == ntstatus.NT_STATUS_WRONG_PASSWORD:
                self.fail("got wrong password error")
            else:
                raise

    def test_SamLogonEx_no_domain(self):
        c = self.get_netlogon_connection()

        self.user_creds.set_domain('')

        logon = samlogon_logon_info(self.domain,
                                    self.machine_name,
                                    self.user_creds)

        logon_level = netlogon.NetlogonNetworkTransitiveInformation
        validation_level = netlogon.NetlogonValidationSamInfo4
        netr_flags = 0

        try:
            c.netr_LogonSamLogonEx(self.server,
                                   self.user_creds.get_workstation(),
                                   logon_level,
                                   logon,
                                   validation_level,
                                   netr_flags)
        except NTSTATUSError as e:
            enum = ctypes.c_uint32(e.args[0]).value
            if enum == ntstatus.NT_STATUS_WRONG_PASSWORD:
                self.fail("got wrong password error")
            else:
                self.fail("got unexpected error" + str(e))

    def test_SamLogonExNTLM(self):
        c = self.get_netlogon_connection()

        logon = samlogon_logon_info(self.domain,
                                    self.machine_name,
                                    self.user_creds,
                                    flags=CLI_CRED_NTLM_AUTH)

        logon_level = netlogon.NetlogonNetworkTransitiveInformation
        validation_level = netlogon.NetlogonValidationSamInfo4
        netr_flags = 0

        try:
            c.netr_LogonSamLogonEx(self.server,
                                   self.user_creds.get_workstation(),
                                   logon_level,
                                   logon,
                                   validation_level,
                                   netr_flags)
        except NTSTATUSError as e:
            enum = ctypes.c_uint32(e.args[0]).value
            if enum == ntstatus.NT_STATUS_WRONG_PASSWORD:
                self.fail("got wrong password error")
            else:
                raise

    def test_SamLogonExMSCHAPv2(self):
        c = self.get_netlogon_connection()

        logon = samlogon_logon_info(self.domain,
                                    self.machine_name,
                                    self.user_creds,
                                    flags=CLI_CRED_NTLM_AUTH)

        logon.identity_info.parameter_control = MSV1_0_ALLOW_MSVCHAPV2

        logon_level = netlogon.NetlogonNetworkTransitiveInformation
        validation_level = netlogon.NetlogonValidationSamInfo4
        netr_flags = 0

        try:
            c.netr_LogonSamLogonEx(self.server,
                                   self.user_creds.get_workstation(),
                                   logon_level,
                                   logon,
                                   validation_level,
                                   netr_flags)
        except NTSTATUSError as e:
            enum = ctypes.c_uint32(e.args[0]).value
            if enum == ntstatus.NT_STATUS_WRONG_PASSWORD:
                self.fail("got wrong password error")
            else:
                raise

    # Test Credentials.encrypt_netr_crypt_password
    # By performing a NetrServerPasswordSet2
    # And the logging on using the new password.

    def test_encrypt_netr_password(self):
        # Change the password
        self.do_Netr_ServerPasswordSet2()
        # Now use the new password to perform an operation
        srvsvc.srvsvc("ncacn_np:%s" % (self.server),
                      self.lp,
                      self.machine_creds)

   # Change the current machine account password with a
   # netr_ServerPasswordSet2 call.

    def do_Netr_ServerPasswordSet2(self):
        c = self.get_netlogon_connection()
        (authenticator, subsequent) = self.get_authenticator(c)
        PWD_LEN  = 32
        DATA_LEN = 512
        newpass = samba.generate_random_password(PWD_LEN, PWD_LEN)
        encoded = newpass.encode('utf-16-le')
        pwd_len = len(encoded)
        filler  = [x if isinstance(x, int) else ord(x) for x in os.urandom(DATA_LEN - pwd_len)]
        pwd = netlogon.netr_CryptPassword()
        pwd.length = pwd_len
        pwd.data = filler + [x if isinstance(x, int) else ord(x) for x in encoded]
        self.machine_creds.encrypt_netr_crypt_password(pwd)
        c.netr_ServerPasswordSet2(self.server,
                                  self.machine_creds.get_workstation(),
                                  SEC_CHAN_WKSTA,
                                  self.machine_name,
                                  authenticator,
                                  pwd)

        self.machine_pass = newpass
        self.machine_creds.set_password(newpass)

    # Establish sealed schannel netlogon connection over TCP/IP
    #
    def get_netlogon_connection(self):
        return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" % self.server,
                                 self.lp,
                                 self.machine_creds)

    #
    # Create the machine account
    def create_machine_account(self):
        self.machine_pass = samba.generate_random_password(32, 32)
        self.machine_name = MACHINE_NAME
        self.machine_dn = "cn=%s,%s" % (self.machine_name, self.ldb.domain_dn())

        # remove the account if it exists, this will happen if a previous test
        # run failed
        delete_force(self.ldb, self.machine_dn)

        utf16pw = ('"%s"' % get_string(self.machine_pass)).encode('utf-16-le')
        self.ldb.add({
            "dn": self.machine_dn,
            "objectclass": "computer",
            "sAMAccountName": "%s$" % self.machine_name,
            "userAccountControl":
                str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
            "unicodePwd": utf16pw})

        self.machine_creds = Credentials()
        self.machine_creds.guess(self.get_loadparm())
        self.machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
        self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
        self.machine_creds.set_password(self.machine_pass)
        self.machine_creds.set_username(self.machine_name + "$")
        self.machine_creds.set_workstation(self.machine_name)

    #
    # Create a test user account
    def create_user_account(self):
        self.user_pass = samba.generate_random_password(32, 32)
        self.user_name = USER_NAME
        self.user_dn = "cn=%s,%s" % (self.user_name, self.ldb.domain_dn())

        # remove the account if it exists, this will happen if a previous test
        # run failed
        delete_force(self.ldb, self.user_dn)

        utf16pw = ('"%s"' % get_string(self.user_pass)).encode('utf-16-le')
        self.ldb.add({
            "dn": self.user_dn,
            "objectclass": "user",
            "sAMAccountName": "%s" % self.user_name,
            "userAccountControl": str(UF_NORMAL_ACCOUNT),
            "unicodePwd": utf16pw})

        self.user_creds = Credentials()
        self.user_creds.guess(self.get_loadparm())
        self.user_creds.set_password(self.user_pass)
        self.user_creds.set_username(self.user_name)
        self.user_creds.set_workstation(self.machine_name)
        pass

    #
    # Get the authenticator from the machine creds.
    def get_authenticator(self, c):
        auth = self.machine_creds.new_client_authenticator()
        current = netr_Authenticator()
        current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
        current.timestamp = auth["timestamp"]

        subsequent = netr_Authenticator()
        return (current, subsequent)

    def do_NetrLogonSamLogonWithFlags(self, c, current, subsequent):
        logon = samlogon_logon_info(self.domain,
                                    self.machine_name,
                                    self.user_creds)

        logon_level = netlogon.NetlogonNetworkTransitiveInformation
        validation_level = netlogon.NetlogonValidationSamInfo4
        netr_flags = 0
        c.netr_LogonSamLogonWithFlags(self.server,
                                      self.user_creds.get_workstation(),
                                      current,
                                      subsequent,
                                      logon_level,
                                      logon,
                                      validation_level,
                                      netr_flags)

    def do_NetrLogonGetDomainInfo(self, c, current, subsequent):
        query = netr_WorkstationInformation()

        c.netr_LogonGetDomainInfo(self.server,
                                  self.user_creds.get_workstation(),
                                  current,
                                  subsequent,
                                  2,
                                  query)
Exemple #45
0
class LATests(samba.tests.TestCase):
    def setUp(self):
        super(LATests, self).setUp()
        self.samdb = SamDB(host,
                           credentials=creds,
                           session_info=system_session(lp),
                           lp=lp)

        self.base_dn = self.samdb.domain_dn()
        self.ou = "OU=la,%s" % self.base_dn
        if opts.delete_in_setup:
            try:
                self.samdb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError as e:
                print("tried deleting %s, got error %s" % (self.ou, e))
        self.samdb.add({'objectclass': 'organizationalUnit', 'dn': self.ou})

    def tearDown(self):
        super(LATests, self).tearDown()
        if not opts.no_cleanup:
            self.samdb.delete(self.ou, ['tree_delete:1'])

    def add_object(self, cn, objectclass, more_attrs={}):
        dn = "CN=%s,%s" % (cn, self.ou)
        attrs = {'cn': cn, 'objectclass': objectclass, 'dn': dn}
        attrs.update(more_attrs)
        self.samdb.add(attrs)

        return dn

    def add_objects(self, n, objectclass, prefix=None, more_attrs={}):
        if prefix is None:
            prefix = objectclass
        dns = []
        for i in range(n):
            dns.append(
                self.add_object("%s%d" % (prefix, i + 1),
                                objectclass,
                                more_attrs=more_attrs))
        return dns

    def add_linked_attribute(self, src, dest, attr='member', controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_ADD, attr)
        self.samdb.modify(m, controls=controls)

    def remove_linked_attribute(self, src, dest, attr='member', controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_DELETE, attr)
        self.samdb.modify(m, controls=controls)

    def replace_linked_attribute(self,
                                 src,
                                 dest,
                                 attr='member',
                                 controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_REPLACE, attr)
        self.samdb.modify(m, controls=controls)

    def attr_search(self, obj, attr, scope=ldb.SCOPE_BASE, **controls):
        if opts.no_reveal_internals:
            if 'reveal_internals' in controls:
                del controls['reveal_internals']

        controls = ['%s:%d' % (k, int(v)) for k, v in controls.items()]

        res = self.samdb.search(obj,
                                scope=scope,
                                attrs=[attr],
                                controls=controls)
        return res

    def assert_links(self, obj, expected, attr, msg='', **kwargs):
        res = self.attr_search(obj, attr, **kwargs)

        if len(expected) == 0:
            if attr in res[0]:
                self.fail("found attr '%s' in %s" % (attr, res[0]))
            return

        try:
            results = list([x[attr] for x in res][0])
        except KeyError:
            self.fail("missing attr '%s' on %s" % (attr, obj))

        expected = sorted(expected)
        results = sorted(results)

        if expected != results:
            print(msg)
            print("expected %s" % expected)
            print("received %s" % results)

        self.assertEqual(results, expected)

    def assert_back_links(self, obj, expected, attr='memberOf', **kwargs):
        self.assert_links(obj,
                          expected,
                          attr=attr,
                          msg='back links do not match',
                          **kwargs)

    def assert_forward_links(self, obj, expected, attr='member', **kwargs):
        self.assert_links(obj,
                          expected,
                          attr=attr,
                          msg='forward links do not match',
                          **kwargs)

    def get_object_guid(self, dn):
        res = self.samdb.search(dn, scope=ldb.SCOPE_BASE, attrs=['objectGUID'])
        return str(misc.GUID(res[0]['objectGUID'][0]))

    def assertRaisesLdbError(self, errcode, msg, f, *args, **kwargs):
        """Assert a function raises a particular LdbError."""
        try:
            f(*args, **kwargs)
        except ldb.LdbError as e:
            (num, msg) = e.args
            if num != errcode:
                lut = {
                    v: k
                    for k, v in vars(ldb).iteritems()
                    if k.startswith('ERR_') and isinstance(v, int)
                }
                self.fail("%s, expected "
                          "LdbError %s, (%d) "
                          "got %s (%d)" %
                          (msg, lut.get(errcode), errcode, lut.get(num), num))
        else:
            lut = {
                v: k
                for k, v in vars(ldb).iteritems()
                if k.startswith('ERR_') and isinstance(v, int)
            }
            self.fail("%s, expected "
                      "LdbError %s, (%d) "
                      "but we got success" % (msg, lut.get(errcode), errcode))

    def _test_la_backlinks(self, reveal=False):
        tag = 'backlinks'
        kwargs = {}
        if reveal:
            tag += '_reveal'
            kwargs = {'reveal_internals': 0}

        u1, u2 = self.add_objects(2, 'user', 'u_%s' % tag)
        g1, g2 = self.add_objects(2, 'group', 'g_%s' % tag)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.assert_back_links(u1, [g1, g2], **kwargs)
        self.assert_back_links(u2, [g2], **kwargs)

    def test_la_backlinks(self):
        self._test_la_backlinks()

    def test_la_backlinks_reveal(self):
        if opts.no_reveal_internals:
            print('skipping because --no-reveal-internals')
            return
        self._test_la_backlinks(True)

    def _test_la_backlinks_delete_group(self, reveal=False):
        tag = 'del_group'
        kwargs = {}
        if reveal:
            tag += '_reveal'
            kwargs = {'reveal_internals': 0}

        u1, u2 = self.add_objects(2, 'user', 'u_' + tag)
        g1, g2 = self.add_objects(2, 'group', 'g_' + tag)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.samdb.delete(g2, ['tree_delete:1'])

        self.assert_back_links(u1, [g1], **kwargs)
        self.assert_back_links(u2, set(), **kwargs)

    def test_la_backlinks_delete_group(self):
        self._test_la_backlinks_delete_group()

    def test_la_backlinks_delete_group_reveal(self):
        if opts.no_reveal_internals:
            print('skipping because --no-reveal-internals')
            return
        self._test_la_backlinks_delete_group(True)

    def test_links_all_delete_group(self):
        u1, u2 = self.add_objects(2, 'user', 'u_all_del_group')
        g1, g2 = self.add_objects(2, 'group', 'g_all_del_group')
        g2guid = self.get_object_guid(g2)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.samdb.delete(g2)
        self.assert_back_links(u1, [g1],
                               show_deleted=1,
                               show_recycled=1,
                               show_deactivated_link=0)
        self.assert_back_links(u2,
                               set(),
                               show_deleted=1,
                               show_recycled=1,
                               show_deactivated_link=0)
        self.assert_forward_links(g1, [u1],
                                  show_deleted=1,
                                  show_recycled=1,
                                  show_deactivated_link=0)
        self.assert_forward_links('<GUID=%s>' % g2guid, [],
                                  show_deleted=1,
                                  show_recycled=1,
                                  show_deactivated_link=0)

    def test_links_all_delete_group_reveal(self):
        u1, u2 = self.add_objects(2, 'user', 'u_all_del_group_reveal')
        g1, g2 = self.add_objects(2, 'group', 'g_all_del_group_reveal')
        g2guid = self.get_object_guid(g2)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.samdb.delete(g2)
        self.assert_back_links(u1, [g1],
                               show_deleted=1,
                               show_recycled=1,
                               show_deactivated_link=0,
                               reveal_internals=0)
        self.assert_back_links(u2,
                               set(),
                               show_deleted=1,
                               show_recycled=1,
                               show_deactivated_link=0,
                               reveal_internals=0)
        self.assert_forward_links(g1, [u1],
                                  show_deleted=1,
                                  show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)
        self.assert_forward_links('<GUID=%s>' % g2guid, [],
                                  show_deleted=1,
                                  show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)

    def test_la_links_delete_link(self):
        u1, u2 = self.add_objects(2, 'user', 'u_del_link')
        g1, g2 = self.add_objects(2, 'group', 'g_del_link')

        res = self.samdb.search(g1, scope=ldb.SCOPE_BASE, attrs=['uSNChanged'])
        old_usn1 = int(res[0]['uSNChanged'][0])

        self.add_linked_attribute(g1, u1)

        res = self.samdb.search(g1, scope=ldb.SCOPE_BASE, attrs=['uSNChanged'])
        new_usn1 = int(res[0]['uSNChanged'][0])

        self.assertNotEqual(old_usn1, new_usn1, "USN should have incremented")

        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        res = self.samdb.search(g2, scope=ldb.SCOPE_BASE, attrs=['uSNChanged'])
        old_usn2 = int(res[0]['uSNChanged'][0])

        self.remove_linked_attribute(g2, u1)

        res = self.samdb.search(g2, scope=ldb.SCOPE_BASE, attrs=['uSNChanged'])
        new_usn2 = int(res[0]['uSNChanged'][0])

        self.assertNotEqual(old_usn2, new_usn2, "USN should have incremented")

        self.assert_forward_links(g1, [u1])
        self.assert_forward_links(g2, [u2])

        self.add_linked_attribute(g2, u1)
        self.assert_forward_links(g2, [u1, u2])
        self.remove_linked_attribute(g2, u2)
        self.assert_forward_links(g2, [u1])
        self.remove_linked_attribute(g2, u1)
        self.assert_forward_links(g2, [])
        self.remove_linked_attribute(g1, [])
        self.assert_forward_links(g1, [])

        # removing a duplicate link in the same message should fail
        self.add_linked_attribute(g2, [u1, u2])
        self.assertRaises(ldb.LdbError, self.remove_linked_attribute, g2,
                          [u1, u1])

    def _test_la_links_delete_link_reveal(self):
        u1, u2 = self.add_objects(2, 'user', 'u_del_link_reveal')
        g1, g2 = self.add_objects(2, 'group', 'g_del_link_reveal')

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.remove_linked_attribute(g2, u1)

        self.assert_forward_links(g2, [u1, u2],
                                  show_deleted=1,
                                  show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)

    def test_la_links_delete_link_reveal(self):
        if opts.no_reveal_internals:
            print('skipping because --no-reveal-internals')
            return
        self._test_la_links_delete_link_reveal()

    def test_la_links_delete_user(self):
        u1, u2 = self.add_objects(2, 'user', 'u_del_user')
        g1, g2 = self.add_objects(2, 'group', 'g_del_user')

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        res = self.samdb.search(g1, scope=ldb.SCOPE_BASE, attrs=['uSNChanged'])
        old_usn1 = int(res[0]['uSNChanged'][0])

        res = self.samdb.search(g2, scope=ldb.SCOPE_BASE, attrs=['uSNChanged'])
        old_usn2 = int(res[0]['uSNChanged'][0])

        self.samdb.delete(u1)

        self.assert_forward_links(g1, [])
        self.assert_forward_links(g2, [u2])

        res = self.samdb.search(g1, scope=ldb.SCOPE_BASE, attrs=['uSNChanged'])
        new_usn1 = int(res[0]['uSNChanged'][0])

        res = self.samdb.search(g2, scope=ldb.SCOPE_BASE, attrs=['uSNChanged'])
        new_usn2 = int(res[0]['uSNChanged'][0])

        # Assert the USN on the alternate object is unchanged
        self.assertEqual(old_usn1, new_usn1)
        self.assertEqual(old_usn2, new_usn2)

    def test_la_links_delete_user_reveal(self):
        u1, u2 = self.add_objects(2, 'user', 'u_del_user_reveal')
        g1, g2 = self.add_objects(2, 'group', 'g_del_user_reveal')

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.samdb.delete(u1)

        self.assert_forward_links(g2, [u2],
                                  show_deleted=1,
                                  show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)
        self.assert_forward_links(g1, [],
                                  show_deleted=1,
                                  show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)

    def test_multiple_links(self):
        u1, u2, u3, u4 = self.add_objects(4, 'user', 'u_multiple_links')
        g1, g2, g3, g4 = self.add_objects(4, 'group', 'g_multiple_links')

        self.add_linked_attribute(g1, [u1, u2, u3, u4])
        self.add_linked_attribute(g2, [u3, u1])
        self.add_linked_attribute(g3, u2)

        self.assertRaisesLdbError(ldb.ERR_ENTRY_ALREADY_EXISTS,
                                  "adding duplicate values",
                                  self.add_linked_attribute, g2,
                                  [u1, u2, u3, u2])

        self.assert_forward_links(g1, [u1, u2, u3, u4])
        self.assert_forward_links(g2, [u3, u1])
        self.assert_forward_links(g3, [u2])
        self.assert_back_links(u1, [g2, g1])
        self.assert_back_links(u2, [g3, g1])
        self.assert_back_links(u3, [g2, g1])
        self.assert_back_links(u4, [g1])

        self.remove_linked_attribute(g2, [u1, u3])
        self.remove_linked_attribute(g1, [u1, u3])

        self.assert_forward_links(g1, [u2, u4])
        self.assert_forward_links(g2, [])
        self.assert_forward_links(g3, [u2])
        self.assert_back_links(u1, [])
        self.assert_back_links(u2, [g3, g1])
        self.assert_back_links(u3, [])
        self.assert_back_links(u4, [g1])

        self.add_linked_attribute(g1, [u1, u3])
        self.add_linked_attribute(g2, [u3, u1])
        self.add_linked_attribute(g3, [u1, u3])

        self.assert_forward_links(g1, [u1, u2, u3, u4])
        self.assert_forward_links(g2, [u1, u3])
        self.assert_forward_links(g3, [u1, u2, u3])
        self.assert_back_links(u1, [g1, g2, g3])
        self.assert_back_links(u2, [g3, g1])
        self.assert_back_links(u3, [g3, g2, g1])
        self.assert_back_links(u4, [g1])

    def test_la_links_replace(self):
        u1, u2, u3, u4 = self.add_objects(4, 'user', 'u_replace')
        g1, g2, g3, g4 = self.add_objects(4, 'group', 'g_replace')

        self.add_linked_attribute(g1, [u1, u2])
        self.add_linked_attribute(g2, [u1, u3])
        self.add_linked_attribute(g3, u1)

        self.replace_linked_attribute(g1, [u2])
        self.replace_linked_attribute(g2, [u2, u3])
        self.replace_linked_attribute(g3, [u1, u3])
        self.replace_linked_attribute(g4, [u4])

        self.assert_forward_links(g1, [u2])
        self.assert_forward_links(g2, [u3, u2])
        self.assert_forward_links(g3, [u3, u1])
        self.assert_forward_links(g4, [u4])
        self.assert_back_links(u1, [g3])
        self.assert_back_links(u2, [g1, g2])
        self.assert_back_links(u3, [g2, g3])
        self.assert_back_links(u4, [g4])

        self.replace_linked_attribute(g1, [u1, u2, u3])
        self.replace_linked_attribute(g2, [u1])
        self.replace_linked_attribute(g3, [u2])
        self.replace_linked_attribute(g4, [])

        self.assert_forward_links(g1, [u1, u2, u3])
        self.assert_forward_links(g2, [u1])
        self.assert_forward_links(g3, [u2])
        self.assert_forward_links(g4, [])
        self.assert_back_links(u1, [g1, g2])
        self.assert_back_links(u2, [g1, g3])
        self.assert_back_links(u3, [g1])
        self.assert_back_links(u4, [])

        self.assertRaisesLdbError(ldb.ERR_ENTRY_ALREADY_EXISTS,
                                  "replacing duplicate values",
                                  self.replace_linked_attribute, g2,
                                  [u1, u2, u3, u2])

    def test_la_links_replace2(self):
        users = self.add_objects(12, 'user', 'u_replace2')
        g1, = self.add_objects(1, 'group', 'g_replace2')

        self.add_linked_attribute(g1, users[:6])
        self.assert_forward_links(g1, users[:6])
        self.replace_linked_attribute(g1, users)
        self.assert_forward_links(g1, users)
        self.replace_linked_attribute(g1, users[6:])
        self.assert_forward_links(g1, users[6:])
        self.remove_linked_attribute(g1, users[6:9])
        self.assert_forward_links(g1, users[9:])
        self.remove_linked_attribute(g1, users[9:])
        self.assert_forward_links(g1, [])

    def test_la_links_permutations(self):
        """Make sure the order in which we add links doesn't matter."""
        users = self.add_objects(3, 'user', 'u_permutations')
        groups = self.add_objects(6, 'group', 'g_permutations')

        for g, p in zip(groups, itertools.permutations(users)):
            self.add_linked_attribute(g, p)

        # everyone should be in every group
        for g in groups:
            self.assert_forward_links(g, users)

        for u in users:
            self.assert_back_links(u, groups)

        for g, p in zip(groups[::-1], itertools.permutations(users)):
            self.replace_linked_attribute(g, p)

        for g in groups:
            self.assert_forward_links(g, users)

        for u in users:
            self.assert_back_links(u, groups)

        for g, p in zip(groups, itertools.permutations(users)):
            self.remove_linked_attribute(g, p)

        for g in groups:
            self.assert_forward_links(g, [])

        for u in users:
            self.assert_back_links(u, [])

    def test_la_links_relaxed(self):
        """Check that the relax control doesn't mess with linked attributes."""
        relax_control = ['relax:0']

        users = self.add_objects(10, 'user', 'u_relax')
        groups = self.add_objects(3,
                                  'group',
                                  'g_relax',
                                  more_attrs={'member': users[:2]})
        g_relax1, g_relax2, g_uptight = groups

        # g_relax1 has all users added at once
        # g_relax2 gets them one at a time in reverse order
        # g_uptight never relaxes

        self.add_linked_attribute(g_relax1, users[2:5], controls=relax_control)

        for u in reversed(users[2:5]):
            self.add_linked_attribute(g_relax2, u, controls=relax_control)
            self.add_linked_attribute(g_uptight, u)

        for g in groups:
            self.assert_forward_links(g, users[:5])

            self.add_linked_attribute(g, users[5:7])
            self.assert_forward_links(g, users[:7])

            for u in users[7:]:
                self.add_linked_attribute(g, u)

            self.assert_forward_links(g, users)

        for u in users:
            self.assert_back_links(u, groups)

        # try some replacement permutations
        import random
        random.seed(1)
        users2 = users[:]
        for i in range(5):
            random.shuffle(users2)
            self.replace_linked_attribute(g_relax1,
                                          users2,
                                          controls=relax_control)

            self.assert_forward_links(g_relax1, users)

        for i in range(5):
            random.shuffle(users2)
            self.remove_linked_attribute(g_relax2,
                                         users2,
                                         controls=relax_control)
            self.remove_linked_attribute(g_uptight, users2)

            self.replace_linked_attribute(g_relax1, [], controls=relax_control)

            random.shuffle(users2)
            self.add_linked_attribute(g_relax2, users2, controls=relax_control)
            self.add_linked_attribute(g_uptight, users2)
            self.replace_linked_attribute(g_relax1,
                                          users2,
                                          controls=relax_control)

            self.assert_forward_links(g_relax1, users)
            self.assert_forward_links(g_relax2, users)
            self.assert_forward_links(g_uptight, users)

        for u in users:
            self.assert_back_links(u, groups)

    def test_add_all_at_once(self):
        """All these other tests are creating linked attributes after the
        objects are there. We want to test creating them all at once
        using LDIF.
        """
        users = self.add_objects(7, 'user', 'u_all_at_once')
        g1, g3 = self.add_objects(2,
                                  'group',
                                  'g_all_at_once',
                                  more_attrs={'member': users})
        (g2, ) = self.add_objects(1,
                                  'group',
                                  'g_all_at_once2',
                                  more_attrs={'member': users[:5]})

        self.assertRaisesLdbError(
            ldb.ERR_ENTRY_ALREADY_EXISTS,
            "adding multiple duplicate values",
            self.add_objects,
            1,
            'group',
            'g_with_duplicate_links',
            more_attrs={'member': users[:5] + users[1:2]})

        self.assert_forward_links(g1, users)
        self.assert_forward_links(g2, users[:5])
        self.assert_forward_links(g3, users)
        for u in users[:5]:
            self.assert_back_links(u, [g1, g2, g3])
        for u in users[5:]:
            self.assert_back_links(u, [g1, g3])

        self.remove_linked_attribute(g2, users[0])
        self.remove_linked_attribute(g2, users[1])
        self.add_linked_attribute(g2, users[1])
        self.add_linked_attribute(g2, users[5])
        self.add_linked_attribute(g2, users[6])

        self.assert_forward_links(g1, users)
        self.assert_forward_links(g2, users[1:])

        for u in users[1:]:
            self.remove_linked_attribute(g2, u)
        self.remove_linked_attribute(g1, users)

        for u in users:
            self.samdb.delete(u)

        self.assert_forward_links(g1, [])
        self.assert_forward_links(g2, [])
        self.assert_forward_links(g3, [])

    def test_one_way_attributes(self):
        e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',
                                  'e_one_way')
        guid = self.get_object_guid(e2)

        self.add_linked_attribute(e1, e2, attr="addressBookRoots")
        self.assert_forward_links(e1, [e2], attr='addressBookRoots')

        self.samdb.delete(e2)

        res = self.samdb.search("<GUID=%s>" % guid,
                                scope=ldb.SCOPE_BASE,
                                controls=['show_deleted:1', 'show_recycled:1'])

        new_dn = str(res[0].dn)
        self.assert_forward_links(e1, [new_dn], attr='addressBookRoots')
        self.assert_forward_links(e1, [new_dn],
                                  attr='addressBookRoots',
                                  show_deactivated_link=0)

    def test_one_way_attributes_delete_link(self):
        e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',
                                  'e_one_way')
        guid = self.get_object_guid(e2)

        self.add_linked_attribute(e1, e2, attr="addressBookRoots")
        self.assert_forward_links(e1, [e2], attr='addressBookRoots')

        self.remove_linked_attribute(e1, e2, attr="addressBookRoots")

        self.assert_forward_links(e1, [], attr='addressBookRoots')
        self.assert_forward_links(e1, [],
                                  attr='addressBookRoots',
                                  show_deactivated_link=0)

    def test_pretend_one_way_attributes(self):
        e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',
                                  'e_one_way')
        guid = self.get_object_guid(e2)

        self.add_linked_attribute(e1, e2, attr="addressBookRoots2")
        self.assert_forward_links(e1, [e2], attr='addressBookRoots2')

        self.samdb.delete(e2)
        res = self.samdb.search("<GUID=%s>" % guid,
                                scope=ldb.SCOPE_BASE,
                                controls=['show_deleted:1', 'show_recycled:1'])

        new_dn = str(res[0].dn)

        self.assert_forward_links(e1, [], attr='addressBookRoots2')
        self.assert_forward_links(e1, [],
                                  attr='addressBookRoots2',
                                  show_deactivated_link=0)

    def test_pretend_one_way_attributes_delete_link(self):
        e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',
                                  'e_one_way')
        guid = self.get_object_guid(e2)

        self.add_linked_attribute(e1, e2, attr="addressBookRoots2")
        self.assert_forward_links(e1, [e2], attr='addressBookRoots2')

        self.remove_linked_attribute(e1, e2, attr="addressBookRoots2")

        self.assert_forward_links(e1, [], attr='addressBookRoots2')
        self.assert_forward_links(e1, [],
                                  attr='addressBookRoots2',
                                  show_deactivated_link=0)
Exemple #46
0
class DsdbTests(TestCase):
    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

        # Create a test user
        user_name = "dsdb-user-" + str(uuid.uuid4().hex[0:6])
        user_pass = samba.generate_random_password(32, 32)
        user_description = "Test user for dsdb test"

        base_dn = self.samdb.domain_dn()

        self.account_dn = "CN=" + user_name + ",CN=Users," + base_dn
        self.samdb.newuser(username=user_name,
                           password=user_pass,
                           description=user_description)
        # Cleanup (teardown)
        self.addCleanup(delete_force, self.samdb, self.account_dn)

        # Get server reference DN
        res = self.samdb.search(base=ldb.Dn(self.samdb,
                                            self.samdb.get_serverName()),
                                scope=ldb.SCOPE_BASE,
                                attrs=["serverReference"])
        # Get server reference
        self.server_ref_dn = ldb.Dn(
            self.samdb, res[0]["serverReference"][0].decode("utf-8"))

        # Get RID Set DN
        res = self.samdb.search(base=self.server_ref_dn,
                                scope=ldb.SCOPE_BASE,
                                attrs=["rIDSetReferences"])
        rid_set_refs = res[0]
        self.assertIn("rIDSetReferences", rid_set_refs)
        rid_set_str = rid_set_refs["rIDSetReferences"][0].decode("utf-8")
        self.rid_set_dn = ldb.Dn(self.samdb, rid_set_str)

    def get_rid_set(self, rid_set_dn):
        res = self.samdb.search(base=rid_set_dn,
                                scope=ldb.SCOPE_BASE,
                                attrs=[
                                    "rIDAllocationPool",
                                    "rIDPreviousAllocationPool", "rIDUsedPool",
                                    "rIDNextRID"
                                ])
        return res[0]

    def test_ridalloc_next_free_rid(self):
        # Test RID allocation. We assume that RID
        # pools allocated to us are continguous.
        self.samdb.transaction_start()
        try:
            orig_rid_set = self.get_rid_set(self.rid_set_dn)
            self.assertIn("rIDAllocationPool", orig_rid_set)
            self.assertIn("rIDPreviousAllocationPool", orig_rid_set)
            self.assertIn("rIDUsedPool", orig_rid_set)
            self.assertIn("rIDNextRID", orig_rid_set)

            # Get rIDNextRID value from RID set.
            next_rid = int(orig_rid_set["rIDNextRID"][0])

            # Check the result of next_free_rid().
            next_free_rid = self.samdb.next_free_rid()
            self.assertEqual(next_rid + 1, next_free_rid)

            # Check calling it twice in succession gives the same result.
            next_free_rid2 = self.samdb.next_free_rid()
            self.assertEqual(next_free_rid, next_free_rid2)

            # Ensure that the RID set attributes have not changed.
            rid_set2 = self.get_rid_set(self.rid_set_dn)
            self.assertEqual(orig_rid_set, rid_set2)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_no_ridnextrid(self):
        self.samdb.transaction_start()
        try:
            # Delete the rIDNextRID attribute of the RID set,
            # and set up previous and next pools.
            prev_lo = 1000
            prev_hi = 1999
            next_lo = 3000
            next_hi = 3999
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDNextRID"] = ldb.MessageElement([], ldb.FLAG_MOD_DELETE,
                                                   "rIDNextRID")
            msg["rIDPreviousAllocationPool"] = (ldb.MessageElement(
                str((prev_hi << 32) | prev_lo), ldb.FLAG_MOD_REPLACE,
                "rIDPreviousAllocationPool"))
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                str((next_hi << 32) | next_lo), ldb.FLAG_MOD_REPLACE,
                "rIDAllocationPool"))
            self.samdb.modify(msg)

            # Ensure that next_free_rid() returns the start of the next pool.
            next_free_rid3 = self.samdb.next_free_rid()
            self.assertEqual(next_lo, next_free_rid3)

            # Check the result of allocate_rid() matches.
            rid = self.samdb.allocate_rid()
            self.assertEqual(next_free_rid3, rid)

            # Check that the result of next_free_rid() has now changed.
            next_free_rid4 = self.samdb.next_free_rid()
            self.assertEqual(rid + 1, next_free_rid4)

            # Check the range of available RIDs.
            free_lo, free_hi = self.samdb.free_rid_bounds()
            self.assertEqual(rid + 1, free_lo)
            self.assertEqual(next_hi, free_hi)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_no_free_rids(self):
        self.samdb.transaction_start()
        try:
            # Exhaust our current pool of RIDs.
            pool_lo = 2000
            pool_hi = 2999
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDPreviousAllocationPool"] = (ldb.MessageElement(
                str((pool_hi << 32) | pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDPreviousAllocationPool"))
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                str((pool_hi << 32) | pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDAllocationPool"))
            msg["rIDNextRID"] = (ldb.MessageElement(str(pool_hi),
                                                    ldb.FLAG_MOD_REPLACE,
                                                    "rIDNextRID"))
            self.samdb.modify(msg)

            # Ensure that calculating the next free RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.next_free_rid()

            self.assertEqual("RID pools out of RIDs", err.exception.args[1])

            # Ensure we can still allocate a new RID.
            self.samdb.allocate_rid()
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_new_ridset(self):
        self.samdb.transaction_start()
        try:
            # Test what happens with RID Set values set to zero (similar to
            # when a RID Set is first created, except we also set
            # rIDAllocationPool to zero).
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDPreviousAllocationPool"] = (ldb.MessageElement(
                "0", ldb.FLAG_MOD_REPLACE, "rIDPreviousAllocationPool"))
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                "0", ldb.FLAG_MOD_REPLACE, "rIDAllocationPool"))
            msg["rIDNextRID"] = (ldb.MessageElement("0", ldb.FLAG_MOD_REPLACE,
                                                    "rIDNextRID"))
            self.samdb.modify(msg)

            # Ensure that calculating the next free RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.next_free_rid()

            self.assertEqual("RID pools out of RIDs", err.exception.args[1])

            # Set values for the next pool.
            pool_lo = 2000
            pool_hi = 2999
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                str((pool_hi << 32) | pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDAllocationPool"))
            self.samdb.modify(msg)

            # Ensure the next free RID value is equal to the next pool's lower
            # bound.
            next_free_rid5 = self.samdb.next_free_rid()
            self.assertEqual(pool_lo, next_free_rid5)

            # Check the range of available RIDs.
            free_lo, free_hi = self.samdb.free_rid_bounds()
            self.assertEqual(pool_lo, free_lo)
            self.assertEqual(pool_hi, free_hi)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_move_to_new_pool(self):
        self.samdb.transaction_start()
        try:
            # Test moving to a new pool from the previous pool.
            pool_lo = 2000
            pool_hi = 2999
            new_pool_lo = 4500
            new_pool_hi = 4599
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDPreviousAllocationPool"] = (ldb.MessageElement(
                str((pool_hi << 32) | pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDPreviousAllocationPool"))
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                str((new_pool_hi << 32) | new_pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDAllocationPool"))
            msg["rIDNextRID"] = (ldb.MessageElement(str(pool_hi - 1),
                                                    ldb.FLAG_MOD_REPLACE,
                                                    "rIDNextRID"))
            self.samdb.modify(msg)

            # We should have remained in the previous pool.
            next_free_rid6 = self.samdb.next_free_rid()
            self.assertEqual(pool_hi, next_free_rid6)

            # Check the range of available RIDs.
            free_lo, free_hi = self.samdb.free_rid_bounds()
            self.assertEqual(pool_hi, free_lo)
            self.assertEqual(pool_hi, free_hi)

            # Allocate a new RID.
            rid2 = self.samdb.allocate_rid()
            self.assertEqual(next_free_rid6, rid2)

            # We should now move to the next pool.
            next_free_rid7 = self.samdb.next_free_rid()
            self.assertEqual(new_pool_lo, next_free_rid7)

            # Check the new range of available RIDs.
            free_lo2, free_hi2 = self.samdb.free_rid_bounds()
            self.assertEqual(new_pool_lo, free_lo2)
            self.assertEqual(new_pool_hi, free_hi2)

            # Ensure that allocate_rid() matches.
            rid3 = self.samdb.allocate_rid()
            self.assertEqual(next_free_rid7, rid3)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_no_ridsetreferences(self):
        self.samdb.transaction_start()
        try:
            # Delete the rIDSetReferences attribute.
            msg = ldb.Message()
            msg.dn = self.server_ref_dn
            msg["rIDSetReferences"] = (ldb.MessageElement([],
                                                          ldb.FLAG_MOD_DELETE,
                                                          "rIDSetReferences"))
            self.samdb.modify(msg)

            # Ensure calculating the next free RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.next_free_rid()

            enum, estr = err.exception.args
            self.assertEqual(ldb.ERR_NO_SUCH_ATTRIBUTE, enum)
            self.assertIn(
                "No RID Set DN - "
                "Cannot find attribute rIDSetReferences of %s "
                "to calculate reference dn" % self.server_ref_dn, estr)

            # Ensure allocating a new RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.allocate_rid()

            enum, estr = err.exception.args
            self.assertEqual(ldb.ERR_ENTRY_ALREADY_EXISTS, enum)
            self.assertIn(
                "No RID Set DN - "
                "Failed to add RID Set %s - "
                "Entry %s already exists" % (self.rid_set_dn, self.rid_set_dn),
                estr)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_no_rid_set(self):
        self.samdb.transaction_start()
        try:
            # Set the rIDSetReferences attribute to not point to a RID Set.
            fake_rid_set_str = self.account_dn
            msg = ldb.Message()
            msg.dn = self.server_ref_dn
            msg["rIDSetReferences"] = (ldb.MessageElement(
                fake_rid_set_str, ldb.FLAG_MOD_REPLACE, "rIDSetReferences"))
            self.samdb.modify(msg)

            # Ensure calculating the next free RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.next_free_rid()

            enum, estr = err.exception.args
            self.assertEqual(ldb.ERR_OPERATIONS_ERROR, enum)
            self.assertIn("Bad RID Set " + fake_rid_set_str, estr)

            # Ensure allocating a new RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.allocate_rid()

            enum, estr = err.exception.args
            self.assertEqual(ldb.ERR_OPERATIONS_ERROR, enum)
            self.assertIn("Bad RID Set " + fake_rid_set_str, estr)
        finally:
            self.samdb.transaction_cancel()

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEqual(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, [
            "local_oid:1.3.6.1.4.1.7165.4.3.14:0",
            "local_oid:1.3.6.1.4.1.7165.4.3.25:0"
        ])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val",
                                                ldb.FLAG_MOD_REPLACE,
                                                "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEqual(self.samdb.get_attribute_from_attid(13),
                         "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEqual(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEqual(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 2)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEqual(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(
            dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description",
                                                      version + 2)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "description"),
            version + 2)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:0" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:1" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            (errno, estr) = e.args
            if errno != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail(
                    "Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                    % e[1])

    # Allocate a unique RID for use in the objectSID tests.
    #
    def allocate_rid(self):
        self.samdb.transaction_start()
        try:
            rid = self.samdb.allocate_rid()
        except:
            self.samdb.transaction_cancel()
            raise
        self.samdb.transaction_commit()
        return str(rid)

    # Ensure that duplicate objectSID's are permitted for foreign security
    # principals.
    #
    def test_duplicate_objectSIDs_allowed_on_foreign_security_principals(self):

        #
        # We need to build a foreign security principal SID
        # i.e a  SID not in the current domain.
        #
        dom_sid = self.samdb.get_domain_sid()
        if str(dom_sid).endswith("0"):
            c = "9"
        else:
            c = "0"
        sid_str = str(dom_sid)[:-1] + c + "-1000"
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        dn = "CN=%s,CN=ForeignSecurityPrincipals,%s" % (sid_str, basedn)

        #
        # First without control
        #

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"
            })
            self.fail("No exception should get ERR_OBJECT_CLASS_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_OBJECT_CLASS_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_MISSING_REQUIRED_ATT
            self.assertTrue(werr in msg, msg)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal",
                "objectSid": sid
            })
            self.fail("No exception should get ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_ILLEGAL_MOD_OPERATION
            self.assertTrue(werr in msg, msg)

        #
        # We need to use the provision control
        # in order to add foreignSecurityPrincipal
        # objects
        #

        controls = ["provision:0"]
        self.samdb.add({
            "dn": dn,
            "objectClass": "foreignSecurityPrincipal"
        },
                       controls=controls)

        self.samdb.delete(dn)

        try:
            self.samdb.add(
                {
                    "dn": dn,
                    "objectClass": "foreignSecurityPrincipal"
                },
                controls=controls)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.fail("Got unexpected exception %d - %s " % (code, msg))

        # cleanup
        self.samdb.delete(dn)

    def _test_foreignSecurityPrincipal(self, obj_class, fpo_attr):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_fpo"
        dn_str = "cn=%s,cn=Users,%s" % (cn, basedn)
        dn = ldb.Dn(self.samdb, dn_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn_str)

        self.samdb.add({"dn": dn_str, "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_INVALID_GROUP_TYPE
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_NO_SUCH_OBJECT")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_NO_SUCH_OBJECT, str(e))
            werr = "%08X" % werror.WERR_NO_SUCH_MEMBER
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 1)
        self.samdb.delete(res[0].dn)
        self.samdb.delete(dn)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

    def test_foreignSecurityPrincipal_member(self):
        return self._test_foreignSecurityPrincipal("group", "member")

    def test_foreignSecurityPrincipal_MembersForAzRole(self):
        return self._test_foreignSecurityPrincipal("msDS-AzRole",
                                                   "msDS-MembersForAzRole")

    def test_foreignSecurityPrincipal_NeverRevealGroup(self):
        return self._test_foreignSecurityPrincipal("computer",
                                                   "msDS-NeverRevealGroup")

    def test_foreignSecurityPrincipal_RevealOnDemandGroup(self):
        return self._test_foreignSecurityPrincipal("computer",
                                                   "msDS-RevealOnDemandGroup")

    def _test_fail_foreignSecurityPrincipal(self,
                                            obj_class,
                                            fpo_attr,
                                            msg_exp,
                                            lerr_exp,
                                            werr_exp,
                                            allow_reference=True):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn = self.samdb.get_default_basedn()
        cn1 = "dsdb_test_fpo1"
        dn1_str = "cn=%s,cn=Users,%s" % (cn1, basedn)
        dn1 = ldb.Dn(self.samdb, dn1_str)
        cn2 = "dsdb_test_fpo2"
        dn2_str = "cn=%s,cn=Users,%s" % (cn2, basedn)
        dn2 = ldb.Dn(self.samdb, dn2_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn1_str)
        self.addCleanup(delete_force, self.samdb, dn2_str)

        self.samdb.add({"dn": dn1_str, "objectClass": obj_class})

        self.samdb.add({"dn": dn2_str, "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("%s" % dn2, ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            if not allow_reference:
                self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            if allow_reference:
                self.fail("Should have not raised an exception: %s" % e)
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(dn2)
        self.samdb.delete(dn1)

    def test_foreignSecurityPrincipal_NonMembers(self):
        return self._test_fail_foreignSecurityPrincipal(
            "group",
            "msDS-NonMembers",
            "LDB_ERR_UNWILLING_TO_PERFORM/WERR_NOT_SUPPORTED",
            ldb.ERR_UNWILLING_TO_PERFORM,
            werror.WERR_NOT_SUPPORTED,
            allow_reference=False)

    def test_foreignSecurityPrincipal_HostServiceAccount(self):
        return self._test_fail_foreignSecurityPrincipal(
            "computer", "msDS-HostServiceAccount",
            "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
            ldb.ERR_CONSTRAINT_VIOLATION,
            werror.WERR_DS_NAME_REFERENCE_INVALID)

    def test_foreignSecurityPrincipal_manager(self):
        return self._test_fail_foreignSecurityPrincipal(
            "user", "manager",
            "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
            ldb.ERR_CONSTRAINT_VIOLATION,
            werror.WERR_DS_NAME_REFERENCE_INVALID)

    #
    # Duplicate objectSID's should not be permitted for sids in the local
    # domain. The test sequence is add an object, delete it, then attempt to
    # re-add it, this should fail with a constraint violation
    #
    def test_duplicate_objectSIDs_not_allowed_on_local_objects(self):

        dom_sid = self.samdb.get_domain_sid()
        rid = self.allocate_rid()
        sid_str = str(dom_sid) + "-" + rid
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_01"
        dn = "cn=%s,cn=Users,%s" % (cn, basedn)

        self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
        self.samdb.delete(dn)

        try:
            self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            if code != ldb.ERR_CONSTRAINT_VIOLATION:
                self.fail("Got %d - %s should have got "
                          "LDB_ERR_CONSTRAINT_VIOLATION" % (code, msg))

    def test_linked_vs_non_linked_reference(self):
        basedn = self.samdb.get_default_basedn()
        kept_dn_str = "cn=reference_kept,cn=Users,%s" % (basedn)
        removed_dn_str = "cn=reference_removed,cn=Users,%s" % (basedn)
        dom_sid = self.samdb.get_domain_sid()
        none_sid_str = str(dom_sid) + "-4294967294"
        none_guid_str = "afafafaf-fafa-afaf-fafa-afafafafafaf"

        self.addCleanup(delete_force, self.samdb, kept_dn_str)
        self.addCleanup(delete_force, self.samdb, removed_dn_str)

        self.samdb.add({"dn": kept_dn_str, "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=kept_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        kept_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        kept_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        kept_dn = res[0].dn

        self.samdb.add({"dn": removed_dn_str, "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=removed_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        removed_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        removed_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        self.samdb.delete(removed_dn_str)

        #
        # First try the linked attribute 'manager'
        # by GUID and SID
        #

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                            ldb.FLAG_MOD_ADD, "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                            ldb.FLAG_MOD_ADD, "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        #
        # Try the non-linked attribute 'assistant'
        # by GUID and SID, which should work.
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_ADD, "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_DELETE, "assistant")
        self.samdb.modify(msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_ADD, "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_DELETE, "assistant")
        self.samdb.modify(msg)

        #
        # Finally ry the non-linked attribute 'assistant'
        # but with non existing GUID, SID, DN
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("CN=NoneNone,%s" % (basedn),
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % none_sid_str,
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % none_guid_str,
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(kept_dn)

    def test_normalize_dn_in_domain_full(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        full_str = str(full_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_str))

    def test_normalize_dn_in_domain_part(self):
        domain_dn = self.samdb.domain_dn()

        part_str = "CN=Users"

        full_dn = ldb.Dn(self.samdb, part_str)
        full_dn.add_base(domain_dn)

        # That is, the domain DN appended
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(part_str))

    def test_normalize_dn_in_domain_full_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_dn))

    def test_normalize_dn_in_domain_part_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        # That is, the domain DN appended
        self.assertEqual(
            ldb.Dn(self.samdb,
                   str(part_dn) + "," + str(domain_dn)),
            self.samdb.normalize_dn_in_domain(part_dn))
Exemple #47
0
class LargeLDAPTest(samba.tests.TestCase):
    def setUp(self):
        super(LargeLDAPTest, self).setUp()
        self.ldb = SamDB(url,
                         credentials=creds,
                         session_info=system_session(lp),
                         lp=lp)
        self.base_dn = self.ldb.domain_dn()
        self.USER_NAME = "large_user" + format(random.randint(0, 99999),
                                               "05") + "-"
        self.OU_NAME = "large_user_ou" + format(random.randint(0, 99999), "05")
        self.ou_dn = ldb.Dn(self.ldb,
                            "ou=" + self.OU_NAME + "," + str(self.base_dn))

        samba.tests.delete_force(self.ldb,
                                 self.ou_dn,
                                 controls=['tree_delete:1'])

        self.ldb.add({
            "dn": self.ou_dn,
            "objectclass": "organizationalUnit",
            "ou": self.OU_NAME
        })

        for x in range(200):
            user_name = self.USER_NAME + format(x, "03")
            self.ldb.add({
                "dn": "cn=" + user_name + "," + str(self.ou_dn),
                "objectclass": "user",
                "sAMAccountName": user_name,
                "jpegPhoto": b'a' * (2 * 1024 * 1024)
            })

    def tearDown(self):
        # Remake the connection for tear-down (old Samba drops the socket)
        self.ldb = SamDB(url,
                         credentials=creds,
                         session_info=system_session(lp),
                         lp=lp)
        samba.tests.delete_force(self.ldb,
                                 self.ou_dn,
                                 controls=['tree_delete:1'])

    def test_unindexed_iterator_search(self):
        """Testing an unindexed search that will break the result size limit"""
        if not url.startswith("ldap"):
            self.fail(msg="This test is only valid on ldap")

        count = 0
        msg1 = None
        search1 = self.ldb.search_iterator(
            base=self.ou_dn,
            expression="(sAMAccountName=" + self.USER_NAME + "*)",
            scope=ldb.SCOPE_SUBTREE,
            attrs=["objectGUID", "samAccountName"])

        for reply in search1:
            self.assertIsInstance(reply, ldb.Message)
            count += 1

        res1 = search1.result()

        self.assertEqual(count, 200)

        # Now try breaking the 256MB limit

        count_jpeg = 0
        msg1 = None
        search1 = self.ldb.search_iterator(
            base=self.ou_dn,
            expression="(sAMAccountName=" + self.USER_NAME + "*)",
            scope=ldb.SCOPE_SUBTREE,
            attrs=["objectGUID", "samAccountName", "jpegPhoto"])
        try:
            for reply in search1:
                self.assertIsInstance(reply, ldb.Message)
                msg1 = reply
                count_jpeg += 1
        except LdbError as e:
            enum = err.args[0]
            self.assertEqual(enum, ldb.ERR_SIZE_LIMIT_EXCEEDED)

        # Assert we don't get all the entries but still the error
        self.assertGreater(count, count_jpeg)

        # Now try for just 100MB (server will do some chunking for this)

        count_jpeg2 = 0
        msg1 = None
        try:
            search1 = self.ldb.search_iterator(
                base=self.ou_dn,
                expression="(sAMAccountName=" + self.USER_NAME + "1*)",
                scope=ldb.SCOPE_SUBTREE,
                attrs=["objectGUID", "samAccountName", "jpegPhoto"])
        except LdbError as e:
            enum = e.args[0]
            estr = e.args[1]
            self.fail(estr)

        for reply in search1:
            self.assertIsInstance(reply, ldb.Message)
            msg1 = reply
            count_jpeg2 += 1

        # Assert we got some entries
        self.assertEqual(count_jpeg2, 100)

    def test_iterator_search(self):
        """Testing an indexed search that will break the result size limit"""
        if not url.startswith("ldap"):
            self.fail(msg="This test is only valid on ldap")

        count = 0
        msg1 = None
        search1 = self.ldb.search_iterator(
            base=self.ou_dn,
            expression="(&(objectClass=user)(sAMAccountName=" +
            self.USER_NAME + "*))",
            scope=ldb.SCOPE_SUBTREE,
            attrs=["objectGUID", "samAccountName"])

        for reply in search1:
            self.assertIsInstance(reply, ldb.Message)
            count += 1
        res1 = search1.result()

        # Now try breaking the 256MB limit

        count_jpeg = 0
        msg1 = None
        search1 = self.ldb.search_iterator(
            base=self.ou_dn,
            expression="(&(objectClass=user)(sAMAccountName=" +
            self.USER_NAME + "*))",
            scope=ldb.SCOPE_SUBTREE,
            attrs=["objectGUID", "samAccountName", "jpegPhoto"])
        try:
            for reply in search1:
                self.assertIsInstance(reply, ldb.Message)
                count_jpeg = +1
        except LdbError as e:
            enum = err.args[0]
            self.assertEqual(enum, ldb.ERR_SIZE_LIMIT_EXCEEDED)

        # Assert we don't get all the entries but still the error
        self.assertGreater(count, count_jpeg)
class SubtreeRenameTests(samba.tests.TestCase):

    def delete_ous(self):
        for ou in (self.ou1, self.ou2, self.ou3):
            try:
                self.samdb.delete(ou, ['tree_delete:1'])
            except ldb.LdbError as e:
                pass

    def setUp(self):
        super(SubtreeRenameTests, self).setUp()
        self.samdb = SamDB(host, credentials=creds,
                           session_info=system_session(lp), lp=lp)

        self.base_dn = self.samdb.domain_dn()
        self.ou1 = "OU=subtree1,%s" % self.base_dn
        self.ou2 = "OU=subtree2,%s" % self.base_dn
        self.ou3 = "OU=subtree3,%s" % self.base_dn
        if opts.delete_in_setup:
            self.delete_ous()
        self.samdb.add({'objectclass': 'organizationalUnit',
                        'dn': self.ou1})
        self.samdb.add({'objectclass': 'organizationalUnit',
                        'dn': self.ou2})

        debug(colour.c_REV_RED(self.id()))

    def tearDown(self):
        super(SubtreeRenameTests, self).tearDown()
        if not opts.no_cleanup:
            self.delete_ous()

    def add_object(self, cn, objectclass, ou=None, more_attrs={}):
        dn = "CN=%s,%s" % (cn, ou)
        attrs = {'cn': cn,
                 'objectclass': objectclass,
                 'dn': dn}
        attrs.update(more_attrs)
        self.samdb.add(attrs)

        return dn

    def add_objects(self, n, objectclass, prefix=None, ou=None, more_attrs={}):
        if prefix is None:
            prefix = objectclass
        dns = []
        for i in range(n):
            dns.append(self.add_object("%s%d" % (prefix, i + 1),
                                       objectclass,
                                       more_attrs=more_attrs,
                                       ou=ou))
        return dns

    def add_linked_attribute(self, src, dest, attr='member',
                             controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_ADD, attr)
        self.samdb.modify(m, controls=controls)

    def remove_linked_attribute(self, src, dest, attr='member',
                                controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_DELETE, attr)
        self.samdb.modify(m, controls=controls)

    def add_binary_link(self, src, dest, binary,
                        attr='msDS-RevealedUsers',
                        controls=None):
        b = hexlify(str(binary).encode('utf-8')).decode('utf-8').upper()
        dest = 'B:%d:%s:%s' % (len(b), b, dest)
        self.add_linked_attribute(src, dest, attr, controls)
        return dest

    def remove_binary_link(self, src, dest, binary,
                           attr='msDS-RevealedUsers',
                           controls=None):
        b = str(binary).encode('utf-8')
        dest = 'B:%s:%s' % (hexlify(b), dest)
        self.remove_linked_attribute(src, dest, attr, controls)

    def replace_linked_attribute(self, src, dest, attr='member',
                                 controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_REPLACE, attr)
        self.samdb.modify(m, controls=controls)

    def attr_search(self, obj, attr, scope=ldb.SCOPE_BASE, **controls):

        controls = ['%s:%d' % (k, int(v)) for k, v in controls.items()]

        res = self.samdb.search(obj,
                                scope=scope,
                                attrs=[attr],
                                controls=controls)
        return res

    def assert_links(self, obj, expected, attr, msg='', **kwargs):
        res = self.attr_search(obj, attr, **kwargs)

        if len(expected) == 0:
            if attr in res[0]:
                self.fail("found attr '%s' in %s" % (attr, res[0]))
            return

        try:
            results = [str(x) for x in res[0][attr]]
        except KeyError:
            self.fail("missing attr '%s' on %s" % (attr, obj))

        expected = sorted(expected)
        results = sorted(results)

        if expected != results:
            debug(msg)
            debug("expected %s" % expected)
            debug("received %s" % results)
            debug("missing    %s" % (sorted(set(expected) - set(results))))
            debug("unexpected %s" % (sorted(set(results) - set(expected))))


        self.assertEqual(results, expected)

    def assert_back_links(self, obj, expected, attr='memberOf', **kwargs):
        self.assert_links(obj, expected, attr=attr,
                          msg='%s back links do not match for %s' %
                          (attr, obj),
                          **kwargs)

    def assert_forward_links(self, obj, expected, attr='member', **kwargs):
        self.assert_links(obj, expected, attr=attr,
                          msg='%s forward links do not match for %s' %
                          (attr, obj),
                          **kwargs)

    def get_object_guid(self, dn):
        res = self.samdb.search(dn,
                                scope=ldb.SCOPE_BASE,
                                attrs=['objectGUID'])
        return str(misc.GUID(res[0]['objectGUID'][0]))

    def assertRaisesLdbError(self, errcode, message, f, *args, **kwargs):
        """Assert a function raises a particular LdbError."""
        try:
            f(*args, **kwargs)
        except ldb.LdbError as e:
            (num, msg) = e.args
            if num != errcode:
                lut = {v: k for k, v in vars(ldb).items()
                       if k.startswith('ERR_') and isinstance(v, int)}
                self.fail("%s, expected "
                          "LdbError %s, (%d) "
                          "got %s (%d) "
                          "%s" % (message,
                                  lut.get(errcode), errcode,
                                  lut.get(num), num,
                                  msg))
        else:
            lut = {v: k for k, v in vars(ldb).items()
                   if k.startswith('ERR_') and isinstance(v, int)}
            self.fail("%s, expected "
                      "LdbError %s, (%d) "
                      "but we got success" % (message,
                                              lut.get(errcode),
                                              errcode))

    def test_la_move_ou_tree(self):
        tag = 'move_tree'

        u1, u2 = self.add_objects(2, 'user', '%s_u_' % tag, ou=self.ou1)
        g1, g2 = self.add_objects(2, 'group', '%s_g_' % tag, ou=self.ou1)
        c1, c2, c3 = self.add_objects(3, 'computer',
                                      '%s_c_' % tag,
                                      ou=self.ou1)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g1, g2)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)
        c1u1 = self.add_binary_link(c1, u1, 'a').replace(self.ou1, self.ou3)
        c2u1 = self.add_binary_link(c2, u1, 'b').replace(self.ou1, self.ou3)
        c3u1 = self.add_binary_link(c3, u1, 124.543).replace(self.ou1, self.ou3)
        c1g1 = self.add_binary_link(c1, g1, 'd').replace(self.ou1, self.ou3)
        c2g2 = self.add_binary_link(c2, g2, 'd').replace(self.ou1, self.ou3)
        c2c1 = self.add_binary_link(c2, c1, 'd').replace(self.ou1, self.ou3)
        c1u2 = self.add_binary_link(c1, u2, 'd').replace(self.ou1, self.ou3)
        c1u1_2 = self.add_binary_link(c1, u1, 'b').replace(self.ou1, self.ou3)

        self.assertRaisesLdbError(20,
                                  "Attribute msDS-RevealedUsers already exists",
                                  self.add_binary_link, c1, u2, 'd')

        self.samdb.rename(self.ou1, self.ou3)
        debug(colour.c_CYAN("rename FINISHED"))
        u1, u2, g1, g2, c1, c2, c3 = [x.replace(self.ou1, self.ou3)
                                      for x in (u1, u2, g1, g2, c1, c2, c3)]

        self.samdb.delete(g2, ['tree_delete:1'])

        self.assert_forward_links(g1, [u1])
        self.assert_back_links(u1, [g1])
        self.assert_back_links(u2, set())
        self.assert_forward_links(c1, [c1u1, c1u1_2, c1u2, c1g1],
                                  attr='msDS-RevealedUsers')
        self.assert_forward_links(c2, [c2u1, c2c1], attr='msDS-RevealedUsers')
        self.assert_forward_links(c3, [c3u1], attr='msDS-RevealedUsers')
        self.assert_back_links(u1, [c1, c1, c2, c3], attr='msDS-RevealedDSAs')
        self.assert_back_links(u2, [c1], attr='msDS-RevealedDSAs')
        self.assert_back_links(g1, [c1], attr='msDS-RevealedDSAs')
        self.assert_back_links(c1, [c2], attr='msDS-RevealedDSAs')

    def test_la_move_ou_groups(self):
        tag = 'move_groups'

        u1, u2 = self.add_objects(2, 'user', '%s_u_' % tag, ou=self.ou2)
        g1, g2 = self.add_objects(2, 'group', '%s_g_' % tag, ou=self.ou1)
        c1, c2, c3 = self.add_objects(3, 'computer',
                                      '%s_c_' % tag,
                                      ou=self.ou1)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g1, g2)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)
        c1u1 = self.add_binary_link(c1, u1, 'a').replace(self.ou1, self.ou3)
        c2u1 = self.add_binary_link(c2, u1, 'b').replace(self.ou1, self.ou3)
        c3u1 = self.add_binary_link(c3, u1, 124.543).replace(self.ou1, self.ou3)
        c1g1 = self.add_binary_link(c1, g1, 'd').replace(self.ou1, self.ou3)
        c2g2 = self.add_binary_link(c2, g2, 'd').replace(self.ou1, self.ou3)
        c2c1 = self.add_binary_link(c2, c1, 'd').replace(self.ou1, self.ou3)
        c1u2 = self.add_binary_link(c1, u2, 'd').replace(self.ou1, self.ou3)
        c1u1_2 = self.add_binary_link(c1, u1, 'b').replace(self.ou1, self.ou3)

        self.samdb.rename(self.ou1, self.ou3)
        debug(colour.c_CYAN("rename FINISHED"))
        u1, u2, g1, g2, c1, c2, c3 = [x.replace(self.ou1, self.ou3)
                                      for x in (u1, u2, g1, g2, c1, c2, c3)]

        self.samdb.delete(g2, ['tree_delete:1'])

        self.assert_forward_links(g1, [u1])
        self.assert_back_links(u1, [g1])
        self.assert_back_links(u2, set())
        self.assert_forward_links(c1, [c1u1, c1u1_2, c1u2, c1g1],
                                  attr='msDS-RevealedUsers')
        self.assert_forward_links(c2, [c2u1, c2c1], attr='msDS-RevealedUsers')
        self.assert_forward_links(c3, [c3u1], attr='msDS-RevealedUsers')
        self.assert_back_links(u1, [c1, c1, c2, c3], attr='msDS-RevealedDSAs')
        self.assert_back_links(u2, [c1], attr='msDS-RevealedDSAs')
        self.assert_back_links(g1, [c1], attr='msDS-RevealedDSAs')
        self.assert_back_links(c1, [c2], attr='msDS-RevealedDSAs')

    def test_la_move_ou_users(self):
        tag = 'move_users'

        u1, u2 = self.add_objects(2, 'user', '%s_u_' % tag, ou=self.ou1)
        g1, g2 = self.add_objects(2, 'group', '%s_g_' % tag, ou=self.ou2)
        c1, c2 = self.add_objects(2, 'computer', '%s_c_' % tag, ou=self.ou1)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g1, g2)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)
        c1u1 = self.add_binary_link(c1, u1, 'a').replace(self.ou1, self.ou3)
        c2u1 = self.add_binary_link(c2, u1, 'b').replace(self.ou1, self.ou3)
        c1g1 = self.add_binary_link(c1, g1, 'd').replace(self.ou1, self.ou3)
        c2g2 = self.add_binary_link(c2, g2, 'd').replace(self.ou1, self.ou3)
        c2c1 = self.add_binary_link(c2, c1, 'd').replace(self.ou1, self.ou3)
        c1u2 = self.add_binary_link(c1, u2, 'd').replace(self.ou1, self.ou3)
        c1u1_2 = self.add_binary_link(c1, u1, 'b').replace(self.ou1, self.ou3)


        self.samdb.rename(self.ou1, self.ou3)
        debug(colour.c_CYAN("rename FINISHED"))
        u1, u2, g1, g2, c1, c2 = [x.replace(self.ou1, self.ou3)
                                  for x in (u1, u2, g1, g2, c1, c2)]

        self.samdb.delete(g2, ['tree_delete:1'])

        self.assert_forward_links(g1, [u1])
        self.assert_back_links(u1, [g1])
        self.assert_back_links(u2, set())
        self.assert_forward_links(c1, [c1u1, c1u1_2, c1u2, c1g1],
                                  attr='msDS-RevealedUsers')
        self.assert_forward_links(c2, [c2u1, c2c1], attr='msDS-RevealedUsers')
        self.assert_back_links(u1, [c1, c1, c2], attr='msDS-RevealedDSAs')
        self.assert_back_links(u2, [c1], attr='msDS-RevealedDSAs')
        self.assert_back_links(g1, [c1], attr='msDS-RevealedDSAs')
        self.assert_back_links(c1, [c2], attr='msDS-RevealedDSAs')

    def test_la_move_ou_noncomputers(self):
        """Here we are especially testing the msDS-RevealedDSAs links"""
        tag = 'move_noncomputers'

        u1, u2 = self.add_objects(2, 'user', '%s_u_' % tag, ou=self.ou1)
        g1, g2 = self.add_objects(2, 'group', '%s_g_' % tag, ou=self.ou1)
        c1, c2, c3 = self.add_objects(3, 'computer', '%s_c_' % tag, ou=self.ou2)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g1, g2)
        c1u1 = self.add_binary_link(c1, u1, 'a').replace(self.ou1, self.ou3)
        c2u1 = self.add_binary_link(c2, u1, 'b').replace(self.ou1, self.ou3)
        c2u1_2 = self.add_binary_link(c2, u1, 'c').replace(self.ou1, self.ou3)
        c3u1 = self.add_binary_link(c3, g1, 'b').replace(self.ou1, self.ou3)
        c1g1 = self.add_binary_link(c1, g1, 'd').replace(self.ou1, self.ou3)
        c2g2 = self.add_binary_link(c2, g2, 'd').replace(self.ou1, self.ou3)
        c2c1 = self.add_binary_link(c2, c1, 'd').replace(self.ou1, self.ou3)
        c1u2 = self.add_binary_link(c1, u2, 'd').replace(self.ou1, self.ou3)
        c1u1_2 = self.add_binary_link(c1, u1, 'b').replace(self.ou1, self.ou3)
        c1u1_3 = self.add_binary_link(c1, u1, 'c').replace(self.ou1, self.ou3)
        c2u1_3 = self.add_binary_link(c2, u1, 'e').replace(self.ou1, self.ou3)
        c3u2 = self.add_binary_link(c3, u2, 'b').replace(self.ou1, self.ou3)

        self.samdb.rename(self.ou1, self.ou3)
        debug(colour.c_CYAN("rename FINISHED"))
        u1, u2, g1, g2, c1, c2, c3 = [x.replace(self.ou1, self.ou3)
                                      for x in (u1, u2, g1, g2, c1, c2, c3)]

        self.samdb.delete(c3, ['tree_delete:1'])

        self.assert_forward_links(g1, [g2, u1])
        self.assert_back_links(u1, [g1])
        self.assert_back_links(u2, [])
        self.assert_forward_links(c1, [c1u1, c1u1_2, c1u1_3, c1u2, c1g1],
                                  attr='msDS-RevealedUsers')
        self.assert_forward_links(c2, [c2u1, c2u1_2, c2u1_3, c2c1, c2g2],
                                  attr='msDS-RevealedUsers')
        self.assert_back_links(u1, [c1, c1, c1, c2, c2, c2],
                               attr='msDS-RevealedDSAs')
        self.assert_back_links(u2, [c1], attr='msDS-RevealedDSAs')
        self.assert_back_links(g1, [c1], attr='msDS-RevealedDSAs')
        self.assert_back_links(c1, [c2], attr='msDS-RevealedDSAs')

    def test_la_move_ou_tree_big(self):
        tag = 'move_ou_big'
        USERS, GROUPS, COMPUTERS = 50, 10, 7

        users = self.add_objects(USERS, 'user', '%s_u_' % tag, ou=self.ou1)
        groups = self.add_objects(GROUPS, 'group', '%s_g_' % tag, ou=self.ou1)
        computers = self.add_objects(COMPUTERS, 'computer', '%s_c_' % tag,
                                     ou=self.ou1)

        start = time()
        for i in range(USERS):
            u = users[i]
            for j in range(i % GROUPS):
                g = groups[j]
                self.add_linked_attribute(g, u)
            for j in range(i % COMPUTERS):
                c = computers[j]
                self.add_binary_link(c, u, 'a')

        debug("linking took %.3fs" % (time() - start))
        start = time()
        self.samdb.rename(self.ou1, self.ou3)
        debug("rename ou took %.3fs" % (time() - start))

        g1 = groups[0].replace(self.ou1, self.ou3)
        start = time()
        self.samdb.rename(g1, g1.replace(self.ou3, self.ou2))
        debug("rename group took %.3fs" % (time() - start))

        u1 = users[0].replace(self.ou1, self.ou3)
        start = time()
        self.samdb.rename(u1, u1.replace(self.ou3, self.ou2))
        debug("rename user took %.3fs" % (time() - start))

        c1 = computers[0].replace(self.ou1, self.ou3)
        start = time()
        self.samdb.rename(c1, c1.replace(self.ou3, self.ou2))
        debug("rename computer took %.3fs" % (time() - start))
Exemple #49
0
class PassWordHashTests(TestCase):
    def setUp(self):
        self.lp = samba.tests.env_loadparm()
        super(PassWordHashTests, self).setUp()

    def set_store_cleartext(self, cleartext):
        # get the current pwdProperties
        pwdProperties = self.ldb.get_pwdProperties()
        # update the clear-text properties flag
        props = int(pwdProperties)
        if cleartext:
            props |= DOMAIN_PASSWORD_STORE_CLEARTEXT
        else:
            props &= ~DOMAIN_PASSWORD_STORE_CLEARTEXT
        self.ldb.set_pwdProperties(str(props))

    # Add a user to ldb, this will exercise the password_hash code
    # and calculate the appropriate supplemental credentials
    def add_user(self, options=None, clear_text=False, ldb=None):
        # set any needed options
        if options is not None:
            for (option, value) in options:
                self.lp.set(option, value)

        if ldb is None:
            self.creds = Credentials()
            self.session = system_session()
            self.creds.guess(self.lp)
            self.session = system_session()
            self.ldb = SamDB(session_info=self.session,
                             credentials=self.creds,
                             lp=self.lp)
        else:
            self.ldb = ldb

        res = self.ldb.search(base=self.ldb.get_config_basedn(),
                              expression="ncName=%s" %
                              self.ldb.get_default_basedn(),
                              attrs=["nETBIOSName"])
        self.netbios_domain = str(res[0]["nETBIOSName"][0])
        self.dns_domain = self.ldb.domain_dns_name()

        # Gets back the basedn
        base_dn = self.ldb.domain_dn()

        # Gets back the configuration basedn
        configuration_dn = self.ldb.get_config_basedn().get_linearized()

        # permit password changes during this test
        PasswordCommon.allow_password_changes(self, self.ldb)

        self.base_dn = self.ldb.domain_dn()

        account_control = 0
        if clear_text:
            # Restore the current domain setting on exit.
            pwdProperties = self.ldb.get_pwdProperties()
            self.addCleanup(self.ldb.set_pwdProperties, pwdProperties)
            # Update the domain setting
            self.set_store_cleartext(clear_text)
            account_control |= UF_ENCRYPTED_TEXT_PASSWORD_ALLOWED

        # (Re)adds the test user USER_NAME with password USER_PASS
        # and userPrincipalName UPN
        delete_force(self.ldb, "cn=" + USER_NAME + ",cn=users," + self.base_dn)
        self.ldb.add({
            "dn": "cn=" + USER_NAME + ",cn=users," + self.base_dn,
            "objectclass": "user",
            "sAMAccountName": USER_NAME,
            "userPassword": USER_PASS,
            "userPrincipalName": UPN,
            "userAccountControl": str(account_control)
        })

    # Get the supplemental credentials for the user under test
    def get_supplemental_creds(self):
        base = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        res = self.ldb.search(scope=ldb.SCOPE_BASE,
                              base=base,
                              attrs=["supplementalCredentials"])
        self.assertIs(True, len(res) > 0)
        obj = res[0]
        sc_blob = obj["supplementalCredentials"][0]
        sc = ndr_unpack(drsblobs.supplementalCredentialsBlob, sc_blob)
        return sc

    # Calculate and validate a Wdigest value
    def check_digest(self, user, realm, password, digest):
        expected = calc_digest(user, realm, password)
        actual = binascii.hexlify(bytearray(digest)).decode('utf8')
        error = "Digest expected[%s], actual[%s], " \
                "user[%s], realm[%s], pass[%s]" % \
                (expected, actual, user, realm, password)
        self.assertEquals(expected, actual, error)

    # Check all of the 29 expected WDigest values
    #
    def check_wdigests(self, digests):

        self.assertEquals(29, digests.num_hashes)

        # Using the n-1 pattern in the array indexes to make it easier
        # to check the tests against the spec and the samba-tool user tests.
        self.check_digest(USER_NAME, self.netbios_domain, USER_PASS,
                          digests.hashes[1 - 1].hash)
        self.check_digest(USER_NAME.lower(), self.netbios_domain.lower(),
                          USER_PASS, digests.hashes[2 - 1].hash)
        self.check_digest(USER_NAME.upper(), self.netbios_domain.upper(),
                          USER_PASS, digests.hashes[3 - 1].hash)
        self.check_digest(USER_NAME, self.netbios_domain.upper(), USER_PASS,
                          digests.hashes[4 - 1].hash)
        self.check_digest(USER_NAME, self.netbios_domain.lower(), USER_PASS,
                          digests.hashes[5 - 1].hash)
        self.check_digest(USER_NAME.upper(), self.netbios_domain.lower(),
                          USER_PASS, digests.hashes[6 - 1].hash)
        self.check_digest(USER_NAME.lower(), self.netbios_domain.upper(),
                          USER_PASS, digests.hashes[7 - 1].hash)
        self.check_digest(USER_NAME, self.dns_domain, USER_PASS,
                          digests.hashes[8 - 1].hash)
        self.check_digest(USER_NAME.lower(), self.dns_domain.lower(),
                          USER_PASS, digests.hashes[9 - 1].hash)
        self.check_digest(USER_NAME.upper(), self.dns_domain.upper(),
                          USER_PASS, digests.hashes[10 - 1].hash)
        self.check_digest(USER_NAME, self.dns_domain.upper(), USER_PASS,
                          digests.hashes[11 - 1].hash)
        self.check_digest(USER_NAME, self.dns_domain.lower(), USER_PASS,
                          digests.hashes[12 - 1].hash)
        self.check_digest(USER_NAME.upper(), self.dns_domain.lower(),
                          USER_PASS, digests.hashes[13 - 1].hash)
        self.check_digest(USER_NAME.lower(), self.dns_domain.upper(),
                          USER_PASS, digests.hashes[14 - 1].hash)
        self.check_digest(UPN, "", USER_PASS, digests.hashes[15 - 1].hash)
        self.check_digest(UPN.lower(), "", USER_PASS,
                          digests.hashes[16 - 1].hash)
        self.check_digest(UPN.upper(), "", USER_PASS,
                          digests.hashes[17 - 1].hash)

        name = "%s\\%s" % (self.netbios_domain, USER_NAME)
        self.check_digest(name, "", USER_PASS, digests.hashes[18 - 1].hash)

        name = "%s\\%s" % (self.netbios_domain.lower(), USER_NAME.lower())
        self.check_digest(name, "", USER_PASS, digests.hashes[19 - 1].hash)

        name = "%s\\%s" % (self.netbios_domain.upper(), USER_NAME.upper())
        self.check_digest(name, "", USER_PASS, digests.hashes[20 - 1].hash)
        self.check_digest(USER_NAME, "Digest", USER_PASS,
                          digests.hashes[21 - 1].hash)
        self.check_digest(USER_NAME.lower(), "Digest", USER_PASS,
                          digests.hashes[22 - 1].hash)
        self.check_digest(USER_NAME.upper(), "Digest", USER_PASS,
                          digests.hashes[23 - 1].hash)
        self.check_digest(UPN, "Digest", USER_PASS,
                          digests.hashes[24 - 1].hash)
        self.check_digest(UPN.lower(), "Digest", USER_PASS,
                          digests.hashes[25 - 1].hash)
        self.check_digest(UPN.upper(), "Digest", USER_PASS,
                          digests.hashes[26 - 1].hash)
        name = "%s\\%s" % (self.netbios_domain, USER_NAME)
        self.check_digest(name, "Digest", USER_PASS,
                          digests.hashes[27 - 1].hash)

        name = "%s\\%s" % (self.netbios_domain.lower(), USER_NAME.lower())
        self.check_digest(name, "Digest", USER_PASS,
                          digests.hashes[28 - 1].hash)

        name = "%s\\%s" % (self.netbios_domain.upper(), USER_NAME.upper())
        self.check_digest(name, "Digest", USER_PASS,
                          digests.hashes[29 - 1].hash)

    def checkUserPassword(self, up, expected):

        # Check we've received the correct number of hashes
        self.assertEquals(len(expected), up.num_hashes)

        i = 0
        for (tag, alg, rounds) in expected:
            self.assertEquals(tag, up.hashes[i].scheme)

            data = up.hashes[i].value.decode('utf8').split("$")
            # Check we got the expected crypt algorithm
            self.assertEquals(alg, data[1])

            if rounds is None:
                cmd = "$%s$%s" % (alg, data[2])
            else:
                cmd = "$%s$rounds=%d$%s" % (alg, rounds, data[3])

            # Calculate the expected hash value
            expected = crypt.crypt(USER_PASS, cmd)
            self.assertEquals(expected, up.hashes[i].value.decode('utf8'))
            i += 1

    # Check that the correct nt_hash was stored for userPassword
    def checkNtHash(self, password, nt_hash):
        creds = Credentials()
        creds.set_anonymous()
        creds.set_password(password)
        expected = creds.get_nt_hash()
        actual = bytearray(nt_hash)
        self.assertEquals(expected, actual)
Exemple #50
0
class RodcTests(samba.tests.TestCase):
    def setUp(self):
        super(RodcTests, self).setUp()
        self.samdb = SamDB(HOST,
                           credentials=CREDS,
                           session_info=system_session(LP),
                           lp=LP)

        self.base_dn = self.samdb.domain_dn()

        root = self.samdb.search(base='',
                                 scope=ldb.SCOPE_BASE,
                                 attrs=['dsServiceName'])
        self.service = root[0]['dsServiceName'][0]
        self.tag = uuid.uuid4().hex

    def test_add_replicated_objects(self):
        for o in (
            {
                'dn': "ou=%s1,%s" % (self.tag, self.base_dn),
                "objectclass": "organizationalUnit"
            },
            {
                'dn': "cn=%s2,%s" % (self.tag, self.base_dn),
                "objectclass": "user"
            },
            {
                'dn': "cn=%s3,%s" % (self.tag, self.base_dn),
                "objectclass": "group"
            },
            {
                'dn': "cn=%s4,%s" % (self.tag, self.service),
                "objectclass": "NTDSConnection",
                'enabledConnection': 'TRUE',
                'fromServer': self.base_dn,
                'options': '0'
            },
        ):
            try:
                self.samdb.add(o)
                self.fail("Failed to fail to add %s" % o['dn'])
            except ldb.LdbError as e:
                (ecode, emsg) = e.args
                if ecode != ldb.ERR_REFERRAL:
                    print(emsg)
                    self.fail("Adding %s: ldb error: %s %s, wanted referral" %
                              (o['dn'], ecode, emsg))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    try:
                        tmpdb = SamDB(address,
                                      credentials=CREDS,
                                      session_info=system_session(LP),
                                      lp=LP)
                        tmpdb.add(o)
                        tmpdb.delete(o['dn'])
                    except ldb.LdbError as e:
                        self.fail("couldn't modify referred location %s" %
                                  address)

                    if address.lower().startswith(
                            self.samdb.domain_dns_name()):
                        self.fail(
                            "referral address did not give a specific DC")

    def test_modify_replicated_attributes(self):
        # some timestamp ones
        dn = 'CN=Guest,CN=Users,' + self.base_dn
        value = 'hallooo'
        for attr in ['carLicense', 'middleName']:
            msg = ldb.Message()
            msg.dn = ldb.Dn(self.samdb, dn)
            msg[attr] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, attr)
            try:
                self.samdb.modify(msg)
                self.fail("Failed to fail to modify %s %s" % (dn, attr))
            except ldb.LdbError as e1:
                (ecode, emsg) = e1.args
                if ecode != ldb.ERR_REFERRAL:
                    self.fail("Failed to REFER when trying to modify %s %s" %
                              (dn, attr))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    try:
                        tmpdb = SamDB(address,
                                      credentials=CREDS,
                                      session_info=system_session(LP),
                                      lp=LP)
                        tmpdb.modify(msg)
                    except ldb.LdbError as e:
                        self.fail("couldn't modify referred location %s" %
                                  address)

                    if address.lower().startswith(
                            self.samdb.domain_dns_name()):
                        self.fail(
                            "referral address did not give a specific DC")

    def test_modify_nonreplicated_attributes(self):
        # some timestamp ones
        dn = 'CN=Guest,CN=Users,' + self.base_dn
        value = '123456789'
        for attr in ['badPwdCount', 'lastLogon', 'lastLogoff']:
            m = ldb.Message()
            m.dn = ldb.Dn(self.samdb, dn)
            m[attr] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, attr)
            # Windows refers these ones even though they are non-replicated
            try:
                self.samdb.modify(m)
                self.fail("Failed to fail to modify %s %s" % (dn, attr))
            except ldb.LdbError as e2:
                (ecode, emsg) = e2.args
                if ecode != ldb.ERR_REFERRAL:
                    self.fail("Failed to REFER when trying to modify %s %s" %
                              (dn, attr))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    if address.lower().startswith(
                            self.samdb.domain_dns_name()):
                        self.fail(
                            "referral address did not give a specific DC")

    def test_modify_nonreplicated_reps_attributes(self):
        # some timestamp ones
        dn = self.base_dn

        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, dn)
        attr = 'repsFrom'

        res = self.samdb.search(dn, scope=ldb.SCOPE_BASE, attrs=['repsFrom'])
        rep = ndr_unpack(drsblobs.repsFromToBlob,
                         res[0]['repsFrom'][0],
                         allow_remaining=True)
        rep.ctr.result_last_attempt = -1
        value = ndr_pack(rep)

        m[attr] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, attr)
        try:
            self.samdb.modify(m)
            self.fail("Failed to fail to modify %s %s" % (dn, attr))
        except ldb.LdbError as e3:
            (ecode, emsg) = e3.args
            if ecode != ldb.ERR_REFERRAL:
                self.fail("Failed to REFER when trying to modify %s %s" %
                          (dn, attr))
            else:
                m = re.search(r'(ldap://[^>]+)>', emsg)
                if m is None:
                    self.fail("referral seems not to refer to anything")
                address = m.group(1)

                if address.lower().startswith(self.samdb.domain_dns_name()):
                    self.fail("referral address did not give a specific DC")

    def test_delete_special_objects(self):
        dn = 'CN=Guest,CN=Users,' + self.base_dn
        try:
            self.samdb.delete(dn)
            self.fail("Failed to fail to delete %s" % (dn))
        except ldb.LdbError as e4:
            (ecode, emsg) = e4.args
            if ecode != ldb.ERR_REFERRAL:
                print(ecode, emsg)
                self.fail("Failed to REFER when trying to delete %s" % dn)
            else:
                m = re.search(r'(ldap://[^>]+)>', emsg)
                if m is None:
                    self.fail("referral seems not to refer to anything")
                address = m.group(1)

                if address.lower().startswith(self.samdb.domain_dns_name()):
                    self.fail("referral address did not give a specific DC")

    def test_no_delete_nonexistent_objects(self):
        dn = 'CN=does-not-exist-%s,CN=Users,%s' % (self.tag, self.base_dn)
        try:
            self.samdb.delete(dn)
            self.fail("Failed to fail to delete %s" % (dn))
        except ldb.LdbError as e5:
            (ecode, emsg) = e5.args
            if ecode != ldb.ERR_NO_SUCH_OBJECT:
                print(ecode, emsg)
                self.fail("Failed to NO_SUCH_OBJECT when trying to delete "
                          "%s (which does not exist)" % dn)
Exemple #51
0
class DsdbTests(TestCase):
    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEquals(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          str(res[0]["replPropertyMetaData"]))
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          str(res[0]["replPropertyMetaData"]))
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          str(res[0]["replPropertyMetaData"]))
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, [
            "local_oid:1.3.6.1.4.1.7165.4.3.14:0",
            "local_oid:1.3.6.1.4.1.7165.4.3.25:0"
        ])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          str(res[0]["replPropertyMetaData"]))
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = long(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val",
                                                ldb.FLAG_MOD_REPLACE,
                                                "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          str(res[0]["replPropertyMetaData"]))
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = long(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = long(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(13),
                          "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 1)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(
            dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description",
                                                      version + 2)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "description"),
            version + 2)

    def test_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               credentials=self.creds,
                               lp=self.lp)

            self.samdb.transaction_start()

            dn = "cn=test_db_lock_user,cn=users," + str(basedn)
            self.samdb.add({
                "dn": dn,
                "objectclass": "user",
            })
            self.samdb.delete(dn)

            # Obtain a write lock
            self.samdb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            self.samdb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        # Release the locks
        for l in res:
            pass

        end = time.time()
        self.assertGreater(end - start, 1.9)

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               credentials=self.creds,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        dn = "cn=test_db_lock_user,cn=users," + str(basedn)
        self.samdb.add({
            "dn": dn,
            "objectclass": "user",
        })
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the parent releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            self.samdb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock3(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               credentials=self.creds,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")

        # This will end up in the top level db
        dn = "@DSDB_LOCK_TEST"
        self.samdb.add({"dn": dn})
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        self.assertGreater(end - start, 1.9)
        os.write(w1, b"prepared")

        # Drop the write lock
        self.samdb.transaction_cancel()

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)
        self.assertEqual(got_pid, pid)

    def _test_full_db_lock1(self, backend_path):
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del (self.samdb)
            gc.collect()

            backenddb = ldb.Ldb(backend_path)

            backenddb.transaction_start()

            backenddb.add({"dn": "@DSDB_LOCK_TEST"})
            backenddb.delete("@DSDB_LOCK_TEST")

            # Obtain a write lock
            backenddb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            backenddb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        end = time.time()
        self.assertGreater(end - start, 1.9)

        # Release the locks
        for l in res:
            pass

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)

    def test_full_db_lock1_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)

    def _test_full_db_lock2(self, backend_path):
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:

            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               credentials=self.creds,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)
            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # In the parent, close the main DB, re-open just one DB
        del (self.samdb)
        gc.collect()
        backenddb = ldb.Ldb(backend_path)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        backenddb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        backenddb.add({"dn": "@DSDB_LOCK_TEST"})
        backenddb.delete("@DSDB_LOCK_TEST")
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        backenddb.transaction_prepare_commit()
        end = time.time()

        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            backenddb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)

    def test_full_db_lock2_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(
                expression="cn=Administrator",
                scope=ldb.SCOPE_SUBTREE,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:0" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(
                expression="cn=Administrator",
                scope=ldb.SCOPE_SUBTREE,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:1" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            if e[0] != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail(
                    "Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                    % e[1])
Exemple #52
0
class VLVTests(samba.tests.TestCase):
    def create_user(self, i, n, prefix='vlvtest', suffix='', attrs=None):
        name = "%s%d%s" % (prefix, i, suffix)
        user = {
            'cn':
            name,
            "objectclass":
            "user",
            'givenName':
            "abcdefghijklmnopqrstuvwxyz"[i % 26],
            "roomNumber":
            "%sbc" % (n - i),
            "carLicense":
            "后来经",
            "employeeNumber":
            "%s%sx" % (abs(i * (99 - i)), '\n' * (i & 255)),
            "accountExpires":
            "%s" % (10**9 + 1000000 * i),
            "msTSExpireDate4":
            "19%02d0101010000.0Z" % (i % 100),
            "flags":
            str(i * (n - i)),
            "serialNumber":
            "abc %s%s%s" % (
                'AaBb |-/'[i & 7],
                ' 3z}'[i & 3],
                '"@'[i & 1],
            ),
        }

        # _user_broken_attrs tests are broken due to problems outside
        # of VLV.
        _user_broken_attrs = {
            # Sort doesn't look past a NUL byte.
            "photo":
            "\x00%d" % (n - i),
            "audio":
            "%sn octet string %s%s ♫♬\x00lalala" %
            ('Aa'[i & 1], chr(i & 255), i),
            "displayNamePrintable":
            "%d\x00%c" % (i, i & 255),
            "adminDisplayName":
            "%d\x00b" % (n - i),
            "title":
            "%d%sb" % (n - i, '\x00' * i),
            "comment":
            "Favourite colour is %d" % (n % (i + 1)),

            # Names that vary only in case. Windows returns
            # equivalent addresses in the order they were put
            # in ('a st', 'A st',...).
            "street":
            "%s st" % (chr(65 | (i & 14) | ((i & 1) * 32))),
        }

        if attrs is not None:
            user.update(attrs)

        user['dn'] = "cn=%s,%s" % (user['cn'], self.ou)

        if opts.skip_attr_regex:
            match = re.compile(opts.skip_attr_regex).search
            for k in user.keys():
                if match(k):
                    del user[k]

        self.users.append(user)
        self.ldb.add(user)
        return user

    def setUp(self):
        super(VLVTests, self).setUp()
        self.ldb = SamDB(host,
                         credentials=creds,
                         session_info=system_session(lp),
                         lp=lp)

        self.base_dn = self.ldb.domain_dn()
        self.ou = "ou=vlv,%s" % self.base_dn
        if opts.delete_in_setup:
            try:
                self.ldb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError as e:
                print "tried deleting %s, got error %s" % (self.ou, e)
        self.ldb.add({"dn": self.ou, "objectclass": "organizationalUnit"})

        self.users = []
        for i in range(N_ELEMENTS):
            self.create_user(i, N_ELEMENTS)

        attrs = self.users[0].keys()
        self.binary_sorted_keys = [
            'audio', 'photo', "msTSExpireDate4", 'serialNumber',
            "displayNamePrintable"
        ]

        self.numeric_sorted_keys = ['flags', 'accountExpires']

        self.timestamp_keys = ['msTSExpireDate4']

        self.int64_keys = set(['accountExpires'])

        self.locale_sorted_keys = [
            x for x in attrs
            if x not in (self.binary_sorted_keys + self.numeric_sorted_keys)
        ]

        # don't try spaces, etc in cn
        self.delicate_keys = ['cn']

    def tearDown(self):
        super(VLVTests, self).tearDown()
        if not opts.delete_in_setup:
            self.ldb.delete(self.ou, ['tree_delete:1'])

    def get_full_list(self, attr, include_cn=False):
        """Fetch the whole list sorted on the attribute, using the VLV.
        This way you get a VLV cookie."""
        n_users = len(self.users)
        sort_control = "server_sort:1:0:%s" % attr
        half_n = n_users // 2
        vlv_search = "vlv:1:%d:%d:%d:0" % (half_n, half_n, half_n + 1)
        attrs = [attr]
        if include_cn:
            attrs.append('cn')
        res = self.ldb.search(self.ou,
                              scope=ldb.SCOPE_ONELEVEL,
                              attrs=attrs,
                              controls=[sort_control, vlv_search])
        if include_cn:
            full_results = [(x[attr][0], x['cn'][0]) for x in res]
        else:
            full_results = [x[attr][0].lower() for x in res]
        controls = res.controls
        return full_results, controls, sort_control

    def get_expected_order(self, attr, expression=None):
        """Fetch the whole list sorted on the attribute, using sort only."""
        sort_control = "server_sort:1:0:%s" % attr
        res = self.ldb.search(self.ou,
                              scope=ldb.SCOPE_ONELEVEL,
                              expression=expression,
                              attrs=[attr],
                              controls=[sort_control])
        results = [x[attr][0] for x in res]
        return results

    def delete_user(self, user):
        self.ldb.delete(user['dn'])
        del self.users[self.users.index(user)]

    def get_gte_tests_and_order(self, attr, expression=None):
        expected_order = self.get_expected_order(attr, expression=expression)
        gte_users = []
        if attr in self.delicate_keys:
            gte_keys = [
                '3',
                'abc',
                '¹',
                'ŋđ¼³ŧ“«đð',
                '桑巴',
            ]
        elif attr in self.timestamp_keys:
            gte_keys = [
                '18560101010000.0Z',
                '19140103010000.0Z',
                '19560101010010.0Z',
                '19700101000000.0Z',
                '19991231211234.3Z',
                '20061111211234.0Z',
                '20390901041234.0Z',
                '25560101010000.0Z',
            ]
        elif attr not in self.numeric_sorted_keys:
            gte_keys = [
                '3',
                'abc',
                ' ',
                '!@#!@#!',
                'kōkako',
                '¹',
                'ŋđ¼³ŧ“«đð',
                '\n\t\t',
                '桑巴',
                'zzzz',
            ]
            if expected_order:
                gte_keys.append(expected_order[len(expected_order) // 2] +
                                ' tail')

        else:
            # "numeric" means positive integers
            # doesn't work with -1, 3.14, ' 3', '9' * 20
            gte_keys = ['3', '1' * 10, '1', '9' * 7, '0']

            if attr in self.int64_keys:
                gte_keys += ['3' * 12, '71' * 8]

        for i, x in enumerate(gte_keys):
            user = self.create_user(i,
                                    N_ELEMENTS,
                                    prefix='gte',
                                    attrs={attr: x})
            gte_users.append(user)

        gte_order = self.get_expected_order(attr)
        for user in gte_users:
            self.delete_user(user)

        # for sanity's sake
        expected_order_2 = self.get_expected_order(attr, expression=expression)
        self.assertEqual(expected_order, expected_order_2)

        # Map gte tests to indexes in expected order. This will break
        # if gte_order and expected_order are differently ordered (as
        # it should).
        gte_map = {}

        # index to the first one with each value
        index_map = {}
        for i, k in enumerate(expected_order):
            if k not in index_map:
                index_map[k] = i

        keys = []
        for k in gte_order:
            if k in index_map:
                i = index_map[k]
                gte_map[k] = i
                for k in keys:
                    gte_map[k] = i
                keys = []
            else:
                keys.append(k)

        for k in keys:
            gte_map[k] = len(expected_order)

        if False:
            print "gte_map:"
            for k in gte_order:
                print "   %10s => %10s" % (k, gte_map[k])

        return gte_order, expected_order, gte_map

    def assertCorrectResults(self, results, expected_order, offset, before,
                             after):
        """A helper to calculate offsets correctly and say as much as possible
        when something goes wrong."""

        start = max(offset - before - 1, 0)
        end = offset + after
        expected_results = expected_order[start:end]

        # if it is a tuple with the cn, drop the cn
        if expected_results and isinstance(expected_results[0], tuple):
            expected_results = [x[0] for x in expected_results]

        if expected_results == results:
            return

        if expected_order is not None:
            print "expected order: %s" % expected_order[:20]
            if len(expected_order) > 20:
                print "... and %d more not shown" % (len(expected_order) - 20)

        print "offset %d before %d after %d" % (offset, before, after)
        print "start %d end %d" % (start, end)
        print "expected: %s" % expected_results
        print "got     : %s" % results
        self.assertEquals(expected_results, results)

    def test_server_vlv_with_cookie(self):
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]
        for attr in attrs:
            expected_order = self.get_expected_order(attr)
            sort_control = "server_sort:1:0:%s" % attr
            res = None
            n = len(self.users)
            for before in [10, 0, 3, 1, 4, 5, 2]:
                for after in [0, 3, 1, 4, 5, 2, 7]:
                    for offset in range(max(1, before - 2),
                                        min(n - after + 2, n)):
                        if res is None:
                            vlv_search = "vlv:1:%d:%d:%d:0" % (before, after,
                                                               offset)
                        else:
                            cookie = get_cookie(res.controls, n)
                            vlv_search = ("vlv:1:%d:%d:%d:%s:%s" %
                                          (before, after, offset, n, cookie))

                        res = self.ldb.search(
                            self.ou,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=[attr],
                            controls=[sort_control, vlv_search])

                        results = [x[attr][0] for x in res]

                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def run_index_tests_with_expressions(self, expressions):
        # Here we don't test every before/after combination.
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]
        for attr in attrs:
            for expression in expressions:
                expected_order = self.get_expected_order(attr, expression)
                sort_control = "server_sort:1:0:%s" % attr
                res = None
                n = len(expected_order)
                for before in range(0, 11):
                    after = before
                    for offset in range(max(1, before - 2),
                                        min(n - after + 2, n)):
                        if res is None:
                            vlv_search = "vlv:1:%d:%d:%d:0" % (before, after,
                                                               offset)
                        else:
                            cookie = get_cookie(res.controls)
                            vlv_search = ("vlv:1:%d:%d:%d:%s:%s" %
                                          (before, after, offset, n, cookie))

                        res = self.ldb.search(
                            self.ou,
                            expression=expression,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=[attr],
                            controls=[sort_control, vlv_search])

                        results = [x[attr][0] for x in res]

                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def test_server_vlv_with_expression(self):
        """What happens when we run the VLV with an expression?"""
        expressions = [
            "(objectClass=*)",
            "(cn=%s)" % self.users[-1]['cn'],
            "(roomNumber=%s)" % self.users[0]['roomNumber'],
        ]
        self.run_index_tests_with_expressions(expressions)

    def test_server_vlv_with_failing_expression(self):
        """What happens when we run the VLV on an expression that matches
        nothing?"""
        expressions = [
            "(samaccountname=testferf)",
            "(cn=hefalump)",
        ]
        self.run_index_tests_with_expressions(expressions)

    def run_gte_tests_with_expressions(self, expressions):
        # Here we don't test every before/after combination.
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]
        for expression in expressions:
            for attr in attrs:
                gte_order, expected_order, gte_map = \
                    self.get_gte_tests_and_order(attr, expression)
                # In case there is some order dependency, disorder tests
                gte_tests = gte_order[:]
                random.seed(2)
                random.shuffle(gte_tests)
                res = None
                sort_control = "server_sort:1:0:%s" % attr

                expected_order = self.get_expected_order(attr, expression)
                sort_control = "server_sort:1:0:%s" % attr
                res = None
                for before in range(0, 11):
                    after = before
                    for gte in gte_tests:
                        if res is not None:
                            cookie = get_cookie(res.controls)
                        else:
                            cookie = None
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        gte=gte,
                                                        cookie=cookie)

                        res = self.ldb.search(
                            self.ou,
                            scope=ldb.SCOPE_ONELEVEL,
                            expression=expression,
                            attrs=[attr],
                            controls=[sort_control, vlv_search])

                        results = [x[attr][0] for x in res]
                        offset = gte_map.get(gte, len(expected_order))

                        # here offset is 0-based
                        start = max(offset - before, 0)
                        end = offset + 1 + after

                        expected_results = expected_order[start:end]

                        self.assertEquals(expected_results, results)

    def test_vlv_gte_with_expression(self):
        """What happens when we run the VLV with an expression?"""
        expressions = [
            "(objectClass=*)",
            "(cn=%s)" % self.users[-1]['cn'],
            "(roomNumber=%s)" % self.users[0]['roomNumber'],
        ]
        self.run_gte_tests_with_expressions(expressions)

    def test_vlv_gte_with_failing_expression(self):
        """What happens when we run the VLV on an expression that matches
        nothing?"""
        expressions = [
            "(samaccountname=testferf)",
            "(cn=hefalump)",
        ]
        self.run_gte_tests_with_expressions(expressions)

    def test_server_vlv_with_cookie_while_adding_and_deleting(self):
        """What happens if we add or remove items in the middle of the VLV?

        Nothing. The search and the sort is not repeated, and we only
        deal with the objects originally found.
        """
        attrs = ['cn'] + [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]
        user_number = 0
        iteration = 0
        for attr in attrs:
            full_results, controls, sort_control = \
                            self.get_full_list(attr, True)
            original_n = len(self.users)

            expected_order = full_results
            random.seed(1)

            for before in range(0, 3) + [6, 11, 19]:
                for after in range(0, 3) + [6, 11, 19]:
                    start = max(before - 1, 1)
                    end = max(start + 4, original_n - after + 2)
                    for offset in range(start, end):
                        #if iteration > 2076:
                        #    return
                        cookie = get_cookie(controls, original_n)
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        offset=offset,
                                                        n=original_n,
                                                        cookie=cookie)

                        iteration += 1
                        res = self.ldb.search(
                            self.ou,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=[attr],
                            controls=[sort_control, vlv_search])

                        controls = res.controls
                        results = [x[attr][0] for x in res]
                        real_offset = max(1, min(offset, len(expected_order)))

                        expected_results = []
                        skipped = 0
                        begin_offset = max(real_offset - before - 1, 0)
                        real_before = min(before, real_offset - 1)
                        real_after = min(after,
                                         len(expected_order) - real_offset)

                        for x in expected_order[begin_offset:]:
                            if x is not None:
                                expected_results.append(x[0])
                                if (len(expected_results) == real_before +
                                        real_after + 1):
                                    break
                            else:
                                skipped += 1

                        if expected_results != results:
                            print("attr %s before %d after %d offset %d" %
                                  (attr, before, after, offset))
                        self.assertEquals(expected_results, results)

                        n = len(self.users)
                        if random.random() < 0.1 + (n < 5) * 0.05:
                            if n == 0:
                                i = 0
                            else:
                                i = random.randrange(n)
                            user = self.create_user(i,
                                                    n,
                                                    suffix='-%s' % user_number)
                            user_number += 1
                        if random.random() < 0.1 + (n > 50) * 0.02 and n:
                            index = random.randrange(n)
                            user = self.users.pop(index)

                            self.ldb.delete(user['dn'])

                            replaced = (user[attr], user['cn'])
                            if replaced in expected_order:
                                i = expected_order.index(replaced)
                                expected_order[i] = None

    def test_server_vlv_with_cookie_while_changing(self):
        """What happens if we modify items in the middle of the VLV?

        The expected behaviour (as found on Windows) is the sort is
        not repeated, but the changes in attributes are reflected.
        """
        attrs = [
            x for x in self.users[0].keys()
            if x not in ('dn', 'objectclass', 'cn')
        ]
        for attr in attrs:
            n_users = len(self.users)
            expected_order = [x.upper() for x in self.get_expected_order(attr)]
            sort_control = "server_sort:1:0:%s" % attr
            res = None
            i = 0

            # First we'll fetch the whole list so we know the original
            # sort order. This is necessary because we don't know how
            # the server will order equivalent items. We are using the
            # dn as a key.
            half_n = n_users // 2
            vlv_search = "vlv:1:%d:%d:%d:0" % (half_n, half_n, half_n + 1)
            res = self.ldb.search(self.ou,
                                  scope=ldb.SCOPE_ONELEVEL,
                                  attrs=['dn', attr],
                                  controls=[sort_control, vlv_search])

            results = [x[attr][0].upper() for x in res]
            #self.assertEquals(expected_order, results)

            dn_order = [str(x['dn']) for x in res]
            values = results[:]

            for before in range(0, 3):
                for after in range(0, 3):
                    for offset in range(1 + before, n_users - after):
                        cookie = get_cookie(res.controls, len(self.users))
                        vlv_search = (
                            "vlv:1:%d:%d:%d:%s:%s" %
                            (before, after, offset, len(self.users), cookie))

                        res = self.ldb.search(
                            self.ou,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=['dn', attr],
                            controls=[sort_control, vlv_search])

                        dn_results = [str(x['dn']) for x in res]
                        dn_expected = dn_order[offset - before - 1:offset +
                                               after]

                        self.assertEquals(dn_expected, dn_results)

                        results = [x[attr][0].upper() for x in res]

                        self.assertCorrectResults(results, values, offset,
                                                  before, after)

                        i += 1
                        if i % 3 == 2:
                            if (attr in self.locale_sorted_keys
                                    or attr in self.binary_sorted_keys):
                                i1 = i % n_users
                                i2 = (i ^ 255) % n_users
                                dn1 = dn_order[i1]
                                dn2 = dn_order[i2]
                                v2 = values[i2]

                                if v2 in self.locale_sorted_keys:
                                    v2 += '-%d' % i
                                cn1 = dn1.split(',', 1)[0][3:]
                                cn2 = dn2.split(',', 1)[0][3:]

                                values[i1] = v2

                                m = ldb.Message()
                                m.dn = ldb.Dn(self.ldb, dn1)
                                m[attr] = ldb.MessageElement(
                                    v2, ldb.FLAG_MOD_REPLACE, attr)

                                self.ldb.modify(m)

    def test_server_vlv_fractions_with_cookie(self):
        """What happens when the count is set to an arbitrary number?

        In that case the offset and the count form a fraction, and the
        VLV should be centred at a point offset/count of the way
        through. For example, if offset is 3 and count is 6, the VLV
        should be looking around halfway. The actual algorithm is a
        bit fiddlier than that, because of the one-basedness of VLV.
        """
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]

        n_users = len(self.users)

        random.seed(4)

        for attr in attrs:
            full_results, controls, sort_control = self.get_full_list(attr)
            self.assertEqual(len(full_results), n_users)
            for before in range(0, 2):
                for after in range(0, 2):
                    for denominator in range(1, 20):
                        for offset in range(1, denominator + 3):
                            cookie = get_cookie(controls, len(self.users))
                            vlv_search = (
                                "vlv:1:%d:%d:%d:%s:%s" %
                                (before, after, offset, denominator, cookie))
                            try:
                                res = self.ldb.search(
                                    self.ou,
                                    scope=ldb.SCOPE_ONELEVEL,
                                    attrs=[attr],
                                    controls=[sort_control, vlv_search])
                            except ldb.LdbError as e:
                                if offset != 0:
                                    raise
                                print(
                                    "offset %d denominator %d raised error "
                                    "expected error %s\n"
                                    "(offset zero is illegal unless "
                                    "content count is zero)" %
                                    (offset, denominator, e))
                                continue

                            results = [x[attr][0].lower() for x in res]

                            if denominator == 0:
                                denominator = n_users
                                if offset == 0:
                                    offset = denominator
                            elif denominator == 1:
                                # the offset can only be 1, but the 1/1 case
                                # means something special
                                if offset == 1:
                                    real_offset = n_users
                                else:
                                    real_offset = 1
                            else:
                                if offset > denominator:
                                    offset = denominator
                                real_offset = (1 + int(
                                    round((n_users - 1) * (offset - 1) /
                                          (denominator - 1.0))))

                            self.assertCorrectResults(results, full_results,
                                                      real_offset, before,
                                                      after)

                            controls = res.controls
                            if False:
                                for c in list(controls):
                                    cstr = str(c)
                                    if cstr.startswith('vlv_resp'):
                                        bits = cstr.rsplit(':')
                                        print("the answer is %s; we said %d" %
                                              (bits[2], real_offset))
                                        break

    def test_server_vlv_no_cookie(self):
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]

        for attr in attrs:
            expected_order = self.get_expected_order(attr)
            sort_control = "server_sort:1:0:%s" % attr
            for before in range(0, 5):
                for after in range(0, 7):
                    for offset in range(1 + before, len(self.users) - after):
                        res = self.ldb.search(self.ou,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=[attr],
                                              controls=[
                                                  sort_control,
                                                  "vlv:1:%d:%d:%d:0" %
                                                  (before, after, offset)
                                              ])
                        results = [x[attr][0] for x in res]
                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def get_expected_order_showing_deleted(
            self,
            attr,
            expression="(|(cn=vlvtest*)(cn=vlv-deleted*))",
            base=None,
            scope=ldb.SCOPE_SUBTREE):
        """Fetch the whole list sorted on the attribute, using sort only,
        searching in the entire tree, not just our OU. This is the
        way to find deleted objects.
        """
        if base is None:
            base = self.base_dn
        sort_control = "server_sort:1:0:%s" % attr
        controls = [sort_control, "show_deleted:1"]

        res = self.ldb.search(base,
                              scope=scope,
                              expression=expression,
                              attrs=[attr],
                              controls=controls)
        results = [x[attr][0] for x in res]
        return results

    def add_deleted_users(self, n):
        deleted_users = [
            self.create_user(i, n, prefix='vlv-deleted') for i in range(n)
        ]

        for user in deleted_users:
            self.delete_user(user)

    def test_server_vlv_no_cookie_show_deleted(self):
        """What do we see with the show_deleted control?"""
        attrs = [
            'objectGUID', 'cn', 'sAMAccountName', 'objectSid', 'name',
            'whenChanged', 'usnChanged'
        ]

        # add some deleted users first, just in case there are none
        self.add_deleted_users(6)
        random.seed(22)
        expression = "(|(cn=vlvtest*)(cn=vlv-deleted*))"

        for attr in attrs:
            show_deleted_control = "show_deleted:1"
            expected_order = self.get_expected_order_showing_deleted(
                attr, expression)
            n = len(expected_order)
            sort_control = "server_sort:1:0:%s" % attr
            for before in [3, 1, 0]:
                for after in [0, 2]:
                    # don't test every position, because there could be hundreds.
                    # jump back and forth instead
                    for i in range(20):
                        offset = random.randrange(max(1, before - 2),
                                                  min(n - after + 2, n))
                        res = self.ldb.search(
                            self.base_dn,
                            expression=expression,
                            scope=ldb.SCOPE_SUBTREE,
                            attrs=[attr],
                            controls=[
                                sort_control, show_deleted_control,
                                "vlv:1:%d:%d:%d:0" % (before, after, offset)
                            ])
                        results = [x[attr][0] for x in res]
                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def test_server_vlv_no_cookie_show_deleted_only(self):
        """What do we see with the show_deleted control when we're not looking
        at any non-deleted things"""
        attrs = [
            'objectGUID',
            'cn',
            'sAMAccountName',
            'objectSid',
            'whenChanged',
        ]

        # add some deleted users first, just in case there are none
        self.add_deleted_users(4)
        base = 'CN=Deleted Objects,%s' % self.base_dn
        expression = "(cn=vlv-deleted*)"
        for attr in attrs:
            show_deleted_control = "show_deleted:1"
            expected_order = self.get_expected_order_showing_deleted(
                attr,
                expression=expression,
                base=base,
                scope=ldb.SCOPE_ONELEVEL)
            print("searching for attr %s amongst %d deleted objects" %
                  (attr, len(expected_order)))
            sort_control = "server_sort:1:0:%s" % attr
            step = max(len(expected_order) // 10, 1)
            for before in [3, 0]:
                for after in [0, 2]:
                    for offset in range(1 + before,
                                        len(expected_order) - after, step):
                        res = self.ldb.search(
                            base,
                            expression=expression,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=[attr],
                            controls=[
                                sort_control, show_deleted_control,
                                "vlv:1:%d:%d:%d:0" % (before, after, offset)
                            ])
                        results = [x[attr][0] for x in res]
                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def test_server_vlv_with_cookie_show_deleted(self):
        """What do we see with the show_deleted control?"""
        attrs = [
            'objectGUID', 'cn', 'sAMAccountName', 'objectSid', 'name',
            'whenChanged', 'usnChanged'
        ]
        self.add_deleted_users(6)
        random.seed(23)
        for attr in attrs:
            expected_order = self.get_expected_order(attr)
            sort_control = "server_sort:1:0:%s" % attr
            res = None
            show_deleted_control = "show_deleted:1"
            expected_order = self.get_expected_order_showing_deleted(attr)
            n = len(expected_order)
            expression = "(|(cn=vlvtest*)(cn=vlv-deleted*))"
            for before in [3, 2, 1, 0]:
                after = before
                for i in range(20):
                    offset = random.randrange(max(1, before - 2),
                                              min(n - after + 2, n))
                    if res is None:
                        vlv_search = "vlv:1:%d:%d:%d:0" % (before, after,
                                                           offset)
                    else:
                        cookie = get_cookie(res.controls, n)
                        vlv_search = ("vlv:1:%d:%d:%d:%s:%s" %
                                      (before, after, offset, n, cookie))

                    res = self.ldb.search(self.base_dn,
                                          expression=expression,
                                          scope=ldb.SCOPE_SUBTREE,
                                          attrs=[attr],
                                          controls=[
                                              sort_control, vlv_search,
                                              show_deleted_control
                                          ])

                    results = [x[attr][0] for x in res]

                    self.assertCorrectResults(results, expected_order, offset,
                                              before, after)

    def test_server_vlv_gte_with_cookie(self):
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]
        for attr in attrs:
            gte_order, expected_order, gte_map = \
                                        self.get_gte_tests_and_order(attr)
            # In case there is some order dependency, disorder tests
            gte_tests = gte_order[:]
            random.seed(1)
            random.shuffle(gte_tests)
            res = None
            sort_control = "server_sort:1:0:%s" % attr
            for before in [0, 1, 2, 4]:
                for after in [0, 1, 3, 6]:
                    for gte in gte_tests:
                        if res is not None:
                            cookie = get_cookie(res.controls, len(self.users))
                        else:
                            cookie = None
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        gte=gte,
                                                        cookie=cookie)

                        res = self.ldb.search(
                            self.ou,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=[attr],
                            controls=[sort_control, vlv_search])

                        results = [x[attr][0] for x in res]
                        offset = gte_map.get(gte, len(expected_order))

                        # here offset is 0-based
                        start = max(offset - before, 0)
                        end = offset + 1 + after

                        expected_results = expected_order[start:end]

                        self.assertEquals(expected_results, results)

    def test_server_vlv_gte_no_cookie(self):
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]
        iteration = 0
        for attr in attrs:
            gte_order, expected_order, gte_map = \
                                        self.get_gte_tests_and_order(attr)
            # In case there is some order dependency, disorder tests
            gte_tests = gte_order[:]
            random.seed(1)
            random.shuffle(gte_tests)

            sort_control = "server_sort:1:0:%s" % attr
            for before in [0, 1, 3]:
                for after in [0, 4]:
                    for gte in gte_tests:
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        gte=gte)

                        res = self.ldb.search(
                            self.ou,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=[attr],
                            controls=[sort_control, vlv_search])
                        results = [x[attr][0] for x in res]

                        # here offset is 0-based
                        offset = gte_map.get(gte, len(expected_order))
                        start = max(offset - before, 0)
                        end = offset + after + 1
                        expected_results = expected_order[start:end]
                        iteration += 1
                        if expected_results != results:
                            middle = expected_order[len(expected_order) // 2]
                            print expected_results, results
                            print middle
                            print expected_order
                            print
                            print(
                                "\nattr %s offset %d before %d "
                                "after %d gte %s" %
                                (attr, offset, before, after, gte))
                        self.assertEquals(expected_results, results)

    def test_multiple_searches(self):
        """The maximum number of concurrent vlv searches per connection is
        currently set at 3. That means if you open 4 VLV searches the
        cookie on the first one should fail.
        """
        # Windows has a limit of 10 VLVs where there are low numbers
        # of objects in each search.
        attrs = ([
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ] * 2)[:12]

        vlv_cookies = []
        for attr in attrs:
            sort_control = "server_sort:1:0:%s" % attr

            res = self.ldb.search(self.ou,
                                  scope=ldb.SCOPE_ONELEVEL,
                                  attrs=[attr],
                                  controls=[sort_control, "vlv:1:1:1:1:0"])

            cookie = get_cookie(res.controls, len(self.users))
            vlv_cookies.append(cookie)
            time.sleep(0.2)

        # now this one should fail
        self.assertRaises(
            ldb.LdbError,
            self.ldb.search,
            self.ou,
            scope=ldb.SCOPE_ONELEVEL,
            attrs=[attr],
            controls=[sort_control,
                      "vlv:1:1:1:1:0:%s" % vlv_cookies[0]])

        # and this one should succeed
        res = self.ldb.search(
            self.ou,
            scope=ldb.SCOPE_ONELEVEL,
            attrs=[attr],
            controls=[sort_control,
                      "vlv:1:1:1:1:0:%s" % vlv_cookies[-1]])

        # this one should fail because it is a new connection and
        # doesn't share cookies
        new_ldb = SamDB(host,
                        credentials=creds,
                        session_info=system_session(lp),
                        lp=lp)

        self.assertRaises(
            ldb.LdbError,
            new_ldb.search,
            self.ou,
            scope=ldb.SCOPE_ONELEVEL,
            attrs=[attr],
            controls=[sort_control,
                      "vlv:1:1:1:1:0:%s" % vlv_cookies[-1]])

        # but now without the critical flag it just does no VLV.
        new_ldb.search(
            self.ou,
            scope=ldb.SCOPE_ONELEVEL,
            attrs=[attr],
            controls=[sort_control,
                      "vlv:0:1:1:1:0:%s" % vlv_cookies[-1]])
class AuditLogDsdbTests(AuditLogTestBase):
    def setUp(self):
        self.message_type = MSG_DSDB_LOG
        self.event_type = DSDB_EVENT_NAME
        super(AuditLogDsdbTests, self).setUp()

        self.server_ip = os.environ["SERVER_IP"]

        host = "ldap://%s" % os.environ["SERVER"]
        self.ldb = SamDB(url=host,
                         session_info=system_session(),
                         credentials=self.get_credentials(),
                         lp=self.get_loadparm())
        self.server = os.environ["SERVER"]

        # Gets back the basedn
        self.base_dn = self.ldb.domain_dn()

        # Get the old "dSHeuristics" if it was set
        dsheuristics = self.ldb.get_dsheuristics()

        # Set the "dSHeuristics" to activate the correct "userPassword"
        # behaviour
        self.ldb.set_dsheuristics("000000001")

        # Reset the "dSHeuristics" as they were before
        self.addCleanup(self.ldb.set_dsheuristics, dsheuristics)

        # Get the old "minPwdAge"
        minPwdAge = self.ldb.get_minPwdAge()

        # Set it temporarily to "0"
        self.ldb.set_minPwdAge("0")
        self.base_dn = self.ldb.domain_dn()

        # Reset the "minPwdAge" as it was before
        self.addCleanup(self.ldb.set_minPwdAge, minPwdAge)

        # (Re)adds the test user USER_NAME with password USER_PASS
        delete_force(self.ldb, "cn=" + USER_NAME + ",cn=users," + self.base_dn)
        self.ldb.add({
            "dn": "cn=" + USER_NAME + ",cn=users," + self.base_dn,
            "objectclass": "user",
            "sAMAccountName": USER_NAME,
            "userPassword": USER_PASS
        })

    #
    # Discard the messages from the setup code
    #
    def discardSetupMessages(self, dn):
        self.waitForMessages(2, dn=dn)
        self.discardMessages()

    def tearDown(self):
        self.discardMessages()
        super(AuditLogDsdbTests, self).tearDown()

    def haveExpectedTxn(self, expected):
        if self.context["txnMessage"] is not None:
            txn = self.context["txnMessage"]["dsdbTransaction"]
            if txn["transactionId"] == expected:
                return True
        return False

    def waitForTransaction(self, expected, connection=None):
        """Wait for a transaction message to arrive
        The connection is passed through to keep the connection alive
        until all the logging messages have been received.
        """

        self.connection = connection

        start_time = time.time()
        while not self.haveExpectedTxn(expected):
            self.msg_ctx.loop_once(0.1)
            if time.time() - start_time > 1:
                self.connection = None
                return ""

        self.connection = None
        return self.context["txnMessage"]

    def test_net_change_password(self):

        dn = "CN=" + USER_NAME + ",CN=Users," + self.base_dn
        self.discardSetupMessages(dn)

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server)
        password = "******"

        net.change_password(newpassword=password,
                            username=USER_NAME,
                            oldpassword=USER_PASS)

        messages = self.waitForMessages(1, net, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        # We skip the check for self.get_service_description() as this
        # is subject to a race between smbd and the s4 rpc_server code
        # as to which will set the description as it is DCE/RPC over SMB

        self.assertTrue(self.is_guid(audit["transactionId"]))

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["clearTextPassword"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertTrue(actions[0]["redacted"])
        self.assertEquals("replace", actions[0]["action"])

    def test_net_set_password(self):

        dn = "CN=" + USER_NAME + ",CN=Users," + self.base_dn
        self.discardSetupMessages(dn)

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server)
        password = "******"
        domain = lp.get("workgroup")

        net.set_password(newpassword=password,
                         account_name=USER_NAME,
                         domain_name=domain)
        messages = self.waitForMessages(1, net, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertEquals(dn, audit["dn"])
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        # We skip the check for self.get_service_description() as this
        # is subject to a race between smbd and the s4 rpc_server code
        # as to which will set the description as it is DCE/RPC over SMB

        self.assertTrue(self.is_guid(audit["transactionId"]))

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["clearTextPassword"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertTrue(actions[0]["redacted"])
        self.assertEquals("replace", actions[0]["action"])

    def test_ldap_change_password(self):

        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        new_password = samba.generate_random_password(32, 32)
        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.ldb.modify_ldif("dn: " + dn + "\n" + "changetype: modify\n" +
                             "delete: userPassword\n" + "userPassword: "******"\n" + "add: userPassword\n" +
                             "userPassword: "******"\n")

        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertEquals(dn, audit["dn"])
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["userPassword"]["actions"]
        self.assertEquals(2, len(actions))
        self.assertTrue(actions[0]["redacted"])
        self.assertEquals("delete", actions[0]["action"])
        self.assertTrue(actions[1]["redacted"])
        self.assertEquals("add", actions[1]["action"])

    def test_ldap_replace_password(self):

        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        new_password = samba.generate_random_password(32, 32)
        self.ldb.modify_ldif("dn: " + dn + "\n" + "changetype: modify\n" +
                             "replace: userPassword\n" + "userPassword: "******"\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")
        self.assertTrue(self.is_guid(audit["transactionId"]))

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["userPassword"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertTrue(actions[0]["redacted"])
        self.assertEquals("replace", actions[0]["action"])

    def test_ldap_add_user(self):

        # The setup code adds a user, so we check for the dsdb events
        # generated by it.
        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        messages = self.waitForMessages(2, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(2, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[1]["dsdbChange"]
        self.assertEquals("Add", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertEquals(dn, audit["dn"])
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")
        self.assertTrue(self.is_guid(audit["sessionId"]))
        self.assertTrue(self.is_guid(audit["transactionId"]))

        attributes = audit["attributes"]
        self.assertEquals(3, len(attributes))

        actions = attributes["objectclass"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        self.assertEquals(1, len(actions[0]["values"]))
        self.assertEquals("user", actions[0]["values"][0]["value"])

        actions = attributes["sAMAccountName"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        self.assertEquals(1, len(actions[0]["values"]))
        self.assertEquals(USER_NAME, actions[0]["values"][0]["value"])

        actions = attributes["userPassword"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        self.assertTrue(actions[0]["redacted"])

    def test_samdb_delete_user(self):

        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        self.ldb.deleteuser(USER_NAME)

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Delete", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        self.assertEquals(0, audit["statusCode"])
        self.assertEquals("Success", audit["status"])
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        transactionId = audit["transactionId"]
        message = self.waitForTransaction(transactionId)
        audit = message["dsdbTransaction"]
        self.assertEquals("commit", audit["action"])
        self.assertTrue(self.is_guid(audit["transactionId"]))
        self.assertTrue(audit["duration"] > 0)

    def test_samdb_delete_non_existent_dn(self):

        DOES_NOT_EXIST = "doesNotExist"
        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        dn = "cn=" + DOES_NOT_EXIST + ",cn=users," + self.base_dn
        try:
            self.ldb.delete(dn)
            self.fail("Exception not thrown")
        except Exception:
            pass

        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Delete", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertEquals(ERR_NO_SUCH_OBJECT, audit["statusCode"])
        self.assertEquals("No such object", audit["status"])
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        transactionId = audit["transactionId"]
        message = self.waitForTransaction(transactionId)
        audit = message["dsdbTransaction"]
        self.assertEquals("rollback", audit["action"])
        self.assertTrue(self.is_guid(audit["transactionId"]))
        self.assertTrue(audit["duration"] > 0)

    def test_create_and_delete_secret_over_lsa(self):

        dn = "cn=Test Secret,CN=System," + self.base_dn
        self.discardSetupMessages(dn)

        creds = self.insta_creds(template=self.get_credentials())
        lsa_conn = lsa.lsarpc("ncacn_np:%s" % self.server, self.get_loadparm(),
                              creds)
        lsa_handle = lsa_conn.OpenPolicy2(
            system_name="\\",
            attr=lsa.ObjectAttribute(),
            access_mask=security.SEC_FLAG_MAXIMUM_ALLOWED)
        secret_name = lsa.String()
        secret_name.string = "G$Test"
        lsa_conn.CreateSecret(handle=lsa_handle,
                              name=secret_name,
                              access_mask=security.SEC_FLAG_MAXIMUM_ALLOWED)

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Add", audit["operation"])
        self.assertTrue(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])

        # We skip the check for self.get_service_description() as this
        # is subject to a race between smbd and the s4 rpc_server code
        # as to which will set the description as it is DCE/RPC over SMB

        attributes = audit["attributes"]
        self.assertEquals(2, len(attributes))

        object_class = attributes["objectClass"]
        self.assertEquals(1, len(object_class["actions"]))
        action = object_class["actions"][0]
        self.assertEquals("add", action["action"])
        values = action["values"]
        self.assertEquals(1, len(values))
        self.assertEquals("secret", values[0]["value"])

        cn = attributes["cn"]
        self.assertEquals(1, len(cn["actions"]))
        action = cn["actions"][0]
        self.assertEquals("add", action["action"])
        values = action["values"]
        self.assertEquals(1, len(values))
        self.assertEquals("Test Secret", values[0]["value"])

        #
        # Now delete the secret.
        self.discardMessages()
        h = lsa_conn.OpenSecret(handle=lsa_handle,
                                name=secret_name,
                                access_mask=security.SEC_FLAG_MAXIMUM_ALLOWED)

        lsa_conn.DeleteObject(h)
        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")

        dn = "cn=Test Secret,CN=System," + self.base_dn
        audit = messages[0]["dsdbChange"]
        self.assertEquals("Delete", audit["operation"])
        self.assertTrue(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])

        # We skip the check for self.get_service_description() as this
        # is subject to a race between smbd and the s4 rpc_server code
        # as to which will set the description as it is DCE/RPC over SMB

    def test_modify(self):

        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        #
        # Add an attribute value
        #
        self.ldb.modify_ldif("dn: " + dn + "\n" + "changetype: modify\n" +
                             "add: carLicense\n" + "carLicense: license-01\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertEquals(dn, audit["dn"])
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(1, len(values))
        self.assertEquals("license-01", values[0]["value"])

        #
        # Add an another value to the attribute
        #
        self.discardMessages()
        self.ldb.modify_ldif("dn: " + dn + "\n" + "changetype: modify\n" +
                             "add: carLicense\n" + "carLicense: license-02\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")
        attributes = messages[0]["dsdbChange"]["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(1, len(values))
        self.assertEquals("license-02", values[0]["value"])

        #
        # Add an another two values to the attribute
        #
        self.discardMessages()
        self.ldb.modify_ldif("dn: " + dn + "\n" + "changetype: modify\n" +
                             "add: carLicense\n" + "carLicense: license-03\n" +
                             "carLicense: license-04\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")
        attributes = messages[0]["dsdbChange"]["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(2, len(values))
        self.assertEquals("license-03", values[0]["value"])
        self.assertEquals("license-04", values[1]["value"])

        #
        # delete two values to the attribute
        #
        self.discardMessages()
        self.ldb.modify_ldif("dn: " + dn + "\n" + "changetype: delete\n" +
                             "delete: carLicense\n" +
                             "carLicense: license-03\n" +
                             "carLicense: license-04\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")
        attributes = messages[0]["dsdbChange"]["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("delete", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(2, len(values))
        self.assertEquals("license-03", values[0]["value"])
        self.assertEquals("license-04", values[1]["value"])

        #
        # replace two values to the attribute
        #
        self.discardMessages()
        self.ldb.modify_ldif("dn: " + dn + "\n" + "changetype: delete\n" +
                             "replace: carLicense\n" +
                             "carLicense: license-05\n" +
                             "carLicense: license-06\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")
        attributes = messages[0]["dsdbChange"]["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("replace", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(2, len(values))
        self.assertEquals("license-05", values[0]["value"])
        self.assertEquals("license-06", values[1]["value"])
Exemple #54
0
class AuthLogPassChangeTests(samba.tests.auth_log_base.AuthLogTestBase):

    def setUp(self):
        super(AuthLogPassChangeTests, self).setUp()

        self.remoteAddress = os.environ["CLIENT_IP"]
        self.server_ip = os.environ["SERVER_IP"]

        host = "ldap://%s" % os.environ["SERVER"]
        self.ldb = SamDB(url=host,
                         session_info=system_session(),
                         credentials=self.get_credentials(),
                         lp=self.get_loadparm())

        print "ldb %s" % type(self.ldb)
        # Gets back the basedn
        base_dn = self.ldb.domain_dn()
        print "base_dn %s" % base_dn

        # Gets back the configuration basedn
        configuration_dn = self.ldb.get_config_basedn().get_linearized()

        # Get the old "dSHeuristics" if it was set
        dsheuristics = self.ldb.get_dsheuristics()

        # Set the "dSHeuristics" to activate the correct "userPassword"
        # behaviour
        self.ldb.set_dsheuristics("000000001")

        # Reset the "dSHeuristics" as they were before
        self.addCleanup(self.ldb.set_dsheuristics, dsheuristics)

        # Get the old "minPwdAge"
        minPwdAge = self.ldb.get_minPwdAge()

        # Set it temporarily to "0"
        self.ldb.set_minPwdAge("0")
        self.base_dn = self.ldb.domain_dn()

        # Reset the "minPwdAge" as it was before
        self.addCleanup(self.ldb.set_minPwdAge, minPwdAge)

        # (Re)adds the test user USER_NAME with password USER_PASS
        delete_force(self.ldb, "cn=" + USER_NAME + ",cn=users," + self.base_dn)
        self.ldb.add({
             "dn": "cn=" + USER_NAME + ",cn=users," + self.base_dn,
             "objectclass": "user",
             "sAMAccountName": USER_NAME,
             "userPassword": USER_PASS
        })

        # discard any auth log messages for the password setup
        self.discardMessages()

    def tearDown(self):
        super(AuthLogPassChangeTests, self).tearDown()


    def test_admin_change_password(self):
        def isLastExpectedMessage(msg):
            return (msg["type"] == "Authentication" and
                    msg["Authentication"]["status"]
                        == "NT_STATUS_OK" and
                    msg["Authentication"]["serviceDescription"]
                        == "SAMR Password Change" and
                    msg["Authentication"]["authDescription"]
                        == "samr_ChangePasswordUser3")

        creds = self.insta_creds(template = self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server_ip)
        password = "******"

        net.change_password(newpassword=password.encode('utf-8'),
                            username=USER_NAME,
                            oldpassword=USER_PASS)


        messages = self.waitForMessages(isLastExpectedMessage)
        print "Received %d messages" % len(messages)
        self.assertEquals(8,
                          len(messages),
                          "Did not receive the expected number of messages")

    def test_admin_change_password_new_password_fails_restriction(self):
        def isLastExpectedMessage(msg):
            return (msg["type"] == "Authentication" and
                    msg["Authentication"]["status"]
                        == "NT_STATUS_PASSWORD_RESTRICTION" and
                    msg["Authentication"]["serviceDescription"]
                        == "SAMR Password Change" and
                    msg["Authentication"]["authDescription"]
                        == "samr_ChangePasswordUser3")

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server_ip)
        password = "******"

        exception_thrown = False
        try:
            net.change_password(newpassword=password.encode('utf-8'),
                                oldpassword=USER_PASS,
                                username=USER_NAME)
        except Exception, msg:
            exception_thrown = True
        self.assertEquals(True, exception_thrown,
                          "Expected exception not thrown")

        messages = self.waitForMessages(isLastExpectedMessage)
        self.assertEquals(8,
                          len(messages),
                          "Did not receive the expected number of messages")