def load_glocalx(rules, is_glocalx_run=False): """ Create a GLocalX instance from `rules_file`. Rules from `rules_file` are considered as this instance's output, i.e. its `self.fine_boundary`. Args: rules (Union(str, set, list)): Path to rules, or directly set/list of rules. is_glocalx_run (bool): Whether the given rule file is the output of a GLocalX run or not. GLocalX stores its output file in a different format than the input rules. Returns: GLocalX: A GLocalX instance as if it were trained and returned `rules` """ if isinstance(rules, str) and not os.path.isfile(rules): raise ValueError('Not a valid file') if isinstance(rules, str): if is_glocalx_run: run = load_run(rules) fine_boundary = run['rules'] glocalx = GLocalX(oracle=run['oracle']) else: fine_boundary = Rule.from_json(rules) glocalx = GLocalX(oracle=None) elif isinstance(rules, set) or isinstance(rules, list): fine_boundary = rules glocalx = GLocalX(oracle=None) else: raise ValueError('Not a str or set or list') # Load rules in the boundary glocalx.fine_boundary = fine_boundary return glocalx
def run(rules, scoring='rrs', name=None, oracle=None, tr=None, vl=None, coverage_weight=1., sparsity_weight=1., alpha=0, beta=0, gamma=-1, max_len=-1, debug=20): """Run the Scoring framework on a set of rules. Arguments: rules (str): JSON file with the train set. scoring (str): Type of scoring to perform. Available names are 'rrs', which includes fidelity scoring, and 'coverage'. Defaults to 'rrs'. oracle (Union(str, None)): Oracle to score against. tr (str): Training set. vl (str): Validation set, if any. name (str): Name for the output logs. coverage_weight (float): Coverage weight vector. sparsity_weight (float): Sparsity weight vector. alpha (float): Pruning hyperparameter, rules with score less than `alpha` are removed from the result. beta (float): Pruning hyperparameter, rules with score less than the `beta`-percentile are removed from the result. gamma (int): Maximum number of rules to use. max_len (int): Pruning hyperparameter, rules with length more than `max_len` are removed from the result. debug (int): Minimum debug level. """ # Set-up debug if debug == 10: min_log = logzero.logging.DEBUG elif debug == 20: min_log = logzero.logging.INFO elif debug == 30: min_log = logzero.logging.WARNING elif debug == 40: min_log = logzero.logging.ERROR elif debug == 50: min_log = logzero.logging.CRITICAL else: min_log = 0 logzero.loglevel(min_log) if name is None: name = tr + str(time.time()) # Info LOG logger.info('Rules: ' + str(rules)) logger.info('name: ' + str(name)) logger.info('score: ' + str(scoring)) logger.info('oracle: ' + str(oracle)) logger.info('vl: ' + str(vl)) logger.info('coverage weight: ' + str(coverage_weight)) logger.info('sparsity weight: ' + str(sparsity_weight)) logger.info('alpha: ' + str(alpha)) logger.info('beta: ' + str(beta)) logger.info('max len: ' + str(max_len)) logger.info('Loading validation... ') data = genfromtxt(tr, delimiter=',', names=True) names = data.dtype.names training_set = data.view(np_float).reshape(data.shape + (-1, )) if vl is not None: data = genfromtxt(vl, delimiter=',', names=True) validation_set = data.view(np_float).reshape(data.shape + (-1, )) else: validation_set = None # Run Scoring logger.info('Loading ruleset...') rules = Rule.from_json(rules, names) rules = [r for r in rules if len(r) > 0] logger.info('Loading oracle...') if oracle is not None: if oracle.endswith('.h5'): oracle = load_model(oracle) elif oracle.endswith('.pickle'): with open(oracle, 'rb') as log: oracle = pickle.load(log) else: return if validation_set is not None: validation_set = hstack( (validation_set[:, :-1], oracle.predict(validation_set[:, :-1].round()).reshape( (validation_set.shape[0], 1)))) training_set = hstack( (training_set[:, :-1], oracle.predict(training_set[:, :-1].round()).reshape( (training_set.shape[0], 1)))) logger.info('Creating scorer...') evaluator = MemEvaluator(oracle=oracle) scorer = Scorer(score=scoring, evaluator=evaluator, oracle=oracle) logger.info('Scoring...') scores = scorer.score(rules, training_set, coverage_weight, sparsity_weight) logger.info('Storing scores...') storage_file = name + '.csv' scorer.save(scores=scores, path=storage_file) # Validation logger.info('Validating...') validation_dic = validate( rules, scores, oracle=oracle, vl=validation_set if validation_set is not None else training_set, scoring=scoring, alpha=alpha, beta=beta, gamma=len(rules) if gamma < 0 else int(gamma), max_len=inf if max_len <= 0 else max_len) validation_dic['coverage'] = coverage_weight validation_dic['sparsity'] = sparsity_weight validation_dic['scoring'] = scoring # Store info on JSON out_str = name + '.results.json' if os.path.isfile(out_str): with open(out_str, 'r') as log: out_dic = json.load(log) out_dic['runs'].append({ 'scoring': scoring, 'coverage': coverage_weight, 'sparsity': sparsity_weight, 'alpha': alpha, 'beta': beta, 'gamma': gamma, 'max_len': max_len, 'results': validation_dic }) else: out_dic = { 'name': name, 'runs': [{ 'scoring': scoring, 'coverage': coverage_weight, 'sparsity': sparsity_weight, 'alpha': alpha, 'beta': beta, 'gamma': gamma, 'max_len': max_len, 'results': validation_dic }] } with open(out_str, 'w') as log: json.dump(out_dic, log) pretty_print_results(validation_dic, out_str)