class NormSourceMaker(object):
    """ An experiment, not really useful, keeping the code around for basically no reason """
    def __init__(self, datafolder = None, phases = None, cross_val = False):
        self.data = None
        if datafolder:
            with gzip.open(datafolder+"normalized.pkl.gz",'rb') as f:
                self.data = cPickle.load(f)
        else:
            paths = glob.glob("~/PycharmProjects/MatchBrain/ml/normalized.pkl.gz")
            if len(paths) > 0:
                with gzip.open(paths[0]) as f:
                    self.data = cPickle.load(f)
        if not self.data:
            return
        self.phases = phases or phase_names
        self.cross_val_keys = list(self.data.keys())
        self.cross_val_index = 0 if cross_val else None

    def cross_val_next(self):
        self.cross_val_index = (self.cross_val_index + 1) % len(self.cross_val_keys)

    def get_block(self):
        self.source = GenSource(e
                                for i in xrange(len(self.cross_val_keys))
                                if not i == self.cross_val_index
                                for e in self.data[self.cross_val_keys[i]])
        self.ph_source = Transformer([self.source], {self.source.getName(): 'd'}, lambda d: d['phase'])
        self.bw_source = Transformer([self.source], {self.source.getName(): 'd'}, lambda d: d['raw'])
        block = SignalBlock(
            [self.source],
            [self.bw_source, self.ph_source]
        )
        self.source.callback = block.stop
        return block
class LogSourceMaker(object):
    """ The way the system used to work, this hooks into the whole signal system thing I built for the game code, not useful for learning """
    def __init__(self, clean_seconds = 3, logfolder = None, phases = None, cross_val = False):
        if logfolder:
            all_dict = loadall(logfolder=logfolder)
        else:
            all_dict = loadall()
        self.seen_val_names = []
        if cross_val:
            def tname_for(fname):
                just_f = fname.split('/')[-1]
                return just_f.split('2015')[0]
            #tester_names = sorted(list(set(map(tname_for, all_dict.iterkeys()))))
            #print(tester_names)
            #print({f: max(g) for f,g in groupby(sorted(all_dict.iterkeys()), tname_for)})
            all_dict = {f: all_dict[max(g)]
                        for f,g in groupby(sorted(all_dict.iterkeys()), tname_for)}
            val_name = random.choice(all_dict.keys())
            self.seen_val_names.append(val_name)
            self.val_dict = {val_name: all_dict[val_name]}
            self.all_dict = {k:v for (k,v) in all_dict.iteritems() if not k is val_name}
        else:
            self.val_dict = {}
            self.all_dict = all_dict
        self.phases = phases or phase_names
        #self.phases.append('none')
        self.clean_seconds = clean_seconds
        self.raws_per_phase = self.to_raws_per_phase(all_dict)
        self.rpp_val = self.to_raws_per_phase(self.val_dict)

    def cross_val_next(self):
        all_dict = self.all_dict
        all_dict.update(self.val_dict)
        next_val = random.choice(all_dict.keys())
        while next_val in self.seen_val_names:
            next_val = random.choice(all_dict.keys())
        self.val_dict = {next_val: all_dict[next_val]}
        self.all_dict = {k:v for (k,v) in all_dict.iteritems() if not k is next_val}
        self.raws_per_phase = self.to_raws_per_phase(self.all_dict)
        self.rpp_val = self.to_raws_per_phase(self.val_dict)

    def to_raws_per_phase(self, a_dict):
        result = []
        cur_phase = phase_names[0]
        for p in a_dict.keys():
            cur_part = a_dict[p]
            seen_phase = False
            clean_counter = self.clean_seconds
            for li in cur_part:
                if "data" in li:
                    if "train" in li["data"] and li["data"]["train"]["phase"] in self.phases:
                        cur_phase = li["data"]["train"]["phase"]
                        seen_phase = True
                        clean_counter = self.clean_seconds
                    elif "brainwave" in li["data"]:
                        if clean_counter > 0:
                            clean_counter -= 1
                        else:
                            result.append((cur_phase
                                           if seen_phase
                                           else "DISTRACT"
                                          ,li["data"]["brainwave"]["raw"]))
                    else:
                        # Skip data for "phases" that aren't in the list
                        pass
        return result

    def get_block(self, shift = 1):
        a_meas = self.raws_per_phase[0][1]
        self.ph_source = GenSource(li[0]
                                   for rep in repeat(self.raws_per_phase, len(a_meas))
                                   for li in rep)
        self.bw_source = GenSource(li
                                   for i in xrange(0, len(a_meas), shift)
                                   for li in map(lambda t: interp_ls(t[0][1], t[1][1], i)
                                                ,zip(self.raws_per_phase[0:-1], self.raws_per_phase[1:])))
        block = SignalBlock(
            [self.bw_source, self.ph_source],
            [Transformer(
                [self.bw_source, self.ph_source],
                {self.bw_source.getName(): 'b', self.ph_source.getName(): 'd'},
                lambda b,d: (b,d))
            ]
        )
        self.bw_source.callback = block.stop
        self.ph_source.callback = block.stop
        #result.sinks.append(Sink([self.bw_source, self.ph_source], lambda x: print(x)))
        return block