def plot_queries(stats: TrialStats):
    stats.plot('queries', 'a=5')
    stats.plot('queries', 'a=7')
    stats.plot('queries', 'a=9')
    plt.title('Queries')
    plt.legend()
    plt.show()
def plot_accuracies(stats: TrialStats):
    stats.plot('accuracies', 'a=5')
    stats.plot('accuracies', 'a=7')
    stats.plot('accuracies', 'a=9')
    plt.title('Accuracy')
    plt.legend()
    plt.show()
def accuracy_vs_time_vs_appliances_func(seed: int,
                                        house_trials: int,
                                        duration: int,
                                        a: int,
                                        house_pool: Sequence[int],
                                        matcher_spec: MatcherSpec,
                                        keys={}):
    trial_stats = TrialStats(keys=keys)

    reconstruct_settings = data_combiner.ReconstructionSettings(
        duration, -1, a)
    for house_num in house_pool:
        for i in range(house_trials):
            print('\ta={},h={},i={}'.format(a, house_num, i))
            data = data_combiner.new_series(house_num, seed + i,
                                            reconstruct_settings)
            stats = realtime_disaggregator.run_trial(house_num, matcher_spec,
                                                     *data)

            trial_stats.push(stats, 'a={},h={},i={}'.format(a, house_num, i))

    return trial_stats
def plot_app_queries_func(stats: TrialStats):
    questions_sets = stats.select('questions', lambda x: True)
    all_labels = [chain(*[q.labels for q in qs]) for qs in questions_sets]
    labels = collections.Counter(chain(*all_labels))

    (labels, values) = map(
        list,
        zip(*sorted([(l, labels[l]) for l in labels], key=lambda x: x[0])))

    print(f'There are {len(labels)} labels')
    print(labels)

    plt.xticks(range(len(values)), labels, rotation=45)
    plt.bar(range(len(values)), values, label='Queries')
    plt.ylabel('Queries')
def plot_app_votes_func(stats: TrialStats):
    questions_sets = stats.select('questions', lambda x: True)
    all_labels = chain(*[
        chain(*[[(l, q.weight) for l in q.labels] for q in qs])
        for qs in questions_sets
    ])

    labels = collections.Counter()

    for label, vote in all_labels:
        labels[label] += vote

    (labels, values) = map(
        list,
        zip(*sorted([(l, labels[l]) for l in labels], key=lambda x: x[0])))

    plt.xticks(range(len(values)), labels, rotation=30)
    plt.bar(range(len(values)), values, label='Votes')
    plt.ylabel('Votes')
def plot_app_accuracies(stats: TrialStats):
    mismatches = stats.select('mismatches', lambda t: True)
    # Mismatches for last frame of each trial
    last_mismatch = [m[-1] for m in mismatches]
    failed_labels = []

    for mismatch in last_mismatch:
        failed_labels.extend(mismatch)

    failures = collections.Counter(failed_labels)

    # Remove always on from pairing
    accuracies = sorted([(l, (len(mismatches) - failures[l]) / len(mismatches))
                         for l in failures if l != 'refrigerator'],
                        key=lambda x: x[0])

    labels = sorted([l for l in failures])
    print(f'There are {len(labels)} labels')
    print(labels)

    for i, (l, acc) in enumerate(accuracies):
        plt.text(i, 1, f'{acc:.2f}', ha='center')
def plot_confidences(stats: TrialStats):
    stats.plot('confidences', 'a=5')
    stats.plot('confidences', 'a=7')
    stats.plot('confidences', 'a=9')
    plt.legend()
    plt.show()
    accuracies = sorted([(l, (len(mismatches) - failures[l]) / len(mismatches))
                         for l in failures if l != 'refrigerator'],
                        key=lambda x: x[0])

    labels = sorted([l for l in failures])
    print(f'There are {len(labels)} labels')
    print(labels)

    for i, (l, acc) in enumerate(accuracies):
        plt.text(i, 1, f'{acc:.2f}', ha='center')


SEED = 434353456
path = Path(__file__).parent / f'{SEED}/accuracy_vs_time_vs_appliances'

stats = TrialStats.load(path)

# Aggregate over time
#plot_accuracies(stats)
#plot_queries(stats)
#plot_confidences(stats)

# Investigate: How often an appliance is 'on' vs how many times its queried
# Goal is to see if some appliances are disproportionatelly represented
# i.e. if an appliance is on for 1 hour, and another is on for 6
# the second should be queried 6x as much
# Maybe weight the first one since it's on 1/6th as much?
# Appliance breakdown
plot_app_queries_func(stats)
plot_app_votes_func(stats)
plot_app_accuracies(stats)