def _create_fixed_cat_data(
            self,
            txts: OrderedDict,
            classes: OrderedDict,
            fixed5_cats: list = None,
            catid2cattxt_map=None) -> (OrderedDict, OrderedDict, OrderedDict):
        """Creates a dataset of samples which belongs to any of the below 5 sample2cats only.

        Selected sample2cats: [114, 3178, 3488, 1922, 517], these sample2cats has max number of samples associated with them.
        NOTE: This method is used only for sanity testing using fixed multi-class scenario.
        """
        if fixed5_cats is None: fixed5_cats = [114, 3178, 3488, 1922, 3142]
        if catid2cattxt_map is None:
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)
        txts_one_fixed5 = OrderedDict()
        classes_one_fixed5 = OrderedDict()
        categories_one_fixed5 = OrderedDict()
        for doc_id, lbls in classes.items():
            if lbls[0] in fixed5_cats:
                classes_one_fixed5[doc_id] = lbls
                txts_one_fixed5[doc_id] = txts[doc_id]
                for lbl in classes_one_fixed5[doc_id]:
                    if lbl not in categories_one_fixed5:
                        categories_one_fixed5[catid2cattxt_map[str(lbl)]] = lbl

        return txts_one_fixed5, classes_one_fixed5, categories_one_fixed5
示例#2
0
    def gen_sample2vec_map(self,txts: dict,vectorizer_model=None):
        """
        Generates a dict of sample text to it's vector map.

        :param vectorizer_model: Doc2Vec model object.
        :param txts:
        :return:
        """
        if self.txts2vec_map is not None:
            return self.txts2vec_map
        else:
            if txts is None: txts = File_Util.load_json(filename=self.dataset_name + "_txts",
                                                        filepath=join(self.dataset_dir,self.dataset_name))
            if vectorizer_model is None:  ## If model is not supplied, load model.
                if self.vectorizer_model is None:
                    self.vectorizer_model = self.text_encoder.load_word2vec()
                vectorizer_model = self.vectorizer_model
            txts2vec_dict = OrderedDict()
            for sample_id,txt in txts.items():
                tokens = self.tokenizer_spacy(txt)
                tokens_vec = self.get_vecs_from_tokens(tokens,vectorizer_model)
                txts2vec_dict[sample_id] = tokens_vec  ## Generate vector for a new sample.
                # txts2vec_dict[sample_id] = vectorizer_model.infer_vector(self.tokenizer_spacy(txt))  ## Generate vector for a new sample using Doc2Vec model only.

        self.txts2vec_map = txts2vec_dict
        return self.txts2vec_map
    def check_cat_present_txt(
            self,
            txts: OrderedDict,
            classes: OrderedDict,
            catid2cattxt_map: OrderedDict = None) -> OrderedDict:
        """Generates a dict of dicts containing the positions of all categories within each text.

        :param classes:
        :param txts:
        :param catid2cattxt_map:
        :return:
        """
        if catid2cattxt_map is None:
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)
        label_ptrs = OrderedDict()
        for doc_id, txt in txts.items():
            label_ptrs[doc_id] = OrderedDict()
            for lbl_id in catid2cattxt_map:
                label_ptrs[doc_id][lbl_id] = self.clean.find_label_occurrences(
                    txt, catid2cattxt_map[str(lbl_id)])
                label_ptrs[doc_id]["true"] = classes[doc_id]

        return label_ptrs
    def _create_pointer_data(
        self,
        txts: OrderedDict,
        classes: OrderedDict,
        catid2cattxt_map: OrderedDict = None
    ) -> (OrderedDict, OrderedDict, OrderedDict):
        """ Creates pointer network type dataset, i.e. labels are marked within document text. """
        if catid2cattxt_map is None:
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)
        txts_ptr = OrderedDict()
        classes_ptr = OrderedDict()
        categories_ptr = OrderedDict()
        for doc_id, lbl_ids in classes.items():
            for lbl_id in lbl_ids:
                label_ptrs = self.clean.find_label_occurrences(
                    txts[doc_id], catid2cattxt_map[str(lbl_id)])
                if label_ptrs:  ## Only if categories exists within the document.
                    classes_ptr[doc_id] = {lbl_id: label_ptrs}
                    txts_ptr[doc_id] = txts[doc_id]

                    if lbl_id not in categories_ptr:
                        categories_ptr[lbl_id] = catid2cattxt_map[str(lbl_id)]

        return txts_ptr, classes_ptr, categories_ptr
    def _create_fewshot_data(
        self,
        txts: OrderedDict,
        classes: OrderedDict,
        catid2cattxt_map: OrderedDict = None
    ) -> (OrderedDict, OrderedDict, OrderedDict):
        """Creates few-shot dataset, i.e. categories with <= 20 samples.

        :param classes:
        :param txts:
        :param catid2cattxt_map:
        :return:
        """
        if catid2cattxt_map is None:
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)

        tail_cats, samples_with_tail_cats, cat2samples_filtered = self.find_cats_with_few_samples(
            sample2cats_map=classes, cat2samples_map=None)
        txts_few = OrderedDict()
        classes_few = OrderedDict()
        categories_few = OrderedDict()
        for doc_id, lbls in classes.items():
            if len(lbls) == 1:
                classes_few[doc_id] = lbls
                txts_few[doc_id] = txts[doc_id]
                for lbl in classes_few[doc_id]:
                    if lbl not in categories_few:
                        categories_few[catid2cattxt_map[str(lbl)]] = lbl

        return txts_few, classes_few, categories_few
    def _create_firstsent_data(
        self,
        txts: OrderedDict,
        classes: OrderedDict,
        catid2cattxt_map: OrderedDict = None
    ) -> (OrderedDict, OrderedDict, OrderedDict):
        """Creates a version of wikipedia dataset with only first sentences and discarding the text.

        :param classes:
        :param txts:
        :param catid2cattxt_map:
        :return:
        """
        if catid2cattxt_map is None:
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)

        txts_firstsent = OrderedDict()
        classes_firstsent = OrderedDict()
        categories_firstsent = OrderedDict()
        for doc_id, lbls in classes.items():
            if len(lbls) == 1:
                classes_firstsent[doc_id] = lbls
                txts_firstsent[doc_id] = txts[doc_id]
                for lbl in classes_firstsent[doc_id]:
                    if lbl not in categories_firstsent:
                        categories_firstsent[catid2cattxt_map[str(lbl)]] = lbl

        return txts_firstsent, classes_firstsent, categories_firstsent
    def multilabel2multiclass_df(self, df: pd.DataFrame):
        """Converts Multi-Label data in DataFrame format to Multi-Class data by replicating the samples.

        :param df: Dataframe containing repeated sample id and it's associated category.
        :returns: DataFrame with replicated samples.
        """
        if self.catid2cattxt_map is None:
            self.catid2cattxt_map = File_Util.load_json(
                filename=self.dataset_name + "_catid2cattxt_map",
                filepath=self.dataset_dir)
        idxs, cat = [], []
        for row in df.values:
            lbls = row[3][1:-1].split(
                ','
            )  ## When DataFrame is saved as csv, list is converted to str
            for lbl in lbls:
                lbl = lbl.strip()
                idxs.append(row[1])
                cat.append(lbl)

        df = pd.DataFrame.from_dict({"idx": idxs, "cat": cat})
        df = df[~df['cat'].isna()]
        df.to_csv(path_or_buf=join(self.dataset_dir, self.dataset_name +
                                   "_multiclass_df.csv"))

        logger.info("Data shape = {} ".format(df.shape))

        return df
    def _create_oneclass_data(
        self,
        txts: OrderedDict,
        classes: OrderedDict,
        catid2cattxt_map: OrderedDict = None
    ) -> (OrderedDict, OrderedDict, OrderedDict):
        """Creates a dataset which belongs to single class only.

        NOTE: This method is used only for sanity testing using multi-class scenario.
        """
        if catid2cattxt_map is None:
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)
        txts_one = OrderedDict()
        classes_one = OrderedDict()
        categories_one = OrderedDict()
        for doc_id, lbls in classes.items():
            if len(lbls) == 1:
                classes_one[doc_id] = lbls
                txts_one[doc_id] = txts[doc_id]
                for lbl in classes_one[doc_id]:
                    if lbl not in categories_one:
                        categories_one[catid2cattxt_map[str(lbl)]] = lbl

        return txts_one, classes_one, categories_one
示例#9
0
    def load_doc_neighborhood_graph(
            self,
            nodes,
            graph_path=None,
            get_stats: bool = config["graph"]["stats"]):
        """ Loads the graph file if found else creates neighborhood graph.

        :param nodes: List of node ids to consider.
        :param get_stats:
        :param graph_path: Full path to the graphml file.
        :return: Networkx graph, Adjecency matrix, stats related to the graph.
        """

        if graph_path is None:
            graph_path = join(
                self.graph_dir, self.dataset_name,
                self.dataset_name + "_G_" + str(len(nodes)) + ".graphml")
        if exists(graph_path):
            logger.info(
                "Loading neighborhood graph from [{0}]".format(graph_path))
            Docs_G = nx.read_graphml(graph_path)
        else:
            self.sample2cats = File_Util.load_json(
                join(self.graph_dir, self.dataset_name,
                     self.dataset_name + "_sample2cats"))
            self.categories = File_Util.load_json(
                join(self.graph_dir, self.dataset_name,
                     self.dataset_name + "_cats"))
            self.cat_id2text_map = File_Util.load_json(
                join(self.graph_dir, self.dataset_name,
                     self.dataset_name + "_catid2cattxt_map"))
            Docs_G = self.create_neighborhood_graph(nodes=nodes)
            logger.debug(nx.info(Docs_G))
            logger.info(
                "Saving neighborhood graph at [{0}]".format(graph_path))
            nx.write_graphml(Docs_G, graph_path)
        # Docs_adj = nx.adjacency_matrix(Docs_G)
        if get_stats:
            Docs_G_stats = self.graph_stats(Docs_G)
            File_Util.save_json(Docs_G_stats,
                                filename=self.dataset_name + "_G_stats",
                                overwrite=True,
                                filepath=join(self.graph_dir,
                                              self.dataset_name))
            return Docs_G, Docs_G_stats
        return Docs_G
示例#10
0
    def load_all(self) -> (OrderedDict, OrderedDict, OrderedDict):
        """Loads and returns the whole data."""
        logger.debug(join(self.dataset_dir, self.dataset_name + "_txts.json"))
        if self.txts_sel is None:
            if isfile(join(self.dataset_dir,
                           self.dataset_name + "_txts.json")):
                self.txts_sel = File_Util.load_json(self.dataset_name +
                                                    "_txts",
                                                    filepath=self.dataset_dir)
            else:
                self.txts_sel, self.sample2cats_sel, self.cats_sel = self.load_full_json(
                    return_values=True)

        if self.sample2cats_sel is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_sample2cats.json")):
                self.sample2cats_sel = File_Util.load_json(
                    self.dataset_name + "_sample2cats",
                    filepath=self.dataset_dir)
            else:
                self.txts_sel, self.sample2cats_sel, self.cats_sel = self.load_full_json(
                    return_values=True)

        if self.cats_sel is None:
            if isfile(join(self.dataset_dir,
                           self.dataset_name + "_cats.json")):
                self.cats_sel = File_Util.load_json(self.dataset_name +
                                                    "_cats",
                                                    filepath=self.dataset_dir)
            else:
                self.txts_sel, self.sample2cats_sel, self.cats_sel = self.load_full_json(
                    return_values=True)
        collect()

        logger.info(
            "Total data counts:\n\ttxts = [{}],\n\tsample2cats = [{}],\n\tcattext2catid_map = [{}]"
            .format(len(self.txts_sel), len(self.sample2cats_sel),
                    len(self.cats_sel)))
        return self.txts_sel, self.sample2cats_sel, self.cats_sel
示例#11
0
    def load_train(self) -> (OrderedDict, OrderedDict, OrderedDict):
        """Loads and returns training set."""
        logger.debug(
            join(self.dataset_dir, self.dataset_name + "_txts_train.json"))
        if self.txts_train is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_txts_train.json")):
                self.txts_train = File_Util.load_json(
                    self.dataset_name + "_txts_train",
                    filepath=self.dataset_dir)
            else:
                self.load_full_json()

        if self.sample2cats_train is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_sample2cats_train.json")):
                self.sample2cats_train = File_Util.load_json(
                    self.dataset_name + "_sample2cats_train",
                    filepath=self.dataset_dir)
            else:
                self.load_full_json()

        if self.cats_sel is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_cats_train.json")):
                self.cats_sel = File_Util.load_json(self.dataset_name +
                                                    "_cats_train",
                                                    filepath=self.dataset_dir)
            else:
                self.load_full_json()
        collect()

        # logger.info("Training data counts:\n\ttxts = [{}],\n\tClasses = [{}],\n\tCategories = [{}]"
        #             .format(len(self.txts_train), len(self.sample2cats_train), len(self.cats_train)))
        return self.txts_train, self.sample2cats_train, self.cats_sel
示例#12
0
    def cat_token_counts(self, catid2cattxt_map=None):
        """ Counts the number of tokens in categories.

        :return:
        :param catid2cattxt_map:
        """
        if catid2cattxt_map is None:
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)
        cat_word_counts = {}
        for cat in catid2cattxt_map:
            cat_word_counts[cat] = len(self.clean.tokenizer_spacy(cat))

        return cat_word_counts
示例#13
0
    def load_val(self) -> (OrderedDict, OrderedDict, OrderedDict):
        """Loads and returns validation set."""
        if self.txts_val is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_txts_val.json")):
                self.txts_val = File_Util.load_json(self.dataset_name +
                                                    "_txts_val",
                                                    filepath=self.dataset_dir)
            else:
                self.load_full_json()

        if self.sample2cats_val is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_sample2cats_val.json")):
                self.sample2cats_val = File_Util.load_json(
                    self.dataset_name + "_sample2cats_val",
                    filepath=self.dataset_dir)
            else:
                self.load_full_json()

        if self.cats_val is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_cats_val.json")):
                self.cats_val = File_Util.load_json(self.dataset_name +
                                                    "_cats_val",
                                                    filepath=self.dataset_dir)
            else:
                self.load_full_json()
        collect()

        # logger.info("Validation data counts:\n\ttxts = [{}],\n\tClasses = [{}],\n\tCategories = [{}]"
        #             .format(len(self.txts_val), len(self.sample2cats_val), len(self.cats_val)))
        return self.txts_val, self.sample2cats_val, self.cats_val
示例#14
0
    def load_test(self) -> (OrderedDict, OrderedDict, OrderedDict):
        """Loads and returns test set."""
        if self.txts_test is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_txts_test.json")):
                self.txts_test = File_Util.load_json(self.dataset_name +
                                                     "_txts_test",
                                                     filepath=self.dataset_dir)
            else:
                self.load_full_json()

        if self.sample2cats_test is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_sample2cats_test.json")):
                self.sample2cats_test = File_Util.load_json(
                    self.dataset_name + "_sample2cats_test",
                    filepath=self.dataset_dir)
            else:
                self.load_full_json()

        if self.cats_test is None:
            if isfile(
                    join(self.dataset_dir,
                         self.dataset_name + "_cats_test.json")):
                self.cats_test = File_Util.load_json(self.dataset_name +
                                                     "_cats_test",
                                                     filepath=self.dataset_dir)
            else:
                self.load_full_json()
        collect()

        # logger.info("Testing data counts:\n\ttxts = [{}],\n\tClasses = [{}],\n\tCategories = [{}]"
        #             .format(len(self.txts_test), len(self.sample2cats_test), len(self.cats_test)))
        return self.txts_test, self.sample2cats_test, self.cats_test
示例#15
0
 def load_categories(self) -> OrderedDict:
     """Loads and returns the whole categories set."""
     if self.cattext2catid_map is None:
         logger.debug(
             join(self.dataset_dir,
                  self.dataset_name + "_cattext2catid_map.json"))
         if isfile(
                 join(self.dataset_dir,
                      self.dataset_name + "_cattext2catid_map.json")):
             self.cattext2catid_map = File_Util.load_json(
                 self.dataset_name + "_cattext2catid_map",
                 filepath=self.dataset_dir)
         else:
             _, _, self.cattext2catid_map = self.load_full_json(
                 return_values=True)
     return self.cattext2catid_map
示例#16
0
    def json2csv(self, txts_all: dict = None, sample2cats_all: dict = None):
        """ Converts existing multiple json files and returns a single pandas dataframe.

        :param txts_all:
        :param sample2cats_all:
        """
        if exists(join(self.dataset_dir, self.dataset_name + "_df.csv")):
            df = pd.read_csv(
                filepath_or_buffer=join(self.dataset_dir, self.dataset_name +
                                        "_df.csv"))
            df = df[~df['txts'].isna()]
        else:
            if txts_all is None or sample2cats_all is None:
                txts_all, sample2cats_all, cats_all, cats_all = self.get_data(
                    load_type="all")
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)
            txts_all_list,sample2cats_all_list,idxs,sample2catstext_all_list = [],[],[],[]
            for idx in sample2cats_all.keys():
                idxs.append(idx)
                txts_all_list.append(txts_all[idx])
                sample2cats_all_list.append(sample2cats_all[idx])
                sample2catstext = []
                for lbl in sample2cats_all[idx]:
                    sample2catstext.append(catid2cattxt_map[str(lbl)])
                sample2catstext_all_list.append(sample2catstext)

            df = pd.DataFrame.from_dict({
                "idx": idxs,
                "txts": txts_all_list,
                "cat": sample2cats_all_list,
                "cat_txt": sample2catstext_all_list
            })
            df = df[~df['txts'].isna()]
            df.to_csv(path_or_buf=join(self.dataset_dir, self.dataset_name +
                                       "_df.csv"))
        logger.info("Data shape = {} ".format(df.shape))
        return df
示例#17
0
    def calculate_idf_per_token(self,txts: list,subtract: int = 1) -> dict:
        """ Calculates tfidf scores for each token in the corpus.

        :param txts:
        :param subtract: Removes this value from idf scores. Sometimes needed to get better scores.
        :return: Dict of token to idf score.
        """
        logger.info("Calculating IDF for each token.")
        if isfile(join(self.dataset_dir,self.dataset_name + "_tfidf_dict.json")):
            idf_dict = File_Util.load_json(filename=self.dataset_name + "_idf_dict",filepath=self.dataset_dir)
        else:
            from sklearn.feature_extraction.text import TfidfVectorizer
            ## Using TfidfVectorizer with spacy tokenizer; same tokenizer should be used everywhere.
            vectorizer = TfidfVectorizer(decode_error='ignore',lowercase=False,smooth_idf=False, sublinear_tf=True, stop_words='english', ngram_range=(1,1), max_df=0.7, vocabulary=None,
                                         tokenizer=self.tokenizer_spacy)
            tfidf_matrix = vectorizer.fit_transform(txts)
            idf = vectorizer.idf_
            idf_dict = dict(zip(vectorizer.get_feature_names(),idf - subtract))  ## Subtract 1 from idf to get better scores
            ignored_tokens = vectorizer.stop_words_

            File_Util.save_json(idf_dict,filename=self.dataset_name + "_idf_dict",filepath=self.dataset_dir)

        return idf_dict
示例#18
0
    def load_full_json(self, return_values: bool = False):
        """
        Loads full dataset and splits the data into train, val and test.
        """
        if isfile(join(self.dataset_dir,self.dataset_name + "_txts.json"))\
                and isfile(
            join(self.dataset_dir,self.dataset_name + "_sample2cats.json"))\
                and isfile(
            join(self.dataset_dir,self.dataset_name + "_cats.json")):
            logger.info("Loading pre-processed json files from: [{}]".format(
                join(self.dataset_dir, self.dataset_name + "_txts.json")))
            txts = File_Util.load_json(self.dataset_name + "_txts",
                                       filepath=self.dataset_dir,
                                       show_path=True)
            classes = File_Util.load_json(self.dataset_name + "_sample2cats",
                                          filepath=self.dataset_dir,
                                          show_path=True)
            categories = File_Util.load_json(self.dataset_name + "_cats",
                                             filepath=self.dataset_dir,
                                             show_path=True)
            assert len(txts) == len(classes),\
                "Count of txts [{0}] and sample2cats [{1}] should match.".format(
                    len(txts),len(classes))
        else:
            logger.warn("Pre-processed json files not found at: [{}]".format(
                join(self.dataset_dir, self.dataset_name + "_txts.json")))
            logger.info(
                "Loading raw data and creating 3 separate dicts of txts [id->texts], sample2cats [id->class_ids]"
                " and categories [class_name : class_id].")
            txts, classes, categories = self.load_raw_data(self.dataset_type)
            File_Util.save_json(categories,
                                self.dataset_name + "_cats",
                                filepath=self.dataset_dir)
            File_Util.save_json(txts,
                                self.dataset_name + "_txts",
                                filepath=self.dataset_dir)
            File_Util.save_json(classes,
                                self.dataset_name + "_sample2cats",
                                filepath=self.dataset_dir)
            logger.info("Cleaning categories.")
            categories, categories_dup_dict, dup_cat_text_map = self.clean.clean_categories(
                categories)
            File_Util.save_json(dup_cat_text_map,
                                self.dataset_name + "_dup_cat_text_map",
                                filepath=self.dataset_dir,
                                overwrite=True)
            File_Util.save_json(categories,
                                self.dataset_name + "_cats",
                                filepath=self.dataset_dir,
                                overwrite=True)
            if categories_dup_dict:  # Replace old category ids with new ids if duplicate categories found.
                File_Util.save_json(
                    categories_dup_dict,
                    self.dataset_name + "_categories_dup_dict",
                    filepath=self.dataset_dir,
                    overwrite=True
                )  # Storing the duplicate categories for future dedup removal.
                classes = self.clean.dedup_data(classes, categories_dup_dict)
            assert len(txts) == len(classes),\
                "Count of txts [{0}] and sample2cats [{1}] should match.".format(
                    len(txts),len(classes))
            File_Util.save_json(txts,
                                self.dataset_name + "_txts",
                                filepath=self.dataset_dir,
                                overwrite=True)
            File_Util.save_json(classes,
                                self.dataset_name + "_sample2cats",
                                filepath=self.dataset_dir,
                                overwrite=True)
            logger.info(
                "Saved txts [{0}], sample2cats [{1}] and categories [{2}] as json files."
                .format(join(self.dataset_dir + "_txts.json"),
                        join(self.dataset_dir + "_sample2cats.json"),
                        join(self.dataset_dir + "_cats.json")))
        if return_values:
            return txts, classes, categories
        else:
            # Splitting data into train, validation and test sets.
            self.txts_train,self.sample2cats_train,self.cats_sel,self.txts_val,self.sample2cats_val,\
            self.cats_val,self.txts_test,self.sample2cats_test,self.cats_test,catid2cattxt_map =\
                self.split_data(txts=txts,classes=classes,categories=categories)
            txts, classes, categories = None, None, None  # Remove large dicts and free up memory.
            collect()

            File_Util.save_json(self.txts_test,
                                self.dataset_name + "_txts_test",
                                filepath=self.dataset_dir)
            File_Util.save_json(self.sample2cats_test,
                                self.dataset_name + "_sample2cats_test",
                                filepath=self.dataset_dir)
            File_Util.save_json(self.txts_val,
                                self.dataset_name + "_txts_val",
                                filepath=self.dataset_dir)
            File_Util.save_json(self.sample2cats_val,
                                self.dataset_name + "_sample2cats_val",
                                filepath=self.dataset_dir)
            File_Util.save_json(self.txts_train,
                                self.dataset_name + "_txts_train",
                                filepath=self.dataset_dir)
            File_Util.save_json(self.sample2cats_train,
                                self.dataset_name + "_sample2cats_train",
                                filepath=self.dataset_dir)
            File_Util.save_json(self.cats_sel,
                                self.dataset_name + "_cats_train",
                                filepath=self.dataset_dir)
            File_Util.save_json(self.cats_val,
                                self.dataset_name + "_cats_val",
                                filepath=self.dataset_dir)
            File_Util.save_json(self.cats_test,
                                self.dataset_name + "_cats_test",
                                filepath=self.dataset_dir)
            File_Util.save_json(catid2cattxt_map,
                                self.dataset_name + "_catid2cattxt_map",
                                filepath=self.dataset_dir)
            return self.txts_train,self.sample2cats_train,self.cats_sel,self.txts_val,self.sample2cats_val,\
                   self.cats_val,self.txts_test,self.sample2cats_test,self.cats_test
示例#19
0
    def split_data(self,
                   txts: OrderedDict,
                   classes: OrderedDict,
                   categories: OrderedDict,
                   test_split: int = config["data"]["test_split"],
                   val_split: int = config["data"]["val_split"]):
        """ Splits input data into train, val and test.

        :return:
        :param categories:
        :param classes:
        :param txts:
        :param val_split: Validation split size.
        :param test_split: Test split size.
        :return:
        """
        logger.info("Total number of samples: [{}]".format(len(classes)))
        sample2cats_train,sample2cats_test,txts_train,txts_test =\
            File_Util.split_dict(classes,txts,
                                 batch_size=int(len(classes) * test_split))
        logger.info("Test count: [{}]. Remaining count: [{}]".format(
            len(sample2cats_test), len(sample2cats_train)))

        sample2cats_train,sample2cats_val,txts_train,txts_val =\
            File_Util.split_dict(sample2cats_train,txts_train,
                                 batch_size=int(len(txts_train) * val_split))
        logger.info("Validation count: [{}]. Train count: [{}]".format(
            len(sample2cats_val), len(sample2cats_train)))

        if isfile(
                join(self.dataset_dir,
                     self.dataset_name + "_catid2cattxt_map.json")):
            catid2cattxt_map = File_Util.load_json(self.dataset_name +
                                                   "_catid2cattxt_map",
                                                   filepath=self.dataset_dir)
            # Integer keys are converted to string when saving as JSON. Converting back to integer.
            catid2cattxt_map_int = OrderedDict()
            for k, v in catid2cattxt_map.items():
                catid2cattxt_map_int[int(k)] = v
            catid2cattxt_map = catid2cattxt_map_int
        else:
            logger.info("Generating inverted categories.")
            catid2cattxt_map = File_Util.inverse_dict_elm(categories)

        logger.info("Creating train categories.")
        cats_train = OrderedDict()
        for k, v in sample2cats_train.items():
            for cat_id in v:
                if cat_id not in cats_train:
                    cats_train[cat_id] = catid2cattxt_map[cat_id]
        cats_train = cats_train

        logger.info("Creating validation categories.")
        cats_val = OrderedDict()
        for k, v in sample2cats_val.items():
            for cat_id in v:
                if cat_id not in cats_val:
                    cats_val[cat_id] = catid2cattxt_map[cat_id]
        cats_val = cats_val

        logger.info("Creating test categories.")
        cats_test = OrderedDict()
        for k, v in sample2cats_test.items():
            for cat_id in v:
                if cat_id not in cats_test:
                    cats_test[cat_id] = catid2cattxt_map[cat_id]
        cats_test = cats_test
        return txts_train, sample2cats_train, cats_train, txts_val, sample2cats_val, cats_val, txts_test, sample2cats_test, cats_test, catid2cattxt_map
示例#20
0
def main(args):
    """
    Main function to run Matching Networks for Extreme Classification.

    :param args: Dict of all the arguments.
    """
    ## Training Phase
    data_loader = Common_Data_Handler()
    data_formatter = Prepare_Data(dataset_loader=data_loader)
    txts, sample2cats, _, cats = data_formatter.load_raw_data(load_type='all')
    txts2vec_map, cats2vec_map = data_formatter.create_vec_maps()
    logger.debug((len(txts2vec_map), len(cats2vec_map)))

    input_vecs, cats_hot, keys, cats_idx = data_formatter.get_input_batch(
        txts2vec_map, sample2cats, return_cat_indices=True, multi_label=False)
    logger.debug(input_vecs.shape)

    input_adj_coo = data_formatter.load_graph_data(keys)
    logger.debug(input_adj_coo.shape)

    idx_train = torch.LongTensor(range(int(input_vecs.shape[0] * 0.7)))
    idx_val = torch.LongTensor(
        range(int(input_vecs.shape[0] * 0.7), int(input_vecs.shape[0] * 0.8)))
    idx_test = torch.LongTensor(
        range(int(input_vecs.shape[0] * 0.8), int(input_vecs.shape[0])))
    # logger.debug(idx_train)
    # logger.debug(idx_val)
    # logger.debug(idx_test)

    # input_vecs = torch.FloatTensor(input_vecs)
    input_vecs = Variable(torch.from_numpy(input_vecs),
                          requires_grad=True).float()
    cats_idx = Variable(torch.from_numpy(cats_idx),
                        requires_grad=False).float()
    # cats_idx = torch.LongTensor(cats_idx)
    input_adj_coo_t = adj_csr2t_coo(input_adj_coo)
    # input_adj_coo_t = input_adj_coo_t.requires_grad
    logger.debug(input_adj_coo_t.shape)

    # Model and optimizer
    model = GCN(nfeat=input_vecs.shape[1],
                nhid=args.hidden,
                nclass=cats_hot.shape[1],
                dropout=args.dropout)

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    filepath = config["paths"]["dataset_dir"][plat][user]
    dataset = config["data"]["dataset_name"]
    samples2cats_map = File_Util.load_json(filename=dataset + "_sample2cats",
                                           filepath=join(filepath, dataset))
    _, label_embs = create_lbl_embs(samples2cats_map, cats2vec_map)

    # label_embs = torch.FloatTensor(label_embs)
    label_embs = Variable(torch.from_numpy(label_embs),
                          requires_grad=True).float()

    # Train model
    train_losses,train_accs,val_losses,val_accs,train_times = [],[],[],[],[]
    t_total = time.time()
    for epoch in range(args.epochs):
        # train_losses.append(train(epoch,model,optimizer,input_vecs,input_adj_coo_t.float(),cats_idx,idx_train,idx_val))
        # loss_train,acc_train,loss_val,acc_val,time_taken =\
        loss_train, acc_train, loss_val, acc_val, time_taken = train_emb(
            epoch=epoch,
            model=model,
            optimizer=optimizer,
            features=input_vecs,
            adj=input_adj_coo_t.float(),
            label_emb=label_embs,
            labels=cats_idx,
            idx_train=idx_train,
            idx_val=idx_val)
        collect()
        # torch.empty_cache()
        train_losses.append(loss_train)
        train_accs.append(acc_train)
        val_losses.append(loss_val)
        val_accs.append(acc_val)
        train_times.append(time_taken)
        logger.info(
            "\nLayer1 weights sum:[{}] \nLayer2 weights sum:[{}]".format(
                torch.sum(model.gc1.weight.data),
                torch.sum(model.gc2.weight.data)))
    logger.info("Optimization Finished!")
    _, train_features = model(input_vecs, input_adj_coo_t.float())
    # W1 = model.gc1.weight.data
    logger.info("Layer 1 weight matrix shape: [{}]".format(
        model.gc1.weight.data.shape))
    logger.info("Layer 2 weight matrix shape: [{}]".format(
        model.gc2.weight.data.shape))
    logger.info("Total time elapsed: {:.4f}s".format(time.time() - t_total))
    plot_occurance(train_losses,
                   plot_name="train_losses_" + str(args.epochs) + ".jpg",
                   title="Train Losses",
                   plot_dir=str(args.epochs))
    plot_occurance(train_accs,
                   plot_name="train_accs_" + str(args.epochs) + ".jpg",
                   ylabel="Accuracy",
                   title="Train Accuracy",
                   plot_dir=str(args.epochs))
    plot_occurance(val_losses,
                   plot_name="val_losses_" + str(args.epochs) + ".jpg",
                   title="Validation Losses",
                   plot_dir=str(args.epochs))
    plot_occurance(val_accs,
                   plot_name="val_accs_" + str(args.epochs) + ".jpg",
                   ylabel="Accuracy",
                   title="Validation Accuracy",
                   plot_dir=str(args.epochs))
    plot_occurance(train_times,
                   plot_name="train_time_" + str(args.epochs) + ".jpg",
                   ylabel="Time",
                   title="Train Time",
                   plot_dir=str(args.epochs))

    # Testing
    # test(model,input_vecs,input_adj_coo_t.float(),cats_idx,idx_test)
    test_emb(model=model,
             train_features=train_features,
             test_features=input_vecs,
             labels=cats_idx,
             idx_train=idx_train,
             idx_test=idx_test)
示例#21
0
    def create_new_data(self,
                        new_data_name: str = "_pointer",
                        save_files: bool = True,
                        save_dir: str = None,
                        catid2cattxt_map: OrderedDict = None):
        """Creates new dataset based on new_data_name value, currently supports: "_fixed5" and "_onehot".

        _fixed5: Creates a dataset of samples which belongs to any of the below 5 sample2cats only.
        _onehot: Creates a dataset which belongs to single class only.

        NOTE: This method is used only for sanity testing using fixed multi-class scenario.
        """
        if save_dir is None:
            save_dir = join(self.dataset_dir,
                            self.dataset_name + new_data_name)
        if isfile(
                join(save_dir, self.dataset_name + new_data_name +
                     "_sample2cats.json")) and isfile(
                         join(save_dir, self.dataset_name + new_data_name +
                              "_txts.json")) and isfile(
                                  join(
                                      save_dir, self.dataset_name +
                                      new_data_name + "_cats.json")):
            logger.info("Loading files from: [{}]".format(save_dir))
            txts_new = File_Util.load_json(self.dataset_name + new_data_name +
                                           "_txts",
                                           filepath=save_dir)
            sample2cats_new = File_Util.load_json(
                self.dataset_name + new_data_name + "_sample2cats",
                filepath=save_dir)
            cats_new = File_Util.load_json(self.dataset_name + new_data_name +
                                           "_cats",
                                           filepath=save_dir)
        else:
            logger.info(
                "No existing files found at [{}]. Generating {} files.".format(
                    save_dir, new_data_name))
            if catid2cattxt_map is None:                catid2cattxt_map =\
File_Util.load_json(self.dataset_name + "_catid2cattxt_map",
                filepath=self.dataset_dir)

            txts, classes, _ = self.load_full_json(return_values=True)
            if new_data_name is "_fixed5":
                txts_one, classes_one, _ = self._create_oneclass_data(
                    txts, classes, catid2cattxt_map=catid2cattxt_map)
                txts_new,sample2cats_new,cats_new =\
                    self._create_fixed_cat_data(txts_one,classes_one,
                                                catid2cattxt_map=catid2cattxt_map)
            elif new_data_name is "_onehot":
                txts_new,sample2cats_new,cats_new =\
                    self._create_oneclass_data(txts,classes,
                                               catid2cattxt_map=catid2cattxt_map)
            elif new_data_name is "_pointer":
                txts_new,sample2cats_new,cats_new =\
                    self._create_pointer_data(txts,classes,
                                              catid2cattxt_map=catid2cattxt_map)
            elif new_data_name is "_fewshot":
                txts_new,sample2cats_new,cats_new =\
                    self._create_fewshot_data(txts,classes,
                                              catid2cattxt_map=catid2cattxt_map)
            elif new_data_name is "_firstsent":
                txts_new,sample2cats_new,cats_new =\
                    self._create_firstsent_data(txts,classes,
                                                catid2cattxt_map=catid2cattxt_map)
            else:
                raise Exception(
                    "Unknown 'new_data_name': [{}]. \n Available options: ['_fixed5','_onehot', '_pointer']"
                    .format(new_data_name))
            if save_files:  # Storing new data
                logger.info(
                    "New dataset will be stored inside original dataset directory at: [{}]"
                    .format(save_dir))
                makedirs(save_dir, exist_ok=True)
                File_Util.save_json(txts_new,
                                    self.dataset_name + new_data_name +
                                    "_txts",
                                    filepath=save_dir)
                File_Util.save_json(sample2cats_new,
                                    self.dataset_name + new_data_name +
                                    "_sample2cats",
                                    filepath=save_dir)
                File_Util.save_json(cats_new,
                                    self.dataset_name + new_data_name +
                                    "_cats",
                                    filepath=save_dir)

        return txts_new, sample2cats_new, cats_new