コード例 #1
0
ファイル: extract_keywords.py プロジェクト: XL2248/GRADE
def _obtain_candidate_keywords(list_all_dialogs, candi_kw_path, min_kw_freq=1, load_file_if_exists=True):
    r"""Obtain and save the candidate keywords used for extracting keywords.

    Inputs: list_all_dialogs, candi_kw_path, load_file_if_exists
        # TODO
        - **list_all_dialogs**:
        - **candi_kw_path**:
        - **load_file_if_exists**:

    Outputs: candi_keywords
        - **candi_keywords**:  a 'list' containing all the candidate keywords
    """
    if load_file_if_exists:
        if os.path.isfile(candi_kw_path):
            with open(candi_kw_path,'r') as f:
                candi_keywords = [kw.strip() for kw in f.readlines()]
            print('Loading candidate keywords from {}'.format(candi_kw_path))
            print('Total candidate keywords count: ', len(candi_keywords))
            return candi_keywords

    if not list_all_dialogs:
        raise Exception('no dialogs provided for obtaining candidate keywords')

    candi_kw_dir = os.path.dirname(candi_kw_path)
    if not os.path.exists(candi_kw_dir):
        os.makedirs(candi_kw_dir)

    print('Obtaining candidate keywords...')

    # initialization
    candi_keywords = []
    kw_counter = collections.Counter()
    kw_extractor = KeywordExtractor()

    # extract possible keywords
    for dialog in tqdm(list_all_dialogs):
        for utterance in dialog:
            cur_keywords = kw_extractor.candi_extract(utterance)
            kw_counter.update(cur_keywords)
            candi_keywords.extend(cur_keywords)

    # delete the keywords occurring less than specified times (indicated by 'min_kw_freq').
    rare_keywords = [kw for kw, freq in kw_counter.most_common() if freq < min_kw_freq]
    candi_keywords = [kw for kw, freq in kw_counter.most_common() if freq >= min_kw_freq]
    # delete keywords containing only one single letter
    single_letter_keywords = [kw for kw in candi_keywords if len(kw) < 2]
    candi_keywords = [kw for kw in candi_keywords if len(kw) >= 2]

    # print the information of candidate keywords
    print('rare keywords count: ', len(rare_keywords))
    print('single letter keywords count: ', len(single_letter_keywords))
    print('total candidate keywords count(before cleaning): ', len(kw_counter.items()))
    print('total candidate keywords count(after cleaning):  ', len(candi_keywords))

    print('Saving candi_keywords into {}...'.format(candi_kw_path))
    with open(candi_kw_path,'w') as f:
        for keyword in candi_keywords:
            f.write(keyword + '\n')

    return candi_keywords
コード例 #2
0
ファイル: nfl_tweets.py プロジェクト: maroy/TSTA
def main():

    include_regex = re.compile(r'\bnfl\b', re.IGNORECASE)
    exclude_regex = re.compile(r'\bnba\b|\bnhl\b', re.IGNORECASE)

    extractor = KeywordExtractor()

    idx_db = sqlite3.connect('../idx.db')
    idx_db_cur = idx_db.cursor()

    nfl_db = sqlite3.connect('../nfl_tweets.db')
    nfl_db_cur = nfl_db.cursor()

    print "Getting already added ids"
    already_added_ids = set()
    nfl_db_cur.execute('SELECT idx_id FROM nfl_tweets')
    for row in nfl_db_cur:
        already_added_ids.add(row[0])

    print "Found {0} ids".format(len(already_added_ids))

    count = 0

    idx_db_cur.execute("SELECT rowid, utc_timestamp, content FROM tweets")
    for row in idx_db_cur:
        if row[0] not in already_added_ids:
            if include_regex.search(row[2]) is not None and exclude_regex.search(row[2]) is None:

                keywords = extractor.extract(row[2])

                if 'nfl' in keywords:
                    count += 1
                    nfl_db_cur.execute(
                        'INSERT INTO nfl_tweets(idx_id,created,keywords) VALUES(?,?,?)',
                        (row[0],row[1],keywords)
                    )

                    if count > 0 and count % 10000 == 0:
                        sys.stdout.write('.')
                        sys.stdout.flush()

    if count > 10000:
        print
        print 'added', count, 'records'
        print 'committing'
    else:
        print 'nothing new found'

    nfl_db.commit()
    print 'closing'

    nfl_db_cur.close()
    nfl_db.close()    

    idx_db_cur.close()
    idx_db.close()
コード例 #3
0
ファイル: dialog_data_processor.py プロジェクト: XL2248/GRADE
    def _obtain_candidate_keywords(self, load_file_if_exists=True):
        r"""Obtains and saves the candidate keywords used for extracting keywords.

        Args:
            load_file_if_exists: A 'bool' indicating whether load candi_keywords file if it exists.

        Returns:
            candi_keywords: A 'list' containing all the candidate keywords.
        """
        if load_file_if_exists:
            candi_keywords_name = '../data/{}/candi_keywords.txt'.format(self.output_data_dir)
            if os.path.isfile(candi_keywords_name):
                with open(candi_keywords_name,'r') as f:
                    candi_keywords = [kw.strip() for kw in f.readlines()]
                print('Loading candidate keywords from {}'.format(candi_keywords_name))
                print('Total candidate keywords count: ', len(candi_keywords))
                return candi_keywords

        print('Obtaining candidate keywords...')

        # Initialization
        candi_keywords = []
        kw_counter = collections.Counter()
        kw_extractor = KeywordExtractor()

        # Extracts possible keywords.
        for dialog in tqdm(self.list_all_dialogs):
            for utterance in dialog:
                cur_keywords = kw_extractor.candi_extract(utterance)
                kw_counter.update(cur_keywords)
                candi_keywords.extend(cur_keywords)

        # Deletes the keywords occurring less than specified times
        rare_keywords = [kw for kw, freq in kw_counter.most_common()
            if freq < self.min_kw_freq]
        candi_keywords = [kw for kw, freq in kw_counter.most_common()
            if freq >= self.min_kw_freq]
        # Deletes keywords containing only one single letter
        single_letter_keywords = [kw for kw in candi_keywords if len(kw) < 2]
        candi_keywords = [kw for kw in candi_keywords if len(kw) >= 2]

        # Writes candidate keywords into file
        candidate_keywords_output_path = '../data/{}/candi_keywords.txt'.format(
            self.output_data_dir)
        with open(candidate_keywords_output_path,'w') as f:
            for keyword in candi_keywords:
                f.write(keyword + '\n')

        return candi_keywords
コード例 #4
0
 def handle(self, *args, **options):
     questions = Question.objects.all()
     ke = KeywordExtractor()
     for question in questions:
         if question.date.year != 2016 or question.date.month != 7:
             continue
         question.keywords = []
         keywords = ke.get_keywords(question.question)
         print ",".join(keywords)
         for keyword in keywords:
             m, created = Keyword.objects.get_or_create(keyword=keyword)
             m.keyword = keyword
             question.keywords.add(m)
             m.save()
         question.save()
コード例 #5
0
ファイル: main.py プロジェクト: maroy/TSTA
def main():
    extractor = KeywordExtractor()

    idx_db = sqlite3.connect('../idx.db')
    idx_db_cur = idx_db.cursor()

    kwd_db = sqlite3.connect('../kwd.db')
    kwd_db_cur = kwd_db.cursor()

    print "Getting already added ids"
    already_added_ids = set()
    kwd_db_cur.execute('SELECT idx_id FROM keywords')
    for row in kwd_db_cur:
        already_added_ids.add(row[0])

    print "Getting found {0} ids".format(len(already_added_ids))

    count = 0

    idx_db_cur.execute("SELECT rowid, utc_timestamp, content FROM tweets")
    for row in idx_db_cur:
        if row[0] not in already_added_ids:
            count += 1
            keywords = extractor.extract(row[2])

            kwd_db_cur.execute(
                'INSERT INTO keywords(idx_id,created,content) VALUES(?,?,?)',
                (row[0],row[1],keywords)
            )

            if count > 0 and count % 10000 == 0:
                sys.stdout.write('.')
                sys.stdout.flush()

    if count > 10000:
        print
        print 'committing'
    else:
        print 'nothing new found'

    kwd_db.commit()
    print 'closing'

    kwd_db_cur.close()
    kwd_db.close()    

    idx_db_cur.close()
    idx_db.close()
コード例 #6
0
 def run(self):
     """
     Loops through the keywords, and uploads an article for each.
     """
     keywords = KeywordExtractor.extract(config.KEYWORDS_FILE_PATH)
     for keyword in keywords:
         self._upload_article(keyword)
コード例 #7
0
def create_tf_idf(file_path):
    reader = TrainingTextReader(file_path)
    keywords = KeywordExtractor(reader.articles[10], 'useless.txt')
    vector_index = Vectorizer(keywords.article_sents_tokened)
    freq_mat = vector_index.frequencyMatrix
    normalized_vector = VectorNormalizer(freq_mat)
    norm_mat = normalized_vector.l2_norm_matrice
    tf_idf = InverseDocumentFrequency(norm_mat)
    return tf_idf.tf_idf_matrice
コード例 #8
0
ファイル: igdb.py プロジェクト: photown/videogame-oracle
    def __init__(self):
        self.keyword_extractor = KeywordExtractor()

        self.publisher_id_to_name = {}
        self.platform_id_to_name = {}
        self.theme_id_to_name = {}
        self.genre_id_to_name = {}
        self.game_mode_id_to_name = {}
        self.game_keyword_id_to_name = {}

        self.fetch_publishers = self.__add_attr_to_game_data(
            'publishers', 'companies', self.publisher_id_to_name)
        self.fetch_platforms = self.__add_attr_to_game_data(
            'platform', 'platforms', self.platform_id_to_name)
        self.fetch_themes = self.__add_attr_to_game_data(
            'themes', 'themes', self.theme_id_to_name)
        self.fetch_genres = self.__add_attr_to_game_data(
            'genres', 'genres', self.genre_id_to_name)
        self.fetch_game_modes = self.__add_attr_to_game_data(
            'game_modes', 'game_modes', self.game_mode_id_to_name)
コード例 #9
0
ファイル: dialog_data_processor.py プロジェクト: XL2248/GRADE
    def __init__(self, dataset_name, output_data_dir,
                 separator, min_kw_freq,
                 context_turns, set_names):
        self.dataset_name = dataset_name
        self.output_data_dir = output_data_dir
        self.separator = separator
        self.min_kw_freq = min_kw_freq
        self.context_turns = context_turns
        self.set_names = set_names

        self._make_data_dir_if_not_exists()
        self._load_raw_dialog_data()

        # Initializes keyword extractor
        candi_keywords = self._obtain_candidate_keywords()
        idf_dict = self._calculate_idf()
        self.kw_extractor = KeywordExtractor(candi_keywords, idf_dict)

        self._obtain_and_save_uttr_kw_mapping()  # uttr_kw_mapping: (utterances -> keywords) mapping
        self._obtain_and_save_vocab()
コード例 #10
0
class KeywordExtractorTests(unittest.TestCase):

    def setUp(self):
        self.extractor = KeywordExtractor()

    def tearDown(self):
        pass

    def test_extract_sanity(self):
        keyword_list = self.extractor.extract(CSV_PATH)
        self.assertEqual(type(keyword_list), types.ListType)
        our_list = []
        with open(CSV_PATH) as f:
            reader = csv.reader(f)
            reader.next()
            for row in reader:
                our_list.append(row[0])

        self.assertItemsEqual(keyword_list, our_list)
コード例 #11
0
ファイル: extract_keywords.py プロジェクト: XL2248/GRADE
    parser.add_argument('--candi_kw_path', type=str, help='path of candidate keywords file')
    parser.add_argument('--input_text_path', type=str, help='path of dialog text that need extracting keywords')
    parser.add_argument('--kw_output_path', type=str, help='path of dialog text that need extracting keywords')
    args = parser.parse_args()

    output_info = 'Start keyword extraction [dataset: {}, file: {}]'.format(
        args.dataset_name, args.input_text_path)
    print('-' * len(output_info))
    print(output_info)
    print('-' * len(output_info))

    # initialize keyword extractor
    try:
        candi_keywords = _obtain_candidate_keywords(None, args.candi_kw_path)
        idf_dict = _calculate_idf(None, args.idf_path)
        kw_extractor = KeywordExtractor(candi_keywords, idf_dict)
    except Exception as err:
        print('Exception: ', err)
        # load all dialogs of the specific dataset
        dataset = load_dataset(args.dataset_name, args.dataset_dir)
        candi_keywords = _obtain_candidate_keywords(dataset, args.candi_kw_path)
        idf_dict = _calculate_idf(dataset, args.idf_path)
        kw_extractor = KeywordExtractor(candi_keywords, idf_dict)


    # load texts that need extracting keywords
    texts = load_texts(args.input_text_path)
    # extract keywords
    extract_keywords(texts, kw_extractor, args.kw_output_path)
    print('Done.')
コード例 #12
0
ファイル: main.py プロジェクト: yuyichen09/TextRank
from keyword_extractor import KeywordExtractor
import argparse

ap = argparse.ArgumentParser()
ap.add_argument("--word2vec",
                default=None,
                help="path to word2vec pre-trained embeddings")
ap.add_argument("--data",
                required=True,
                help="path to file from which keywords are to be extracted")

args = ap.parse_args()

with open(args.data, 'r') as data_file:
    lines = data_file.readlines()

extractor = KeywordExtractor(word2vec=args.word2vec)

for text in lines:
    keywords = extractor.extract(text, ratio=0.2, split=True, scores=True)
    for keyword in keywords:
        print(keyword)
コード例 #13
0
ファイル: igdb.py プロジェクト: photown/videogame-oracle
class IgdbFetcher:

    GET_GAME_INFO_TEMPLATE = \
        "https://igdbcom-internet-game-database-v1.p.mashape.com/games/" \
        "?fields=name,summary,storyline,publishers,themes,keywords," \
        "game_modes,genres,first_release_date,release_dates" \
        "&limit=20&offset=0&search=%s"
    ENDPOINT_API_TEMPLATE = \
        "https://igdbcom-internet-game-database-v1.p.mashape.com/" \
        "%s/%s?fields=name"
    STOP_WORDS = {'game', 'player', 'gameplay'}

    def __init__(self):
        self.keyword_extractor = KeywordExtractor()

        self.publisher_id_to_name = {}
        self.platform_id_to_name = {}
        self.theme_id_to_name = {}
        self.genre_id_to_name = {}
        self.game_mode_id_to_name = {}
        self.game_keyword_id_to_name = {}

        self.fetch_publishers = self.__add_attr_to_game_data(
            'publishers', 'companies', self.publisher_id_to_name)
        self.fetch_platforms = self.__add_attr_to_game_data(
            'platform', 'platforms', self.platform_id_to_name)
        self.fetch_themes = self.__add_attr_to_game_data(
            'themes', 'themes', self.theme_id_to_name)
        self.fetch_genres = self.__add_attr_to_game_data(
            'genres', 'genres', self.genre_id_to_name)
        self.fetch_game_modes = self.__add_attr_to_game_data(
            'game_modes', 'game_modes', self.game_mode_id_to_name)

    def get_game_info(self, game_data):

        response = unirest.get(self.GET_GAME_INFO_TEMPLATE % game_data.name,
                               headers={
                                   "X-Mashape-Key": os.environ['IGDB_KEY'],
                                   "Accept": "application/json"
                               }
                               )
        game_info = None
        game_name_lower = game_data.name.lower().strip()
        for response_game in response.body:
            if 'name' not in response_game:
                continue
            if game_name_lower == response_game['name'].lower().strip():
                game_info = response_game
                break

        if not game_info:
            return False
        if not self.__validate_field(game_info, 'release_dates'):
            return False
        if not self.__validate_field(game_info, 'publishers'):
            return False
        if not self.__validate_field(game_info, 'themes') and \
                not self.__validate_field(game_info, 'genres'):
            return False
        if not self.__validate_field(game_info, 'game_modes'):
            return False
        if 'first_release_date' not in game_info:
            return False
        if 'summary' not in game_info and 'storyline' not in game_info:
            return False

        for release_date in game_info['release_dates']:
            self.fetch_platforms(release_date, game_data.add_platform)

        if 'themes' in game_info:
            self.fetch_themes(game_info, game_data.add_genre)
        if 'genres' in game_info:
            self.fetch_genres(game_info, game_data.add_genre)

        self.fetch_publishers(game_info, game_data.add_publisher)
        self.fetch_game_modes(game_info, game_data.add_game_mode)

        release_date_timestamp = game_info['first_release_date']
        release_date = datetime.datetime.fromtimestamp(
            release_date_timestamp / 1000)
        game_data.release_date = release_date
        release_day_of_year = release_date.timetuple().tm_yday
        quarter = int(release_day_of_year / (367 / 4.0))
        game_data.release_quarter = quarter

        if 'summary' in game_info:
            summary = game_info['summary']
            summary_keywords = self.__extract_keywords(summary)
            game_data.add_keywords(summary_keywords)

        if 'storyline' in game_info:
            storyline = game_info['storyline']
            storyline_keywords = self.__extract_keywords(storyline)
            game_data.add_keywords(storyline_keywords)

        print "response body = " + str(response.body)
        return True

    def __validate_field(self, game_info, field_name):
        return field_name in game_info and len(game_info[field_name]) > 0

    def __is_valid_keyword(self, keyword):
        return keyword not in self.STOP_WORDS and \
            re.match("^[A-Za-z]+$", keyword)

    def __extract_keywords(self, text):
        keyword_tuples = self.keyword_extractor.extract(text)
        keywords = []
        for keyword, _, _ in keyword_tuples:
            if self.__is_valid_keyword(keyword):
                keywords.append(keyword)
        return keywords

    def __add_attr_to_game_data(self, attr_name, endpoint_name, attr_map):
        def f(game_info, add_func):
            if attr_name not in game_info:
                print "Attribute %s is empty, skipping." % attr_name
                return
            if type(game_info[attr_name]) == list:
                for attr_id in game_info[attr_name]:
                    if attr_id not in attr_map:
                        fetched_name = self.__fetch_endpoint(
                            endpoint_name, attr_id)
                        if not fetched_name:
                            continue
                        attr_map[attr_id] = fetched_name
                    add_func(attr_map[attr_id])
            else:
                attr_id = game_info[attr_name]
                if attr_id not in attr_map:
                    fetched_name = self.__fetch_endpoint(
                        endpoint_name, attr_id)
                    attr_map[attr_id] = fetched_name
                add_func(attr_map[attr_id])
        return f

    def __fetch_endpoint(self, endpoint_name, id):
        url = self.ENDPOINT_API_TEMPLATE % (endpoint_name, id)
        response = unirest.get(url, headers={
            "X-Mashape-Key": os.environ['IGDB_KEY']
        })
        if not type(response.body) == list or len(response.body) == 0 or \
                'name' not in response.body[0]:
            return None
        return response.body[0]['name']
コード例 #14
0
from bs4 import BeautifulSoup
import urllib3
import random

# Custom Libs
from article_lister import ArticleLister
from keyword_extractor import KeywordExtractor
from news_db_storer import NewsDBStorer

cache_dir = "./pkl_cache/"

dbstore = NewsDBStorer(db_name="newsarticlesdb",
                       table_name="politician_based_newsarticlestable")
dbstore.set_up_connection()

keyword_xtractor = KeywordExtractor()


class GenericNewsScraper:
    def __init__(self, paper_name="cnn", base_url="https://www.cnn.com/"):
        self.articles = []
        self.base_url = base_url
        self.paper_name = paper_name
        self.art_obj = set()

    # Loads article cache
    def load_articles(self):
        f = open(cache_dir + self.paper_name + ".pkl", 'rb')
        cache_obj = pickle.load(f)
        f.close()
        return cache_obj
コード例 #15
0
ファイル: dialog_data_processor.py プロジェクト: XL2248/GRADE
class DialogDataProcessor:
    r"""Loads and pre-processes original dialog data.

    Attributes:
        dataset_name: A 'str' indicating the dataset name.
        output_data_dir: A 'str' indicating the output data directory's name.
        separator: A 'str' used to separate two utterances.
        min_kw_freq: An 'int' indicating the minimum of keyword occurrence frequency.
        context_turns: An 'int' indicating the number of turns of each dialog context.
        set_names: A 'list' of 'str' containing the set names,
            e.g., ['train', 'validation', 'test'].
    """

    def __init__(self, dataset_name, output_data_dir,
                 separator, min_kw_freq,
                 context_turns, set_names):
        self.dataset_name = dataset_name
        self.output_data_dir = output_data_dir
        self.separator = separator
        self.min_kw_freq = min_kw_freq
        self.context_turns = context_turns
        self.set_names = set_names

        self._make_data_dir_if_not_exists()
        self._load_raw_dialog_data()

        # Initializes keyword extractor
        candi_keywords = self._obtain_candidate_keywords()
        idf_dict = self._calculate_idf()
        self.kw_extractor = KeywordExtractor(candi_keywords, idf_dict)

        self._obtain_and_save_uttr_kw_mapping()  # uttr_kw_mapping: (utterances -> keywords) mapping
        self._obtain_and_save_vocab()

    def process_original_data(self):
        for name in self.set_names:
            self.current_set_name = name
            print('\nStart processing {} set...'.format(name))
            print('-' * 50)
            self._obtain_original_dialogs()
            self._extract_original_dialogs_keywords()
            self._save_processed_original_dialogs()
            print('-' * 50)

    def _make_data_dir_if_not_exists(self):
        output_data_path = '../data/{}'.format(self.output_data_dir)
        if not os.path.exists(output_data_path):
            os.makedirs(output_data_path)
        for set_name in self.set_names:
            pair1_path = os.path.join(output_data_path, set_name, 'pair-1')
            if not os.path.exists(pair1_path):
                os.makedirs(pair1_path)

    def _calculate_idf(self, load_file_if_exists=True):
        r"""Calculates and saves the IDF values for extracting keywords.

        Args:
            load_file_if_exists: A 'bool' indicating whether load IDF file if it exists.

        Returns:
            idf_dict: A 'Dict' containing all the IDF values of keywords.
        """
        if load_file_if_exists:
            idf_dict_name = '../data/{}/idf.dict'.format(self.output_data_dir)
            if os.path.isfile(idf_dict_name):
                with open(idf_dict_name, 'rb') as f:
                    idf_dict = pickle.load(f)
                print('Loading idf dict from {}'.format(idf_dict_name))
                print('idf dict size: ', len(idf_dict))
                return idf_dict

        print('Calculating idf...')

        # Calculates IDF
        counter = collections.Counter()
        total = 0.
        for dialog in tqdm(self.list_all_dialogs):
            for utterance in dialog:
                total += 1
                counter.update(set(kw_tokenize(utterance)))
        idf_dict = {}
        for k,v in counter.items():
            idf_dict[k] = np.log10(total / (v+1.))

        # Writes idf dict into file
        with open('../data/{}/idf.dict'.format(self.output_data_dir), 'wb') as f:
            pickle.dump(idf_dict, f)

        return idf_dict

    def _obtain_candidate_keywords(self, load_file_if_exists=True):
        r"""Obtains and saves the candidate keywords used for extracting keywords.

        Args:
            load_file_if_exists: A 'bool' indicating whether load candi_keywords file if it exists.

        Returns:
            candi_keywords: A 'list' containing all the candidate keywords.
        """
        if load_file_if_exists:
            candi_keywords_name = '../data/{}/candi_keywords.txt'.format(self.output_data_dir)
            if os.path.isfile(candi_keywords_name):
                with open(candi_keywords_name,'r') as f:
                    candi_keywords = [kw.strip() for kw in f.readlines()]
                print('Loading candidate keywords from {}'.format(candi_keywords_name))
                print('Total candidate keywords count: ', len(candi_keywords))
                return candi_keywords

        print('Obtaining candidate keywords...')

        # Initialization
        candi_keywords = []
        kw_counter = collections.Counter()
        kw_extractor = KeywordExtractor()

        # Extracts possible keywords.
        for dialog in tqdm(self.list_all_dialogs):
            for utterance in dialog:
                cur_keywords = kw_extractor.candi_extract(utterance)
                kw_counter.update(cur_keywords)
                candi_keywords.extend(cur_keywords)

        # Deletes the keywords occurring less than specified times
        rare_keywords = [kw for kw, freq in kw_counter.most_common()
            if freq < self.min_kw_freq]
        candi_keywords = [kw for kw, freq in kw_counter.most_common()
            if freq >= self.min_kw_freq]
        # Deletes keywords containing only one single letter
        single_letter_keywords = [kw for kw in candi_keywords if len(kw) < 2]
        candi_keywords = [kw for kw in candi_keywords if len(kw) >= 2]

        # Writes candidate keywords into file
        candidate_keywords_output_path = '../data/{}/candi_keywords.txt'.format(
            self.output_data_dir)
        with open(candidate_keywords_output_path,'w') as f:
            for keyword in candi_keywords:
                f.write(keyword + '\n')

        return candi_keywords

    def _obtain_and_save_uttr_kw_mapping(self, load_file_if_exists=True):
        r"""Obtains and saves the mapping that maps utterances into keywords they contain.

        Args:
            load_file_if_exists: A 'bool' indicating whether load mapping file if it exists.

        Returns:
            uttr_kw_mapping: A 'dict' containing utterances->keywords mapping.
        """
        if load_file_if_exists:
            uttr_kw_mapping_name = '../data/{}/uttr_kw.dict'.format(
                self.output_data_dir)
            if os.path.isfile(uttr_kw_mapping_name):
                with open(uttr_kw_mapping_name, 'rb') as f:
                    self.uttr_kw_mapping = pickle.load(f)
                print('Loading utterances->keywords mapping from {}'.format(
                    uttr_kw_mapping_name))
                print('(utterances -> keyword) mapping size: ',
                    len(self.uttr_kw_mapping))
                return

        print('Obtaining mapping from utterances to keywords...')

        # Extracts keywords to construct mapping
        self.uttr_kw_mapping = {}
        for dialog in tqdm(self.list_all_dialogs):
            for utterance in dialog:
                cur_keywords = self.kw_extractor.idf_extract(utterance)
                self.uttr_kw_mapping[utterance] = cur_keywords
        print('(utterances -> keyword) mapping size: ', len(self.uttr_kw_mapping))

        # Writes uttr_kw_mapping into file
        with open('../data/{}/uttr_kw.dict'.format(self.output_data_dir), 'wb') as f:
            pickle.dump(self.uttr_kw_mapping, f)

    def _obtain_and_save_vocab(self, load_file_if_exists=True):
        r"""Obtains and saves the vocabulary of data.
        Args:
            load_file_if_exists: A 'bool' indicating whether load vocab file if it exists.

        Returns:
            vocab: A 'list' containing all the words occurring in the data.
        """
        if load_file_if_exists:
            vocab_name = '../data/{}/vocab.txt'.format(self.output_data_dir)
            if os.path.isfile(vocab_name):
                with open(vocab_name,'r') as f:
                    self.vocab = [word.strip() for word in f.readlines()]
                print('Loading vocab from {}'.format(vocab_name))
                print('Total vocab count: ', len(self.vocab))
                return

        print('Obtain and save vocab...')

        counter = collections.Counter()
        for dialog in tqdm(self.list_all_dialogs):
            for utterance in dialog:
                counter.update(simp_tokenize(utterance))
        print('Total vocab count: ', len(counter.items()))

        # Vocab sorted by occurrence frequency (descending order)
        self.vocab = [token for token, _ in
            sorted(list(counter.items()), key=lambda x: (-x[1], x[0]))]

        # Writes vocab into file
        with open('../data/{}/vocab.txt'.format(self.output_data_dir),'w') as f:
            for word in self.vocab:
                f.write(word + '\n')

    def _load_raw_dialog_data(self):
        r"""Loads raw dialog data from files.

        Returns:
            list_all_dialogs: A 'list' containing all the dialogues, where each
                dialogue is also a 'list' containing all the utterances of this dialogue.
            dict_categorized_dialogs: a 'dict' containing the dialogue list of
                training, validation and testing set.
        """
        print('Loading raw dialog data...')
        self.list_all_dialogs = []
        self.dict_categorized_dialogs = {}
        for set_name in self.set_names:
            current_dialog_path = os.path.join(self.raw_data_dir,
                                               set_name,
                                               'dialogues_{}.txt'.format(set_name))
            with open(current_dialog_path, 'r') as f:
                raw_dialog_data = f.readlines()
            for dialog_str in tqdm(raw_dialog_data):
                dialog = self._process_dialog_str(dialog_str)
                self.list_all_dialogs.append(dialog)
                try:
                    self.dict_categorized_dialogs[set_name].append(dialog)
                except:
                    self.dict_categorized_dialogs[set_name] = [dialog]

    def _obtain_original_dialogs(self):
        # Augments the dialog data by divide each dialog into several sub-dialogs.
        print('Obtaining original dialogs...')
        self.original_dialogs = []
        for dialog in tqdm(self.list_current_dialogs):
            self.original_dialogs.extend(
                self.split_dialog(dialog, self.context_turns))

    def _extract_original_dialogs_keywords(self):
        self.original_dialogs_keywords = []
        print('Extracting keywords in original dialogs...')
        for dialog in tqdm(self.original_dialogs):
            current_dialog_keywords = []
            for utterance in dialog:
                keywords_str = ' '.join(self.uttr_kw_mapping[utterance])
                current_dialog_keywords.append(keywords_str)
            self.original_dialogs_keywords.append(current_dialog_keywords)

    def _save_processed_original_dialogs(self):
        # Saves all processed original dialog data into files
        print('Writing original dialog data into files...')
        o_text_path = os.path.join(self.current_set_output_dir,
                                   'pair-1',
                                   'original_dialog.text')
        o_kw_path = os.path.join(self.current_set_output_dir,
                                 'pair-1',
                                 'original_dialog.keyword')
        o_res_text_path = os.path.join(self.current_set_output_dir,
                                       'pair-1',
                                       'original_dialog_response.text')
        o_uni_res_text_path = os.path.join(self.current_set_output_dir,
                                       'pair-1',
                                       'original_dialog_response_uni.text')
        str_original_dialogs = self.element_to_str(self.original_dialogs, '|||')
        str_original_dialogs_keywords = self.element_to_str(
            self.original_dialogs_keywords, '|||')
        str_original_responses = [dialog[-1] for dialog in self.original_dialogs]
        uni_str_original_responses = [dialog.split('|||')[-1] for dialog in list(set(self.element_to_str(self.original_dialogs, '|||')))]
        self.save(str_original_dialogs, o_text_path)
        self.save(str_original_dialogs_keywords, o_kw_path)
        self.save(str_original_responses, o_res_text_path)
        self.save(uni_str_original_responses, o_uni_res_text_path)

    def _process_dialog_str(self, dialog_str):
        dialog = dialog_str.split(self.separator)[:-1]
        dialog = self.replace_content_in_dialog(dialog, old_content='.', new_content=' . ')
        dialog = self.replace_content_in_dialog(dialog, old_content='?', new_content=' ? ')
        dialog = self.replace_content_in_dialog(dialog, old_content=',', new_content=' , ')
        dialog = self.replace_content_in_dialog(dialog, old_content=' ’ ', new_content="'")
        dialog = [utterance.strip() for utterance in dialog]
        return dialog
# Private Methods - End
# -----------------------------------------------------------------------------

    @property
    def raw_data_dir(self):
        return './dataset/{}'.format(self.dataset_name)

    @property
    def list_current_dialogs(self):
        return self.dict_categorized_dialogs[self.current_set_name]

    @property
    def current_set_output_dir(self):
        return '../data/{}/{}/'.format(self.output_data_dir, self.current_set_name)

    @staticmethod
    def split_dialog(dialog, context_turns=1):
        r"""Split dialog into several sub-dialogs.
        Inputs: dialog, context_turns
            - **dialog**:        a 'list' containing utterances in the dialog
            - **context_turns**: how many turns of a dialogue containing in a context
        Outputs: sub_dialogs
            - **sub_dialogs**: a 'list' containing sub-dialogs
                            with respect to the current dialog

        Example:
            dialog: ['Hello!', 'Hi!', 'What's your name?', 'James.']

            assume context_turns = 1
        => (split dialog into contexts(previous utterance) and responses)

            contexts: [
                ['Hello!', 'Hi!'],
                ['Hi!', 'What's your name?'],
            ]
            responses: [
                ['What's your name?']
                ['James.']
            ]

        => (merge contexts and responses one by one)

            sub_dialogs: [
                ['Hello!', 'Hi!', 'What's your name?'],
                ['Hi!', 'What's your name?', 'James.']
            ]
        """
        num_uttr_in_context = context_turns * 2
        contexts = [
            dialog[i:i+num_uttr_in_context]
            for i in range(0, len(dialog) - num_uttr_in_context)
        ]
        responses = [[dialog[i]] for i in range(num_uttr_in_context, len(dialog))]
        sub_dialogs = [context + response for context, response in zip(contexts, responses)]
        return sub_dialogs

    @staticmethod
    def save(contents, output_path):
        with open(output_path, 'w') as f:
            for content in tqdm(contents):
                f.write(content + '\n')

    @staticmethod
    def element_to_str(contents, seperator):
        # each element in 'contents' is also a list
        return [seperator.join(element) for element in contents]

    @staticmethod
    def replace_content_in_dialog(dialog, old_content, new_content):
        r"""Replace specified content in the dialog with given new content.

        Inputs: dialog, separator
            - **dialog**:      a 'list' containing utterances in the dialog
            - **old_content**: a 'str' indicating the content needed to be replaced in the dialog
            - **new_content**: a 'str' indicating the content used to replace the old content
        Outputs: replaced_dialog
            - **replaced_dialog**: a 'list' containing all replaced utterances in the dialog

        Example:
            For an utterance ['Hello.My name is James . '],
            We wanna replace the '.' with ' . ', the procedure is as follow:
                1. first replace ' . ' with '.' obtained ['Hello.My name is James.']
                2. then replace '.' with ' . '  obtained ['Hello . My name is James . ']
        Note:
            if we replace 'old_content' with 'new_content' directly, in this example, we would get:
            ['Hello . My name is James  .  ']
        """
        # first replace the 'new_content' with 'old_content'
        # to ensure there're no utterances containing the specified 'new_content'
        replaced_dialog = [utterance.replace(new_content, old_content) for utterance in dialog]
        replaced_dialog = [utterance.replace(old_content, new_content) for utterance in replaced_dialog]
        return replaced_dialog
コード例 #16
0
 def setUp(self):
     self.extractor = KeywordExtractor()