예제 #1
0
 def __init__(self,
              dataRemote: bool = False,
              allowedDirs: List[str] = None,
              allowedFileTypes: List[str] = None):
     """
     Args:
         dataRemote (bool): whether data will be served from the local instance or requests forwarded
             to a remote instance for handling.
         allowedDirs (list): list of directories from which files are allowed to be read/writting. File
             operations will not be permitted unless the file path is a child of an allowed directory.
         allowedFileTypes (list): list of file extensions, such as '.dcm', '.txt', for which file
             operations are permitted. No file operations will be done unless the file extension matches
             one on the list.
     """
     super().__init__(isRemote=dataRemote)
     if dataRemote is True:
         return
     self.initWatchSet = False
     self.watchDir = None
     self.currentStreamId = 0
     self.streamInfo = None
     self.allowedDirs = allowedDirs
     # Remove trailing slash from dir names
     if allowedDirs is not None:
         self.allowedDirs = [dir.rstrip('/') for dir in allowedDirs]
     self.allowedFileTypes = allowedFileTypes
     # make sure allowed file extensions start with '.'
     if allowedFileTypes is not None:
         if allowedFileTypes[0] != '*':
             for i in range(len(allowedFileTypes)):
                 if not allowedFileTypes[i].startswith('.'):
                     allowedFileTypes[i] = '.' + allowedFileTypes[i]
     self.fileWatchLock = threading.Lock()
     # instantiate local FileWatcher
     self.fileWatcher = FileWatcher()
예제 #2
0
 def __init__(self, filesremote=False, commPipes=None):
     self.local = not filesremote
     self.commPipes = commPipes
     self.fileWatcher = None
     self.initWatchSet = False
     if self.local:
         self.fileWatcher = FileWatcher()
예제 #3
0
class DataInterface(RemoteableExtensible):
    """
    Provides functions for accessing remote or local files depending on if dateRemote flag is
    set true or false. 

    If dataRemote=True, then the RemoteExtensible parent class takes over and forwards all
    requests to a remote server via a callback function registered with the RemoteExtensible object.
    In that case *none* of the methods below will be locally invoked.

    If dataRemote=False, then the methods below will be invoked locally and the RemoteExtensible
    parent class is inoperable (i.e. does nothing).
    """
    def __init__(self,
                 dataRemote: bool = False,
                 allowedDirs: List[str] = None,
                 allowedFileTypes: List[str] = None):
        """
        Args:
            dataRemote (bool): whether data will be served from the local instance or requests forwarded
                to a remote instance for handling.
            allowedDirs (list): list of directories from which files are allowed to be read/writting. File
                operations will not be permitted unless the file path is a child of an allowed directory.
            allowedFileTypes (list): list of file extensions, such as '.dcm', '.txt', for which file
                operations are permitted. No file operations will be done unless the file extension matches
                one on the list.
        """
        super().__init__(isRemote=dataRemote)
        if dataRemote is True:
            return
        self.initWatchSet = False
        self.watchDir = None
        self.currentStreamId = 0
        self.streamInfo = None
        self.allowedDirs = allowedDirs
        # Remove trailing slash from dir names
        if allowedDirs is not None:
            self.allowedDirs = [dir.rstrip('/') for dir in allowedDirs]
        self.allowedFileTypes = allowedFileTypes
        # make sure allowed file extensions start with '.'
        if allowedFileTypes is not None:
            if allowedFileTypes[0] != '*':
                for i in range(len(allowedFileTypes)):
                    if not allowedFileTypes[i].startswith('.'):
                        allowedFileTypes[i] = '.' + allowedFileTypes[i]
        self.fileWatchLock = threading.Lock()
        # instantiate local FileWatcher
        self.fileWatcher = FileWatcher()

    def __del__(self):
        if hasattr(self, "fileWatcher"):
            if self.fileWatcher is not None:
                self.fileWatcher.__del__()
                self.fileWatcher = None

    def initScannerStream(self,
                          imgDir: str,
                          filePattern: str,
                          minFileSize: int,
                          demoStep: int = 0) -> int:
        """
        Initialize a data stream context with image directory and filepattern.
        Once the stream is initialized call getImageData() to retrieve image data.
        NOTE: currently only one stream at a time is supported.

        Args:
            imgDir: the directory where the images are or will be written from the MRI scanner.
            filePattern: a pattern of the image file names that has a TR tag which will be used
                to index the images, for example 'scan01_{TR:03d}.dcm'. In this example a call to
                getImageData(imgIndex=6) would look for dicom file 'scan01_006.dcm'.

        Returns:
            streamId: An identifier used when calling getImageData()
        """
        self._checkAllowedDirs(imgDir)
        self._checkAllowedFileTypes(filePattern)

        # check that filePattern has {TR} in it
        if not re.match(r'.*{TR.*', filePattern):
            raise InvocationError(
                r"initScannerStream filePattern must have a {TR} pattern")
        self.currentStreamId = self.currentStreamId + 1
        self.streamInfo = StructDict({
            'streamId': self.currentStreamId,
            'type': 'scanner',
            'imgDir': imgDir,
            'filePattern': filePattern,
            'minFileSize': minFileSize,
            'demoStep': demoStep,
            'imgIndex': 0,
        })
        _, file_ext = os.path.splitext(filePattern)
        self.initWatch(imgDir, '*' + file_ext, minFileSize, demoStep)
        return self.currentStreamId

    def getImageData(self,
                     streamId: int,
                     imageIndex: int = None,
                     timeout: int = 5) -> pydicom.dataset.FileDataset:
        """
        Get data from a stream initialized with initScannerStream or initOpenNeuroStream

        Args:
            streamId: Id of a previously opened stream.
            imageIndex: Which image from the stream to retrieve. If left blank it will
                retrieve the next image in the stream (next after either the last request or 
                starting from 0 if no previous requests)
        Returns:
            The bytes array representing the image data
            returns pydicom.dataset.FileDataset
        """
        if self.currentStreamId == 0 or self.currentStreamId != streamId or self.streamInfo.streamId != streamId:
            raise ValidationError(
                f"StreamID mismatch {self.currentStreamId} : {streamId}")

        if imageIndex is None:
            imageIndex = self.streamInfo.imgIndex
        filename = self.streamInfo.filePattern.format(TR=imageIndex)

        retries = 0
        while retries < 5:
            retries += 1
            try:
                data = self.watchFile(filename, timeout)
                dicomImg = readDicomFromBuffer(data)
                # Convert pixel data to a numpy.ndarray internally.
                # Note: the conversion cause error in pickle encoding
                # dicomImg.convert_pixel_data()
                self.streamInfo.imgIndex = imageIndex + 1
                return dicomImg
            except TimeoutError as err:
                logging.warning(
                    f"Timeout waiting for {filename}. Retry in 100 ms")
                time.sleep(0.1)
            except Exception as err:
                logging.error(
                    f"getImageData Error, filename {filename} err: {err}")
                return None
        return None

    def getFile(self, filename: str) -> bytes:
        """Returns a file's data immediately or fails if the file doesn't exist."""
        fileDir, fileCheck = os.path.split(filename)
        self._checkAllowedDirs(fileDir)
        self._checkAllowedFileTypes(fileCheck)

        data = None
        if not os.path.exists(filename):
            raise FileNotFoundError(f'File not found {filename}')
        with open(filename, 'rb') as fp:
            data = fp.read()
        # Consider - detect string encoding - but this could be computationally expenise on large data
        # encoding = chardet.detect(data)['encoding']
        # if encoding == 'ascii':
        #     data = data.decode(encoding)
        return data

    def getNewestFile(self, filepattern: str) -> bytes:
        """Searches for files matching filePattern and returns the data from the newest one."""
        data = None
        baseDir, filePattern = os.path.split(filepattern)
        self._checkAllowedDirs(baseDir)
        if not os.path.isabs(baseDir):
            # TODO - handle relative paths
            pass
        filename = utils.findNewestFile(baseDir, filePattern)
        self._checkAllowedFileTypes(filename)
        if filename is None:
            # No file matching pattern
            raise FileNotFoundError(
                'No file found matching pattern {}'.format(filePattern))
        elif not os.path.exists(filename):
            raise FileNotFoundError(
                'File missing after match {}'.format(filePattern))
        else:
            with open(filename, 'rb') as fp:
                data = fp.read()
        return data

    def initWatch(self,
                  dir: str,
                  filePattern: str,
                  minFileSize: int,
                  demoStep: int = 0) -> None:
        """Initialize a watch directory for files matching filePattern.

        No data is returned by this function, but a filesystem watch is established.
        After calling initWatch, use watchFile() to watch for a specific file's arrival.

        Args:
            dir: Directory to watch for arrival (creation) of new files
            filePattern: Regex style filename pattern of files to watch for (i.e. *.dcm)
            minFileSize: Minimum size of the file to return (continue waiting if below this size)
            demoStep: Minimum interval (in seconds) to wait before returning files.
                Useful for demos replaying existing files while mimicking original timing.
        """
        self._checkAllowedDirs(dir)
        self._checkAllowedFileTypes(filePattern)
        self.fileWatchLock.acquire()
        self.watchDir = dir
        try:
            self.fileWatcher.initFileNotifier(dir, filePattern, minFileSize,
                                              demoStep)
        finally:
            self.fileWatchLock.release()
        self.initWatchSet = True
        return

    def watchFile(self, filename: str, timeout: int = 5) -> bytes:
        """Watches for a specific file to be created and returns the file data.

        InitWatch() must be called first, before watching for specific files.
        If filename includes the full path, the path must match that used in initWatch().
        """
        data = None
        if not self.initWatchSet:
            raise StateError(
                "DataInterface: watchFile() called without an initWatch()")

        # check filename dir matches initWatch dir
        fileDir, fileCheck = os.path.split(filename)
        if fileDir not in ('', None):
            if fileDir != self.watchDir:
                raise RequestError(
                    "DataInterface: watchFile: filepath doesn't match "
                    f"watch directory: {fileDir}, {self.watchDir}")
            self._checkAllowedDirs(fileDir)
        self._checkAllowedFileTypes(fileCheck)

        self.fileWatchLock.acquire()
        try:
            foundFilename = self.fileWatcher.waitForFile(
                filename, timeout=timeout, timeCheckIncrement=0.2)
        finally:
            self.fileWatchLock.release()
        if foundFilename is None:
            raise TimeoutError("WatchFile: Timeout {}s: {}".format(
                timeout, filename))
        else:
            with open(foundFilename, 'rb') as fp:
                data = fp.read()
        return data

    def putFile(self,
                filename: str,
                data: Union[str, bytes],
                compress: bool = False) -> None:
        """
        Create a file (filename) and write the bytes or text to it. 
        In remote mode the file is written at the remote.

        Args:
            filename: Name of file to create
            data: data to write to the file
            compress: Whether to compress the data in transit (not within the file),
                only has affect in remote mode.
        """
        fileDir, fileCheck = os.path.split(filename)
        self._checkAllowedDirs(fileDir)
        self._checkAllowedFileTypes(fileCheck)

        if type(data) == str:
            data = data.encode()

        outputDir = os.path.dirname(filename)
        if not os.path.exists(outputDir):
            os.makedirs(outputDir)
        with open(filename, 'wb+') as binFile:
            binFile.write(data)
        return

    def listFiles(self, filepattern: str) -> List[str]:
        """Lists files matching the regex filePattern"""
        fileDir, fileCheck = os.path.split(filepattern)
        self._checkAllowedDirs(fileDir)
        self._checkAllowedFileTypes(fileCheck)
        if not os.path.isabs(filepattern):
            errStr = "listFiles must have an absolute path: {}".format(
                filepattern)
            raise RequestError(errStr)
        fileList = []
        for filename in glob.iglob(filepattern, recursive=True):
            if os.path.isdir(filename):
                continue
            fileList.append(filename)
        fileList = self._filterFileList(fileList)
        return fileList

    def getAllowedFileTypes(self) -> List[str]:
        """Returns the list of file extensions which are allowed for read and write"""
        return self.allowedFileTypes

    def _checkAllowedDirs(self, dir: str) -> bool:
        if self.allowedDirs is None or len(self.allowedDirs) == 0:
            raise ValidationError(
                'DataInterface: no allowed directories are set')
        if dir is None:
            return True
        if self.allowedDirs[0] == '*':
            return True
        dirMatch = False
        for allowedDir in self.allowedDirs:
            if dir.startswith(allowedDir):
                dirMatch = True
                break
        if dirMatch is False:
            raise ValidationError(
                f'Path {dir} not within list of allowed directories {self.allowedDirs}. '
                'Make sure you specified a full (absolute) path. '
                'Specify allowed directories with FileServer -d parameter.')
        return True

    def _checkAllowedFileTypes(self, filename: str) -> bool:
        """ Class-private function for checking if a file is allowed."""
        if self.allowedFileTypes is None or len(self.allowedFileTypes) == 0:
            raise ValidationError(
                'DataInterface: no allowed file types are set')
        if filename is None or filename == '':
            return True
        if self.allowedFileTypes[0] == '*':
            return True
        if filename[-1] == '*':
            # wildcards will be filtered later
            return True
        fileExtension = Path(filename).suffix
        if fileExtension not in self.allowedFileTypes:
            raise ValidationError(
                f"File type {fileExtension} not in list of allowed file types {self.allowedFileTypes}. "
                "Specify allowed filetypes with FileServer -f parameter.")
        return True

    def _filterFileList(self, fileList: List[str]) -> List[str]:
        """Class-private funtion to filter a list of files to include only allowed ones.
            Args: fileList - list of files to filter
            Returns: filtered fileList - containing only the allowed files
        """
        if self.allowedFileTypes is None or len(self.allowedFileTypes) == 0:
            raise ValidationError(
                'DataInterface: no allowed file types are set')
        if self.allowedFileTypes[0] == '*':
            return fileList
        filteredList = []
        for filename in fileList:
            if os.path.isdir(filename):
                continue
            fileExtension = Path(filename).suffix
            if fileExtension in self.allowedFileTypes:
                filteredList.append(filename)
        return filteredList
예제 #4
0
class WsFileWatcher:
    ''' A server that watches for files on the scanner computer and replies to
        cloud service requests with the file data. The communication connection
        is made with webSockets (ws)
    '''
    fileWatcher = FileWatcher()
    allowedDirs = None
    allowedTypes = None
    serverAddr = None
    sessionCookie = None
    needLogin = True
    shouldExit = False
    validationError = None
    # Synchronizing across threads
    clientLock = threading.Lock()
    fileWatchLock = threading.Lock()

    @staticmethod
    def runFileWatcher(serverAddr,
                       retryInterval=10,
                       allowedDirs=defaultAllowedDirs,
                       allowedTypes=defaultAllowedTypes,
                       username=None,
                       password=None,
                       testMode=False):
        WsFileWatcher.serverAddr = serverAddr
        WsFileWatcher.allowedDirs = allowedDirs
        for i in range(len(allowedTypes)):
            if not allowedTypes[i].startswith('.'):
                allowedTypes[i] = '.' + allowedTypes[i]
        WsFileWatcher.allowedTypes = allowedTypes
        # go into loop trying to do webSocket connection periodically
        WsFileWatcher.shouldExit = False
        while not WsFileWatcher.shouldExit:
            try:
                if WsFileWatcher.needLogin or WsFileWatcher.sessionCookie is None:
                    WsFileWatcher.sessionCookie = login(serverAddr,
                                                        username,
                                                        password,
                                                        testMode=testMode)
                wsAddr = os.path.join('wss://', serverAddr, 'wsData')
                if testMode:
                    print(
                        "Warning: using non-encrypted connection for test mode"
                    )
                    wsAddr = os.path.join('ws://', serverAddr, 'wsData')
                logging.log(DebugLevels.L6, "Trying connection: %s", wsAddr)
                ws = websocket.WebSocketApp(
                    wsAddr,
                    on_message=WsFileWatcher.on_message,
                    on_close=WsFileWatcher.on_close,
                    on_error=WsFileWatcher.on_error,
                    cookie="login="******"Connected to: %s", wsAddr)
                print("Connected to: {}".format(wsAddr))
                ws.run_forever(sslopt={"ca_certs": certFile})
            except Exception as err:
                logging.log(
                    logging.INFO, "WSFileWatcher Exception {}: {}".format(
                        type(err).__name__, str(err)))
                print('sleep {}'.format(retryInterval))
                time.sleep(retryInterval)

    @staticmethod
    def stop():
        WsFileWatcher.shouldExit = True

    @staticmethod
    def on_message(client, message):
        fileWatcher = WsFileWatcher.fileWatcher
        response = {'status': 400, 'error': 'unhandled request'}
        try:
            request = json.loads(message)
            response = request.copy()
            if 'data' in response: del response['data']
            cmd = request.get('cmd')
            dir = request.get('dir')
            filename = request.get('filename')
            timeout = request.get('timeout', 0)
            compress = request.get('compress', False)
            logging.log(logging.INFO, "{}: {} {}".format(cmd, dir, filename))
            # Do Validation Checks
            if cmd not in ['getAllowedFileTypes', 'ping', 'error']:
                # All other commands must have a filename or directory parameter
                if dir is None and filename is not None:
                    dir, filename = os.path.split(filename)
                if filename is None:
                    errStr = "{}: Missing filename param".format(cmd)
                    return send_error_response(client, response, errStr)
                if dir is None:
                    errStr = "{}: Missing dir param".format(cmd)
                    return send_error_response(client, response, errStr)
                if cmd in ('watchFile', 'getFile', 'getNewestFile'):
                    if not os.path.isabs(dir):
                        # make path relative to the watch dir
                        dir = os.path.join(fileWatcher.watchDir, dir)
                if WsFileWatcher.validateRequestedFile(dir, filename,
                                                       cmd) is False:
                    errStr = '{}: {}'.format(cmd,
                                             WsFileWatcher.validationError)
                    return send_error_response(client, response, errStr)
                if cmd in ('putTextFile', 'putBinaryFile', 'dataLog'):
                    if not os.path.exists(dir):
                        os.makedirs(dir)
                if not os.path.exists(dir):
                    errStr = '{}: No such directory: {}'.format(cmd, dir)
                    return send_error_response(client, response, errStr)
            # Now handle requests
            if cmd == 'initWatch':
                minFileSize = request.get('minFileSize')
                demoStep = request.get('demoStep')
                if minFileSize is None:
                    errStr = "InitWatch: Missing minFileSize param"
                    return send_error_response(client, response, errStr)
                WsFileWatcher.fileWatchLock.acquire()
                try:
                    fileWatcher.initFileNotifier(dir, filename, minFileSize,
                                                 demoStep)
                finally:
                    WsFileWatcher.fileWatchLock.release()
                response.update({'status': 200})
                return send_response(client, response)
            elif cmd == 'watchFile':
                WsFileWatcher.fileWatchLock.acquire()
                filename = os.path.join(dir, filename)
                try:
                    retVal = fileWatcher.waitForFile(filename, timeout=timeout)
                finally:
                    WsFileWatcher.fileWatchLock.release()
                if retVal is None:
                    errStr = "WatchFile: 408 Timeout {}s: {}".format(
                        timeout, filename)
                    response.update({'status': 408, 'error': errStr})
                    logging.log(logging.WARNING, errStr)
                    return send_response(client, response)
                else:
                    response.update({'status': 200, 'filename': filename})
                    return send_data_response(client, response, compress)
            elif cmd == 'getFile':
                filename = os.path.join(dir, filename)
                if not os.path.exists(filename):
                    errStr = "GetFile: File not found {}".format(filename)
                    return send_error_response(client, response, errStr)
                response.update({'status': 200, 'filename': filename})
                return send_data_response(client, response, compress)
            elif cmd == 'getNewestFile':
                resultFilename = findNewestFile(dir, filename)
                if resultFilename is None or not os.path.exists(
                        resultFilename):
                    errStr = 'GetNewestFile: file not found: {}'.format(
                        os.path.join(dir, filename))
                    return send_error_response(client, response, errStr)
                response.update({'status': 200, 'filename': resultFilename})
                return send_data_response(client, response, compress)
            elif cmd == 'listFiles':
                if not os.path.isabs(dir):
                    errStr = "listFiles must have an absolute path: {}".format(
                        dir)
                    return send_error_response(client, response, errStr)
                filePattern = os.path.join(dir, filename)
                fileList = [x for x in glob.iglob(filePattern, recursive=True)]
                fileList = WsFileWatcher.filterFileList(fileList)
                response.update({
                    'status': 200,
                    'filePattern': filePattern,
                    'fileList': fileList
                })
                return send_response(client, response)
            elif cmd == 'getAllowedFileTypes':
                response.update({
                    'status': 200,
                    'fileTypes': WsFileWatcher.allowedTypes
                })
                return send_response(client, response)
            elif cmd == 'putTextFile':
                text = request.get('text')
                if text is None:
                    errStr = 'PutTextFile: Missing text field'
                    return send_error_response(client, response, errStr)
                elif type(text) is not str:
                    errStr = "PutTextFile: Only text data allowed"
                    return send_error_response(client, response, errStr)
                fullPath = os.path.join(dir, filename)
                with open(fullPath, 'w') as volFile:
                    volFile.write(text)
                response.update({'status': 200})
                return send_response(client, response)
            elif cmd == 'putBinaryFile':
                try:
                    data = unpackDataMessage(request)
                except Exception as err:
                    errStr = 'putBinaryFile: {}'.format(err)
                    return send_error_response(client, response, errStr)
                # If data is None - Incomplete multipart data, more will follow
                if data is not None:
                    fullPath = os.path.join(dir, filename)
                    with open(fullPath, 'wb') as binFile:
                        binFile.write(data)
                response.update({'status': 200})
                return send_response(client, response)
            elif cmd == 'dataLog':
                logLine = request.get('logLine')
                if logLine is None:
                    errStr = 'DataLog: Missing logLine field'
                    return send_error_response(client, response, errStr)
                fullPath = os.path.join(dir, filename)
                with open(fullPath, 'a') as logFile:
                    logFile.write(logLine + '\n')
                response.update({'status': 200})
                return send_response(client, response)
            elif cmd == 'ping':
                response.update({'status': 200})
                return send_response(client, response)
            elif cmd == 'error':
                errorCode = request.get('status', 400)
                errorMsg = request.get('error', 'missing error msg')
                if errorCode == 401:
                    WsFileWatcher.needLogin = True
                    WsFileWatcher.sessionCookie = None
                errStr = 'Error {}: {}'.format(errorCode, errorMsg)
                logging.log(logging.ERROR, errStr)
                return
            else:
                errStr = 'OnMessage: Unrecognized command {}'.format(cmd)
                return send_error_response(client, response, errStr)
        except Exception as err:
            errStr = "OnMessage Exception: {}: {}".format(cmd, err)
            send_error_response(client, response, errStr)
            if cmd == 'error':
                sys.exit()
            return
        errStr = 'unhandled request'
        send_error_response(client, response, errStr)
        return

    @staticmethod
    def on_close(client):
        logging.info('connection closed')

    @staticmethod
    def on_error(client, error):
        if type(error) is KeyboardInterrupt:
            WsFileWatcher.shouldExit = True
        else:
            logging.log(
                logging.WARNING, "on_error: WSFileWatcher: {} {}".format(
                    type(error), str(error)))

    @staticmethod
    def validateRequestedFile(dir, file, cmd):
        textFileTypeOnly = False
        wildcardAllowed = False
        if cmd in ('putTextFile', 'dataLog'):
            textFileTypeOnly = True
        if cmd in ('listFiles'):
            wildcardAllowed = True
        # Restrict requests to certain directories and file types
        WsFileWatcher.validationError = None
        if WsFileWatcher.allowedDirs is None or WsFileWatcher.allowedTypes is None:
            raise StateError(
                'FileServer: Allowed Directories or File Types is not set')
        if file is not None and file != '':
            fileDir, filename = os.path.split(file)
            fileExtension = Path(filename).suffix
            if textFileTypeOnly:
                if fileExtension != '.txt':
                    WsFileWatcher.validationError = \
                        'Only .txt files allowed with command putTextFile() or dataLog()'
                    return False
            if wildcardAllowed:
                pass  # wildcard searches will be filtered for filetype later
            elif fileExtension not in WsFileWatcher.allowedTypes:
                WsFileWatcher.validationError = \
                    "File type {} not in list of allowed file types {}. " \
                    "Specify allowed filetypes with FileServer -f parameter.". \
                    format(fileExtension, WsFileWatcher.allowedTypes)
                return False
            if fileDir is not None and fileDir != '':  # and os.path.isabs(fileDir):
                dirMatch = False
                for allowedDir in WsFileWatcher.allowedDirs:
                    if fileDir.startswith(allowedDir):
                        dirMatch = True
                        break
                if dirMatch is False:
                    WsFileWatcher.validationError = \
                        "Path {} not within list of allowed directories {}. " \
                        "Make sure you specified a full (absolute) path. " \
                        "Specify allowed directories with FileServer -d parameter.". \
                        format(fileDir, WsFileWatcher.allowedDirs)
                    return False
        if dir is not None and dir != '':
            for allowedDir in WsFileWatcher.allowedDirs:
                if dir.startswith(allowedDir):
                    return True
            WsFileWatcher.validationError = \
                "Path {} not within list of allowed directories {}. " \
                "Make sure you specified a full (absolute) path. " \
                "Specify allowed directories with FileServer -d parameter.". \
                format(dir, WsFileWatcher.allowedDirs)
            return False
        # default case
        return True

    @staticmethod
    def filterFileList(fileList):
        filteredList = []
        for filename in fileList:
            if os.path.isdir(filename):
                continue
            fileExtension = Path(filename).suffix
            if fileExtension in WsFileWatcher.allowedTypes:
                filteredList.append(filename)
        return filteredList
예제 #5
0
class FileInterface:
    def __init__(self, filesremote=False, commPipes=None):
        self.local = not filesremote
        self.commPipes = commPipes
        self.fileWatcher = None
        self.initWatchSet = False
        if self.local:
            self.fileWatcher = FileWatcher()

    def __del__(self):
        if self.fileWatcher is not None:
            self.fileWatcher.__del__()
            self.fileWatcher = None

    def getFile(self, filename):
        data = None
        if self.local:
            with open(filename, 'rb') as fp:
                data = fp.read()
        else:
            getFileCmd = projUtils.getFileReqStruct(filename)
            retVals = projUtils.clientSendCmd(self.commPipes, getFileCmd)
            data = retVals.data
        return data

    def getNewestFile(self, filePattern):
        data = None
        if self.local:
            baseDir, filePattern = os.path.split(filePattern)
            if not os.path.isabs(baseDir):
                # TODO - handle relative paths
                pass
            filename = findNewestFile(baseDir, filePattern)
            if filename is None:
                # No file matching pattern
                raise FileNotFoundError('No file found matching pattern {}'.format(filePattern))
            elif not os.path.exists(filename):
                raise FileNotFoundError('File missing after match {}'.format(filePattern))
            else:
                with open(filename, 'rb') as fp:
                    data = fp.read()
        else:
            getNewestFileCmd = projUtils.getNewestFileReqStruct(filePattern)
            retVals = projUtils.clientSendCmd(self.commPipes, getNewestFileCmd)
            data = retVals.data
        return data

    def initWatch(self, dir, filePattern, minFileSize, demoStep=0):
        if self.local:
            self.fileWatcher.initFileNotifier(dir, filePattern, minFileSize, demoStep)
        else:
            initWatchCmd = projUtils.initWatchReqStruct(dir, filePattern, minFileSize, demoStep)
            projUtils.clientSendCmd(self.commPipes, initWatchCmd)
        self.initWatchSet = True
        return

    def watchFile(self, filename, timeout=5):
        data = None
        if not self.initWatchSet:
            raise StateError("FileInterface: watchFile() called without an initWatch()")
        if self.local:
            retVal = self.fileWatcher.waitForFile(filename, timeout=timeout)
            if retVal is None:
                raise FileNotFoundError("WatchFile: Timeout {}s: {}".format(timeout, filename))
            else:
                with open(filename, 'rb') as fp:
                    data = fp.read()
        else:
            watchCmd = projUtils.watchFileReqStruct(filename, timeout=timeout)
            retVals = projUtils.clientSendCmd(self.commPipes, watchCmd)
            data = retVals.data
        return data

    def putTextFile(self, filename, text):
        if self.local:
            outputDir = os.path.dirname(filename)
            if not os.path.exists(outputDir):
                os.makedirs(outputDir)
            with open(filename, 'w+') as textFile:
                textFile.write(text)
        else:
            putFileCmd = projUtils.putTextFileReqStruct(filename, text)
            projUtils.clientSendCmd(self.commPipes, putFileCmd)
        return

    def putBinaryFile(self, filename, data, compress=False):
        if self.local:
            outputDir = os.path.dirname(filename)
            if not os.path.exists(outputDir):
                os.makedirs(outputDir)
            with open(filename, 'wb+') as binFile:
                binFile.write(data)
        else:
            try:
                fileHash = None
                putFileCmd = projUtils.putBinaryFileReqStruct(filename)
                for putFilePart in projUtils.generateDataParts(data, putFileCmd, compress):
                    fileHash = putFilePart.get('fileHash')
                    projUtils.clientSendCmd(self.commPipes, putFilePart)
            except Exception as err:
                # Send error notice to clear any partially cached data on the server side
                # Add fileHash to message and send status=400 to notify
                if fileHash:
                    putFileCmd['fileHash'] = fileHash
                    putFileCmd['status'] = 400
                    projUtils.clientSendCmd(self.commPipes, putFileCmd)
                raise err
        return

    def listFiles(self, filePattern):
        if self.local:
            if not os.path.isabs(filePattern):
                errStr = "listFiles must have an absolute path: {}".format(filePattern)
                raise RequestError(errStr)
            fileList = []
            for filename in glob.iglob(filePattern, recursive=True):
                if os.path.isdir(filename):
                    continue
                fileList.append(filename)
        else:
            listCmd = projUtils.listFilesReqStruct(filePattern)
            retVals = projUtils.clientSendCmd(self.commPipes, listCmd)
            fileList = retVals.get('fileList')
            if type(fileList) is not list:
                errStr = "Invalid fileList reponse type {}: expecting list".format(type(fileList))
                raise StateError(errStr)
        return fileList

    def allowedFileTypes(self):
        if self.local:
            return ['*']
        else:
            cmd = projUtils.allowedFileTypesReqStruct()
            retVals = projUtils.clientSendCmd(self.commPipes, cmd)
            fileTypes = retVals.get('fileTypes')
            if type(fileTypes) is not list:
                errStr = "Invalid fileTypes reponse type {}: expecting list".format(type(fileTypes))
                raise StateError(errStr)
        return fileTypes