コード例 #1
0
def run_schema_test(schema):
    new_sfm = dm.SplitFileManager(test_path, m_tvar)
    new_sfm.create_files(schema, clobber=True, leave_open=True)

    data = new_sfm.get_padded_by_coords(new_sfm.cs)
    assert (np.isnan(data).all())

    newdata = np.random.normal(size=data.shape).astype(np.float32)
    new_sfm.set_by_coords(new_sfm.cs, newdata)

    data = new_sfm.get_padded_by_coords(new_sfm.cs)
    assert ((data == newdata).all())

    subcs = mt.gen_coordset(dt.dates('dec 9 2000 - jan 15 2001'),
                            extent.ioffset[5, 2])

    newdata = np.random.normal(size=subcs.shape).astype(np.float32)
    new_sfm.set_by_coords(subcs, newdata)
    assert ((new_sfm.get_padded_by_coords(subcs) == newdata).all())
    assert ((new_sfm.get_padded_by_coords(
        new_sfm.cs)[new_sfm.cs.get_index(subcs)] == newdata.reshape(
            dm.simple_shape(newdata.shape))).all())

    subcs = mt.gen_coordset(dt.dates('dec 12 2000'), extent.ioffset[5, 2:4])

    newdata = np.random.normal(size=subcs.shape).astype(np.float32)
    new_sfm.set_by_coords(subcs, newdata)
    assert ((new_sfm.get_padded_by_coords(subcs) == newdata).all())
    assert ((new_sfm.get_padded_by_coords(
        new_sfm.cs)[new_sfm.cs.get_index(subcs)] == newdata.reshape(
            dm.simple_shape(newdata.shape))).all())
コード例 #2
0
    def run(self, period, extent, return_inputs=False, expanded=True):
        coords = mt.gen_coordset(period, extent)
        if self.outputs:
            ### initialise output files if necessary
            self.outputs.initialise(period, extent)

        iresults = self.input_runner.get_data_flat(coords, extent.mask)
        mresults = self.model_runner.run_from_mapping(iresults,
                                                      coords.shape[0],
                                                      extent.cell_count,
                                                      recycle_states=False)

        if self.outputs is not None:
            self.outputs.set_data(coords, mresults, extent.mask)

        if expanded:
            mresults = nodes.expand_dict(mresults, extent.mask)
            if return_inputs:
                ires_spatial = dict([(k, iresults[k])
                                     for k, v in self._dspecs.items()
                                     if 'cell' in v.dims])
                ires_other = dict([(k, iresults[k])
                                   for k, v in self._dspecs.items()
                                   if 'cell' not in v.dims])
                iresults = nodes.expand_dict(ires_spatial, extent.mask)
                for k, v in ires_other.items():
                    if k not in self.input_runner.const_inputs:
                        iresults[k] = v

        if return_inputs:
            return mresults, iresults
        else:
            return mresults
コード例 #3
0
    def handle_current_chunk(self):
        period_idx = self.cur_chunk['period_idx']

        # if period_idx != self.cur_period_idx:
        #     self._set_cur_period(period_idx)

        chunk_idx = self.cur_chunk['chunk_idx']

        bgroup = self.map_buffer(self.cur_chunk['buffer'])
        bdata = bgroup.map_dims(self.cur_chunk['dims'])

        coords = gen_coordset(self.periods[period_idx], self.chunks[chunk_idx])
        self.exe_graph.set_data(coords, bdata, self.chunks[chunk_idx].mask)

        self.reclaim_buffer(self.cur_chunk['buffer'])
        #     in_data = self.map_buffer(shm_buf,data['shape'])
        #
        #     pwrite_idx = self.time_indices[v]
        #     chunk = self.chunks[chunk_idx]
        #
        #     write_idx = (pwrite_idx,chunk.x,chunk.y)
        #
        #     # ENFORCE_MASK +++
        #     if self.enforce_mask:
        #         subex = self.extents[chunk_idx]
        #         in_data[:,subex.mask==True] = FILL_VALUE
        #
        #     self._write_slice(write_var,in_data,write_idx)

        self.cur_chunk = None
        self.cur_chunk_count += 1

        completed = self.cur_chunk_count / len(self.chunks) * 100.
        if completed - self.completed > 5:
            self._send_log(message("completed %.2f%%" % completed))
            # logger.info("completed %.2f%%",completed)
            self.completed = completed

        if self.cur_chunk_count == len(self.chunks):
            # logger.info("Completed period %s - %d of %d",dt.pretty_print_period(self.periods[self.cur_period_count]),self.cur_period_count+1,len(self.periods))
            self._send_log(
                message("Completed period %s - %d of %d" %
                        (dt.pretty_print_period(
                            self.periods[self.cur_period_count]),
                         self.cur_period_count + 1, len(self.periods))))
            self.completed = 0
            self.cur_chunk_count = 0
            self.cur_period_count += 1
            self.exe_graph.sync_all()
            if self.cur_period_count == len(self.periods):
                self._send_log(message("terminate"))
                self.terminate()
コード例 #4
0
    def init_files(self,period,extent):
        # if self._init:
        v = mt.Variable(self.nc_var,'mm')
        cs = mt.gen_coordset(period,extent)

        mvar = mt.MappedVariable(v,cs,self.dtype) #in_dtype)

        self.fm = SplitFileManager(self.path,mvar)

        clobber = self._file_mode == 'w'
        self.fm.create_files(self.schema,False,clobber,chunksize=self.chunksizes,ncparams=self.ncparams)#,chunksize=(1,16,16)) #None) #(64,32,32)) #(1,16,16))

        self._open = True
コード例 #5
0
ファイル: node.py プロジェクト: nemochina2008/awra_cms
    def build_io(self, node_settings):  #input_settings,required_inputs):
        '''
        Assume that we have NCD files we can load from - probably there are other sources...
        Build shared memory dictionaries for each of our cell_workers
        '''
        # print("Building inputs...")
        input_settings = node_settings['inputs']

        self.shm_inputs = {}
        self.shm_outputs = {}
        self.outputs = {}
        # inputs = {}

        igraph = graph.ExecutionGraph(input_settings.mapping)
        node_settings['input_dataspecs'] = igraph.get_dataspecs(True)

        for cid in self.catchments:
            ovs = node_settings['output_variables']
            self.shm_outputs[cid] = create_shm_dict(
                ovs, (self.timesteps, self.catchments[cid].cell_count))
            self.outputs[cid] = shm_to_nd_dict(**self.shm_outputs[cid])

            coords = gen_coordset(self.run_period, self.catchments[cid])
            input_build = igraph.get_data_flat(coords,
                                               self.catchments[cid].mask)

            # self.shm_inputs[cid] = {}

            shapes = {}
            for n in igraph.input_graph:
                if not type(igraph.input_graph[n]['exe']) == ConstNode:
                    try:
                        shapes[n] = input_build[n].shape
                    except AttributeError:
                        shapes[n] = None

            self.igraph = igraph

            self.shm_inputs[cid] = create_shm_dict_inputs(shapes)
            _inputs_np = shm_to_nd_dict_inputs(**self.shm_inputs[cid])

            for n in igraph.input_graph:
                if not type(igraph.input_graph[n]['exe']) == ConstNode:

                    # inputs[cid][n] = input_build[n]

                    if shapes[n] is None or len(shapes[n]) == 0:
                        _inputs_np[n][0] = input_build[n]
                    else:
                        _inputs_np[n][...] = input_build[n][...]
コード例 #6
0
def test_get_padded_by_coords():
    from awrams.utils.io.data_mapping import SplitFileManager
    from awrams.utils.mapping_types import gen_coordset
    import awrams.utils.datetools as dt

    path = os.path.join(os.path.dirname(__file__),'..','..','test_data','simulation')

    sfm = SplitFileManager.open_existing(path,'temp_min_day_*.nc','temp_min_day')
    # return sfm
    extent = sfm.get_extent().ioffset[200:230,200:230]
    period = dt.dates('2011')
    coords = gen_coordset(period,extent)

    data = sfm.get_padded_by_coords(coords)
    assert data.shape == coords.shape
コード例 #7
0
ファイル: rescaling.py プロジェクト: awracms/awra_cms_old
    def get_data(self,coords):
        origin = geo.GeoPoint.from_degrees(coords['latitude'][0],coords['longitude'][0])
        nlats,nlons = len(coords['latitude']), len(coords['longitude'])

        georef = geo.GeoReference(origin,nlats,nlons,self.cell_size,lat_orient=self.lat_orient)
        
        target = extents.Extent(georef)
        
        ss,offset,scalefac = get_extent_scale_indices(self.cextent,target)
        
        data = self.child.get_data(gen_coordset(coords['time'].index,ss))
        
        dres = downscale(data,scalefac)
        
        return dres
コード例 #8
0
def setup_var_coords():
    global m_tvar
    global extent
    global test_path

    georef = geo.GeoReference((0,0),1000,1000,0.05)
    extent = extents.Extent(georef).ioffset[0:10,0:10]

    period = dt.dates('dec 2000 - jan 25 2001')
    tvar = mt.Variable('test_var','mm')
    
    m_tvar = mt.MappedVariable(tvar,mt.gen_coordset(period,extent),np.float32)

    test_path = os.path.join(os.path.dirname(__file__),'file_tests')

    shutil.rmtree(test_path,True)
コード例 #9
0
ファイル: ondemand.py プロジェクト: nemochina2008/awra_cms
    def run(self,period,extent,return_inputs=False):
        coords = mt.gen_coordset(period,extent)
        if self.outputs:
            ### initialise output files if necessary
            self.outputs.initialise(coords[0])

        iresults = self.input_runner.get_data_flat(coords,extent.mask)
        mresults = self.model_runner.run_from_mapping(iresults,coords.shape[0],extent.cell_count,True)

        if self.outputs is not None:
            self.outputs.set_data(coords,mresults,extent.mask)

        if return_inputs:
            return mresults,iresults
        else:
            return mresults
コード例 #10
0
ファイル: worker.py プロジェクト: awracms/awra_cms_old
def build_multi_coords(alloc_info, extent_map, period):

    esubs = []

    for c in alloc_info['catchments']:
        esubs.append(
            extents.split_extent(extent_map[c['cid']], c['ncells'],
                                 c['start_cell']))

    coords = []
    masks = []

    for e in esubs:
        coords.append(gen_coordset(period, e))
        masks.append(e.mask)

    return coords, masks
コード例 #11
0
ファイル: ondemand.py プロジェクト: nemochina2008/awra_cms
    def run_prepack(self,iresults,period,extent):
        '''
        run with pre-packaged inputs for calibration
        :param cid:
        :param period:
        :param extent:
        :return:
        '''
        coords = mt.gen_coordset(period,extent)
        if self.outputs:
            ### initialise output files if necessary
            self.outputs.initialise(coords[0])

        mresults = self.model_runner.run_from_mapping(iresults,coords.shape[0],extent.cell_count,True)

        if self.outputs is not None:
            self.outputs.set_data(coords,mresults,extent.mask)

        return mresults
コード例 #12
0
ファイル: input_reader.py プロジェクト: awracms/awra_cms_old
    def read_active_chunk(self):
        extent = self.chunks[self.cur_chunk_idx]
        period = self.periods[self.cur_period_idx]
        coords = mt.gen_coordset(period, extent)
        results = self.exe_graph.get_data_flat(coords, extent.mask)

        buf_id, out_buf = self.get_buffer_safe('inputs')

        #+++ Ensure all dimensions filled out; static dimensions will always be max_dims
        # Can we guarantee this?
        dims = out_buf.max_dims.copy()
        dims.update(dict(cell=extent.cell_count, time=len(period)))

        target_data = out_buf.map_dims(dims)

        for k, v in results.items():
            try:
                target_data[k][...] = v
            except:
                print("Failed setting on", k, v)
                raise

        valid_keys = list(results.keys())

        if self.recycle_states:
            for k in self.state_keys:
                init_k = 'init_' + k

                target_data[init_k][...] = self.state_buffers[
                    self.cur_chunk_idx][k]
                if not init_k in valid_keys:
                    valid_keys.append(init_k)

        chunk_msg = message('chunk')
        content = chunk_msg['content']
        content['chunk_idx'] = self.cur_chunk_idx
        content['period_idx'] = self.cur_period_idx
        content['buffer'] = buf_id
        content['dims'] = dims
        content['valid_keys'] = valid_keys

        return chunk_msg
コード例 #13
0
def test_get_padded_by_coords():
    from awrams.utils.io.data_mapping import SplitFileManager
    from awrams.utils.mapping_types import gen_coordset
    import awrams.utils.datetools as dt

    data_paths = config_manager.get_system_profile().get_settings(
    )['DATA_PATHS']

    path = os.path.join(data_paths['BASE_DATA'], 'test_data', 'simulation',
                        'climate', 'temp_min_day')

    sfm = SplitFileManager.open_existing(path, 'temp_min_day_*.nc',
                                         'temp_min_day')
    # return sfm
    extent = sfm.get_extent().ioffset[200:230, 200:230]
    period = dt.dates('2011')
    coords = gen_coordset(period, extent)

    data = sfm.get_padded_by_coords(coords)
    assert data.shape == coords.shape
コード例 #14
0
def resample_data(in_path,
                  in_pattern,
                  variable,
                  period,
                  out_path,
                  to_freq,
                  method,
                  mode='w',
                  enforce_mask=True,
                  extent=None,
                  use_weights=False):
    '''
    method is 'sum' or 'mean'
    if no extent is supplied then the full (unmasked) input will be used
    'use_weights' should be set for unequally binned conversions (monthly->annual means, for example)
    '''
    from glob import glob
    import time
    import numpy as np

    from awrams.utils.messaging import reader as nr
    from awrams.utils.messaging import writer as nw
    from awrams.utils.messaging.brokers import OrderedFanInChunkBroker, FanOutChunkBroker
    from awrams.utils.messaging.general import message
    from awrams.utils.messaging.buffers import create_managed_buffers
    from awrams.utils.processing.chunk_resampler import ChunkedTimeResampler
    from awrams.utils.catchments import subdivide_extent
    from awrams.utils import datetools as dt
    from awrams.utils import mapping_types as mt
    from awrams.utils.io import data_mapping as dm

    start = time.time()

    NWORKERS = 2
    read_ahead = 3
    writemax = 3
    BLOCKSIZE = 128
    nbuffers = (NWORKERS * 2) + read_ahead + writemax

    # Receives all messages from clients
    '''
    Build the 'standard queues'
    This should be wrapped up somewhere else for 
    various topologies...
    '''

    control_master = mp.Queue()

    worker_q = mp.Queue()
    for i in range(NWORKERS):
        worker_q.put(i)

    #Reader Queues
    chunk_out_r = mp.Queue(read_ahead)
    reader_in = dict(control=mp.Queue())
    reader_out = dict(control=control_master, chunks=chunk_out_r)

    #Writer Queues
    chunk_in_w = mp.Queue(writemax)
    writer_in = dict(control=mp.Queue(), chunks=chunk_in_w)
    writer_out = dict(control=control_master)

    #FanIn queues
    fanout_in = dict(control=mp.Queue(), chunks=chunk_out_r, workers=worker_q)
    fanout_out = dict(control=control_master)

    fanin_in = dict(control=mp.Queue())
    fanin_out = dict(control=control_master, out=chunk_in_w, workers=worker_q)

    #Worker Queues
    work_inq = []
    work_outq = []

    for i in range(NWORKERS):
        work_inq.append(mp.Queue())
        fanout_out[i] = work_inq[-1]

        work_outq.append(mp.Queue())
        fanin_in[i] = work_outq[-1]
    '''
    End standard queues...
    '''

    infiles = glob(in_path + '/' + in_pattern)
    if len(infiles) > 1:
        ff = dm.filter_years(period)
    else:
        ff = None

    sfm = dm.SplitFileManager.open_existing(in_path,
                                            in_pattern,
                                            variable,
                                            ff=ff)
    in_freq = sfm.get_frequency()

    split_periods = [period]
    if hasattr(in_freq, 'freqstr'):
        if in_freq.freqstr == 'D':
            #Force splitting so that flat files don't end up getting loaded entirely into memory!
            #Also a bit of a hack to deal with PeriodIndex/DTI issues...
            split_periods = dt.split_period(
                dt.resample_dti(period, 'd', as_period=False), 'a')

    in_periods = [dt.resample_dti(p, in_freq) for p in split_periods]
    in_pmap = sfm.get_period_map_multi(in_periods)

    out_periods = []
    for p in in_periods:
        out_periods.append(dt.resample_dti(p, to_freq))

    if extent is None:
        extent = sfm.ref_ds.get_extent(True)
        if extent.mask.size == 1:
            extent.mask = (np.ones(extent.shape) * extent.mask).astype(np.bool)

    sub_extents = subdivide_extent(extent, BLOCKSIZE)
    chunks = [nr.Chunk(*s.indices()) for s in sub_extents]

    out_period = dt.resample_dti(period, to_freq)
    out_cs = mt.gen_coordset(out_period, extent)

    v = mt.Variable.from_ncvar(sfm.ref_ds.awra_var)
    in_dtype = sfm.ref_ds.awra_var.dtype

    sfm.close_all()

    use_weights = False

    if method == 'mean':
        if dt.validate_timeframe(in_freq) == 'MONTHLY':
            use_weights = True
    '''
    Need a way of formalising multiple buffer pools for different classes of
    work..
    '''

    max_inplen = max([len(p) for p in in_periods])
    bufshape = (max_inplen, BLOCKSIZE, BLOCKSIZE)

    shared_buffers = {}
    shared_buffers['main'] = create_managed_buffers(nbuffers,
                                                    bufshape,
                                                    build=False)

    mvar = mt.MappedVariable(v, out_cs, in_dtype)
    sfm = dm.FlatFileManager(out_path, mvar)

    CLOBBER = mode == 'w'

    sfm.create_files(False, CLOBBER, chunksize=(1, BLOCKSIZE, BLOCKSIZE))

    outfile_maps = {
        v.name:
        dict(nc_var=v.name, period_map=sfm.get_period_map_multi(out_periods))
    }
    infile_maps = {v.name: dict(nc_var=v.name, period_map=in_pmap)}

    reader = nr.StreamingReader(reader_in, reader_out, shared_buffers,
                                infile_maps, chunks, in_periods)
    writer = nw.MultifileChunkWriter(writer_in,
                                     writer_out,
                                     shared_buffers,
                                     outfile_maps,
                                     sub_extents,
                                     out_periods,
                                     enforce_mask=enforce_mask)

    fanout = FanOutChunkBroker(fanout_in, fanout_out)
    fanin = OrderedFanInChunkBroker(fanin_in, fanin_out, NWORKERS, len(chunks))

    fanout.start()
    fanin.start()

    workers = []
    w_control = []
    for i in range(NWORKERS):
        w_in = dict(control=mp.Queue(), chunks=work_inq[i])
        w_out = dict(control=control_master, chunks=work_outq[i])
        w = ChunkedTimeResampler(w_in,
                                 w_out,
                                 shared_buffers,
                                 sub_extents,
                                 in_periods,
                                 to_freq,
                                 method,
                                 enforce_mask=enforce_mask,
                                 use_weights=use_weights)
        workers.append(w)
        w_control.append(w_in['control'])
        w.start()

    writer.start()
    reader.start()

    writer.join()

    fanout_in['control'].put(message('terminate'))
    fanin_in['control'].put(message('terminate'))

    for i in range(NWORKERS):
        w_control[i].put(message('terminate'))

    for x in range(4):
        control_master.get()

    for i in range(NWORKERS):
        workers[i].join()
        control_master.get()

    reader.join()
    fanout.join()
    fanin.join()

    end = time.time()
    logger.info("elapsed time: %ss", end - start)
コード例 #15
0
def build_fixed_mapping(mapping, model_keys, dyn_keys, period, extent):
    """
    Args:
        mapping (dict): imap as produced by model.get_default_mapping()
        model_keys (list): keys that must be available as outputs of this graph
        dyn_keys (list): keys for dynamic nodes (ie runtime computed)
        period (DatetimeIndex): period over which to generate the mapping
        extent (extents.Extent): extent (or list of extents) over which to generate the mapping
    
    Returns:
        dict: mapping (containing static data as well as node mappings)
    
    """
    mapping = expand_const(mapping)

    dyn_head_mapping = dict([(k, v) for k, v in mapping.items()
                             if k in dyn_keys])
    dyn_head_mapping = list(dyn_head_mapping)
    dynamic_mapping = get_output_tree(dyn_head_mapping, mapping)
    dynamic_keys = set(dynamic_mapping)

    static_keys = dynamic_keys.symmetric_difference(set(mapping.keys()))
    static_mapping = dict([k, mapping[k]] for k in static_keys)

    def coalesce_keys(key_list):
        out_list = []
        for keys in key_list:
            out_list += keys
        return set(out_list)

    kl = [v.inputs for k, v in dynamic_mapping.items()]
    extra_static_keys = coalesce_keys(kl).difference(
        dynamic_keys
    )  # set of (static) keys required by the dynamic part of the graph

    req_static_outputs = (
        set(model_keys).intersection(static_keys)).union(extra_static_keys)

    #mapping = mapping.copy()

    #parameterised, fixed_mapping,fixed_endpoints = graph.split_parameterised_mapping(mapping,model_keys)
    static_egraph = ExecutionGraph(static_mapping)
    dspecs = static_egraph.get_dataspecs()

    if not isinstance(extent, list):
        extent = [extent]

    coords = [gen_coordset(period, e) for e in extent]
    masks = [e.mask for e in extent]

    res = static_egraph.get_data_flat(coords, masks, multi=True)
    static_data = dict((k, res[k]) for k in req_static_outputs)

    del (res)

    #static_mapping = graph.expand_const(fixed_mapping)

    out_mapping = dict(
        [k, static(static_data[k], dspecs[k], static_mapping[k].out_type)]
        for k in req_static_outputs)
    out_mapping.update(dynamic_mapping)

    return out_mapping
コード例 #16
0
ファイル: domain.py プロジェクト: adivicco47/awra_cms-1
def coordset_from_domain(in_domain):
    cs = gen_coordset(in_domain.coords['time'].mapping.value_obj,in_domain.coords['latlon'].mapping.value_obj)
    mask = in_domain.coords['latlon'].mapping.extent.mask
    return cs,mask
コード例 #17
0
ファイル: modelgraph.py プロジェクト: nemochina2008/awra_cms
    def handle_current_chunk(self):
        '''
        Send back any buffers we've received, and generate some output
        Obligatory for ChunkedProcessors
        '''

        c = self.cur_chunk['content']
        period_idx = c['period_idx']
        chunk_idx = c['chunk_idx']
        buf = c['buffer']
        dims = c['dims']
        valid = c['valid_keys']

        bgroup = self.map_buffer(buf)
        bdata = bgroup.map_dims(dims)

        # Basically need to ensure that we're getting states from the sender rather
        # than trying to generate them locally...
        if period_idx > 0:
            if not self.recycling:
                for k in self.state_keys:
                    init_k = 'init_' + k
                    if init_k in self.exe_graph.process_graph:
                        del self.exe_graph.process_graph[init_k]
                    self.exe_graph.input_graph[init_k] = dict(
                        exe=nodes.ConstNode(None))
                self.recycling = True

        for k in valid:
            self.exe_graph.input_graph[k]['exe'].value = bdata[k]

        extent = self.chunks[chunk_idx]
        period = self.periods[period_idx]
        coords = mt.gen_coordset(period, extent)
        graph_results = self.exe_graph.get_data_flat(coords, extent.mask)

        state_buf_id, state_buf = self.get_buffer_safe('states')
        output_buf_id, output_buf = self.get_buffer_safe('outputs')

        #Run the model!
        #data = self.process_data(in_data,period_idx,chunk_idx)
        model_results = self.runner.run_over_dimensions(graph_results, dims)

        self.reclaim_buffer(buf)

        states_np = state_buf.map_dims(dims)

        for k in self.state_keys:
            states_np[k][...] = model_results['final_states'][k]

        state_msg = message('states')
        c = state_msg['content']
        c['chunk_idx'] = chunk_idx
        c['period_idx'] = period_idx
        c['buffer'] = state_buf_id

        #self.send('state_return',state_msg)
        self.qout['state_return'].put(state_msg)

        output_np = output_buf.map_dims(dims)

        for k in model_results.keys():
            if k != 'final_states':
                output_np[k][...] = model_results[k]

        self.out_chunk = message('chunk')
        c = self.out_chunk['content']
        c['period_idx'] = period_idx
        c['chunk_idx'] = chunk_idx
        c['dims'] = dims
        c['buffer'] = output_buf_id

        self.send_out_chunk()