def run_schema_test(schema): new_sfm = dm.SplitFileManager(test_path, m_tvar) new_sfm.create_files(schema, clobber=True, leave_open=True) data = new_sfm.get_padded_by_coords(new_sfm.cs) assert (np.isnan(data).all()) newdata = np.random.normal(size=data.shape).astype(np.float32) new_sfm.set_by_coords(new_sfm.cs, newdata) data = new_sfm.get_padded_by_coords(new_sfm.cs) assert ((data == newdata).all()) subcs = mt.gen_coordset(dt.dates('dec 9 2000 - jan 15 2001'), extent.ioffset[5, 2]) newdata = np.random.normal(size=subcs.shape).astype(np.float32) new_sfm.set_by_coords(subcs, newdata) assert ((new_sfm.get_padded_by_coords(subcs) == newdata).all()) assert ((new_sfm.get_padded_by_coords( new_sfm.cs)[new_sfm.cs.get_index(subcs)] == newdata.reshape( dm.simple_shape(newdata.shape))).all()) subcs = mt.gen_coordset(dt.dates('dec 12 2000'), extent.ioffset[5, 2:4]) newdata = np.random.normal(size=subcs.shape).astype(np.float32) new_sfm.set_by_coords(subcs, newdata) assert ((new_sfm.get_padded_by_coords(subcs) == newdata).all()) assert ((new_sfm.get_padded_by_coords( new_sfm.cs)[new_sfm.cs.get_index(subcs)] == newdata.reshape( dm.simple_shape(newdata.shape))).all())
def run(self, period, extent, return_inputs=False, expanded=True): coords = mt.gen_coordset(period, extent) if self.outputs: ### initialise output files if necessary self.outputs.initialise(period, extent) iresults = self.input_runner.get_data_flat(coords, extent.mask) mresults = self.model_runner.run_from_mapping(iresults, coords.shape[0], extent.cell_count, recycle_states=False) if self.outputs is not None: self.outputs.set_data(coords, mresults, extent.mask) if expanded: mresults = nodes.expand_dict(mresults, extent.mask) if return_inputs: ires_spatial = dict([(k, iresults[k]) for k, v in self._dspecs.items() if 'cell' in v.dims]) ires_other = dict([(k, iresults[k]) for k, v in self._dspecs.items() if 'cell' not in v.dims]) iresults = nodes.expand_dict(ires_spatial, extent.mask) for k, v in ires_other.items(): if k not in self.input_runner.const_inputs: iresults[k] = v if return_inputs: return mresults, iresults else: return mresults
def handle_current_chunk(self): period_idx = self.cur_chunk['period_idx'] # if period_idx != self.cur_period_idx: # self._set_cur_period(period_idx) chunk_idx = self.cur_chunk['chunk_idx'] bgroup = self.map_buffer(self.cur_chunk['buffer']) bdata = bgroup.map_dims(self.cur_chunk['dims']) coords = gen_coordset(self.periods[period_idx], self.chunks[chunk_idx]) self.exe_graph.set_data(coords, bdata, self.chunks[chunk_idx].mask) self.reclaim_buffer(self.cur_chunk['buffer']) # in_data = self.map_buffer(shm_buf,data['shape']) # # pwrite_idx = self.time_indices[v] # chunk = self.chunks[chunk_idx] # # write_idx = (pwrite_idx,chunk.x,chunk.y) # # # ENFORCE_MASK +++ # if self.enforce_mask: # subex = self.extents[chunk_idx] # in_data[:,subex.mask==True] = FILL_VALUE # # self._write_slice(write_var,in_data,write_idx) self.cur_chunk = None self.cur_chunk_count += 1 completed = self.cur_chunk_count / len(self.chunks) * 100. if completed - self.completed > 5: self._send_log(message("completed %.2f%%" % completed)) # logger.info("completed %.2f%%",completed) self.completed = completed if self.cur_chunk_count == len(self.chunks): # logger.info("Completed period %s - %d of %d",dt.pretty_print_period(self.periods[self.cur_period_count]),self.cur_period_count+1,len(self.periods)) self._send_log( message("Completed period %s - %d of %d" % (dt.pretty_print_period( self.periods[self.cur_period_count]), self.cur_period_count + 1, len(self.periods)))) self.completed = 0 self.cur_chunk_count = 0 self.cur_period_count += 1 self.exe_graph.sync_all() if self.cur_period_count == len(self.periods): self._send_log(message("terminate")) self.terminate()
def init_files(self,period,extent): # if self._init: v = mt.Variable(self.nc_var,'mm') cs = mt.gen_coordset(period,extent) mvar = mt.MappedVariable(v,cs,self.dtype) #in_dtype) self.fm = SplitFileManager(self.path,mvar) clobber = self._file_mode == 'w' self.fm.create_files(self.schema,False,clobber,chunksize=self.chunksizes,ncparams=self.ncparams)#,chunksize=(1,16,16)) #None) #(64,32,32)) #(1,16,16)) self._open = True
def build_io(self, node_settings): #input_settings,required_inputs): ''' Assume that we have NCD files we can load from - probably there are other sources... Build shared memory dictionaries for each of our cell_workers ''' # print("Building inputs...") input_settings = node_settings['inputs'] self.shm_inputs = {} self.shm_outputs = {} self.outputs = {} # inputs = {} igraph = graph.ExecutionGraph(input_settings.mapping) node_settings['input_dataspecs'] = igraph.get_dataspecs(True) for cid in self.catchments: ovs = node_settings['output_variables'] self.shm_outputs[cid] = create_shm_dict( ovs, (self.timesteps, self.catchments[cid].cell_count)) self.outputs[cid] = shm_to_nd_dict(**self.shm_outputs[cid]) coords = gen_coordset(self.run_period, self.catchments[cid]) input_build = igraph.get_data_flat(coords, self.catchments[cid].mask) # self.shm_inputs[cid] = {} shapes = {} for n in igraph.input_graph: if not type(igraph.input_graph[n]['exe']) == ConstNode: try: shapes[n] = input_build[n].shape except AttributeError: shapes[n] = None self.igraph = igraph self.shm_inputs[cid] = create_shm_dict_inputs(shapes) _inputs_np = shm_to_nd_dict_inputs(**self.shm_inputs[cid]) for n in igraph.input_graph: if not type(igraph.input_graph[n]['exe']) == ConstNode: # inputs[cid][n] = input_build[n] if shapes[n] is None or len(shapes[n]) == 0: _inputs_np[n][0] = input_build[n] else: _inputs_np[n][...] = input_build[n][...]
def test_get_padded_by_coords(): from awrams.utils.io.data_mapping import SplitFileManager from awrams.utils.mapping_types import gen_coordset import awrams.utils.datetools as dt path = os.path.join(os.path.dirname(__file__),'..','..','test_data','simulation') sfm = SplitFileManager.open_existing(path,'temp_min_day_*.nc','temp_min_day') # return sfm extent = sfm.get_extent().ioffset[200:230,200:230] period = dt.dates('2011') coords = gen_coordset(period,extent) data = sfm.get_padded_by_coords(coords) assert data.shape == coords.shape
def get_data(self,coords): origin = geo.GeoPoint.from_degrees(coords['latitude'][0],coords['longitude'][0]) nlats,nlons = len(coords['latitude']), len(coords['longitude']) georef = geo.GeoReference(origin,nlats,nlons,self.cell_size,lat_orient=self.lat_orient) target = extents.Extent(georef) ss,offset,scalefac = get_extent_scale_indices(self.cextent,target) data = self.child.get_data(gen_coordset(coords['time'].index,ss)) dres = downscale(data,scalefac) return dres
def setup_var_coords(): global m_tvar global extent global test_path georef = geo.GeoReference((0,0),1000,1000,0.05) extent = extents.Extent(georef).ioffset[0:10,0:10] period = dt.dates('dec 2000 - jan 25 2001') tvar = mt.Variable('test_var','mm') m_tvar = mt.MappedVariable(tvar,mt.gen_coordset(period,extent),np.float32) test_path = os.path.join(os.path.dirname(__file__),'file_tests') shutil.rmtree(test_path,True)
def run(self,period,extent,return_inputs=False): coords = mt.gen_coordset(period,extent) if self.outputs: ### initialise output files if necessary self.outputs.initialise(coords[0]) iresults = self.input_runner.get_data_flat(coords,extent.mask) mresults = self.model_runner.run_from_mapping(iresults,coords.shape[0],extent.cell_count,True) if self.outputs is not None: self.outputs.set_data(coords,mresults,extent.mask) if return_inputs: return mresults,iresults else: return mresults
def build_multi_coords(alloc_info, extent_map, period): esubs = [] for c in alloc_info['catchments']: esubs.append( extents.split_extent(extent_map[c['cid']], c['ncells'], c['start_cell'])) coords = [] masks = [] for e in esubs: coords.append(gen_coordset(period, e)) masks.append(e.mask) return coords, masks
def run_prepack(self,iresults,period,extent): ''' run with pre-packaged inputs for calibration :param cid: :param period: :param extent: :return: ''' coords = mt.gen_coordset(period,extent) if self.outputs: ### initialise output files if necessary self.outputs.initialise(coords[0]) mresults = self.model_runner.run_from_mapping(iresults,coords.shape[0],extent.cell_count,True) if self.outputs is not None: self.outputs.set_data(coords,mresults,extent.mask) return mresults
def read_active_chunk(self): extent = self.chunks[self.cur_chunk_idx] period = self.periods[self.cur_period_idx] coords = mt.gen_coordset(period, extent) results = self.exe_graph.get_data_flat(coords, extent.mask) buf_id, out_buf = self.get_buffer_safe('inputs') #+++ Ensure all dimensions filled out; static dimensions will always be max_dims # Can we guarantee this? dims = out_buf.max_dims.copy() dims.update(dict(cell=extent.cell_count, time=len(period))) target_data = out_buf.map_dims(dims) for k, v in results.items(): try: target_data[k][...] = v except: print("Failed setting on", k, v) raise valid_keys = list(results.keys()) if self.recycle_states: for k in self.state_keys: init_k = 'init_' + k target_data[init_k][...] = self.state_buffers[ self.cur_chunk_idx][k] if not init_k in valid_keys: valid_keys.append(init_k) chunk_msg = message('chunk') content = chunk_msg['content'] content['chunk_idx'] = self.cur_chunk_idx content['period_idx'] = self.cur_period_idx content['buffer'] = buf_id content['dims'] = dims content['valid_keys'] = valid_keys return chunk_msg
def test_get_padded_by_coords(): from awrams.utils.io.data_mapping import SplitFileManager from awrams.utils.mapping_types import gen_coordset import awrams.utils.datetools as dt data_paths = config_manager.get_system_profile().get_settings( )['DATA_PATHS'] path = os.path.join(data_paths['BASE_DATA'], 'test_data', 'simulation', 'climate', 'temp_min_day') sfm = SplitFileManager.open_existing(path, 'temp_min_day_*.nc', 'temp_min_day') # return sfm extent = sfm.get_extent().ioffset[200:230, 200:230] period = dt.dates('2011') coords = gen_coordset(period, extent) data = sfm.get_padded_by_coords(coords) assert data.shape == coords.shape
def resample_data(in_path, in_pattern, variable, period, out_path, to_freq, method, mode='w', enforce_mask=True, extent=None, use_weights=False): ''' method is 'sum' or 'mean' if no extent is supplied then the full (unmasked) input will be used 'use_weights' should be set for unequally binned conversions (monthly->annual means, for example) ''' from glob import glob import time import numpy as np from awrams.utils.messaging import reader as nr from awrams.utils.messaging import writer as nw from awrams.utils.messaging.brokers import OrderedFanInChunkBroker, FanOutChunkBroker from awrams.utils.messaging.general import message from awrams.utils.messaging.buffers import create_managed_buffers from awrams.utils.processing.chunk_resampler import ChunkedTimeResampler from awrams.utils.catchments import subdivide_extent from awrams.utils import datetools as dt from awrams.utils import mapping_types as mt from awrams.utils.io import data_mapping as dm start = time.time() NWORKERS = 2 read_ahead = 3 writemax = 3 BLOCKSIZE = 128 nbuffers = (NWORKERS * 2) + read_ahead + writemax # Receives all messages from clients ''' Build the 'standard queues' This should be wrapped up somewhere else for various topologies... ''' control_master = mp.Queue() worker_q = mp.Queue() for i in range(NWORKERS): worker_q.put(i) #Reader Queues chunk_out_r = mp.Queue(read_ahead) reader_in = dict(control=mp.Queue()) reader_out = dict(control=control_master, chunks=chunk_out_r) #Writer Queues chunk_in_w = mp.Queue(writemax) writer_in = dict(control=mp.Queue(), chunks=chunk_in_w) writer_out = dict(control=control_master) #FanIn queues fanout_in = dict(control=mp.Queue(), chunks=chunk_out_r, workers=worker_q) fanout_out = dict(control=control_master) fanin_in = dict(control=mp.Queue()) fanin_out = dict(control=control_master, out=chunk_in_w, workers=worker_q) #Worker Queues work_inq = [] work_outq = [] for i in range(NWORKERS): work_inq.append(mp.Queue()) fanout_out[i] = work_inq[-1] work_outq.append(mp.Queue()) fanin_in[i] = work_outq[-1] ''' End standard queues... ''' infiles = glob(in_path + '/' + in_pattern) if len(infiles) > 1: ff = dm.filter_years(period) else: ff = None sfm = dm.SplitFileManager.open_existing(in_path, in_pattern, variable, ff=ff) in_freq = sfm.get_frequency() split_periods = [period] if hasattr(in_freq, 'freqstr'): if in_freq.freqstr == 'D': #Force splitting so that flat files don't end up getting loaded entirely into memory! #Also a bit of a hack to deal with PeriodIndex/DTI issues... split_periods = dt.split_period( dt.resample_dti(period, 'd', as_period=False), 'a') in_periods = [dt.resample_dti(p, in_freq) for p in split_periods] in_pmap = sfm.get_period_map_multi(in_periods) out_periods = [] for p in in_periods: out_periods.append(dt.resample_dti(p, to_freq)) if extent is None: extent = sfm.ref_ds.get_extent(True) if extent.mask.size == 1: extent.mask = (np.ones(extent.shape) * extent.mask).astype(np.bool) sub_extents = subdivide_extent(extent, BLOCKSIZE) chunks = [nr.Chunk(*s.indices()) for s in sub_extents] out_period = dt.resample_dti(period, to_freq) out_cs = mt.gen_coordset(out_period, extent) v = mt.Variable.from_ncvar(sfm.ref_ds.awra_var) in_dtype = sfm.ref_ds.awra_var.dtype sfm.close_all() use_weights = False if method == 'mean': if dt.validate_timeframe(in_freq) == 'MONTHLY': use_weights = True ''' Need a way of formalising multiple buffer pools for different classes of work.. ''' max_inplen = max([len(p) for p in in_periods]) bufshape = (max_inplen, BLOCKSIZE, BLOCKSIZE) shared_buffers = {} shared_buffers['main'] = create_managed_buffers(nbuffers, bufshape, build=False) mvar = mt.MappedVariable(v, out_cs, in_dtype) sfm = dm.FlatFileManager(out_path, mvar) CLOBBER = mode == 'w' sfm.create_files(False, CLOBBER, chunksize=(1, BLOCKSIZE, BLOCKSIZE)) outfile_maps = { v.name: dict(nc_var=v.name, period_map=sfm.get_period_map_multi(out_periods)) } infile_maps = {v.name: dict(nc_var=v.name, period_map=in_pmap)} reader = nr.StreamingReader(reader_in, reader_out, shared_buffers, infile_maps, chunks, in_periods) writer = nw.MultifileChunkWriter(writer_in, writer_out, shared_buffers, outfile_maps, sub_extents, out_periods, enforce_mask=enforce_mask) fanout = FanOutChunkBroker(fanout_in, fanout_out) fanin = OrderedFanInChunkBroker(fanin_in, fanin_out, NWORKERS, len(chunks)) fanout.start() fanin.start() workers = [] w_control = [] for i in range(NWORKERS): w_in = dict(control=mp.Queue(), chunks=work_inq[i]) w_out = dict(control=control_master, chunks=work_outq[i]) w = ChunkedTimeResampler(w_in, w_out, shared_buffers, sub_extents, in_periods, to_freq, method, enforce_mask=enforce_mask, use_weights=use_weights) workers.append(w) w_control.append(w_in['control']) w.start() writer.start() reader.start() writer.join() fanout_in['control'].put(message('terminate')) fanin_in['control'].put(message('terminate')) for i in range(NWORKERS): w_control[i].put(message('terminate')) for x in range(4): control_master.get() for i in range(NWORKERS): workers[i].join() control_master.get() reader.join() fanout.join() fanin.join() end = time.time() logger.info("elapsed time: %ss", end - start)
def build_fixed_mapping(mapping, model_keys, dyn_keys, period, extent): """ Args: mapping (dict): imap as produced by model.get_default_mapping() model_keys (list): keys that must be available as outputs of this graph dyn_keys (list): keys for dynamic nodes (ie runtime computed) period (DatetimeIndex): period over which to generate the mapping extent (extents.Extent): extent (or list of extents) over which to generate the mapping Returns: dict: mapping (containing static data as well as node mappings) """ mapping = expand_const(mapping) dyn_head_mapping = dict([(k, v) for k, v in mapping.items() if k in dyn_keys]) dyn_head_mapping = list(dyn_head_mapping) dynamic_mapping = get_output_tree(dyn_head_mapping, mapping) dynamic_keys = set(dynamic_mapping) static_keys = dynamic_keys.symmetric_difference(set(mapping.keys())) static_mapping = dict([k, mapping[k]] for k in static_keys) def coalesce_keys(key_list): out_list = [] for keys in key_list: out_list += keys return set(out_list) kl = [v.inputs for k, v in dynamic_mapping.items()] extra_static_keys = coalesce_keys(kl).difference( dynamic_keys ) # set of (static) keys required by the dynamic part of the graph req_static_outputs = ( set(model_keys).intersection(static_keys)).union(extra_static_keys) #mapping = mapping.copy() #parameterised, fixed_mapping,fixed_endpoints = graph.split_parameterised_mapping(mapping,model_keys) static_egraph = ExecutionGraph(static_mapping) dspecs = static_egraph.get_dataspecs() if not isinstance(extent, list): extent = [extent] coords = [gen_coordset(period, e) for e in extent] masks = [e.mask for e in extent] res = static_egraph.get_data_flat(coords, masks, multi=True) static_data = dict((k, res[k]) for k in req_static_outputs) del (res) #static_mapping = graph.expand_const(fixed_mapping) out_mapping = dict( [k, static(static_data[k], dspecs[k], static_mapping[k].out_type)] for k in req_static_outputs) out_mapping.update(dynamic_mapping) return out_mapping
def coordset_from_domain(in_domain): cs = gen_coordset(in_domain.coords['time'].mapping.value_obj,in_domain.coords['latlon'].mapping.value_obj) mask = in_domain.coords['latlon'].mapping.extent.mask return cs,mask
def handle_current_chunk(self): ''' Send back any buffers we've received, and generate some output Obligatory for ChunkedProcessors ''' c = self.cur_chunk['content'] period_idx = c['period_idx'] chunk_idx = c['chunk_idx'] buf = c['buffer'] dims = c['dims'] valid = c['valid_keys'] bgroup = self.map_buffer(buf) bdata = bgroup.map_dims(dims) # Basically need to ensure that we're getting states from the sender rather # than trying to generate them locally... if period_idx > 0: if not self.recycling: for k in self.state_keys: init_k = 'init_' + k if init_k in self.exe_graph.process_graph: del self.exe_graph.process_graph[init_k] self.exe_graph.input_graph[init_k] = dict( exe=nodes.ConstNode(None)) self.recycling = True for k in valid: self.exe_graph.input_graph[k]['exe'].value = bdata[k] extent = self.chunks[chunk_idx] period = self.periods[period_idx] coords = mt.gen_coordset(period, extent) graph_results = self.exe_graph.get_data_flat(coords, extent.mask) state_buf_id, state_buf = self.get_buffer_safe('states') output_buf_id, output_buf = self.get_buffer_safe('outputs') #Run the model! #data = self.process_data(in_data,period_idx,chunk_idx) model_results = self.runner.run_over_dimensions(graph_results, dims) self.reclaim_buffer(buf) states_np = state_buf.map_dims(dims) for k in self.state_keys: states_np[k][...] = model_results['final_states'][k] state_msg = message('states') c = state_msg['content'] c['chunk_idx'] = chunk_idx c['period_idx'] = period_idx c['buffer'] = state_buf_id #self.send('state_return',state_msg) self.qout['state_return'].put(state_msg) output_np = output_buf.map_dims(dims) for k in model_results.keys(): if k != 'final_states': output_np[k][...] = model_results[k] self.out_chunk = message('chunk') c = self.out_chunk['content'] c['period_idx'] = period_idx c['chunk_idx'] = chunk_idx c['dims'] = dims c['buffer'] = output_buf_id self.send_out_chunk()