def get_metric_report(info_dicts, args, keys=['-'], LR=None):
    class_metric_strs, total_metrics = [], []
    report_metrics = [
        'jacc', 'acc', 'mcc', 'f1', 'recall', 'precision', 'var'
    ] if args.all_metrics else [args.report_metric]
    for m in report_metrics:
        for d in info_dicts:
            d.update({'metric': m})
        class_metrics = [get_metric(d) for d in info_dicts]
        total_metrics.append(get_metric(info_dicts))

        if LR is not None:
            delim = '-'
        else:
            delim = {
                'mcc': '#',
                'f1': '+',
                'jacc': '=',
                'acc': '>',
                'var': '%',
                'recall': '<',
                'precision': '~'
            }[m]
        class_metric_strs.append(", ".join(
            '{} {} {:5.2f}'.format(k, delim, f * 100)
            for k, f in zip(keys, class_metrics)))

    return total_metrics, class_metric_strs
def _binary_threshold(preds,
                      labels,
                      metric,
                      micro,
                      global_tweaks=1000,
                      debug=False,
                      heads_per_class=1,
                      class_single_threshold=False):
    avg_metric = 0
    best_thresholds = []
    info_dicts = []
    category_metrics = []
    # Compute threshold per class... *unless* multiple heads per class and one threshold required.
    num_categories = labels.shape[1]
    for category in range(num_categories):
        category_best_threshold = category_best_metric = 0
        for threshold in np.linspace(0.005, 1, 200):
            if heads_per_class > 1 and class_single_threshold:
                info_dict = update_info_dict(
                    defaultdict(int),
                    labels[:, (category * heads_per_class):(category + 1) *
                           heads_per_class],
                    preds[:, (category * heads_per_class):(category + 1) *
                          heads_per_class],
                    threshold=threshold)
            else:
                info_dict = update_info_dict(defaultdict(int),
                                             labels[:, category],
                                             preds[:, category],
                                             threshold=threshold)
            metric_score = get_metric(info_dict, metric, micro)
            if metric_score > category_best_metric or category_best_metric == 0:
                category_best_metric, category_best_threshold, category_best_info_dict = metric_score, threshold, info_dict
        info_dicts.append(category_best_info_dict)
        category_metrics.append(category_best_metric)
        best_thresholds.append(category_best_threshold)

    # HACK -- use micro average here, even if not elsewhere
    micro = True
    best_metric = get_metric(info_dicts, metric, micro)
    # HACK: Attempt to tune thresholds simultaneously... for overall micro average
    if num_categories < 2:
        global_tweaks = 0
    if debug and global_tweaks > 0:
        print('best after invididual thresholds (micro %s)' % micro)
        print(best_thresholds)
        print(get_metric(info_dicts, metric, micro))
    for i in range(global_tweaks):
        # Choose random category
        category = np.random.randint(num_categories)
        curr_threshold = best_thresholds[category]
        # tweak randomly
        new_threshold = curr_threshold + (0.08 * (np.random.random() - 0.5))
        if heads_per_class > 1 and class_single_threshold:
            info_dict = update_info_dict(
                defaultdict(int),
                labels[:, (category * heads_per_class):(category + 1) *
                       heads_per_class],
                preds[:, (category * heads_per_class):(category + 1) *
                      heads_per_class],
                threshold=new_threshold)
        else:
            info_dict = update_info_dict(defaultdict(int),
                                         labels[:, category],
                                         preds[:, category],
                                         threshold=new_threshold)
        old_dict = info_dicts[category]
        info_dicts[category] = info_dict
        # compute *global* metrics
        metric_score = get_metric(info_dicts, metric, micro)
        # save new threshold if global metrics improve
        if metric_score > best_metric:
            # print('Better threshold %.3f for category %d' % (new_threshold, category))
            best_thresholds[category] = round(new_threshold, 3)
            best_metric = metric_score
        else:
            info_dicts[category] = old_dict
    if debug and global_tweaks > 0:
        print('final thresholds')
        print(best_thresholds)
        print(get_metric(info_dicts, metric, micro))

    # OK, now *if* we used multiple heads per class (same threshold) copy these back out to final answers
    if heads_per_class > 1 and class_single_threshold:
        best_thresholds = np.concatenate([[best_thresholds[i]] *
                                          heads_per_class
                                          for i in range(num_categories)])
    else:
        best_thresholds = np.array(best_thresholds)
    # print(best_thresholds)
    return get_metric(info_dicts, metric,
                      micro), best_thresholds, category_metrics, info_dicts
def train_logreg(trX,
                 trY,
                 vaX=None,
                 vaY=None,
                 teX=None,
                 teY=None,
                 penalty='l1',
                 max_iter=100,
                 C=2**np.arange(-8, 1).astype(np.float),
                 seed=42,
                 model=None,
                 eval_test=True,
                 neurons=None,
                 drop_neurons=False,
                 report_metric='acc',
                 automatic_thresholding=False,
                 threshold_metric='acc',
                 micro=False):

    # if only integer is provided for C make it iterable so we can loop over
    if not isinstance(C, collections.Iterable):
        C = list([C])
    # extract features for given neuron indices
    if neurons is not None:
        if drop_neurons:
            all_neurons = set(list(range(trX.shape[-1])))
            neurons = set(list(neurons))
            neurons = list(all_neurons - neurons)
        trX = trX[:, neurons]
        if vaX is not None:
            vaX = vaX[:, neurons]
        if teX is not None:
            teX = teX[:, neurons]
    # Cross validation over C
    n_classes = 1
    if len(trY.shape) > 1:
        n_classes = trY.shape[-1]
    scores = []
    if model is None:
        for i, c in enumerate(C):
            if n_classes <= 1:
                model = LogisticRegression(C=c,
                                           penalty=penalty,
                                           max_iter=max_iter,
                                           random_state=seed)
                model.fit(trX, trY)
                blank_info_dict = {
                    'fp': 0,
                    'tp': 0,
                    'fn': 0,
                    'tn': 0,
                    'std': 0.,
                    'metric': threshold_metric,
                    'micro': micro
                }
                if vaX is not None:
                    info_dict = update_info_dict(
                        blank_info_dict.copy(), vaY,
                        model.predict_proba(vaX)[:, -1])
                else:
                    info_dict = update_info_dict(
                        blank_info_dict.copy(), trY,
                        model.predict_proba(trX)[:, -1])
                scores.append(get_metric(info_dict))
                print(scores[-1])
                del model
            else:
                info_dicts = []
                model = []
                for cls in range(n_classes):
                    _model = LogisticRegression(C=c,
                                                penalty=penalty,
                                                max_iter=max_iter,
                                                random_state=seed)
                    _model.fit(trX, trY[:, cls])
                    blank_info_dict = {
                        'fp': 0,
                        'tp': 0,
                        'fn': 0,
                        'tn': 0,
                        'std': 0.,
                        'metric': threshold_metric,
                        'micro': micro
                    }
                    if vaX is not None:
                        info_dict = update_info_dict(
                            blank_info_dict.copy(), vaY[:, cls],
                            _model.predict_proba(vaX)[:, -1])
                    else:
                        info_dict = update_info_dict(
                            blank_info_dict.copy(), trY[:, cls],
                            _model.predict_proba(trX)[:, -1])
                    info_dicts.append(info_dict)
                    model.append(_model)
                scores.append(get_metric(info_dicts))
                print(scores[-1])
                del model
        c = C[np.argmax(scores)]
        if n_classes <= 1:
            model = LogisticRegression(C=c,
                                       penalty=penalty,
                                       max_iter=max_iter,
                                       random_state=seed)
            model.fit(trX, trY)
        else:
            model = []
            for cls in range(n_classes):
                _model = LogisticRegression(C=c,
                                            penalty=penalty,
                                            max_iter=max_iter,
                                            random_state=seed)
                _model.fit(trX, trY[:, cls])
                model.append(_model)
    else:
        c = model.C
    # predict probabilities and get accuracy of regression model on train, val, test as appropriate
    # also get number of regression weights that are not zero. (number of features used for modeling)
    scores = []
    if n_classes == 1:
        nnotzero = np.sum(model.coef_ != 0)
        preds = model.predict_proba(trX)[:, -1]
        train_score = get_metric(
            update_info_dict(blank_info_dict.copy(), trY, preds),
            report_metric)
    else:
        nnotzero = 0
        preds = []
        info_dicts = []
        for cls in range(n_classes):
            nnotzero += np.sum(model[cls].coef_ != 0)
            _preds = model[cls].predict_proba(trX)[:, -1]
            info_dicts.append(
                update_info_dict(blank_info_dict.copy(), trY[:, cls], _preds))
            preds.append(_preds)
        nnotzero /= n_classes
        train_score = get_metric(info_dicts, report_metric)
        preds = np.concatenate([p.reshape((-1, 1)) for p in preds], axis=1)
    scores.append(train_score * 100)
    if vaX is None:
        eval_data = trX
        eval_labels = trY
        val_score = train_score
    else:
        eval_data = vaX
        eval_labels = vaY
        if n_classes == 1:
            preds = model.predict_proba(vaX)[:, -1]
            val_score = get_metric(
                update_info_dict(blank_info_dict.copy(), vaY, preds),
                report_metric)
        else:
            preds = []
            info_dicts = []
            for cls in range(n_classes):
                _preds = model[cls].predict_proba(vaX)[:, -1]
                info_dicts.append(
                    update_info_dict(blank_info_dict.copy(), vaY[:, cls],
                                     _preds))
                preds.append(_preds)
            val_score = get_metric(info_dicts, report_metric)
            preds = np.concatenate([p.reshape((-1, 1)) for p in preds], axis=1)
    val_preds = preds
    val_labels = eval_labels
    scores.append(val_score * 100)
    eval_score = val_score
    threshold = np.array([.5] * n_classes)
    if automatic_thresholding:
        _, threshold, _, _ = _binary_threshold(
            preds.reshape(-1, n_classes), eval_labels.reshape(-1, n_classes),
            threshold_metric, micro)
        threshold = float(threshold.squeeze())
    if teX is not None and teY is not None and eval_test:
        eval_data = teX
        eval_labels = teY
        if n_classes == 1:
            preds = model.predict_proba(eval_data)[:, -1]
        else:
            preds = []
            for cls in range(n_classes):
                _preds = model[cls].predict_proba(eval_data)[:, -1]
                preds.append(_preds)
            preds = np.concatenate([p.reshape((-1, 1)) for p in preds], axis=1)
    if n_classes == 1:
        threshold = float(threshold.squeeze())
        eval_score = get_metric(
            update_info_dict(blank_info_dict.copy(),
                             eval_labels,
                             preds,
                             threshold=threshold), report_metric)
    else:
        info_dicts = []
        for cls in range(n_classes):
            info_dicts.append(
                update_info_dict(blank_info_dict.copy(),
                                 eval_labels[:, cls],
                                 preds[:, cls],
                                 threshold=threshold[cls]))
        eval_score = get_metric(info_dicts, report_metric)

    scores.append(eval_score * 100)
    return model, scores, preds, c, nnotzero