Beispiel #1
0
    def create_histogram(self, normalize_histogram=False):
        counts = []
        words = []
        word_occurences = []
        hists = []
        words_1 = []

        self.logger.info('Start histogram creation...')
        fe = ExtractFeatures(self.feature_type)
        for idx, image_name in enumerate(self.im_list):
            self.logger.debug('Processing image %s, %s out of %s', image_name,
                              idx, len(self.im_list))
            try:
                img = np.array(Image.open(image_name))
            except Exception as e:
                self.logger.error('Failed to load image %s', e)
            features = fe.extractFeature(img)
            words = self.km.predict(features.reshape(-1, 200))

            histogram = np.bincount(words, minlength=self.vocab_size)

            if normalize_histogram:
                histogram = self.normalize_histogram(histogram)

            hists.append(histogram)
            words_1.append(words)
        self.logger.debug('')
        self.logger.info('Finished histogram creation')
        return hists
Beispiel #2
0
    def __init__(self, g2p_dir, nlpnet_model_dir='', dic_file=''):

        self.clf_delaf = joblib.load('%s/g2p_clf_delaf.pkl' % g2p_dir)
        self.vectorizer_delaf = joblib.load('%s/vectorizer_delaf.pkl' % g2p_dir)
        self.lab_encoder_delaf = joblib.load('%s/lab_encoder_delaf.pkl' % g2p_dir)

        self.clf_simple = joblib.load('%s/g2p_clf_simple.pkl' % g2p_dir)
        self.vectorizer_simple = joblib.load('%s/vectorizer_simple.pkl' % g2p_dir)
        self.lab_encoder_simple = joblib.load('%s/lab_encoder_simple.pkl' % g2p_dir)
        
        self.feat_extractor = ExtractFeatures(nlpnet_model_dir)

        # Load and process the exception dictionary
        # Transcriptions stored in this dictionary are assumed to be right
        # They are returned instead of being transcribed online
        self.dic = {}
        if dic_file != '':
            with open(dic_file) as dic:
                for line in dic:

                    # Separate the entry and the transcription
                    entry, trans = line.decode('utf-8').rsplit(';')

                    entry = entry.strip()
                    trans = trans.strip()

                    # Create a DelafEntry object, in order to be able to retrieve the word and the gramm info
                    delaf_entry = DelafEntry(entry)
                    word = delaf_entry.getWord()
                    pos = delaf_entry.getPos()
                    self.dic[word] = {}
                    self.dic[word][pos] = trans
    def feature_extraction(self):
        self.logger.info('Start feature extraction...')
        feature_extractor = ExtractFeatures(self.feature_type)

        #Feature list
        all_features = []
        for idx, image_name in enumerate(self.im_list):
            self.logger.debug('Processing image %s, %s out of %s', image_name,
                              idx, len(self.im_list))
            try:
                img = np.array(Image.open(image_name))
            except Exception as e:
                self.logger('Failed to load image %s', e)

            features = feature_extractor.extractFeature(img)
            self.logger.debug('Feature shape %s', features.shape)

            all_features.append(features.reshape(-1, 200))
        self.logger.info('Finished feature extraction')
        return all_features
    def __init__(self, mjlog_file=None, log_url=None, stop_tag=None):
        if log_url:
            log_id, player_position, needed_round = self._parse_url(log_url)
            log_content = self._download_log_content(log_id)
        elif mjlog_file:
            with open(mjlog_file, encoding="utf8") as f:
                log_id = mjlog_file.split("/")[-1].split(".")[0]
                player_position = 0  # tw: seat
                needed_round = 1  # ts: round
                log_content = f.read()
        rounds = self._parse_rounds(log_content)

        self.player_position = player_position
        self.round_content = rounds[needed_round]
        self.stop_tag = stop_tag
        self.decoder = TenhouDecoder()

        # ADD: to get results of all rounds
        self.rounds = rounds
        # ADD: to extract features to be saved
        self.extract_features = ExtractFeatures()
Beispiel #5
0
    def generate_features(self, file_path, label):
        """[convert audio file using mfcc feature extration]
        Args:
            file_path ([string]): [path of the sample file]
            label ([int]): [0 if is music and 1 if is speech]
        """

        signal, sample_rate = librosa.load(
            file_path, sr=self.config["pre_process"]["sample_rate"])
        num_segments = floor(len(signal) / self.samples_per_segment)
        extractFeatures = ExtractFeatures()
        num_mfcc_vectors_per_segment = math.ceil(
            self.samples_per_segment /
            self.config["pre_process"]["hop_length"])

        # process all segments of audio file
        for d in range(num_segments):
            # calculate start and finish sample for current segment
            start = self.samples_per_segment * d
            finish = start + self.samples_per_segment

            #mfcc = extractFeatures.generate_mfcc(self.config,signal[start:finish])
            centroid = extractFeatures.generate_spectral_centroid(
                self.config, signal[start:finish])
            zcr = extractFeatures.generate_zero_crossing_rate(
                self.config, signal[start:finish])
            mel_spec = extractFeatures.generate_mel_spectogram(
                self.config, signal[start:finish])
            onset = extractFeatures.geneate_onset_strength(
                self.config, signal[start:finish])
            rms = extractFeatures.generate_rms(self.config,
                                               signal[start:finish])
            self.data["data"].append({
                # "mfcc": mfcc,
                "centroid": centroid,
                "zcr": zcr,
                "mel_spec": mel_spec,
                "onset": onset,
                "rms": rms,
                "label": label
            })
class TenhouLogReproducer(object):
    """
    The way to debug bot decisions that it made in real tenhou.net games
    """
    def __init__(self, mjlog_file=None, log_url=None, stop_tag=None):
        if log_url:
            log_id, player_position, needed_round = self._parse_url(log_url)
            log_content = self._download_log_content(log_id)
        elif mjlog_file:
            with open(mjlog_file, encoding="utf8") as f:
                log_id = mjlog_file.split("/")[-1].split(".")[0]
                player_position = 0  # tw: seat
                needed_round = 1  # ts: round
                log_content = f.read()
        rounds = self._parse_rounds(log_content)

        self.player_position = player_position
        self.round_content = rounds[needed_round]
        self.stop_tag = stop_tag
        self.decoder = TenhouDecoder()

        # ADD: to get results of all rounds
        self.rounds = rounds
        # ADD: to extract features to be saved
        self.extract_features = ExtractFeatures()

    def reproduce(self, dry_run=False):
        draw_tags = ['T', 'U', 'V', 'W']
        discard_tags = ['D', 'E', 'F', 'G']

        player_draw = draw_tags[self.player_position]

        player_draw_regex = re.compile('^<[{}]+\d*'.format(
            ''.join(player_draw)))

        draw_regex = re.compile('^<[{}]+\d*'.format(''.join(draw_tags)))
        discard_regex = re.compile('^<[{}]+\d*'.format(''.join(discard_tags)))

        table = Table()
        previous_tag = ""
        score = 1
        is_valid_sample = False
        for n, tag in enumerate(self.round_content):
            if dry_run:
                print(tag)

            if not dry_run and tag == self.stop_tag:
                break

            if 'INIT' in tag:
                values = self.decoder.parse_initial_values(tag)

                shifted_scores = []
                for x in range(0, 4):
                    shifted_scores.append(
                        values['scores'][self._normalize_position(
                            x, self.player_position)])

                table.init_round(
                    values['round_number'],
                    values['count_of_honba_sticks'],
                    values['count_of_riichi_sticks'],
                    values['dora_indicator'],
                    self._normalize_position(self.player_position,
                                             values['dealer']),
                    shifted_scores,
                )

                hands = [
                    [
                        int(x) for x in self.decoder.get_attribute_content(
                            tag, 'hai0').split(',')
                    ],
                    [
                        int(x) for x in self.decoder.get_attribute_content(
                            tag, 'hai1').split(',')
                    ],
                    [
                        int(x) for x in self.decoder.get_attribute_content(
                            tag, 'hai2').split(',')
                    ],
                    [
                        int(x) for x in self.decoder.get_attribute_content(
                            tag, 'hai3').split(',')
                    ],
                ]

                # DEL: we can't only initialize the main player, we must initialize
                # other players as well.
                #table.player.init_hand(hands[self.player_position])

                # ADD: initialize all players on the table
                table.players[0].init_hand(hands[self.player_position])
                table.players[1].init_hand(hands[(self.player_position + 1) %
                                                 4])
                table.players[2].init_hand(hands[(self.player_position + 2) %
                                                 4])
                table.players[3].init_hand(hands[(self.player_position + 3) %
                                                 4])

                # ADD: when restart a new game, we need to reinitialize the config
                self.extract_features.__init__()

            # We must deal with ALL players.
            #if player_draw_regex.match(tag) and 'UN' not in tag:
            if draw_regex.match(tag) and 'UN' not in tag:
                tile = self.decoder.parse_tile(tag)

                # CHG: we must deal with ALL players
                #table.player.draw_tile(tile)
                if "T" in tag:
                    table.players[0].draw_tile(tile)
                elif "U" in tag:
                    table.players[1].draw_tile(tile)
                elif "V" in tag:
                    table.players[2].draw_tile(tile)
                elif "W" in tag:
                    table.players[3].draw_tile(tile)
                    #print("After draw `W`:", table.players[3].tiles)

            if discard_regex.match(tag) and 'DORA' not in tag:
                tile = self.decoder.parse_tile(tag)
                player_sign = tag.upper()[1]

                # TODO: I don't know why the author wrote the code as below, the
                # player_seat won't work if we use self._normalize_position. This
                # might be a tricky part, and we need to review it later.
                #player_seat = self._normalize_position(self.player_position, discard_tags.index(player_sign))

                # Temporally solution to modify the player_seat
                player_seat = (discard_tags.index(player_sign) +
                               self.player_position) % 4
                #print("updated player seat:",player_seat)

                if player_seat == 0:
                    table.players[player_seat].discard_tile(
                        DiscardOption(table.players[player_seat], tile // 4, 0,
                                      [], 0))
                else:
                    # ADD: we must take care of ALL players
                    tile_to_discard = tile

                    is_tsumogiri = tile_to_discard == table.players[
                        player_seat].last_draw
                    # it is important to use table method,
                    # to recalculate revealed tiles and etc.
                    table.add_discarded_tile(player_seat, tile_to_discard,
                                             is_tsumogiri)

                    #print("seat:",player_seat)
                    #print("tiles:", TilesConverter.to_one_line_string(table.players[player_seat].tiles), " discard?:", TilesConverter.to_one_line_string([tile_to_discard]))
                    table.players[player_seat].tiles.remove(tile_to_discard)

                    # DEL
                    #table.add_discarded_tile(player_seat, tile, False)

            if '<N who=' in tag:
                meld = self.decoder.parse_meld(tag)
                #player_seat = self._normalize_position(self.player_position, meld.who)
                # Again, we change the player_seat here
                player_seat = (meld.who + self.player_position) % 4
                table.add_called_meld(player_seat, meld)

                #if player_seat == 0:
                # CHG: we need to handle ALL players here
                if True:
                    # we had to delete called tile from hand
                    # to have correct tiles count in the hand
                    if meld.type != Meld.KAN and meld.type != Meld.CHANKAN:
                        table.players[player_seat].draw_tile(meld.called_tile)

            if '<REACH' in tag and 'step="1"' in tag:
                who_called_riichi = self._normalize_position(
                    self.player_position,
                    self.decoder.parse_who_called_riichi(tag))
                table.add_called_riichi(who_called_riichi)

            # This part is to extract the features that will be used to train
            # our model.
            try:
                next_tag = self.round_content[n + 1]
            except IndexError:
                next_tag = ""
            if '<AGARI' in next_tag:
                who_regex = re.compile("who=\"\d+\"")
                fromWho_regex = re.compile("fromWho=\"\d+\"")
                sc_regex = "sc=\"[+-]?\d+,[+-]?\d+,[+-]?\d+,[+-]?\d+,[+-]?\d+,[+-]?\d+,[+-]?\d+,[+-]?\d+\""
                score_regex = re.compile(sc_regex)
                who = int(
                    who_regex.search(next_tag).group(0).replace(
                        '"', '').split("=")[1])
                fromWho = int(
                    fromWho_regex.search(next_tag).group(0).replace(
                        '"', '').split("=")[1])
                scores = [
                    float(s)
                    for s in score_regex.search(next_tag).group(0).replace(
                        '"', '').split("=")[1].split(",")
                ]
                score = scores[fromWho * 2 + 1]
                player_seat, features = self.execute_extraction(tag, table)

                #                # tsumo is not a valid sample for our training.
                #                if (who!=fromWho): # not tsumo (lose the score to winner, score<0)
                #                    if (features is not None) and (player_seat is not None) and (score<0):
                #                        # The first element before ";" is table_info, therefore player_info starts
                #                        # from index 1, and we put who+1 here.
                #                        self.feature_to_logger(features, who+1, score)
                #                        score = 1

                # tsumo is a valid sample for our training
                if (who == fromWho):  # tsumo (win the score, score>0)
                    if (features
                            is not None) and (player_seat
                                              is not None) and (score > 0):
                        self.feature_to_logger(features, who + 1, score)
                        score = -1

            else:
                player_seat, features = self.execute_extraction(tag, table)

        if not dry_run:
            tile = self.decoder.parse_tile(self.stop_tag)
            print('Hand: {}'.format(table.player.format_hand_for_print(tile)))

            # to rebuild all caches
            table.player.draw_tile(tile)
            tile = table.player.discard_tile()

            # real run, you can stop debugger here
            table.player.draw_tile(tile)
            tile = table.player.discard_tile()

            print('Discard: {}'.format(
                TilesConverter.to_one_line_string([tile])))

    def feature_to_logger(self, features, player_seat, score):
        features_list = features.split(";")
        assert len(features_list) == 6, "<D> Features format incorrect!"
        table_info = features_list[0]
        player_info = features_list[player_seat]
        logger2.info(table_info + ";" + player_info + ";" + str(score))

    def execute_extraction(self, tag, table):
        """
        D/E/F/G are for discards
        T/U/V/W are for draws
        """
        if ('<T' in tag) or ('<D' in tag):
            features = self.extract_features.get_scores_features(table)
            return 1, features
        if ('<U' in tag) or ('<E' in tag):
            features = self.extract_features.get_scores_features(table)
            return 2, features
        if ('<V' in tag) or ('<F' in tag):
            features = self.extract_features.get_scores_features(table)
            return 3, features
        if ('<W' in tag) or ('<G' in tag):
            features = self.extract_features.get_scores_features(table)
            return 4, features
        return None, None


#    def execute_extraction(self, tag, score, table, to_logger):
#        if '<D' in tag:
#            #features = self.extract_features.get_is_waiting_features(table)
#            #features = self.extract_features.get_waiting_tiles_features(table)
#            features = self.extract_features.get_scores_features(score, table)
#            if (features is not None) and to_logger:
#                features_list = features.split(";")
#                assert len(features_list)==6, "<D> Features format incorrect!"
#                score_info = features_list[0]
#                player_info = features_list[1]
#                logger2.info(score_info + ";" + player_info)
#
#        if '<E' in tag:
#            #features = self.extract_features.get_is_waiting_features(table)
#            #features = self.extract_features.get_waiting_tiles_features(table)
#            features = self.extract_features.get_scores_features(score, table)
#            if (features is not None) and to_logger:
#                features_list = features.split(";")
#                assert len(features_list)==6, "<E> Features format incorrect!"
#                score_info = features_list[0]
#                player_info = features_list[2]
#                logger2.info(score_info + ";" + player_info)
#
#        if '<F' in tag:
#            #features = self.extract_features.get_is_waiting_features(table)
#            #features = self.extract_features.get_waiting_tiles_features(table)
#            features = self.extract_features.get_scores_features(score, table)
#            if (features is not None) and to_logger:
#                features_list = features.split(";")
#                assert len(features_list)==6, "<F> Features format incorrect!"
#                score_info = features_list[0]
#                player_info = features_list[3]
#                logger2.info(score_info + ";" + player_info)
#
#        if '<G' in tag:
#            #features = self.extract_features.get_is_waiting_features(table)
#            #features = self.extract_features.get_waiting_tiles_features(table)
#            features = self.extract_features.get_scores_features(score, table)
#            if (features is not None) and to_logger:
#                features_list = features.split(";")
#                assert len(features_list)==6, "<G> Features format incorrect!"
#                score_info = features_list[0]
#                player_info = features_list[4]
#                logger2.info(score_info + ";" + player_info)

    def reproduce_all(self, dry_run=False):
        for r in self.rounds:
            self.round_content = r
            self.reproduce(dry_run=dry_run)
            print("--------------------------------------\n")

    def _normalize_position(self, who, from_who):
        positions = [0, 1, 2, 3]
        return positions[who - from_who]

    def _parse_url(self, log_url):
        temp = log_url.split('?')[1].split('&')
        log_id, player, round_number = '', 0, 0
        for item in temp:
            item = item.split('=')
            if 'log' == item[0]:
                log_id = item[1]
            if 'tw' == item[0]:
                player = int(item[1])
            if 'ts' == item[0]:
                round_number = int(item[1])
        return log_id, player, round_number

    def _download_log_content(self, log_id):
        """
        Check the log file, and if it is not there download it from tenhou.net
        :param log_id:
        :return:
        """
        temp_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   'logs')
        if not os.path.exists(temp_folder):
            os.mkdir(temp_folder)

        log_file = os.path.join(temp_folder, log_id)
        if os.path.exists(log_file):
            with open(log_file, 'r') as f:
                return f.read()
        else:
            url = 'http://e.mjv.jp/0/log/?{0}'.format(log_id)
            response = requests.get(url)

            with open(log_file, 'w') as f:
                f.write(response.text)

            return response.text

    def _parse_rounds(self, log_content):
        """
        Build list of round tags
        :param log_content:
        :return:
        """
        rounds = []

        game_round = []
        tag_start = 0
        tag = None
        for x in range(0, len(log_content)):
            if log_content[x] == '>':
                tag = log_content[tag_start:x + 1]
                tag_start = x + 1

            # not useful tags
            if tag and ('mjloggm' in tag or 'TAIKYOKU' in tag):
                tag = None

            # new round was started
            if tag and 'INIT' in tag:
                rounds.append(game_round)
                game_round = []

            # the end of the game
            if tag and 'owari' in tag:
                rounds.append(game_round)

            if tag:
                # to save some memory we can remove not needed information from logs
                if 'INIT' in tag:
                    # we dont need seed information
                    find = re.compile(r'shuffle="[^"]*"')
                    tag = find.sub('', tag)

                # add processed tag to the round
                game_round.append(tag)
                tag = None

        return rounds[1:]
Beispiel #7
0
class TenhouLogReproducer(object):
    """
    The way to debug bot decisions that it made in real tenhou.net games
    """

    def __init__(self, mjlog_file=None, log_url=None, stop_tag=None):
        if log_url:
            log_id, player_position, needed_round = self._parse_url(log_url)
            log_content = self._download_log_content(log_id)
        elif mjlog_file:
            with open(mjlog_file, encoding="utf8") as f:
                log_id = mjlog_file.split("/")[-1].split(".")[0]
                player_position = 0 # tw: seat
                needed_round = 1 # ts: round
                log_content = f.read()
        rounds = self._parse_rounds(log_content)
        
        self.player_position = player_position
        self.round_content = rounds[needed_round]
        self.stop_tag = stop_tag
        self.decoder = TenhouDecoder()
        
        
        # ADD: to get results of all rounds
        self.rounds = rounds
        # ADD: to extract features to be saved
        self.extract_features = ExtractFeatures()

    def reproduce(self, dry_run=False):
        draw_tags = ['T', 'U', 'V', 'W']
        discard_tags = ['D', 'E', 'F', 'G']


        player_draw = draw_tags[self.player_position]
        player_draw_regex = re.compile('^<[{}]+\d*'.format(''.join(player_draw)))
        
        draw_regex = re.compile('^<[{}]+\d*'.format(''.join(draw_tags)))
        discard_regex = re.compile('^<[{}]+\d*'.format(''.join(discard_tags)))

        table = Table()
        score = 1
        
        skip = False # We neglect those unwanted records
        clean_records = [] # We use a list to store the clean records
        
        for n, tag in enumerate(self.round_content):
            if dry_run:
                if (draw_regex.match(tag) and 'UN' not in tag) or (discard_regex.match(tag) and 'DORA' not in tag):
                    tile = self.decoder.parse_tile(tag)
                    print("%s %s"%(tag, TilesConverter.to_one_line_string([tile])))
                else:
                    print(tag)

            if not dry_run and tag == self.stop_tag:
                break

            if 'INIT' in tag:
                values = self.decoder.parse_initial_values(tag)

                shifted_scores = []
                for x in range(0, 4):
                    shifted_scores.append(values['scores'][self._normalize_position(x, self.player_position)])

                table.init_round(
                    values['round_number'],
                    values['count_of_honba_sticks'],
                    values['count_of_riichi_sticks'],
                    values['dora_indicator'],
                    self._normalize_position(self.player_position, values['dealer']),
                    shifted_scores,
                )

                hands = [
                    [int(x) for x in self.decoder.get_attribute_content(tag, 'hai0').split(',')],
                    [int(x) for x in self.decoder.get_attribute_content(tag, 'hai1').split(',')],
                    [int(x) for x in self.decoder.get_attribute_content(tag, 'hai2').split(',')],
                    [int(x) for x in self.decoder.get_attribute_content(tag, 'hai3').split(',')],
                ]
                 
                # DEL: we can't only initialize the main player, we must initialize
                # other players as well.
                #table.player.init_hand(hands[self.player_position])
                
                # ADD: initialize all players on the table
                table.players[0].init_hand(hands[self.player_position])
                table.players[1].init_hand(hands[(self.player_position+1)%4])
                table.players[2].init_hand(hands[(self.player_position+2)%4])
                table.players[3].init_hand(hands[(self.player_position+3)%4])
                
                # ADD: when restart a new game, we need to reinitialize the config
                self.extract_features.__init__()
                # raw records of a new game
                raw_records = []

            # Trigger skip condition after <INIT>, b/c <INIT> has higher priority than skip
            if skip:
                continue

            # We must deal with ALL players.
            #if player_draw_regex.match(tag) and 'UN' not in tag:
            if draw_regex.match(tag) and 'UN' not in tag:
                tile = self.decoder.parse_tile(tag)
                
                # CHG: we must deal with ALL players
                #table.player.draw_tile(tile)
                if "T" in tag:
                    table.players[0].draw_tile(tile)
                elif "U" in tag:
                    table.players[1].draw_tile(tile)
                elif "V" in tag:
                    table.players[2].draw_tile(tile)
                elif "W" in tag:
                    table.players[3].draw_tile(tile)
                    #print("After draw `W`:", table.players[3].tiles)
                
            if discard_regex.match(tag) and 'DORA' not in tag:
                tile = self.decoder.parse_tile(tag)
                
                player_sign = tag.upper()[1]
                
                # TODO: I don't know why the author wrote the code as below, the 
                # player_seat won't work if we use self._normalize_position. This 
                # might be a tricky part, and we need to review it later.
                #player_seat = self._normalize_position(self.player_position, discard_tags.index(player_sign))
                
                # Temporally solution to modify the player_seat
                player_seat = (discard_tags.index(player_sign) + self.player_position)%4

                # Whenever a tile is discarded, the other players will check their
                # hands to see if they could call meld (steal the tile).
                try:
                    next_tag = self.round_content[n+1]
                except:
                    next_tag = ""
                for i in range(4):
                    if i!=player_seat:
                        is_kamicha_discard = (i-1==player_seat)
                        comb = table.players[i].get_possible_melds(tile, is_kamicha_discard)
                        table.players[i].possible_melds = comb   
                        if len(comb)>0: # he can call melds now
                              #print("\nplayer %s closed hand: %s"%(i,[t//4 for t in table.players[i].closed_hand]))
                              print("player %s closed hand: %s"%(i,TilesConverter.to_one_line_string(table.players[i].closed_hand)))
                              print("Tag: %s\nNext tag: %s"%(tag, next_tag))
                              if '<N who=' in next_tag:
                                  meld = self.decoder.parse_meld(next_tag)
                                  # TODO: Obviously the player seat is confusing again. I can't
                                  # get this part fixed. So let's just manually fix it for the moment.
                                  meld.from_who = player_seat
                                  print("%s meld from %s: get %s to form %s(%s): %s\n"%(meld.who, meld.from_who, meld.called_tile, meld.type, meld.opened, meld.tiles))
                                  assert meld.called_tile==tile, "Called tile NOT equal to discarded tile!"
                                  meld_tiles = [t//4 for t in meld.tiles]
                              else:
                                  meld_tiles = []
                              # This part is to extract the features for stealing
                              features = self.extract_features.get_stealing_features(table)
                              # write data to log file
                              # Note: since the first info is table info, player info starts
                              # from 1, so we use (i+1) here.
                              self.feature_to_logger(features, i+1, meld_tiles, tile, comb)
                
                if player_seat == 0:
                    table.players[player_seat].discard_tile(DiscardOption(table.players[player_seat], tile // 4, 0, [], 0))
                else:
                    # ADD: we must take care of ALL players
                    tile_to_discard = tile
            
                    is_tsumogiri = tile_to_discard == table.players[player_seat].last_draw
                    # it is important to use table method,
                    # to recalculate revealed tiles and etc.
                    table.add_discarded_tile(player_seat, tile_to_discard, is_tsumogiri)
                    
                    table.players[player_seat].tiles.remove(tile_to_discard)
            
                # This part is to extract the features we need.    
#                player_seat, features = self.execute_extraction(tag, table)
#                raw_records.append((player_seat, features))

            if '<N who=' in tag:
                meld = self.decoder.parse_meld(tag)
                #player_seat = self._normalize_position(self.player_position, meld.who)
                # Again, we change the player_seat here
                player_seat = (meld.who + self.player_position) % 4
                table.add_called_meld(player_seat, meld)

                #if player_seat == 0:
                # CHG: we need to handle ALL players here    
                if True:
                    # we had to delete called tile from hand
                    # to have correct tiles count in the hand
                    if meld.type != Meld.KAN and meld.type != Meld.CHANKAN:
                        table.players[player_seat].draw_tile(meld.called_tile)

            # [Joseph]: For old records, there is no 'step' in the tag. We need to take care of it
            if ('<REACH' in tag and 'step="1"' in tag) or ('<REACH' in tag and 'step' not in tag):
                
                # [Joseph] I don't know why he used _normalize_position here, from which I got incorrect
                # positions. Therefore I simply use the parsed result instead.
#                who_called_riichi = self._normalize_position(self.player_position,
#                                                             self.decoder.parse_who_called_riichi(tag))
                who_called_riichi = self.decoder.parse_who_called_riichi(tag)
                
                table.add_called_riichi(who_called_riichi)
                
#                print("\n")
#                print("{} Riichi!".format(who_called_riichi))
#                print("self.player_position: {}".format(self.player_position))
#                print("parse: {}".format(self.decoder.parse_who_called_riichi(tag)))
#                print("\n")
                # We need to delete those unnecessary records for one player mahjong
#                for (player_seat, features) in raw_records:
#                    if player_seat==(who_called_riichi+1):
#                        clean_records.append((player_seat,features))
                
                skip = True # skip all remaining tags untill <INIT>
        
        # Write the records to log        
#        for player_seat, features in clean_records:
#            self.feature_to_logger(features, player_seat)
              
            # This part is to extract the features that will be used to train
            # our model.
#            try:
#                next_tag = self.round_content[n+1]
#            except IndexError:
#                next_tag = ""
#            if '<AGARI' in next_tag:           
#                who_regex = re.compile("who=\"\d+\"")
#                fromWho_regex = re.compile("fromWho=\"\d+\"")           
#                sc_regex = "sc=\"[+-]?\d+,[+-]?\d+,[+-]?\d+,[+-]?\d+,[+-]?\d+,[+-]?\d+,[+-]?\d+,[+-]?\d+\""
#                score_regex = re.compile(sc_regex)
#                machi_regex = re.compile("machi=\"\d+\"")
#                
#                who = int(who_regex.search(next_tag).group(0).replace('"','').split("=")[1])
#                fromWho = int(fromWho_regex.search(next_tag).group(0).replace('"','').split("=")[1])
#                scores = [float(s) for s in score_regex.search(next_tag).group(0).replace('"','').split("=")[1].split(",")]              
#                machi = int(machi_regex.search(next_tag).group(0).replace('"','').split("=")[1])
#                score = scores[fromWho*2+1] 
#                player_seat, features = self.execute_extraction(tag, table)
#                
#                if (who!=fromWho): # tsumo is not a valid sample for our training.                    
#                    if (features is not None) and (player_seat is not None) and (score<0):
#                        # The first element before ";" is table_info, therefor player_info starts
#                        # from index 1, and we put who+1 here.
#                        self.feature_to_logger(features, who+1, machi//4, score)
#                        score = 1
#                    #print("\n{}\n{}\n".format(tag,table.players[who].tiles))
#            else:
#                player_seat, features = self.execute_extraction(tag, table)
         

                   
        if not dry_run:
            tile = self.decoder.parse_tile(self.stop_tag)
            print('Hand: {}'.format(table.player.format_hand_for_print(tile)))

            # to rebuild all caches
            table.player.draw_tile(tile)
            tile = table.player.discard_tile()

            # real run, you can stop debugger here
            table.player.draw_tile(tile)
            tile = table.player.discard_tile()

            print('Discard: {}'.format(TilesConverter.to_one_line_string([tile])))
            
    def feature_to_logger(self, features, player_seat, meld_tiles, tile, comb):
        """
        param features:
        """
        features_list = features.split(";")
        assert len(features_list)==6, "<D> Features format incorrect!"
        table_info = features_list[0]
        player_info = features_list[player_seat]
        
        d1, d2 = table_info, player_info
        
        (table_count_of_honba_sticks,
         table_count_of_remaining_tiles,
         table_count_of_riichi_sticks,
         table_round_number,
         table_round_wind,
         table_turns,
         table_dealer_seat,
         table_dora_indicators,
         table_dora_tiles,
         table_revealed_tiles) = ast.literal_eval(d1)
        
        (player_winning_tiles,                   
         player_discarded_tiles,
         player_closed_hand,
         player_dealer_seat,
         player_in_riichi,
         player_is_dealer,
         player_is_open_hand,  
         player_last_draw,            
         player_melds,                
         player_name,
         player_position,
         player_rank,
         player_scores,
         player_seat,
         player_uma) = ast.literal_eval(d2)
        
#        if player_discarded_tiles:
#            player_discard = player_discarded_tiles[-1]
#        else:
#            player_discard = -1
                
        #logger2.info(table_info + ";" + player_info + ";" + str(machi_34) + ";" + str(score))    
#        logger2.info(str(player_last_draw) + ";" + 
#                     str(player_discard) + ";" + 
#                     str(player_closed_hand) + ";" +
#                     str(table_revealed_tiles)
#                     )

        logger2.info(str(meld_tiles) + ";" + 
                     str(tile) + ";" +  # This info might not be used, but let's keep it for checking purpose 
                     str(comb) + ";" +  # This info might not be used, but let's just keep it for checking purpose
                     str(player_closed_hand) + ";" +
                     str(player_melds) + ";" +
                     str(table_revealed_tiles) + ";" + 
                     str(table_turns)
                     )    
        
    def execute_extraction(self, table):
        features = self.extract_features.get_stealing_features(table)             
        return features
     
    
#    def execute_extraction(self, tag, score, table, to_logger):
#        if '<D' in tag:
#            #features = self.extract_features.get_is_waiting_features(table)
#            #features = self.extract_features.get_waiting_tiles_features(table)
#            features = self.extract_features.get_scores_features(score, table)
#            if (features is not None) and to_logger:
#                features_list = features.split(";")
#                assert len(features_list)==6, "<D> Features format incorrect!"
#                score_info = features_list[0]
#                player_info = features_list[1]
#                logger2.info(score_info + ";" + player_info)
#                
#        if '<E' in tag:
#            #features = self.extract_features.get_is_waiting_features(table)
#            #features = self.extract_features.get_waiting_tiles_features(table)
#            features = self.extract_features.get_scores_features(score, table)
#            if (features is not None) and to_logger:
#                features_list = features.split(";")
#                assert len(features_list)==6, "<E> Features format incorrect!"
#                score_info = features_list[0]
#                player_info = features_list[2]
#                logger2.info(score_info + ";" + player_info)
#                
#        if '<F' in tag:
#            #features = self.extract_features.get_is_waiting_features(table)
#            #features = self.extract_features.get_waiting_tiles_features(table)
#            features = self.extract_features.get_scores_features(score, table)
#            if (features is not None) and to_logger:
#                features_list = features.split(";")
#                assert len(features_list)==6, "<F> Features format incorrect!"
#                score_info = features_list[0]
#                player_info = features_list[3]
#                logger2.info(score_info + ";" + player_info)
#               
#        if '<G' in tag:
#            #features = self.extract_features.get_is_waiting_features(table)
#            #features = self.extract_features.get_waiting_tiles_features(table)
#            features = self.extract_features.get_scores_features(score, table)
#            if (features is not None) and to_logger:
#                features_list = features.split(";")
#                assert len(features_list)==6, "<G> Features format incorrect!"
#                score_info = features_list[0]
#                player_info = features_list[4]
#                logger2.info(score_info + ";" + player_info)
                
       
            
    def reproduce_all(self, dry_run=False):
        for r in self.rounds:
            self.round_content = r
            self.reproduce(dry_run=dry_run)
            print("--------------------------------------\n")

    def _normalize_position(self, who, from_who):
        positions = [0, 1, 2, 3]
        return positions[who - from_who]

    def _parse_url(self, log_url):
        temp = log_url.split('?')[1].split('&')
        log_id, player, round_number = '', 0, 0
        for item in temp:
            item = item.split('=')
            if 'log' == item[0]:
                log_id = item[1]
            if 'tw' == item[0]:
                player = int(item[1])
            if 'ts' == item[0]:
                round_number = int(item[1])
        return log_id, player, round_number

    def _download_log_content(self, log_id):
        """
        Check the log file, and if it is not there download it from tenhou.net
        :param log_id:
        :return:
        """
        temp_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'logs')
        if not os.path.exists(temp_folder):
            os.mkdir(temp_folder)

        log_file = os.path.join(temp_folder, log_id)
        if os.path.exists(log_file):
            with open(log_file, 'r') as f:
                return f.read()
        else:
            url = 'http://e.mjv.jp/0/log/?{0}'.format(log_id)
            response = requests.get(url)

            with open(log_file, 'w') as f:
                f.write(response.text)

            return response.text

    def _parse_rounds(self, log_content):
        """
        Build list of round tags
        :param log_content:
        :return:
        """
        rounds = []

        game_round = []
        tag_start = 0
        tag = None
        for x in range(0, len(log_content)):
            if log_content[x] == '>':
                tag = log_content[tag_start:x + 1]
                tag_start = x + 1

            # not useful tags
            if tag and ('mjloggm' in tag or 'TAIKYOKU' in tag):
                tag = None

            # new round was started
            if tag and 'INIT' in tag:
                rounds.append(game_round)
                game_round = []

            # the end of the game
            if tag and 'owari' in tag:
                rounds.append(game_round)

            if tag:
                # to save some memory we can remove not needed information from logs
                if 'INIT' in tag:
                    # we dont need seed information
                    find = re.compile(r'shuffle="[^"]*"')
                    tag = find.sub('', tag)

                # add processed tag to the round
                game_round.append(tag)
                tag = None

        return rounds[1:]
Beispiel #8
0
from kmeans_classifier import KMeansClassifier
from data_preparation import DataPreparation
from extract_features import ExtractFeatures
import pandas as pd
import statistics as s
import math as m

#Cut images - create folder images_to_process and cropped_images before executing
'''
data_prep = DataPreparation("images_to_process/*","cropped_images/",250)
data_prep.cut_images()
'''

#Extract features
feature_extractor = ExtractFeatures(normalize=True)
feature_extractor.extract()
'''
#Read extracted features
df = pd.read_csv('normalized_data')
usecols = ["fft", "blobs", "corners", "b", "g", "r"]

#Cluster
classifier = KMeansClassifier(clusters_count=5,
                              dimensions_count=5,
                              data=df,
                              columns=usecols,
                              epochs=10)
classifier.fit()


Beispiel #9
0
class AeiouadoG2P(object):

    def __init__(self, g2p_dir, nlpnet_model_dir='', dic_file=''):

        self.clf_delaf = joblib.load('%s/g2p_clf_delaf.pkl' % g2p_dir)
        self.vectorizer_delaf = joblib.load('%s/vectorizer_delaf.pkl' % g2p_dir)
        self.lab_encoder_delaf = joblib.load('%s/lab_encoder_delaf.pkl' % g2p_dir)

        self.clf_simple = joblib.load('%s/g2p_clf_simple.pkl' % g2p_dir)
        self.vectorizer_simple = joblib.load('%s/vectorizer_simple.pkl' % g2p_dir)
        self.lab_encoder_simple = joblib.load('%s/lab_encoder_simple.pkl' % g2p_dir)
        
        self.feat_extractor = ExtractFeatures(nlpnet_model_dir)

        # Load and process the exception dictionary
        # Transcriptions stored in this dictionary are assumed to be right
        # They are returned instead of being transcribed online
        self.dic = {}
        if dic_file != '':
            with open(dic_file) as dic:
                for line in dic:

                    # Separate the entry and the transcription
                    entry, trans = line.decode('utf-8').rsplit(';')

                    entry = entry.strip()
                    trans = trans.strip()

                    # Create a DelafEntry object, in order to be able to retrieve the word and the gramm info
                    delaf_entry = DelafEntry(entry)
                    word = delaf_entry.getWord()
                    pos = delaf_entry.getPos()
                    self.dic[word] = {}
                    self.dic[word][pos] = trans


    def _conv_simple(self, word):
        '''
        Return the transcription of a word (simple format) in Aeiouado's phone convention.
        '''

        word = word.strip()        
        trans = []

        for ch_feats in self.feat_extractor.getWordFeaturesSimple(word):
            
            num_feats = self.vectorizer_simple.transform(ch_feats)
        
            predicted_class_num = self.clf_simple.predict(num_feats)[0]
            predicted_class = self.lab_encoder_simple.inverse_transform(predicted_class_num)
        
            trans.append(predicted_class)
                   
        return ''.join(trans)


    def _conv_delaf(self, word, pos, gender, tense, person, number):
        '''
        Return the transcription of a word (Delaf format) in Aeiouado's phone convention.
        '''

        word = word.strip()        
        trans = []

        for ch_feats in self.feat_extractor.getWordFeaturesDelafEntry(word, pos, gender, tense, person, number):
            
            num_feats = self.vectorizer_simple.transform(ch_feats)
        
            predicted_class_num = self.clf_simple.predict(num_feats)[0]
            predicted_class = self.lab_encoder_simple.inverse_transform(predicted_class_num)
        
            trans.append(predicted_class)
                   
        return ''.join(trans)
    

    def _conv2ipa(self, trans):
        trans = trans.replace('P', 'p')
        trans = trans.replace('B', 'b')
        trans = trans.replace('T', 't')
        trans = trans.replace('D', 'd')
        trans = trans.replace('K', 'k')
        trans = trans.replace('G', 'g')
        trans = trans.replace('7', 'tʃ')
        trans = trans.replace('8', 'dʒ')
        trans = trans.replace('M', 'm')
        trans = trans.replace('N', 'n')
        trans = trans.replace('Ñ', 'ɲ')
        trans = trans.replace('F', 'f')
        trans = trans.replace('V', 'v')
        trans = trans.replace('S', 's')
        trans = trans.replace('Z', 'z')
        trans = trans.replace('X', 'ʃ')
        trans = trans.replace('J', 'ʒ')
        trans = trans.replace('L', 'l')
        trans = trans.replace('Ĺ', 'ʎ')
        trans = trans.replace('Ŕ', 'ɾ')
        trans = trans.replace('H', 'x')
        trans = trans.replace('R', 'ɣ')
        trans = trans.replace('-Q', 'k-s')
        trans = trans.replace('Q', 'ks')
        trans = trans.replace('W', 'w')
        trans = trans.replace('Y', 'y')
        trans = trans.replace('Ŵ', 'ʊ̃')
        trans = trans.replace('Ỹ', 'ỹ')
        trans = trans.replace('Á', 'a')
        trans = trans.replace('6', 'ə')
        trans = trans.replace('É', 'ɛ')
        trans = trans.replace('Ê', 'e')
        trans = trans.replace('Í', 'i')
        trans = trans.replace('I', 'ɪ')
        trans = trans.replace('Ó', 'ɔ')
        trans = trans.replace('Ô', 'o')
        trans = trans.replace('Ú', 'u')
        trans = trans.replace('U', 'ʊ')
        trans = trans.replace('Ã', 'ã')
        trans = trans.replace('Ẽ', 'ẽ')
        trans = trans.replace('Ĩ', 'ĩ')
        trans = trans.replace('Õ', 'õ')
        trans = trans.replace('Ũ', 'ũ')
        trans = trans.replace('@', "'")
        trans = trans.replace("-'", "'")
        return trans
    

    def _ipa2sampa(self, trans):
        trans = re.sub('p', ' p ', trans)
        trans = re.sub('b', ' b ', trans)
        trans = re.sub(r't([^ʃ])', r' t \1 ', trans)
        trans = re.sub(r'd([^ʒ])', r' d \1 ', trans)
        trans = re.sub('k', ' k ', trans)
        trans = re.sub('g', ' g ', trans)
        trans = re.sub('tʃ', ' tS ', trans)
        trans = re.sub('dʒ', ' dZ ', trans)
        trans = re.sub('ʃ', ' S ', trans)
        trans = re.sub('ʒ', ' Z ', trans)
        trans = re.sub('m', ' m ', trans)
        trans = re.sub('n', ' n ', trans)
        trans = re.sub('f', ' f ', trans)
        trans = re.sub('v', ' v ', trans)
        trans = re.sub('s', ' s ', trans)
        trans = re.sub('z', ' z ', trans)
        trans = re.sub(r'([^t])ʃ', r' \1 S ', trans)
        trans = re.sub(r'([^d])ʒ', r' \1 Z ', trans)
        trans = re.sub('Ñ', ' J ', trans)
        trans = re.sub('ɲ', ' J ', trans)
        trans = re.sub('L', ' l ', trans)
        trans = re.sub('ʎ', ' L ', trans)
        trans = re.sub('ɾ', ' 4 ', trans)
        trans = re.sub('x', ' x ', trans)
        trans = re.sub('ɣ', ' G ', trans)
        trans = re.sub('ks', ' k s ', trans)
        trans = re.sub('w', ' w ', trans)
        trans = re.sub('y', ' j ', trans)
        trans = re.sub('ʊ̃', ' w~ ', trans)
        trans = re.sub('ỹ', ' j~ ', trans)
        trans = re.sub('a', ' a ', trans)
        trans = re.sub('@', '', trans)
        trans = re.sub('ə', ' @ ', trans)
        trans = re.sub('ɛ', ' E ', trans)
        trans = re.sub('e', ' e ', trans)
        trans = re.sub('i', ' i ', trans)
        trans = re.sub('ɪ', ' I ', trans)
        trans = re.sub('ɔ', ' O ', trans)
        trans = re.sub('o', ' o ', trans)
        trans = re.sub('u', ' u ', trans)
        trans = re.sub('ʊ', ' U ', trans)
        trans = re.sub('\-', ' . ', trans)
        trans = re.sub("'", ' " ', trans)
        
        trans = re.sub(r'ã', r' a~ ', trans)
        trans = re.sub(r'ẽ', r' e~ ', trans)
        trans = re.sub(r'ĩ', r' i~ ', trans)
        trans = re.sub(r'õ', r' o~ ', trans)
        trans = re.sub(r'ũ', r' u~ ', trans)

        trans = re.sub(r' ~', r'~', trans)
    
        trans = re.sub(r'[ ]{2,}', r' ', trans.strip())
        return trans

    def _sampa2htk(self, trans):
        trans = trans.replace('.', '')
        trans = trans.replace('"', '')
    
        trans = re.sub(r'[ ]{2,}', r' ', trans.strip())
        return trans



    def transcribe_word_simple(self, word, c = 'ipa', space = False, dic = True):
    
        word = word.lower()
        
        if dic:
            if word in self.dic:
                if len(self.dic[word]) == 1:
                    for pos in self.dic[word]:
                        trans = self.dic[word][pos]
            else:
                trans = self._conv_simple(word)
        else:
            trans = self._conv_simple(word)
        
        if c == 'ipa':
            return self._conv2ipa(trans)
        elif c == 'xsampa':
            if space == True:
                return self._ipa2sampa(self._conv2ipa(trans))
            else:
                spacedXsampa = self._ipa2sampa(self._conv2ipa(trans))
                return ''.join(spacedXsampa.split())
        elif c == 'htk':
            return self._sampa2htk(self._ipa2sampa(self._conv2ipa(trans)))
        elif c == 'aeiouado':
            return trans
        else:
            return self._conv2ipa(trans)


    def transcribe_word_delaf(self, word, pos, gender, tense, person, number, c = 'ipa', space = False, dic = True):
    
        word = word.lower()
        
        # Check if word is in dictionary
        # If positive, return the first transcription it finds
        if dic:
            if word in self.dic:
                if len(self.dic[word]) == 1:
                    for pos in self.dic[word]:
                        return self.dic[word][pos]
            else:
                trans = self._conv_delaf(word)
        else:
            trans = self._conv_delaf(word)
        
        if c == 'ipa':
            return self._conv2ipa(trans)
        elif c == 'xsampa':
            if space == True:
                return self._ipa2sampa(self._conv2ipa(trans))
            else:
                spacedXsampa = self._ipa2sampa(self._conv2ipa(trans))
                return ''.join(spacedXsampa.split())
        elif c == 'aeiouado':
            return trans
        else:
            return self._conv2ipa(trans)
Beispiel #10
0
def Db_test(label_file, **kwargs):
    #db_file,data_file,base_dir,id_dir,dis_dir,base_id
    db_file = kwargs.get('db_file', None)
    data_file = kwargs.get('data_file', None)
    base_dir = kwargs.get('base_dir', None)
    id_dir = kwargs.get('id_dir', None)
    dis_dir = kwargs.get('saved_dir', None)
    base_id = kwargs.get('base_id', None)
    failed_dir = kwargs.get('failed_dir', None)
    db_in = open(db_file, 'r')
    data_in = open(data_file, 'r')
    db_lines = db_in.readlines()
    data_lines = data_in.readlines()
    dis_record = open("./output/distance_top1.txt", 'w')
    dis_top2 = open("./output/distance_top2.txt", 'w')
    if config.caffe_use:
        face_p1 = "../models/sphere/sph_2.prototxt"
        face_m1 = "../models/sphere/sph20_ms_4v5.caffemodel"
        face_p2 = "../models/mx_models/mobile_model/face-mobile.prototxt"
        face_m2 = "../models/mx_models/mobile_model/face-mobile.caffemodel"
        if config.feature_1024:
            out_layer = 'fc5_n'
        elif config.insight:
            out_layer = 'fc1'
        else:
            out_layer = 'fc5'
        FaceModel_1 = FaceReg(face_p1, face_m1, 'fc5')
        FaceModel_2 = FaceReg(face_p2, face_m2, out_layer)
    elif config.mx_:
        #model_path = "/home/lxy/Develop/Center_Loss/arcface/insightface/models/model-r50-am-lfw/model"
        #model_path = "/home/lxy/Develop/Center_Loss/arcface/insightface/models/model-r34-amf/model"
        #model_path = "../models/mx_models/v1_bn/model" #9
        #model_path = "/home/lxy/Develop/Center_Loss/arcface/insightface/models/model-r100-ii/model"
        #model_path = "/home/lxy/Develop/Center_Loss/arcface/insightface/models/model-y1-test2/model"
        #model_path = "../models/mx_models/mobile_model/model"
        model_path1 = "../models/mx_models/model-100/model-org/modelresave"
        #model_path = "../models/mx_models/model_prison/model"
        #model_path = "../models/mx_models/model-r50/model"
        #model_path = "../models/mx_models/model-100/model-v2/model"
        model_path2 = "../models/mx_models/model-100/model-v3/modelresave"
        epoch1_num = 0  #9#2
        epoch2_num = 0
        img_size = [112, 112]
        FaceModel_1 = mx_Face(model_path1, epoch1_num, img_size)
        FaceModel_2 = mx_Face(model_path2, epoch2_num, img_size)
    else:
        print("please select a frame: caffe mxnet tensorflow")
    #mkdir save image dir
    Face1_features = ExtractFeatures(FaceModel_1)
    Face2_features = ExtractFeatures(FaceModel_2)
    make_dirs(dis_dir)
    make_dirs(failed_dir)
    parm = args()
    DB_FT = Annoy_DB(config.feature_lenth)
    tpr = 0
    fpr = 0
    idx_ = 0
    db_names = []
    save_reg_dirs = []
    db_cnt_dict = dict()
    db_paths_dict = dict()
    #db_features = []
    for item_, line_one in enumerate(db_lines):
        line_one = line_one.strip()
        base_feat1 = Face1_features.extract_f(line_one, id_dir)
        base_feat2 = Face2_features.extract_f(line_one, id_dir)
        if base_feat1 is None or base_feat2 is None:
            print("feature is None")
            continue
        base_feat = featuresProcess([base_feat1, base_feat2])
        one_name = line_one[:-4]
        db_names.append(one_name)
        face_db_dir = os.path.join(dis_dir, one_name)
        save_reg_dirs.append(face_db_dir)
        DB_FT.add_data(item_, base_feat)
        db_path = os.path.join(id_dir, line_one)
        db_paths_dict[one_name] = db_path
        #db_features.append(base_feat)
    #db_features = np.asarray(db_features)
    print("begin do build db")
    #print(db_features.shape)
    #DB_FT = Face_DB(db_features)
    DB_FT.build_db()
    print("db over")
    label_dict = get_label(label_file)
    dest_id_dict = dict()
    org_id_dict = dict()
    org_wrong_dict = dict()
    cnt_org = 0
    cnt_dest = 0
    test_label_dict = dict()
    reg_frame_dict = dict()
    for querry_line in data_lines:
        querry_line = querry_line.strip()
        if config.use_framenum:
            frame_num = get_frame_num(querry_line)
        if config.reg_fromlabel:
            que_spl = querry_line.split("/")
            if config.dir_label:
                key_que = que_spl[0]
            else:
                key_que = que_spl[0][2:]
            test_label_dict[key_que] = 1
            real_label = label_dict.setdefault(key_que, 300) + base_id
        idx_ += 1
        sys.stdout.write('\r>> deal with %d/%d' % (idx_, len(data_lines)))
        sys.stdout.flush()
        if config.face_detect:
            querry_feat1 = Face1_features.extract_f3(querry_line, base_dir)
            querry_feat2 = Face2_features.extract_f3(querry_line, base_dir)
        else:
            querry_feat1 = Face1_features.extract_f2(querry_line, base_dir)
            querry_feat2 = Face2_features.extract_f2(querry_line, base_dir)
        if querry_feat1 is None or querry_feat2 is None:
            print("feature is None")
            continue
        else:
            querry_feat = featuresProcess([querry_feat1, querry_feat2])
            idx_list, distance_list = DB_FT.findNeatest(querry_feat, 2)
            #print("distance, idx ",distance,idx)
            idx = idx_list[0]
            img_name = db_names[idx]
            img_dir = save_reg_dirs[idx]
            pred_label = label_dict.setdefault(img_name, 300) + base_id
            img_cnt = db_cnt_dict.setdefault(img_name, 0)
            org_path = os.path.join(base_dir, querry_line)
            Threshold_Value = config.confidence
            reg_condition = 1 if (distance_list[0] <= config.top1_distance and
                                  (distance_list[1] -
                                   distance_list[0]) >= Threshold_Value) else 0
            if config.use_framenum:
                if reg_condition:
                    reg_frame_cnt = reg_frame_dict.setdefault(
                        img_name, frame_num)
                    frame_interval = np.abs(frame_num - reg_frame_cnt)
                    save_condition = 1 if frame_interval < config.frame_interval else 0
                else:
                    save_condition = 0
            else:
                save_condition = reg_condition
            #dist_path = os.path.join(img_path,img_name+"_"+str(img_cnt)+".jpg")
            #if config.reg_fromlabel:
            #   dist_path = os.path.join(img_dir,que_spl[1])
            #else:
            #   dist_path = os.path.join(img_dir,querry_line)
            #print(dist_path)
            print("distance ", distance_list)
            dis_record.write("%.3f\n" % (distance_list[0]))
            dis_top2.write("%.3f\n" % (distance_list[1] - distance_list[0]))
            #if distance_list[0] <= config.top1_distance and (distance_list[1] - distance_list[0]) >= Threshold_Value:
            if save_condition:
                print("distance: ", distance_list[0])
                make_dirs(img_dir)
                if config.reg_fromlabel:
                    dist_path = os.path.join(img_dir, que_spl[1])
                else:
                    dist_path = os.path.join(img_dir, querry_line)
                shutil.copyfile(org_path, dist_path)
                db_cnt_dict[img_name] = img_cnt + 1
                if db_cnt_dict[img_name] == 1 and config.save_idimg:
                    org_id_path = db_paths_dict[img_name]
                    dist_id_path = os.path.join(img_dir, img_name + ".jpg")
                    shutil.copyfile(org_id_path, dist_id_path)
                if config.reg_fromlabel:
                    print("real and pred ", real_label, pred_label,
                          querry_line)
                    if int(pred_label) == int(real_label):
                        tpr += 1
                        print("*************")
                    else:
                        fpr += 1
                        if config.save_failedimg:
                            failed_img_dir = os.path.join(failed_dir, key_que)
                            make_dirs(failed_img_dir)
                            failed_img_path = os.path.join(
                                failed_img_dir, que_spl[1])
                            shutil.copyfile(org_path, failed_img_path)
                        cnt_org = org_id_dict.setdefault(key_que, 0)
                        org_id_dict[key_que] = cnt_org + 1
                        cnt_dest = dest_id_dict.setdefault(img_name, 0)
                        dest_id_dict[img_name] = cnt_dest + 1
                        ori_dest_id = org_wrong_dict.setdefault(key_que, [])
                        if img_name in ori_dest_id:
                            pass
                        else:
                            org_wrong_dict[key_que].append(img_name)
                else:
                    print("pred label: ", pred_label, querry_line)
                    tpr += 1
                    print("*************")
            elif config.save_failedimg and not config.reg_fromlabel:
                failed_img_dir = os.path.join(failed_dir, img_name)
                make_dirs(failed_img_dir)
                failed_img_path = os.path.join(failed_img_dir, querry_line)
                shutil.copyfile(org_path, failed_img_path)

    right_id = 0
    not_reg_id = []
    for key_name in db_cnt_dict.keys():
        if db_cnt_dict[key_name] != 0:
            right_id += 1
        elif config.reg_fromlabel:
            if key_name in test_label_dict.keys():
                not_reg_id.append(key_name)
        else:
            not_reg_id.append(key_name)
    print("not reg keys: ", not_reg_id)
    print("right id ", right_id)
    if config.reg_fromlabel:
        org_id = 0
        #ori_name = []
        for key_name in org_id_dict.keys():
            if org_id_dict[key_name] != 0:
                org_id += 1
                #ori_name.append[key_name]
        dest_id = 0
        #dest_name = []
        for key_name in dest_id_dict.keys():
            if dest_id_dict[key_name] != 0:
                dest_id += 1
                #dest_name.append[key_name]
        f_result = open("./output/result_record.txt", 'w')
        for key_name in org_wrong_dict.keys():
            f_result.write("{} : {}".format(key_name,
                                            org_wrong_dict[key_name]))
            f_result.write("\n")
        #print("values: ",db_cnt_dict.values())
        f_result.close()
        print("TPR and FPR is ", tpr, fpr)
        print("wrong orignal: ", org_id_dict.values())
        print("who will be wrong: ", dest_id_dict.values())
        print("the org and dest: ", org_id, dest_id)
    else:
        print("Reg img num: ", tpr)
    db_in.close()
    data_in.close()
    dis_record.close()
    dis_top2.close()