def test_purge_deleted_files(self):
        files = set(
            File(f)
            for f in ('dns.2014.Y.mtbl', 'dns.201501.M.mtbl',
                      'dns.20150201.W.mtbl', 'dns.20150208.D.mtbl',
                      'dns.20150209.0000.H.mtbl', 'dns.20150209.0100.X.mtbl',
                      'dns.20150209.0110.m.mtbl'))

        class Fail(Exception):
            pass

        to_delete = set(os.path.join(self.td, fn.name) for fn in files)

        for fn in set(to_delete):
            for extension in DIGEST_EXTENSIONS:
                to_delete.add('{}.{}'.format(fn, extension))

        def my_unlink(fn):
            self.assertIn(fn, to_delete)
            to_delete.remove(fn)

        os.unlink = my_unlink

        fs = Fileset(None, self.td)
        fs.pending_deletions = files
        fs.purge_deleted_files()

        self.assertItemsEqual(fs.pending_deletions, [])
        self.assertItemsEqual(to_delete, [])
    def test_load_remote_fileset_apikey(self):
        fileset_uri = 'http://example.com/dns.fileset'
        files = ('dns.2014.Y.mtbl', 'dns.201501.M.mtbl', 'dns.20150201.W.mtbl',
                 'dns.20150208.D.mtbl', 'dns.20150209.0000.H.mtbl',
                 'dns.20150209.0100.X.mtbl', 'dns.20150209.0110.m.mtbl')
        apikey = 'TEST APIKEY'

        headers = []

        def my_urlopen(obj, timeout=None):
            headers.extend(obj.header_items())
            uri = get_uri(obj)
            self.assertEqual(uri, fileset_uri)
            fp = StringIO('\n'.join(files + ('', )))
            digest = base64.b64encode(hashlib.sha256(fp.getvalue()).digest())
            msg = httplib.HTTPMessage(fp=StringIO(
                'Content-Length: {}\r\nDigest: SHA-256={}'.format(
                    len(fp.getvalue()), digest)),
                                      seekable=True)
            return urllib.addinfourl(fp, msg, uri)

        urllib2.urlopen = my_urlopen

        fs = Fileset(fileset_uri, self.td, apikey=apikey)
        self.assertEqual(fs.apikey, apikey)
        fs.load_remote_fileset()

        for k, v in headers:
            if k.lower() == 'x-api-key':
                self.assertEqual(v, apikey)
                break
        else:
            self.fail('X-API-Key header missing')

        self.assertItemsEqual(fs.remote_files, (File(f) for f in files))
    def __init__(self,
                 fileset_uri,
                 destination,
                 base=None,
                 extension='mtbl',
                 frequency=1800,
                 download_timeout=None,
                 retry_timeout=60,
                 apikey=None,
                 validator=None,
                 digest_required=True,
                 minimal=True,
                 download_manager=None):
        self.fileset_uri = fileset_uri

        if not os.path.isdir(destination):
            raise OSError(errno.ENOENT,
                          'Not a directory: \'{}\''.format(destination))

        self.destination = destination

        if base:
            self.base = base
        else:
            self.base = os.path.splitext(
                os.path.basename(urlparse.urlsplit(fileset_uri)[2]))[0]

        self.extension = extension
        self.frequency = frequency
        self.download_timeout = download_timeout
        self.retry_timeout = retry_timeout
        self.minimal = minimal

        self.fileset = Fileset(uri=self.fileset_uri,
                               dname=self.destination,
                               base=self.base,
                               extension=self.extension,
                               apikey=apikey,
                               validator=validator,
                               timeout=download_timeout,
                               digest_required=digest_required)

        if download_manager:
            self.download_manager = download_manager
        else:
            self.download_manager = DownloadManager(
                download_timeout=download_timeout, retry_timeout=retry_timeout)
            self.download_manager.start()

        self.thread = None
    def test_list_tempfiles(self):
        fs = Fileset(None, self.td)
        files = set(
            ('dns.2014.Y.mtbl', 'dns.201501.M.mtbl', 'dns.20150201.W.mtbl',
             'dns.20150208.D.mtbl', 'dns.20150209.0000.H.mtbl',
             'dns.20150209.0100.X.mtbl', 'dns.20150209.0110.m.mtbl'))
        for fn in files:
            open(os.path.join(self.td, fn), 'w')
            open(os.path.join(self.td, '.{}'.format(fn)), 'w')

        tempfiles = set(
            tempfile.mkstemp(dir=self.td, prefix='.{}.'.format(fn))[1]
            for fn in files)
        self.assertItemsEqual(fs.list_temporary_files(), tempfiles)
    def test_load_local_fileset(self):
        fileset = ('dns.2014.Y.mtbl', 'dns.201501.M.mtbl',
                   'dns.20150201.W.mtbl', 'dns.20150208.D.mtbl',
                   'dns.20150209.0000.H.mtbl', 'dns.20150209.0100.X.mtbl',
                   'dns.20150209.0110.m.mtbl')

        for fn in fileset:
            open(os.path.join(self.td, fn), 'w')

        fs = Fileset(None, self.td)
        fs.load_local_fileset()

        self.assertItemsEqual(fs.all_local_files, (File(fn) for fn in fileset))
        self.assertItemsEqual(fs.minimal_local_files,
                              (File(fn) for fn in fileset))
    def test_prune_redundant_files(self):
        files = set(
            File(f)
            for f in ('dns.2014.Y.mtbl', 'dns.201501.M.mtbl',
                      'dns.20150201.W.mtbl', 'dns.20150208.D.mtbl',
                      'dns.20150209.0000.H.mtbl', 'dns.20150209.0100.X.mtbl',
                      'dns.20150209.0110.m.mtbl'))
        redundant = set(
            File(f)
            for f in ('dns.201401.M.mtbl', 'dns.20150108.W.mtbl',
                      'dns.20150202.D.mtbl', 'dns.20150208.0100.H.mtbl',
                      'dns.20150209.0020.X.mtbl', 'dns.20150209.0109.m.mtbl'))

        fs = Fileset(None, self.td)
        fs.minimal_local_files = files.union(redundant)
        fs.prune_redundant_files()

        self.assertItemsEqual(fs.minimal_local_files, files)
        self.assertItemsEqual(fs.pending_deletions, redundant)
    def test_load_remote_fileset_bad_content_length(self):
        fileset_uri = 'http://example.com/dns.fileset'
        files = ('dns.2014.Y.mtbl', 'dns.201501.M.mtbl', 'dns.20150201.W.mtbl',
                 'dns.20150208.D.mtbl', 'dns.20150209.0000.H.mtbl',
                 'dns.20150209.0100.X.mtbl', 'dns.20150209.0110.m.mtbl')

        def my_urlopen(obj, timeout=None):
            uri = get_uri(obj)
            self.assertEqual(uri, fileset_uri)
            fp = StringIO('\n'.join(files + ('', )))
            msg = httplib.HTTPMessage(fp=StringIO(
                'Content-Length: {}'.format(len(fp.getvalue()) + 1)),
                                      seekable=True)
            return urllib.addinfourl(fp, msg, uri)

        urllib2.urlopen = my_urlopen

        fs = Fileset(fileset_uri, self.td)
        with self.assertRaisesRegexp(FilesetError, r'content length mismatch'):
            fs.load_remote_fileset()
    def test_write_local_fileset_full(self):
        files = set(
            File(f)
            for f in ('dns.2014.Y.mtbl', 'dns.201501.M.mtbl',
                      'dns.20150201.W.mtbl', 'dns.20150208.D.mtbl',
                      'dns.20150209.0000.H.mtbl', 'dns.20150209.0100.X.mtbl',
                      'dns.20150209.0110.m.mtbl'))
        redundant = set(
            File(f)
            for f in ('dns.201401.M.mtbl', 'dns.20150108.W.mtbl',
                      'dns.20150202.D.mtbl', 'dns.20150208.0100.H.mtbl',
                      'dns.20150209.0020.X.mtbl', 'dns.20150209.0109.m.mtbl'))

        fs = Fileset(None, self.td)
        fs.all_local_files = files.union(redundant)
        fs.minimal_local_files = files

        fs.write_local_fileset(minimal=False)

        fileset_path = os.path.join(self.td, 'dns.fileset')
        full_fileset_path = os.path.join(self.td, 'dns-full.fileset')

        self.assertFalse(os.path.exists(fileset_path))
        self.assertTrue(os.path.exists(full_fileset_path))

        fileset = set(File(f.strip()) for f in open(full_fileset_path))

        self.assertItemsEqual(files.union(redundant), fileset)
    def test_load_remote_fileset(self):
        fileset_uri = 'http://example.com/dns.fileset'
        files = ('dns.2014.Y.mtbl', 'dns.201501.M.mtbl', 'dns.20150201.W.mtbl',
                 'dns.20150208.D.mtbl', 'dns.20150209.0000.H.mtbl',
                 'dns.20150209.0100.X.mtbl', 'dns.20150209.0110.m.mtbl')

        def my_urlopen(obj, timeout=None):
            uri = get_uri(obj)
            self.assertEqual(uri, fileset_uri)
            fp = StringIO('\n'.join(files + ('', )))
            digest = base64.b64encode(hashlib.sha256(fp.getvalue()).digest())
            msg = httplib.HTTPMessage(fp=StringIO(
                'Content-Length: {}\r\nDigest: SHA-256={}'.format(
                    len(fp.getvalue()), digest)),
                                      seekable=True)
            return urllib.addinfourl(fp, msg, uri)

        urllib2.urlopen = my_urlopen

        fs = Fileset(fileset_uri, self.td)
        fs.load_remote_fileset()

        self.assertItemsEqual(fs.remote_files, (File(f) for f in files))
    def test_prune_obsolete_files_full(self):
        remote_files = set(
            File(f) for f in (
                'dns.2014.Y.mtbl',
                'dns.201401.M.mtbl',
                'dns.20140201.D.mtbl',
                'dns.20140201.0000.H.mtbl',
                'dns.20140201.0100.X.mtbl',
                'dns.20140201.0110.m.mtbl',
                'dns.2015.Y.mtbl',
            ))
        files = set(
            File(f) for f in (
                'dns.2014.Y.mtbl',
                'dns.201401.M.mtbl',
                'dns.20140201.D.mtbl',
                'dns.20140201.0000.H.mtbl',
                'dns.20140201.0100.X.mtbl',
                'dns.20140201.0110.m.mtbl',
                'dns.201501.M.mtbl',
                'dns.20150201.W.mtbl',
                'dns.20150208.D.mtbl',
                'dns.20150209.0000.H.mtbl',
                'dns.20150209.0100.X.mtbl',
                'dns.20150209.0110.m.mtbl',
            ))
        obsolete = set(
            File(f) for f in (
                'dns.2012.Y.mtbl',
                'dns.20130108.W.mtbl',
                'dns.20130202.D.mtbl',
                'dns.20130208.0100.H.mtbl',
                'dns.20130209.0020.X.mtbl',
                'dns.20130209.0109.m.mtbl',
                'dns.20150101.D.mtbl',
                'dns.20150101.0000.H.mtbl',
                'dns.20150101.0100.X.mtbl',
                'dns.20150101.0110.m.mtbl',
            ))

        fs = Fileset(None, self.td)
        fs.all_local_files = files.union(obsolete)
        fs.minimal_local_files = files.union(obsolete)
        fs.remote_files = remote_files
        fs.prune_obsolete_files(minimal=False)

        self.assertItemsEqual(fs.all_local_files, files)
        self.assertItemsEqual(fs.minimal_local_files, files)
        self.assertItemsEqual(fs.remote_files, remote_files)
        self.assertItemsEqual(fs.pending_deletions, obsolete)
    def test_missing_files(self):
        files = set(
            File(f)
            for f in ('dns.2014.Y.mtbl', 'dns.201501.M.mtbl',
                      'dns.20150201.W.mtbl', 'dns.20150208.D.mtbl',
                      'dns.20150209.0000.H.mtbl', 'dns.20150209.0100.X.mtbl',
                      'dns.20150209.0110.m.mtbl'))
        missing = set(
            File(f)
            for f in ('dns.2012.Y.mtbl', 'dns.20130108.W.mtbl',
                      'dns.20130202.D.mtbl', 'dns.20130208.0100.H.mtbl',
                      'dns.20130209.0020.X.mtbl', 'dns.20130209.0109.m.mtbl'))

        fs = Fileset(None, self.td)
        fs.all_local_files = set(files)
        fs.minimal_local_files = set(files)
        fs.remote_files = files.union(missing)

        self.assertItemsEqual(fs.missing_files(), missing)
class DNSTableManager:
    def __init__(self,
                 fileset_uri,
                 destination,
                 base=None,
                 extension='mtbl',
                 frequency=1800,
                 download_timeout=None,
                 retry_timeout=60,
                 apikey=None,
                 validator=None,
                 digest_required=True,
                 minimal=True,
                 download_manager=None):
        self.fileset_uri = fileset_uri

        if not os.path.isdir(destination):
            raise OSError(errno.ENOENT,
                          'Not a directory: \'{}\''.format(destination))

        self.destination = destination

        if base:
            self.base = base
        else:
            self.base = os.path.splitext(
                os.path.basename(urlparse.urlsplit(fileset_uri)[2]))[0]

        self.extension = extension
        self.frequency = frequency
        self.download_timeout = download_timeout
        self.retry_timeout = retry_timeout
        self.minimal = minimal

        self.fileset = Fileset(uri=self.fileset_uri,
                               dname=self.destination,
                               base=self.base,
                               extension=self.extension,
                               apikey=apikey,
                               validator=validator,
                               timeout=download_timeout,
                               digest_required=digest_required)

        if download_manager:
            self.download_manager = download_manager
        else:
            self.download_manager = DownloadManager(
                download_timeout=download_timeout, retry_timeout=retry_timeout)
            self.download_manager.start()

        self.thread = None

    def start(self):
        if self.thread:
            raise Exception

        self.thread = threading.Thread(target=self.run)
        self.thread.setDaemon(True)
        self.thread.start()

    def join(self):
        if not self.thread:
            raise Exception

        self.thread.join()
        self.thread = None

    def run(self):
        next_remote_load = 0
        while True:
            now = time.time()
            self.fileset.load_local_fileset()

            try:
                if now >= next_remote_load:
                    self.fileset.load_remote_fileset()
                    next_remote_load = now + self.frequency
            except (FilesetError, urllib2.URLError, urllib2.HTTPError,
                    httplib.HTTPException, socket.error) as e:
                logger.error('Failed to load remote fileset {}: {}'.format(
                    self.fileset_uri, str(e)))
                logger.debug(traceback.format_exc())
                next_remote_load = now + self.retry_timeout

            for f in sorted(self.fileset.missing_files(), reverse=True):
                if f not in self.download_manager:
                    self.download_manager.enqueue(f)

            self.fileset.prune_obsolete_files(minimal=self.minimal)
            self.fileset.prune_redundant_files(minimal=self.minimal)

            try:
                self.fileset.write_local_fileset()
                if not self.minimal:
                    self.fileset.write_local_fileset(minimal=False)
            except (IOError, OSError) as e:
                logger.error('Failed to write fileset {}: {}'.format(
                    self.fileset.get_fileset_name(), str(e)))
                logger.debug(traceback.format_exc())

            try:
                self.fileset.purge_deleted_files()
            except OSError as e:
                logger.error('Failed to purge deleted files in {}: {}'.format(
                    self.destination, str(e)))
                logger.debug(traceback.format_exc())

            time.sleep(1)

    def clean_tempfiles(self):
        open_files = set()
        for p in psutil.process_iter():
            try:
                try:
                    func = p.get_open_files
                except AttributeError:
                    func = p.open_files

                for f in func():
                    open_files.add(f.path)
            except psutil.AccessDenied:
                pass

        for filename in self.fileset.list_temporary_files():
            if filename in open_files:
                logger.debug(
                    'Not unlinking tempfile {!r}: In use.'.format(filename))
                continue

            logger.debug('Unlinking tempfile: {!r}'.format(filename))
            os.unlink(filename)