def __init__(
            self,
            dataset_name: str = config["data"]["dataset_name"],
            dataset_type=config["xc_datasets"][config["data"]["dataset_name"]],
            dataset_dir: str = config["paths"]["dataset_dir"][plat][user]):
        """
        Loads train val or test data based on run_mode.

        Args:
            dataset_dir : Path to directory of the dataset.
            dataset_name : Name of the dataset.
        """
        super(Common_Data_Handler, self).__init__()
        self.dataset_name = dataset_name
        self.dataset_type = dataset_type
        self.dataset_dir = join(dataset_dir, self.dataset_name)

        self.clean = Text_Process()

        self.cattext2catid_map = None
        self.catid2cattxt_map = None
        self.txts_sel, self.sample2cats_sel, self.cats_sel = None, None, None
        self.txts_train, self.sample2cats_train, self.cats_sel = None, None, None
        self.txts_test, self.sample2cats_test, self.cats_test = None, None, None
        self.txts_val, self.sample2cats_val, self.cats_val = None, None, None
Exemple #2
0
    def __init__(self,dataset_name=config["data"]["dataset_name"],
                 dataset_dir: str = config["paths"]["dataset_dir"][plat][user]):
        """
        Initializes the html loader.

        Args:
            dataset_dir : Path to directory of the dataset.
            dataset_name : Name of the dataset.
        """
        super(HTMLLoader,self).__init__()
        self.dataset_name = dataset_name
        # self.dataset_dir = join(dataset_dir,self.dataset_name)
        self.dataset_dir = dataset_dir
        self.raw_html_dir = join(self.dataset_dir,dataset_name + "_RawData")
        self.raw_txt_dir = join(self.dataset_dir,"txt_files")
        logger.debug("Dataset name: [{}], Directory: [{}]".format(self.dataset_name,self.dataset_dir))
        self.clean = Text_Process()
        self.txts,self.classes,self.cats = self.gen_dicts()
Exemple #3
0
    def __init__(self,
                 dataset_name: str = config["data"]["dataset_name"],
                 graph_format: str = "graphml",
                 top_k: int = 10,
                 graph_dir: str = config["paths"]["dataset_dir"][plat][user]):
        """

        :param dataset_name:
        :param graph_dir:
        :param graph_format:
        :param top_k:
        """
        super(Neighborhood_Graph, self).__init__()
        self.graph_dir = graph_dir
        self.dataset_name = dataset_name
        self.graph_format = graph_format
        self.top_k = top_k
        self.txt_process = Text_Process()
Exemple #4
0
    def __init__(self,dataset_loader,dataset_name: str = config["data"]["dataset_name"],
                 dataset_dir: str = config["paths"]["dataset_dir"][plat][user]) -> None:
        self.dataset_name = dataset_name
        self.dataset_dir = dataset_dir
        self.dataset_loader = dataset_loader

        self.txts2vec_map = None
        self.doc2vec_model = None
        self.cats_all = None
        self.txt_encoder_model = None
        self.txts_sel,self.sample2cats_sel,self.cats_sel = None,None,None

        self.graph = Neighborhood_Graph()
        self.txt_process = Text_Process()
        self.txt_encoder = Text_Encoder()
        self.mlb = MultiLabelBinarizer()

        # dataset_loader.gen_data_stats()
        self.oov_words_dict = OrderedDict()
Exemple #5
0
class HTMLLoader(torch.utils.data.Dataset):
    """
    Class to process and load html files from a directory.

    Datasets: Wiki10-31K

    txts : Wikipedia english texts after parsing and cleaning.
    txts = {"id1": "wiki_text_1", "id2": "wiki_text_2"}

    sample2cats   : OrderedDict of id to classes.
    sample2cats = {"id1": [class_id_1,class_id_2],"id2": [class_id_2,class_id_10]}

    cattext2catid_map : Dict of class texts.
    cattext2catid_map = {"Computer Science":class_id_1, "Machine Learning":class_id_2}

    samples : {
        "txts":"",
        "classes":""
        }
    """

    def __init__(self,dataset_name=config["data"]["dataset_name"],
                 dataset_dir: str = config["paths"]["dataset_dir"][plat][user]):
        """
        Initializes the html loader.

        Args:
            dataset_dir : Path to directory of the dataset.
            dataset_name : Name of the dataset.
        """
        super(HTMLLoader,self).__init__()
        self.dataset_name = dataset_name
        # self.dataset_dir = join(dataset_dir,self.dataset_name)
        self.dataset_dir = dataset_dir
        self.raw_html_dir = join(self.dataset_dir,dataset_name + "_RawData")
        self.raw_txt_dir = join(self.dataset_dir,"txt_files")
        logger.debug("Dataset name: [{}], Directory: [{}]".format(self.dataset_name,self.dataset_dir))
        self.clean = Text_Process()
        self.txts,self.classes,self.cats = self.gen_dicts()

    def gen_dicts(self):
        """Filters txts, sample2cats and cattext2catid_map from wikipedia text.

        :return: Dict of txts, sample2cats and cattext2catid_map filtered from samples.
        """

        if isdir(self.raw_txt_dir):
            logger.info("Loading data from TXT files.")
            self.samples = self.read_txt_dir(self.raw_txt_dir)
        else:
            logger.info("Could not find TXT files: [{}]".format(self.raw_txt_dir))
            logger.info("Loading data from HTML files.")
            html_parser = self.get_html_parser()
            self.samples = self.read_html_dir(html_parser)

        classes = OrderedDict()
        hid_classes = OrderedDict()
        cats = OrderedDict()
        hid_cats = OrderedDict()
        txts = OrderedDict()
        cat_idx = 0
        hid_cat_idx = 0
        no_cat_ids = []  # List to store failed parsing cases.
        for doc_id,txt in self.samples.items():
            txt = list(filter(None,txt))  # Removing empty items
            doc,filtered_cats,filtered_hid_cats = self.clean.filter_html_cats_reverse(txt)
            ## assert filtered_cats, "No category information was found for doc_id: [{0}].".format(doc_id)
            if filtered_cats:  ## Check at least one category was successfully filtered from html file.
                txts[doc_id] = clean_wiki(doc)  ## Removing category information and other texts from html pages.
                for lbl in filtered_cats:
                    if lbl not in cats:  ## If lbl does not exists in cats already, add it and assign a
                        ## new category index.
                        cats[lbl] = cat_idx
                        cat_idx += 1
                    if doc_id in classes:  ## Check if doc_id exists, append if yes.
                        classes[doc_id].append(cats[lbl])
                    else:  ## Create entry for doc_id if does not exist.
                        classes[doc_id] = [cats[lbl]]
            else:  ## If no category was found, store the doc_id in a separate place for later inspection.
                logger.warn("No categories found in document: [{}].".format(doc_id))
                no_cat_ids.append(doc_id)

            ## Shall we use hidden category information?
            if filtered_hid_cats:  ## Check at least one hidden category was successfully filtered from html file.
                for lbl in filtered_hid_cats:
                    if lbl not in hid_cats:  ## If lbl does not exists in hid_cats already, add it and
                        ## assign a new hid_category index.
                        hid_cats[lbl] = hid_cat_idx
                        hid_cat_idx += 1
                    if doc_id in hid_classes:  ## Check if doc_id exists, append if yes.
                        hid_classes[doc_id].append(hid_cats[lbl])
                    else:  ## Create entry for doc_id if does not exist.
                        hid_classes[doc_id] = [hid_cats[lbl]]
        logger.warn("No cattext2catid_map found for: [{}] documents. Storing ids for reference in file '_no_cat_ids'."
                    .format(len(no_cat_ids)))
        File_Util.save_json(hid_classes,self.dataset_name + "_hid_classes",filepath=self.dataset_dir)
        File_Util.save_json(hid_cats,self.dataset_name + "_hid_cats",filepath=self.dataset_dir)
        File_Util.save_json(no_cat_ids,self.dataset_name + "_no_cat_ids",filepath=self.dataset_dir)
        logger.info("Number of txts: [{}], sample2cats: [{}] and cattext2catid_map: [{}]."
                    .format(len(txts),len(classes),len(cats)))
        return txts,classes,cats

    def read_txt_dir(self,raw_txt_dir: str,encoding: str = "iso-8859-1") -> OrderedDict:
        """
        Reads all txt files from [self.raw_txt_dir] folder as str and returns a OrderedDict[str(filename)]=str(content).

        :param raw_txt_dir:
        :param encoding:
        :param html_parser:
        :param dataset_dir: Path to directory of html files.
        """
        data = OrderedDict()
        if raw_txt_dir is None: raw_txt_dir = self.raw_txt_dir
        logger.info("Raw TXT path: {}".format(raw_txt_dir))
        if isdir(raw_txt_dir):
            for i in listdir(raw_txt_dir):
                if isfile(join(raw_txt_dir,i)) and i.endswith(".txt"):
                    with open(join(raw_txt_dir,i),encoding=encoding) as txt_ptr:
                        data[str(i[:-4])] = str(txt_ptr.read()).splitlines()  ## [:-4] removes the ".txt" from filename.
        else: return False
        return data

    @staticmethod
    def get_html_parser(alt_text: bool = True,ignore_table: bool = True,decode_errors: str = "ignore",
                        default_alt: str = "",
                        ignore_link: bool = True,reference_links: bool = True,bypass_tables: bool = True,
                        ignore_emphasis: bool = True,
                        unicode_snob: bool = False,no_automatic_links: bool = True,no_skip_internal_links: bool = True,
                        single_line_break: bool = True,
                        escape_all: bool = True,ignore_images: object = True):
        """
        Returns a html parser with config, based on: https://github.com/Alir3z4/html2text.

        Usage: https://github.com/Alir3z4/html2text/blob/master/docs/usage.md
        logger.debug(html_parser.handle("<p>Hello, <a href='http://earth.google.com/'>world</a>!"))

        ignore_links    : Ignore converting links from HTML
        images_to_alt   : Discard image data, only keep alt text
        ignore_tables   : Ignore table-related tags (table, th, td, tr) while keeping rows.
        decode_errors   : What to do in case an error is encountered. ignore, strict, replace etc.
        default_image_alt: Inserts the given alt text whenever images are missing alt values.
        :return: html2text parser.
        """
        logger.info("Getting HTML parser.")
        import html2text  ## https://github.com/Alir3z4/html2text

        html_parser = html2text.HTML2Text()
        # html_parser.images_to_alt = alt_text  ## Discard image data, only keep alt text
        html_parser.ignore_tables = ignore_table  ## Ignore table-related tags (table, th, td, tr) while keeping rows.
        html_parser.ignore_images = ignore_images  ## Ignore table-related tags (table, th, td, tr) while keeping rows.
        html_parser.decode_errors = decode_errors  ## Handling decoding error: "ignore", "strict", "replace" etc.
        html_parser.default_image_alt = default_alt  ## Inserts the given alt text whenever images are missing alt values.
        html_parser.ignore_links = ignore_link  ## Ignore converting links from HTML.
        html_parser.reference_links = reference_links  ## Use reference links instead of inline links to create markdown.
        html_parser.bypass_tables = bypass_tables  ## Format tables in HTML rather than Markdown syntax.
        html_parser.ignore_emphasis = ignore_emphasis  ## Ignore all emphasis formatting in the html.
        html_parser.unicode_snob = unicode_snob  ## Use unicode throughout instead of ASCII.
        html_parser.no_automatic_links = no_automatic_links  ## Do not use automatic links like http://google.com.
        ## html_parser.no_skip_internal_links = no_skip_internal_links  ## Turn off skipping of internal links.
        html_parser.skip_internal_links = True
        html_parser.single_line_break = single_line_break  ## Use a single line break after a block element rather than two.
        html_parser.escape_all = escape_all  ## Escape all special characters.
        return html_parser

    def read_html_dir(self,html_parser,encoding="iso-8859-1",specials="""_-@*#'"/\\""",replace=' '):
        """
        Reads all html files in a folder as str and returns a OrderedDict[str(filename)]=str(content).

        :param replace:
        :param specials:
        :param encoding:
        :param html_parser:
        :param dataset_dir: Path to directory of html files.
        """
        from unidecode import unidecode

        data = OrderedDict()
        # logger.debug("Raw HTML path: {}".format(self.raw_html_dir))
        makedirs(join(self.dataset_dir,"txt_files"),exist_ok=True)
        if isdir(self.raw_html_dir):
            trans_table = self.clean.make_trans_table(specials=specials,
                                                      replace=replace)  ## Creating mapping to clean txts.
            for i in listdir(self.raw_html_dir):
                if isfile(join(self.raw_html_dir,i)):
                    with open(join(self.raw_html_dir,i),encoding=encoding) as html_ptr:
                        h_content = html_parser.handle(html_ptr.read())
                        clean_text = unidecode(str(h_content).splitlines()).translate(trans_table)
                        File_Util.write_file(clean_text,i,filepath=join(self.dataset_dir,"txt_files"))
                        data[str(i)] = clean_text
        return data

    def get_data(self):
        """
        Function to get the entire dataset
        """
        return self.txts,self.classes,self.cats

    def get_txts(self):
        """
        Function to get the entire set of features
        """
        return self.txts

    def get_classes(self):
        """
        Function to get the entire set of sample2cats.
        """
        return self.classes

    def get_cats(self) -> dict:
        """
        Function to get the entire set of cattext2catid_map
        """
        return self.cats
class Common_Data_Handler:
    """
    Class to load and prepare and split pre-built json files.

    txts : Wikipedia english texts after parsing and cleaning.
    txts = {"id1": "wiki_text_1", "id2": "wiki_text_2"}

    sample2cats   : OrderedDict of id to sample2cats.
    sample2cats = {"id1": [class_id_1,class_id_2],"id2": [class_id_2,class_id_10]}

    categories : Dict of class texts.
    categories = {"Computer Science":class_id_1, "Machine Learning":class_id_2}
    """
    def __init__(
            self,
            dataset_name: str = config["data"]["dataset_name"],
            dataset_type=config["xc_datasets"][config["data"]["dataset_name"]],
            dataset_dir: str = config["paths"]["dataset_dir"][plat][user]):
        """
        Loads train val or test data based on run_mode.

        Args:
            dataset_dir : Path to directory of the dataset.
            dataset_name : Name of the dataset.
        """
        super(Common_Data_Handler, self).__init__()
        self.dataset_name = dataset_name
        self.dataset_type = dataset_type
        self.dataset_dir = join(dataset_dir, self.dataset_name)

        self.clean = Text_Process()

        self.cattext2catid_map = None
        self.catid2cattxt_map = None
        self.txts_sel, self.sample2cats_sel, self.cats_sel = None, None, None
        self.txts_train, self.sample2cats_train, self.cats_sel = None, None, None
        self.txts_test, self.sample2cats_test, self.cats_test = None, None, None
        self.txts_val, self.sample2cats_val, self.cats_val = None, None, None

    def multilabel2multiclass_df(self, df: pd.DataFrame):
        """Converts Multi-Label data in DataFrame format to Multi-Class data by replicating the samples.

        :param df: Dataframe containing repeated sample id and it's associated category.
        :returns: DataFrame with replicated samples.
        """
        if self.catid2cattxt_map is None:
            self.catid2cattxt_map = File_Util.load_json(
                filename=self.dataset_name + "_catid2cattxt_map",
                filepath=self.dataset_dir)
        idxs, cat = [], []
        for row in df.values:
            lbls = row[3][1:-1].split(
                ','
            )  ## When DataFrame is saved as csv, list is converted to str
            for lbl in lbls:
                lbl = lbl.strip()
                idxs.append(row[1])
                cat.append(lbl)

        df = pd.DataFrame.from_dict({"idx": idxs, "cat": cat})
        df = df[~df['cat'].isna()]
        df.to_csv(path_or_buf=join(self.dataset_dir, self.dataset_name +
                                   "_multiclass_df.csv"))

        logger.info("Data shape = {} ".format(df.shape))

        return df

    def gen_data_stats(self,
                       txts: dict = None,
                       sample2cats: dict = None,
                       cats: dict = None):
        """ Generates statistics about the data.

        Like:
         freq category id distribution: Category frequency distribution (sorted).
         sample ids with max number of categories:
         Top words: Most common words.
         category specific word dist: Words which are dominant in a particular categories.
         words per sample dist: Distribution of word count in a sample.
         words per category dist: Distribution of words per category.
         most co-occurring categories: Categories which has highest common sample.
         """
        # dict(sorted(words.items(), key=lambda x: x[1]))  # Sorting a dict by value.
        # sorted_d = sorted((value, key) for (key,value) in d.items())  # Sorting a dict by value.
        # dd = OrderedDict(sorted(d.items(), key=lambda x: x[1]))  # Sorting a dict by value.
        if sample2cats is None:
            txts, sample2cats, cats = self.load_full_json(return_values=True)

        cat_freq = OrderedDict()
        for k, v in sample2cats.items():
            for cat in v:
                if cat not in cat_freq:
                    cat_freq[cat] = 1
                else:
                    cat_freq[cat] += 1
        cat_freq_sorted = OrderedDict(
            sorted(cat_freq.items(),
                   key=lambda x: x[1]))  # Sorting a dict by value.
        logger.info("Category Length: {}".format(len(cat_freq_sorted)))
        logger.info("Category frequencies: {}".format(cat_freq_sorted))

    def load_full_json(self, return_values: bool = False):
        """
        Loads full dataset and splits the data into train, val and test.
        """
        if isfile(join(self.dataset_dir,self.dataset_name + "_txts.json"))\
                and isfile(
            join(self.dataset_dir,self.dataset_name + "_sample2cats.json"))\
                and isfile(
            join(self.dataset_dir,self.dataset_name + "_cats.json")):
            logger.info("Loading pre-processed json files from: [{}]".format(
                join(self.dataset_dir, self.dataset_name + "_txts.json")))
            txts = File_Util.load_json(self.dataset_name + "_txts",
                                       filepath=self.dataset_dir,
                                       show_path=True)
            classes = File_Util.load_json(self.dataset_name + "_sample2cats",
                                          filepath=self.dataset_dir,
                                          show_path=True)
            categories = File_Util.load_json(self.dataset_name + "_cats",
                                             filepath=self.dataset_dir,
                                             show_path=True)
            assert len(txts) == len(classes),\
                "Count of txts [{0}] and sample2cats [{1}] should match.".format(
                    len(txts),len(classes))
        else:
            logger.warn("Pre-processed json files not found at: [{}]".format(
                join(self.dataset_dir, self.dataset_name + "_txts.json")))
            logger.info(
                "Loading raw data and creating 3 separate dicts of txts [id->texts], sample2cats [id->class_ids]"
                " and categories [class_name : class_id].")
            txts, classes, categories = self.load_raw_data(self.dataset_type)
            File_Util.save_json(categories,
                                self.dataset_name + "_cats",
                                filepath=self.dataset_dir)
            File_Util.save_json(txts,
                                self.dataset_name + "_txts",
                                filepath=self.dataset_dir)
            File_Util.save_json(classes,
                                self.dataset_name + "_sample2cats",
                                filepath=self.dataset_dir)
            logger.info("Cleaning categories.")
            categories, categories_dup_dict, dup_cat_text_map = self.clean.clean_categories(
                categories)
            File_Util.save_json(dup_cat_text_map,
                                self.dataset_name + "_dup_cat_text_map",
                                filepath=self.dataset_dir,
                                overwrite=True)
            File_Util.save_json(categories,
                                self.dataset_name + "_cats",
                                filepath=self.dataset_dir,
                                overwrite=True)
            if categories_dup_dict:  # Replace old category ids with new ids if duplicate categories found.
                File_Util.save_json(
                    categories_dup_dict,
                    self.dataset_name + "_categories_dup_dict",
                    filepath=self.dataset_dir,
                    overwrite=True
                )  # Storing the duplicate categories for future dedup removal.
                classes = self.clean.dedup_data(classes, categories_dup_dict)
            assert len(txts) == len(classes),\
                "Count of txts [{0}] and sample2cats [{1}] should match.".format(
                    len(txts),len(classes))
            File_Util.save_json(txts,
                                self.dataset_name + "_txts",
                                filepath=self.dataset_dir,
                                overwrite=True)
            File_Util.save_json(classes,
                                self.dataset_name + "_sample2cats",
                                filepath=self.dataset_dir,
                                overwrite=True)
            logger.info(
                "Saved txts [{0}], sample2cats [{1}] and categories [{2}] as json files."
                .format(join(self.dataset_dir + "_txts.json"),
                        join(self.dataset_dir + "_sample2cats.json"),
                        join(self.dataset_dir + "_cats.json")))
        if return_values:
            return txts, classes, categories
        else:
            # Splitting data into train, validation and test sets.
            self.txts_train,self.sample2cats_train,self.cats_sel,self.txts_val,self.sample2cats_val,\
            self.cats_val,self.txts_test,self.sample2cats_test,self.cats_test,catid2cattxt_map =\
                self.split_data(txts=txts,classes=classes,categories=categories)
            txts, classes, categories = None, None, None  # Remove large dicts and free up memory.
            collect()

            File_Util.save_json(self.txts_test,
                                self.dataset_name + "_txts_test",
                                filepath=self.dataset_dir)
            File_Util.save_json(self.sample2cats_test,
                                self.dataset_name + "_sample2cats_test",
                                filepath=self.dataset_dir)
            File_Util.save_json(self.txts_val,
                                self.dataset_name + "_txts_val",
                                filepath=self.dataset_dir)
            File_Util.save_json(self.sample2cats_val,
                                self.dataset_name + "_sample2cats_val",
                                filepath=self.dataset_dir)
            File_Util.save_json(self.txts_train,
                                self.dataset_name + "_txts_train",
                                filepath=self.dataset_dir)
            File_Util.save_json(self.sample2cats_train,
                                self.dataset_name + "_sample2cats_train",
                                filepath=self.dataset_dir)
            File_Util.save_json(self.cats_sel,
                                self.dataset_name + "_cats_train",
                                filepath=self.dataset_dir)
            File_Util.save_json(self.cats_val,
                                self.dataset_name + "_cats_val",
                                filepath=self.dataset_dir)
            File_Util.save_json(self.cats_test,
                                self.dataset_name + "_cats_test",
                                filepath=self.dataset_dir)
            File_Util.save_json(catid2cattxt_map,
                                self.dataset_name + "_catid2cattxt_map",
                                filepath=self.dataset_dir)
            return self.txts_train,self.sample2cats_train,self.cats_sel,self.txts_val,self.sample2cats_val,\
                   self.cats_val,self.txts_test,self.sample2cats_test,self.cats_test

    def load_raw_data(self, dataset_type: str = None):
        """
        Loads raw data based on type of dataset.

        :param dataset_type: Type of dataset.
        """
        if dataset_type is None: dataset_type = self.dataset_type
        if dataset_type == "html":
            self.dataset = html.HTMLLoader(dataset_name=self.dataset_name,
                                           dataset_dir=self.dataset_dir)
        elif dataset_type == "json":
            self.dataset = json.JSONLoader(dataset_name=self.dataset_name,
                                           dataset_dir=self.dataset_dir)
        elif dataset_type == "txt":
            self.dataset = txt.TXTLoader(dataset_name=self.dataset_name,
                                         dataset_dir=self.dataset_dir)
        else:
            raise Exception(
                "Dataset type for dataset [{}] not found. \n"
                "Possible reasons: Dataset not added to the config file.".
                format(self.dataset_name))
        txts, classes, categories = self.dataset.get_data()
        return txts, classes, categories

    def split_data(self,
                   txts: OrderedDict,
                   classes: OrderedDict,
                   categories: OrderedDict,
                   test_split: int = config["data"]["test_split"],
                   val_split: int = config["data"]["val_split"]):
        """ Splits input data into train, val and test.

        :return:
        :param categories:
        :param classes:
        :param txts:
        :param val_split: Validation split size.
        :param test_split: Test split size.
        :return:
        """
        logger.info("Total number of samples: [{}]".format(len(classes)))
        sample2cats_train,sample2cats_test,txts_train,txts_test =\
            File_Util.split_dict(classes,txts,
                                 batch_size=int(len(classes) * test_split))
        logger.info("Test count: [{}]. Remaining count: [{}]".format(
            len(sample2cats_test), len(sample2cats_train)))

        sample2cats_train,sample2cats_val,txts_train,txts_val =\
            File_Util.split_dict(sample2cats_train,txts_train,
                                 batch_size=int(len(txts_train) * val_split))
        logger.info("Validation count: [{}]. Train count: [{}]".format(
            len(sample2cats_val), len(sample2cats_train)))

        if isfile(
                join(self.dataset_dir,
                     self.dataset_name + "_catid2cattxt_map.json")):
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)
            # Integer keys are converted to string when saving as JSON. Converting back to integer.
            catid2cattxt_map_int = OrderedDict()
            for k, v in catid2cattxt_map.items():
                catid2cattxt_map_int[int(k)] = v
            catid2cattxt_map = catid2cattxt_map_int
        else:
            logger.info("Generating inverted categories.")
            catid2cattxt_map = File_Util.inverse_dict_elm(categories)

        logger.info("Creating train categories.")
        cats_train = OrderedDict()
        for k, v in sample2cats_train.items():
            for cat_id in v:
                if cat_id not in cats_train:
                    cats_train[cat_id] = catid2cattxt_map[cat_id]
        cats_train = cats_train

        logger.info("Creating validation categories.")
        cats_val = OrderedDict()
        for k, v in sample2cats_val.items():
            for cat_id in v:
                if cat_id not in cats_val:
                    cats_val[cat_id] = catid2cattxt_map[cat_id]
        cats_val = cats_val

        logger.info("Creating test categories.")
        cats_test = OrderedDict()
        for k, v in sample2cats_test.items():
            for cat_id in v:
                if cat_id not in cats_test:
                    cats_test[cat_id] = catid2cattxt_map[cat_id]
        cats_test = cats_test
        return txts_train, sample2cats_train, cats_train, txts_val, sample2cats_val, cats_val, txts_test, sample2cats_test, cats_test, catid2cattxt_map

    def gen_cat2samples_map(self, classes_dict: dict = None):
        """ Generates a dictionary of category to samples mapping.i.e. sample : categories -> categories : samples

        :returns: A dictionary of categories to sample mapping.
        """
        cat2samples_map = OrderedDict()
        if classes_dict is None: classes_dict = self.sample2cats_sel
        for sample_id, categories_list in classes_dict.items():
            for cat in categories_list:
                if cat not in cat2samples_map:
                    cat2samples_map[cat] = []
                cat2samples_map[cat].append(sample_id)
        return cat2samples_map

    def find_cats_with_few_samples(self,
                                   sample2cats_map: dict = None,
                                   cat2samples_map: dict = None,
                                   remove_count=20):
        """ Finds categories with <= [remove_count] samples. Default few-shot = <=20.

        :returns:
            cat2samples_few: Category to samples map without tail categories.
            tail_cats: List of tail cat ids.
            samples_with_tail_cats: Set of sample ids which belong to tail categories.
        """
        if cat2samples_map is None:
            cat2samples_map = self.gen_cat2samples_map(sample2cats_map)
        tail_cats = []
        samples_with_tail_cats = set()
        cat2samples_few = OrderedDict()
        for cat, sample_list in cat2samples_map.items():
            if len(sample_list) <= remove_count:
                cat2samples_few[cat] = len(sample_list)
            else:
                tail_cats.append(cat)
                samples_with_tail_cats.update(sample_list)

        return tail_cats, samples_with_tail_cats, cat2samples_few

    def get_data(
            self,
            load_type: str = "all",
            calculate_idf=False) -> (OrderedDict, OrderedDict, OrderedDict):
        """:returns loaded dictionaries based on "load_type" value. Loads all if not provided."""
        self.cattext2catid_map = self.load_categories()
        if load_type == "train":
            self.txts_sel, self.sample2cats_sel, self.cats_sel = self.load_train(
            )
        elif load_type == "val":
            self.txts_sel, self.sample2cats_sel, self.cats_sel = self.load_val(
            )
        elif load_type == "test":
            self.txts_sel, self.sample2cats_sel, self.cats_sel = self.load_test(
            )
        elif load_type == "all":
            self.txts_sel, self.sample2cats_sel, self.cats_sel = self.load_all(
            )
        else:
            raise Exception(
                "Unsupported load_type: [{}]. \n"
                "Available options: ['all (Default)','train','val','test']".
                format(load_type))
        # self.gen_data_stats(self.txts_sel, self.sample2cats_sel, self.cats_sel)
        df = self.json2csv(self.txts_sel, self.sample2cats_sel)
        if calculate_idf:
            idf_dict = self.clean.calculate_idf_per_token(
                txts=list(self.txts_sel.values()))
            return df, self.cattext2catid_map, idf_dict
        return df, self.cattext2catid_map

    def load_categories(self) -> OrderedDict:
        """Loads and returns the whole categories set."""
        if self.cattext2catid_map is None:
            logger.debug(
                join(self.dataset_dir,
                     self.dataset_name + "_cattext2catid_map.json"))
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_cattext2catid_map.json")):
                self.cattext2catid_map = File_Util.load_json(
                    self.dataset_name + "_cattext2catid_map",
                    filepath=self.dataset_dir)
            else:
                _, _, self.cattext2catid_map = self.load_full_json(
                    return_values=True)
        return self.cattext2catid_map

    def load_all(self) -> (OrderedDict, OrderedDict, OrderedDict):
        """Loads and returns the whole data."""
        logger.debug(join(self.dataset_dir, self.dataset_name + "_txts.json"))
        if self.txts_sel is None:
            if isfile(join(self.dataset_dir,
                           self.dataset_name + "_txts.json")):
                self.txts_sel = File_Util.load_json(self.dataset_name +
                                                    "_txts",
                                                    filepath=self.dataset_dir)
            else:
                self.txts_sel, self.sample2cats_sel, self.cats_sel = self.load_full_json(
                    return_values=True)

        if self.sample2cats_sel is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_sample2cats.json")):
                self.sample2cats_sel = File_Util.load_json(
                    self.dataset_name + "_sample2cats",
                    filepath=self.dataset_dir)
            else:
                self.txts_sel, self.sample2cats_sel, self.cats_sel = self.load_full_json(
                    return_values=True)

        if self.cats_sel is None:
            if isfile(join(self.dataset_dir,
                           self.dataset_name + "_cats.json")):
                self.cats_sel = File_Util.load_json(self.dataset_name +
                                                    "_cats",
                                                    filepath=self.dataset_dir)
            else:
                self.txts_sel, self.sample2cats_sel, self.cats_sel = self.load_full_json(
                    return_values=True)
        collect()

        logger.info(
            "Total data counts:\n\ttxts = [{}],\n\tsample2cats = [{}],\n\tcattext2catid_map = [{}]"
            .format(len(self.txts_sel), len(self.sample2cats_sel),
                    len(self.cats_sel)))
        return self.txts_sel, self.sample2cats_sel, self.cats_sel

    def load_train(self) -> (OrderedDict, OrderedDict, OrderedDict):
        """Loads and returns training set."""
        logger.debug(
            join(self.dataset_dir, self.dataset_name + "_txts_train.json"))
        if self.txts_train is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_txts_train.json")):
                self.txts_train = File_Util.load_json(
                    self.dataset_name + "_txts_train",
                    filepath=self.dataset_dir)
            else:
                self.load_full_json()

        if self.sample2cats_train is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_sample2cats_train.json")):
                self.sample2cats_train = File_Util.load_json(
                    self.dataset_name + "_sample2cats_train",
                    filepath=self.dataset_dir)
            else:
                self.load_full_json()

        if self.cats_sel is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_cats_train.json")):
                self.cats_sel = File_Util.load_json(self.dataset_name +
                                                    "_cats_train",
                                                    filepath=self.dataset_dir)
            else:
                self.load_full_json()
        collect()

        # logger.info("Training data counts:\n\ttxts = [{}],\n\tClasses = [{}],\n\tCategories = [{}]"
        #             .format(len(self.txts_train), len(self.sample2cats_train), len(self.cats_train)))
        return self.txts_train, self.sample2cats_train, self.cats_sel

    def load_val(self) -> (OrderedDict, OrderedDict, OrderedDict):
        """Loads and returns validation set."""
        if self.txts_val is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_txts_val.json")):
                self.txts_val = File_Util.load_json(self.dataset_name +
                                                    "_txts_val",
                                                    filepath=self.dataset_dir)
            else:
                self.load_full_json()

        if self.sample2cats_val is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_sample2cats_val.json")):
                self.sample2cats_val = File_Util.load_json(
                    self.dataset_name + "_sample2cats_val",
                    filepath=self.dataset_dir)
            else:
                self.load_full_json()

        if self.cats_val is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_cats_val.json")):
                self.cats_val = File_Util.load_json(self.dataset_name +
                                                    "_cats_val",
                                                    filepath=self.dataset_dir)
            else:
                self.load_full_json()
        collect()

        # logger.info("Validation data counts:\n\ttxts = [{}],\n\tClasses = [{}],\n\tCategories = [{}]"
        #             .format(len(self.txts_val), len(self.sample2cats_val), len(self.cats_val)))
        return self.txts_val, self.sample2cats_val, self.cats_val

    def load_test(self) -> (OrderedDict, OrderedDict, OrderedDict):
        """Loads and returns test set."""
        if self.txts_test is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_txts_test.json")):
                self.txts_test = File_Util.load_json(self.dataset_name +
                                                     "_txts_test",
                                                     filepath=self.dataset_dir)
            else:
                self.load_full_json()

        if self.sample2cats_test is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_sample2cats_test.json")):
                self.sample2cats_test = File_Util.load_json(
                    self.dataset_name + "_sample2cats_test",
                    filepath=self.dataset_dir)
            else:
                self.load_full_json()

        if self.cats_test is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_cats_test.json")):
                self.cats_test = File_Util.load_json(self.dataset_name +
                                                     "_cats_test",
                                                     filepath=self.dataset_dir)
            else:
                self.load_full_json()
        collect()

        # logger.info("Testing data counts:\n\ttxts = [{}],\n\tClasses = [{}],\n\tCategories = [{}]"
        #             .format(len(self.txts_test), len(self.sample2cats_test), len(self.cats_test)))
        return self.txts_test, self.sample2cats_test, self.cats_test

    def create_new_data(self,
                        new_data_name: str = "_pointer",
                        save_files: bool = True,
                        save_dir: str = None,
                        catid2cattxt_map: OrderedDict = None):
        """Creates new dataset based on new_data_name value, currently supports: "_fixed5" and "_onehot".

        _fixed5: Creates a dataset of samples which belongs to any of the below 5 sample2cats only.
        _onehot: Creates a dataset which belongs to single class only.

        NOTE: This method is used only for sanity testing using fixed multi-class scenario.
        """
        if save_dir is None:
            save_dir = join(self.dataset_dir,
                            self.dataset_name + new_data_name)
        if isfile(
                join(save_dir, self.dataset_name + new_data_name +
                     "_sample2cats.json")) and isfile(
                         join(save_dir, self.dataset_name + new_data_name +
                              "_txts.json")) and isfile(
                                  join(
                                      save_dir, self.dataset_name +
                                      new_data_name + "_cats.json")):
            logger.info("Loading files from: [{}]".format(save_dir))
            txts_new = File_Util.load_json(self.dataset_name + new_data_name +
                                           "_txts",
                                           filepath=save_dir)
            sample2cats_new = File_Util.load_json(
                self.dataset_name + new_data_name + "_sample2cats",
                filepath=save_dir)
            cats_new = File_Util.load_json(self.dataset_name + new_data_name +
                                           "_cats",
                                           filepath=save_dir)
        else:
            logger.info(
                "No existing files found at [{}]. Generating {} files.".format(
                    save_dir, new_data_name))
            if catid2cattxt_map is None:                catid2cattxt_map =\
File_Util.load_json(self.dataset_name + "_catid2cattxt_map",
                filepath=self.dataset_dir)

            txts, classes, _ = self.load_full_json(return_values=True)
            if new_data_name is "_fixed5":
                txts_one, classes_one, _ = self._create_oneclass_data(
                    txts, classes, catid2cattxt_map=catid2cattxt_map)
                txts_new,sample2cats_new,cats_new =\
                    self._create_fixed_cat_data(txts_one,classes_one,
                                                catid2cattxt_map=catid2cattxt_map)
            elif new_data_name is "_onehot":
                txts_new,sample2cats_new,cats_new =\
                    self._create_oneclass_data(txts,classes,
                                               catid2cattxt_map=catid2cattxt_map)
            elif new_data_name is "_pointer":
                txts_new,sample2cats_new,cats_new =\
                    self._create_pointer_data(txts,classes,
                                              catid2cattxt_map=catid2cattxt_map)
            elif new_data_name is "_fewshot":
                txts_new,sample2cats_new,cats_new =\
                    self._create_fewshot_data(txts,classes,
                                              catid2cattxt_map=catid2cattxt_map)
            elif new_data_name is "_firstsent":
                txts_new,sample2cats_new,cats_new =\
                    self._create_firstsent_data(txts,classes,
                                                catid2cattxt_map=catid2cattxt_map)
            else:
                raise Exception(
                    "Unknown 'new_data_name': [{}]. \n Available options: ['_fixed5','_onehot', '_pointer']"
                    .format(new_data_name))
            if save_files:  # Storing new data
                logger.info(
                    "New dataset will be stored inside original dataset directory at: [{}]"
                    .format(save_dir))
                makedirs(save_dir, exist_ok=True)
                File_Util.save_json(txts_new,
                                    self.dataset_name + new_data_name +
                                    "_txts",
                                    filepath=save_dir)
                File_Util.save_json(sample2cats_new,
                                    self.dataset_name + new_data_name +
                                    "_sample2cats",
                                    filepath=save_dir)
                File_Util.save_json(cats_new,
                                    self.dataset_name + new_data_name +
                                    "_cats",
                                    filepath=save_dir)

        return txts_new, sample2cats_new, cats_new

    def _create_fixed_cat_data(
            self,
            txts: OrderedDict,
            classes: OrderedDict,
            fixed5_cats: list = None,
            catid2cattxt_map=None) -> (OrderedDict, OrderedDict, OrderedDict):
        """Creates a dataset of samples which belongs to any of the below 5 sample2cats only.

        Selected sample2cats: [114, 3178, 3488, 1922, 517], these sample2cats has max number of samples associated with them.
        NOTE: This method is used only for sanity testing using fixed multi-class scenario.
        """
        if fixed5_cats is None: fixed5_cats = [114, 3178, 3488, 1922, 3142]
        if catid2cattxt_map is None:
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)
        txts_one_fixed5 = OrderedDict()
        classes_one_fixed5 = OrderedDict()
        categories_one_fixed5 = OrderedDict()
        for doc_id, lbls in classes.items():
            if lbls[0] in fixed5_cats:
                classes_one_fixed5[doc_id] = lbls
                txts_one_fixed5[doc_id] = txts[doc_id]
                for lbl in classes_one_fixed5[doc_id]:
                    if lbl not in categories_one_fixed5:
                        categories_one_fixed5[catid2cattxt_map[str(lbl)]] = lbl

        return txts_one_fixed5, classes_one_fixed5, categories_one_fixed5

    def _create_oneclass_data(
        self,
        txts: OrderedDict,
        classes: OrderedDict,
        catid2cattxt_map: OrderedDict = None
    ) -> (OrderedDict, OrderedDict, OrderedDict):
        """Creates a dataset which belongs to single class only.

        NOTE: This method is used only for sanity testing using multi-class scenario.
        """
        if catid2cattxt_map is None:
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)
        txts_one = OrderedDict()
        classes_one = OrderedDict()
        categories_one = OrderedDict()
        for doc_id, lbls in classes.items():
            if len(lbls) == 1:
                classes_one[doc_id] = lbls
                txts_one[doc_id] = txts[doc_id]
                for lbl in classes_one[doc_id]:
                    if lbl not in categories_one:
                        categories_one[catid2cattxt_map[str(lbl)]] = lbl

        return txts_one, classes_one, categories_one

    def cat_token_counts(self, catid2cattxt_map=None):
        """ Counts the number of tokens in categories.

        :return:
        :param catid2cattxt_map:
        """
        if catid2cattxt_map is None:
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)
        cat_word_counts = {}
        for cat in catid2cattxt_map:
            cat_word_counts[cat] = len(self.clean.tokenizer_spacy(cat))

        return cat_word_counts

    def json2csv(self, txts_all: dict = None, sample2cats_all: dict = None):
        """ Converts existing multiple json files and returns a single pandas dataframe.

        :param txts_all:
        :param sample2cats_all:
        """
        if exists(join(self.dataset_dir, self.dataset_name + "_df.csv")):
            df = pd.read_csv(
                filepath_or_buffer=join(self.dataset_dir, self.dataset_name +
                                        "_df.csv"))
            df = df[~df['txts'].isna()]
        else:
            if txts_all is None or sample2cats_all is None:
                txts_all, sample2cats_all, cats_all, cats_all = self.get_data(
                    load_type="all")
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)
            txts_all_list,sample2cats_all_list,idxs,sample2catstext_all_list = [],[],[],[]
            for idx in sample2cats_all.keys():
                idxs.append(idx)
                txts_all_list.append(txts_all[idx])
                sample2cats_all_list.append(sample2cats_all[idx])
                sample2catstext = []
                for lbl in sample2cats_all[idx]:
                    sample2catstext.append(catid2cattxt_map[str(lbl)])
                sample2catstext_all_list.append(sample2catstext)

            df = pd.DataFrame.from_dict({
                "idx": idxs,
                "txts": txts_all_list,
                "cat": sample2cats_all_list,
                "cat_txt": sample2catstext_all_list
            })
            df = df[~df['txts'].isna()]
            df.to_csv(path_or_buf=join(self.dataset_dir, self.dataset_name +
                                       "_df.csv"))
        logger.info("Data shape = {} ".format(df.shape))
        return df

    def check_cat_present_txt(
            self,
            txts: OrderedDict,
            classes: OrderedDict,
            catid2cattxt_map: OrderedDict = None) -> OrderedDict:
        """Generates a dict of dicts containing the positions of all categories within each text.

        :param classes:
        :param txts:
        :param catid2cattxt_map:
        :return:
        """
        if catid2cattxt_map is None:
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)
        label_ptrs = OrderedDict()
        for doc_id, txt in txts.items():
            label_ptrs[doc_id] = OrderedDict()
            for lbl_id in catid2cattxt_map:
                label_ptrs[doc_id][lbl_id] = self.clean.find_label_occurrences(
                    txt, catid2cattxt_map[str(lbl_id)])
                label_ptrs[doc_id]["true"] = classes[doc_id]

        return label_ptrs

    def _create_pointer_data(
        self,
        txts: OrderedDict,
        classes: OrderedDict,
        catid2cattxt_map: OrderedDict = None
    ) -> (OrderedDict, OrderedDict, OrderedDict):
        """ Creates pointer network type dataset, i.e. labels are marked within document text. """
        if catid2cattxt_map is None:
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)
        txts_ptr = OrderedDict()
        classes_ptr = OrderedDict()
        categories_ptr = OrderedDict()
        for doc_id, lbl_ids in classes.items():
            for lbl_id in lbl_ids:
                label_ptrs = self.clean.find_label_occurrences(
                    txts[doc_id], catid2cattxt_map[str(lbl_id)])
                if label_ptrs:  ## Only if categories exists within the document.
                    classes_ptr[doc_id] = {lbl_id: label_ptrs}
                    txts_ptr[doc_id] = txts[doc_id]

                    if lbl_id not in categories_ptr:
                        categories_ptr[lbl_id] = catid2cattxt_map[str(lbl_id)]

        return txts_ptr, classes_ptr, categories_ptr

    def _create_fewshot_data(
        self,
        txts: OrderedDict,
        classes: OrderedDict,
        catid2cattxt_map: OrderedDict = None
    ) -> (OrderedDict, OrderedDict, OrderedDict):
        """Creates few-shot dataset, i.e. categories with <= 20 samples.

        :param classes:
        :param txts:
        :param catid2cattxt_map:
        :return:
        """
        if catid2cattxt_map is None:
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)

        tail_cats, samples_with_tail_cats, cat2samples_filtered = self.find_cats_with_few_samples(
            sample2cats_map=classes, cat2samples_map=None)
        txts_few = OrderedDict()
        classes_few = OrderedDict()
        categories_few = OrderedDict()
        for doc_id, lbls in classes.items():
            if len(lbls) == 1:
                classes_few[doc_id] = lbls
                txts_few[doc_id] = txts[doc_id]
                for lbl in classes_few[doc_id]:
                    if lbl not in categories_few:
                        categories_few[catid2cattxt_map[str(lbl)]] = lbl

        return txts_few, classes_few, categories_few

    def _create_firstsent_data(
        self,
        txts: OrderedDict,
        classes: OrderedDict,
        catid2cattxt_map: OrderedDict = None
    ) -> (OrderedDict, OrderedDict, OrderedDict):
        """Creates a version of wikipedia dataset with only first sentences and discarding the text.

        :param classes:
        :param txts:
        :param catid2cattxt_map:
        :return:
        """
        if catid2cattxt_map is None:
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)

        txts_firstsent = OrderedDict()
        classes_firstsent = OrderedDict()
        categories_firstsent = OrderedDict()
        for doc_id, lbls in classes.items():
            if len(lbls) == 1:
                classes_firstsent[doc_id] = lbls
                txts_firstsent[doc_id] = txts[doc_id]
                for lbl in classes_firstsent[doc_id]:
                    if lbl not in categories_firstsent:
                        categories_firstsent[catid2cattxt_map[str(lbl)]] = lbl

        return txts_firstsent, classes_firstsent, categories_firstsent
Exemple #7
0
class Prepare_Data:
    """ Prepare data into proper format.

        Converts strings to vectors,
        Converts category ids to multi-hot vectors,
        etc.
    """

    def __init__(self,dataset_loader,dataset_name: str = config["data"]["dataset_name"],
                 dataset_dir: str = config["paths"]["dataset_dir"][plat][user]) -> None:
        self.dataset_name = dataset_name
        self.dataset_dir = dataset_dir
        self.dataset_loader = dataset_loader

        self.txts2vec_map = None
        self.doc2vec_model = None
        self.cats_all = None
        self.txt_encoder_model = None
        self.txts_sel,self.sample2cats_sel,self.cats_sel = None,None,None

        self.graph = Neighborhood_Graph()
        self.txt_process = Text_Process()
        self.txt_encoder = Text_Encoder()
        self.mlb = MultiLabelBinarizer()

        # dataset_loader.gen_data_stats()
        self.oov_words_dict = OrderedDict()

    def load_graph_data(self,nodes):
        """ Loads graph data for XC datasets. """
        Docs_G = self.graph.load_doc_neighborhood_graph(nodes=nodes,get_stats=config["graph"]["stats"])
        Docs_adj_coo = self.graph.get_adj_matrix(Docs_G,adj_format='coo')
        # Docs_adj_coo_t = adj_csr2t_coo(Docs_adj_coo)
        return Docs_adj_coo

    @staticmethod
    def create_batch_repeat(X: dict,Y: dict,keys):
        """
        Generates batch from keys.

        :param X:
        :param Y:
        :param keys:
        :return:
        """
        batch_x = []
        batch_y = []
        shuffle(keys)
        for k in keys:
            batch_x.append(X[k])
            batch_y.append(Y[k])
        return np.stack(batch_x),batch_y

    def get_input_batch(self,txts:dict,sample2cats:dict,keys:list=None,return_cat_indices: bool = False,multi_label: bool = True) ->\
            [np.ndarray,np.ndarray,np.ndarray]:
        """Generates feature vectors of input documents.

        :param txts:
        :param sample2cats:
        :param keys:
        :param return_cat_indices:
        :param multi_label:
        :return:
        """
        if keys is None:
            sample_ids = list(txts.keys())
            # batch_size = int(0.7 * len(sample_ids))
            batch_size = int(len(sample_ids))
            _,keys = File_Util.get_batch_keys(sample_ids,batch_size=batch_size, remove_keys=False)
        txt_vecs_keys,sample2cats_keys = self.create_batch_repeat(txts,sample2cats,keys)
        sample2cats_keys_hot = self.mlb.transform(sample2cats_keys)

        if return_cat_indices:
            if multi_label:
                ## For Multi-Label, multi-label-margin loss
                cats_idx = [self.mlb.inverse_transform(sample2cats_keys_hot)]
            else:
                ## For Multi-Class, cross-entropy loss
                cats_idx = sample2cats_keys_hot.argmax(1)
            return txt_vecs_keys,sample2cats_keys_hot,keys,cats_idx

        return txt_vecs_keys,sample2cats_keys_hot,keys

    def invert_cat2samples(self,classes_dict: dict = None):
        """Converts sample : cats to cats : samples

        :returns: A dictionary of cats to sample mapping.
        """
        cat2id = OrderedDict()
        if classes_dict is None: classes_dict = self.sample2cats_sel
        for k,v in classes_dict.items():
            for cat in v:
                if cat not in cat2id:
                    cat2id[cat] = []
                cat2id[cat].append(k)
        return cat2id

    def create_vec_maps(self,txts:dict=None,cats:dict=None):
        """ Maps text and categories to their vector representation.

        :param txts:
        :param cats:
        :return:
        """
        logger.debug(join(self.dataset_dir,self.dataset_name,self.dataset_name + "_txts2vec_map.pkl"))
        if isfile(join(self.dataset_dir,self.dataset_name,self.dataset_name + "_txts2vec_map.pkl"))\
                and isfile(join(self.dataset_dir,self.dataset_name,self.dataset_name + "_cats2vec_map.pkl")):
            logger.info("Loading pre-processed mappings from: [{}] and [{}]"
                        .format(join(self.dataset_dir,self.dataset_name,self.dataset_name + "_txts2vec_map.pkl"),
                                join(self.dataset_dir,self.dataset_name,self.dataset_name + "_cat2vec_map.pkl")))
            txts2vec_map = File_Util.load_pickle(self.dataset_name + "_txts2vec_map",
                                                 filepath=join(self.dataset_dir,self.dataset_name))
            cats2vec_map = File_Util.load_pickle(self.dataset_name + "_cats2vec_map",
                                                 filepath=join(self.dataset_dir,self.dataset_name))
        else:
            if txts is None or cats is None:
                txts,_,_,cats = self.load_raw_data(load_type='all',return_values=True)
            ## Generate txts2vec_map and cats2vec_map
            logger.info("Generating pre-processed mappings.")
            txts2vec_map = self.txt_process.gen_sample2vec_map(txts=txts)
            catid2cattxt = File_Util.inverse_dict_elm(cats)
            cats2vec_map = self.txt_process.gen_cats2vec_map(cats=catid2cattxt)

            logger.info("Saving pre-processed mappings to: [{}] and [{}]"
                        .format(join(self.dataset_dir,self.dataset_name,self.dataset_name + "_txts2vec_map.pkl"),
                                join(self.dataset_dir,self.dataset_name,self.dataset_name + "_cat2vec_map.pkl")))
            File_Util.save_pickle(txts2vec_map,self.dataset_name + "_txts2vec_map",
                                  filepath=join(self.dataset_dir,self.dataset_name))
            File_Util.save_pickle(cats2vec_map,self.dataset_name + "_cats2vec_map",
                                  filepath=join(self.dataset_dir,self.dataset_name))
        return txts2vec_map,cats2vec_map

    def load_raw_data(self,load_type: str = 'all', return_values=True):
        """ Loads the json data provided by param "load_type".

        :param return_values:
        :param load_type: Which data to load: Options: ['train', 'val', 'test']
        """
        self.txts_sel,self.sample2cats_sel,self.cats_sel,self.cats_all = self.dataset_loader.get_data(load_type=load_type)
        self.remain_sample_ids = list(self.txts_sel.keys())
        self.cat2sample_map = self.invert_cat2samples(self.sample2cats_sel)
        self.remain_cat_ids = list(self.cats_sel.keys())

        ## MultiLabelBinarizer only take list of list as input. Need to convert "list of int" to "list of list".
        cat_ids = []
        for cat_id in self.cats_all.values():
            cat_ids.append([cat_id])
        self.mlb.fit(cat_ids)

        # self.idf_dict = self.txt_process.calculate_idf_per_token(txts=self.txts_sel.values())

        if return_values:
            return self.txts_sel,self.sample2cats_sel,self.cats_sel,self.cats_all

    def normalize_inputs(self):
        """
        Normalizes our data, to have a mean of 0 and sdt of 1.

        """
        self.mean = np.mean(self.x_train)
        self.std = np.std(self.x_train)
        self.max = np.max(self.x_train)
        self.min = np.min(self.x_train)
        logger.debug(
            ("train_shape",self.x_train.shape,"test_shape",self.x_test.shape,"val_shape",self.x_val.shape))
        self.x_train = (self.x_train - self.mean) / self.std
        self.x_val = (self.x_val - self.mean) / self.std
        self.x_test = (self.x_test - self.mean) / self.std

    def create_multihot(self,batch_classes_dict):
        """
        Creates multi-hot vectors for a batch of data.

        :param batch_classes_dict:
        :return:
        """
        classes_multihot = self.mlb.fit_transform(batch_classes_dict.values())
        return classes_multihot
 def __init__(self) -> None:
     super(Vector_Visualizations, self).__init__()
     self.text_process = Text_Process()
     self.text_encoder = Text_Encoder()
class Vector_Visualizations:
    """ Visualize vectors in 2D. """
    def __init__(self) -> None:
        super(Vector_Visualizations, self).__init__()
        self.text_process = Text_Process()
        self.text_encoder = Text_Encoder()

    def create_vectors(self, cats: dict):
        """ Creates vector from cats.

        :param cats:
        """
        self.cats_processed = self.text_process.process_cats(cats)
        model = self.text_encoder.load_word2vec()
        self.cats_processed_vecs, _ = self.text_process.gen_lbl2vec(
            self.cats_processed, model)
        return self.cats_processed_vecs

    def show_vectors(self, cats_processed_vecs=None):
        if cats_processed_vecs is None:
            cats_processed_vecs = self.cats_processed_vecs
        cats_processed_2d = self.use_tsne(cats_processed_vecs,
                                          list(self.cats.values()))
        return cats_processed_2d

    def view_closestwords_tsnescatterplot(
            self,
            model,
            word,
            word_dim=config["prep_vecs"]["input_size"],
            sim_words=config["prep_vecs"]["sim_words"]):
        """ Method to plot the top sim_words in 2D using TSNE.

        :param model:
        :param word:
        :param word_dim:
        :param sim_words:
        :param plot_title:
        """
        arr = np.empty((0, word_dim), dtype='f')
        word_labels = [word]

        ## get close words
        close_words = model.similar_by_word(word, topn=sim_words)

        ## add the vector for each of the closest words to the array
        arr = np.append(arr, np.array([model[word]]), axis=0)
        for wrd_score in close_words:
            wrd_vector = model[wrd_score[0]]
            word_labels.append(wrd_score[0])
            arr = np.append(arr, np.array([wrd_vector]), axis=0)

        self.use_tsne(arr, word_labels)

    def use_tsne(self, vecs: np.matrix, word_labels: list[int]) -> np.ndarray:
        """ Use tsne to project the vectors.

        :param vecs:
        :param word_labels:
        :param plot_title:
        """
        ## find tsne coords for 2 dimensions
        tsne = TSNE(n_components=2, random_state=0)
        np.set_printoptions(suppress=True)
        Y = tsne.fit_transform(vecs)
        self.plot_vectors(Y, word_labels)

        return Y

    @staticmethod
    def plot_vectors(vecs,
                     labels,
                     test_offset: float = 0.005,
                     plot_title=config["graph"]["plot_name"]):
        """ Use tsne to project the vectors.

        :param test_offset:
        :param vecs:
        :param labels:
        :param plot_title:
        """
        vecs = np.asarray(vecs)
        x_coords = vecs[:, 0]
        y_coords = vecs[:, 1]

        ## display scatter plot
        plt.scatter(x_coords, y_coords)

        for label, x, y in zip(labels, x_coords, y_coords):
            plt.annotate(label,
                         xy=(x, y),
                         xytext=(0, 0),
                         textcoords='offset points')
        plt.xlim(x_coords.min() + test_offset, x_coords.max() + test_offset)
        plt.ylim(y_coords.min() + test_offset, y_coords.max() + test_offset)
        plt.xticks(rotation=35)
        plt.title(plot_title)
        plt.show()
        plt.savefig(plot_title)

    def use_pca(self, vecs, word_labels):
        """ Use tsne to project the vectors.

        :param vecs:
        :param word_labels:
        :param plot_title:
        """
        ## find tsne coords for 2 dimensions
        tsne = TSNE(n_components=2, random_state=0)
        np.set_printoptions(suppress=True)
        Y = tsne.fit_transform(vecs)
        self.plot_vectors(Y, word_labels)

        return Y

    @staticmethod
    def get_cosine_dist(vecs):
        """

        :param vecs:
        :return:
        """
        from sklearn.metrics.pairwise import cosine_similarity

        pair_cosine = cosine_similarity(vecs, vecs)
        return pair_cosine
Exemple #10
0
class Neighborhood_Graph:
    """ Class to generate neighborhood graph of categories. """
    def __init__(self,
                 dataset_name: str = config["data"]["dataset_name"],
                 graph_format: str = "graphml",
                 top_k: int = 10,
                 graph_dir: str = config["paths"]["dataset_dir"][plat][user]):
        """

        :param dataset_name:
        :param graph_dir:
        :param graph_format:
        :param top_k:
        """
        super(Neighborhood_Graph, self).__init__()
        self.graph_dir = graph_dir
        self.dataset_name = dataset_name
        self.graph_format = graph_format
        self.top_k = top_k
        self.txt_process = Text_Process()
        # self.load_doc_neighborhood_graph()

    def get_adj_matrix(self,
                       G: nx.classes.graph.Graph = None,
                       default_load: str = 'val',
                       adj_format: str = "coo"):
        """ Returns the adjacency matrix in [adj_format].

        :param default_load:
        :param G: Networkx Graph
        :param adj_format: scipy sparse format name, e.g. {‘bsr’, ‘csr’, ‘csc’, ‘coo’, ‘lil’, ‘dia’, ‘dok’}.
        """
        if G is None:
            G = self.load_doc_neighborhood_graph(get_stats=False)
        return nx.to_scipy_sparse_matrix(G, format=adj_format)

    def invert_classes_dict(self, input_dict=None):
        """ Generates a new dict with categories to document ids map.

        :param input_dict:
        :return:
        """
        if input_dict is None: input_dict = self.sample2cats
        inverted_dict = OrderedDict()
        for doc, cats in input_dict.items():
            for cat in cats:
                inverted_dict[cat].append(doc)

        return inverted_dict

    def prepare_label_graph(self, cat2docs_map=None):
        """ Generates a dict of categories mapped to document and then creates the label neighborhood graph.

        :param cat2docs_map:
        """
        if cat2docs_map is None:
            cat2docs_map = self.invert_classes_dict(self.sample2cats)

        G_cats = self.create_neighborhood_graph(cat2docs_map)

        nx.relabel_nodes(G_cats, self.cat_id2text_map, copy=False)
        return G_cats

    def create_neighborhood_graph(self,
                                  nodes: list = None,
                                  sample2cats_map: dict = None,
                                  min_common=config["graph"]["min_common"]):
        """ Generates the neighborhood graph (of type category or document) as key
        and common items as values.

        Default: Generates the document graph; for label graph call 'prepare_label_graph()'.

        :param nodes: List of node ids to consider, others are ignored.
        :param min_common: Minimum number of common categories between two documents.
        :param sample2cats_map:
        :return:
        """
        if sample2cats_map is None: sample2cats_map = self.sample2cats
        if nodes is None: nodes = list(sample2cats_map.keys())
        G = nx.Graph()
        G.add_nodes_from(nodes)
        logger.debug(nx.info(G))
        for doc1, cats1 in sample2cats_map.items():
            for doc2, cats2 in sample2cats_map.items():
                if doc1 in nodes and doc2 in nodes:  ## Consider samples only within [nodes] list.
                    if doc1 != doc2:
                        cats_common = set(cats1).intersection(set(cats2))
                        if len(cats_common) >= min_common:
                            G.add_edge(doc1,
                                       doc2,
                                       edge_id=str(doc1) + '-' + str(doc2),
                                       common=repr(cats_common))

        return G

    def create_word_cooccurance_graph_txts(
            self,
            nodes: list = None,
            txts: dict = None,
            min_common_ratio=config["graph"]["min_common_ratio"]):
        """ Generate the graph among documents using common words as edge.

        Default: Generates the document graph; for label graph call 'prepare_label_graph()'.

        :param nodes: List of node ids to consider, others are ignored.
        :param min_common_ratio: Minimum number of common categories between two documents.
        :param txts:
        :return:
        """
        if nodes is None: nodes = list(txts.keys())
        G = nx.Graph()
        G.add_nodes_from(nodes)
        for txt_id1, txt1 in txts.items():
            for txt_id2, txt2 in txts.items():
                txt1_tokens = self.txt_process.tokenizer_spacy(txt1)
                txt2_tokens = self.txt_process.tokenizer_spacy(txt2)
                cats_common = set(txt1_tokens).intersection(set(txt2_tokens))
                cats_common_ratio = len(cats_common) / (len(txt1_tokens) +
                                                        len(txt2_tokens))
                if cats_common_ratio >= min_common_ratio:  ## Normalize min_common_ratio by txt1 and txt2's token count.
                    G.add_edge(txt_id1,
                               txt_id2,
                               edge_id=str(txt_id1) + '-' + str(txt_id2),
                               min_common_tokens=repr(min_common_ratio))

        logger.debug(nx.info(G))
        return G

    def load_doc_neighborhood_graph(
            self,
            nodes,
            graph_path=None,
            get_stats: bool = config["graph"]["stats"]):
        """ Loads the graph file if found else creates neighborhood graph.

        :param nodes: List of node ids to consider.
        :param get_stats:
        :param graph_path: Full path to the graphml file.
        :return: Networkx graph, Adjecency matrix, stats related to the graph.
        """

        if graph_path is None:
            graph_path = join(
                self.graph_dir, self.dataset_name,
                self.dataset_name + "_G_" + str(len(nodes)) + ".graphml")
        if exists(graph_path):
            logger.info(
                "Loading neighborhood graph from [{0}]".format(graph_path))
            Docs_G = nx.read_graphml(graph_path)
        else:
            self.sample2cats = File_Util.load_json(
                join(self.graph_dir, self.dataset_name,
                     self.dataset_name + "_sample2cats"))
            self.categories = File_Util.load_json(
                join(self.graph_dir, self.dataset_name,
                     self.dataset_name + "_cats"))
            self.cat_id2text_map = File_Util.load_json(
                join(self.graph_dir, self.dataset_name,
                     self.dataset_name + "_catid2cattxt_map"))
            Docs_G = self.create_neighborhood_graph(nodes=nodes)
            logger.debug(nx.info(Docs_G))
            logger.info(
                "Saving neighborhood graph at [{0}]".format(graph_path))
            nx.write_graphml(Docs_G, graph_path)
        # Docs_adj = nx.adjacency_matrix(Docs_G)
        if get_stats:
            Docs_G_stats = self.graph_stats(Docs_G)
            File_Util.save_json(Docs_G_stats,
                                filename=self.dataset_name + "_G_stats",
                                overwrite=True,
                                filepath=join(self.graph_dir,
                                              self.dataset_name))
            return Docs_G, Docs_G_stats
        return Docs_G

    @staticmethod
    def graph_stats(G):
        """ Generates and returns graph related statistics.

        :param G: Graph in Netwokx format.
        :return: dict
        """
        G_stats = OrderedDict()
        G_stats["info"] = nx.info(G)
        logger.debug("info: [{0}]".format(G_stats["info"]))
        G_stats["degree_sequence"] = sorted([d for n, d in G.degree()],
                                            reverse=True)
        # logger.debug("degree_sequence: {0}".format(G_stats["degree_sequence"]))
        G_stats["dmax"] = max(G_stats["degree_sequence"])
        logger.debug("dmax: [{0}]".format(G_stats["dmax"]))
        G_stats["dmin"] = min(G_stats["degree_sequence"])
        logger.debug("dmin: [{0}]".format(G_stats["dmin"]))
        G_stats["node_count"] = nx.number_of_nodes(G)
        # logger.debug("node_count: [{0}]".format(G_stats["node_count"]))
        G_stats["edge_count"] = nx.number_of_edges(G)
        # logger.debug("edge_count: [{0}]".format(G_stats["edge_count"]))
        G_stats["density"] = nx.density(G)
        logger.debug("density: [{0}]".format(G_stats["density"]))
        if nx.is_connected(G):
            G_stats["radius"] = nx.radius(G)
            logger.debug("radius: [{0}]".format(G_stats["radius"]))
            G_stats["diameter"] = nx.diameter(G)
            logger.debug("diameter: [{0}]".format(G_stats["diameter"]))
            G_stats["eccentricity"] = nx.eccentricity(G)
            logger.debug("eccentricity: [{0}]".format(G_stats["eccentricity"]))
            G_stats["center"] = nx.center(G)
            logger.debug("center: [{0}]".format(G_stats["center"]))
            G_stats["periphery"] = nx.periphery(G)
            logger.debug("periphery: [{0}]".format(G_stats["periphery"]))
        else:
            logger.warning("The graph in not connected.")
            G_comps = nx.connected_components(G)
            logger.debug(
                [len(c) for c in sorted(G_comps, key=len, reverse=True)])

        return G_stats

    def find_single_labels(self):
        """ Finds the number of samples with only single label. """
        single_labels = []
        for i, t in enumerate(self.sample2cats):
            if len(t) == 1:
                single_labels.append(i)
        if single_labels:
            logger.debug(
                len(single_labels),
                'samples has only single category. These categories will not occur in the'
                'co-occurrence graph.')
        return len(single_labels)

    @staticmethod
    def plot_occurance(E,
                       plot_name=config["graph"]["plot_name"],
                       clear=True,
                       log=False):
        """

        :param E:
        :param plot_name:
        :param clear:
        :param log:
        """
        from matplotlib import pyplot as plt

        plt.plot(E)
        plt.xlabel("Documents")
        if log:
            plt.yscale('log')
        plt.ylabel("Category co-occurance")
        plt.title("Documents degree distribution (sorted)")
        plt.savefig(plot_name)
        plt.show()
        if clear:
            plt.cla()

    def get_subgraph(self,
                     V,
                     E,
                     dataset_name,
                     level=config["graph"]["level"],
                     root_node=config["graph"]["root_node"],
                     subgraph_count=config["graph"]["subgraph_count"],
                     ignore_deg=config["graph"]["ignore_deg"]):
        """ Generates a subgraph of [level] hops starting from [root_node] node.

        # total_points: total number of samples.
        # feature_dm: number of features per sample.
        # number_of_labels: total number of categories.
        # X: feature matrix of dimension total_points * feature_dm.
        # sample2cats: list of size total_points. Each element of the list containing categories corresponding to one sample.
        # V: list of all categories (nodes).
        # E: dict of edge tuple(node_1,node_2) -> weight, eg. {(1, 4): 1, (2, 7): 3}.
        """
        # build a unweighted graph of all edges
        g = nx.Graph()
        g.add_edges_from(E.keys())

        # Below section will try to build a smaller subgraph from the actual graph for visualization
        subgraph_lists = []
        for sg in range(subgraph_count):
            if root_node is None:
                # select a random vertex to be the root
                np.random.shuffle(V)
                v = V[0]
            else:
                v = root_node

            # two files to write the graph and label information
            # Remove characters like \, /, <, >, :, *, |, ", ? from file names,
            # windows can not have file name with these characters
            label_info_filepath = 'samples/' + str(
                dataset_name) + '_Info[{}].txt'.format(
                    str(int(v)) + '-' +
                    File_Util.remove_special_chars(self.cat_id2text_map[v]))
            label_graph_filepath = 'samples/' + str(
                dataset_name) + '_G[{}].graphml'.format(
                    str(int(v)) + '-' +
                    File_Util.remove_special_chars(self.cat_id2text_map[v]))
            # label_graph_el = 'samples/'+str(dataset_name)+'_E[{}].el'.format(str(int(v)) + '-'
            # + self.cat_id2text_map[v]).replace(' ','_')

            logger.debug('Label:[' + self.cat_id2text_map[v] + ']')
            label_info_file = open(label_info_filepath, 'w')
            label_info_file.write('Label:[' + self.cat_id2text_map[v] + ']' +
                                  "\n")

            # build the subgraph using bfs
            bfs_q = Queue()
            bfs_q.put(v)
            bfs_q.put(0)
            node_check = OrderedDict()
            ignored = []

            sub_g = nx.Graph()
            lvl = 0
            while not bfs_q.empty() and lvl <= level:
                v = bfs_q.get()
                if v == 0:
                    lvl += 1
                    bfs_q.put(0)
                    continue
                elif node_check.get(v, True):
                    node_check[v] = False
                    edges = list(g.edges(v))
                    # label_info_file.write('\nNumber of edges: ' + str(len(edges)) + ' for node: '
                    # + self.cat_id2text_map[v] + '[' + str(v) + ']' + '\n')
                    if ignore_deg is not None and len(edges) > ignore_deg:
                        # label_info_file.write('Ignoring: [' + self.cat_id2text_map[v] + '] \t\t\t degree: ['
                        # + str(len(edges)) + ']\n')
                        ignored.append("Ignoring: deg [" +
                                       self.cat_id2text_map[v] + "] = [" +
                                       str(len(edges)) + "]\n")
                        continue
                    for uv_tuple in edges:
                        edge = tuple(sorted(uv_tuple))
                        sub_g.add_edge(edge[0], edge[1], weight=E[edge])
                        bfs_q.put(uv_tuple[1])
                else:
                    continue

            # relabel the nodes to reflect textual label
            nx.relabel_nodes(sub_g, self.cat_id2text_map, copy=False)
            logger.debug('sub_g: [{0}]'.format(sub_g))

            label_info_file.write(str('\n'))
            # Writing some statistics about the subgraph
            label_info_file.write(str(nx.info(sub_g)) + '\n')
            label_info_file.write('density: ' + str(nx.density(sub_g)) + '\n')
            label_info_file.write(
                'list of the frequency of each degree value [degree_histogram]: '
                + str(nx.degree_histogram(sub_g)) + '\n')
            for nodes in ignored:
                label_info_file.write(str(nodes) + '\n')
            # subg_edgelist = nx.generate_edgelist(sub_g,label_graph_el)
            label_info_file.close()
            nx.write_graphml(sub_g, label_graph_filepath)

            subgraph_lists.append(sub_g)

            logger.info(
                'Sub graph generated at: [{0}]'.format(label_graph_filepath))

            if root_node:
                logger.info(
                    "Root node provided, will generate only one graph file.")
                break

        return subgraph_lists