Exemplo n.º 1
0
 def preprocessInputs(self, request: TaskRequest, op: OpNode,
                      inputDataset: EDASDataset) -> EDASDataset:
     #         interp_na = bool(op.getParm("interp_na", False))
     #         if interp_na:   inputs: Dict[str,EDASArray] = { id: input.updateXa( input.xr.interpolate_na( dim="t", method='linear' ),"interp_na" ) for (id, input) in inputDset.arrayMap.items() }
     #         else:           inputs: Dict[str,EDASArray] = { id: input for (id, input) in inputDset.arrayMap.items() }
     if op.isSimple and not self.requiresAlignment:
         result = inputDataset
     else:
         resultArrays: OrderedDict[str, EDASArray] = OrderedDict()
         arrayList = list(inputDataset.arrayMap.values())
         for aid, array in inputDataset.arrayMap.items():
             unapplied_domains: Set[str] = array.unapplied_domains(
                 arrayList, op.domain)
             if len(unapplied_domains) > 0:
                 merged_domain: str = request.intersectDomains(
                     unapplied_domains, False)
                 processed_domain: Domain = request.cropDomain(
                     merged_domain, arrayList)
                 sub_array = array.subset(processed_domain,
                                          unapplied_domains)
                 resultArrays[aid] = sub_array
             else:
                 resultArrays[aid] = array
         resultDataset = EDASDataset(resultArrays, inputDataset.attrs)
         alignmentTarget = resultDataset.getAlignmentVariable(
             op.getParm("align", "lowest"))
         preprop_result = resultDataset.align(alignmentTarget)
         result: EDASDataset = preprop_result.groupby(op.grouping).resample(
             op.resampling)
     print(" $$$$ processInputCrossSection: " + op.name + " -> " +
           str(result.ids))
     return result.purge()
Exemplo n.º 2
0
 def plotPrediction(self, results: EDASDataset, title, **kwargs):
     plt.title(title)
     print("Plotting: " + ",".join(list(results.ids)))
     prediction: np.ndarray = results.getArray("prediction").xr.values
     target: np.ndarray = results.getArray("target").xr.values
     x = range(prediction.shape[0])
     plt.plot(x, prediction, "r-", label="prediction")
     plt.plot(x, target, "b--", label="target")
     plt.legend()
     plt.show()
Exemplo n.º 3
0
 def plotPerformance(self, results: EDASDataset, title, **kwargs):
     plt.title(title)
     print("Plotting: " + ",".join(list(results.ids)))
     valLoss: np.ndarray = results.getArray("val_loss").xr.values
     trainlLoss: np.ndarray = results.getArray("loss").xr.values
     x = range(valLoss.shape[0])
     plt.plot(x, valLoss, "r-", label="Validation Loss")
     plt.plot(x, trainlLoss, "b--", label="Training Loss")
     plt.legend()
     plt.show()
Exemplo n.º 4
0
 def signResult(self, result: EDASDataset, request: TaskRequest, node: WorkflowNode, **kwargs ) -> EDASDataset:
     result["proj"] = request.project
     result["exp"] = request.experiment
     result["uid"] = str(request.uid)
     for key,value in kwargs.items(): result[key] = value
     archive = node.getParm("archive")
     if archive: result["archive"] = archive
     if node.isBranch:
         result.persist()
     return result
Exemplo n.º 5
0
 def getInputCrossSections(self, inputs: EDASDatasetCollection ) -> Dict[str,EDASDataset]:
     inputCrossSections = {}
     for dsKey, dset in inputs.items():
         for index, (akey, array) in enumerate(dset.arrayMap.items()):
             merge_set: EDASDataset = inputCrossSections.setdefault( index, EDASDataset(OrderedDict(), inputs.attrs ) )
             merge_set[dsKey + "-" + akey] = array
     return inputCrossSections
Exemplo n.º 6
0
 def buildProduct(self, dsid: str, request: TaskRequest, node: OpNode,
                  result_arrays: List[EDASArray], attrs: Dict[str, str]):
     result_dset = EDASDataset.init(self.renameResults(result_arrays, node),
                                    attrs)
     for parm in ["product", "archive"]:
         result_dset[parm] = node.getParm(parm, "")
     result_dset.name = node.getResultId(dsid)
     return self.signResult(result_dset, request, node)
Exemplo n.º 7
0
 def mergeResults(self) -> List[EDASDataset]:
     if self.results[0].getResultClass() == "METADATA":
         return self.results
     mergeMethod: str = self.results[0]["merge"]
     if mergeMethod is None:
         return EDASDataset.merge( self.results )
     mergeToks = mergeMethod.split(":")
     return self.getBestResult( mergeToks[0].strip().lower(), mergeToks[1].strip().lower() )
Exemplo n.º 8
0
 def mergeEnsembles(self, op: OpNode, dset: EDASDataset) -> EDASDataset:
     self.logger.info(" ---> Merge Ensembles: ")
     for xarray in dset.xarrays:
         self.logger.info(
             f" Variable {xarray.name}: dims: {xarray.dims}, coords: {xarray.coords.keys()} "
         )
     sarray: xr.DataArray = xr.concat(dset.xarrays, dim=op.ensDim)
     result = EDASArray(dset.id, list(dset.domains)[0], sarray)
     return EDASDataset.init(OrderedDict([(dset.id, result)]), dset.attrs)
Exemplo n.º 9
0
 def getCachedDataset(self, snode: SourceNode) -> Optional[EDASDataset]:
     cache_status = self.getCacheStatus(snode)
     if cache_status != CacheStatus.Ignore:
         cid = snode.varSource.getId()
         variable = EDASKCacheMgr[cid]
         if variable is None:
             assert cache_status == CacheStatus.Option, "Missing cached input: " + cid
         else:
             return EDASDataset.init(OrderedDict([(cid, variable)]), {})
     return None
Exemplo n.º 10
0
 def processDataset(self, request: TaskRequest, dset: xr.Dataset,
                    snode: SourceNode) -> EDASDataset:
     coordMap = Axis.getDatasetCoordMap(dset)
     filteredCoordMap = snode.varSource.name2id(coordMap)
     edset: EDASDataset = EDASDataset.new(
         dset, {id: snode.domain
                for id in snode.varSource.ids}, filteredCoordMap)
     processed_domain: Domain = request.cropDomain(snode.domain,
                                                   edset.inputs,
                                                   snode.offset)
     result = edset.subset(processed_domain) if snode.domain else edset
     self.logger.info(
         f"###### ProcessDataset, coordMap = {filteredCoordMap}, dset coords = {list(edset.xr[0].coords.keys())}"
     )
     return self.signResult(result,
                            request,
                            snode,
                            sources=snode.varSource.getId())
Exemplo n.º 11
0
 def eof_plot(self, mtype: int, dset: EDASDataset):
     dset.plotMaps(view="mol", mtype=mtype)
Exemplo n.º 12
0
 def subset(self, domId: str, dset: EDASDataset) -> EDASDataset:
     return dset.subset(
         self.domain(domId)) if dset.requiresSubset(domId) else dset
Exemplo n.º 13
0
 def processDataset(self, request: TaskRequest, dset: xr.Dataset, snode: SourceNode) -> EDASDataset:
     coordMap = Axis.getDatasetCoordMap( dset )
     edset: EDASDataset = EDASDataset.new( dset, { id:snode.domain for id in snode.varSource.ids}, snode.varSource.name2id(coordMap) )
     processed_domain: Domain  = request.cropDomain( snode.domain, edset.inputs, snode.offset )
     result = edset.subset( processed_domain ) if snode.domain else edset
     return self.signResult(result, request, snode, sources=snode.varSource.getId())