Exemplo n.º 1
0
    def buildWorkflow(self, request: TaskRequest, wnode: WorkflowNode, inputs: EDASDatasetCollection ) -> EDASDatasetCollection:
        op: OpNode = wnode
        self.logger.info("  ~~~~~~~~~~~~~~~~~~~~~~~~~~ Build Workflow, inputs: " + str( [ str(w) for w in op.inputs ] ) + ", op metadata = " + str(op.metadata) + ", axes = " + str(op.axes) + ", runargs = " + str(request.runargs) )
        results = EDASDatasetCollection("OpKernel.build-" + wnode.name )
#        if (len(inputs) < self._minInputs) or (len(inputs) > self._maxInputs): raise Exception( "Wrong number of inputs for kernel {}: {}".format( self._spec.name, len(inputs)))
        self.testOptions( wnode )
        for connector in wnode.connectors:
            inputDatasets: Dict[str,EDASDataset] = self.getInputCrossSections( inputs.filterByConnector(connector) )
            for key, dset in inputDatasets.items():
                processedInputs = self.preprocessInputs(request, op, dset )
                if wnode.ensDim is not None: processedInputs = self.mergeEnsembles( op, processedInputs )
                results[connector.output] = self.processInputCrossSection( request, op, processedInputs )
        return results
Exemplo n.º 2
0
 def getInputCrossSections(self, inputs: EDASDatasetCollection ) -> Dict[str,EDASDataset]:
     inputCrossSections = {}
     for dsKey, dset in inputs.items():
         for index, (akey, array) in enumerate(dset.arrayMap.items()):
             merge_set: EDASDataset = inputCrossSections.setdefault( index, EDASDataset(OrderedDict(), inputs.attrs ) )
             merge_set[dsKey + "-" + akey] = array
     return inputCrossSections
Exemplo n.º 3
0
    def buildWorkflow(self, request: TaskRequest, node: WorkflowNode, inputs: EDASDatasetCollection )  -> EDASDatasetCollection:
        snode: SourceNode = node
        results = EDASDatasetCollection( "InputKernel.build-" + node.name )
        t0 = time.time()
        dset = self.getCachedDataset( snode )
        if dset is not None:
            self.importToDatasetCollection(results, request, snode, dset.xr )
            self.logger.info( "Access input data from cache: " + dset.id )
        else:
            dataSource: DataSource = snode.varSource.dataSource
            if dataSource.type == SourceType.collection:
                from edas.collection.agg import Axis as AggAxis, File as AggFile
                collection = Collection.new( dataSource.address )
                self.logger.info("Input collection: " + dataSource.address )
                aggs = collection.sortVarsByAgg( snode.varSource.vids )
                domain = request.operationManager.domains.getDomain( snode.domain )
                if domain is not None:
                    timeBounds = domain.findAxisBounds(Axis.T)
                    startDate = None if (domain is None or timeBounds is None) else TimeConversions.parseDate(timeBounds.start)
                    endDate   = None if (domain is None or timeBounds is None) else TimeConversions.parseDate(timeBounds.end)
                else: startDate = endDate = None
                for ( aggId, vars ) in aggs.items():
                    use_chunks = True
                    pathList = collection.pathList(aggId) if startDate is None else collection.periodPathList(aggId,startDate,endDate)
                    assert len(pathList) > 0, f"No files found in aggregation {aggId} for date range {startDate} - {endDate} "
                    nFiles = len(pathList)
                    if use_chunks:
                        nReadPartitions = int( EdasEnv.get( "mfdataset.npartitions", 250 ) )
                        agg = collection.getAggregation(aggId)
                        nchunks, fileSize = agg.getChunkSize( nReadPartitions, nFiles )
                        chunk_kwargs = {} if nchunks is None else dict(chunks={"time": nchunks})
                        self.logger.info( f"Open mfdataset: vars={vars}, NFILES={nFiles}, FileSize={fileSize}, FILES[0]={pathList[0]}, chunk_kwargs={chunk_kwargs}, startDate={startDate}, endDate={endDate}, domain={domain}" )
                    else:
                        chunk_kwargs = {}
                        self.logger.info( f"Open mfdataset: vars={vars},  NFILES={nFiles}, FILES[0]={pathList[0]}" )
                    dset = xr.open_mfdataset( pathList, engine='netcdf4', data_vars=vars, parallel=True, **chunk_kwargs )
                    self.logger.info(f"Import to collection")
                    self.importToDatasetCollection( results, request, snode, dset )
                    self.logger.info(f"Collection import complete.")
            elif dataSource.type == SourceType.file:
                self.logger.info( "Reading data from address: " + dataSource.address )
                files = glob.glob( dataSource.address )
                parallel = len(files) > 1
                assert len(files) > 0, f"No files matching path {dataSource.address}"
                dset = xr.open_mfdataset(dataSource.address, engine='netcdf4', data_vars=snode.varSource.ids, parallel=parallel )
                self.importToDatasetCollection(results, request, snode, dset)
            elif dataSource.type == SourceType.archive:
                self.logger.info( "Reading data from archive: " + dataSource.address )
                dataPath =  request.archivePath( dataSource.address )
                dset = xr.open_mfdataset( [dataPath] )
                self.importToDatasetCollection(results, request, snode, dset)
            elif dataSource.type == SourceType.dap:
                nchunks = request.runargs.get( "ncores", 8 )
                self.logger.info( f" --------------->>> Reading data from address: {dataSource.address}, nchunks = {nchunks}" )
#                dset = xr.open_mfdataset( [dataSource.address], engine="netcdf4", data_vars=snode.varSource.ids, chunks={"time":nchunks} )
                dset = xr.open_dataset( dataSource.address, engine="netcdf4", chunks={"time":nchunks} )
                self.importToDatasetCollection( results, request, snode, dset )
            self.logger.info( f"Access input data source {dataSource.address}, time = {time.time() - t0} sec" )
            self.logger.info( "@L: LOCATION=> host: {}, thread: {}, proc: {}".format( socket.gethostname(), threading.get_ident(), os.getpid() ) )
        return results