def __init__(self,
                 vocab: Vocabulary,
                 bow_embedder: TokenEmbedder,
                 vae: VAE,
                 apply_batchnorm_on_recon: bool = False,
                 batchnorm_weight_learnable: bool = False,
                 batchnorm_bias_learnable: bool = True,
                 kl_weight_annealing: str = "constant",
                 linear_scaling: float = 1000.0,
                 sigmoid_weight_1: float = 0.25,
                 sigmoid_weight_2: float = 15,
                 reference_counts: str = None,
                 reference_vocabulary: str = None,
                 use_doc_info: str = False,
                 use_background: str = False,
                 background_data_path: str = None,
                 update_background_freq: bool = False,
                 track_topics: bool = True,
                 track_npmi: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self.metrics = {'nkld': Average(), 'nll': Average(), 'perp': Average()}

        self.vocab = vocab
        self.vae = vae
        self.track_topics = track_topics
        self.track_npmi = track_npmi
        self.vocab_namespace = "persona_based"
        self._update_background_freq = update_background_freq

        vocab_size = self.vocab.get_vocab_size(self.vocab_namespace)
        self._use_doc_info = use_doc_info
        # bp()
        if use_doc_info:
            self.interpolation = torch.nn.Parameter(torch.zeros(2, requires_grad=True))
        self._background_freq = self.initialize_bg_from_file(file_=background_data_path) if use_background else 0
        print(self._background_freq)
        # bp()
        self._ref_counts = reference_counts

        if reference_vocabulary is not None:
            # Compute data necessary to compute NPMI every epoch
            logger.info("Loading reference vocabulary.")
            self._ref_vocab = read_json(cached_path(reference_vocabulary))
            self._ref_vocab_index = dict(zip(self._ref_vocab, range(len(self._ref_vocab))))
            logger.info("Loading reference count matrix.")
            self._ref_count_mat = load_sparse(cached_path(self._ref_counts))
            logger.info("Computing word interaction matrix.")
            self._ref_doc_counts = (self._ref_count_mat > 0).astype(float)
            self._ref_interaction = self._ref_doc_counts.T.dot(self._ref_doc_counts)
            self._ref_doc_sum = np.array(self._ref_doc_counts.sum(0).tolist()[0])
            logger.info("Generating npmi matrices.")
            (self._npmi_numerator,
             self._npmi_denominator) = self.generate_npmi_vals(self._ref_interaction,
                                                               self._ref_doc_sum)
            self.n_docs = self._ref_count_mat.shape[0]

        self._bag_of_words_embedder = bow_embedder

        self._kl_weight_annealing = kl_weight_annealing

        self._linear_scaling = float(linear_scaling)
        self._sigmoid_weight_1 = float(sigmoid_weight_1)
        self._sigmoid_weight_2 = float(sigmoid_weight_2)
        if kl_weight_annealing == "linear":
            self._kld_weight = min(1.0, 1 / self._linear_scaling)
        elif kl_weight_annealing == "sigmoid":
            self._kld_weight = float(1 / (1 + np.exp(-self._sigmoid_weight_1 * (1 - self._sigmoid_weight_2))))
        elif kl_weight_annealing == "constant":
            self._kld_weight = 1.0
        else:
            raise ConfigurationError("anneal type {} not found".format(kl_weight_annealing))

        # setup batchnorm
        self._apply_batchnorm_on_recon = apply_batchnorm_on_recon
        if apply_batchnorm_on_recon:
            self.bow_bn = create_trainable_BatchNorm1d(vocab_size,
                                                       weight_learnable=batchnorm_weight_learnable,
                                                       bias_learnable=batchnorm_bias_learnable,
                                                       eps=0.001, momentum=0.001, affine=True)

        # Maintain these states for periodically printing topics and updating KLD
        self._metric_epoch_tracker = 0
        self._kl_epoch_tracker = 0
        self._cur_epoch = 0
        self._cur_npmi = 0.0
        self.batch_num = 0

        initializer(self)
Ejemplo n.º 2
0
    def __init__(self,
                 vocab: Vocabulary,
                 bow_embedder: TokenEmbedder,
                 vae: VAE,
                 doc_kl_weight_annealing: str = "constant",
                 doc_linear_scaling: float = 1000.0,
                 doc_sigmoid_weight_1: float = 0.25,
                 doc_sigmoid_weight_2: float = 15,
                 doc_saturation_period: int = 2,
                 doc_period: int = 10,
                 entity_kl_weight_annealing: str = "constant",
                 entity_linear_scaling: float = 1000.0,
                 entity_sigmoid_weight_1: float = 0.25,
                 entity_sigmoid_weight_2: float = 15,
                 entity_saturation_period: int = 2,
                 entity_period: int = 10,
                 reference_counts: str = None,
                 reference_vocabulary: str = None,
                 background_data_path: str = None,
                 update_background_freq: bool = False,
                 track_topics: bool = True,
                 track_npmi: bool = True,
                 visual_topic: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self.metrics = {
            'nkld': Average(),
            'd_nkld': Average(),
            'e_nkld': Average(),
            'nll': Average()
        }

        self.vocab = vocab
        self.vae = vae
        self.track_topics = track_topics
        self.track_npmi = track_npmi
        self.visual_topic = visual_topic
        self.vocab_namespace = "entity_based"
        self._update_background_freq = update_background_freq
        # bp()
        self._background_freq = self.initialize_bg_from_file(
            file_=background_data_path)
        # bp()
        self._ref_counts = reference_counts
        self._npmi_updated = False
        import pickle
        # if dev_path is not None:
        #     self.dev_set = pickle.load(open(dev_path, "rb"))

        if reference_vocabulary is not None:
            # Compute data necessary to compute NPMI every epoch
            logger.info("Loading reference vocabulary.")
            self._ref_vocab = read_json(cached_path(reference_vocabulary))
            self._ref_vocab_index = dict(
                zip(self._ref_vocab, range(len(self._ref_vocab))))
            logger.info("Loading reference count matrix.")
            self._ref_count_mat = load_sparse(cached_path(self._ref_counts))
            logger.info("Computing word interaction matrix.")
            self._ref_doc_counts = (self._ref_count_mat > 0).astype(float)
            self._ref_interaction = (self._ref_doc_counts).T.dot(
                self._ref_doc_counts)
            self._ref_doc_sum = np.array(
                self._ref_doc_counts.sum(0).tolist()[0])
            logger.info("Generating npmi matrices.")
            (self._npmi_numerator,
             self._npmi_denominator) = self.generate_npmi_vals(
                 self._ref_interaction, self._ref_doc_sum)
            self.n_docs = self._ref_count_mat.shape[0]

        vocab_size = self.vocab.get_vocab_size(self.vocab_namespace)
        self._bag_of_words_embedder = bow_embedder

        self._doc_kl_weight_annealing = doc_kl_weight_annealing

        self._doc_linear_scaling = float(doc_linear_scaling)
        self._doc_sigmoid_weight_1 = float(doc_sigmoid_weight_1)
        self._doc_sigmoid_weight_2 = float(doc_sigmoid_weight_2)
        if doc_kl_weight_annealing == "linear":
            self._doc_kld_weight = min(1, 1 / self._doc_linear_scaling)
        elif doc_kl_weight_annealing == "sigmoid":
            self._doc_kld_weight = float(
                1 / (1 + np.exp(-self._doc_sigmoid_weight_1 *
                                (1 - self._doc_sigmoid_weight_2))))
        elif doc_kl_weight_annealing == "constant":
            self._doc_kld_weight = 1.0
        elif doc_kl_weight_annealing == "cyclic-linear":
            self._doc_period = doc_period
            self._doc_saturation_period = doc_saturation_period
            self._doc_cyclic_kl_anneal_tracker = 0
            self._doc_kld_weight = 1 / self._doc_period
        else:
            raise ConfigurationError("anneal type(doc) {} not found".format(
                doc_kl_weight_annealing))

        self._entity_kl_weight_annealing = entity_kl_weight_annealing
        self._entity_linear_scaling = float(entity_linear_scaling)
        self._entity_sigmoid_weight_1 = float(entity_sigmoid_weight_1)
        self._entity_sigmoid_weight_2 = float(entity_sigmoid_weight_2)
        if entity_kl_weight_annealing == "linear":
            self._entity_kld_weight = min(1, 1 / self._entity_linear_scaling)
        elif entity_kl_weight_annealing == "sigmoid":
            self._entity_kld_weight = float(
                1 / (1 + np.exp(-self._entity_sigmoid_weight_1 *
                                (1 - self._entity_sigmoid_weight_2))))
        elif entity_kl_weight_annealing == "constant":
            self._entity_kld_weight = 1.0
        elif entity_kl_weight_annealing == "cyclic-linear":
            self._entity_period = entity_period
            self._entity_saturation_period = entity_saturation_period
            self._entity_cyclic_kl_anneal_tracker = 0
            self._entity_kld_weight = 1 / self._entity_period
        else:
            raise ConfigurationError("anneal type(entity) {} not found".format(
                entity_kl_weight_annealing))

        # setup batchnorm
        self.doc_bow_bn = torch.nn.BatchNorm1d(vocab_size,
                                               eps=0.001,
                                               momentum=0.001,
                                               affine=True)
        self.doc_bow_bn.weight.data.copy_(
            torch.ones(vocab_size, dtype=torch.float64))
        self.doc_bow_bn.weight.requires_grad = False

        # self.entity_bow_bn = torch.nn.BatchNorm1d(vocab_size, eps=0.001, momentum=0.001, affine=True)
        # self.entity_bow_bn.weight.data.copy_(torch.ones(vocab_size, dtype=torch.float64))
        # self.entity_bow_bn.weight.requires_grad = False

        # Maintain these states for periodically printing topics and updating KLD
        self._metric_epoch_tracker = 0
        self._kl_epoch_tracker = 0
        self._cur_epoch = 0
        self._cur_entity_npmi = 0.0
        self._cur_doc_npmi = 0.0
        self.batch_num = 0

        initializer(self)