Exemple #1
0
    def __init__(self,
                 inputs,
                 processors,
                 period,
                 extent,
                 existing_inputs=False):
        '''
        Inputs: {var_name: input_reader}
        Processors: {var_name: [p0, p1...]}
        '''
        self.inputs = inputs
        self.processor_map = processors
        self.period = period
        self.split_periods = dt.split_period(self.period, 'a')

        self.control = mp.Queue()
        self.input_mapper = ClimateInputBridge(inputs, self.split_periods,
                                               extent, existing_inputs)

        #self.extent = extent.translate_to_origin(self.input_mapper.input_bridge.geo_ref)

        self.extent = extent
def resample_data(in_path,
                  in_pattern,
                  variable,
                  period,
                  out_path,
                  to_freq,
                  method,
                  mode='w',
                  enforce_mask=True,
                  extent=None,
                  use_weights=False):
    '''
    method is 'sum' or 'mean'
    if no extent is supplied then the full (unmasked) input will be used
    'use_weights' should be set for unequally binned conversions (monthly->annual means, for example)
    '''
    from glob import glob
    import time
    import numpy as np

    from awrams.utils.messaging import reader as nr
    from awrams.utils.messaging import writer as nw
    from awrams.utils.messaging.brokers import OrderedFanInChunkBroker, FanOutChunkBroker
    from awrams.utils.messaging.general import message
    from awrams.utils.messaging.buffers import create_managed_buffers
    from awrams.utils.processing.chunk_resampler import ChunkedTimeResampler
    from awrams.utils.catchments import subdivide_extent
    from awrams.utils import datetools as dt
    from awrams.utils import mapping_types as mt
    from awrams.utils.io import data_mapping as dm

    start = time.time()

    NWORKERS = 2
    read_ahead = 3
    writemax = 3
    BLOCKSIZE = 128
    nbuffers = (NWORKERS * 2) + read_ahead + writemax

    # Receives all messages from clients
    '''
    Build the 'standard queues'
    This should be wrapped up somewhere else for 
    various topologies...
    '''

    control_master = mp.Queue()

    worker_q = mp.Queue()
    for i in range(NWORKERS):
        worker_q.put(i)

    #Reader Queues
    chunk_out_r = mp.Queue(read_ahead)
    reader_in = dict(control=mp.Queue())
    reader_out = dict(control=control_master, chunks=chunk_out_r)

    #Writer Queues
    chunk_in_w = mp.Queue(writemax)
    writer_in = dict(control=mp.Queue(), chunks=chunk_in_w)
    writer_out = dict(control=control_master)

    #FanIn queues
    fanout_in = dict(control=mp.Queue(), chunks=chunk_out_r, workers=worker_q)
    fanout_out = dict(control=control_master)

    fanin_in = dict(control=mp.Queue())
    fanin_out = dict(control=control_master, out=chunk_in_w, workers=worker_q)

    #Worker Queues
    work_inq = []
    work_outq = []

    for i in range(NWORKERS):
        work_inq.append(mp.Queue())
        fanout_out[i] = work_inq[-1]

        work_outq.append(mp.Queue())
        fanin_in[i] = work_outq[-1]
    '''
    End standard queues...
    '''

    infiles = glob(in_path + '/' + in_pattern)
    if len(infiles) > 1:
        ff = dm.filter_years(period)
    else:
        ff = None

    sfm = dm.SplitFileManager.open_existing(in_path,
                                            in_pattern,
                                            variable,
                                            ff=ff)
    in_freq = sfm.get_frequency()

    split_periods = [period]
    if hasattr(in_freq, 'freqstr'):
        if in_freq.freqstr == 'D':
            #Force splitting so that flat files don't end up getting loaded entirely into memory!
            #Also a bit of a hack to deal with PeriodIndex/DTI issues...
            split_periods = dt.split_period(
                dt.resample_dti(period, 'd', as_period=False), 'a')

    in_periods = [dt.resample_dti(p, in_freq) for p in split_periods]
    in_pmap = sfm.get_period_map_multi(in_periods)

    out_periods = []
    for p in in_periods:
        out_periods.append(dt.resample_dti(p, to_freq))

    if extent is None:
        extent = sfm.ref_ds.get_extent(True)
        if extent.mask.size == 1:
            extent.mask = (np.ones(extent.shape) * extent.mask).astype(np.bool)

    sub_extents = subdivide_extent(extent, BLOCKSIZE)
    chunks = [nr.Chunk(*s.indices()) for s in sub_extents]

    out_period = dt.resample_dti(period, to_freq)
    out_cs = mt.gen_coordset(out_period, extent)

    v = mt.Variable.from_ncvar(sfm.ref_ds.awra_var)
    in_dtype = sfm.ref_ds.awra_var.dtype

    sfm.close_all()

    use_weights = False

    if method == 'mean':
        if dt.validate_timeframe(in_freq) == 'MONTHLY':
            use_weights = True
    '''
    Need a way of formalising multiple buffer pools for different classes of
    work..
    '''

    max_inplen = max([len(p) for p in in_periods])
    bufshape = (max_inplen, BLOCKSIZE, BLOCKSIZE)

    shared_buffers = {}
    shared_buffers['main'] = create_managed_buffers(nbuffers,
                                                    bufshape,
                                                    build=False)

    mvar = mt.MappedVariable(v, out_cs, in_dtype)
    sfm = dm.FlatFileManager(out_path, mvar)

    CLOBBER = mode == 'w'

    sfm.create_files(False, CLOBBER, chunksize=(1, BLOCKSIZE, BLOCKSIZE))

    outfile_maps = {
        v.name:
        dict(nc_var=v.name, period_map=sfm.get_period_map_multi(out_periods))
    }
    infile_maps = {v.name: dict(nc_var=v.name, period_map=in_pmap)}

    reader = nr.StreamingReader(reader_in, reader_out, shared_buffers,
                                infile_maps, chunks, in_periods)
    writer = nw.MultifileChunkWriter(writer_in,
                                     writer_out,
                                     shared_buffers,
                                     outfile_maps,
                                     sub_extents,
                                     out_periods,
                                     enforce_mask=enforce_mask)

    fanout = FanOutChunkBroker(fanout_in, fanout_out)
    fanin = OrderedFanInChunkBroker(fanin_in, fanin_out, NWORKERS, len(chunks))

    fanout.start()
    fanin.start()

    workers = []
    w_control = []
    for i in range(NWORKERS):
        w_in = dict(control=mp.Queue(), chunks=work_inq[i])
        w_out = dict(control=control_master, chunks=work_outq[i])
        w = ChunkedTimeResampler(w_in,
                                 w_out,
                                 shared_buffers,
                                 sub_extents,
                                 in_periods,
                                 to_freq,
                                 method,
                                 enforce_mask=enforce_mask,
                                 use_weights=use_weights)
        workers.append(w)
        w_control.append(w_in['control'])
        w.start()

    writer.start()
    reader.start()

    writer.join()

    fanout_in['control'].put(message('terminate'))
    fanin_in['control'].put(message('terminate'))

    for i in range(NWORKERS):
        w_control[i].put(message('terminate'))

    for x in range(4):
        control_master.get()

    for i in range(NWORKERS):
        workers[i].join()
        control_master.get()

    reader.join()
    fanout.join()
    fanin.join()

    end = time.time()
    logger.info("elapsed time: %ss", end - start)
 def split_periods(self, period):
     return dt.split_period(period, 'm')
Exemple #4
0
 def split_dti(self, dti):
     return dt.split_period(dti, 'm')
Exemple #5
0
    def run(self,input_map,output_map,period,extent): #periods,chunks):
        '''
        Should be the basis for new-style sim server
        Currently no file output, but runs inputgraph/model quite happily...
        '''
        import time
        start = time.time()

        chunks = extents.subdivide_extent(extent,self.spatial_chunk)
        periods = dt.split_period(period,'a')

        self.logger.info("Getting I/O dataspecs...")
        #+++ Document rescaling separately, don't just change the graph behind the scenes...
        #mapping = graph.map_rescaling_nodes(input_map.mapping,extent)
        mapping = input_map
        filtered = graph.get_input_tree(self.model.get_input_keys(),mapping)

        input_nodes = {}
        worker_nodes = {}
        output_nodes = {}

        for k,v in filtered.items():
            if 'io' in v.properties:
                input_nodes[k] = v
                worker_nodes[k] = nodes.const(None)
            else:
                worker_nodes[k] = v

        for k,v in output_map.items():
            try:
                if v.properties['io'] == 'from_model':
                    output_nodes[k] = v
            except: # AttributeError:
                pass
                # print("EXCEPTION",k,v)

        igraph = graph.ExecutionGraph(input_nodes)

        self._set_max_dims(igraph)

        input_dspecs = igraph.get_dataspecs(True)

        #+++ No guarantee this will close files. Put in separate function?
        del igraph

        model_dspecs = graph.ExecutionGraph(mapping).get_dataspecs(True)
        output_dspecs = graph.OutputGraph(output_nodes).get_dataspecs(True)

        self.model.init_shared(model_dspecs)
        ### initialise output ncfiles
        self.logger.info("Initialising output files...")
        outgraph = graph.OutputGraph(output_map)
        outgraph.initialise(period,extent)

        #+++ Can we guarantee that statespecs will be 64bit for recycling?

        # NWORKERS = 2
        # READ_AHEAD = 1

        sspec = DataSpec('array',['cell'],np.float64)

        state_specs = {}
        for k in self.model.get_state_keys():
            init_k = 'init_' + k

            input_dspecs[init_k] = sspec
            state_specs[k] = sspec

        self.logger.info("Building buffers...")
        input_bufs = create_managed_buffergroups(input_dspecs,self.max_dims,self.num_workers+self.read_ahead)
        state_bufs = create_managed_buffergroups(state_specs,self.max_dims,self.num_workers*2+self.read_ahead)
        output_bufs = create_managed_buffergroups(output_dspecs,self.max_dims,self.num_workers+self.read_ahead)

        all_buffers = dict(inputs=input_bufs,states=state_bufs,outputs=output_bufs)

        smc = SharedMemClient(all_buffers,False)

        control_master = mp.Queue()
        control_status = mp.Queue()

        state_returnq =mp.Queue()

        chunkq = mp.Queue()

        chunkoutq = mp.Queue()

        reader_inq = dict(control=mp.Queue(),state_return=state_returnq)
        reader_outq = dict(control=control_master,chunks=chunkq)

        writer_inq = dict(control=mp.Queue(),chunks=chunkoutq)
        writer_outq = dict(control=control_master,log=mp.Queue()) #,chunks=chunkq)

        child_control_qs = [reader_inq['control'],writer_inq['control'],writer_outq['log']]

        self.logger.info("Running simulation...")
        workers = []
        for w in range(self.num_workers):
            worker_inq = dict(control=mp.Queue(),chunks=chunkq)
            worker_outq = dict(control=control_master,state_return=state_returnq,chunks=chunkoutq)
            worker_p = mg.ModelGraphRunner(worker_inq,worker_outq,all_buffers,chunks,periods,worker_nodes,self.model)
            worker_p.start()
            workers.append(worker_p)
            child_control_qs.append(worker_inq['control'])

        control = ControlMaster(control_master, control_status, child_control_qs)
        control.start()

        reader_p = input_reader.InputGraphRunner(reader_inq,reader_outq,all_buffers,chunks,periods,input_nodes,self.model.get_state_keys())
        reader_p.start()

        writer_p = writer.OutputGraphRunner(writer_inq,writer_outq,all_buffers,chunks,periods,output_map)
        writer_p.start()

        log = True
        while log:
            msg = writer_outq['log'].get()
            if msg['subject'] == 'terminate':
                log = False
            else:
                self.logger.info(msg['subject'])

        writer_p.join()

        for w in workers:
            w.qin['control'].put(message('terminate'))
            # control_master.get_nowait()
            w.join()

        reader_inq['control'].put(message('terminate'))
        control_master.put(message('finished'))

        problem = False
        msg = control_status.get()
        if msg['subject'] == 'exception_raised':
            problem = True
        control.join()

        reader_p.join()

        if problem:
            raise Exception("Problem detected")
        self.logger.info("elapsed time: %.2f",time.time() - start)
Exemple #6
0
def split_period_annual_chunked(in_dti,chunksize):
    out_a= dt.split_period(in_dti,'a')
    all_out = []
    for dti in out_a:
        all_out = all_out + [DomainMappedDatetimeIndex(d) for d in dt.split_period_chunks(dti,chunksize)]
    return all_out