Ejemplo n.º 1
0
    def hidden_dep_dec_func(self, prev_word_ids, hidden):

        t1 = [
            [1., 1.1, 0., 0.0, 0.0, 0.0],
            [10., 0., 50., 0., 0., 0.]  # this will be ignored
        ]

        t2 = [[0.2, 0., 0.1, 0., 0., 0.], [0., 0., 0., 1.1, 0., 2.]]
        t3 = [
            [4., 99., 331., 10., 133., 53.],  # should be ignored
            [0., 0., 0., -3.0000, 0., 0.001]
        ]

        self.state += 1

        if self.state == 0:
            hidden[0, 0] += 1.
            return DecState(T.tensor(t1), rec_vals={"hidden": hidden})

        if self.state == 1:
            t2 = T.tensor(t2)
            t2[0, :] += hidden[0, 0]
            return DecState(t2, rec_vals={"hidden": T.tensor([[1.2], [2.]])})

        if self.state == 2:
            t3 = T.tensor(t3)
            t3[0, :] += hidden[0, 0]
            t3[1, :] += hidden[1, 0]
            return DecState(t3, rec_vals={"hidden": T.tensor([[0.], [0.]])})

        raise ValueError("The decoding func supports only 3 steps!")
Ejemplo n.º 2
0
    def decode_beam(self, seqs, hidden, att_word_ids, init_z=None, **kwargs):
        """Function to be used in the beam search process.

        :param seqs: [batch_size, 1]
        :param hidden: [batch_size, hidden_dim]
        :param att_word_ids: [batch_size, cat_rev_len]
        :param init_z: [batch_size, z_dim]
        """
        embds = self._embds(seqs)
        mask = T.ones_like(seqs, dtype=T.float32)

        if init_z is None:
            init_z = hidden

        word_log_probs, att_wts, \
        hidden, cont, ptr_probs = self._decode(embds=embds, mask=mask,
                                               extra_feat=init_z.unsqueeze(1),
                                               hidden=hidden,
                                               att_word_ids=att_word_ids,
                                               **kwargs)
        out = DecState(word_scores=word_log_probs,
                       rec_vals={"hidden": hidden, "cont": cont,
                                 "init_z": init_z},
                       coll_vals={'copy_probs': ptr_probs.squeeze(-1),
                                  'att_wts': att_wts.squeeze(1),
                                  "att_word_ids": att_word_ids})
        return out
Ejemplo n.º 3
0
    def generate_summaries(self, batch, **kwargs):
        """Generates only summaries; simplified script for inference."""
        self.model.eval()
        revs = batch[ModelF.REV].to(self.device)
        rev_lens = batch[ModelF.REV_LEN].to(self.device)
        revs_mask = batch[ModelF.REV_MASK].to(self.device)
        group_rev_indxs = batch[ModelF.GROUP_REV_INDXS].to(self.device)

        if ModelF.GROUP_REV_INDXS_MASK in batch:
            summ_rev_indxs_mask = batch[ModelF.GROUP_REV_INDXS_MASK].to(
                self.device)
        else:
            summ_rev_indxs_mask = None

        max_rev_len = revs.size(1)
        summs_nr = group_rev_indxs.size(0)

        with T.no_grad():
            rev_embds = self.model._embds(revs)
            rev_encs, rev_hiddens = self.model.encode(rev_embds, rev_lens)

            att_keys = self.model.create_att_keys(rev_hiddens)
            contxt_states = self.model.get_contxt_states(
                rev_hiddens, rev_embds)

            summ_att_keys, \
            summ_att_vals, \
            summ_att_mask = group_att_over_input(inp_att_keys=att_keys,
                                                 inp_att_vals=rev_hiddens,
                                                 inp_att_mask=revs_mask,
                                                 att_indxs=group_rev_indxs,
                                                 att_indxs_mask=summ_rev_indxs_mask)

            summ_att_word_ids = revs[group_rev_indxs].view(summs_nr, -1)

            c_mu_q, c_sigma_q, _ = self.model.get_q_c_mu_sigma(
                contxt_states, revs_mask, group_rev_indxs, summ_rev_indxs_mask)
            if self.min_sen_seq_len is not None:
                min_lens = [self.min_sen_seq_len] * summs_nr
            else:
                min_lens = None
            z_mu_p, z_sigma_p = self.model.get_p_z_mu_sigma(c_mu_q)

            init_summ_dec_state = DecState(rec_vals={"hidden": z_mu_p})
            summ_word_ids, \
            summ_coll_vals = self.beamer(init_summ_dec_state,
                                         min_lens=min_lens,
                                         max_steps=max_rev_len,
                                         att_keys=summ_att_keys,
                                         att_values=summ_att_vals,
                                         att_mask=summ_att_mask,
                                         att_word_ids=summ_att_word_ids,
                                         minimum=1, **kwargs)

            return summ_word_ids
Ejemplo n.º 4
0
    def test_hidden_dependent_output(self):
        beam_size = 2
        max_steps = 3

        vocab = {0: "a", 1: "b", 2: "c", 3: "<pad>", 4: "<s>", 5: "<e>"}
        exp_seqs = [[4, 1, 0, 5]]
        init_hidden = T.tensor([[0.]], dtype=T.float32)

        dec = Dec()
        beam_decoder = Beamer(decoding_func=dec.hidden_dep_dec_func,
                              start_id=4,
                              end_id=5,
                              validate_dec_out=False,
                              n_best=beam_size,
                              beam_size=beam_size)

        init_dec_state = DecState(rec_vals={"hidden": init_hidden})

        act_seqs, _ = beam_decoder(init_dec_state, max_steps=max_steps)

        self.assertTrue((exp_seqs == act_seqs))
Ejemplo n.º 5
0
    def test_simple_output(self):
        """Hidden state independent test."""
        beam_size = 2
        max_steps = 3

        vocab = {0: "a", 1: "b", 2: "c", 3: "<pad>", 4: "<s>", 5: "<e>"}
        exp_seqs = [[4, 2, 5], [4, 0, 0, 5]]

        init_hidden = T.tensor([[0., 0., 0.], [0., 0., 0.]], dtype=T.float32)
        init_dec_state = DecState(rec_vals={"hidden": init_hidden})

        dec = Dec()
        beam_decoder = Beamer(decoding_func=dec.dummy_dec_func,
                              start_id=4,
                              beam_size=beam_size,
                              end_id=5,
                              validate_dec_out=False)

        act_seqs, _ = beam_decoder(init_dec_state=init_dec_state,
                                   max_steps=max_steps)
        self.assertTrue((exp_seqs == act_seqs))
Ejemplo n.º 6
0
    def decode(self, seq, tr_state=None, dummy=None, **kwargs):
        """BeamSearch or Sampler specific decoding function. Performs one step
        decoding based on the current state in `tr_state`.

        Args:
            seq: [batch_size, 1]
            tr_state: [batch_size, num_layers, curr_seq_len, model_dim]
            dummy: it used for consistency with the beam search.

        Returns:
            DecState
        """
        pos_offset = 0 if tr_state is None else tr_state.size(2)
        tr_state = tr_state.transpose(1, 0) if tr_state is not None else None
        word_scores, new_tr_state, \
        mem_att_wts = self._decode(tgt=seq, tr_state=tr_state,
                                   pos_offset=pos_offset, **kwargs)
        # new_state: [num_layers, batch_size, curr_seq_len, model_dim]
        tr_state = new_tr_state if tr_state is None else\
            T.cat((tr_state, new_tr_state), dim=2)
        tr_state = tr_state.transpose(1, 0)
        out = DecState(word_scores=word_scores,
                       rec_vals={"tr_state": tr_state})
        return out
Ejemplo n.º 7
0
    def dummy_dec_func(self, prev_word_ids, hidden):
        t1 = [
            [1.1, 0., 1., 0., 0., 0.],
            [0., 0., 0., 10., 10., 10.],  # this will be ignored
            [1.1, 1., 0., 0., 0., 0.],
            [0., 0., 0., 10., 0., 0.]  # this will be ignored
        ]

        t2 = [[0., 1., 2., 0., 0., 0.], [0., 0., 0., 0., 0., 6.],
              [1.1, 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0.]]
        t3 = [[4., 9999., 3., 10., 133., 5.], [0., 0., 0., 0., 0., 1.],
              [0., 0., 0., 0., 0., 1.], [0., 0., 0., 0., 0., 1.]]
        self.state += 1

        if self.state == 0:
            word_scores = T.tensor(t1)
        elif self.state == 1:
            word_scores = T.tensor(t2)
        elif self.state == 2:
            word_scores = T.tensor(t3)
        else:
            raise ValueError("The decoding func supports only 3 steps!")

        return DecState(word_scores=word_scores, rec_vals={"hidden": hidden})
Ejemplo n.º 8
0
    def generate(self, batch, min_seq_len, max_seq_len, use_true_props=True):
        """Generates/decodes sequences from a conditional pmf.

        Args:
            batch: self-explanatory.
            min_seq_len: minimum length of the generated sequence.
            max_seq_len: maximum allowed length for generated sequences.
            use_true_props: whether to use true property values or the ones
                inferred by the plug-in network.

        Returns:
            word_ids: list of word ids.
            prop_values: dict of lists.
        """
        self.model.eval()
        rev = batch[ModelF.REV].to(self.device)
        rev_mask = batch[ModelF.REV_MASK].to(self.device)

        group_rev_indxs = batch[ModelF.GROUP_REV_INDXS].to(self.device)
        group_rev_indxs_mask = batch[ModelF.GROUP_REV_INDXS_MASK].to(
            self.device)
        bs = group_rev_indxs.size(0)

        mem,\
        mem_bin_mask = self.model.create_mem(rev=rev, rev_mask=rev_mask,
                                             group_rev_indxs=group_rev_indxs,
                                             group_rev_indxs_mask=group_rev_indxs_mask)

        if use_true_props:
            len_prop = batch[ModelF.LEN_PROP].to(self.device)
            rating_prop = batch[ModelF.RATING_PROP].to(self.device)
            rouge_prop = batch[ModelF.ROUGE_PROP].to(self.device)
            pov_prop = batch[ModelF.POV_PROP].to(self.device)
            prop_vals = {
                ModelF.LEN_PROP: len_prop,
                ModelF.RATING_PROP: rating_prop,
                ModelF.ROUGE_PROP: rouge_prop,
                ModelF.POV_PROP: pov_prop
            }
        else:
            if hasattr(self.model, 'plugin'):
                prop_vals, _ = self.model.plugin(mem, mem_bin_mask)
            else:
                prop_vals = {}

        dummy = T.zeros(group_rev_indxs.size(0), device=self.device)

        if min_seq_len is not None:
            min_lens = [min_seq_len] * bs
        else:
            min_lens = None

        with T.no_grad():
            init_dec_state = DecState(rec_vals={"dummy": dummy})
            word_ids, _ = self.gen_func(init_dec_state=init_dec_state,
                                        max_steps=max_seq_len,
                                        log_normalize=True,
                                        min_lens=min_lens,
                                        mem=mem,
                                        mem_bin_mask=mem_bin_mask,
                                        minimum=1,
                                        **prop_vals)
        prop_vals = {k: v.tolist() for k, v in prop_vals.items()}
        return word_ids, prop_vals
Ejemplo n.º 9
0
    def predict(self, batch, **kwargs):
        """Predicts summaries and reviews."""
        self.model.eval()
        revs = batch[ModelF.REV].to(self.device)
        rev_lens = batch[ModelF.REV_LEN].to(self.device)
        revs_mask = batch[ModelF.REV_MASK].to(self.device)
        summ_rev_indxs = batch[ModelF.GROUP_REV_INDXS].to(self.device)
        summ_rev_indxs_mask = batch[ModelF.GROUP_REV_INDXS_MASK].to(
            self.device)
        other_revs = batch[ModelF.OTHER_REV_INDXS].to(self.device)
        other_revs_mask = batch[ModelF.OTHER_REV_INDXS_MASK].to(self.device)
        rev_to_group_indx = batch[ModelF.REV_TO_GROUP_INDX].to(self.device)

        bs = revs.size(0)
        max_rev_len = revs.size(1)
        summs_nr = summ_rev_indxs.size(0)

        with T.no_grad():
            rev_embds = self.model._embds(revs)
            rev_encs, rev_hiddens = self.model.encode(rev_embds, rev_lens)

            att_keys = self.model.create_att_keys(rev_hiddens)
            contxt_states = self.model.get_contxt_states(
                rev_hiddens, rev_embds)

            summ_att_keys, \
            summ_att_vals, \
            summ_att_mask = group_att_over_input(inp_att_keys=att_keys,
                                                 inp_att_vals=rev_hiddens,
                                                 inp_att_mask=revs_mask,
                                                 att_indxs=summ_rev_indxs,
                                                 att_indxs_mask=summ_rev_indxs_mask)

            summ_att_word_ids = revs[summ_rev_indxs].view(summs_nr, -1)

            c_mu_q, c_sigma_q, _ = self.model.get_q_c_mu_sigma(
                contxt_states, revs_mask, summ_rev_indxs, summ_rev_indxs_mask)

            # DECODING OF SUMMARIES #
            if self.min_sen_seq_len is not None:
                min_lens = [self.min_sen_seq_len] * summs_nr
            else:
                min_lens = None
            z_mu_p, z_sigma_p = self.model.get_p_z_mu_sigma(c_mu_q)

            init_summ_dec_state = DecState(rec_vals={"hidden": z_mu_p})
            summ_word_ids, \
            summ_coll_vals = self.beamer(init_summ_dec_state,
                                         min_lens=min_lens,
                                         max_steps=max_rev_len,
                                         att_keys=summ_att_keys,
                                         att_values=summ_att_vals,
                                         att_mask=summ_att_mask,
                                         att_word_ids=summ_att_word_ids,
                                         minimum=1, **kwargs)

            # DECODING OF REVIEWS #

            z_mu_q, \
            z_sigma_q = self.model.get_q_z_mu_sigma(rev_encs,
                                                    c=c_mu_q[rev_to_group_indx])

            # creating attention values for the reviewer
            rev_att_keys, \
            rev_att_vals, \
            rev_att_vals_mask = group_att_over_input(inp_att_keys=att_keys,
                                                     inp_att_vals=rev_hiddens,
                                                     inp_att_mask=revs_mask,
                                                     att_indxs=other_revs,
                                                     att_indxs_mask=other_revs_mask)

            rev_att_word_ids = revs[other_revs].view(bs, -1)
            if self.min_sen_seq_len is not None:
                min_lens = [self.min_sen_seq_len] * bs
            else:
                min_lens = None

            init_rev_dec_state = DecState(rec_vals={"hidden": z_mu_q})
            rev_word_ids, \
            rev_coll_vals = self.beamer(init_rev_dec_state,
                                        min_lens=min_lens,
                                        max_steps=max_rev_len,
                                        att_keys=rev_att_keys,
                                        att_values=rev_att_vals,
                                        att_mask=rev_att_vals_mask,
                                        att_word_ids=rev_att_word_ids,
                                        minimum=1, **kwargs)

            return rev_word_ids, rev_coll_vals, summ_word_ids, summ_coll_vals
Ejemplo n.º 10
0
    def gather(self, outputs, output_device):
        """Gathers outputs from replicas, aggregates them generically.

        Gathers outputs from replicas, assumes that each element of the output
        is a PyTorch tensor(1+dim), DecStates, or dictionary of scalars.
        The first and second are concatenated along 0-dim, the last are
        added and then divided by the number of replicas (mean of means).

        It does not check that the outputs are consistent across replicas. At
        the moment.

        Args:
            outputs: replica outputs.
            output_device:

        Returns:
            aggregated output.
        """
        def _coll_tensor(coll_list, out_tensor):
            out_tensor = gather(outputs=[out_tensor],
                                target_device=self.output_device,
                                dim=0)
            coll_list.append(out_tensor)

        def _coll_dict(coll_dict, out_dict):
            for k, v in out_dict.items():
                if k not in coll_dict:
                    coll_dict[k] = []
                coll_dict[k].append(v)

        def _coll_dec_state(coll_dstate, out_dstate):
            for attr_name, attr_val in out_dstate.__dict__.items():
                if isinstance(attr_val, Tensor):
                    if getattr(coll_dstate, attr_name) is None:
                        setattr(coll_dstate, attr_name, [])
                    _coll_tensor(getattr(coll_dstate, attr_name), attr_val)
                elif isinstance(attr_val, dict):
                    if getattr(coll_dstate, attr_name) is None:
                        setattr(coll_dstate, attr_name, dict())
                    _coll_dict(getattr(coll_dstate, attr_name), attr_val)
                elif attr_val is None:
                    continue
                else:
                    raise NotImplementedError

        # minor reformatting
        for indx, o in enumerate(outputs):
            if not isinstance(o, (list, tuple)):
                outputs[indx] = [o]

        coll_outputs = [None for _ in range(len(outputs[0]))]

        # collecting outputs
        for output in outputs:
            for indx, o in enumerate(output):
                # statistics, such as loss, assumed to be of simple types
                if isinstance(o, dict):
                    if coll_outputs[indx] is None:
                        coll_outputs[indx] = dict()
                    _coll_dict(coll_outputs[indx], o)
                # tensors
                elif isinstance(o, Tensor):
                    if coll_outputs[indx] is None:
                        coll_outputs[indx] = []
                    _coll_tensor(coll_outputs[indx], o)
                # decoder state
                elif isinstance(o, DecState):
                    if coll_outputs[indx] is None:
                        coll_outputs[indx] = DecState()
                    coll_dec_state = coll_outputs[indx]
                    _coll_dec_state(coll_dec_state, o)
                else:
                    raise NotImplementedError

        # aggregating outputs
        for indx, o in enumerate(coll_outputs):
            if isinstance(o, list):
                coll_outputs[indx] = T.cat(o)
            elif isinstance(o, dict):
                coll_outputs[indx] = {k: sum(v) / len(v) for k, v in o.items()}
            elif isinstance(o, DecState):
                for attr_name, attr_val in o.__dict__.items():
                    if isinstance(attr_val, list):
                        setattr(o, attr_name, T.cat(attr_val))
                    elif isinstance(attr_val, dict):
                        for k, v in attr_val.items():
                            attr_val[k] = T.cat(v)
                    elif attr_val is None:
                        continue
                    else:
                        raise NotImplementedError
            else:
                raise NotImplementedError

        if len(coll_outputs) == 1:
            coll_outputs = coll_outputs[0]

        return coll_outputs