예제 #1
0
    def test_get_att_over_input(self):
        """Testing concatenation of review hiddens belonging to same group."""
        hiddens = T.tensor([[[0.1, 0.2], [10.1, 11.1]], [[1, 2], [10, 11]],
                            [[3, 4], [30, 31]], [[5, 6], [50, 51]],
                            [[7, 8], [70, 71]]])

        mask = T.tensor([[1., 1.], [1., 1.], [1., 0.], [1., 1.], [1., 0.]])
        indxs = T.tensor([[3, 0], [1, 2], [4, 0]])
        indxs_mask = T.tensor([[1., 1.], [1., 1.], [1., 0.]])

        exp_att_vals = T.tensor([[[5., 6.], [50, 51], [0.1, 0.2], [10.1,
                                                                   11.1]],
                                 [[
                                     1,
                                     2,
                                 ], [10, 11], [3, 4], [30, 31]],
                                 [[7, 8], [70, 71], [0., 0.], [0., 0.]]])
        exp_att_vals_mask = T.tensor([[1., 1., 1., 1.], [1., 1., 1., 0.],
                                      [1., 0., 0., 0.]])
        act_att_vals, \
        act_att_keys, \
        act_att_vals_mask = group_att_over_input(inp_att_vals=hiddens,
                                                 inp_att_keys=hiddens,
                                                 inp_att_mask=mask,
                                                 att_indxs=indxs,
                                                 att_indxs_mask=indxs_mask)

        self.assertTrue((exp_att_vals == act_att_keys).all())
        self.assertTrue((exp_att_vals == act_att_vals).all())
        self.assertTrue((exp_att_vals_mask == act_att_vals_mask).all())
예제 #2
0
    def generate_summaries(self, batch, **kwargs):
        """Generates only summaries; simplified script for inference."""
        self.model.eval()
        revs = batch[ModelF.REV].to(self.device)
        rev_lens = batch[ModelF.REV_LEN].to(self.device)
        revs_mask = batch[ModelF.REV_MASK].to(self.device)
        group_rev_indxs = batch[ModelF.GROUP_REV_INDXS].to(self.device)

        if ModelF.GROUP_REV_INDXS_MASK in batch:
            summ_rev_indxs_mask = batch[ModelF.GROUP_REV_INDXS_MASK].to(
                self.device)
        else:
            summ_rev_indxs_mask = None

        max_rev_len = revs.size(1)
        summs_nr = group_rev_indxs.size(0)

        with T.no_grad():
            rev_embds = self.model._embds(revs)
            rev_encs, rev_hiddens = self.model.encode(rev_embds, rev_lens)

            att_keys = self.model.create_att_keys(rev_hiddens)
            contxt_states = self.model.get_contxt_states(
                rev_hiddens, rev_embds)

            summ_att_keys, \
            summ_att_vals, \
            summ_att_mask = group_att_over_input(inp_att_keys=att_keys,
                                                 inp_att_vals=rev_hiddens,
                                                 inp_att_mask=revs_mask,
                                                 att_indxs=group_rev_indxs,
                                                 att_indxs_mask=summ_rev_indxs_mask)

            summ_att_word_ids = revs[group_rev_indxs].view(summs_nr, -1)

            c_mu_q, c_sigma_q, _ = self.model.get_q_c_mu_sigma(
                contxt_states, revs_mask, group_rev_indxs, summ_rev_indxs_mask)
            if self.min_sen_seq_len is not None:
                min_lens = [self.min_sen_seq_len] * summs_nr
            else:
                min_lens = None
            z_mu_p, z_sigma_p = self.model.get_p_z_mu_sigma(c_mu_q)

            init_summ_dec_state = DecState(rec_vals={"hidden": z_mu_p})
            summ_word_ids, \
            summ_coll_vals = self.beamer(init_summ_dec_state,
                                         min_lens=min_lens,
                                         max_steps=max_rev_len,
                                         att_keys=summ_att_keys,
                                         att_values=summ_att_vals,
                                         att_mask=summ_att_mask,
                                         att_word_ids=summ_att_word_ids,
                                         minimum=1, **kwargs)

            return summ_word_ids
예제 #3
0
    def get_q_c_mu_sigma(self, states, states_mask, group_indxs,
                         group_indxs_mask):
        """Computes c approximate posterior's parameters (mu, sigma).

        :param states: [batch_size, seq_len1, dim]
                        representations of review steps, e.g. hidden + embd.
        :param states_mask: [batch_size, seq_len1]
        :param group_indxs: [batch_size2, seq_len2]
                      indxs of reviews belonging to the same product
        :param group_indxs_mask: [batch_size2, seq_len2]
        """
        grouped_states, \
        grouped_mask = group_att_over_input(inp_att_vals=states,
                                            inp_att_mask=states_mask,
                                            att_indxs=group_indxs,
                                            att_indxs_mask=group_indxs_mask)
        ws_state, \
        score_weights = self._compute_ws_state(states=grouped_states,
                                               states_mask=grouped_mask)
        mu, sigma = self._c_inf_network(ws_state)

        return mu, sigma, score_weights
예제 #4
0
    def predict(self, batch, **kwargs):
        """Predicts summaries and reviews."""
        self.model.eval()
        revs = batch[ModelF.REV].to(self.device)
        rev_lens = batch[ModelF.REV_LEN].to(self.device)
        revs_mask = batch[ModelF.REV_MASK].to(self.device)
        summ_rev_indxs = batch[ModelF.GROUP_REV_INDXS].to(self.device)
        summ_rev_indxs_mask = batch[ModelF.GROUP_REV_INDXS_MASK].to(
            self.device)
        other_revs = batch[ModelF.OTHER_REV_INDXS].to(self.device)
        other_revs_mask = batch[ModelF.OTHER_REV_INDXS_MASK].to(self.device)
        rev_to_group_indx = batch[ModelF.REV_TO_GROUP_INDX].to(self.device)

        bs = revs.size(0)
        max_rev_len = revs.size(1)
        summs_nr = summ_rev_indxs.size(0)

        with T.no_grad():
            rev_embds = self.model._embds(revs)
            rev_encs, rev_hiddens = self.model.encode(rev_embds, rev_lens)

            att_keys = self.model.create_att_keys(rev_hiddens)
            contxt_states = self.model.get_contxt_states(
                rev_hiddens, rev_embds)

            summ_att_keys, \
            summ_att_vals, \
            summ_att_mask = group_att_over_input(inp_att_keys=att_keys,
                                                 inp_att_vals=rev_hiddens,
                                                 inp_att_mask=revs_mask,
                                                 att_indxs=summ_rev_indxs,
                                                 att_indxs_mask=summ_rev_indxs_mask)

            summ_att_word_ids = revs[summ_rev_indxs].view(summs_nr, -1)

            c_mu_q, c_sigma_q, _ = self.model.get_q_c_mu_sigma(
                contxt_states, revs_mask, summ_rev_indxs, summ_rev_indxs_mask)

            # DECODING OF SUMMARIES #
            if self.min_sen_seq_len is not None:
                min_lens = [self.min_sen_seq_len] * summs_nr
            else:
                min_lens = None
            z_mu_p, z_sigma_p = self.model.get_p_z_mu_sigma(c_mu_q)

            init_summ_dec_state = DecState(rec_vals={"hidden": z_mu_p})
            summ_word_ids, \
            summ_coll_vals = self.beamer(init_summ_dec_state,
                                         min_lens=min_lens,
                                         max_steps=max_rev_len,
                                         att_keys=summ_att_keys,
                                         att_values=summ_att_vals,
                                         att_mask=summ_att_mask,
                                         att_word_ids=summ_att_word_ids,
                                         minimum=1, **kwargs)

            # DECODING OF REVIEWS #

            z_mu_q, \
            z_sigma_q = self.model.get_q_z_mu_sigma(rev_encs,
                                                    c=c_mu_q[rev_to_group_indx])

            # creating attention values for the reviewer
            rev_att_keys, \
            rev_att_vals, \
            rev_att_vals_mask = group_att_over_input(inp_att_keys=att_keys,
                                                     inp_att_vals=rev_hiddens,
                                                     inp_att_mask=revs_mask,
                                                     att_indxs=other_revs,
                                                     att_indxs_mask=other_revs_mask)

            rev_att_word_ids = revs[other_revs].view(bs, -1)
            if self.min_sen_seq_len is not None:
                min_lens = [self.min_sen_seq_len] * bs
            else:
                min_lens = None

            init_rev_dec_state = DecState(rec_vals={"hidden": z_mu_q})
            rev_word_ids, \
            rev_coll_vals = self.beamer(init_rev_dec_state,
                                        min_lens=min_lens,
                                        max_steps=max_rev_len,
                                        att_keys=rev_att_keys,
                                        att_values=rev_att_vals,
                                        att_mask=rev_att_vals_mask,
                                        att_word_ids=rev_att_word_ids,
                                        minimum=1, **kwargs)

            return rev_word_ids, rev_coll_vals, summ_word_ids, summ_coll_vals
예제 #5
0
    def forward(self, rev, rev_len, rev_mask,
                group_rev_indxs, group_rev_indxs_mask,
                rev_to_group_indx, other_rev_indxs, other_rev_indxs_mask,
                other_rev_comp_states, other_rev_comp_states_mask,
                c_lambd=0., z_lambd=0.):
        """
        :param rev: review word ids.
            [batch_size, rev_seq_len]
        :param rev_len: review lengths.
            [batch_size]
        :param rev_mask: float mask where 0. is set to padded words.
            [batch_size, rev_seq_len]
        :param rev_to_group_indx: mapping from reviews to their corresponding
            groups.
            [batch_size]
        :param group_rev_indxs: indxs of reviews that belong to same groups.
            [group_count, max_rev_count]
        :param group_rev_indxs_mask: float mask where 0. is set to padded
            review indxs.
            [group_count, max_rev_count]
        :param other_rev_indxs: indxs of leave-one-out reviews.
            [batch_size, max_rev_count]
        :param other_rev_indxs_mask: float mask for leave-one-out reviews.
        :param other_rev_comp_states: indxs of (hidden) states of leave-one-out
            reviews. Used as an optimization to avoid attending over padded
            positions.
            [batch_size, cat_rev_len]
        :param other_rev_comp_states_mask: masking of states for leave-one-out
            reviews.
            [batch_size, cat_rev_len]
        :param c_lambd: annealing constant for c representations.
        :param z_lambd: annealing constant for z representations.

        :return loss: scalar loss corresponding to the mean ELBO over batches.
        :return metrs: additional statistics that are used for analytics and
            debugging.
        """
        bs = rev.size(0)
        device = rev.device
        group_count = group_rev_indxs.size(0)
        loss = 0.
        metrs = OrderedDict()

        rev_word_embds = self._embds(rev)
        rev_encs, rev_hiddens = self.encode(rev_word_embds, rev_len)

        att_keys = self.create_att_keys(rev_hiddens)
        contxt_states = self.get_contxt_states(rev_hiddens, rev_word_embds)

        # running the c inference network for the whole group
        c_mu_q, \
        c_sigma_q, \
        scor_wts = self.get_q_c_mu_sigma(contxt_states, rev_mask,
                                         group_rev_indxs, group_rev_indxs_mask)
        # c_mu_q: [group_count, context_dim]
        # c_sigma_q: [group_count, context_dim]
        c = re_parameterize(c_mu_q, c_sigma_q)

        # running the z inference network for each review
        z_mu_q, z_sigma_q = self.get_q_z_mu_sigma(rev_encs,
                                                  c=c[rev_to_group_indx])
        z = re_parameterize(z_mu_q, z_sigma_q)

        # PERFORMING REVIEWS RECONSTRUCTION #

        rev_att_keys, \
        rev_att_vals, \
        rev_att_mask = group_att_over_input(inp_att_keys=att_keys,
                                            inp_att_vals=rev_hiddens,
                                            inp_att_mask=rev_mask,
                                            att_indxs=other_rev_indxs,
                                            att_indxs_mask=other_rev_indxs_mask)

        rev_att_word_ids = rev[other_rev_indxs].view(bs, -1)

        # optimizing the attention targets by making more compact tensors
        # with less padded entries
        sel = T.arange(bs, device=device).unsqueeze(-1)
        rev_att_keys = rev_att_keys[sel, other_rev_comp_states]
        rev_att_vals = rev_att_vals[sel, other_rev_comp_states]
        rev_att_mask = other_rev_comp_states_mask
        rev_att_word_ids = rev_att_word_ids[sel, other_rev_comp_states]

        # creating an extra feature that is passe
        extra_feat = z.unsqueeze(1).repeat(1, rev_word_embds.size(1), 1)

        log_probs, rev_att_wts, \
        hidden, cont, \
        copy_probs = self._decode(embds=rev_word_embds, mask=rev_mask,
                                  extra_feat=extra_feat,
                                  hidden=z, att_keys=rev_att_keys,
                                  att_values=rev_att_vals,
                                  att_mask=rev_att_mask,
                                  att_word_ids=rev_att_word_ids)
        rec_term = comp_seq_log_prob(log_probs[:, :-1], seqs=rev[:, 1:],
                                     seqs_mask=rev_mask[:, 1:])
        avg_rec_term = rec_term.mean(dim=0)

        loss += - avg_rec_term

        # KULLBACK-LEIBLER TERMS #

        # c kld
        c_kl_term = kld_normal(c_mu_q, c_sigma_q)
        # notice that the below kl term is divided by the number of
        # reviews (data-units) to permit proper scaling of loss components
        summed_c_kl_term = c_kl_term.sum()
        avg_c_kl_term = summed_c_kl_term / bs
        loss += c_lambd * avg_c_kl_term

        # z kld
        # running the prior network
        z_mu_p, z_sigma_p = self.get_p_z_mu_sigma(c=c)
        # computing the actual term
        z_kl_term = kld_gauss(z_mu_q, z_sigma_q, z_mu_p[rev_to_group_indx],
                              z_sigma_p[rev_to_group_indx], eps=EPS)
        avg_z_kl_term = z_kl_term.mean(dim=0)
        loss += z_lambd * avg_z_kl_term

        log_var_p_z = T.log(z_sigma_p).sum(-1)

        rev_att_wts = rev_att_wts[:, :-1] * rev_mask[:, 1:].unsqueeze(
            -1)
        copy_probs = copy_probs[:, :-1] * rev_mask[:, 1:]

        # computing maximums over attention weights
        rev_att_max = rev_att_wts.max(-1)[0]

        # computing maximums over steps of copy probs
        max_copy_probs = copy_probs.max(-1)[0]

        log_var_q_z = T.log(z_sigma_q).sum(-1)
        log_var_q_c = T.log(c_sigma_q).sum(-1)

        # averaging over batches different statistics for logging
        avg_att_max = rev_att_max.mean()
        avg_max_copy_prob = max_copy_probs.mean(0)
        avg_copy_prob = (copy_probs.sum(-1) / rev_len.float()).mean(0)
        avg_first_scor_wts = scor_wts[:, 0].mean(0)

        metrs['avg_lb'] = (avg_rec_term - avg_c_kl_term - avg_z_kl_term).item()
        metrs['avg_rec'] = avg_rec_term.item()
        metrs['avg_c_kl'] = (summed_c_kl_term / group_count).item()
        metrs['avg_z_kl'] = avg_z_kl_term.item()

        metrs['avg_log_p_z_var'] = log_var_p_z.mean(0).item()
        metrs['avg_log_q_z_var'] = log_var_q_z.mean(0).item()
        metrs['avg_log_q_c_var'] = log_var_q_c.mean(0).item()
        metrs['avg_att_max'] = avg_att_max.item()
        metrs['avg_max_copy_prob'] = avg_max_copy_prob.item()
        metrs['avg_copy_prob'] = avg_copy_prob.item()

        metrs['c_lambd'] = c_lambd
        metrs['z_lambd'] = z_lambd

        return loss, metrs