class Automator(object):

    EXPERIMENTS = {'r1': ('FR1', 'FR2', 'FR3',
                          'catFR1', 'catFR2', 'catFR3',
                          'PAL1', 'PAL2', 'PAL3',
                          'TH1', 'TH2', 'TH3', 'THR',
                          'PS1', 'PS2', 'PS3')}
    MATH_TASKS = ('FR1', 'FR2', 'FR3', 'catFR1', 'catFR2', 'catFR3', 'PAL1', 'PAL2', 'PAL3', 'ltpFR', 'ltpFR2')

    INCLUDE_TRANSFERRED = False

    def __init__(self, protocol):
        self.protocol = protocol
        self.index = JsonIndexReader(os.path.join(paths.db_root, 'protocols', '{}.json'.format(protocol)))
        self.importers = []

    def populate_importers(self):
        self.add_existing_events_importers()
        self.add_existing_montage_importers()
        self.add_future_events_importers()

    def add_existing_montage_importers(self):
        subjects = self.index.subjects()
        for subject in subjects:
            montages = self.index.montages(subject=subject)
            for montage in montages:
                code = self.index.get_value('subject_alias', subject=subject, montage=montage.split('.')[1])
                importer = Importer(Importer.CONVERT_MONTAGE,
                                    subject=subject, montage=montage, protocol=self.protocol, code=code)
                if importer.check() or importer.errored or self.INCLUDE_TRANSFERRED:
                    self.importers.append(importer)

    def session_indexes(self):
        experiments = self.index.experiments()
        for experiment in experiments:
            exp_index = self.index.filtered(experiment=experiment)
            subjects = exp_index.subjects()
            for subject in subjects:
                subj_index = exp_index.filtered(subject=subject)
                sessions = subj_index.sessions()
                for session in sessions:
                    sess_index = subj_index.filtered(session=session)
                    yield subject, experiment, session, sess_index

    def build_existing_event_importer_kwargs(self, subject, experiment, session, index, do_compare=True):
        montage = index.montages()[0]
        do_math = experiment in self.MATH_TASKS
        code = index.get_value('subject_alias')
        try:
            original_session = int(index.get_value('original_session'))
        except KeyError:
            original_session = int(session)
        try:
            original_experiment = index.get_value('original_experiment')
        except KeyError:
            original_experiment = experiment

        kwargs = dict(subject=subject, montage=montage, experiment=original_experiment, session=int(session),
                      new_experiment=experiment, original_session=original_session,
                      do_math=do_math, protocol=self.protocol, code=code, do_compare=do_compare)
        return kwargs


    def add_existing_events_importers(self):
        for subject, experiment, session, index in self.session_indexes():
            kwargs = self.build_existing_event_importer_kwargs(subject, experiment, session, index)
            importer = Importer(Importer.BUILD_EVENTS, **kwargs)
            if not importer.check() and not importer.errored:
                pass
            elif not importer.errored:
                self.importers.append(importer)
            else:
                importer2 = Importer(Importer.CONVERT_EVENTS, **kwargs)
                if not importer2.check() and not importer2.errored:
                    pass
                else:
                    if not importer.errored:
                        self.importers.append(importer)
                    elif not importer2.errored:
                        self.importers.append(importer2)
                    else:
                        self.importers.append(importer)

    def add_future_events_importers(self):
        subjects = self.index.subjects()
        for subject in subjects:
            subj_index = self.index.filtered(subject=subject)
            montages = subj_index.montages()
            max_montage = max(montages)
            for experiment in self.EXPERIMENTS:
                exp_index = subj_index.filtered(montage=max_montage, experiment=experiment)
                sessions = exp_index.sessions()
                if sessions:
                    max_session = max(sessions)
                    try:
                        original_session = exp_index.get_value('original_session', session=max_session) + 1
                    except KeyError:
                        original_session = max_session+1
                    try:
                        original_experiment = exp_index.get_value('original_experiment', session=max_session)
                    except KeyError:
                        original_experiment = experiment
                    session = max_session + 1
                else:
                    session = 0
                    original_session = 0
                    original_experiment = experiment
                do_math = experiment in self.MATH_TASKS
                code = subj_index.get_value('subject_alias', montage=max_montage.split('.')[1])
                kwargs = dict(subject=subject, montage=max_montage, experiment=original_experiment, session=session,
                              new_experiment=experiment, original_session=original_session,
                              do_math=do_math, protocol=self.protocol, code=code, do_compare=False)
                importer = Importer(Importer.BUILD_EVENTS, **kwargs)
                if importer.check():
                    self.importers.append(importer)

    def run_all_imports(self):
        for importer in self.importers:
            importer.run()

    def sorted_importers(self):
        order = 'initialized', 'errored', '_should_transfer', 'transferred', 'processed', 'subject'
        return sorted(self.importers, key=lambda imp: [imp.__dict__[o] for o in order])

    def describe(self):
        descriptions = []
        if not self.importers:
            return 'No Importers'
        for importer in self.sorted_importers():
            descriptions.append(importer.describe())
        return '\n---------------\n'.join(descriptions)
def get_subject_sessions_by_experiment(experiment, protocol='r1', include_montage_changes=False):
    """

    :param experiment:
    :param protocol:
    :param include_montage_changes:
    :return: subject, subject_code,  session, original_session, experiment, version
    """
    json_reader = JsonIndexReader(os.path.join(paths.rhino_root,'protocols','%s.json'%protocol))
    if experiment in json_reader.experiments():
        subjects = json_reader.subjects(experiment=experiment)
        for subject_no_montage in subjects:
            for montage in json_reader.montages(subject=subject_no_montage, experiment=experiment):
                subject = subject_no_montage if montage == '0' else '%s_%s' % (subject_no_montage, montage)
                sessions = json_reader.sessions(subject=subject_no_montage, montage=montage, experiment=experiment)
                for session in sessions:
                    try:
                        original_session =  json_reader.get_value('original_session',
                                                                  subject=subject_no_montage,experiment=experiment,
                                                                  session=session)
                    except ValueError:
                        original_session = session # not necessarily robust
                    yield subject_no_montage, subject,session, original_session,  experiment, '0'
    else:
        if re.match('catFR[0-4]', experiment):
            ram_exp = 'RAM_{}'.format(experiment[0].capitalize() + experiment[1:])
        else:
            ram_exp = 'RAM_{}'.format(experiment)
        events_dir = os.path.join(paths.data_root,'events',ram_exp)
        events_files = sorted(glob.glob(os.path.join(events_dir, '{}*_events.mat'.format(protocol.upper()))),
                              key=lambda f: f.split('_')[:-1])
        seen_experiments = defaultdict(list)
        for events_file in events_files:
            subject = '_'.join(os.path.basename(events_file).split('_')[:-1])
            subject_no_montage = subject.split('_')[0]
            if '_' in subject:
                if not include_montage_changes:
                    continue
            mat_events_reader = BaseEventReader(filename=events_file, common_root=paths.data_root)
            logger.debug('Loading matlab events {exp}: {subj}'.format(exp=experiment, subj=subject))
            try:
                mat_events = mat_events_reader.read()
                sessions = np.unique(mat_events['session'])
                version_str = mat_events[-5]['expVersion'] if 'expVersion' in mat_events.dtype.names else '0'
                version = -1
                try:
                    version = float(version_str.split('_')[-1])
                except:
                    try:
                        version = float(version_str.split('v')[-1])
                    except:
                        pass

                for i, session in enumerate(sessions):
                    if 'experiment' in mat_events.dtype.names:
                        experiments = np.unique(mat_events[mat_events['session'] == session]['experiment'])
                    else:
                        experiments = [experiment]
                    for this_experiment in experiments:
                        n_sessions = seen_experiments[subject_no_montage].count(this_experiment)
                        yield subject_no_montage, subject, n_sessions, session, this_experiment, version
                        seen_experiments[subject_no_montage].append(this_experiment)
            except AttributeError:
                traceback.print_exc()
                logger.error('Could not get session from {}'.format(events_file))