class DataGatherer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.downloader = Downloader(cfg)
        self.logger = Logger(cfg)
        self.parser = AnotherHTMLParser(self.logger)
        self.pairs = set()
        self.db_handler = DBHandler(cfg)
        self._word_dict = None

    def read_raw_pairs(self, delimiter=',', limit=0):
        path = cfg['train_path']
        try:
            f = open(path)
        except IOError:
            self.logger.critical("Can't open file '{}'!".format(path))
            sys.exit()

        lines = f.read().split('\n')

        pairs = set()
        i = 0
        for line in lines:
            if not line:
                continue
            if limit and i > limit:
                break

            i += 1
            elements = line.split(delimiter)
            try:
                if elements[2] == 'left':
                    pair = (elements[0], elements[1])
                else:
                    pair = (elements[1], elements[0])
                if pair in pairs:
                    self.logger.warning('pair {} is duplicate!'.format(pair))
                    i -= 1
                pairs.add(pair)
            except IndexError:
                raise AssertionError('line {} is incorrect!'.format(line))
        return pairs

    def read_pairs(self, delimiter=',', limit=0):
        path = cfg['train_fixed_path']
        try:
            f = open(path)
        except IOError:
            self.logger.critical("Can't open file '{}'!".format(path))
            sys.exit()

        lines = f.read().split('\n')

        pairs = set()
        i = 0
        for line in lines:
            if not line:
                continue
            if limit and i > limit:
                break

            i += 1
            elements = line.split(delimiter)
            try:
                pair = tuple(elements)
                if pair in pairs:
                    self.logger.warning('pair {} is duplicate!'.format(pair))
                    i -= 1
                pairs.add(pair)
            except IndexError:
                raise AssertionError('line {} is incorrect!'.format(line))
        return pairs

    def exclude_untracked_videos(self, pairs):
        ids = set(self.db_handler.get_all_video_ids())
        pairs_set = set(pairs)
        for pair in pairs:
            for youtube_id in pair:
                if youtube_id not in ids:
                    pairs_set.remove(pair)
                    break
        return pairs_set

    def rewrite_pairs(self, pairs):
        pairs_fixed = self.exclude_untracked_videos(pairs)
        f = open(self.cfg['train_fixed_path'], 'wb')
        for pair in pairs_fixed:
            f.write(','.join(pair) + '\n')
        f.close()

    def fill_video_catalog(self, pairs, force=False):
        lefts_and_rights = zip(*pairs)
        ids = set(lefts_and_rights[0] + lefts_and_rights[1])

        if not force:
            ids_cache = set(self.db_handler.get_all_video_ids())
            ids.difference_update(ids_cache)

        for i, youtube_id in enumerate(ids):
            if i % 100 == 0:
                self.logger.info('scanned {} lines.'.format(i))
            self.add_video_by_id(youtube_id)

    def update_video_catalog(self, limit=None):
        ids_cache = set(self.db_handler.get_all_video_ids())
        for i, youtube_id in enumerate(ids_cache):
            if limit and i > limit:
                break
            self.update_video_by_id(youtube_id)

    def add_video_by_id(self, youtube_id):
        html = self.downloader.get_html(youtube_id)
        if not self.parser._check_video_availability(html):
            return

        video_item = Video(youtube_id)
        video_item.update(title=self.parser.get_video_title(html))
        self.db_handler.add_entry(video_item)

    def update_video_by_id(self, youtube_id):
        html = self.downloader.get_html(youtube_id)
        if not self.parser._check_video_availability(html):
            return

        video_item = self.db_handler.get_video_by_youtube_id(youtube_id)
        try:
            video_item.update(
                title=self.parser.get_video_title(html),
                views=self.parser.get_view_count(html),
                likes=self.parser.get_likes_count(html),
                dislikes=self.parser.get_dislikes_count(html),
            )
        except ParseError:
            video_item.mark_invalid()
        self.db_handler.commit()

    def update_rank1s(self, pairs):
        videos = self.db_handler.get_all_videos()
        rank1_map = self.get_rank1_map(pairs)

        for video in videos:
            if video.youtube_id in rank1_map:
                video.rank1 = rank1_map[video.youtube_id]
            else:
                self.logger.warning('video {} has no rank calculated!'.format(video.youtube_id))

        self.db_handler.db_session.commit()

    def update_rank2s(self, catalog, pairs):
        videos = self.db_handler.get_all_videos()
        rank2_map = self.get_rank2_map(catalog, pairs)

        for video in videos:
            if video.youtube_id in rank2_map:
                video.rank2 = rank2_map[video.youtube_id]
            else:
                self.logger.warning('video {} has no rank calculated!'.format(video.youtube_id))

        self.db_handler.db_session.commit()

    def update_views(self, force=False):
        if force:
            videos = self.db_handler.get_all_videos()
        else:
            videos = self.db_handler.db_session.query(Video).filter(Video.views == None).all()

        for video in videos:
            try:
                video.views = self.parser.get_view_count(self.downloader.get_html(video.youtube_id))
            except ParseError:
                pass

        self.db_handler.commit()

    def get_video_catalog(self):
        return self.db_handler.get_all_video_data()

    def get_rank1_map(self, pairs):
        ids_above, ids_below = zip(*pairs)
        rank_map = defaultdict(lambda: 0)

        for youtube_id in ids_above:
            rank_map[youtube_id] += 1

        for youtube_id in ids_below:
            rank_map[youtube_id] -= 1

        return rank_map

    def get_rank2_map(self, catalog, pairs):
        chunks = partial_sort(catalog, pairs)
        aggregated_ranks = calculate_aggregated_ranks(chunks)
        assert len(aggregated_ranks) == len(chunks)
        ranked_chunks = zip(aggregated_ranks, chunks)

        r_map = {}
        for rank, chunk in ranked_chunks:
            for youtube_id in chunk:
                r_map[youtube_id] = rank
        return r_map

    def get_char_stat(self):
        characters = set()
        videos = self.db_handler.get_all_videos()
        for video in videos:
            if video.title:
                characters.update(video.title)
        return sorted(list(characters))

    def update_lang_stat(self):
        videos = self.db_handler.get_all_videos()
        for video in videos:
            if video.title:
                video.lang = get_lang(video.title)

        self.db_handler.commit()

    def get_all_words(self):
        words = defaultdict(lambda: 0)
        print 'delimiters: {}'.format(TITLE_DELIMITER)
        videos = self.db_handler.get_all_videos()
        for video in videos:
            for word in extract_words(video.title):
                words[prepare_word(word)] += 1

        return words

    def fill_word_db(self, words):
        for w, count in words.iteritems():
            word = Word(w, None, count)
            self.db_handler.db_session.add(word)
        self.db_handler.commit()

    def fill_words_for_videos(self):
        words = self.db_handler.db_session.query(Word).all()
        word_dict = {}
        for word in words:
            word_dict[word.word] = word

        videos = self.db_handler.get_all_videos()
        for video in videos:
            wordids = set()
            for word in extract_words(video.title):
                w = prepare_word(word)
                if w in word_dict:
                    wordids.add(word_dict[w].id)
            video.wordids = serialize_ids(wordids)

        self.db_handler.commit()

    def calculate_rank1_for_words(self):
        words = self.db_handler.db_session.query(Word).filter(Word.count >= 10).all()
        word_dict = {}
        for word in words:
            word_dict[word.id] = word

        rank_dict = defaultdict(lambda: [])

        videos = self.db_handler.get_all_videos()
        for video in videos:
            word_ids = deserialize_ids(video.wordids)
            for word_id in word_ids:
                if word_id not in word_dict:
                    continue
                rank_dict[word_id].append(video.rank1)

        for word_id in rank_dict:
            if word_id not in word_dict:
                continue
            word_dict[word_id].rank1 = mean(rank_dict.setdefault(word_id, [0]))

        # kostyl! set rank = 0 for word ''
        null_word = self.db_handler.db_session.query(Word).filter(Word.word == '').one()
        null_word.rank1 = 0
        # --

        self.db_handler.commit()

    def get_word_dict_by_word(self):
        if not self._word_dict:
            words = self.db_handler.db_session.query(Word).all()
            self._word_dict = {}
            for word in words:
                self._word_dict[word.word] = word

        return self._word_dict

    def calculate_title_rank(self, title, f):
        word_dict = self.get_word_dict_by_word()
        title_words = extract_words(title)

        title_rank = sum(f(word_dict[x]) for x in title_words if x in word_dict)
        return title_rank