예제 #1
0
파일: data.py 프로젝트: TaxIPP-Life/liam2
def append_table(input_table, output_table, chunksize=10000, condition=None,
                 stop=None, show_progress=False):
    if input_table.dtype != output_table.dtype:
        output_fields = get_fields(output_table)
    else:
        output_fields = None

    if stop is None:
        numrows = len(input_table)
    else:
        numrows = stop

    if not chunksize:
        chunksize = numrows

    num_chunks, remainder = divmod(numrows, chunksize)
    if remainder > 0:
        num_chunks += 1

    if output_fields is not None:
        expanded_data = np.empty(chunksize, dtype=np.dtype(output_fields))
        expanded_data[:] = get_missing_record(expanded_data)

    #noinspection PyUnusedLocal
    def copy_chunk(chunk_idx, chunk_num):
        chunk_start = chunk_num * chunksize
        chunk_stop = min(chunk_start + chunksize, numrows)
        if condition is not None:
            input_data = input_table.readWhere(condition, start=chunk_start,
                                               stop=chunk_stop)
        else:
            input_data = input_table.read(chunk_start, chunk_stop)

        if output_fields is not None:
            # use our pre-allocated buffer (except for the last chunk)
            if len(input_data) == len(expanded_data):
                default_values = {}
                output_data = add_and_drop_fields(input_data, output_fields,
                                                  default_values, expanded_data)
            else:
                default_values = {}
                output_data = add_and_drop_fields(input_data, output_fields, default_values)
        else:
            output_data = input_data

        output_table.append(output_data)
        output_table.flush()

    if show_progress:
        loop_wh_progress(copy_chunk, range(num_chunks))
    else:
        for chunk in range(num_chunks):
            copy_chunk(chunk, chunk)

    return output_table
예제 #2
0
파일: data.py 프로젝트: gvk489/liam2
def append_table(input_table, output_table, chunksize=10000, condition=None,
                 stop=None, show_progress=False, default_values=None):

    if input_table.dtype != output_table.dtype:
        output_fields = get_fields(output_table)
    else:
        output_fields = None

    if stop is None:
        numrows = len(input_table)
    else:
        numrows = stop

    if not chunksize:
        chunksize = numrows

    num_chunks, remainder = divmod(numrows, chunksize)
    if remainder > 0:
        num_chunks += 1

    if output_fields is not None:
        expanded_data = get_default_array(chunksize, np.dtype(output_fields),
                                          default_values)

    # noinspection PyUnusedLocal
    def copy_chunk(chunk_idx, chunk_num):
        chunk_start = chunk_num * chunksize
        chunk_stop = min(chunk_start + chunksize, numrows)
        if condition is not None:
            input_data = input_table.readWhere(condition, start=chunk_start,
                                               stop=chunk_stop)
        else:
            input_data = input_table.read(chunk_start, chunk_stop)

        if output_fields is not None:
            # use our pre-allocated buffer (except for the last chunk)
            if len(input_data) == len(expanded_data):
                output_data = add_and_drop_fields(input_data, output_fields,
                                                  default_values, expanded_data)
            else:
                output_data = add_and_drop_fields(input_data, output_fields,
                                                  default_values)
        else:
            output_data = input_data

        output_table.append(output_data)
        output_table.flush()

    if show_progress:
        loop_wh_progress(copy_chunk, range(num_chunks))
    else:
        for chunk in range(num_chunks):
            copy_chunk(chunk, chunk)

    return output_table
예제 #3
0
    def evaluate(self, context):
        global local_ctx

        ctx_filter = context.get('__filter__')

        id_to_rownum = context.id_to_rownum

        # at some point ctx_filter will be cached automatically, so we don't
        # need to take care of it manually here
        if ctx_filter is not None:
            set1filter = expr_eval(ctx_filter & self.set1filter, context)
            set2filter = expr_eval(ctx_filter & self.set2filter, context)
        else:
            set1filter = expr_eval(self.set1filter, context)
            set2filter = expr_eval(self.set2filter, context)

        score_expr = self.score_expr

        used_variables = score_expr.collect_variables(context)
        used_variables1 = [v for v in used_variables
                                    if not v.startswith('__other_')]
        used_variables2 = [v[8:] for v in used_variables
                                    if v.startswith('__other_')]

        set1 = context_subset(context, set1filter, ['id'] + used_variables1)
        set2 = context_subset(context, set2filter, ['id'] + used_variables2)
        set1len = set1filter.sum()
        set2len = set2filter.sum()
        tomatch = min(set1len, set2len)
        
        orderby = self.orderby
        if not isinstance(orderby, str):
            order = expr_eval(orderby, context)
        else: 
            order = np.zeros(context_length(context), dtype=int)
            if orderby == 'EDtM':
                for var in used_variables1:
                    order[set1filter] += (set1[var] -  set1[var].mean())**2/set1[var].var()
            if orderby == 'SDtOM':
                order_ctx = dict((k if k in used_variables1 else k, v)
                             for k, v in set1.iteritems())
                order_ctx.update(('__other_' + k, set2[k].mean()) for k in used_variables2)
                order[set1filter] = expr_eval(score_expr, order_ctx)               
        
        sorted_set1_indices = order[set1filter].argsort()[::-1]
        set1tomatch = sorted_set1_indices[:tomatch]
        print("matching with %d/%d individuals" % (set1len, set2len))

        #TODO: compute pk_names automatically: variables which are either
        # boolean, or have very few possible values and which are used more
        # than once in the expression and/or which are used in boolean
        # expressions
#        pk_names = ('eduach', 'work')
#        optimized_exprs = {}

        result = np.empty(context_length(context), dtype=int)
        result.fill(-1)

        local_ctx = dict(('__other_' + k if k in ['id'] + used_variables2 else k, v)
                         for k, v in set2.iteritems())

        if self.pool_size is None:
            #noinspection PyUnusedLocal
            def match_one_set1_individual(idx, sorted_idx):
                global local_ctx
    
                if not context_length(local_ctx):
                    raise StopIteration
    
                local_ctx.update((k, set1[k][sorted_idx]) for k in ['id'] + used_variables1)
    
    #            pk = tuple(individual1[fname] for fname in pk_names)
    #            optimized_expr = optimized_exprs.get(pk)
    #            if optimized_expr is None:
    #                for name in pk_names:
    #                    fake_set1['__f_%s' % name].value = individual1[name]
    #                optimized_expr = str(symbolic_expr.simplify())
    #                optimized_exprs[pk] = optimized_expr
    #            set2_scores = evaluate(optimized_expr, mm_dict, set2)
    
                set2_scores = expr_eval(score_expr, local_ctx)
    
                individual2_idx = np.argmax(set2_scores)
    
                id1 = local_ctx['id']
                id2 = local_ctx['__other_id'][individual2_idx]
    
                local_ctx = context_delete(local_ctx, individual2_idx)
    
                result[id_to_rownum[id1]] = id2
                result[id_to_rownum[id2]] = id1            
            
            loop_wh_progress(match_one_set1_individual, set1tomatch)
        else:
            pool_size = self.pool_size
            #noinspection PyUnusedLocal
            def match_one_set1_individual_pool(idx, sorted_idx, pool_size):
                global local_ctx
                
                set2_size = context_length(local_ctx)
                if not set2_size:
                    raise StopIteration
                
                if set2_size > pool_size:
                    pool = random.sample(xrange(context_length(local_ctx)), pool_size)
                else:
                    pool = range(set2_size)

                sub_local_ctx = context_subset(local_ctx, pool, None)
                sub_local_ctx.update((k, set1[k][sorted_idx]) for k in ['id'] + used_variables1)
                
                set2_scores = expr_eval(score_expr, sub_local_ctx)
    
                individual2_pool_idx = np.argmax(set2_scores)
                individual2_idx = pool[individual2_pool_idx]
                
                id1 = sub_local_ctx['id']
                id2 = local_ctx['__other_id'][individual2_idx]
    
                local_ctx = context_delete(local_ctx, individual2_idx)
    
                result[id_to_rownum[id1]] = id2
                result[id_to_rownum[id2]] = id1
                
            loop_wh_progress(match_one_set1_individual_pool, set1tomatch, pool_size=10)
            
        return result
예제 #4
0
def merge_h5(input1_path, input2_path, output_path):
    input1_file = tables.openFile(input1_path, mode="r")
    input2_file = tables.openFile(input2_path, mode="r")

    output_file = tables.openFile(output_path, mode="w")
    output_globals = output_file.createGroup("/", "globals", "Globals")

    print "copying globals from", input1_path,
    copyTable(input1_file.root.globals.periodic, output_file, output_globals)
    print "done."

    input1_entities = input1_file.root.entities
    input2_entities = input2_file.root.entities

    fields1 = get_h5_fields(input1_file)
    fields2 = get_h5_fields(input2_file)

    ent_names1 = set(fields1.keys())
    ent_names2 = set(fields2.keys())

    output_entities = output_file.createGroup("/", "entities", "Entities")
    for ent_name in sorted(ent_names1 | ent_names2):
        print
        print ent_name
        ent_fields1 = fields1.get(ent_name, [])
        ent_fields2 = fields2.get(ent_name, [])
        output_fields = merge_items(ent_fields1, ent_fields2)
        output_table = output_file.createTable(output_entities, ent_name,
                                               np.dtype(output_fields))

        if ent_name in ent_names1:
            table1 = getattr(input1_entities, ent_name)
            print " * indexing table from %s ..." % input1_path,
            input1_rows = index_table_light(table1)
            print "done."
        else:
            table1 = None
            input1_rows = {}

        if ent_name in ent_names2:
            table2 = getattr(input2_entities, ent_name)
            print " * indexing table from %s ..." % input2_path,
            input2_rows = index_table_light(table2)
            print "done."
        else:
            table2 = None
            input2_rows = {}

        print " * merging: ",
        input1_periods = input1_rows.keys()
        input2_periods = input2_rows.keys()
        output_periods = sorted(set(input1_periods) | set(input2_periods))

        def merge_period(period_idx, period):
            if ent_name in ent_names1:
                start, stop = input1_rows.get(period, (0, 0))
                input1_array = table1.read(start, stop)
            else:
                input1_array = None

            if ent_name in ent_names2:
                start, stop = input2_rows.get(period, (0, 0))
                input2_array = table2.read(start, stop)
            else:
                input2_array = None

            if ent_name in ent_names1 and ent_name in ent_names2:
                output_array, _ = mergeArrays(input1_array, input2_array)
            elif ent_name in ent_names1:
                output_array = input1_array
            elif ent_name in ent_names2:
                output_array = input2_array
            else:
                raise Exception("this shouldn't have happened")
            output_table.append(output_array)
            output_table.flush()

        loop_wh_progress(merge_period, output_periods)
        print " done."

    input1_file.close()
    input2_file.close()
    output_file.close()
예제 #5
0
파일: matching.py 프로젝트: abozio/Myliam2
    def evaluate(self, context):
        global local_ctx
        global cost

        ctx_filter = context.get('__filter__')

        id_to_rownum = context.id_to_rownum

        # at some point ctx_filter will be cached automatically, so we don't
        # need to take care of it manually here
        if ctx_filter is not None:
            set1filter = expr_eval(ctx_filter & self.set1filter, context)
            set2filter = expr_eval(ctx_filter & self.set2filter, context)
        else:
            set1filter = expr_eval(self.set1filter, context)
            set2filter = expr_eval(self.set2filter, context)

        score_expr = self.score_expr

        used_variables = score_expr.collect_variables(context)
        used_variables1 = ['id'] + [v for v in used_variables
                                    if not v.startswith('__other_')]
        used_variables2 = ['id'] + [v[8:] for v in used_variables
                                    if v.startswith('__other_')]

        set1 = context_subset(context, set1filter, used_variables1)
        set2 = context_subset(context, set2filter, used_variables2)
        orderby = expr_eval(self.orderby, context)
        sorted_set1_indices = orderby[set1filter].argsort()[::-1]
        print "matching with %d/%d individuals" % (set1filter.sum(),
                                                   set2filter.sum())

        #TODO: compute pk_names automatically: variables which are either
        # boolean, or have very few possible values and which are used more
        # than once in the expression and/or which are used in boolean
        # expressions
#        pk_names = ('eduach', 'work')
#        optimized_exprs = {}

        result = np.empty(context_length(context), dtype=int)
        result.fill(-1)

        local_ctx = dict(('__other_' + k if k in used_variables2 else k, v)
                         for k, v in set2.iteritems())
#        print local_ctx
#        test=local_ctx.copy()
#        test.update((k, set1[k]) for k in used_variables1)
#
#

        
######## Tentative de Munkres
        
        if self.option == "optimal": 
            cost = []
            def create_cost(idx, sorted_idx):
    
                global cost
                if not context_length(local_ctx):
                    raise StopIteration
                local_ctx.update((k, set1[k][sorted_idx]) for k in used_variables1)
    
                set2_scores = expr_eval(score_expr, local_ctx)
                cost.append(set2_scores[:].tolist())
                
            loop_wh_progress(create_cost, sorted_set1_indices)       
            resultat = MunkresX.maxWeightMatching(cost)
            for id1,id2 in resultat.items(): 
                result[id_to_rownum[id1]] = id2
                result[id_to_rownum[id2]] = id1    
            return result
        
        else : 
            def match_one_set1_individual(idx, sorted_idx):
                global local_ctx   
                if not context_length(local_ctx):
                    raise StopIteration    
                local_ctx.update((k, set1[k][sorted_idx]) for k in used_variables1)
                set2_scores = expr_eval(score_expr, local_ctx)
    #            print set2_scores
                individual2_idx = np.argmax(set2_scores)   
                id1 = local_ctx['id']
                id2 = local_ctx['__other_id'][individual2_idx]    
                local_ctx = context_delete(local_ctx, individual2_idx)
    
                result[id_to_rownum[id1]] = id2
                result[id_to_rownum[id2]] = id1
    
            loop_wh_progress(match_one_set1_individual, sorted_set1_indices)       
            return result
예제 #6
0
    def compute(self, context, set1filter, set2filter, score, orderby,
                pool_size=None, algo='onebyone'):
        global matching_ctx

        if pool_size is not None:
            assert isinstance(pool_size, int)
            assert pool_size > 0

        set1filterexpr = self._getfilter(context, set1filter)
        set1filtervalue = expr_eval(set1filterexpr, context)
        set2filterexpr = self._getfilter(context, set2filter)
        set2filtervalue = expr_eval(set2filterexpr, context)
        set1len = set1filtervalue.sum()
        set2len = set2filtervalue.sum()
        print("matching with %d/%d individuals" % (set1len, set2len), end='')

        varnames = {v.name for v in score.collect_variables()}
        used_variables1 = {n for n in varnames if not n.startswith('__other_')}
        used_variables2 = {n[8:] for n in varnames if n.startswith('__other_')}

        if isinstance(orderby, str):
            assert orderby == 'EDtM'
            orderby_vars = used_variables1
        else:
            orderby_vars = {v.name for v in orderby.collect_variables()}

        if algo == 'onebyone':
            all_vars = {'id'} | used_variables1 | orderby_vars
            set1 = context.subset(set1filtervalue, all_vars, set1filterexpr)
            set2 = context.subset(set2filtervalue, {'id'} | used_variables2,
                                  set2filterexpr)

            # subset creates a dict for the current entity, so .entity_data is a
            # dict
            set1 = set1.entity_data
            set2 = set2.entity_data

            set1['__ids__'] = set1['id'].reshape(set1len, 1)
            set2['__ids__'] = set2['id'].reshape(set2len, 1)

            print()
        else:
            # optimized matching by grouping sets by values, which usually
            # means smaller sets and improved running time.
            assert algo == 'byvalue'

            # if orderby contains variables that are not used in the score
            # expression, this will effectively add variables in the
            # matching context AND group by those variables. This is correct
            # because otherwise (if we did not group by them), we could have
            # groups containing individuals with different values of the
            # ordering variables (ie the ordering would not be respected).
            set1 = group_context(used_variables1 | orderby_vars,
                                 set1filtervalue, context)
            set2 = group_context(used_variables2, set2filtervalue, context)

            # we cannot simply take the [:min(set1len, set2len)] indices like in
            # the non-optimized case and iterate over that because we don't know
            # how many groups we will need to match.
            print(" (%d/%d groups)"
                  % (context_length(set1), context_length(set2)))

        if isinstance(orderby, str):
            orderbyvalue = np.zeros(context_length(set1))
            for name in used_variables1:
                column = set1[name]
                orderbyvalue += (column - column.mean()) ** 2 / column.var()
        else:
            orderbyvalue = expr_eval(orderby, context.clone(entity_data=set1))

        # Delete variables which are not in the score expression (but in the
        # orderby expr or possibly "id") because they are no longer needed and
        # would slow things down.
        context_keep(set1, used_variables1)
        context_keep(set2, used_variables2)

        sorted_set1_indices = orderbyvalue.argsort()[::-1]

        result = np.full(context_length(context), -1, dtype=int)
        id_to_rownum = context.id_to_rownum

        # prefix all keys except __len__
        matching_ctx = {'__other_' + k if k != '__len__' else k: v
                        for k, v in set2.iteritems()}

        def match_cell(idx, sorted_idx, pool_size):
            global matching_ctx

            set2_size = context_length(matching_ctx)
            if not set2_size:
                raise StopIteration

            if pool_size is not None and set2_size > pool_size:
                pool = random.sample(xrange(set2_size), pool_size)
                local_ctx = context_subset(matching_ctx, pool)
            else:
                local_ctx = matching_ctx.copy()

            local_ctx.update((k, set1[k][sorted_idx])
                             for k in {'__ids__'} | used_variables1)

            eval_ctx = context.clone(entity_data=local_ctx)
            set2_scores = expr_eval(score, eval_ctx)
            cell2_idx = set2_scores.argmax()

            cell1ids = local_ctx['__ids__']
            cell2ids = local_ctx['__other___ids__'][cell2_idx]

            if pool_size is not None and set2_size > pool_size:
                # transform pool-local index to set/matching_ctx index
                cell2_idx = pool[cell2_idx]

            cell1size = len(cell1ids)
            cell2size = len(cell2ids)
            nb_match = min(cell1size, cell2size)

            # we could introduce a random choice here but it is not
            # much necessary. In that case, it should be done in group_context
            ids1 = cell1ids[:nb_match]
            ids2 = cell2ids[:nb_match]

            result[id_to_rownum[ids1]] = ids2
            result[id_to_rownum[ids2]] = ids1
            
            if nb_match == cell2size:
                matching_ctx = context_delete(matching_ctx, cell2_idx)
            else:
                # other variables do not need to be modified since the cell
                # only got smaller and was not deleted
                matching_ctx['__other___ids__'][cell2_idx] = cell2ids[nb_match:]

            # FIXME: the expr gets cached for the full matching_ctx at the
            # beginning and then when another women with the same values is
            # found, it thinks it can reuse the expr but it breaks because it
            # has not the correct length.

            # the current workaround is to invalidate the whole cache for the
            # current entity but this is not the right way to go.
            # * disable the cache for matching?
            # * use a local cache so that methods after matching() can use
            # what was in the cache before matching(). Shouldn't the cache be
            # stored inside the context anyway?
            expr_cache.invalidate(context.period, context.entity_name)

            if nb_match < cell1size:
                set1['__ids__'][sorted_idx] = cell1ids[nb_match:]
                match_cell(idx, sorted_idx, pool_size)
        loop_wh_progress(match_cell, sorted_set1_indices, pool_size)
        return result
예제 #7
0
파일: matching.py 프로젝트: chenyuyou/liam2
    def evaluate(self, context):
        global matching_ctx

        ctx_filter = context.get('__filter__')

        id_to_rownum = context.id_to_rownum

        # at some point ctx_filter will be cached automatically, so we don't
        # need to take care of it manually here
        if ctx_filter is not None:
            set1filter = expr_eval(ctx_filter & self.set1filter, context)
            set2filter = expr_eval(ctx_filter & self.set2filter, context)
        else:
            set1filter = expr_eval(self.set1filter, context)
            set2filter = expr_eval(self.set2filter, context)

        score_expr = self.score_expr

        used_variables = score_expr.collect_variables(context)
        used_variables1 = ['id'] + [v for v in used_variables
                                    if not v.startswith('__other_')]
        used_variables2 = ['id'] + [v[8:] for v in used_variables
                                    if v.startswith('__other_')]

        #TODO: we should detect whether or not we are using non-simple
        # expressions (EvaluableExpression children) and pre-evaluate them,
        # because otherwise they are re-evaluated on all of set2 for each
        # individual in set1. See https://github.com/liam2/liam2/issues/128
        set1 = context_subset(context, set1filter, used_variables1)
        set2 = context_subset(context, set2filter, used_variables2)
        orderby = expr_eval(self.orderby, context)
        set1len = set1filter.sum()
        set2len = set2filter.sum()
        tomatch = min(set1len, set2len)
        sorted_set1_indices = orderby[set1filter].argsort()[::-1]
        set1tomatch = sorted_set1_indices[:tomatch]
        print("matching with %d/%d individuals" % (set1len, set2len))

        #TODO: compute pk_names automatically: variables which are either
        # boolean, or have very few possible values and which are used more
        # than once in the expression and/or which are used in boolean
        # expressions
#        pk_names = ('eduach', 'work')
#        optimized_exprs = {}

        result = np.empty(context_length(context), dtype=int)
        result.fill(-1)

        matching_ctx = dict(('__other_' + k if k in used_variables2 else k, v)
                            for k, v in set2.iteritems())

        #noinspection PyUnusedLocal
        def match_one_set1_individual(idx, sorted_idx):
            global matching_ctx

            if not context_length(matching_ctx):
                raise StopIteration

            local_ctx = matching_ctx.copy()
            local_ctx.update((k, set1[k][sorted_idx]) for k in used_variables1)
#            pk = tuple(individual1[fname] for fname in pk_names)
#            optimized_expr = optimized_exprs.get(pk)
#            if optimized_expr is None:
#                for name in pk_names:
#                    fake_set1['__f_%s' % name].value = individual1[name]
#                optimized_expr = str(symbolic_expr.simplify())
#                optimized_exprs[pk] = optimized_expr
#            set2_scores = evaluate(optimized_expr, mm_dict, set2)
            set2_scores = expr_eval(score_expr, local_ctx)

            individual2_idx = np.argmax(set2_scores)

            id1 = local_ctx['id']
            id2 = matching_ctx['__other_id'][individual2_idx]
            matching_ctx = context_delete(matching_ctx, individual2_idx)

            result[id_to_rownum[id1]] = id2
            result[id_to_rownum[id2]] = id1

        loop_wh_progress(match_one_set1_individual, set1tomatch,
                         title="Matching...")
        return result
예제 #8
0
def merge_group(parent1, parent2, name, output_file, index_col):
    print()
    print(name)
    print('=' * len(name))

    group1 = getattr(parent1, name, None)
    group2 = getattr(parent2, name, None)
    if group1 is None and group2 is None:
        print("node not found in either input files, skipped")
        return

    output_group = output_file.create_group("/", name)
    fields1 = get_group_fields(group1)
    fields2 = get_group_fields(group2)
    ent_names1 = set(fields1.keys())
    ent_names2 = set(fields2.keys())
    for ent_name in sorted(ent_names1 | ent_names2):
        print()
        print(ent_name)
        ent_fields1 = fields1.get(ent_name, [])
        ent_fields2 = fields2.get(ent_name, [])
        output_fields = merge_items(ent_fields1, ent_fields2)
        output_table = output_file.create_table(output_group, ent_name,
                                                np.dtype(output_fields))

        if ent_name in ent_names1:
            table1 = getattr(group1, ent_name)
            # noinspection PyProtectedMember
            print(" * indexing table from %s ..." % group1._v_file.filename,
                  end=' ')
            input1_rows = index_table_light(table1, index_col)
            print("done.")
        else:
            table1 = None
            input1_rows = {}

        if ent_name in ent_names2:
            table2 = getattr(group2, ent_name)
            # noinspection PyProtectedMember
            print(" * indexing table from %s ..." % group2._v_file.filename,
                  end=' ')
            input2_rows = index_table_light(table2, index_col)
            print("done.")
        else:
            table2 = None
            input2_rows = {}

        print(" * merging: ", end=' ')
        input1_periods = input1_rows.keys()
        input2_periods = input2_rows.keys()
        output_periods = sorted(set(input1_periods) | set(input2_periods))

        # noinspection PyUnusedLocal
        def merge_period(period_idx, period):
            if ent_name in ent_names1:
                start, stop = input1_rows.get(period, (0, 0))
                input1_array = table1.read(start, stop)
            else:
                input1_array = None

            if ent_name in ent_names2:
                start, stop = input2_rows.get(period, (0, 0))
                input2_array = table2.read(start, stop)
            else:
                input2_array = None

            if ent_name in ent_names1 and ent_name in ent_names2:
                if 'id' in input1_array.dtype.names:
                    assert 'id' in input2_array.dtype.names
                    output_array, _ = merge_arrays(input1_array, input2_array)
                else:
                    output_array = merge_array_records(input1_array,
                                                       input2_array)

            elif ent_name in ent_names1:
                output_array = input1_array
            elif ent_name in ent_names2:
                output_array = input2_array
            else:
                raise Exception("this shouldn't have happened")
            output_table.append(output_array)
            output_table.flush()

        loop_wh_progress(merge_period, output_periods)
        print(" done.")
예제 #9
0
def merge_h5(input1_path, input2_path, output_path):
    input1_file = tables.open_file(input1_path, mode="r")
    input2_file = tables.open_file(input2_path, mode="r")

    output_file = tables.open_file(output_path, mode="w")

    print("copying globals from", input1_path, end=' ')
    #noinspection PyProtectedMember
    input1_file.root.globals._f_copy(output_file.root, recursive=True)
    print("done.")

    input1_entities = input1_file.root.entities
    input2_entities = input2_file.root.entities

    fields1 = get_h5_fields(input1_file)
    fields2 = get_h5_fields(input2_file)

    ent_names1 = set(fields1.keys())
    ent_names2 = set(fields2.keys())

    output_entities = output_file.create_group("/", "entities", "Entities")
    for ent_name in sorted(ent_names1 | ent_names2):
        print()
        print(ent_name)
        ent_fields1 = fields1.get(ent_name, [])
        ent_fields2 = fields2.get(ent_name, [])
        output_fields = merge_items(ent_fields1, ent_fields2)
        output_table = output_file.create_table(output_entities, ent_name,
                                               np.dtype(output_fields))

        if ent_name in ent_names1:
            table1 = getattr(input1_entities, ent_name)
            print(" * indexing table from %s ..." % input1_path, end=' ')
            input1_rows = index_table_light(table1)
            print("done.")
        else:
            table1 = None
            input1_rows = {}

        if ent_name in ent_names2:
            table2 = getattr(input2_entities, ent_name)
            print(" * indexing table from %s ..." % input2_path, end=' ')
            input2_rows = index_table_light(table2)
            print("done.")
        else:
            table2 = None
            input2_rows = {}

        print(" * merging: ", end=' ')
        input1_periods = input1_rows.keys()
        input2_periods = input2_rows.keys()
        output_periods = sorted(set(input1_periods) | set(input2_periods))

        #noinspection PyUnusedLocal
        def merge_period(period_idx, period):
            if ent_name in ent_names1:
                start, stop = input1_rows.get(period, (0, 0))
                input1_array = table1.read(start, stop)
            else:
                input1_array = None

            if ent_name in ent_names2:
                start, stop = input2_rows.get(period, (0, 0))
                input2_array = table2.read(start, stop)
            else:
                input2_array = None

            if ent_name in ent_names1 and ent_name in ent_names2:
                output_array, _ = merge_arrays(input1_array, input2_array)
            elif ent_name in ent_names1:
                output_array = input1_array
            elif ent_name in ent_names2:
                output_array = input2_array
            else:
                raise Exception("this shouldn't have happened")
            output_table.append(output_array)
            output_table.flush()

        loop_wh_progress(merge_period, output_periods)
        print(" done.")

    input1_file.close()
    input2_file.close()
    output_file.close()