示例#1
0
    def update_scores(self, epoch_num: Optional[List[int]]) -> None:
        """
        Update vi/purity score once per epoch

        Parameters
        ----------
        ``epoch_num`` : List[int]
            epoch tracker output (containing current epoch number)
        """

        if self.dev and epoch_num and epoch_num[
                0] != self._metric_epoch_tracker:
            was_training = self.training
            self.eval()
            X = []
            y = []
            for name, ontology in self.ontologys:
                for doc in self.dev:
                    docid = doc['docid']
                    entities = doc["entities"]
                    if len(entities) == 0:
                        continue
                    idxs = [
                        ontology(docid, entity["label"]) for entity in entities
                    ]
                    entities_text = np.asarray(
                        np.stack(
                            [entity["text"].sum(0) for entity in entities]))

                    tensor_input = torch.from_numpy(entities_text)
                    # turn input into float tensor of size (batch_size=1, num_entity, vocab_size)
                    tensor_input = tensor_input.float().unsqueeze(0)
                    output_dic = self(tensor_input)
                    tensor_output = output_dic["e"][0].detach().numpy()
                    for i in range(len(idxs)):
                        if idxs[i] is not None:
                            X.append(tensor_output[i])
                            y.append(idxs[i])
                X, y = np.array(X), np.array(y)
                algo_partitions = bamman_clustering(X)
                gold_partitions = gold_clustering(y)
                VI = variation_of_information(algo_partitions, gold_partitions)
                purity_score = purity(algo_partitions, gold_partitions)
                if "vi" not in self.metrics:
                    self.metrics[f"vi"] = 0
                if "purity" not in self.metrics:
                    self.metrics[f"purity"] = 0

                self.metrics[f"vi"] += VI
                self.metrics[f"purity"] += purity_score

            self.metrics[f"vi"] /= len(self.ontologys)
            self.metrics[f"purity"] /= len(self.ontologys)
            self.train() if was_training else self.eval()
示例#2
0
                        #
                        colors = list(matplotlib.cm.colors.get_named_colors_mapping().values())[:num_class]
                        for k, c in zip(range(num_class), colors):
                            plt.scatter(X_r[y == k, 0], X_r[y == k, 1], c=c, label=str(k), s=5, alpha=.5)

                        plt.xlabel('PCA1'), plt.ylabel('PCA2'), ax.grid('on')
                        # plt.ylim([-4,4])
                        plt.title("PCA"), plt.legend(bbox_to_anchor=(1.05, 1))
                        plt.savefig(f"{dirname}/{figname}"), plt.close()

                        print(f"n_data: {len(y)}, gold_clustering: {name}")
                        algo_partitions = bamman_clustering(X)
                        gold_partitions = gold_clustering(y)
                        max_sizes.append(max(len(lst) for lst in algo_partitions))
                        # print(f"biggest partition size: {max(len(lst) for lst in algo_partitions)}")
                        VI = variation_of_information(algo_partitions, gold_partitions)
                        purity_score = purity(algo_partitions, gold_partitions)
                        VIs.append(VI)
                        purity_scores.append(purity_score)

                    max_sizes_stats = describe(max_sizes)
                    VI_stats = describe(VIs)
                    purity_scores_stats = describe(purity_scores)
                    print(f"Metric {name}")
                    print(f"Variation of Information: {describe_string(VI_stats, percent=False)}")
                    print(f"Purity Score: {describe_string(purity_scores_stats, percent=True)}")
                    print(f"Max Sizes: {describe_string(max_sizes_stats, percent=False)}")
                    print()

                    max_size_mean[name][j, i] = max_sizes_stats.mean
                    max_size_std[name][j, i] = 0 if isnan(max_sizes_stats.variance) else sqrt(max_sizes_stats.variance)