def run(self, accountname, credopts=None, sambaopts=None, versionopts=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)
        paths = provision.provision_paths_from_lp(lp, lp.get("realm"))
        sam = SamDB(paths.samdb, session_info=system_session(),
                    credentials=creds, lp=lp)
        # TODO once I understand how, use the domain info to naildown
        # to the correct domain
        (cleanedaccount, realm, domain) = _get_user_realm_domain(accountname)

        res = sam.search(expression="sAMAccountName=%s" %
                    ldb.binary_encode(cleanedaccount),
                    scope=ldb.SCOPE_SUBTREE,
                    attrs=["userAccountControl", "msDS-AllowedToDelegateTo"])
        if len(res) == 0:
            raise CommandError("Unable to find account name '%s'" % accountname)
        assert(len(res) == 1)

        uac = int(res[0].get("userAccountControl")[0])
        allowed = res[0].get("msDS-AllowedToDelegateTo")

        self.outf.write("Account-DN: %s\n" %  str(res[0].dn))
        self.outf.write("UF_TRUSTED_FOR_DELEGATION: %s\n"
            % bool(uac & dsdb.UF_TRUSTED_FOR_DELEGATION))
        self.outf.write("UF_TRUSTED_TO_AUTHENTICATE_FOR_DELEGATION: %s\n" %
              bool(uac & dsdb.UF_TRUSTED_TO_AUTHENTICATE_FOR_DELEGATION))

        if allowed is not None:
            for a in allowed:
                self.outf.write("msDS-AllowedToDelegateTo: %s\n" % a)
Example #2
0
class DirsyncBaseTests(samba.tests.TestCase):

    def setUp(self):
        super(DirsyncBaseTests, self).setUp()
        self.ldb_admin = SamDB(ldapshost, credentials=creds, session_info=system_session(lp), lp=lp)
        self.base_dn = self.ldb_admin.domain_dn()
        self.domain_sid = security.dom_sid(self.ldb_admin.get_domain_sid())
        self.user_pass = samba.generate_random_password(12, 16)
        self.configuration_dn = self.ldb_admin.get_config_basedn().get_linearized()
        self.sd_utils = sd_utils.SDUtils(self.ldb_admin)
        #used for anonymous login
        print("baseDN: %s" % self.base_dn)

    def get_user_dn(self, name):
        return "CN=%s,CN=Users,%s" % (name, self.base_dn)

    def get_ldb_connection(self, target_username, target_password):
        creds_tmp = Credentials()
        creds_tmp.set_username(target_username)
        creds_tmp.set_password(target_password)
        creds_tmp.set_domain(creds.get_domain())
        creds_tmp.set_realm(creds.get_realm())
        creds_tmp.set_workstation(creds.get_workstation())
        creds_tmp.set_gensec_features(creds_tmp.get_gensec_features()
                                      | gensec.FEATURE_SEAL)
        creds_tmp.set_kerberos_state(DONT_USE_KERBEROS) # kinit is too expensive to use in a tight loop
        ldb_target = SamDB(url=ldaphost, credentials=creds_tmp, lp=lp)
        return ldb_target
Example #3
0
    def run(self, H=None, credopts=None, sambaopts=None, versionopts=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)

        samdb = SamDB(url=H, session_info=system_session(),
            credentials=creds, lp=lp)

        domain_dn = samdb.domain_dn()
        forest_dn = samba.dn_from_dns_name(samdb.forest_dns_name())
        infrastructure_dn = "CN=Infrastructure," + domain_dn
        naming_dn = "CN=Partitions,%s" % samdb.get_config_basedn()
        schema_dn = samdb.get_schema_basedn()
        rid_dn = "CN=RID Manager$,CN=System," + domain_dn
        domaindns_dn = "CN=Infrastructure,DC=DomainDnsZones," + domain_dn
        forestdns_dn = "CN=Infrastructure,DC=ForestDnsZones," + forest_dn

        infrastructureMaster = get_fsmo_roleowner(samdb, infrastructure_dn)
        pdcEmulator = get_fsmo_roleowner(samdb, domain_dn)
        namingMaster = get_fsmo_roleowner(samdb, naming_dn)
        schemaMaster = get_fsmo_roleowner(samdb, schema_dn)
        ridMaster = get_fsmo_roleowner(samdb, rid_dn)
        domaindnszonesMaster = get_fsmo_roleowner(samdb, domaindns_dn)
        forestdnszonesMaster = get_fsmo_roleowner(samdb, forestdns_dn)

        self.message("SchemaMasterRole owner: " + schemaMaster)
        self.message("InfrastructureMasterRole owner: " + infrastructureMaster)
        self.message("RidAllocationMasterRole owner: " + ridMaster)
        self.message("PdcEmulationMasterRole owner: " + pdcEmulator)
        self.message("DomainNamingMasterRole owner: " + namingMaster)
        self.message("DomainDnsZonesMasterRole owner: " + domaindnszonesMaster)
        self.message("ForestDnsZonesMasterRole owner: " + forestdnszonesMaster)
Example #4
0
    def test_1000_binds(self):

        for x in range(1, 1000):
            samdb = SamDB(host, credentials=creds,
                         session_info=system_session(self.lp), lp=self.lp)
            samdb.search(base=samdb.domain_dn(),
                         scope=SCOPE_BASE, attrs=["*"])
Example #5
0
    def run(self, attribute, H=None, credopts=None, sambaopts=None, versionopts=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        samdb = SamDB(url=H, session_info=system_session(),
            credentials=creds, lp=lp)

        schema_dn = samdb.schema_dn()

        may_filt = '(&(objectClass=classSchema)' \
         '(|(mayContain={0})(systemMayContain={0})))'.format(attribute)
        must_filt = '(&(objectClass=classSchema)' \
         '(|(mustContain={0})(systemMustContain={0})))'.format(attribute)

        may_res = samdb.search(base=schema_dn, scope=ldb.SCOPE_SUBTREE,
                           expression=may_filt, attrs=['cn'])
        must_res = samdb.search(base=schema_dn, scope=ldb.SCOPE_SUBTREE,
                           expression=must_filt, attrs=['cn'])

        self.outf.write('--- MAY contain ---\n')
        for msg in may_res:
            self.outf.write('%s\n' % msg['cn'][0])

        self.outf.write('--- MUST contain ---\n')
        for msg in must_res:
            self.outf.write('%s\n' % msg['cn'][0])
Example #6
0
    def run(self, computername, new_ou_dn, credopts=None, sambaopts=None,
            versionopts=None, H=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)
        samdb = SamDB(url=H, session_info=system_session(),
                      credentials=creds, lp=lp)
        domain_dn = ldb.Dn(samdb, samdb.domain_dn())

        samaccountname = computername
        if not computername.endswith('$'):
            samaccountname = "%s$" % computername

        filter = ("(&(sAMAccountName=%s)(sAMAccountType=%u))" %
                  (ldb.binary_encode(samaccountname),
                   dsdb.ATYPE_WORKSTATION_TRUST))
        try:
            res = samdb.search(base=domain_dn,
                               expression=filter,
                               scope=ldb.SCOPE_SUBTREE)
            computer_dn = res[0].dn
        except IndexError:
            raise CommandError('Unable to find computer "%s"' % (computername))

        full_new_ou_dn = ldb.Dn(samdb, new_ou_dn)
        if not full_new_ou_dn.is_child_of(domain_dn):
            full_new_ou_dn.add_base(domain_dn)
        new_computer_dn = ldb.Dn(samdb, str(computer_dn))
        new_computer_dn.remove_base_components(len(computer_dn)-1)
        new_computer_dn.add_base(full_new_ou_dn)
        try:
            samdb.rename(computer_dn, new_computer_dn)
        except Exception as e:
            raise CommandError('Failed to move computer "%s"' % computername, e)
        self.outf.write('Moved computer "%s" to "%s"\n' %
                        (computername, new_ou_dn))
Example #7
0
    def run(self, computername, credopts=None, sambaopts=None, versionopts=None,
            H=None, computer_attrs=None):

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)
        samdb = SamDB(url=H, session_info=system_session(),
                      credentials=creds, lp=lp)

        attrs = None
        if computer_attrs:
            attrs = computer_attrs.split(",")

        samaccountname = computername
        if not computername.endswith('$'):
            samaccountname = "%s$" % computername

        filter = ("(&(sAMAccountType=%d)(sAMAccountName=%s))" %
                  (dsdb.ATYPE_WORKSTATION_TRUST,
                   ldb.binary_encode(samaccountname)))

        domaindn = samdb.domain_dn()

        try:
            res = samdb.search(base=domaindn, expression=filter,
                               scope=ldb.SCOPE_SUBTREE, attrs=attrs)
            computer_dn = res[0].dn
        except IndexError:
            raise CommandError('Unable to find computer "%s"' %
                               samaccountname)

        for msg in res:
            computer_ldif = samdb.write_ldif(msg, ldb.CHANGETYPE_NONE)
            self.outf.write(computer_ldif)
Example #8
0
    def run(self, H=None, credopts=None, sambaopts=None, versionopts=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        samdb = SamDB(url=H, session_info=system_session(),
            credentials=creds, lp=lp)

        res = samdb.search(pso_container(samdb), scope=ldb.SCOPE_SUBTREE,
                           attrs=['name', 'msDS-PasswordSettingsPrecedence'],
                           expression="(objectClass=msDS-PasswordSettings)")

        # an unprivileged search against Windows returns nothing here. On Samba
        # we get the PSO names, but not their attributes
        if len(res) == 0 or 'msDS-PasswordSettingsPrecedence' not in res[0]:
            self.outf.write("No PSOs are present, or you don't have permission to view them.\n")
            return

        # sort the PSOs so they're displayed in order of precedence
        pso_list = sorted(res, cmp=pso_cmp)

        self.outf.write("Precedence | PSO name\n")
        self.outf.write("--------------------------------------------------\n")

        for pso in pso_list:
            precedence = pso['msDS-PasswordSettingsPrecedence']
            self.outf.write("%-10s | %s\n" %(precedence, pso['name']))
Example #9
0
    def run(self, H=None, credopts=None, sambaopts=None, versionopts=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        samdb = SamDB(url=H, session_info=system_session(),
            credentials=creds, lp=lp)

        domain_dn = samdb.domain_dn()
        object_dn = "%s,%s" % (self.objectdn, domain_dn)

        # Show all the settings we know how to set in the forest object!
        res = samdb.search(base=object_dn, scope=ldb.SCOPE_BASE,
                           attrs=self.attributes)

        # Now we just display these attributes. The value is that
        # we make them a bit prettier and human accessible.
        # There should only be one response!
        res_object = res[0]

        self.outf.write("Settings for %s\n" % object_dn)
        for attr in self.attributes:
            try:
                self.outf.write("%s: %s\n" % (attr, res_object[attr][0]))
            except KeyError:
                self.outf.write("%s: <NO VALUE>\n" % attr)
Example #10
0
 def run(self, user, credopts=None, sambaopts=None, versionopts=None):
     lp = sambaopts.get_loadparm()
     creds = credopts.get_credentials(lp)
     paths = provision.provision_paths_from_lp(lp, lp.get("realm"))
     sam = SamDB(paths.samdb, session_info=system_session(),
                 credentials=creds, lp=lp)
     # TODO once I understand how, use the domain info to naildown
     # to the correct domain
     (cleaneduser, realm, domain) = _get_user_realm_domain(user)
     self.outf.write(cleaneduser+"\n")
     res = sam.search(expression="samaccountname=%s" % ldb.binary_encode(cleaneduser),
                         scope=ldb.SCOPE_SUBTREE,
                         attrs=["servicePrincipalName"])
     if len(res) >0:
         spns = res[0].get("servicePrincipalName")
         found = False
         flag = ldb.FLAG_MOD_ADD
         if spns != None:
             self.outf.write(
                 "User %s has the following servicePrincipalName: \n" %
                 res[0].dn)
             for e in spns:
                 self.outf.write("\t %s\n" % e)
         else:
             self.outf.write("User %s has no servicePrincipalName" %
                 res[0].dn)
     else:
         raise CommandError("User %s not found" % user)
Example #11
0
File: user.py Project: sYnfo/samba
    def run(self, username=None, sambaopts=None, credopts=None,
            versionopts=None, H=None, filter=None, days=None, noexpiry=None):
        if username is None and filter is None:
            raise CommandError("Either the username or '--filter' must be specified!")

        if filter is None:
            filter = "(&(objectClass=user)(sAMAccountName=%s))" % (ldb.binary_encode(username))

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        samdb = SamDB(url=H, session_info=system_session(),
            credentials=creds, lp=lp)

        try:
            samdb.setexpiry(filter, days*24*3600, no_expiry_req=noexpiry)
        except Exception as msg:
            # FIXME: Catch more specific exception
            raise CommandError("Failed to set expiry for user '%s': %s" % (
                username or filter, msg))
        if noexpiry:
            self.outf.write("Expiry for user '%s' disabled.\n" % (
                username or filter))
        else:
            self.outf.write("Expiry for user '%s' set to %u days.\n" % (
                username or filter, days))
Example #12
0
    def run(self, accountname, onoff, H=None, credopts=None, sambaopts=None,
            versionopts=None):

        on = False
        if onoff == "on":
            on = True
        elif onoff == "off":
            on = False
        else:
            raise CommandError("invalid argument: '%s' (choose from 'on', 'off')" % onoff)

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)
        paths = provision.provision_paths_from_lp(lp, lp.get("realm"))
        if H == None:
            path = paths.samdb
        else:
            path = H

        sam = SamDB(path, session_info=system_session(),
                    credentials=creds, lp=lp)
        # TODO once I understand how, use the domain info to naildown
        # to the correct domain
        (cleanedaccount, realm, domain) = _get_user_realm_domain(accountname)

        search_filter = "sAMAccountName=%s" % ldb.binary_encode(cleanedaccount)
        flag = dsdb.UF_TRUSTED_TO_AUTHENTICATE_FOR_DELEGATION
        try:
            sam.toggle_userAccountFlags(search_filter, flag,
                        flags_str="Trusted-to-Authenticate-for-Delegation",
                        on=on, strict=True)
        except Exception as err:
            raise CommandError(err)
Example #13
0
File: user.py Project: runt18/samba
    def run(self, username=None, filter=None, credopts=None, sambaopts=None,
            versionopts=None, H=None, newpassword=None,
            must_change_at_next_login=False, random_password=False):
        if filter is None and username is None:
            raise CommandError("Either the username or '--filter' must be specified!")

        if random_password:
            password = generate_random_password(128, 255)
        else:
            password = newpassword

        while 1:
            if password is not None and password is not '':
                break
            password = getpass("New Password: "******"(&(objectClass=user)(sAMAccountName={0!s}))".format((ldb.binary_encode(username)))

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        creds.set_gensec_features(creds.get_gensec_features() | gensec.FEATURE_SEAL)

        samdb = SamDB(url=H, session_info=system_session(),
                      credentials=creds, lp=lp)

        try:
            samdb.setpassword(filter, password,
                              force_change_at_next_login=must_change_at_next_login,
                              username=username)
        except Exception, msg:
            # FIXME: catch more specific exception
            raise CommandError("Failed to set password for user '{0!s}': {1!s}".format(username or filter, msg))
Example #14
0
    def run(self, accountname, principal, H=None, credopts=None, sambaopts=None,
            versionopts=None):

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)
        paths = provision.provision_paths_from_lp(lp, lp.get("realm"))
        if H == None:
            path = paths.samdb
        else:
            path = H

        sam = SamDB(path, session_info=system_session(),
                    credentials=creds, lp=lp)
        # TODO once I understand how, use the domain info to naildown
        # to the correct domain
        (cleanedaccount, realm, domain) = _get_user_realm_domain(accountname)

        res = sam.search(expression="sAMAccountName=%s" %
                         ldb.binary_encode(cleanedaccount),
                         scope=ldb.SCOPE_SUBTREE,
                         attrs=["msDS-AllowedToDelegateTo"])
        if len(res) == 0:
            raise CommandError("Unable to find account name '%s'" % accountname)
        assert(len(res) == 1)

        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["msDS-AllowedToDelegateTo"] = ldb.MessageElement([principal],
                                          ldb.FLAG_MOD_DELETE,
                                          "msDS-AllowedToDelegateTo")
        try:
            sam.modify(msg)
        except Exception as err:
            raise CommandError(err)
Example #15
0
    def run(self, accountname, onoff, credopts=None, sambaopts=None, versionopts=None):

        on = False
        if onoff == "on":
            on = True
        elif onoff == "off":
            on = False
        else:
            raise CommandError("Invalid argument [%s]" % onoff)

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)
        paths = provision.provision_paths_from_lp(lp, lp.get("realm"))
        sam = SamDB(paths.samdb, session_info=system_session(),
                    credentials=creds, lp=lp)
        # TODO once I understand how, use the domain info to naildown
        # to the correct domain
        (cleanedaccount, realm, domain) = _get_user_realm_domain(accountname)

        search_filter = "sAMAccountName=%s" % ldb.binary_encode(cleanedaccount)
        flag = dsdb.UF_TRUSTED_FOR_DELEGATION
        try:
            sam.toggle_userAccountFlags(search_filter, flag, on=on, strict=True)
        except Exception, err:
            raise CommandError(err)
Example #16
0
    def run(self, username=None, filter=None, credopts=None, sambaopts=None,
            versionopts=None, H=None, newpassword=None,
            must_change_at_next_login=None):
        if filter is None and username is None:
            raise CommandError("Either the username or '--filter' must be specified!")

        password = newpassword
        if password is None:
            password = getpass("New Password: "******"(&(objectClass=user)(sAMAccountName=%s))" % (username)

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        creds.set_gensec_features(creds.get_gensec_features() | gensec.FEATURE_SEAL)

        samdb = SamDB(url=H, session_info=system_session(),
                      credentials=creds, lp=lp)

        try:
            samdb.setpassword(filter, password,
                              force_change_at_next_login=must_change_at_next_login,
                              username=username)
        except Exception, e:
            raise CommandError('Failed to set password for user "%s"' % username, e)
Example #17
0
    def test_join_time_ridalloc(self):
        """Perform a join against the RID manager and assert we have a RID Set"""

        fsmo_dn = ldb.Dn(self.ldb_dc1, "CN=RID Manager$,CN=System," + self.ldb_dc1.domain_dn())
        (fsmo_owner, fsmo_not_owner) = self._determine_fSMORoleOwner(fsmo_dn)

        targetdir = self._test_join(fsmo_owner['dns_name'], "RIDALLOCTEST5")
        try:
            # Connect to the database
            ldb_url = "tdb://%s" % os.path.join(targetdir, "private/sam.ldb")
            smbconf = os.path.join(targetdir, "etc/smb.conf")

            lp = self.get_loadparm()
            new_ldb = SamDB(ldb_url, credentials=self.get_credentials(),
                            session_info=system_session(lp), lp=lp)

            # 1. Get server name
            res = new_ldb.search(base=ldb.Dn(new_ldb, new_ldb.get_serverName()),
                                 scope=ldb.SCOPE_BASE, attrs=["serverReference"])
            # 2. Get server reference
            server_ref_dn = ldb.Dn(new_ldb, res[0]['serverReference'][0])

            # 3. Assert we get the RID Set
            res = new_ldb.search(base=server_ref_dn,
                                 scope=ldb.SCOPE_BASE, attrs=['rIDSetReferences'])

            self.assertTrue("rIDSetReferences" in res[0])
        finally:
            self._test_force_demote(fsmo_owner['dns_name'], "RIDALLOCTEST5")
            shutil.rmtree(targetdir, ignore_errors=True)
Example #18
0
    def run(self, H=None, credopts=None, sambaopts=None, versionopts=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)

        samdb = SamDB(url=H, session_info=system_session(),
            credentials=creds, lp=lp)

        domain_dn = samdb.domain_dn()
        forest_dn = samba.dn_from_dns_name(samdb.forest_dns_name())
        infrastructure_dn = "CN=Infrastructure," + domain_dn
        naming_dn = "CN=Partitions,%s" % samdb.get_config_basedn()
        schema_dn = samdb.get_schema_basedn()
        rid_dn = "CN=RID Manager$,CN=System," + domain_dn
        domaindns_dn = "CN=Infrastructure,DC=DomainDnsZones," + domain_dn
        forestdns_dn = "CN=Infrastructure,DC=ForestDnsZones," + forest_dn

        masters = [(schema_dn, "schema", "SchemaMasterRole"),
                   (infrastructure_dn, "infrastructure", "InfrastructureMasterRole"),
                   (rid_dn, "rid", "RidAllocationMasterRole"),
                   (domain_dn, "pdc", "PdcEmulationMasterRole"),
                   (naming_dn, "naming", "DomainNamingMasterRole"),
                   (domaindns_dn, "domaindns", "DomainDnsZonesMasterRole"),
                   (forestdns_dn, "forestdns", "ForestDnsZonesMasterRole"),
        ]

        for master in masters:
            (dn, short_name, long_name) = master
            try:
                master = get_fsmo_roleowner(samdb, dn, short_name)
                if master is not None:
                    self.message("%s owner: %s" % (long_name, str(master)))
                else:
                    self.message("%s has no current owner" % (long_name))
            except CommandError, e:
                self.message("%s: * %s" % (long_name, e.message))
Example #19
0
class SchemaTests_msDS_isRODC(samba.tests.TestCase):

    def setUp(self):
        super(SchemaTests_msDS_isRODC, self).setUp()
        self.ldb =  SamDB(host, credentials=creds,
            session_info=system_session(lp), lp=lp, options=ldb_options)
        res = self.ldb.search(base="", expression="", scope=SCOPE_BASE, attrs=["defaultNamingContext"])
        self.assertEquals(len(res), 1)
        self.base_dn = res[0]["defaultNamingContext"][0]

    def test_objectClass_ntdsdsa(self):
        res = self.ldb.search(self.base_dn, expression="objectClass=nTDSDSA",
                              attrs=["msDS-isRODC"], controls=["search_options:1:2"])
        for ldb_msg in res:
            self.assertTrue("msDS-isRODC" in ldb_msg)

    def test_objectClass_server(self):
        res = self.ldb.search(self.base_dn, expression="objectClass=server",
                              attrs=["msDS-isRODC"], controls=["search_options:1:2"])
        for ldb_msg in res:
            ntds_search_dn = "CN=NTDS Settings,%s" % ldb_msg['dn']
            try:
                res_check = self.ldb.search(ntds_search_dn, attrs=["objectCategory"])
            except LdbError, (num, _):
                self.assertEquals(num, ERR_NO_SUCH_OBJECT)
                print("Server entry %s doesn't have a NTDS settings object" % res[0]['dn'])
            else:
                self.assertTrue("objectCategory" in res_check[0])
                self.assertTrue("msDS-isRODC" in ldb_msg)
Example #20
0
    def run(self, accountname, credopts=None, sambaopts=None, versionopts=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)
        paths = provision.provision_paths_from_lp(lp, lp.get("realm"))
        sam = SamDB(paths.samdb, session_info=system_session(),
                    credentials=creds, lp=lp)
        # TODO once I understand how, use the domain info to naildown
        # to the correct domain
        (cleanedaccount, realm, domain) = _get_user_realm_domain(accountname)
        print "Searching for: %s" % (cleanedaccount)
        res = sam.search(expression="sAMAccountName=%s" % ldb.binary_encode(cleanedaccount),
                            scope=ldb.SCOPE_SUBTREE,
                            attrs=["userAccountControl", "msDS-AllowedToDelegateTo"])
        if len(res) != 1:
            raise CommandError("Account %s found %d times" % (accountname, len(res)))

        uac = int(res[0].get("userAccountControl")[0])
        allowed = res[0].get("msDS-AllowedToDelegateTo")

        print "Account-DN: %s" %  str(res[0].dn)

        if uac & dsdb.UF_TRUSTED_FOR_DELEGATION:
            print "UF_TRUSTED_FOR_DELEGATION: 1"
        else:
            print "UF_TRUSTED_FOR_DELEGATION: 0"

        if uac & dsdb.UF_TRUSTED_TO_AUTHENTICATE_FOR_DELEGATION:
            print "UF_TRUSTED_TO_AUTHENTICATE_FOR_DELEGATION: 1"
        else:
            print "UF_TRUSTED_TO_AUTHENTICATE_FOR_DELEGATION: 0"

        if allowed != None:
            for a in allowed:
                print "msDS-AllowedToDelegateTo: %s" % (str(a))
Example #21
0
    def run(self, subcommand, H=None, min_pwd_age=None, max_pwd_age=None,
            quiet=False, complexity=None, store_plaintext=None, history_length=None,
            min_pwd_length=None, credopts=None, sambaopts=None,
            versionopts=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        samdb = SamDB(url=H, session_info=system_session(),
            credentials=creds, lp=lp)

        domain_dn = samdb.domain_dn()
        res = samdb.search(domain_dn, scope=ldb.SCOPE_BASE,
          attrs=["pwdProperties", "pwdHistoryLength", "minPwdLength",
                 "minPwdAge", "maxPwdAge"])
        assert(len(res) == 1)
        try:
            pwd_props = int(res[0]["pwdProperties"][0])
            pwd_hist_len = int(res[0]["pwdHistoryLength"][0])
            cur_min_pwd_len = int(res[0]["minPwdLength"][0])
            # ticks -> days
            cur_min_pwd_age = int(abs(int(res[0]["minPwdAge"][0])) / (1e7 * 60 * 60 * 24))
            if int(res[0]["maxPwdAge"][0]) == -0x8000000000000000:
                cur_max_pwd_age = 0
            else:
                cur_max_pwd_age = int(abs(int(res[0]["maxPwdAge"][0])) / (1e7 * 60 * 60 * 24))
        except Exception, e:
            raise CommandError("Could not retrieve password properties!", e)
Example #22
0
    def run(self, groupname, credopts=None, sambaopts=None, versionopts=None,
            H=None, group_attrs=None):

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)
        samdb = SamDB(url=H, session_info=system_session(),
                      credentials=creds, lp=lp)

        attrs = None
        if group_attrs:
            attrs = group_attrs.split(",")

        filter = ("(&(sAMAccountType=%d)(sAMAccountName=%s))" %
                     ( ATYPE_SECURITY_GLOBAL_GROUP,
                       ldb.binary_encode(groupname)))

        domaindn = samdb.domain_dn()

        try:
            res = samdb.search(base=domaindn, expression=filter,
                               scope=ldb.SCOPE_SUBTREE, attrs=attrs)
            user_dn = res[0].dn
        except IndexError:
            raise CommandError('Unable to find group "%s"' % (groupname))

        for msg in res:
            user_ldif = samdb.write_ldif(msg, ldb.CHANGETYPE_NONE)
            self.outf.write(user_ldif)
Example #23
0
    def run(self, groupname, credopts=None, sambaopts=None, versionopts=None, H=None):

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)
        samdb = SamDB(url=H, session_info=system_session(),
                      credentials=creds, lp=lp)

        filter = ("(&(sAMAccountName=%s)(objectClass=group))" %
                  groupname)

        try:
            res = samdb.search(base=samdb.domain_dn(),
                               scope=ldb.SCOPE_SUBTREE,
                               expression=filter,
                               attrs=["dn"])
            group_dn = res[0].dn
        except IndexError:
            raise CommandError('Unable to find group "%s"' % (groupname))

        try:
            samdb.delete(group_dn)
        except Exception as e:
            # FIXME: catch more specific exception
            raise CommandError('Failed to remove group "%s"' % groupname, e)
        self.outf.write("Deleted group %s\n" % groupname)
Example #24
0
    def run(self, groupname, credopts=None, sambaopts=None,
            versionopts=None, H=None, groupou=None, group_scope=None,
            group_type=None, description=None, mail_address=None, notes=None, gid_number=None, nis_domain=None):

        if (group_type or "Security") == "Security":
            gtype = security_group.get(group_scope, GTYPE_SECURITY_GLOBAL_GROUP)
        else:
            gtype = distribution_group.get(group_scope, GTYPE_DISTRIBUTION_GLOBAL_GROUP)

        if (gid_number is None and nis_domain is not None) or (gid_number is not None and nis_domain is None):
            raise CommandError('Both --gid-number and --nis-domain have to be set for a RFC2307-enabled group. Operation cancelled.')

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)

        try:
            samdb = SamDB(url=H, session_info=system_session(),
                          credentials=creds, lp=lp)
            samdb.newgroup(groupname, groupou=groupou, grouptype = gtype,
                          description=description, mailaddress=mail_address, notes=notes,
                          gidnumber=gid_number, nisdomain=nis_domain)
        except Exception as e:
            # FIXME: catch more specific exception
            raise CommandError('Failed to create group "%s"' % groupname, e)
        self.outf.write("Added group %s\n" % groupname)
Example #25
0
    def run(self, accountname, principal, credopts=None, sambaopts=None, versionopts=None):

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)
        paths = provision.provision_paths_from_lp(lp, lp.get("realm"))
        sam = SamDB(paths.samdb, session_info=system_session(),
                    credentials=creds, lp=lp)
        # TODO once I understand how, use the domain info to naildown
        # to the correct domain
        (cleanedaccount, realm, domain) = _get_user_realm_domain(accountname)

        res = sam.search(expression="sAMAccountName=%s" % cleanedaccount,
                            scope=ldb.SCOPE_SUBTREE,
                            attrs=["msDS-AllowedToDelegateTo"])
        if len(res) != 1:
            raise CommandError("Account %s found %d times" % (accountname, len(res)))

        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["msDS-AllowedToDelegateTo"] = ldb.MessageElement([principal],
                                              ldb.FLAG_MOD_ADD,
                                              "msDS-AllowedToDelegateTo")
        try:
            sam.modify(msg)
        except Exception, err:
            raise CommandError(err)
Example #26
0
    def run(self, username=None, filter=None, credopts=None, sambaopts=None,
            versionopts=None, H=None, newpassword=None,
            must_change_at_next_login=False, random_password=False,
            smartcard_required=False, clear_smartcard_required=False):
        if filter is None and username is None:
            raise CommandError("Either the username or '--filter' must be specified!")

        password = newpassword

        if smartcard_required:
            if password is not None and password is not '':
                raise CommandError('It is not allowed to specifiy '
                                   '--newpassword '
                                   'together with --smartcard-required.')
            if must_change_at_next_login:
                raise CommandError('It is not allowed to specifiy '
                                   '--must-change-at-next-login '
                                   'together with --smartcard-required.')
            if clear_smartcard_required:
                raise CommandError('It is not allowed to specifiy '
                                   '--clear-smartcard-required '
                                   'together with --smartcard-required.')

        if random_password and not smartcard_required:
            password = generate_random_password(128, 255)

        while True:
            if smartcard_required:
                break
            if password is not None and password is not '':
                break
            password = getpass("New Password: "******"Retype Password: "******"Sorry, passwords do not match.\n")

        if filter is None:
            filter = "(&(objectClass=user)(sAMAccountName=%s))" % (ldb.binary_encode(username))

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        creds.set_gensec_features(creds.get_gensec_features() | gensec.FEATURE_SEAL)

        samdb = SamDB(url=H, session_info=system_session(),
                      credentials=creds, lp=lp)

        if smartcard_required:
            command = ""
            try:
                command = "Failed to set UF_SMARTCARD_REQUIRED for user '%s'" % (username or filter)
                flags = dsdb.UF_SMARTCARD_REQUIRED
                samdb.toggle_userAccountFlags(filter, flags, on=True)
                command = "Failed to enable account for user '%s'" % (username or filter)
                samdb.enable_account(filter)
            except Exception, msg:
                # FIXME: catch more specific exception
                raise CommandError("%s: %s" % (command, msg))
            self.outf.write("Added UF_SMARTCARD_REQUIRED OK\n")
Example #27
0
    def setUp(self):
        super(UserAccountControlTests, self).setUp()
        self.admin_creds = creds
        self.admin_samdb = SamDB(url=ldaphost,
                                 session_info=system_session(),
                                 credentials=self.admin_creds, lp=lp)
        self.domain_sid = security.dom_sid(self.admin_samdb.get_domain_sid())
        self.base_dn = self.admin_samdb.domain_dn()

        self.unpriv_user = "******"
        self.unpriv_user_pw = "samba123@"
        self.unpriv_creds = self.get_creds(self.unpriv_user, self.unpriv_user_pw)

        delete_force(self.admin_samdb, "CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))
        delete_force(self.admin_samdb, "OU=test_computer_ou1,%s" % (self.base_dn))
        delete_force(self.admin_samdb, "CN=%s,CN=Users,%s" % (self.unpriv_user, self.base_dn))

        self.admin_samdb.newuser(self.unpriv_user, self.unpriv_user_pw)
        res = self.admin_samdb.search("CN=%s,CN=Users,%s" % (self.unpriv_user, self.admin_samdb.domain_dn()),
                                      scope=SCOPE_BASE,
                                      attrs=["objectSid"])
        self.assertEqual(1, len(res))

        self.unpriv_user_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        self.unpriv_user_dn = res[0].dn

        self.samdb = SamDB(url=ldaphost, credentials=self.unpriv_creds, lp=lp)

        self.samr = samr.samr("ncacn_ip_tcp:%s[seal]" % host, lp, self.unpriv_creds)
        self.samr_handle = self.samr.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
        self.samr_domain = self.samr.OpenDomain(self.samr_handle, security.SEC_FLAG_MAXIMUM_ALLOWED, self.domain_sid)

        self.sd_utils = sd_utils.SDUtils(self.admin_samdb)

        self.admin_samdb.create_ou("OU=test_computer_ou1," + self.base_dn)
        self.unpriv_user_sid = self.sd_utils.get_object_sid(self.unpriv_user_dn)
        mod = "(OA;;CC;bf967a86-0de6-11d0-a285-00aa003049e2;;%s)" % str(self.unpriv_user_sid)

        old_sd = self.sd_utils.read_sd_on_dn("OU=test_computer_ou1," + self.base_dn)

        self.sd_utils.dacl_add_ace("OU=test_computer_ou1," + self.base_dn, mod)

        self.add_computer_ldap("testcomputer-t")

        self.sd_utils.modify_sd_on_dn("OU=test_computer_ou1," + self.base_dn, old_sd)

        self.computernames = ["testcomputer-0"]

        # Get the SD of the template account, then force it to match
        # what we expect for SeMachineAccountPrivilege accounts, so we
        # can confirm we created the accounts correctly
        self.sd_reference_cc = self.sd_utils.read_sd_on_dn("CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))

        self.sd_reference_modify = self.sd_utils.read_sd_on_dn("CN=testcomputer-t,OU=test_computer_ou1,%s" % (self.base_dn))
        for ace in self.sd_reference_modify.dacl.aces:
            if ace.type == security.SEC_ACE_TYPE_ACCESS_ALLOWED and ace.trustee == self.unpriv_user_sid:
                ace.access_mask = ace.access_mask | security.SEC_ADS_SELF_WRITE | security.SEC_ADS_WRITE_PROP

        # Now reconnect without domain admin rights
        self.samdb = SamDB(url=ldaphost, credentials=self.unpriv_creds, lp=lp)
Example #28
0
    def test_modify_replicated_attributes(self):
        # some timestamp ones
        dn = 'CN=Guest,CN=Users,' + self.base_dn
        value = 'hallooo'
        for attr in ['carLicense', 'middleName']:
            msg = ldb.Message()
            msg.dn = ldb.Dn(self.samdb, dn)
            msg[attr] = ldb.MessageElement(value,
                                           ldb.FLAG_MOD_REPLACE,
                                           attr)
            try:
                self.samdb.modify(msg)
                self.fail("Failed to fail to modify %s %s" % (dn, attr))
            except ldb.LdbError as e1:
                (ecode, emsg) = e1.args
                if ecode != ldb.ERR_REFERRAL:
                    self.fail("Failed to REFER when trying to modify %s %s" %
                              (dn, attr))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    try:
                        tmpdb = SamDB(address, credentials=CREDS,
                                      session_info=system_session(LP), lp=LP)
                        tmpdb.modify(msg)
                    except ldb.LdbError as e:
                        self.fail("couldn't modify referred location %s" %
                                  address)

                    if address.lower().startswith(self.samdb.domain_dns_name()):
                        self.fail("referral address did not give a specific DC")
Example #29
0
class BaseDeleteTests(samba.tests.TestCase):

    def GUID_string(self, guid):
        return self.ldb.schema_format_value("objectGUID", guid)

    def setUp(self):
        super(BaseDeleteTests, self).setUp()
        self.ldb = SamDB(host, credentials=creds, session_info=system_session(lp), lp=lp)

        self.base_dn = self.ldb.domain_dn()
        self.configuration_dn = self.ldb.get_config_basedn().get_linearized()

    def search_guid(self, guid):
        print "SEARCH by GUID {0!s}".format(self.GUID_string(guid))

        res = self.ldb.search(base="<GUID={0!s}>".format(self.GUID_string(guid)),
                         scope=SCOPE_BASE, controls=["show_deleted:1"])
        self.assertEquals(len(res), 1)
        return res[0]

    def search_dn(self,dn):
        print "SEARCH by DN {0!s}".format(dn)

        res = self.ldb.search(expression="(objectClass=*)",
                         base=dn,
                         scope=SCOPE_BASE,
                         controls=["show_deleted:1"])
        self.assertEquals(len(res), 1)
        return res[0]
Example #30
0
    def run(self, username, password=None, credopts=None, sambaopts=None,
            versionopts=None, H=None, must_change_at_next_login=False, random_password=False,
            use_username_as_cn=False, userou=None, surname=None, given_name=None, initials=None,
            profile_path=None, script_path=None, home_drive=None, home_directory=None,
            job_title=None, department=None, company=None, description=None,
            mail_address=None, internet_address=None, telephone_number=None, physical_delivery_office=None):

        if random_password:
            password = generate_random_password(128, 255)

        while 1:
            if password is not None and password is not '':
                break
            password = getpass("New Password: "******"Failed to add user '%s': " % username, e)
Example #31
0
class inf_to_ldb(gp_ext_setter):
    '''This class takes the .inf file parameter (essentially a GPO file mapped
    to a GUID), hashmaps it to the Samba parameter, which then uses an ldb
    object to update the parameter to Samba4. Not registry oriented whatsoever.
    '''
    def __init__(self, logger, gp_db, lp, creds, key, value):
        super(inf_to_ldb, self).__init__(logger, gp_db, lp, creds, key, value)
        try:
            self.ldb = SamDB(self.lp.samdb_url(),
                             session_info=system_session(),
                             credentials=self.creds,
                             lp=self.lp)
        except (NameError, LdbError):
            raise Exception('Failed to load SamDB for assigning Group Policy')

    def ch_minPwdAge(self, val):
        old_val = self.ldb.get_minPwdAge()
        self.logger.info('KDC Minimum Password age was changed from %s to %s' %
                         (old_val, val))
        self.gp_db.store(str(self), self.attribute, str(old_val))
        self.ldb.set_minPwdAge(val)

    def ch_maxPwdAge(self, val):
        old_val = self.ldb.get_maxPwdAge()
        self.logger.info('KDC Maximum Password age was changed from %s to %s' %
                         (old_val, val))
        self.gp_db.store(str(self), self.attribute, str(old_val))
        self.ldb.set_maxPwdAge(val)

    def ch_minPwdLength(self, val):
        old_val = self.ldb.get_minPwdLength()
        self.logger.info(
            'KDC Minimum Password length was changed from %s to %s' %
            (old_val, val))
        self.gp_db.store(str(self), self.attribute, str(old_val))
        self.ldb.set_minPwdLength(val)

    def ch_pwdProperties(self, val):
        old_val = self.ldb.get_pwdProperties()
        self.logger.info('KDC Password Properties were changed from %s to %s' %
                         (old_val, val))
        self.gp_db.store(str(self), self.attribute, str(old_val))
        self.ldb.set_pwdProperties(val)

    def days2rel_nttime(self):
        seconds = 60
        minutes = 60
        hours = 24
        sam_add = 10000000
        val = (self.val)
        val = int(val)
        return str(-(val * seconds * minutes * hours * sam_add))

    def mapper(self):
        '''ldap value : samba setter'''
        return {
            "minPwdAge": (self.ch_minPwdAge, self.days2rel_nttime),
            "maxPwdAge": (self.ch_maxPwdAge, self.days2rel_nttime),
            # Could be none, but I like the method assignment in
            # update_samba
            "minPwdLength": (self.ch_minPwdLength, self.explicit),
            "pwdProperties": (self.ch_pwdProperties, self.explicit),
        }

    def __str__(self):
        return 'System Access'
Example #32
0
 def get_samdb(self):
     lp, creds, server = self.get_lp_et_al()
     url = 'ldap://' + server
     ldb = SamDB(url, credentials=creds, lp=lp)
     return ldb
Example #33
0
    def __init__(self, domain_sid, invocationid=None, schemadn=None,
                 files=None, override_prefixmap=None, additional_prefixmap=None,
                 base_schema=None):
        from samba.provision import setup_path

        """Load schema for the SamDB from the AD schema files and
        samba4_schema.ldif

        :param samdb: Load a schema into a SamDB.
        :param schemadn: DN of the schema

        Returns the schema data loaded, to avoid double-parsing when then
        needing to add it to the db
        """

        if base_schema is None:
            base_schema = Schema.default_base_schema()

        self.base_schema = base_schema

        self.schemadn = schemadn
        # We need to have the am_rodc=False just to keep some warnings quiet -
        # this isn't a real SAM, so it's meaningless.
        self.ldb = SamDB(global_schema=False, am_rodc=False)
        if invocationid is not None:
            self.ldb.set_invocation_id(invocationid)

        self.schema_data = read_ms_schema(
            setup_path('ad-schema/%s' % Schema.base_schemas[base_schema][0]),
            setup_path('ad-schema/%s' % Schema.base_schemas[base_schema][1]))

        if files is not None:
            for file in files:
                self.schema_data += open(file, 'r').read()

        self.schema_data = substitute_var(self.schema_data,
            {"SCHEMADN": schemadn})
        check_all_substituted(self.schema_data)

        schema_version = str(Schema.get_version(base_schema))
        self.schema_dn_modify = read_and_sub_file(
            setup_path("provision_schema_basedn_modify.ldif"),
            {"SCHEMADN": schemadn, "OBJVERSION" : schema_version})

        descr = b64encode(get_schema_descriptor(domain_sid)).decode('utf8')
        self.schema_dn_add = read_and_sub_file(
            setup_path("provision_schema_basedn.ldif"),
            {"SCHEMADN": schemadn, "DESCRIPTOR": descr})

        if override_prefixmap is not None:
            self.prefixmap_data = override_prefixmap
        else:
            self.prefixmap_data = open(setup_path("prefixMap.txt"), 'r').read()

        if additional_prefixmap is not None:
            for map in additional_prefixmap:
                self.prefixmap_data += "%s\n" % map

        self.prefixmap_data = b64encode(self.prefixmap_data).decode('utf8')

        # We don't actually add this ldif, just parse it
        prefixmap_ldif = "dn: %s\nprefixMap:: %s\n\n" % (self.schemadn, self.prefixmap_data)
        self.set_from_ldif(prefixmap_ldif, self.schema_data, self.schemadn)
Example #34
0
class Schema(object):

    # the schema files (and corresponding object version) that we know about
    base_schemas = {
       "2008_R2_old" : ("MS-AD_Schema_2K8_R2_Attributes.txt",
                        "MS-AD_Schema_2K8_R2_Classes.txt",
                        47),
       "2008_R2" : ("Attributes_for_AD_DS__Windows_Server_2008_R2.ldf",
                    "Classes_for_AD_DS__Windows_Server_2008_R2.ldf",
                    47),
       "2012"    : ("AD_DS_Attributes__Windows_Server_2012.ldf",
                    "AD_DS_Classes__Windows_Server_2012.ldf",
                    56),
       "2012_R2" : ("AD_DS_Attributes__Windows_Server_2012_R2.ldf",
                    "AD_DS_Classes__Windows_Server_2012_R2.ldf",
                    69),
    }

    def __init__(self, domain_sid, invocationid=None, schemadn=None,
                 files=None, override_prefixmap=None, additional_prefixmap=None,
                 base_schema=None):
        from samba.provision import setup_path

        """Load schema for the SamDB from the AD schema files and
        samba4_schema.ldif

        :param samdb: Load a schema into a SamDB.
        :param schemadn: DN of the schema

        Returns the schema data loaded, to avoid double-parsing when then
        needing to add it to the db
        """

        if base_schema is None:
            base_schema = Schema.default_base_schema()

        self.base_schema = base_schema

        self.schemadn = schemadn
        # We need to have the am_rodc=False just to keep some warnings quiet -
        # this isn't a real SAM, so it's meaningless.
        self.ldb = SamDB(global_schema=False, am_rodc=False)
        if invocationid is not None:
            self.ldb.set_invocation_id(invocationid)

        self.schema_data = read_ms_schema(
            setup_path('ad-schema/%s' % Schema.base_schemas[base_schema][0]),
            setup_path('ad-schema/%s' % Schema.base_schemas[base_schema][1]))

        if files is not None:
            for file in files:
                self.schema_data += open(file, 'r').read()

        self.schema_data = substitute_var(self.schema_data,
            {"SCHEMADN": schemadn})
        check_all_substituted(self.schema_data)

        schema_version = str(Schema.get_version(base_schema))
        self.schema_dn_modify = read_and_sub_file(
            setup_path("provision_schema_basedn_modify.ldif"),
            {"SCHEMADN": schemadn, "OBJVERSION" : schema_version})

        descr = b64encode(get_schema_descriptor(domain_sid)).decode('utf8')
        self.schema_dn_add = read_and_sub_file(
            setup_path("provision_schema_basedn.ldif"),
            {"SCHEMADN": schemadn, "DESCRIPTOR": descr})

        if override_prefixmap is not None:
            self.prefixmap_data = override_prefixmap
        else:
            self.prefixmap_data = open(setup_path("prefixMap.txt"), 'r').read()

        if additional_prefixmap is not None:
            for map in additional_prefixmap:
                self.prefixmap_data += "%s\n" % map

        self.prefixmap_data = b64encode(self.prefixmap_data).decode('utf8')

        # We don't actually add this ldif, just parse it
        prefixmap_ldif = "dn: %s\nprefixMap:: %s\n\n" % (self.schemadn, self.prefixmap_data)
        self.set_from_ldif(prefixmap_ldif, self.schema_data, self.schemadn)

    @staticmethod
    def default_base_schema():
        """Returns the default base schema to use"""
        return "2008_R2"

    @staticmethod
    def get_version(base_schema):
        """Returns the base schema's object version, e.g. 47 for 2008_R2"""
        return Schema.base_schemas[base_schema][2]

    def set_from_ldif(self, pf, df, dn):
        dsdb._dsdb_set_schema_from_ldif(self.ldb, pf, df, dn)

    def write_to_tmp_ldb(self, schemadb_path):
        self.ldb.connect(url=schemadb_path)
        self.ldb.transaction_start()
        try:
            # These are actually ignored, as the schema has been forced
            # when the ldb object was created, and that overrides this
            self.ldb.add_ldif("""dn: @ATTRIBUTES
linkID: INTEGER

dn: @INDEXLIST
@IDXATTR: linkID
@IDXATTR: attributeSyntax
@IDXGUID: objectGUID
""")

            schema_dn_add = self.schema_dn_add \
                            + "objectGUID: 24e2ca70-b093-4ae8-84c0-2d7ac652a1b8\n"

            # These bits of LDIF are supplied when the Schema object is created
            self.ldb.add_ldif(schema_dn_add)
            self.ldb.modify_ldif(self.schema_dn_modify)
            self.ldb.add_ldif(self.schema_data)
        except:
            self.ldb.transaction_cancel()
            raise
        else:
            self.ldb.transaction_commit()

    # Return a hash with the forward attribute as a key and the back as the
    # value
    def linked_attributes(self):
        return get_linked_attributes(self.schemadn, self.ldb)

    def dnsyntax_attributes(self):
        return get_dnsyntax_attributes(self.schemadn, self.ldb)

    def convert_to_openldap(self, target, mapping):
        return dsdb._dsdb_convert_schema_to_openldap(self.ldb, target, mapping)
Example #35
0
class Schema(object):
    def __init__(self,
                 domain_sid,
                 invocationid=None,
                 schemadn=None,
                 files=None,
                 override_prefixmap=None,
                 additional_prefixmap=None):
        from samba.provision import setup_path
        """Load schema for the SamDB from the AD schema files and
        samba4_schema.ldif

        :param samdb: Load a schema into a SamDB.
        :param schemadn: DN of the schema

        Returns the schema data loaded, to avoid double-parsing when then
        needing to add it to the db
        """

        self.schemadn = schemadn
        # We need to have the am_rodc=False just to keep some warnings quiet -
        # this isn't a real SAM, so it's meaningless.
        self.ldb = SamDB(global_schema=False, am_rodc=False)
        if invocationid is not None:
            self.ldb.set_invocation_id(invocationid)

        self.schema_data = read_ms_schema(
            setup_path('ad-schema/MS-AD_Schema_2K8_R2_Attributes.txt'),
            setup_path('ad-schema/MS-AD_Schema_2K8_R2_Classes.txt'))

        if files is not None:
            for file in files:
                self.schema_data += open(file, 'r').read()

        self.schema_data = substitute_var(self.schema_data,
                                          {"SCHEMADN": schemadn})
        check_all_substituted(self.schema_data)

        self.schema_dn_modify = read_and_sub_file(
            setup_path("provision_schema_basedn_modify.ldif"),
            {"SCHEMADN": schemadn})

        descr = b64encode(get_schema_descriptor(domain_sid))
        self.schema_dn_add = read_and_sub_file(
            setup_path("provision_schema_basedn.ldif"), {
                "SCHEMADN": schemadn,
                "DESCRIPTOR": descr
            })

        if override_prefixmap is not None:
            self.prefixmap_data = override_prefixmap
        else:
            self.prefixmap_data = open(setup_path("prefixMap.txt"), 'r').read()

        if additional_prefixmap is not None:
            for map in additional_prefixmap:
                self.prefixmap_data += "%s\n" % map

        self.prefixmap_data = b64encode(self.prefixmap_data)

        # We don't actually add this ldif, just parse it
        prefixmap_ldif = "dn: cn=schema\nprefixMap:: %s\n\n" % self.prefixmap_data
        self.set_from_ldif(prefixmap_ldif, self.schema_data)

    def set_from_ldif(self, pf, df):
        dsdb._dsdb_set_schema_from_ldif(self.ldb, pf, df)

    def write_to_tmp_ldb(self, schemadb_path):
        self.ldb.connect(url=schemadb_path)
        self.ldb.transaction_start()
        try:
            self.ldb.add_ldif("""dn: @ATTRIBUTES
linkID: INTEGER

dn: @INDEXLIST
@IDXATTR: linkID
@IDXATTR: attributeSyntax
""")
            # These bits of LDIF are supplied when the Schema object is created
            self.ldb.add_ldif(self.schema_dn_add)
            self.ldb.modify_ldif(self.schema_dn_modify)
            self.ldb.add_ldif(self.schema_data)
        except Exception:
            self.ldb.transaction_cancel()
            raise
        else:
            self.ldb.transaction_commit()

    # Return a hash with the forward attribute as a key and the back as the
    # value
    def linked_attributes(self):
        return get_linked_attributes(self.schemadn, self.ldb)

    def dnsyntax_attributes(self):
        return get_dnsyntax_attributes(self.schemadn, self.ldb)

    def convert_to_openldap(self, target, mapping):
        return dsdb._dsdb_convert_schema_to_openldap(self.ldb, target, mapping)
Example #36
0
    def run(self,
            computername,
            credopts=None,
            sambaopts=None,
            versionopts=None,
            H=None,
            editor=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)
        samdb = SamDB(url=H,
                      session_info=system_session(),
                      credentials=creds,
                      lp=lp)

        samaccountname = computername
        if not computername.endswith('$'):
            samaccountname = "%s$" % computername

        filter = (
            "(&(sAMAccountType=%d)(sAMAccountName=%s))" %
            (dsdb.ATYPE_WORKSTATION_TRUST, ldb.binary_encode(samaccountname)))

        domaindn = samdb.domain_dn()

        try:
            res = samdb.search(base=domaindn,
                               expression=filter,
                               scope=ldb.SCOPE_SUBTREE)
            computer_dn = res[0].dn
        except IndexError:
            raise CommandError('Unable to find computer "%s"' % (computername))

        if len(res) != 1:
            raise CommandError('Invalid number of results: for "%s": %d' %
                               ((computername), len(res)))

        msg = res[0]
        result_ldif = common.get_ldif_for_editor(samdb, msg)

        if editor is None:
            editor = os.environ.get('EDITOR')
            if editor is None:
                editor = 'vi'

        with tempfile.NamedTemporaryFile(suffix=".tmp") as t_file:
            t_file.write(get_bytes(result_ldif))
            t_file.flush()
            try:
                check_call([editor, t_file.name])
            except CalledProcessError as e:
                raise CalledProcessError("ERROR: ", e)
            with open(t_file.name) as edited_file:
                edited_message = edited_file.read()

        msgs_edited = samdb.parse_ldif(edited_message)
        msg_edited = next(msgs_edited)[1]

        res_msg_diff = samdb.msg_diff(msg, msg_edited)
        if len(res_msg_diff) == 0:
            self.outf.write("Nothing to do\n")
            return

        try:
            samdb.modify(res_msg_diff)
        except Exception as e:
            raise CommandError("Failed to modify computer '%s': " %
                               (computername, e))

        self.outf.write("Modified computer '%s' successfully\n" % computername)
Example #37
0
    def run(self,
            groupname,
            credopts=None,
            sambaopts=None,
            versionopts=None,
            H=None,
            hide_expired=False,
            hide_disabled=False,
            full_dn=False):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)

        try:
            samdb = SamDB(url=H,
                          session_info=system_session(),
                          credentials=creds,
                          lp=lp)

            search_filter = ("(&(objectClass=group)(sAMAccountName=%s))" %
                             ldb.binary_encode(groupname))
            try:
                res = samdb.search(samdb.domain_dn(),
                                   scope=ldb.SCOPE_SUBTREE,
                                   expression=(search_filter),
                                   attrs=["objectSid"])
                group_sid_binary = res[0].get('objectSid', idx=0)
            except IndexError:
                raise CommandError('Unable to find group "%s"' % (groupname))

            group_sid = ndr_unpack(security.dom_sid, group_sid_binary)
            (group_dom_sid, rid) = group_sid.split()
            group_sid_dn = "<SID=%s>" % (group_sid)

            filter_expires = ""
            if hide_expired is True:
                current_nttime = samdb.get_nttime()
                filter_expires = ("(|"
                                  "(!(accountExpires=*))"
                                  "(accountExpires=0)"
                                  "(accountExpires>=%u)"
                                  ")" % (current_nttime))

            filter_disabled = ""
            if hide_disabled is True:
                filter_disabled = "(!(userAccountControl:%s:=%u))" % (
                    ldb.OID_COMPARATOR_AND, UF_ACCOUNTDISABLE)

            filter = "(&(|(primaryGroupID=%s)(memberOf=%s))%s%s)" % (
                rid, group_sid_dn, filter_disabled, filter_expires)

            res = samdb.search(samdb.domain_dn(),
                               scope=ldb.SCOPE_SUBTREE,
                               expression=filter,
                               attrs=["samAccountName", "cn"])

            if (len(res) == 0):
                return

            for msg in res:
                if full_dn:
                    self.outf.write("%s\n" % msg.get("dn"))
                    continue

                member_name = msg.get("samAccountName", idx=0)
                if member_name is None:
                    member_name = msg.get("cn", idx=0)
                self.outf.write("%s\n" % member_name)

        except Exception as e:
            raise CommandError('Failed to list members of "%s" group - %s' %
                               (groupname, e))
Example #38
0
    def run(self,
            groupname,
            credopts=None,
            sambaopts=None,
            versionopts=None,
            H=None,
            mail_address=None,
            samaccountname=None,
            force_new_cn=None,
            reset_cn=None):
        # illegal options
        if force_new_cn and reset_cn:
            raise CommandError("It is not allowed to specify --force-new-cn "
                               "together with --reset-cn.")
        if force_new_cn == "":
            raise CommandError("Failed to rename group - delete protected "
                               "attribute 'CN'")
        if samaccountname == "":
            raise CommandError("Failed to rename group - delete protected "
                               "attribute 'sAMAccountName'")

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)
        samdb = SamDB(url=H,
                      session_info=system_session(),
                      credentials=creds,
                      lp=lp)
        domain_dn = ldb.Dn(samdb, samdb.domain_dn())

        filter = ("(&(objectClass=group)(samaccountname=%s))" %
                  ldb.binary_encode(groupname))
        try:
            res = samdb.search(base=domain_dn,
                               scope=ldb.SCOPE_SUBTREE,
                               expression=filter,
                               attrs=["sAMAccountName", "cn", "mail"])
            old_group = res[0]
            group_dn = old_group.dn
        except IndexError:
            raise CommandError('Unable to find group "%s"' % (groupname))

        group_parent_dn = group_dn.parent()
        old_cn = old_group["cn"][0]

        # get the actual and the new group cn and the new dn
        if force_new_cn is not None:
            new_cn = force_new_cn
        elif samaccountname is not None:
            new_cn = samaccountname
        else:
            new_cn = old_group["sAMAccountName"]

        # CN must change, if the new CN is different and the old CN is the
        # standard CN or the change is forced with force-new-cn or reset-cn
        expected_cn = old_group["sAMAccountName"]
        must_change_cn = str(old_cn) != str(new_cn) and \
                         (str(old_cn) == str(expected_cn) or \
                          reset_cn or bool(force_new_cn))

        new_group_dn = ldb.Dn(samdb, "CN=%s" % new_cn)
        new_group_dn.add_base(group_parent_dn)

        # format given attributes
        group_attrs = ldb.Message()
        group_attrs.dn = group_dn
        samdb.prepare_attr_replace(group_attrs, old_group, "sAMAccountName",
                                   samaccountname)
        samdb.prepare_attr_replace(group_attrs, old_group, "mail",
                                   mail_address)

        group_attributes_changed = len(group_attrs) > 0

        # update the group with formatted attributes
        samdb.transaction_start()
        try:
            if group_attributes_changed:
                samdb.modify(group_attrs)
            if must_change_cn:
                samdb.rename(group_dn, new_group_dn)
        except Exception as e:
            samdb.transaction_cancel()
            raise CommandError('Failed to rename group "%s"' % groupname, e)
        samdb.transaction_commit()

        if must_change_cn:
            self.outf.write('Renamed CN of group "%s" from "%s" to "%s" '
                            'successfully\n' % (groupname, old_cn, new_cn))

        if group_attributes_changed:
            self.outf.write('Following attributes of group "%s" have been '
                            'changed successfully:\n' % (groupname))
            for attr in group_attrs.keys():
                if attr == "dn":
                    continue
                self.outf.write(
                    '%s: %s\n' %
                    (attr,
                     group_attrs[attr] if group_attrs[attr] else '[removed]'))
Example #39
0
    def run(self, sambaopts=None, credopts=None, backup_file=None,
            targetdir=None, newservername=None, host_ip=None, host_ip6=None,
            site=None):
        if not (backup_file and os.path.exists(backup_file)):
            raise CommandError('Backup file not found.')
        if targetdir is None:
            raise CommandError('Please specify a target directory')
        # allow restoredc to install into a directory prepopulated by selftest
        if (os.path.exists(targetdir) and os.listdir(targetdir) and
            os.environ.get('SAMBA_SELFTEST') != '1'):
            raise CommandError('Target directory is not empty')
        if not newservername:
            raise CommandError('Server name required')

        logger = logging.getLogger()
        logger.setLevel(logging.DEBUG)
        logger.addHandler(logging.StreamHandler(sys.stdout))

        # ldapcmp prefers the server's netBIOS name in upper-case
        newservername = newservername.upper()

        # extract the backup .tar to a temp directory
        targetdir = os.path.abspath(targetdir)
        tf = tarfile.open(backup_file)
        tf.extractall(targetdir)
        tf.close()

        # use the smb.conf that got backed up, by default (save what was
        # actually backed up, before we mess with it)
        smbconf = os.path.join(targetdir, 'etc', 'smb.conf')
        shutil.copyfile(smbconf, smbconf + ".orig")

        # if a smb.conf was specified on the cmd line, then use that instead
        cli_smbconf = sambaopts.get_loadparm_path()
        if cli_smbconf:
            logger.info("Using %s as restored domain's smb.conf" % cli_smbconf)
            shutil.copyfile(cli_smbconf, smbconf)

        lp = samba.param.LoadParm()
        lp.load(smbconf)

        # open a DB connection to the restored DB
        private_dir = os.path.join(targetdir, 'private')
        samdb_path = os.path.join(private_dir, 'sam.ldb')
        samdb = SamDB(url=samdb_path, session_info=system_session(), lp=lp)
        backup_type = self.get_backup_type(samdb)

        if site is None:
            # There's no great way to work out the correct site to add the
            # restored DC to. By default, add it to Default-First-Site-Name,
            # creating the site if it doesn't already exist
            site = self.create_default_site(samdb, logger)
            logger.info("Adding new DC to site '{0}'".format(site))

        # read the naming contexts out of the DB
        res = samdb.search(base="", scope=ldb.SCOPE_BASE,
                           attrs=['namingContexts'])
        ncs = [str(r) for r in res[0].get('namingContexts')]

        # for offline backups we need to make sure the upToDateness info
        # contains the invocation-ID and highest-USN of the DC we backed up.
        # Otherwise replication propagation dampening won't correctly filter
        # objects created by that DC
        if backup_type == "offline":
            self.save_uptodate_vectors(samdb, ncs)

        # Create account using the join_add_objects function in the join object
        # We need namingContexts, account control flags, and the sid saved by
        # the backup process.
        creds = credopts.get_credentials(lp)
        ctx = DCJoinContext(logger, creds=creds, lp=lp, site=site,
                            forced_local_samdb=samdb,
                            netbios_name=newservername)
        ctx.nc_list = ncs
        ctx.full_nc_list = ncs
        ctx.userAccountControl = (samba.dsdb.UF_SERVER_TRUST_ACCOUNT |
                                  samba.dsdb.UF_TRUSTED_FOR_DELEGATION)

        # rewrite the smb.conf to make sure it uses the new targetdir settings.
        # (This doesn't update all filepaths in a customized config, but it
        # corrects the same paths that get set by a new provision)
        logger.info('Updating basic smb.conf settings...')
        make_smbconf(smbconf, newservername, ctx.domain_name,
                     ctx.realm, targetdir, lp=lp,
                     serverrole="active directory domain controller")

        # Get the SID saved by the backup process and create account
        res = samdb.search(base=ldb.Dn(samdb, "@SAMBA_DSDB"),
                           scope=ldb.SCOPE_BASE,
                           attrs=['sidForRestore'])
        sid = res[0].get('sidForRestore')[0]
        logger.info('Creating account with SID: ' + str(sid))
        ctx.join_add_objects(specified_sid=dom_sid(str(sid)))

        m = ldb.Message()
        m.dn = ldb.Dn(samdb, '@ROOTDSE')
        ntds_guid = str(ctx.ntds_guid)
        m["dsServiceName"] = ldb.MessageElement("<GUID=%s>" % ntds_guid,
                                                ldb.FLAG_MOD_REPLACE,
                                                "dsServiceName")
        samdb.modify(m)

        # if we renamed the backed-up domain, then we need to add the DNS
        # objects for the new realm (we do this in the restore, now that we
        # know the new DC's IP address)
        if backup_type == "rename":
            self.register_dns_zone(logger, samdb, lp, ctx.ntds_guid,
                                   host_ip, host_ip6, site)

        secrets_path = os.path.join(private_dir, 'secrets.ldb')
        secrets_ldb = Ldb(secrets_path, session_info=system_session(), lp=lp)
        secretsdb_self_join(secrets_ldb, domain=ctx.domain_name,
                            realm=ctx.realm, dnsdomain=ctx.dnsdomain,
                            netbiosname=ctx.myname, domainsid=ctx.domsid,
                            machinepass=ctx.acct_pass,
                            key_version_number=ctx.key_version_number,
                            secure_channel_type=misc.SEC_CHAN_BDC)

        # Seize DNS roles
        domain_dn = samdb.domain_dn()
        forest_dn = samba.dn_from_dns_name(samdb.forest_dns_name())
        domaindns_dn = ("CN=Infrastructure,DC=DomainDnsZones,", domain_dn)
        forestdns_dn = ("CN=Infrastructure,DC=ForestDnsZones,", forest_dn)
        for dn_prefix, dns_dn in [forestdns_dn, domaindns_dn]:
            if dns_dn not in ncs:
                continue
            full_dn = dn_prefix + dns_dn
            m = ldb.Message()
            m.dn = ldb.Dn(samdb, full_dn)
            m["fSMORoleOwner"] = ldb.MessageElement(samdb.get_dsServiceName(),
                                                    ldb.FLAG_MOD_REPLACE,
                                                    "fSMORoleOwner")
            samdb.modify(m)

        # Seize other roles
        for role in ['rid', 'pdc', 'naming', 'infrastructure', 'schema']:
            self.seize_role(role, samdb, force=True)

        # Get all DCs and remove them (this ensures these DCs cannot
        # replicate because they will not have a password)
        search_expr = "(&(objectClass=Server)(serverReference=*))"
        res = samdb.search(samdb.get_config_basedn(), scope=ldb.SCOPE_SUBTREE,
                           expression=search_expr)
        for m in res:
            cn = str(m.get('cn')[0])
            if cn != newservername:
                remove_dc(samdb, logger, cn)

        # Remove the repsFrom and repsTo from each NC to ensure we do
        # not try (and fail) to talk to the old DCs
        for nc in ncs:
            msg = ldb.Message()
            msg.dn = ldb.Dn(samdb, nc)

            msg["repsFrom"] = ldb.MessageElement([],
                                                 ldb.FLAG_MOD_REPLACE,
                                                 "repsFrom")
            msg["repsTo"] = ldb.MessageElement([],
                                               ldb.FLAG_MOD_REPLACE,
                                               "repsTo")
            samdb.modify(msg)

        # Update the krbtgt passwords twice, ensuring no tickets from
        # the old domain are valid
        update_krbtgt_account_password(samdb)
        update_krbtgt_account_password(samdb)

        # restore the sysvol directory from the backup tar file, including the
        # original NTACLs. Note that the backup_restore() will fail if not root
        sysvol_tar = os.path.join(targetdir, 'sysvol.tar.gz')
        dest_sysvol_dir = lp.get('path', 'sysvol')
        if not os.path.exists(dest_sysvol_dir):
            os.makedirs(dest_sysvol_dir)
        backup_restore(sysvol_tar, dest_sysvol_dir, samdb, smbconf)
        os.remove(sysvol_tar)

        # fix up any stale links to the old DCs we just removed
        logger.info("Fixing up any remaining references to the old DCs...")
        self.fix_old_dc_references(samdb)

        # Remove DB markers added by the backup process
        self.remove_backup_markers(samdb)

        logger.info("Backup file successfully restored to %s" % targetdir)
        logger.info("Please check the smb.conf settings are correct before "
                    "starting samba.")
Example #40
0
    def run(self, sambaopts=None, targetdir=None):

        logger = logging.getLogger()
        logger.setLevel(logging.DEBUG)
        logger.addHandler(logging.StreamHandler(sys.stdout))

        # Get the absolute paths of all the directories we're going to backup
        lp = sambaopts.get_loadparm()

        paths = samba.provision.provision_paths_from_lp(lp, lp.get('realm'))
        if not (paths.samdb and os.path.exists(paths.samdb)):
            logger.error("No database found at {0}".format(paths.samdb))
            raise CommandError('Please check you are root, and ' +
                               'are running this command on an AD DC')

        check_targetdir(logger, targetdir)

        samdb = SamDB(url=paths.samdb, session_info=system_session(), lp=lp)
        sid = get_sid_for_restore(samdb, logger)

        backup_dirs = [paths.private_dir, paths.state_dir,
                       os.path.dirname(paths.smbconf)]  # etc dir
        logger.info('running backup on dirs: {0}'.format(' '.join(backup_dirs)))

        # Recursively get all file paths in the backup directories
        all_files = []
        for backup_dir in backup_dirs:
            for (working_dir, _, filenames) in os.walk(backup_dir):
                if working_dir.startswith(paths.sysvol):
                    continue
                if working_dir.endswith('.sock') or '.sock/' in working_dir:
                    continue

                for filename in filenames:
                    if filename in all_files:
                        continue

                    # Assume existing backup files are from a previous backup.
                    # Delete and ignore.
                    if filename.endswith(self.backup_ext):
                        os.remove(os.path.join(working_dir, filename))
                        continue

                    # Sock files are autogenerated at runtime, ignore.
                    if filename.endswith('.sock'):
                        continue

                    all_files.append(os.path.join(working_dir, filename))

        # Backup secrets, sam.ldb and their downstream files
        self.backup_secrets(paths.private_dir, lp, logger)
        self.backup_smb_dbs(paths.private_dir, samdb, lp, logger)

        # Open the new backed up samdb, flag it as backed up, and write
        # the next SID so the restore tool can add objects.
        # WARNING: Don't change this code unless you know what you're doing.
        #          Writing to a .bak file only works because the DN being
        #          written to happens to be top level.
        samdb = SamDB(url=paths.samdb + self.backup_ext,
                      session_info=system_session(), lp=lp)
        time_str = get_timestamp()
        add_backup_marker(samdb, "backupDate", time_str)
        add_backup_marker(samdb, "sidForRestore", sid)
        add_backup_marker(samdb, "backupType", "offline")

        # Now handle all the LDB and TDB files that are not linked to
        # anything else.  Use transactions for LDBs.
        for path in all_files:
            if not os.path.exists(path + self.backup_ext):
                if path.endswith('.ldb'):
                    logger.info('Starting transaction on solo db: ' + path)
                    ldb_obj = Ldb(path, lp=lp)
                    ldb_obj.transaction_start()
                    logger.info('   running tdbbackup on the same file')
                    self.offline_tdb_copy(path)
                    ldb_obj.transaction_cancel()
                elif path.endswith('.tdb'):
                    logger.info('running tdbbackup on lone tdb file ' + path)
                    self.offline_tdb_copy(path)

        # Now make the backup tar file and add all
        # backed up files and any other files to it.
        temp_tar_dir = tempfile.mkdtemp(dir=targetdir,
                                        prefix='INCOMPLETEsambabackupfile')
        temp_tar_name = os.path.join(temp_tar_dir, "samba-backup.tar.bz2")
        tar = tarfile.open(temp_tar_name, 'w:bz2')

        logger.info('running offline ntacl backup of sysvol')
        sysvol_tar_fn = 'sysvol.tar.gz'
        sysvol_tar = os.path.join(temp_tar_dir, sysvol_tar_fn)
        backup_offline(paths.sysvol, sysvol_tar, samdb, paths.smbconf)
        tar.add(sysvol_tar, sysvol_tar_fn)
        os.remove(sysvol_tar)

        create_log_file(temp_tar_dir, lp, "offline", "localhost", True)
        backup_fn = os.path.join(temp_tar_dir, "backup.txt")
        tar.add(backup_fn, os.path.basename(backup_fn))
        os.remove(backup_fn)

        logger.info('building backup tar')
        for path in all_files:
            arc_path = self.get_arc_path(path, paths)

            if os.path.exists(path + self.backup_ext):
                logger.info('   adding backup ' + arc_path + self.backup_ext +
                            ' to tar and deleting file')
                tar.add(path + self.backup_ext, arcname=arc_path)
                os.remove(path + self.backup_ext)
            elif path.endswith('.ldb') or path.endswith('.tdb'):
                logger.info('   skipping ' + arc_path)
            else:
                logger.info('   adding misc file ' + arc_path)
                tar.add(path, arcname=arc_path)

        tar.close()
        os.rename(temp_tar_name,
                  os.path.join(targetdir,
                               'samba-backup-{0}.tar.bz2'.format(time_str)))
        os.rmdir(temp_tar_dir)
        logger.info('Backup succeeded.')
Example #41
0
    def run(self, sambaopts=None, credopts=None, versionopts=None, H=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)

        samdb = SamDB(url=H,
                      session_info=system_session(),
                      credentials=creds,
                      lp=lp)

        domain_dn = samdb.domain_dn()
        res = samdb.search(domain_dn,
                           scope=ldb.SCOPE_SUBTREE,
                           expression=("(objectClass=group)"),
                           attrs=["samaccountname", "member"])

        # first count up how many members each group has
        group_assignments = {}
        total_memberships = 0

        for msg in res:
            name = str(msg.get("samaccountname"))
            num_members = len(msg.get("member", default=[]))
            group_assignments[name] = num_members
            total_memberships += num_members

        num_groups = res.count
        self.outf.write("Group membership statistics*\n")
        self.outf.write("-------------------------------------------------\n")
        self.outf.write("Total groups: {0}\n".format(num_groups))
        self.outf.write("Total memberships: {0}\n".format(total_memberships))
        average = total_memberships / float(num_groups)
        self.outf.write("Average members per group: %.2f\n" % average)

        # find the max and median memberships (note that some default groups
        # always have zero members, so displaying the min is not very helpful)
        group_names = list(group_assignments.keys())
        group_members = list(group_assignments.values())
        idx = group_members.index(max(group_members))
        max_members = group_members[idx]
        self.outf.write("Max members: {0} ({1})\n".format(
            max_members, group_names[idx]))
        group_members.sort()
        midpoint = num_groups // 2
        median = group_members[midpoint]
        if num_groups % 2 == 0:
            median = (median + group_members[midpoint - 1]) / 2
        self.outf.write("Median members per group: {0}\n\n".format(median))

        # convert this to the frequency of group membership, i.e. how many
        # groups have 5 members, how many have 6 members, etc
        group_freqs = defaultdict(int)
        for group, num_members in group_assignments.items():
            group_freqs[num_members] += 1

        # now squash this down even further, so that we just display the number
        # of groups that fall into one of the following membership bands
        bands = [(0, 1), (2, 4), (5, 9), (10, 14), (15, 19), (20, 24),
                 (25, 29), (30, 39), (40, 49), (50, 59), (60, 69), (70, 79),
                 (80, 89), (90, 99), (100, 149), (150, 199), (200, 249),
                 (250, 299), (300, 399), (400, 499), (500, 999), (1000, 1999),
                 (2000, 2999), (3000, 3999), (4000, 4999), (5000, 9999),
                 (10000, max_members)]

        self.outf.write("Members        Number of Groups\n")
        self.outf.write("-------------------------------------------------\n")

        for band in bands:
            band_start = band[0]
            band_end = band[1]
            if band_start > max_members:
                break

            num_groups = self.num_in_range(band_start, band_end, group_freqs)

            if num_groups != 0:
                band_str = "{0}-{1}".format(band_start, band_end)
                self.outf.write("%13s  %u\n" % (band_str, num_groups))

        self.outf.write(
            "\n* Note this does not include nested group memberships\n")
Example #42
0
    def setUp(self):
        super(BaseSortTests, self).setUp()
        self.ldb = SamDB(host,
                         credentials=creds,
                         session_info=system_session(lp),
                         lp=lp)

        self.base_dn = self.ldb.domain_dn()
        self.ou = "ou=sort,%s" % self.base_dn
        if False:
            try:
                self.ldb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError as e:
                print("tried deleting %s, got error %s" % (self.ou, e))

        self.ldb.add({"dn": self.ou, "objectclass": "organizationalUnit"})
        self.users = []
        n = opts.elements
        for i in range(n):
            self.create_user(i, n)

        attrs = set(self.users[0].keys()) - set(['objectclass', 'dn'])
        self.binary_sorted_keys = attrs.intersection([
            'audio', 'photo', "msTSExpireDate4", 'serialNumber',
            "displayNamePrintable"
        ])

        self.numeric_sorted_keys = attrs.intersection(
            ['flags', 'accountExpires'])

        self.timestamp_keys = attrs.intersection(['msTSExpireDate4'])

        self.int64_keys = set(['accountExpires'])

        self.locale_sorted_keys = [
            x for x in attrs
            if x not in (self.binary_sorted_keys | self.numeric_sorted_keys)
        ]

        self.expected_results = {}
        self.expected_results_binary = {}

        for k in self.locale_sorted_keys:
            # Using key=locale.strxfrm fails on \x00
            forward = sorted((norm(x[k]) for x in self.users),
                             cmp=locale.strcoll)
            reverse = list(reversed(forward))
            self.expected_results[k] = (forward, reverse)

        for k in self.binary_sorted_keys:
            forward = sorted((x[k] for x in self.users))
            reverse = list(reversed(forward))
            self.expected_results_binary[k] = (forward, reverse)
            self.expected_results[k] = (forward, reverse)

        # Fix up some because Python gets it wrong, using Schwartzian tramsform
        for k in ('adminDisplayName', 'title', 'streetAddress',
                  'employeeNumber'):
            if k in self.expected_results:
                broken = self.expected_results[k][0]
                tmp = [(x.replace('\x00', ''), x) for x in broken]
                tmp.sort()
                fixed = [x[1] for x in tmp]
                self.expected_results[k] = (fixed, list(reversed(fixed)))
        for k in ('streetAddress', 'postalAddress'):
            if k in self.expected_results:
                c = {}
                for u in self.users:
                    x = u[k]
                    if x in c:
                        c[x] += 1
                        continue
                    c[x] = 1
                fixed = []
                for x in FIENDISH_TESTS:
                    fixed += [norm(x)] * c.get(x, 0)

                rev = list(reversed(fixed))
                self.expected_results[k] = (fixed, rev)
Example #43
0
class BaseSortTests(samba.tests.TestCase):
    avoid_tricky_sort = False
    maxDiff = 2000

    def create_user(self,
                    i,
                    n,
                    prefix='sorttest',
                    suffix='',
                    attrs=None,
                    tricky=False):
        name = "%s%d%s" % (prefix, i, suffix)
        user = {
            'cn':
            name,
            "objectclass":
            "user",
            'givenName':
            "abcdefghijklmnopqrstuvwxyz"[i % 26],
            "roomNumber":
            "%sb\x00c" % (n - i),
            "carLicense":
            "后来经",
            "employeeNumber":
            "%s%sx" % (abs(i * (99 - i)), '\n' * (i & 255)),
            "accountExpires":
            "%s" % (10**9 + 1000000 * i),
            "msTSExpireDate4":
            "19%02d0101010000.0Z" % (i % 100),
            "flags":
            str(i * (n - i)),
            "serialNumber":
            "abc %s%s%s" % (
                'AaBb |-/'[i & 7],
                ' 3z}'[i & 3],
                '"@'[i & 1],
            ),
            "comment":
            "Favourite colour is %d" % (n % (i + 1)),
        }

        if self.avoid_tricky_sort:
            # We are not even going to try passing tests that assume
            # some kind of Unicode awareness.
            for k, v in user.items():
                user[k] = re.sub(r'[^\w,.]', 'X', v)
        else:
            # Add some even trickier ones!
            fiendish_index = i % len(FIENDISH_TESTS)
            user.update({
                # Sort doesn't look past a NUL byte.
                "photo":
                "\x00%d" % (n - i),
                "audio":
                "%sn octet string %s%s ♫♬\x00lalala" %
                ('Aa'[i & 1], chr(i & 255), i),
                "displayNamePrintable":
                "%d\x00%c" % (i, i & 255),
                "adminDisplayName":
                "%d\x00b" % (n - i),
                "title":
                "%d%sb" % (n - i, '\x00' * i),

                # Names that vary only in case. Windows returns
                # equivalent addresses in the order they were put
                # in ('a st', 'A st',...). We don't check that.
                "street":
                "%s st" % (chr(65 | (i & 14) | ((i & 1) * 32))),
                "streetAddress":
                FIENDISH_TESTS[fiendish_index],
                "postalAddress":
                FIENDISH_TESTS[-fiendish_index],
            })

        if attrs is not None:
            user.update(attrs)

        user['dn'] = "cn=%s,%s" % (user['cn'], self.ou)

        self.users.append(user)
        self.ldb.add(user)
        return user

    def setUp(self):
        super(BaseSortTests, self).setUp()
        self.ldb = SamDB(host,
                         credentials=creds,
                         session_info=system_session(lp),
                         lp=lp)

        self.base_dn = self.ldb.domain_dn()
        self.ou = "ou=sort,%s" % self.base_dn
        if False:
            try:
                self.ldb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError as e:
                print("tried deleting %s, got error %s" % (self.ou, e))

        self.ldb.add({"dn": self.ou, "objectclass": "organizationalUnit"})
        self.users = []
        n = opts.elements
        for i in range(n):
            self.create_user(i, n)

        attrs = set(self.users[0].keys()) - set(['objectclass', 'dn'])
        self.binary_sorted_keys = attrs.intersection([
            'audio', 'photo', "msTSExpireDate4", 'serialNumber',
            "displayNamePrintable"
        ])

        self.numeric_sorted_keys = attrs.intersection(
            ['flags', 'accountExpires'])

        self.timestamp_keys = attrs.intersection(['msTSExpireDate4'])

        self.int64_keys = set(['accountExpires'])

        self.locale_sorted_keys = [
            x for x in attrs
            if x not in (self.binary_sorted_keys | self.numeric_sorted_keys)
        ]

        self.expected_results = {}
        self.expected_results_binary = {}

        for k in self.locale_sorted_keys:
            # Using key=locale.strxfrm fails on \x00
            forward = sorted((norm(x[k]) for x in self.users),
                             cmp=locale.strcoll)
            reverse = list(reversed(forward))
            self.expected_results[k] = (forward, reverse)

        for k in self.binary_sorted_keys:
            forward = sorted((x[k] for x in self.users))
            reverse = list(reversed(forward))
            self.expected_results_binary[k] = (forward, reverse)
            self.expected_results[k] = (forward, reverse)

        # Fix up some because Python gets it wrong, using Schwartzian tramsform
        for k in ('adminDisplayName', 'title', 'streetAddress',
                  'employeeNumber'):
            if k in self.expected_results:
                broken = self.expected_results[k][0]
                tmp = [(x.replace('\x00', ''), x) for x in broken]
                tmp.sort()
                fixed = [x[1] for x in tmp]
                self.expected_results[k] = (fixed, list(reversed(fixed)))
        for k in ('streetAddress', 'postalAddress'):
            if k in self.expected_results:
                c = {}
                for u in self.users:
                    x = u[k]
                    if x in c:
                        c[x] += 1
                        continue
                    c[x] = 1
                fixed = []
                for x in FIENDISH_TESTS:
                    fixed += [norm(x)] * c.get(x, 0)

                rev = list(reversed(fixed))
                self.expected_results[k] = (fixed, rev)

    def tearDown(self):
        super(BaseSortTests, self).tearDown()
        self.ldb.delete(self.ou, ['tree_delete:1'])

    def _test_server_sort_default(self):
        attrs = self.locale_sorted_keys

        for attr in attrs:
            for rev in (0, 1):
                res = self.ldb.search(
                    self.ou,
                    scope=ldb.SCOPE_ONELEVEL,
                    attrs=[attr],
                    controls=["server_sort:1:%d:%s" % (rev, attr)])
                self.assertEqual(len(res), len(self.users))

                expected_order = self.expected_results[attr][rev]
                received_order = [norm(x[attr][0]) for x in res]
                if expected_order != received_order:
                    print(attr, ['forward', 'reverse'][rev])
                    print("expected", expected_order)
                    print("recieved", received_order)
                    print("unnormalised:", [x[attr][0] for x in res])
                    print("unnormalised: «%s»" % '»  «'.join(x[attr][0]
                                                             for x in res))
                self.assertEquals(expected_order, received_order)

    def _test_server_sort_binary(self):
        for attr in self.binary_sorted_keys:
            for rev in (0, 1):
                res = self.ldb.search(
                    self.ou,
                    scope=ldb.SCOPE_ONELEVEL,
                    attrs=[attr],
                    controls=["server_sort:1:%d:%s" % (rev, attr)])

                self.assertEqual(len(res), len(self.users))
                expected_order = self.expected_results_binary[attr][rev]
                received_order = [x[attr][0] for x in res]
                if expected_order != received_order:
                    print(attr)
                    print(expected_order)
                    print(received_order)
                self.assertEquals(expected_order, received_order)

    def _test_server_sort_us_english(self):
        # Windows doesn't support many matching rules, but does allow
        # the locale specific sorts -- if it has the locale installed.
        # The most reliable locale is the default US English, which
        # won't change the sort order.

        for lang, oid in [
            ('en_US', '1.2.840.113556.1.4.1499'),
        ]:

            for attr in self.locale_sorted_keys:
                for rev in (0, 1):
                    res = self.ldb.search(
                        self.ou,
                        scope=ldb.SCOPE_ONELEVEL,
                        attrs=[attr],
                        controls=["server_sort:1:%d:%s:%s" % (rev, attr, oid)])

                    self.assertTrue(len(res) == len(self.users))
                    expected_order = self.expected_results[attr][rev]
                    received_order = [norm(x[attr][0]) for x in res]
                    if expected_order != received_order:
                        print(attr, lang)
                        print(['forward', 'reverse'][rev])
                        print("expected: ", expected_order)
                        print("recieved: ", received_order)
                        print("unnormalised:", [x[attr][0] for x in res])
                        print("unnormalised: «%s»" % '»  «'.join(x[attr][0]
                                                                 for x in res))

                    self.assertEquals(expected_order, received_order)

    def _test_server_sort_different_attr(self):
        def cmp_locale(a, b):
            return locale.strcoll(a[0], b[0])

        def cmp_binary(a, b):
            return cmp_fn(a[0], b[0])

        def cmp_numeric(a, b):
            return cmp_fn(int(a[0]), int(b[0]))

        # For testing simplicity, the attributes in here need to be
        # unique for each user. Otherwise there are multiple possible
        # valid answers.
        sort_functions = {
            'cn': cmp_binary,
            "employeeNumber": cmp_locale,
            "accountExpires": cmp_numeric,
            "msTSExpireDate4": cmp_binary
        }
        attrs = sort_functions.keys()
        attr_pairs = zip(attrs, attrs[1:] + attrs[:1])

        for sort_attr, result_attr in attr_pairs:
            forward = sorted(((norm(x[sort_attr]), norm(x[result_attr]))
                              for x in self.users),
                             cmp=sort_functions[sort_attr])
            reverse = list(reversed(forward))

            for rev in (0, 1):
                res = self.ldb.search(
                    self.ou,
                    scope=ldb.SCOPE_ONELEVEL,
                    attrs=[result_attr],
                    controls=["server_sort:1:%d:%s" % (rev, sort_attr)])
                self.assertEqual(len(res), len(self.users))
                pairs = (forward, reverse)[rev]

                expected_order = [x[1] for x in pairs]
                received_order = [norm(x[result_attr][0]) for x in res]

                if expected_order != received_order:
                    print(sort_attr, result_attr, ['forward', 'reverse'][rev])
                    print("expected", expected_order)
                    print("recieved", received_order)
                    print("unnormalised:", [x[result_attr][0] for x in res])
                    print("unnormalised: «%s»" % '»  «'.join(x[result_attr][0]
                                                             for x in res))
                    print("pairs:", pairs)
                    # There are bugs in Windows that we don't want (or
                    # know how) to replicate regarding timestamp sorting.
                    # Let's remind ourselves.
                    if result_attr == "msTSExpireDate4":
                        print('-' * 72)
                        print("This test fails against Windows with the "
                              "default number of elements (33).")
                        print("Try with --elements=27 (or similar).")
                        print('-' * 72)

                self.assertEquals(expected_order, received_order)
                for x in res:
                    if sort_attr in x:
                        self.fail('the search for %s should not return %s' %
                                  (result_attr, sort_attr))
Example #44
0
    def run(self,
            sambaopts=None,
            credopts=None,
            versionopts=None,
            H=None,
            verbose=False,
            base_dn=None,
            full_dn=False):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)

        samdb = SamDB(url=H,
                      session_info=system_session(),
                      credentials=creds,
                      lp=lp)
        attrs = ["samaccountname"]

        if verbose:
            attrs += ["grouptype", "member"]
        domain_dn = samdb.domain_dn()
        if base_dn:
            domain_dn = samdb.normalize_dn_in_domain(base_dn)
        res = samdb.search(domain_dn,
                           scope=ldb.SCOPE_SUBTREE,
                           expression=("(objectClass=group)"),
                           attrs=attrs)
        if (len(res) == 0):
            return

        if verbose:
            self.outf.write(
                "Group Name                                  Group Type      Group Scope  Members\n"
            )
            self.outf.write(
                "--------------------------------------------------------------------------------\n"
            )

            for msg in res:
                self.outf.write("%-44s" % msg.get("samaccountname", idx=0))
                hgtype = hex(int("%s" % msg["grouptype"]) & 0x00000000FFFFFFFF)
                if (hgtype == hex(int(security_group.get("Builtin")))):
                    self.outf.write("Security         Builtin  ")
                elif (hgtype == hex(int(security_group.get("Domain")))):
                    self.outf.write("Security         Domain   ")
                elif (hgtype == hex(int(security_group.get("Global")))):
                    self.outf.write("Security         Global   ")
                elif (hgtype == hex(int(security_group.get("Universal")))):
                    self.outf.write("Security         Universal")
                elif (hgtype == hex(int(distribution_group.get("Global")))):
                    self.outf.write("Distribution     Global   ")
                elif (hgtype == hex(int(distribution_group.get("Domain")))):
                    self.outf.write("Distribution     Domain   ")
                elif (hgtype == hex(int(distribution_group.get("Universal")))):
                    self.outf.write("Distribution     Universal")
                else:
                    self.outf.write("                          ")
                num_members = len(msg.get("member", default=[]))
                self.outf.write("    %6u\n" % num_members)
        else:
            for msg in res:
                if full_dn:
                    self.outf.write("%s\n" % msg.get("dn"))
                    continue

                self.outf.write("%s\n" % msg.get("samaccountname", idx=0))
    def setUp(self):
        super(ConfidentialAttrCommon, self).setUp()

        self.ldb_admin = SamDB(ldaphost,
                               credentials=creds,
                               session_info=system_session(lp),
                               lp=lp)
        self.user_pass = "******"
        self.base_dn = self.ldb_admin.domain_dn()
        self.schema_dn = self.ldb_admin.get_schema_basedn()
        self.sd_utils = sd_utils.SDUtils(self.ldb_admin)

        # the tests work by setting the 'Confidential' bit in the searchFlags
        # for an existing schema attribute. This only works against Windows if
        # the systemFlags does not have FLAG_SCHEMA_BASE_OBJECT set for the
        # schema attribute being modified. There are only a few attributes that
        # meet this criteria (most of which only apply to 'user' objects)
        self.conf_attr = "homePostalAddress"
        attr_cn = "CN=Address-Home"
        # schemaIdGuid for homePostalAddress (used for ACE tests)
        self.conf_attr_guid = "16775781-47f3-11d1-a9c3-0000f80367c1"
        self.conf_attr_sec_guid = "77b5b886-944a-11d1-aebd-0000f80367c1"
        self.attr_dn = "{0},{1}".format(attr_cn, self.schema_dn)

        userou = "OU=conf-attr-test"
        self.ou = "{0},{1}".format(userou, self.base_dn)
        self.ldb_admin.create_ou(self.ou)

        # use a common username prefix, so we can use sAMAccountName=CATC-* as
        # a search filter to only return the users we're interested in
        self.user_prefix = "catc-"

        # add a test object with this attribute set
        self.conf_value = "abcdef"
        self.conf_user = "******".format(self.user_prefix)
        self.ldb_admin.newuser(self.conf_user, self.user_pass, userou=userou)
        self.conf_dn = self.get_user_dn(self.conf_user)
        self.add_attr(self.conf_dn, self.conf_attr, self.conf_value)

        # add a sneaky user that will try to steal our secrets
        self.user = "******".format(self.user_prefix)
        self.ldb_admin.newuser(self.user, self.user_pass, userou=userou)
        self.ldb_user = self.get_ldb_connection(self.user, self.user_pass)

        self.all_users = [self.user, self.conf_user]

        # add some other users that also have confidential attributes, so we
        # check we don't disclose their details, particularly in '!' searches
        for i in range(1, 3):
            username = "******".format(self.user_prefix, i)
            self.ldb_admin.newuser(username, self.user_pass, userou=userou)
            userdn = self.get_user_dn(username)
            self.add_attr(userdn, self.conf_attr, "xyz{0}".format(i))
            self.all_users.append(username)

        # there are 4 users in the OU, plus the OU itself
        self.test_dn = self.ou
        self.total_objects = len(self.all_users) + 1
        self.objects_with_attr = 3

        # sanity-check the flag is not already set (this'll cause problems if
        # previous test run didn't clean up properly)
        search_flags = self.get_attr_search_flags(self.attr_dn)
        self.assertTrue(
            int(search_flags) & SEARCH_FLAG_CONFIDENTIAL == 0,
            "{0} searchFlags already {1}".format(self.conf_attr, search_flags))
Example #46
0
    def test_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session, lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        dn = "cn=test_db_lock_user,cn=users," + str(basedn)
        self.samdb.add({
            "dn": dn,
            "objectclass": "user",
        })
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the parent releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            self.samdb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)
class ConfidentialAttrCommon(samba.tests.TestCase):
    def setUp(self):
        super(ConfidentialAttrCommon, self).setUp()

        self.ldb_admin = SamDB(ldaphost,
                               credentials=creds,
                               session_info=system_session(lp),
                               lp=lp)
        self.user_pass = "******"
        self.base_dn = self.ldb_admin.domain_dn()
        self.schema_dn = self.ldb_admin.get_schema_basedn()
        self.sd_utils = sd_utils.SDUtils(self.ldb_admin)

        # the tests work by setting the 'Confidential' bit in the searchFlags
        # for an existing schema attribute. This only works against Windows if
        # the systemFlags does not have FLAG_SCHEMA_BASE_OBJECT set for the
        # schema attribute being modified. There are only a few attributes that
        # meet this criteria (most of which only apply to 'user' objects)
        self.conf_attr = "homePostalAddress"
        attr_cn = "CN=Address-Home"
        # schemaIdGuid for homePostalAddress (used for ACE tests)
        self.conf_attr_guid = "16775781-47f3-11d1-a9c3-0000f80367c1"
        self.conf_attr_sec_guid = "77b5b886-944a-11d1-aebd-0000f80367c1"
        self.attr_dn = "{0},{1}".format(attr_cn, self.schema_dn)

        userou = "OU=conf-attr-test"
        self.ou = "{0},{1}".format(userou, self.base_dn)
        self.ldb_admin.create_ou(self.ou)

        # use a common username prefix, so we can use sAMAccountName=CATC-* as
        # a search filter to only return the users we're interested in
        self.user_prefix = "catc-"

        # add a test object with this attribute set
        self.conf_value = "abcdef"
        self.conf_user = "******".format(self.user_prefix)
        self.ldb_admin.newuser(self.conf_user, self.user_pass, userou=userou)
        self.conf_dn = self.get_user_dn(self.conf_user)
        self.add_attr(self.conf_dn, self.conf_attr, self.conf_value)

        # add a sneaky user that will try to steal our secrets
        self.user = "******".format(self.user_prefix)
        self.ldb_admin.newuser(self.user, self.user_pass, userou=userou)
        self.ldb_user = self.get_ldb_connection(self.user, self.user_pass)

        self.all_users = [self.user, self.conf_user]

        # add some other users that also have confidential attributes, so we
        # check we don't disclose their details, particularly in '!' searches
        for i in range(1, 3):
            username = "******".format(self.user_prefix, i)
            self.ldb_admin.newuser(username, self.user_pass, userou=userou)
            userdn = self.get_user_dn(username)
            self.add_attr(userdn, self.conf_attr, "xyz{0}".format(i))
            self.all_users.append(username)

        # there are 4 users in the OU, plus the OU itself
        self.test_dn = self.ou
        self.total_objects = len(self.all_users) + 1
        self.objects_with_attr = 3

        # sanity-check the flag is not already set (this'll cause problems if
        # previous test run didn't clean up properly)
        search_flags = self.get_attr_search_flags(self.attr_dn)
        self.assertTrue(
            int(search_flags) & SEARCH_FLAG_CONFIDENTIAL == 0,
            "{0} searchFlags already {1}".format(self.conf_attr, search_flags))

    def tearDown(self):
        super(ConfidentialAttrCommon, self).tearDown()
        self.ldb_admin.delete(self.ou, ["tree_delete:1"])

    def add_attr(self, dn, attr, value):
        m = Message()
        m.dn = Dn(self.ldb_admin, dn)
        m[attr] = MessageElement(value, FLAG_MOD_ADD, attr)
        self.ldb_admin.modify(m)

    def set_attr_search_flags(self, attr_dn, flags):
        """Modifies the searchFlags for an object in the schema"""
        m = Message()
        m.dn = Dn(self.ldb_admin, attr_dn)
        m['searchFlags'] = MessageElement(flags, FLAG_MOD_REPLACE,
                                          'searchFlags')
        self.ldb_admin.modify(m)

        # note we have to update the schema for this change to take effect (on
        # Windows, at least)
        self.ldb_admin.set_schema_update_now()

    def get_attr_search_flags(self, attr_dn):
        """Marks the attribute under test as being confidential"""
        res = self.ldb_admin.search(attr_dn,
                                    scope=SCOPE_BASE,
                                    attrs=['searchFlags'])
        return res[0]['searchFlags'][0]

    def make_attr_confidential(self):
        """Marks the attribute under test as being confidential"""

        # work out the original 'searchFlags' value before we overwrite it
        old_value = self.get_attr_search_flags(self.attr_dn)

        self.set_attr_search_flags(self.attr_dn, str(SEARCH_FLAG_CONFIDENTIAL))

        # reset the value after the test completes
        self.addCleanup(self.set_attr_search_flags, self.attr_dn, old_value)

    # The behaviour of the DC can differ in some cases, depending on whether
    # we're talking to a Windows DC or a Samba DC
    def guess_dc_mode(self):
        # if we're in selftest, we can be pretty sure it's a Samba DC
        if os.environ.get('SAMBA_SELFTEST') == '1':
            return DC_MODE_RETURN_NONE

        searches = self.get_negative_match_all_searches()
        res = self.ldb_user.search(self.test_dn,
                                   expression=searches[0],
                                   scope=SCOPE_SUBTREE)

        # we default to DC_MODE_RETURN_NONE (samba).Update this if it
        # looks like we're talking to a Windows DC
        if len(res) == self.total_objects:
            return DC_MODE_RETURN_ALL

        # otherwise assume samba DC behaviour
        return DC_MODE_RETURN_NONE

    def get_user_dn(self, name):
        return "CN={0},{1}".format(name, self.ou)

    def get_user_sid_string(self, username):
        user_dn = self.get_user_dn(username)
        user_sid = self.sd_utils.get_object_sid(user_dn)
        return str(user_sid)

    def get_ldb_connection(self, target_username, target_password):
        creds_tmp = Credentials()
        creds_tmp.set_username(target_username)
        creds_tmp.set_password(target_password)
        creds_tmp.set_domain(creds.get_domain())
        creds_tmp.set_realm(creds.get_realm())
        creds_tmp.set_workstation(creds.get_workstation())
        features = creds_tmp.get_gensec_features() | gensec.FEATURE_SEAL
        creds_tmp.set_gensec_features(features)
        creds_tmp.set_kerberos_state(DONT_USE_KERBEROS)
        ldb_target = SamDB(url=ldaphost, credentials=creds_tmp, lp=lp)
        return ldb_target

    def assert_search_result(self, expected_num, expr, samdb):

        # try asking for different attributes back: None/all, the confidential
        # attribute itself, and a random unrelated attribute
        attr_filters = [None, ["*"], [self.conf_attr], ['name']]
        for attr in attr_filters:
            res = samdb.search(self.test_dn,
                               expression=expr,
                               scope=SCOPE_SUBTREE,
                               attrs=attr)
            self.assertTrue(
                len(res) == expected_num,
                "%u results, not %u for search %s, attr %s" %
                (len(res), expected_num, expr, str(attr)))

    # return a selection of searches that match exactly against the test object
    def get_exact_match_searches(self):
        first_char = self.conf_value[:1]
        last_char = self.conf_value[-1:]
        test_attr = self.conf_attr

        searches = [
            # search for the attribute using a sub-string wildcard
            # (which could reveal the attribute's actual value)
            "({0}={1}*)".format(test_attr, first_char),
            "({0}=*{1})".format(test_attr, last_char),

            # sanity-check equality against an exact match on value
            "({0}={1})".format(test_attr, self.conf_value),

            # '~=' searches don't work against Samba
            # sanity-check an approx search against an exact match on value
            # "({0}~={1})".format(test_attr, self.conf_value),

            # check wildcard in an AND search...
            "(&({0}={1}*)(objectclass=*))".format(test_attr, first_char),

            # ...an OR search (against another term that will never match)
            "(|({0}={1}*)(objectclass=banana))".format(test_attr, first_char)
        ]

        return searches

    # return searches that match any object with the attribute under test
    def get_match_all_searches(self):
        searches = [
            # check a full wildcard against the confidential attribute
            # (which could reveal the attribute's presence/absence)
            "({0}=*)".format(self.conf_attr),

            # check wildcard in an AND search...
            "(&(objectclass=*)({0}=*))".format(self.conf_attr),

            # ...an OR search (against another term that will never match)
            "(|(objectclass=banana)({0}=*))".format(self.conf_attr),

            # check <=, and >= expressions that would normally find a match
            "({0}>=0)".format(self.conf_attr),
            "({0}<=ZZZZZZZZZZZ)".format(self.conf_attr)
        ]

        return searches

    def assert_conf_attr_searches(self, has_rights_to=0, samdb=None):
        """Check searches against the attribute under test work as expected"""

        if samdb is None:
            samdb = self.ldb_user

        if has_rights_to == "all":
            has_rights_to = self.objects_with_attr

        # these first few searches we just expect to match against the one
        # object under test that we're trying to guess the value of
        expected_num = 1 if has_rights_to > 0 else 0
        for search in self.get_exact_match_searches():
            self.assert_search_result(expected_num, search, samdb)

        # these next searches will match any objects we have rights to see
        expected_num = has_rights_to
        for search in self.get_match_all_searches():
            self.assert_search_result(expected_num, search, samdb)

    # The following are double negative searches (i.e. NOT non-matching-
    # condition) which will therefore match ALL objects, including the test
    # object(s).
    def get_negative_match_all_searches(self):
        first_char = self.conf_value[:1]
        last_char = self.conf_value[-1:]
        not_first_char = chr(ord(first_char) + 1)
        not_last_char = chr(ord(last_char) + 1)

        searches = [
            "(!({0}={1}*))".format(self.conf_attr, not_first_char),
            "(!({0}=*{1}))".format(self.conf_attr, not_last_char)
        ]
        return searches

    # the following searches will not match against the test object(s). So
    # a user with sufficient rights will see an inverse sub-set of objects.
    # (An unprivileged user would either see all objects on Windows, or no
    # objects on Samba)
    def get_inverse_match_searches(self):
        first_char = self.conf_value[:1]
        last_char = self.conf_value[-1:]
        searches = [
            "(!({0}={1}*))".format(self.conf_attr, first_char),
            "(!({0}=*{1}))".format(self.conf_attr, last_char)
        ]
        return searches

    def negative_searches_all_rights(self, total_objects=None):
        expected_results = {}

        if total_objects is None:
            total_objects = self.total_objects

        # these searches should match ALL objects (including the OU)
        for search in self.get_negative_match_all_searches():
            expected_results[search] = total_objects

        # a ! wildcard should only match the objects without the attribute
        search = "(!({0}=*))".format(self.conf_attr)
        expected_results[search] = total_objects - self.objects_with_attr

        # whereas the inverse searches should match all objects *except* the
        # one under test
        for search in self.get_inverse_match_searches():
            expected_results[search] = total_objects - 1

        return expected_results

    # Returns the expected negative (i.e. '!') search behaviour when talking to
    # a DC with DC_MODE_RETURN_ALL behaviour, i.e. we assert that users
    # without rights always see ALL objects in '!' searches
    def negative_searches_return_all(self,
                                     has_rights_to=0,
                                     total_objects=None):
        """Asserts user without rights cannot see objects in '!' searches"""
        expected_results = {}

        if total_objects is None:
            total_objects = self.total_objects

        # Windows 'hides' objects by always returning all of them, so negative
        # searches that match all objects will simply return all objects
        for search in self.get_negative_match_all_searches():
            expected_results[search] = total_objects

        # if we're matching on everything except the one object under test
        # (i.e. the inverse subset), we'll still see all objects if
        # has_rights_to == 0. Or we'll see all bar one if has_rights_to == 1.
        inverse_searches = self.get_inverse_match_searches()
        inverse_searches += ["(!({0}=*))".format(self.conf_attr)]

        for search in inverse_searches:
            expected_results[search] = total_objects - has_rights_to

        return expected_results

    # Returns the expected negative (i.e. '!') search behaviour when talking to
    # a DC with DC_MODE_RETURN_NONE behaviour, i.e. we assert that users
    # without rights cannot see objects in '!' searches at all
    def negative_searches_return_none(self, has_rights_to=0):
        expected_results = {}

        # the 'match-all' searches should only return the objects we have
        # access rights to (if any)
        for search in self.get_negative_match_all_searches():
            expected_results[search] = has_rights_to

        # for inverse matches, we should NOT be told about any objects at all
        inverse_searches = self.get_inverse_match_searches()
        inverse_searches += ["(!({0}=*))".format(self.conf_attr)]
        for search in inverse_searches:
            expected_results[search] = 0

        return expected_results

    # Returns the expected negative (i.e. '!') search behaviour. This varies
    # depending on what type of DC we're talking to (i.e. Windows or Samba)
    # and what access rights the user has.
    # Note we only handle has_rights_to="all", 1 (the test object), or 0 (i.e.
    # we don't have rights to any objects)
    def negative_search_expected_results(self,
                                         has_rights_to,
                                         dc_mode,
                                         total_objects=None):

        if has_rights_to == "all":
            expect_results = self.negative_searches_all_rights(total_objects)

        # if it's a Samba DC, we only expect the 'match-all' searches to return
        # the objects that we have access rights to (all others are hidden).
        # Whereas Windows 'hides' the objects by always returning all of them
        elif dc_mode == DC_MODE_RETURN_NONE:
            expect_results = self.negative_searches_return_none(has_rights_to)
        else:
            expect_results = self.negative_searches_return_all(
                has_rights_to, total_objects)
        return expect_results

    def assert_negative_searches(self,
                                 has_rights_to=0,
                                 dc_mode=DC_MODE_RETURN_NONE,
                                 samdb=None):
        """Asserts user without rights cannot see objects in '!' searches"""

        if samdb is None:
            samdb = self.ldb_user

        # build a dictionary of key=search-expr, value=expected_num assertions
        expected_results = self.negative_search_expected_results(
            has_rights_to, dc_mode)

        for search, expected_num in expected_results.items():
            self.assert_search_result(expected_num, search, samdb)

    def assert_attr_returned(self, expect_attr, samdb, attrs):
        # does a query that should always return a successful result, and
        # checks whether the confidential attribute is present
        res = samdb.search(self.conf_dn,
                           expression="(objectClass=*)",
                           scope=SCOPE_SUBTREE,
                           attrs=attrs)
        self.assertTrue(len(res) == 1)

        attr_returned = False
        for msg in res:
            if self.conf_attr in msg:
                attr_returned = True
        self.assertEqual(expect_attr, attr_returned)

    def assert_attr_visible(self, expect_attr, samdb=None):
        if samdb is None:
            samdb = self.ldb_user

        # sanity-check confidential attribute is/isn't returned as expected
        # based on the filter attributes we ask for
        self.assert_attr_returned(expect_attr, samdb, attrs=None)
        self.assert_attr_returned(expect_attr, samdb, attrs=["*"])
        self.assert_attr_returned(expect_attr, samdb, attrs=[self.conf_attr])

        # filtering on a different attribute should never return the conf_attr
        self.assert_attr_returned(expect_attr=False,
                                  samdb=samdb,
                                  attrs=['name'])

    def assert_attr_visible_to_admin(self):
        # sanity-check the admin user can always see the confidential attribute
        self.assert_conf_attr_searches(has_rights_to="all",
                                       samdb=self.ldb_admin)
        self.assert_negative_searches(has_rights_to="all",
                                      samdb=self.ldb_admin)
        self.assert_attr_visible(expect_attr=True, samdb=self.ldb_admin)
Example #48
0
                self.assertTrue("msDS-isRODC" in ldb_msg)


if not "://" in host:
    if os.path.isfile(host):
        host = "tdb://%s" % host
    else:
        host = "ldap://%s" % host

ldb_options = []
if host.startswith("ldap://"):
    # user 'paged_search' module when connecting remotely
    ldb_options = ["modules:paged_searches"]

ldb = SamDB(host,
            credentials=creds,
            session_info=system_session(lp),
            lp=lp,
            options=ldb_options)

runner = SubunitTestRunner()
rc = 0
if not runner.run(unittest.makeSuite(SchemaTests)).wasSuccessful():
    rc = 1
if not runner.run(unittest.makeSuite(SchemaTests_msDS_IntId)).wasSuccessful():
    rc = 1
if not runner.run(unittest.makeSuite(SchemaTests_msDS_isRODC)).wasSuccessful():
    rc = 1

sys.exit(rc)
Example #49
0
class DsdbLockTestCase(SamDBTestCase):
    def test_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session, lp=self.lp)

            self.samdb.transaction_start()

            dn = "cn=test_db_lock_user,cn=users," + str(basedn)
            self.samdb.add({
                "dn": dn,
                "objectclass": "user",
            })
            self.samdb.delete(dn)

            # Obtain a write lock
            self.samdb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            self.samdb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        # Release the locks
        for l in res:
            pass

        end = time.time()
        self.assertGreater(end - start, 1.9)

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session, lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        dn = "cn=test_db_lock_user,cn=users," + str(basedn)
        self.samdb.add({
            "dn": dn,
            "objectclass": "user",
        })
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the parent releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            self.samdb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock3(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session, lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")

        # This will end up in the top level db
        dn = "@DSDB_LOCK_TEST"
        self.samdb.add({"dn": dn})
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        self.assertGreater(end - start, 1.9)
        os.write(w1, b"prepared")

        # Drop the write lock
        self.samdb.transaction_cancel()

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)
        self.assertEqual(got_pid, pid)

    def _test_full_db_lock1(self, backend_path):
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del (self.samdb)
            gc.collect()

            backenddb = ldb.Ldb(backend_path)

            backenddb.transaction_start()

            backenddb.add({"dn": "@DSDB_LOCK_TEST"})
            backenddb.delete("@DSDB_LOCK_TEST")

            # Obtain a write lock
            backenddb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            backenddb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        end = time.time()
        self.assertGreater(end - start, 1.9)

        # Release the locks
        for l in res:
            pass

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)

    def test_full_db_lock1_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)

    def _test_full_db_lock2(self, backend_path):
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:

            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session, lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)
            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # In the parent, close the main DB, re-open just one DB
        del (self.samdb)
        gc.collect()
        backenddb = ldb.Ldb(backend_path)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        backenddb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        backenddb.add({"dn": "@DSDB_LOCK_TEST"})
        backenddb.delete("@DSDB_LOCK_TEST")
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        backenddb.transaction_prepare_commit()
        end = time.time()

        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            backenddb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)

    def test_full_db_lock2_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)
Example #50
0
File: drs.py Project: szaydel/samba
class cmd_drs_replicate(Command):
    """Replicate a naming context between two DCs."""

    synopsis = "%prog <destinationDC> <sourceDC> <NC> [options]"

    takes_optiongroups = {
        "sambaopts": options.SambaOptions,
        "versionopts": options.VersionOptions,
        "credopts": options.CredentialsOptions,
    }

    takes_args = ["DEST_DC", "SOURCE_DC", "NC"]

    takes_options = [
        Option("--add-ref", help="use ADD_REF to add to repsTo on source", action="store_true"),
        Option("--sync-forced", help="use SYNC_FORCED to force inbound replication", action="store_true"),
        Option("--sync-all", help="use SYNC_ALL to replicate from all DCs", action="store_true"),
        Option("--full-sync", help="resync all objects", action="store_true"),
        Option("--local", help="pull changes directly into the local database (destination DC is ignored)", action="store_true"),
        Option("--local-online", help="pull changes into the local database (destination DC is ignored) as a normal online replication", action="store_true"),
        Option("--async-op", help="use ASYNC_OP for the replication", action="store_true"),
        Option("--single-object", help="Replicate only the object specified, instead of the whole Naming Context (only with --local)", action="store_true"),
    ]

    def drs_local_replicate(self, SOURCE_DC, NC, full_sync=False,
                            single_object=False,
                            sync_forced=False):
        '''replicate from a source DC to the local SAM'''

        self.server = SOURCE_DC
        drsuapi_connect(self)

        # Override the default flag LDB_FLG_DONT_CREATE_DB
        self.local_samdb = SamDB(session_info=system_session(), url=None,
                                 credentials=self.creds, lp=self.lp,
                                 flags=0)

        self.samdb = SamDB(url="ldap://%s" % self.server,
                           session_info=system_session(),
                           credentials=self.creds, lp=self.lp)

        # work out the source and destination GUIDs
        res = self.local_samdb.search(base="", scope=ldb.SCOPE_BASE,
                                      attrs=["dsServiceName"])
        self.ntds_dn = res[0]["dsServiceName"][0]

        res = self.local_samdb.search(base=self.ntds_dn, scope=ldb.SCOPE_BASE,
                                      attrs=["objectGUID"])
        self.ntds_guid = misc.GUID(
            self.samdb.schema_format_value("objectGUID",
                                           res[0]["objectGUID"][0]))

        source_dsa_invocation_id = misc.GUID(self.samdb.get_invocation_id())
        dest_dsa_invocation_id = misc.GUID(self.local_samdb.get_invocation_id())
        destination_dsa_guid = self.ntds_guid

        exop = drsuapi.DRSUAPI_EXOP_NONE

        if single_object:
            exop = drsuapi.DRSUAPI_EXOP_REPL_OBJ
            full_sync = True

        self.samdb.transaction_start()
        repl = drs_utils.drs_Replicate("ncacn_ip_tcp:%s[seal]" % self.server,
                                       self.lp,
                                       self.creds, self.local_samdb,
                                       dest_dsa_invocation_id)

        # Work out if we are an RODC, so that a forced local replicate
        # with the admin pw does not sync passwords
        rodc = self.local_samdb.am_rodc()
        try:
            (num_objects, num_links) = repl.replicate(NC,
                                                      source_dsa_invocation_id,
                                                      destination_dsa_guid,
                                                      rodc=rodc,
                                                      full_sync=full_sync,
                                                      exop=exop,
                                                      sync_forced=sync_forced)
        except Exception as e:
            raise CommandError("Error replicating DN %s" % NC, e)
        self.samdb.transaction_commit()

        if full_sync:
            self.message("Full Replication of all %d objects and %d links "
                         "from %s to %s was successful." %
                         (num_objects, num_links, SOURCE_DC,
                          self.local_samdb.url))
        else:
            self.message("Incremental replication of %d objects and %d links "
                         "from %s to %s was successful." %
                         (num_objects, num_links, SOURCE_DC,
                          self.local_samdb.url))

    def run(self, DEST_DC, SOURCE_DC, NC,
            add_ref=False, sync_forced=False, sync_all=False, full_sync=False,
            local=False, local_online=False, async_op=False, single_object=False,
            sambaopts=None, credopts=None, versionopts=None):

        self.server = DEST_DC
        self.lp = sambaopts.get_loadparm()

        self.creds = credopts.get_credentials(self.lp, fallback_machine=True)

        if local:
            self.drs_local_replicate(SOURCE_DC, NC, full_sync=full_sync,
                                     single_object=single_object,
                                     sync_forced=sync_forced)
            return

        if local_online:
            server_bind = drsuapi.drsuapi("irpc:dreplsrv", lp_ctx=self.lp)
            server_bind_handle = misc.policy_handle()
        else:
            drsuapi_connect(self)
            server_bind = self.drsuapi
            server_bind_handle = self.drsuapi_handle

        if not async_op:
            # Give the sync replication 5 minutes time
            server_bind.request_timeout = 5 * 60

        samdb_connect(self)

        # we need to find the NTDS GUID of the source DC
        msg = self.samdb.search(base=self.samdb.get_config_basedn(),
                                expression="(&(objectCategory=server)(|(name=%s)(dNSHostName=%s)))" % (
            ldb.binary_encode(SOURCE_DC),
            ldb.binary_encode(SOURCE_DC)),
                                attrs=[])
        if len(msg) == 0:
            raise CommandError("Failed to find source DC %s" % SOURCE_DC)
        server_dn = msg[0]['dn']

        msg = self.samdb.search(base=server_dn, scope=ldb.SCOPE_ONELEVEL,
                                expression="(|(objectCategory=nTDSDSA)(objectCategory=nTDSDSARO))",
                                attrs=['objectGUID', 'options'])
        if len(msg) == 0:
            raise CommandError("Failed to find source NTDS DN %s" % SOURCE_DC)
        source_dsa_guid = msg[0]['objectGUID'][0]
        dsa_options = int(attr_default(msg, 'options', 0))

        req_options = 0
        if not (dsa_options & dsdb.DS_NTDSDSA_OPT_DISABLE_OUTBOUND_REPL):
            req_options |= drsuapi.DRSUAPI_DRS_WRIT_REP
        if add_ref:
            req_options |= drsuapi.DRSUAPI_DRS_ADD_REF
        if sync_forced:
            req_options |= drsuapi.DRSUAPI_DRS_SYNC_FORCED
        if sync_all:
            req_options |= drsuapi.DRSUAPI_DRS_SYNC_ALL
        if full_sync:
            req_options |= drsuapi.DRSUAPI_DRS_FULL_SYNC_NOW
        if async_op:
            req_options |= drsuapi.DRSUAPI_DRS_ASYNC_OP

        try:
            drs_utils.sendDsReplicaSync(server_bind, server_bind_handle, source_dsa_guid, NC, req_options)
        except drs_utils.drsException as estr:
            raise CommandError("DsReplicaSync failed", estr)
        if async_op:
            self.message("Replicate from %s to %s was started." % (SOURCE_DC, DEST_DC))
        else:
            self.message("Replicate from %s to %s was successful." % (SOURCE_DC, DEST_DC))
Example #51
0
class DsdbTests(TestCase):
    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

        # Create a test user
        user_name = "dsdb-user-" + str(uuid.uuid4().hex[0:6])
        user_pass = samba.generate_random_password(32, 32)
        user_description = "Test user for dsdb test"

        base_dn = self.samdb.domain_dn()

        self.account_dn = "cn=" + user_name + ",cn=Users," + base_dn
        self.samdb.newuser(username=user_name,
                           password=user_pass,
                           description=user_description)
        # Cleanup (teardown)
        self.addCleanup(delete_force, self.samdb, self.account_dn)

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEquals(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, [
            "local_oid:1.3.6.1.4.1.7165.4.3.14:0",
            "local_oid:1.3.6.1.4.1.7165.4.3.25:0"
        ])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val",
                                                ldb.FLAG_MOD_REPLACE,
                                                "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(13),
                          "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 2)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(
            dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description",
                                                      version + 2)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "description"),
            version + 2)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:0" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:1" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            (errno, estr) = e.args
            if errno != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail(
                    "Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                    % e[1])

    # Allocate a unique RID for use in the objectSID tests.
    #
    def allocate_rid(self):
        self.samdb.transaction_start()
        try:
            rid = self.samdb.allocate_rid()
        except:
            self.samdb.transaction_cancel()
            raise
        self.samdb.transaction_commit()
        return str(rid)

    # Ensure that duplicate objectSID's are permitted for foreign security
    # principals.
    #
    def test_duplicate_objectSIDs_allowed_on_foreign_security_principals(self):

        #
        # We need to build a foreign security principal SID
        # i.e a  SID not in the current domain.
        #
        dom_sid = self.samdb.get_domain_sid()
        if str(dom_sid).endswith("0"):
            c = "9"
        else:
            c = "0"
        sid = str(dom_sid)[:-1] + c + "-1000"
        basedn = self.samdb.get_default_basedn()
        dn = "CN=%s,CN=ForeignSecurityPrincipals,%s" % (sid, basedn)
        self.samdb.add({"dn": dn, "objectClass": "foreignSecurityPrincipal"})

        self.samdb.delete(dn)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"
            })
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.fail("Got unexpected exception %d - %s " % (code, msg))

        # cleanup
        self.samdb.delete(dn)

    #
    # Duplicate objectSID's should not be permitted for sids in the local
    # domain. The test sequence is add an object, delete it, then attempt to
    # re-add it, this should fail with a constraint violation
    #
    def test_duplicate_objectSIDs_not_allowed_on_local_objects(self):

        dom_sid = self.samdb.get_domain_sid()
        rid = self.allocate_rid()
        sid_str = str(dom_sid) + "-" + rid
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_01"
        dn = "cn=%s,cn=Users,%s" % (cn, basedn)

        self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
        self.samdb.delete(dn)

        try:
            self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            if code != ldb.ERR_CONSTRAINT_VIOLATION:
                self.fail("Got %d - %s should have got "
                          "LDB_ERR_CONSTRAINT_VIOLATION" % (code, msg))

    def test_normalize_dn_in_domain_full(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        full_str = str(full_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_str))

    def test_normalize_dn_in_domain_part(self):
        domain_dn = self.samdb.domain_dn()

        part_str = "CN=Users"

        full_dn = ldb.Dn(self.samdb, part_str)
        full_dn.add_base(domain_dn)

        # That is, the domain DN appended
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(part_str))

    def test_normalize_dn_in_domain_full_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_dn))

    def test_normalize_dn_in_domain_part_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        # That is, the domain DN appended
        self.assertEqual(
            ldb.Dn(self.samdb,
                   str(part_dn) + "," + str(domain_dn)),
            self.samdb.normalize_dn_in_domain(part_dn))
Example #52
0
    def _test_full_db_lock2(self, backend_path):
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:

            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session, lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)
            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # In the parent, close the main DB, re-open just one DB
        del (self.samdb)
        gc.collect()
        backenddb = ldb.Ldb(backend_path)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        backenddb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        backenddb.add({"dn": "@DSDB_LOCK_TEST"})
        backenddb.delete("@DSDB_LOCK_TEST")
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        backenddb.transaction_prepare_commit()
        end = time.time()

        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            backenddb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)
Example #53
0
class LDAPNotificationTest(samba.tests.TestCase):

    def setUp(self):
        super(LDAPNotificationTest, self).setUp()
        self.ldb = SamDB(url, credentials=creds, session_info=system_session(lp), lp=lp)
        self.base_dn = self.ldb.domain_dn()

        res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
        self.assertEquals(len(res), 1)

        self.user_sid_dn = "<SID=%s>" % str(ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["tokenGroups"][0]))

    def test_simple_search(self):
        """Testing a notification with an modify and a timeout"""
        if not url.startswith("ldap"):
            self.fail(msg="This test is only valid on ldap")

        msg1 = None
        search1 = self.ldb.search_iterator(base=self.user_sid_dn,
                                           expression="(objectClass=*)",
                                           scope=ldb.SCOPE_SUBTREE,
                                           attrs=["name", "objectGUID", "displayName"])
        for reply in search1:
            self.assertIsInstance(reply, ldb.Message)
            self.assertIsNone(msg1)
            msg1 = reply
        res1 = search1.result()

        search2 = self.ldb.search_iterator(base=self.base_dn,
                                           expression="(objectClass=*)",
                                           scope=ldb.SCOPE_SUBTREE,
                                           attrs=["name", "objectGUID", "displayName"])
        refs2 = 0
        msg2 = None
        for reply in search2:
            if isinstance(reply, str):
                refs2 += 1
                continue
            self.assertIsInstance(reply, ldb.Message)
            if reply["objectGUID"][0] == msg1["objectGUID"][0]:
                self.assertIsNone(msg2)
                msg2 = reply
                self.assertEqual(msg1.dn, msg2.dn)
                self.assertEqual(len(msg1), len(msg2))
                self.assertEqual(msg1["name"], msg2["name"])
                #self.assertEqual(msg1["displayName"], msg2["displayName"])
        res2 = search2.result()

        self.ldb.modify_ldif("""
dn: """ + self.user_sid_dn + """
changetype: modify
replace: otherLoginWorkstations
otherLoginWorkstations: BEFORE"
""")
        notify1 = self.ldb.search_iterator(base=self.base_dn,
                                           expression="(objectClass=*)",
                                           scope=ldb.SCOPE_SUBTREE,
                                           attrs=["name", "objectGUID", "displayName"],
                                           controls=["notification:1"],
                                           timeout=1)

        self.ldb.modify_ldif("""
dn: """ + self.user_sid_dn + """
changetype: modify
replace: otherLoginWorkstations
otherLoginWorkstations: AFTER"
""")

        msg3 = None
        for reply in notify1:
            self.assertIsInstance(reply, ldb.Message)
            if reply["objectGUID"][0] == msg1["objectGUID"][0]:
                self.assertIsNone(msg3)
                msg3 = reply
                self.assertEqual(msg1.dn, msg3.dn)
                self.assertEqual(len(msg1), len(msg3))
                self.assertEqual(msg1["name"], msg3["name"])
                #self.assertEqual(msg1["displayName"], msg3["displayName"])
        try:
            res = notify1.result()
            self.fail()
        except LdbError as e10:
            (num, _) = e10.args
            self.assertEquals(num, ERR_TIME_LIMIT_EXCEEDED)
        self.assertIsNotNone(msg3)

        self.ldb.modify_ldif("""
dn: """ + self.user_sid_dn + """
changetype: delete
delete: otherLoginWorkstations
""")

    def test_max_search(self):
        """Testing the max allowed notifications"""
        if not url.startswith("ldap"):
            self.fail(msg="This test is only valid on ldap")

        max_notifications = 5

        notifies = [None] * (max_notifications + 1)
        for i in range(0, max_notifications + 1):
            notifies[i] = self.ldb.search_iterator(base=self.base_dn,
                                                   expression="(objectClass=*)",
                                                   scope=ldb.SCOPE_SUBTREE,
                                                   attrs=["name"],
                                                   controls=["notification:1"],
                                                   timeout=1)
        num_admin_limit = 0
        num_time_limit = 0
        for i in range(0, max_notifications + 1):
            try:
                for msg in notifies[i]:
                    continue
                res = notifies[i].result()
                self.fail()
            except LdbError as e:
                (num, _) = e.args
                if num == ERR_ADMIN_LIMIT_EXCEEDED:
                    num_admin_limit += 1
                    continue
                if num == ERR_TIME_LIMIT_EXCEEDED:
                    num_time_limit += 1
                    continue
                raise
        self.assertEqual(num_admin_limit, 1)
        self.assertEqual(num_time_limit, max_notifications)

    def test_invalid_filter(self):
        """Testing invalid filters for notifications"""
        if not url.startswith("ldap"):
            self.fail(msg="This test is only valid on ldap")

        valid_attrs = ["objectClass", "objectGUID", "distinguishedName", "name"]

        for va in valid_attrs:
            try:
                hnd = self.ldb.search_iterator(base=self.base_dn,
                                               expression="(%s=*)" % va,
                                               scope=ldb.SCOPE_SUBTREE,
                                               attrs=["name"],
                                               controls=["notification:1"],
                                               timeout=1)
                for reply in hnd:
                    self.fail()
                res = hnd.result()
                self.fail()
            except LdbError as e1:
                (num, _) = e1.args
                self.assertEquals(num, ERR_TIME_LIMIT_EXCEEDED)

            try:
                hnd = self.ldb.search_iterator(base=self.base_dn,
                                               expression="(|(%s=*)(%s=value))" % (va, va),
                                               scope=ldb.SCOPE_SUBTREE,
                                               attrs=["name"],
                                               controls=["notification:1"],
                                               timeout=1)
                for reply in hnd:
                    self.fail()
                res = hnd.result()
                self.fail()
            except LdbError as e2:
                (num, _) = e2.args
                self.assertEquals(num, ERR_TIME_LIMIT_EXCEEDED)

            try:
                hnd = self.ldb.search_iterator(base=self.base_dn,
                                               expression="(&(%s=*)(%s=value))" % (va, va),
                                               scope=ldb.SCOPE_SUBTREE,
                                               attrs=["name"],
                                               controls=["notification:1"],
                                               timeout=0)
                for reply in hnd:
                    self.fail()
                res = hnd.result()
                self.fail()
            except LdbError as e3:
                (num, _) = e3.args
                self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

            try:
                hnd = self.ldb.search_iterator(base=self.base_dn,
                                               expression="(%s=value)" % va,
                                               scope=ldb.SCOPE_SUBTREE,
                                               attrs=["name"],
                                               controls=["notification:1"],
                                               timeout=0)
                for reply in hnd:
                    self.fail()
                res = hnd.result()
                self.fail()
            except LdbError as e4:
                (num, _) = e4.args
                self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

            try:
                hnd = self.ldb.search_iterator(base=self.base_dn,
                                               expression="(%s>=value)" % va,
                                               scope=ldb.SCOPE_SUBTREE,
                                               attrs=["name"],
                                               controls=["notification:1"],
                                               timeout=0)
                for reply in hnd:
                    self.fail()
                res = hnd.result()
                self.fail()
            except LdbError as e5:
                (num, _) = e5.args
                self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

            try:
                hnd = self.ldb.search_iterator(base=self.base_dn,
                                               expression="(%s<=value)" % va,
                                               scope=ldb.SCOPE_SUBTREE,
                                               attrs=["name"],
                                               controls=["notification:1"],
                                               timeout=0)
                for reply in hnd:
                    self.fail()
                res = hnd.result()
                self.fail()
            except LdbError as e6:
                (num, _) = e6.args
                self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

            try:
                hnd = self.ldb.search_iterator(base=self.base_dn,
                                               expression="(%s=*value*)" % va,
                                               scope=ldb.SCOPE_SUBTREE,
                                               attrs=["name"],
                                               controls=["notification:1"],
                                               timeout=0)
                for reply in hnd:
                    self.fail()
                res = hnd.result()
                self.fail()
            except LdbError as e7:
                (num, _) = e7.args
                self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

            try:
                hnd = self.ldb.search_iterator(base=self.base_dn,
                                               expression="(!(%s=*))" % va,
                                               scope=ldb.SCOPE_SUBTREE,
                                               attrs=["name"],
                                               controls=["notification:1"],
                                               timeout=0)
                for reply in hnd:
                    self.fail()
                res = hnd.result()
                self.fail()
            except LdbError as e8:
                (num, _) = e8.args
                self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        res = self.ldb.search(base=self.ldb.get_schema_basedn(),
                              expression="(objectClass=attributeSchema)",
                              scope=ldb.SCOPE_ONELEVEL,
                              attrs=["lDAPDisplayName"],
                              controls=["paged_results:1:2500"])
        for msg in res:
            va = str(msg["lDAPDisplayName"][0])
            if va in valid_attrs:
                continue

            try:
                hnd = self.ldb.search_iterator(base=self.base_dn,
                                               expression="(%s=*)" % va,
                                               scope=ldb.SCOPE_SUBTREE,
                                               attrs=["name"],
                                               controls=["notification:1"],
                                               timeout=0)
                for reply in hnd:
                    self.fail()
                res = hnd.result()
                self.fail()
            except LdbError as e9:
                (num, _) = e9.args
                if num != ERR_UNWILLING_TO_PERFORM:
                    print("va[%s]" % va)
                self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)

        try:
            va = "noneAttributeName"
            hnd = self.ldb.search_iterator(base=self.base_dn,
                                           expression="(%s=*)" % va,
                                           scope=ldb.SCOPE_SUBTREE,
                                           attrs=["name"],
                                           controls=["notification:1"],
                                           timeout=0)
            for reply in hnd:
                self.fail()
            res = hnd.result()
            self.fail()
        except LdbError as e11:
            (num, _) = e11.args
            if num != ERR_UNWILLING_TO_PERFORM:
                print("va[%s]" % va)
            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)
Example #54
0
                                            samdb.domain_dns_name())
        new_dns_name = '%s._msdcs.%s' % (samdb.get_ntds_GUID(),
                                         samdb.domain_dns_name())
    elif role == "forestdns":
        master_dns_name = '%s._msdcs.%s' % (master_guid,
                                            samdb.forest_dns_name())
        new_dns_name = '%s._msdcs.%s' % (samdb.get_ntds_GUID(),
                                         samdb.forest_dns_name())

    new_owner = samdb.get_dsServiceName()

    if master_dns_name != new_dns_name:
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)
        samdb = SamDB(url="ldap://%s" % (master_dns_name),
                      session_info=system_session(),
                      credentials=creds, lp=lp)

        m = ldb.Message()
        m.dn = ldb.Dn(samdb, role_object)
        m["fSMORoleOwner"] = ldb.MessageElement(master_owner,
                                                ldb.FLAG_MOD_DELETE,
                                                "fSMORoleOwner")

        try:
            samdb.modify(m)
        except LdbError, (num, msg):
            raise CommandError("Failed to delete role '%s': %s" %
                               (role, msg))

        m = ldb.Message()
Example #55
0
    def run(self, new_domain_name, new_dns_realm, sambaopts=None,
            credopts=None, server=None, targetdir=None, keep_dns_realm=False,
            no_secrets=False, backend_store=None):
        logger = self.get_logger()
        logger.setLevel(logging.INFO)

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        # Make sure we have all the required args.
        if server is None:
            raise CommandError('Server required')

        check_targetdir(logger, targetdir)

        delete_old_dns = not keep_dns_realm

        new_dns_realm = new_dns_realm.lower()
        new_domain_name = new_domain_name.upper()

        new_base_dn = samba.dn_from_dns_name(new_dns_realm)
        logger.info("New realm for backed up domain: %s" % new_dns_realm)
        logger.info("New base DN for backed up domain: %s" % new_base_dn)
        logger.info("New domain NetBIOS name: %s" % new_domain_name)

        tmpdir = tempfile.mkdtemp(dir=targetdir)

        # setup a join-context for cloning the remote server
        include_secrets = not no_secrets
        ctx = DCCloneAndRenameContext(new_base_dn, new_domain_name,
                                      new_dns_realm, logger=logger,
                                      creds=creds, lp=lp,
                                      include_secrets=include_secrets,
                                      dns_backend='SAMBA_INTERNAL',
                                      server=server, targetdir=tmpdir,
                                      backend_store=backend_store)

        # sanity-check we're not "renaming" the domain to the same values
        old_domain = ctx.domain_name
        if old_domain == new_domain_name:
            shutil.rmtree(tmpdir)
            raise CommandError("Cannot use the current domain NetBIOS name.")

        old_realm = ctx.realm
        if old_realm == new_dns_realm:
            shutil.rmtree(tmpdir)
            raise CommandError("Cannot use the current domain DNS realm.")

        # do the clone/rename
        ctx.do_join()

        # get the paths used for the clone, then drop the old samdb connection
        del ctx.local_samdb
        paths = ctx.paths

        # get a free RID to use as the new DC's SID (when it gets restored)
        remote_sam = SamDB(url='ldap://' + server, credentials=creds,
                           session_info=system_session(), lp=lp)
        new_sid = get_sid_for_restore(remote_sam, logger)

        # Grab the remote DC's sysvol files and bundle them into a tar file.
        # Note we end up with 2 sysvol dirs - the original domain's files (that
        # use the old realm) backed here, as well as default files generated
        # for the new realm as part of the clone/join.
        sysvol_tar = os.path.join(tmpdir, 'sysvol.tar.gz')
        smb_conn = smb_sysvol_conn(server, lp, creds)
        backup_online(smb_conn, sysvol_tar, remote_sam.get_domain_sid())

        # connect to the local DB (making sure we use the new/renamed config)
        lp.load(paths.smbconf)
        samdb = SamDB(url=paths.samdb, session_info=system_session(), lp=lp)

        # Edit the cloned sam.ldb to mark it as a backup
        time_str = get_timestamp()
        add_backup_marker(samdb, "backupDate", time_str)
        add_backup_marker(samdb, "sidForRestore", new_sid)
        add_backup_marker(samdb, "backupRename", old_realm)
        add_backup_marker(samdb, "backupType", "rename")

        # fix up the DNS objects that are using the old dnsRoot value
        self.update_dns_root(logger, samdb, old_realm, delete_old_dns)

        # update the netBIOS name and the Partition object for the domain
        self.rename_domain_partition(logger, samdb, new_domain_name)

        if delete_old_dns:
            self.delete_old_dns_zones(logger, samdb, old_realm)

        logger.info("Fixing DN attributes after rename...")
        self.fix_old_dn_attributes(samdb)

        # ensure the admin user always has a password set (same as provision)
        if no_secrets:
            set_admin_password(logger, samdb)

        # Add everything in the tmpdir to the backup tar file
        backup_file = backup_filepath(targetdir, new_dns_realm, time_str)
        create_log_file(tmpdir, lp, "rename", server, include_secrets,
                        "Original domain %s (NetBIOS), %s (DNS realm)" %
                        (old_domain, old_realm))
        create_backup_tar(logger, tmpdir, backup_file)

        shutil.rmtree(tmpdir)
Example #56
0
File: drs.py Project: szaydel/samba
    def drs_local_replicate(self, SOURCE_DC, NC, full_sync=False,
                            single_object=False,
                            sync_forced=False):
        '''replicate from a source DC to the local SAM'''

        self.server = SOURCE_DC
        drsuapi_connect(self)

        # Override the default flag LDB_FLG_DONT_CREATE_DB
        self.local_samdb = SamDB(session_info=system_session(), url=None,
                                 credentials=self.creds, lp=self.lp,
                                 flags=0)

        self.samdb = SamDB(url="ldap://%s" % self.server,
                           session_info=system_session(),
                           credentials=self.creds, lp=self.lp)

        # work out the source and destination GUIDs
        res = self.local_samdb.search(base="", scope=ldb.SCOPE_BASE,
                                      attrs=["dsServiceName"])
        self.ntds_dn = res[0]["dsServiceName"][0]

        res = self.local_samdb.search(base=self.ntds_dn, scope=ldb.SCOPE_BASE,
                                      attrs=["objectGUID"])
        self.ntds_guid = misc.GUID(
            self.samdb.schema_format_value("objectGUID",
                                           res[0]["objectGUID"][0]))

        source_dsa_invocation_id = misc.GUID(self.samdb.get_invocation_id())
        dest_dsa_invocation_id = misc.GUID(self.local_samdb.get_invocation_id())
        destination_dsa_guid = self.ntds_guid

        exop = drsuapi.DRSUAPI_EXOP_NONE

        if single_object:
            exop = drsuapi.DRSUAPI_EXOP_REPL_OBJ
            full_sync = True

        self.samdb.transaction_start()
        repl = drs_utils.drs_Replicate("ncacn_ip_tcp:%s[seal]" % self.server,
                                       self.lp,
                                       self.creds, self.local_samdb,
                                       dest_dsa_invocation_id)

        # Work out if we are an RODC, so that a forced local replicate
        # with the admin pw does not sync passwords
        rodc = self.local_samdb.am_rodc()
        try:
            (num_objects, num_links) = repl.replicate(NC,
                                                      source_dsa_invocation_id,
                                                      destination_dsa_guid,
                                                      rodc=rodc,
                                                      full_sync=full_sync,
                                                      exop=exop,
                                                      sync_forced=sync_forced)
        except Exception as e:
            raise CommandError("Error replicating DN %s" % NC, e)
        self.samdb.transaction_commit()

        if full_sync:
            self.message("Full Replication of all %d objects and %d links "
                         "from %s to %s was successful." %
                         (num_objects, num_links, SOURCE_DC,
                          self.local_samdb.url))
        else:
            self.message("Incremental replication of %d objects and %d links "
                         "from %s to %s was successful." %
                         (num_objects, num_links, SOURCE_DC,
                          self.local_samdb.url))
Example #57
0
    def run(self,
            username,
            password=None,
            credopts=None,
            sambaopts=None,
            versionopts=None,
            H=None,
            must_change_at_next_login=False,
            random_password=False,
            use_username_as_cn=False,
            userou=None,
            surname=None,
            given_name=None,
            initials=None,
            profile_path=None,
            script_path=None,
            home_drive=None,
            home_directory=None,
            job_title=None,
            department=None,
            company=None,
            description=None,
            mail_address=None,
            internet_address=None,
            telephone_number=None,
            physical_delivery_office=None,
            rfc2307_from_nss=False,
            uid=None,
            uid_number=None,
            gid_number=None,
            gecos=None,
            login_shell=None):

        if random_password:
            password = generate_random_password(128, 255)

        while True:
            if password is not None and password is not '':
                break
            password = getpass("New Password: "******"Retype Password: "******"Sorry, passwords do not match.\n")

        if rfc2307_from_nss:
            pwent = pwd.getpwnam(username)
            if uid is None:
                uid = username
            if uid_number is None:
                uid_number = pwent[2]
            if gid_number is None:
                gid_number = pwent[3]
            if gecos is None:
                gecos = pwent[4]
            if login_shell is None:
                login_shell = pwent[6]

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        if uid_number or gid_number:
            if not lp.get("idmap_ldb:use rfc2307"):
                self.outf.write(
                    "You are setting a Unix/RFC2307 UID or GID. You may want to set 'idmap_ldb:use rfc2307 = Yes' to use those attributes for XID/SID-mapping.\n"
                )

        try:
            samdb = SamDB(url=H,
                          session_info=system_session(),
                          credentials=creds,
                          lp=lp)
            samdb.newuser(username,
                          password,
                          force_password_change_at_next_login_req=
                          must_change_at_next_login,
                          useusernameascn=use_username_as_cn,
                          userou=userou,
                          surname=surname,
                          givenname=given_name,
                          initials=initials,
                          profilepath=profile_path,
                          homedrive=home_drive,
                          scriptpath=script_path,
                          homedirectory=home_directory,
                          jobtitle=job_title,
                          department=department,
                          company=company,
                          description=description,
                          mailaddress=mail_address,
                          internetaddress=internet_address,
                          telephonenumber=telephone_number,
                          physicaldeliveryoffice=physical_delivery_office,
                          uid=uid,
                          uidnumber=uid_number,
                          gidnumber=gid_number,
                          gecos=gecos,
                          loginshell=login_shell)
        except Exception, e:
            raise CommandError("Failed to add user '%s': " % username, e)
Example #58
0
    def run(self,
            computername,
            credopts=None,
            sambaopts=None,
            versionopts=None,
            H=None,
            computerou=None,
            description=None,
            prepare_oldjoin=False,
            ip_address_list=None,
            service_principal_name_list=None):

        if ip_address_list is None:
            ip_address_list = []

        if service_principal_name_list is None:
            service_principal_name_list = []

        # check each IP address if provided
        for ip_address in ip_address_list:
            if not _is_valid_ip(ip_address):
                raise CommandError('Invalid IP address {}'.format(ip_address))

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        try:
            samdb = SamDB(url=H,
                          session_info=system_session(),
                          credentials=creds,
                          lp=lp)
            samdb.newcomputer(
                computername,
                computerou=computerou,
                description=description,
                prepare_oldjoin=prepare_oldjoin,
                ip_address_list=ip_address_list,
                service_principal_name_list=service_principal_name_list,
            )

            if ip_address_list:
                # if ip_address_list provided, then we need to create DNS
                # records for this computer.

                hostname = re.sub(r"\$$", "", computername)
                if hostname.count('$'):
                    raise CommandError('Illegal computername "%s"' %
                                       computername)

                filters = '(&(sAMAccountName={}$)(objectclass=computer))'.format(
                    ldb.binary_encode(hostname))

                recs = samdb.search(base=samdb.domain_dn(),
                                    scope=ldb.SCOPE_SUBTREE,
                                    expression=filters,
                                    attrs=['primaryGroupID', 'objectSid'])

                group = recs[0]['primaryGroupID'][0]
                owner = ndr_unpack(security.dom_sid, recs[0]["objectSid"][0])

                dns_conn = dnsserver.dnsserver(
                    "ncacn_ip_tcp:{}[sign]".format(samdb.host_dns_name()), lp,
                    creds)

                change_owner_sd = security.descriptor()
                change_owner_sd.owner_sid = owner
                change_owner_sd.group_sid = security.dom_sid(
                    "{}-{}".format(samdb.get_domain_sid(), group), )

                add_dns_records(samdb, hostname, dns_conn, change_owner_sd,
                                samdb.host_dns_name(), ip_address_list,
                                self.get_logger())
        except Exception as e:
            raise CommandError("Failed to add computer '%s': " % computername,
                               e)

        self.outf.write("Computer '%s' added successfully\n" % computername)
Example #59
0
class RodcTests(samba.tests.TestCase):

    def setUp(self):
        super(RodcTests, self).setUp()
        self.samdb = SamDB(HOST, credentials=CREDS,
                           session_info=system_session(LP), lp=LP)

        self.base_dn = self.samdb.domain_dn()

        root = self.samdb.search(base='', scope=ldb.SCOPE_BASE,
                                 attrs=['dsServiceName'])
        self.service = root[0]['dsServiceName'][0]
        self.tag = uuid.uuid4().hex

    def test_add_replicated_objects(self):
        for o in (
                {
                    'dn': "ou=%s1,%s" % (self.tag, self.base_dn),
                    "objectclass": "organizationalUnit"
                },
                {
                    'dn': "cn=%s2,%s" % (self.tag, self.base_dn),
                    "objectclass": "user"
                },
                {
                    'dn': "cn=%s3,%s" % (self.tag, self.base_dn),
                    "objectclass": "group"
                },
                {
                    'dn': "cn=%s4,%s" % (self.tag, self.service),
                    "objectclass": "NTDSConnection",
                    'enabledConnection': 'TRUE',
                    'fromServer': self.base_dn,
                    'options': '0'
                },
        ):
            try:
                self.samdb.add(o)
                self.fail("Failed to fail to add %s" % o['dn'])
            except ldb.LdbError as e:
                (ecode, emsg) = e.args
                if ecode != ldb.ERR_REFERRAL:
                    print(emsg)
                    self.fail("Adding %s: ldb error: %s %s, wanted referral" %
                              (o['dn'], ecode, emsg))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    try:
                        tmpdb = SamDB(address, credentials=CREDS,
                                      session_info=system_session(LP), lp=LP)
                        tmpdb.add(o)
                        tmpdb.delete(o['dn'])
                    except ldb.LdbError as e:
                        self.fail("couldn't modify referred location %s" %
                                  address)

                    if address.lower().startswith(self.samdb.domain_dns_name()):
                        self.fail("referral address did not give a specific DC")

    def test_modify_replicated_attributes(self):
        # some timestamp ones
        dn = 'CN=Guest,CN=Users,' + self.base_dn
        value = 'hallooo'
        for attr in ['carLicense', 'middleName']:
            msg = ldb.Message()
            msg.dn = ldb.Dn(self.samdb, dn)
            msg[attr] = ldb.MessageElement(value,
                                           ldb.FLAG_MOD_REPLACE,
                                           attr)
            try:
                self.samdb.modify(msg)
                self.fail("Failed to fail to modify %s %s" % (dn, attr))
            except ldb.LdbError as e1:
                (ecode, emsg) = e1.args
                if ecode != ldb.ERR_REFERRAL:
                    self.fail("Failed to REFER when trying to modify %s %s" %
                              (dn, attr))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    try:
                        tmpdb = SamDB(address, credentials=CREDS,
                                      session_info=system_session(LP), lp=LP)
                        tmpdb.modify(msg)
                    except ldb.LdbError as e:
                        self.fail("couldn't modify referred location %s" %
                                  address)

                    if address.lower().startswith(self.samdb.domain_dns_name()):
                        self.fail("referral address did not give a specific DC")

    def test_modify_nonreplicated_attributes(self):
        # some timestamp ones
        dn = 'CN=Guest,CN=Users,' + self.base_dn
        value = '123456789'
        for attr in ['badPwdCount', 'lastLogon', 'lastLogoff']:
            m = ldb.Message()
            m.dn = ldb.Dn(self.samdb, dn)
            m[attr] = ldb.MessageElement(value,
                                         ldb.FLAG_MOD_REPLACE,
                                         attr)
            # Windows refers these ones even though they are non-replicated
            try:
                self.samdb.modify(m)
                self.fail("Failed to fail to modify %s %s" % (dn, attr))
            except ldb.LdbError as e2:
                (ecode, emsg) = e2.args
                if ecode != ldb.ERR_REFERRAL:
                    self.fail("Failed to REFER when trying to modify %s %s" %
                              (dn, attr))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    if address.lower().startswith(self.samdb.domain_dns_name()):
                        self.fail("referral address did not give a specific DC")

    def test_modify_nonreplicated_reps_attributes(self):
        # some timestamp ones
        dn = self.base_dn

        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, dn)
        attr = 'repsFrom'

        res = self.samdb.search(dn, scope=ldb.SCOPE_BASE,
                                attrs=['repsFrom'])
        rep = ndr_unpack(drsblobs.repsFromToBlob, res[0]['repsFrom'][0],
                         allow_remaining=True)
        rep.ctr.result_last_attempt = -1
        value = ndr_pack(rep)

        m[attr] = ldb.MessageElement(value,
                                     ldb.FLAG_MOD_REPLACE,
                                     attr)
        try:
            self.samdb.modify(m)
            self.fail("Failed to fail to modify %s %s" % (dn, attr))
        except ldb.LdbError as e3:
            (ecode, emsg) = e3.args
            if ecode != ldb.ERR_REFERRAL:
                self.fail("Failed to REFER when trying to modify %s %s" %
                          (dn, attr))
            else:
                m = re.search(r'(ldap://[^>]+)>', emsg)
                if m is None:
                    self.fail("referral seems not to refer to anything")
                address = m.group(1)

                if address.lower().startswith(self.samdb.domain_dns_name()):
                    self.fail("referral address did not give a specific DC")

    def test_delete_special_objects(self):
        dn = 'CN=Guest,CN=Users,' + self.base_dn
        try:
            self.samdb.delete(dn)
            self.fail("Failed to fail to delete %s" % (dn))
        except ldb.LdbError as e4:
            (ecode, emsg) = e4.args
            if ecode != ldb.ERR_REFERRAL:
                print(ecode, emsg)
                self.fail("Failed to REFER when trying to delete %s" % dn)
            else:
                m = re.search(r'(ldap://[^>]+)>', emsg)
                if m is None:
                    self.fail("referral seems not to refer to anything")
                address = m.group(1)

                if address.lower().startswith(self.samdb.domain_dns_name()):
                    self.fail("referral address did not give a specific DC")

    def test_no_delete_nonexistent_objects(self):
        dn = 'CN=does-not-exist-%s,CN=Users,%s' % (self.tag, self.base_dn)
        try:
            self.samdb.delete(dn)
            self.fail("Failed to fail to delete %s" % (dn))
        except ldb.LdbError as e5:
            (ecode, emsg) = e5.args
            if ecode != ldb.ERR_NO_SUCH_OBJECT:
                print(ecode, emsg)
                self.fail("Failed to NO_SUCH_OBJECT when trying to delete "
                          "%s (which does not exist)" % dn)
Example #60
0
class GroupAuditTests(AuditLogTestBase):

    def setUp(self):
        self.message_type = MSG_GROUP_LOG
        self.event_type = DSDB_GROUP_EVENT_NAME
        super(GroupAuditTests, self).setUp()

        self.remoteAddress = os.environ["CLIENT_IP"]
        self.server_ip = os.environ["SERVER_IP"]

        host = "ldap://%s" % os.environ["SERVER"]
        self.ldb = SamDB(url=host,
                         session_info=system_session(),
                         credentials=self.get_credentials(),
                         lp=self.get_loadparm())
        self.server = os.environ["SERVER"]

        # Gets back the basedn
        self.base_dn = self.ldb.domain_dn()

        # Get the old "dSHeuristics" if it was set
        dsheuristics = self.ldb.get_dsheuristics()

        # Set the "dSHeuristics" to activate the correct "userPassword"
        # behaviour
        self.ldb.set_dsheuristics("000000001")

        # Reset the "dSHeuristics" as they were before
        self.addCleanup(self.ldb.set_dsheuristics, dsheuristics)

        # Get the old "minPwdAge"
        minPwdAge = self.ldb.get_minPwdAge()

        # Set it temporarily to "0"
        self.ldb.set_minPwdAge("0")
        self.base_dn = self.ldb.domain_dn()

        # Reset the "minPwdAge" as it was before
        self.addCleanup(self.ldb.set_minPwdAge, minPwdAge)

        # (Re)adds the test user USER_NAME with password USER_PASS
        self.ldb.add({
            "dn": "cn=" + USER_NAME + ",cn=users," + self.base_dn,
            "objectclass": "user",
            "sAMAccountName": USER_NAME,
            "userPassword": USER_PASS
        })
        self.ldb.newgroup(GROUP_NAME_01)
        self.ldb.newgroup(GROUP_NAME_02)

    def tearDown(self):
        super(GroupAuditTests, self).tearDown()
        delete_force(self.ldb, "cn=" + USER_NAME + ",cn=users," + self.base_dn)
        self.ldb.deletegroup(GROUP_NAME_01)
        self.ldb.deletegroup(GROUP_NAME_02)

    def test_add_and_remove_users_from_group(self):

        #
        # Wait for the primary group change for the created user.
        #
        messages = self.waitForMessages(2)
        print("Received %d messages" % len(messages))
        self.assertEquals(2,
                          len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("PrimaryGroup", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=domain users,cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        # Check the Add message for the new users primary group
        audit = messages[1]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=domain users,cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        self.assertEquals(EVT_ID_USER_ADDED_TO_GLOBAL_SEC_GROUP,
                          audit["eventId"])
        #
        # Add the user to a group
        #
        self.discardMessages()

        self.ldb.add_remove_group_members(GROUP_NAME_01, [USER_NAME])
        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        #
        # Add the user to another group
        #
        self.discardMessages()
        self.ldb.add_remove_group_members(GROUP_NAME_02, [USER_NAME])

        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_02 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        #
        # Remove the user from a group
        #
        self.discardMessages()
        self.ldb.add_remove_group_members(
            GROUP_NAME_01,
            [USER_NAME],
            add_members_operation=False)
        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("Removed", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        #
        # Re-add the user to a group
        #
        self.discardMessages()
        self.ldb.add_remove_group_members(GROUP_NAME_01, [USER_NAME])

        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

    def test_change_primary_group(self):

        #
        # Wait for the primary group change for the created user.
        #
        messages = self.waitForMessages(2)
        print("Received %d messages" % len(messages))
        self.assertEquals(2,
                          len(messages),
                          "Did not receive the expected number of messages")

        # Check the PrimaryGroup message
        audit = messages[0]["groupChange"]

        self.assertEqual("PrimaryGroup", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=domain users,cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        # Check the Add message for the new users primary group
        audit = messages[1]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=domain users,cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        self.assertEquals(EVT_ID_USER_ADDED_TO_GLOBAL_SEC_GROUP,
                          audit["eventId"])

        #
        # Add the user to a group, the user needs to be a member of a group
        # before there primary group can be set to that group.
        #
        self.discardMessages()

        self.ldb.add_remove_group_members(GROUP_NAME_01, [USER_NAME])
        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")
        self.assertEquals(EVT_ID_USER_ADDED_TO_GLOBAL_SEC_GROUP,
                          audit["eventId"])

        #
        # Change the primary group of a user
        #
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        # get the primaryGroupToken of the group
        res = self.ldb.search(base=group_dn, attrs=["primaryGroupToken"],
                              scope=ldb.SCOPE_BASE)
        group_id = res[0]["primaryGroupToken"]

        # set primaryGroupID attribute of the user to that group
        m = ldb.Message()
        m.dn = ldb.Dn(self.ldb, user_dn)
        m["primaryGroupID"] = ldb.MessageElement(
            group_id,
            FLAG_MOD_REPLACE,
            "primaryGroupID")
        self.discardMessages()
        self.ldb.modify(m)

        #
        # Wait for the primary group change.
        # Will see the user removed from the new group
        #          the user added to their old primary group
        #          and a new primary group event.
        #
        messages = self.waitForMessages(3)
        print("Received %d messages" % len(messages))
        self.assertEquals(3,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["groupChange"]
        self.assertEqual("Removed", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")
        self.assertEquals(EVT_ID_USER_REMOVED_FROM_GLOBAL_SEC_GROUP,
                          audit["eventId"])

        audit = messages[1]["groupChange"]

        self.assertEqual("Added", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=domain users,cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")
        self.assertEquals(EVT_ID_USER_ADDED_TO_GLOBAL_SEC_GROUP,
                          audit["eventId"])

        audit = messages[2]["groupChange"]

        self.assertEqual("PrimaryGroup", audit["action"])
        user_dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        group_dn = "cn=" + GROUP_NAME_01 + ",cn=users," + self.base_dn
        self.assertTrue(user_dn.lower(), audit["user"].lower())
        self.assertTrue(group_dn.lower(), audit["group"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")