def __init__(self,
                 vocab: Vocabulary,
                 bow_embedder: TokenEmbedder,
                 vae: VAE,
                 apply_batchnorm_on_recon: bool = False,
                 batchnorm_weight_learnable: bool = False,
                 batchnorm_bias_learnable: bool = True,
                 kl_weight_annealing: str = "constant",
                 linear_scaling: float = 1000.0,
                 sigmoid_weight_1: float = 0.25,
                 sigmoid_weight_2: float = 15,
                 reference_counts: str = None,
                 reference_vocabulary: str = None,
                 use_doc_info: str = False,
                 use_background: str = False,
                 background_data_path: str = None,
                 update_background_freq: bool = False,
                 track_topics: bool = True,
                 track_npmi: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self.metrics = {'nkld': Average(), 'nll': Average(), 'perp': Average()}

        self.vocab = vocab
        self.vae = vae
        self.track_topics = track_topics
        self.track_npmi = track_npmi
        self.vocab_namespace = "persona_based"
        self._update_background_freq = update_background_freq

        vocab_size = self.vocab.get_vocab_size(self.vocab_namespace)
        self._use_doc_info = use_doc_info
        # bp()
        if use_doc_info:
            self.interpolation = torch.nn.Parameter(torch.zeros(2, requires_grad=True))
        self._background_freq = self.initialize_bg_from_file(file_=background_data_path) if use_background else 0
        print(self._background_freq)
        # bp()
        self._ref_counts = reference_counts

        if reference_vocabulary is not None:
            # Compute data necessary to compute NPMI every epoch
            logger.info("Loading reference vocabulary.")
            self._ref_vocab = read_json(cached_path(reference_vocabulary))
            self._ref_vocab_index = dict(zip(self._ref_vocab, range(len(self._ref_vocab))))
            logger.info("Loading reference count matrix.")
            self._ref_count_mat = load_sparse(cached_path(self._ref_counts))
            logger.info("Computing word interaction matrix.")
            self._ref_doc_counts = (self._ref_count_mat > 0).astype(float)
            self._ref_interaction = self._ref_doc_counts.T.dot(self._ref_doc_counts)
            self._ref_doc_sum = np.array(self._ref_doc_counts.sum(0).tolist()[0])
            logger.info("Generating npmi matrices.")
            (self._npmi_numerator,
             self._npmi_denominator) = self.generate_npmi_vals(self._ref_interaction,
                                                               self._ref_doc_sum)
            self.n_docs = self._ref_count_mat.shape[0]

        self._bag_of_words_embedder = bow_embedder

        self._kl_weight_annealing = kl_weight_annealing

        self._linear_scaling = float(linear_scaling)
        self._sigmoid_weight_1 = float(sigmoid_weight_1)
        self._sigmoid_weight_2 = float(sigmoid_weight_2)
        if kl_weight_annealing == "linear":
            self._kld_weight = min(1.0, 1 / self._linear_scaling)
        elif kl_weight_annealing == "sigmoid":
            self._kld_weight = float(1 / (1 + np.exp(-self._sigmoid_weight_1 * (1 - self._sigmoid_weight_2))))
        elif kl_weight_annealing == "constant":
            self._kld_weight = 1.0
        else:
            raise ConfigurationError("anneal type {} not found".format(kl_weight_annealing))

        # setup batchnorm
        self._apply_batchnorm_on_recon = apply_batchnorm_on_recon
        if apply_batchnorm_on_recon:
            self.bow_bn = create_trainable_BatchNorm1d(vocab_size,
                                                       weight_learnable=batchnorm_weight_learnable,
                                                       bias_learnable=batchnorm_bias_learnable,
                                                       eps=0.001, momentum=0.001, affine=True)

        # Maintain these states for periodically printing topics and updating KLD
        self._metric_epoch_tracker = 0
        self._kl_epoch_tracker = 0
        self._cur_epoch = 0
        self._cur_npmi = 0.0
        self.batch_num = 0

        initializer(self)
예제 #2
0
    def __init__(
            self,
            vocab,
            encoder_topic: FeedForward,
            mean_projection_topic: FeedForward,
            log_variance_projection_topic: FeedForward,
            encoder_entity: FeedForward,
            mean_projection_entity: FeedForward,
            log_variance_projection_entity: FeedForward,
            encoder_entity_approx: FeedForward,
            mean_projection_entity_approx: FeedForward,
            log_variance_projection_entity_approx: FeedForward,
            decoder_topic: FeedForward,  # decode topic input to persona hidden
            decoder_mean_projection_topic: FeedForward,
            decoder_log_variance_projection_topic: FeedForward,
            decoder_persona: FeedForward,
            prior: Dict = {
                "type": "normal",
                "mu": 0,
                "var": 1
            },
            apply_batchnorm_on_normal: bool = False,
            apply_batchnorm_on_decoder: bool = False,
            batchnorm_weight_learnable: bool = False,
            batchnorm_bias_learnable: bool = True,
            stochastic_beta: bool = False,
            z_dropout: float = 0.2) -> None:
        super(BasicLVAE, self).__init__(vocab)

        self.encoder_topic = encoder_topic
        self.mean_topic = mean_projection_topic
        self.log_var_topic = log_variance_projection_topic
        self.encoder_entity = encoder_entity
        self.mean_entity = mean_projection_entity
        self.log_var_entity = log_variance_projection_entity
        self.encoder_entity_approx = encoder_entity_approx
        self.mean_entity_approx = mean_projection_entity_approx
        self.log_var_entity_approx = log_variance_projection_entity_approx
        self.decoder_mean_topic = decoder_mean_projection_topic
        self.decoder_log_var_topic = decoder_log_variance_projection_topic
        self._decoder_topic = torch.nn.Linear(decoder_topic.get_input_dim(),
                                              decoder_topic.get_output_dim(),
                                              bias=False)
        self._decoder_persona = torch.nn.Linear(
            decoder_persona.get_input_dim(),
            decoder_persona.get_output_dim(),
            bias=False)
        self._z_dropout = torch.nn.Dropout(z_dropout)

        self.num_topic = encoder_topic.get_output_dim()
        self.num_persona = encoder_entity.get_output_dim()

        self.prior = prior
        self.p_params = None
        # self.p_mu, self.p_sigma, self.p_log_var = None, None, None
        self.initialize_prior(prior)

        # If specified, established batchnorm for both mean and log variance.
        self._apply_batchnorm_on_normal = apply_batchnorm_on_normal
        self.mean_bn_entity, self.log_var_bn_entity = None, None
        self.mean_bn_topic, self.log_var_bn_topic = None, None
        self.decoder_mean_bn_topic, self.decoder_log_var_bn_topic = None, None
        self.mean_bn_entity_approx, self.log_var_bn_entity_approx = None, None
        if apply_batchnorm_on_normal:
            self.mean_bn_topic = create_trainable_BatchNorm1d(
                self.num_topic,
                weight_learnable=batchnorm_weight_learnable,
                bias_learnable=batchnorm_bias_learnable,
                eps=0.001,
                momentum=0.001,
                affine=True)
            self.log_var_bn_topic = create_trainable_BatchNorm1d(
                self.num_topic,
                weight_learnable=batchnorm_weight_learnable,
                bias_learnable=batchnorm_bias_learnable,
                eps=0.001,
                momentum=0.001,
                affine=True)
            self.decoder_mean_bn_topic = create_trainable_BatchNorm1d(
                self.num_persona,
                weight_learnable=batchnorm_weight_learnable,
                bias_learnable=batchnorm_bias_learnable,
                eps=0.001,
                momentum=0.001,
                affine=True)
            self.decoder_log_var_bn_topic = create_trainable_BatchNorm1d(
                self.num_persona,
                weight_learnable=batchnorm_weight_learnable,
                bias_learnable=batchnorm_bias_learnable,
                eps=0.001,
                momentum=0.001,
                affine=True)

        # If specified, established batchnorm for reconstruction matrix, applying batch norm across vocabulary
        self._apply_batchnorm_on_decoder = apply_batchnorm_on_decoder
        if apply_batchnorm_on_decoder:
            self.decoder_bn_topic = create_trainable_BatchNorm1d(
                decoder_topic.get_output_dim(),
                weight_learnable=batchnorm_weight_learnable,
                bias_learnable=batchnorm_bias_learnable,
                eps=0.001,
                momentum=0.001,
                affine=True)
            self.decoder_bn_persona = create_trainable_BatchNorm1d(
                decoder_persona.get_output_dim(),
                weight_learnable=batchnorm_weight_learnable,
                bias_learnable=batchnorm_bias_learnable,
                eps=0.001,
                momentum=0.001,
                affine=True)
        # If specified, constrain each topic to be a distribution over vocabulary
        self._stochastic_beta = stochastic_beta
예제 #3
0
    def __init__(self,
                 vocab,
                 extracter: Seq2VecEncoder,
                 encoder: FeedForward,
                 mean_projection: FeedForward,
                 log_variance_projection: FeedForward,
                 decoder: FeedForward,
                 prior: Dict = {"type": "normal", "mu": 0, "var": 1},
                 apply_batchnorm_on_normal: bool = False,
                 apply_batchnorm_on_decoder: bool = False,
                 batchnorm_weight_learnable: bool = False,
                 batchnorm_bias_learnable: bool = True,
                 stochastic_beta: bool = False,
                 z_dropout: float = 0.2) -> None:
        super(CNNVAE, self).__init__(vocab)
        self.extracter = extracter
        self.encoder = encoder
        self.mean_projection = mean_projection
        self.log_variance_projection = log_variance_projection
        self._decoder = torch.nn.Linear(decoder.get_input_dim(), decoder.get_output_dim(),
                                        bias=False)
        self._z_dropout = torch.nn.Dropout(z_dropout)

        self.latent_dim = mean_projection.get_output_dim()
        self.prior = prior

        if prior['type'] == "normal":
            if 'mu' not in prior or 'var' not in prior:
                raise Exception("MU, VAR undefined for normal")
            p_mu = torch.zeros(1, self.latent_dim).fill_(prior['mu'])
            p_var = torch.zeros(1, self.latent_dim).fill_(prior['var'])
            p_log_var = p_var.log()
        elif prior['type'] == "laplace-approx":
            a = torch.zeros(1, self.latent_dim).fill_(prior['alpha'])
            p_mu = a.log() - torch.mean(a.log(), 1)
            p_var = 1.0 / a * (1 - 2.0 / self.latent_dim) + 1.0 / self.latent_dim * torch.mean(1 / a)
            p_log_var = p_var.log()
        else:
            raise Exception("Invalid/Undefined prior!")

        # parameters of prior distribution are not trainable
        self.register_buffer("p_mu", p_mu)
        self.register_buffer("p_log_var", p_log_var)

        # If specified, established batchnorm for both mean and log variance.
        self._apply_batchnorm_on_normal = apply_batchnorm_on_normal
        if apply_batchnorm_on_normal:
            self.mean_bn = create_trainable_BatchNorm1d(self.latent_dim,
                                                        weight_learnable=batchnorm_weight_learnable,
                                                        bias_learnable=batchnorm_bias_learnable,
                                                        eps=0.001, momentum=0.001, affine=True)

            self.log_var_bn = create_trainable_BatchNorm1d(self.latent_dim,
                                                           weight_learnable=batchnorm_weight_learnable,
                                                           bias_learnable=batchnorm_bias_learnable,
                                                           eps=0.001, momentum=0.001, affine=True)
        # If specified, established batchnorm for reconstruction matrix, applying batch norm across vocabulary
        self._apply_batchnorm_on_decoder = apply_batchnorm_on_decoder
        if apply_batchnorm_on_decoder:
            self.decoder_bn = create_trainable_BatchNorm1d(decoder.get_output_dim(),
                                                           weight_learnable=batchnorm_weight_learnable,
                                                           bias_learnable=batchnorm_bias_learnable,
                                                           eps=0.001, momentum=0.001, affine=True)

        # If specified, constrain each topic to be a distribution over vocabulary
        self._stochastic_beta = stochastic_beta
예제 #4
0
    def __init__(self,
                 vocab,
                 extracter: Seq2VecEncoder,
                 encoder_d1: FeedForward,
                 encoder_d2: FeedForward,
                 mean_projection_d1: FeedForward,
                 log_variance_projection_d1: FeedForward,
                 mean_projection_d2: FeedForward,
                 log_variance_projection_d2: FeedForward,
                 encoder_t1: FeedForward,
                 mean_projection_t1: FeedForward,
                 log_variance_projection_t1: FeedForward,
                 decoder1: FeedForward,
                 decoder2: FeedForward,
                 mean_projection_dec2: FeedForward,
                 log_variance_projection_dec2: FeedForward,
                 prior: Dict = {"type": "normal", "mu": 0, "var": 1},
                 apply_batchnorm_on_normal: bool = False,
                 apply_batchnorm_on_decoder: bool = False,
                 batchnorm_weight_learnable: bool = False,
                 batchnorm_bias_learnable: bool = True,
                 stochastic_beta: bool = False,
                 z_dropout: float = 0.2) -> None:
        super(LadderVAE, self).__init__(vocab)
        self.extracter = extracter
        # bp() 
        self.encoder_d1 = encoder_d1
        self.mean_projection_d1 = mean_projection_d1
        self.log_variance_projection_d1 = log_variance_projection_d1
        self.encoder_d2 = encoder_d2
        self.mean_projection_d2 = mean_projection_d2
        self.log_variance_projection_d2 = log_variance_projection_d2
        self.encoder_t1 = encoder_t1
        self.mean_projection_t1 = mean_projection_t1
        self.log_variance_projection_t1 = log_variance_projection_t1

        self._decoder1 = torch.nn.Linear(decoder1.get_input_dim(), decoder1.get_output_dim(),
                                         bias=False)
        self._decoder2 = torch.nn.Linear(decoder2.get_input_dim(), decoder2.get_output_dim(),
                                         bias=False)
        self.mean_projection_dec2 = mean_projection_dec2
        self.log_variance_projection_dec2 = log_variance_projection_dec2

        self._z_dropout = torch.nn.Dropout(z_dropout)

        self.num_persona = mean_projection_d1.get_output_dim()
        self.num_topic = mean_projection_d2.get_output_dim()

        self.prior = prior
        self.p_params = None
        # self.p_mu, self.p_sigma, self.p_log_var = None, None, None
        self.initialize_prior(prior)

        # If specified, established batchnorm for both mean and log variance.
        self._apply_batchnorm_on_normal = apply_batchnorm_on_normal
        self.mean_bn_d1, self.log_var_bn_d1 = None, None
        self.mean_bn_d2, self.log_var_bn_d2 = None, None
        self.mean_bn_t1, self.log_var_bn_t1 = None, None
        if apply_batchnorm_on_normal:
            self.mean_bn_d1 = create_trainable_BatchNorm1d(self.num_persona,
                                                           weight_learnable=batchnorm_weight_learnable,
                                                           bias_learnable=batchnorm_bias_learnable,
                                                           eps=0.001, momentum=0.001, affine=True)
            self.log_var_bn_d1 = create_trainable_BatchNorm1d(self.num_persona,
                                                              weight_learnable=batchnorm_weight_learnable,
                                                              bias_learnable=batchnorm_bias_learnable,
                                                              eps=0.001, momentum=0.001, affine=True)
            self.mean_bn_d2 = create_trainable_BatchNorm1d(self.num_topic,
                                                           weight_learnable=batchnorm_weight_learnable,
                                                           bias_learnable=batchnorm_bias_learnable,
                                                           eps=0.001, momentum=0.001, affine=True)
            self.log_var_bn_d2 = create_trainable_BatchNorm1d(self.num_topic,
                                                              weight_learnable=batchnorm_weight_learnable,
                                                              bias_learnable=batchnorm_bias_learnable,
                                                              eps=0.001, momentum=0.001, affine=True)
            self.mean_bn_t1 = create_trainable_BatchNorm1d(self.num_persona,
                                                           weight_learnable=batchnorm_weight_learnable,
                                                           bias_learnable=batchnorm_bias_learnable,
                                                           eps=0.001, momentum=0.001, affine=True)
            self.log_var_bn_t1 = create_trainable_BatchNorm1d(self.num_persona,
                                                              weight_learnable=batchnorm_weight_learnable,
                                                              bias_learnable=batchnorm_bias_learnable,
                                                              eps=0.001, momentum=0.001, affine=True)
            self.mean_bn_dec2 = create_trainable_BatchNorm1d(self.num_persona,
                                                             weight_learnable=batchnorm_weight_learnable,
                                                             bias_learnable=batchnorm_bias_learnable,
                                                             eps=0.001, momentum=0.001, affine=True)
            self.log_var_bn_dec2 = create_trainable_BatchNorm1d(self.num_persona,
                                                                weight_learnable=batchnorm_weight_learnable,
                                                                bias_learnable=batchnorm_bias_learnable,
                                                                eps=0.001, momentum=0.001, affine=True)

        # If specified, established batchnorm for reconstruction matrix, applying batch norm across vocabulary
        self._apply_batchnorm_on_decoder = apply_batchnorm_on_decoder
        if apply_batchnorm_on_decoder:
            self.decoder_bn1 = create_trainable_BatchNorm1d(decoder1.get_output_dim(),
                                                            weight_learnable=batchnorm_weight_learnable,
                                                            bias_learnable=batchnorm_bias_learnable,
                                                            eps=0.001, momentum=0.001, affine=True)

            self.decoder_bn2 = create_trainable_BatchNorm1d(decoder2.get_output_dim(),
                                                            weight_learnable=batchnorm_weight_learnable,
                                                            bias_learnable=batchnorm_bias_learnable,
                                                            eps=0.001, momentum=0.001, affine=True)

        # If specified, constrain each topic to be a distribution over vocabulary
        self._stochastic_beta = stochastic_beta
예제 #5
0
    def __init__(
            self,
            vocab,
            encoder_entity: FeedForward,
            encoder_entity_global: FeedForward,
            decoder_type: FeedForward,  # (d_dim -> P)
            mean_projection_type: FeedForward,
            log_var_projection_type: FeedForward,
            decoder_topic: FeedForward,  # (K -> V)
            decoder_persona: FeedForward,  # (P -> K)
            prior: Dict = {
                "type": "normal",
                "mu": 0,
                "var": 1
            },
            pooling_layer: str = "max",
            apply_batchnorm_on_normal: bool = False,
            apply_batchnorm_on_decoder: bool = False,
            batchnorm_weight_learnable: bool = False,
            batchnorm_bias_learnable: bool = True,
            stochastic_weight: bool = False,
            z_dropout: float = 0.2) -> None:
        super(Bamman, self).__init__(vocab)

        self.encoder_entity = encoder_entity
        self.encoder_entity_global = encoder_entity_global
        self.mean_projection_type = mean_projection_type
        self.log_var_projection_type = log_var_projection_type

        self._decoder_type = torch.nn.Linear(decoder_type.get_input_dim(),
                                             decoder_type.get_output_dim(),
                                             bias=False)
        self._decoder_topic = torch.nn.Linear(decoder_topic.get_input_dim(),
                                              decoder_topic.get_output_dim(),
                                              bias=False)
        self._decoder_persona = torch.nn.Linear(
            decoder_persona.get_input_dim(),
            decoder_persona.get_output_dim(),
            bias=False)
        self._z_dropout = torch.nn.Dropout(z_dropout)

        self.num_type = decoder_type.get_input_dim()
        self.num_topic = decoder_persona.get_output_dim()
        self.num_persona = decoder_persona.get_input_dim()

        self.prior = prior
        if pooling_layer not in ["max", "sum", "mean"]:
            raise Exception("Undefined pooling function")
        self.pooling_func = pooling_layer
        self.pooling_layer = getattr(torch, pooling_layer)
        self.p_params = None
        # self.p_mu, self.p_sigma, self.p_log_var = None, None, None
        self.initialize_prior(prior)

        # If specified, established batchnorm for both mean and log variance.
        self._apply_batchnorm_on_normal = apply_batchnorm_on_normal
        self.mean_bn_type, self.log_var_bn_type = None, None
        if apply_batchnorm_on_normal:
            self.mean_bn_type = create_trainable_BatchNorm1d(
                self.num_type,
                weight_learnable=batchnorm_weight_learnable,
                bias_learnable=batchnorm_bias_learnable,
                eps=0.001,
                momentum=0.001,
                affine=True)
            self.log_var_bn_type = create_trainable_BatchNorm1d(
                self.num_type,
                weight_learnable=batchnorm_weight_learnable,
                bias_learnable=batchnorm_bias_learnable,
                eps=0.001,
                momentum=0.001,
                affine=True)

        # If specified, established batchnorm for reconstruction matrix, applying batch norm across vocabulary
        self._apply_batchnorm_on_decoder = apply_batchnorm_on_decoder
        if apply_batchnorm_on_decoder:
            self.decoder_bn_type = create_trainable_BatchNorm1d(
                decoder_type.get_output_dim(),
                weight_learnable=batchnorm_weight_learnable,
                bias_learnable=batchnorm_bias_learnable,
                eps=0.001,
                momentum=0.001,
                affine=True)
            self.decoder_bn_topic = create_trainable_BatchNorm1d(
                decoder_topic.get_output_dim(),
                weight_learnable=batchnorm_weight_learnable,
                bias_learnable=batchnorm_bias_learnable,
                eps=0.001,
                momentum=0.001,
                affine=True)

            self.decoder_bn_persona = create_trainable_BatchNorm1d(
                decoder_persona.get_output_dim(),
                weight_learnable=batchnorm_weight_learnable,
                bias_learnable=batchnorm_bias_learnable,
                eps=0.001,
                momentum=0.001,
                affine=True)
        # If specified, constrain each topic to be a distribution over vocabulary
        self._stochastic_weight = stochastic_weight
예제 #6
0
    def __init__(
            self,
            vocab,
            encoder: FeedForward,
            mean_projection: FeedForward,
            log_variance_projection: FeedForward,
            decoder: FeedForward,
            # prior: Dict = {"type": "normal", "mu": 0, "var": 1},
            # apply_batchnorm_on_normal: bool = False,
            # apply_batchnorm_on_decoder: bool = False,
            # batchnorm_weight_learnable: bool = False,
            # batchnorm_bias_learnable: bool = True,
            # stochastic_beta: bool = False,
            apply_batchnorm: bool = False,
            z_dropout: float = 0.2) -> None:
        super(LogisticNormal, self).__init__(vocab)
        self.encoder = encoder
        self.mean_projection = mean_projection
        self.log_variance_projection = log_variance_projection
        self._decoder = torch.nn.Linear(decoder.get_input_dim(),
                                        decoder.get_output_dim(),
                                        bias=False)
        self._z_dropout = torch.nn.Dropout(z_dropout)

        self.latent_dim = mean_projection.get_output_dim()

        # self.prior = prior
        # if prior['type'] == "normal":
        #    if 'mu' not in prior or 'var' not in prior:
        #          raise Exception("MU, VAR undefined for normal")
        #     p_mu = torch.zeros(1, self.latent_dim).fill_(prior['mu'])
        #     p_var = torch.zeros(1, self.latent_dim).fill_(prior['var'])
        #     p_log_var = p_var.log()
        # elif prior['type'] == "laplace-approx":
        #     a = torch.zeros(1, self.latent_dim).fill_(prior['alpha'])
        #     p_mu = a.log() - torch.mean(a.log(), 1)
        #     p_var = 1.0 / a * (1 - 2.0 / self.latent_dim) + 1.0 / self.latent_dim * torch.mean(1 / a)
        #     p_log_var = p_var.log()
        # else:
        #     raise Exception("Invalid/Undefined prior!")

        # parameters of prior distribution are not trainable
        # self.register_buffer("p_mu", p_mu)
        # self.register_buffer("p_log_var", p_log_var)

        # If specified, established batchnorm for both mean and log variance.
        self._apply_batchnorm = apply_batchnorm
        if apply_batchnorm:

            self.mean_bn = create_trainable_BatchNorm1d(
                self.latent_dim,
                weight_learnable=batchnorm_weight_learnable,
                bias_learnable=batchnorm_bias_learnable,
                eps=0.001,
                momentum=0.001,
                affine=True)
            self.log_var_bn = create_trainable_BatchNorm1d(
                self.latent_dim,
                weight_learnable=batchnorm_weight_learnable,
                bias_learnable=batchnorm_bias_learnable,
                eps=0.001,
                momentum=0.001,
                affine=True)