Example #1
0
    def _decode_and_generate(
        self,
        decoder_in,
        memory_bank,
        batch,
        src_vocabs,
        memory_lengths,
        src_map=None,
        step=None,
        batch_offset=None,
        cls_bank=None
    ):
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx
            )

	# Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch

        # cls_bank = None
        # cls_bank = memory_bank[:, 0:1, :]

        # dec_out, dec_attn = self.model.decoder(
        #     decoder_in, memory_bank, memory_lengths=memory_lengths, step=step, adapter=True, cls_bank=cls_bank)
        dec_out, dec_attn = self.model.decoder(
            decoder_in, memory_bank, memory_lengths=memory_lengths, step=step, adapter=True)

        # Generator forward.
        if not self.copy_attn:
            attn = dec_attn["std"]
            log_probs = self.model.generator(dec_out.squeeze(0))
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            attn = dec_attn["copy"]
            scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
                                          attn.view(-1, attn.size(2)),
                                          src_map)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            if batch_offset is None:
                scores = scores.view(batch.batch_size, -1, scores.size(-1))
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))
            scores = collapse_copy_scores(
                scores,
                batch,
                self._tgt_vocab,
                src_vocabs,
                batch_dim=0,
                batch_offset=batch_offset
            )
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            log_probs = scores.squeeze(0).log()
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        return log_probs, attn
Example #2
0
    def _decode_and_generate(self,
                             decoder_in,
                             memory_bank,
                             batch,
                             src_vocabs,
                             memory_lengths,
                             src_map=None,
                             step=None,
                             batch_offset=None):
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx)

        # Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch
        self.model.decoder.set_copy_info(batch, self._tgt_vocab)
        dec_out, dec_attn = self.model.decoder(decoder_in,
                                               memory_bank,
                                               memory_lengths=memory_lengths,
                                               step=step)

        # Generator forward.
        if not self.copy_attn:
            if "std" in dec_attn:
                attn = dec_attn["std"]
            else:
                attn = None
            log_probs = self.model.generator(dec_out.squeeze(0))
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            attn = dec_attn["copy"]
            #print("DEC_OUT: ", dec_out.size())
            #print("ATTN: ", attn.size())
            scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
                                          attn.view(-1, attn.size(2)), src_map)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            if batch_offset is None:
                scores = scores.view(-1, batch.batch_size, scores.size(-1))
                scores = scores.transpose(0, 1).contiguous()
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))

            #print("TGT_VOCAB: ", self._tgt_vocab)
            scores = collapse_copy_scores(scores,
                                          batch,
                                          self._tgt_vocab,
                                          src_vocabs,
                                          batch_dim=0,
                                          batch_offset=batch_offset)
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))

            log_probs = scores.squeeze(0).log()
            #print(log_probs.size())
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        return log_probs, attn
Example #3
0
    def _decode_and_generate(
        self,
        decoder_in,
        memory_bank,
        batch,
        data,
        memory_lengths,
        src_map=None,
        step=None,
        batch_offset=None
    ):

        tgt_field = self.fields["tgt"][0][1].base_field
        unk_idx = tgt_field.vocab.stoi[tgt_field.unk_token]
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(len(tgt_field.vocab) - 1), unk_idx
            )

        # Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch
        dec_out, dec_attn = self.model.decoder(
            decoder_in, memory_bank, memory_lengths=memory_lengths, step=step
        )

        # Generator forward.
        if not self.copy_attn:
            attn = dec_attn["std"]
            log_probs = self.model.generator(dec_out.squeeze(0))
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            attn = dec_attn["copy"]
            scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
                                          attn.view(-1, attn.size(2)),
                                          src_map)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            if batch_offset is None:
                scores = scores.view(batch.batch_size, -1, scores.size(-1))
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))
            scores = collapse_copy_scores(
                scores,
                batch,
                tgt_field.vocab,
                data.src_vocabs,
                batch_dim=0,
                batch_offset=batch_offset
            )
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            log_probs = scores.squeeze(0).log()
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        return log_probs, attn
Example #4
0
    def _decode_and_generate(
            self,
            decoder_in,
            memory_bank,
            batch,
            src_vocabs,
            memory_lengths,
            src_map=None,
            step=None,
            batch_offset=None):
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx
            )

        # Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch
        dec_out, dec_attn = self.model.decoder(
            decoder_in, memory_bank, memory_lengths=memory_lengths, step=step
        )

        # Generator forward.
        if not self.copy_attn:
            if "std" in dec_attn:
                attn = dec_attn["std"]
            else:
                attn = None
            log_probs = self.model.generator(dec_out.squeeze(0))
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            attn = dec_attn["copy"]
            scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
                                          attn.view(-1, attn.size(2)),
                                          src_map)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            if batch_offset is None:
                scores = scores.view(batch.batch_size, -1, scores.size(-1))
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))
            scores = collapse_copy_scores(
                scores,
                batch,
                self._tgt_vocab,
                src_vocabs,
                batch_dim=0,
                batch_offset=batch_offset
            )
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            log_probs = scores.squeeze(0).log()
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        return log_probs, attn
Example #5
0
    def _decode_and_generate(
            self,
            decoder_in,
            memory_bank,
            batch,
            memory_lengths,
            src_map=None,
            step=None):
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx
            )

        # Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch
        dec_out, dec_attn = self.model.decoder(
            decoder_in, memory_bank, memory_lengths=memory_lengths, step=step
        )

        # Generator forward.
        if not self.copy_attn:
            if "std" in dec_attn:
                attn_key = 'std'
            else:
                attn_key = None
            log_probs = self.model.generator(dec_out.squeeze(0))
            scores = log_probs.exp()  # unfortunate but torch.Categorical want softmax
        else:
            attn = dec_attn["copy"]
            attn_key= 'copy'
            scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
                                          attn.view(-1, attn.size(2)),
                                          src_map)

            scores = scores.view(batch.batch_size, -1, scores.size(-1))
            
            scores = collapse_copy_scores(
                scores,
                batch,
                self._tgt_vocab,
                src_vocabs=None,
                batch_dim=0,
                batch_offset=0
            )
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            
        return scores, dec_attn, attn_key
Example #6
0
    def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None):
        """
        See StdRNNDecoder._run_forward_pass() for description
        of arguments and return values.
        """
        # Additional args check.
        input_feed = self.state["input_feed"].squeeze(0)
        input_feed_batch, _ = input_feed.size()
        _, tgt_batch, _ = tgt.size()

        aeq(tgt_batch, input_feed_batch)
        # END Additional args check.

        dec_outs = []
        attns = {}
        if self.attn is not None:
            attns["std"] = []
        if self.copy_attn is not None or self._reuse_copy_attn:
            attns["copy"] = []
        if self._coverage:
            attns["coverage"] = []

        emb = self.embeddings(tgt)
        assert emb.dim() == 3  # len x batch x embedding_dim

        dec_state = self.state["hidden"]
        coverage = self.state["coverage"].squeeze(0) \
            if self.state["coverage"] is not None else None

        # Input feed concatenates hidden state with
        # input at every time step.
        #GN: This is the loop that needs to be modified
        #for emb_t in emb.split(1):
        #print("TARGET[0]: ", tgt[0])
        #print("LEN TARGET[0]: ", str(len(tgt[0])))
        #temp = torch.ones(1,dtype=int)
        #temp[0] = 5
        #temp = temp.unsqueeze(1)
        #temp = temp.unsqueeze(1)
        #temp2 = self.embeddings(temp)
        #print("TEMP: ", temp)
        #print("TEMP2: ", temp2)
        for t in range(0, len(tgt)):
            if t == 0 or self.eval_status == True:
                #emb_t = self.embeddings([t])
                emb_t = emb.split(1)[t]  #Start symbol
            elif self.teacher_forcing == "teacher":
                #emb_t = self.embeddings([t])
                emb_t = emb.split(1)[t]  #Use gold output
                #print("DEC: " + str(len(dec_outs)))
                #print("TGT: " + str(len(tgt)))
            else:
                #t_value = top_labels[0] #Use predicted output
                if self.teacher_forcing == "random":
                    rep_t = torch.ones(len(tgt[0]), dtype=int)
                    for batch in range(len(tgt[0])):
                        t_value = random.randint(
                            0, self.vocab_size -
                            1)  #Randomly select a member of the vocab
                        rep_t[batch] = t_value
                    rep_t = rep_t.unsqueeze(1)
                    rep_t = rep_t.unsqueeze(1)

                elif self.teacher_forcing == "student" or self.teacher_forcing == "dist":
                    rep_t = torch.ones(len(tgt[0]), dtype=int)
                    '''if self.copy_attn is None:
                        log_probs = self.generator(decoder_output.squeeze(0))
                    else:
                        attn = attns["copy"]
                        scores = self.generator(decoder_output.view(-1, decoder_output.size(2)), attn.view(-1, attn.size(2)), self.batch.src_map)
                        #if batch_offset is None: #Not a beam search, batch_offset doesn't make sense in this case
                        scores = scores.view(-1, self.batch.batch_size, scores.size(-1))
                        scores = scores.transpose(0,1).contiguous()
                        #else:
                        #    scores = scores.view(-1, self.beam_size, scores.size(-1))
                        src_vocabs = None #If this happens, the collapse function backs off to the back source, which is fine
                        scores = collapse_copy_scores(scores, self.batch, self.tgt_vocab, src_vocabs, batch_dim=0)
                        scores = scores.view(decoder_input.size(0), -1, scores.size(-1)) #decoder input is still from last t_value, so it should be fine
                        log_probs = scores.squeeze(0).log()
                    '''

                    for batch_id in range(len(tgt[0])):
                        #print(log_probs[batch])
                        top_probs, top_labels = torch.topk(
                            log_probs[batch_id],
                            self.vocab_size)  #"COPY" is also an option
                        #print("PROBS: ", top_probs.size())
                        #print("LABELS: ", top_labels.size())
                        top_probs = top_probs.squeeze(0).tolist()
                        top_probs = np.exp(
                            top_probs
                        )  #Normalization is required due to some extra weight that is lost in the log/exp conversion
                        top_probs /= np.sum(top_probs)
                        top_labels = top_labels.squeeze(0).tolist()
                        #print("PROBS: ", top_probs)
                        #print("LABELS: ", top_labels)

                        #print("TOP PROBS: " , top_probs)
                        #print("TOP LABELS: " , top_labels)

                        #top_probs = top_probs[batch].tolist()
                        #top_labels = top_labels[batch].tolist()
                        if (self.teacher_forcing == "student"):
                            t_value = top_labels[0]
                        elif (self.teacher_forcing == "dist"):
                            rand_val = random.uniform(0, 1)
                            rand_sum = 0.0
                            index = 0
                            while (rand_sum <
                                   rand_val):  # and index < len(top_labels)):
                                #print("INDEX: ", index)
                                #print("VAL: ", rand_val)
                                #print("SUM: ", rand_sum)
                                rand_sum += top_probs[index]
                                np.exp(top_probs[index])
                                index += 1
                            t_value = top_labels[index - 1]
                        #rep_t[batch_id] = t_value
                    rep_t = rep_t.unsqueeze(1)
                    rep_t = rep_t.unsqueeze(1)

                #rep_t = dtorch.ones(len(tgt[0]),dtype=int)
                #rep_t[0] = t_value
                #rep_t = rep_t.unsqueeze(1)
                #rep_t = rep_t.unsqueeze(1)
                emb_t = self.embeddings(rep_t).squeeze(1)

                #print("DEC: " + str(dec_outs[-1]))
                #print("TGT: " + str(emb.split(1)[t]))

        #    elif opt.teacher_forcing == "rand":
        #        emb_t = self.embeddings(random)
        #    elif opt.teacher_forcing == "dist":
        #        rand = random_integer
        #
        #print(emb_t.squeeze(0).size())
        #print(input_feed.size())
            if (emb_t.dim() > 2):
                decoder_input = torch.cat([emb_t.squeeze(0), input_feed], 1)
            else:
                decoder_input = torch.cat([emb_t, input_feed], 1)

            rnn_output, dec_state = self.rnn(decoder_input, dec_state)
            #print("SIZE: ", dec_state.size())
            if self.attentional:
                decoder_output, p_attn = self.attn(
                    rnn_output,
                    memory_bank.transpose(0, 1),
                    memory_lengths=memory_lengths)
                attns["std"].append(p_attn)
            else:
                decoder_output = rnn_output
            if self.context_gate is not None:
                # TODO: context gate should be employed
                # instead of second RNN transform.
                decoder_output = self.context_gate(decoder_input, rnn_output,
                                                   decoder_output)
            decoder_output = self.dropout(decoder_output)
            #log_probs = self.rnn.generator(decoder_output.squeeze(0))
            #print("PROBS: ", log_probs)
            #top_probs, top_labels = torch.topk(probs,len(probs[0]))
            #top_probs = top_probs[0].tolist()
            #top_labels = top_labels[0].tolist()

            input_feed = decoder_output

            dec_outs += [decoder_output]

            # Update the coverage attention.
            if self._coverage:
                coverage = p_attn if coverage is None else p_attn + coverage
                attns["coverage"] += [coverage]

            if self.copy_attn is not None:
                _, copy_attn = self.copy_attn(decoder_output,
                                              memory_bank.transpose(0, 1))
                attns["copy"] += [copy_attn]
            elif self._reuse_copy_attn:
                attns["copy"] = attns["std"]

            decoder_output = rnn_output
            #print("DEC: ", decoder_output.size())
            #print("ATTN: ", copy_attn.size())
            if self.eval_status == False:

                if self.copy_attn is None:
                    log_probs = self.generator(decoder_output.squeeze(0))
                else:
                    attn = attns["copy"]
                    #print("SRC_MAP: ", self.batch.src_map.size())
                    #print("ATTN: ", copy_attn.size())
                    #src_map = torch.zeros(copy_attn.size(1), self.batch.src.size(0), self.batch.src.size(1), dtype=torch.float)
                    #print(self.batch)
                    scores = self.generator(
                        decoder_output.view(-1, decoder_output.size(1)),
                        copy_attn.view(-1, copy_attn.size(1)),
                        self.batch.src_map)

                    #if batch_offset is None: #Not a beam search, batch_offset doesn't make sense in this case
                    scores = scores.view(-1, self.batch.batch_size,
                                         scores.size(-1))
                    scores = scores.transpose(0, 1).contiguous()
                    #else:
                    #    scores = scores.view(-1, self.beam_size, scores.size(-1))
                    src_vocabs = None  #If this happens, the collapse function backs off to the batch source, which is fine
                    scores = collapse_copy_scores(scores,
                                                  self.batch,
                                                  self.tgt_vocab,
                                                  src_vocabs,
                                                  batch_dim=0)

                    scores = scores.view(
                        decoder_input.size(0), -1, scores.size(-1)
                    )  #decoder input is still from last t_value, so it should be fine
                    log_probs = scores.squeeze(1).log()
                    #log_probs = scores.squeeze(0).log()

        return dec_state, dec_outs, attns
Example #7
0
    def _decode_and_generate(
            self,
            decoder_in,
            #states,
            memory_bank,
            batch,
            src_vocabs,
            memory_lengths,
            src_map=None,
            step=None,
            batch_offset=None):

        #masks, expand_masks = states.get_mask()

        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx)
        # Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch
        dec_out, dec_attn = self.model.decoder(decoder_in,
                                               memory_bank,
                                               memory_lengths=memory_lengths,
                                               step=step,
                                               generator_id=self.generator_id)

        # Generator forward.
        if not self.copy_attn:
            if "std" in dec_attn:
                attn = dec_attn["std"]
            else:
                attn = None
            log_probs = self.model.generator[self.generator_id](
                dec_out.squeeze(0))
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
            #log_probs = log_probs.exp().mul(masks).log()

        else:
            attn = dec_attn["copy"]
            #print(dec_out.size())
            #print(attn.size())
            scores = self.model.generator[self.generator_id](dec_out.view(
                -1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map)

            #print(scores.size())
            #print(batch_offset)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            if batch_offset is None:
                scores = scores.view(batch.batch_size, -1, scores.size(-1))
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))
            scores = collapse_copy_scores(
                scores,
                batch,
                self._tgt_vocab,
                src_vocabs,
                batch_dim=0,
                batch_offset=batch_offset
            )  # some copied words are the same, the scores of the same word should be sumed.

            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            #expand_masks = expand_masks.expand(expand_masks.size(0), scores.size(2) - masks.size(1))

            #masks = torch.cat([masks, expand_masks], 1)
            #scores = scores.mul(masks)

            log_probs = scores.squeeze(0).log()

            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        return log_probs, attn
Example #8
0
    def _decode_and_generate(
            self,
            decoder_in,
            memory_bank,
            batch,
            src_vocabs,
            memory_lengths,
            src_map=None,
            step=None,
            batch_offset=None,
            verbose=False):
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx
            )

        # Decoder forward, takes [tgt_len, batch, nfeats] as input [1,20,1]
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch
        dec_out, dec_attn = self.model.decoder(
            decoder_in, memory_bank, memory_lengths=memory_lengths, step=step
        )
        # dec_out has shape [1, 20, 512] 512 is wordvec size
        # dec_attn has shape [1, 20, 1311] 1311 is src length

        # Generator forward.
        if not self.copy_attn:
            if "std" in dec_attn:
                attn = dec_attn["std"]
            else:
                attn = None
            log_probs = self.model.generator(dec_out.squeeze(0))
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            attn = dec_attn["copy"]
            scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
                                          attn.view(-1, attn.size(2)),
                                          src_map)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            # batch_offset is not none,
            # scores: [20, 50511] where 50511 is the size of the extended vocab
            if batch_offset is None:
                scores = scores.view(batch.batch_size, -1, scores.size(-1))
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))
            scores = collapse_copy_scores(
                scores,
                batch,
                self._tgt_vocab,
                src_vocabs,
                batch_dim=0,
                batch_offset=batch_offset
            )
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            log_probs = scores.squeeze(0).log()
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        return log_probs, attn
Example #9
0
    def _compute_loss(self,
                      batch,
                      output,
                      target,
                      copy_attn,
                      align,
                      std_attn=None,
                      coverage_attn=None):
        """Compute the loss.

        The args must match :func:`self._make_shard_state()`.

        Args:
            batch: the current batch.
            output: the predict output from the model.
            target: the validate target to compare output with.
            copy_attn: the copy attention value.
            align: the align info.
        """
        target = target.view(-1)
        align = align.view(-1)
        src_map_list = list()
        for src_type in self.src_types:
            src_map_list.append(getattr(batch, f"src_map.{src_type}"))
        # end for
        scores = self.generator(self._bottle(output), self._bottle(copy_attn),
                                src_map_list)
        loss = self.criterion(scores, align, target)

        if self.lambda_coverage != 0.0:
            coverage_loss = self._compute_coverage_loss(
                std_attn, coverage_attn)
            loss += coverage_loss

        # this block does not depend on the loss value computed above
        # and is used only for stats
        scores_data = collapse_copy_scores(
            self._unbottle(scores.clone(), batch.batch_size), batch,
            self.tgt_vocab, None)
        scores_data = self._bottle(scores_data)

        # this block does not depend on the loss value computed above
        # and is used only for stats
        # Correct target copy token instead of <unk>
        # tgt[i] = align[i] + len(tgt_vocab)
        # for i such that tgt[i] == 0 and align[i] != 0
        target_data = target.clone()
        unk = self.criterion.unk_index
        correct_mask = (target_data == unk) & (align != unk)
        offset_align = align[correct_mask] + len(self.tgt_vocab)
        target_data[correct_mask] += offset_align

        # Compute sum of perplexities for stats
        stats = self._stats(loss.sum().clone(), scores_data, target_data)

        # this part looks like it belongs in CopyGeneratorLoss
        if self.normalize_by_length:
            # Compute Loss as NLL divided by seq length
            tgt_lens = batch.tgt[:, :, 0].ne(self.padding_idx).sum(0).float()
            # Compute Total Loss per sequence in batch
            loss = loss.view(-1, batch.batch_size).sum(0)
            # Divide by length of each sequence and sum
            loss = torch.div(loss, tgt_lens).sum()
        else:
            loss = loss.sum()

        return loss, stats
Example #10
0
    def _decode_and_generate(self,
                             decoder_in,
                             memory_bank,
                             batch,
                             src_vocabs,
                             memory_lengths,
                             src_map=None,
                             step=None,
                             batch_offset=None,
                             batch_indices=None):
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx)

        # Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch
        # pdb.set_trace()
        if self.user_emb:
            if batch_indices != None:
                uid = batch.uid[batch_indices].unsqueeze(-1).expand_as(
                    decoder_in.view(3, -1, 1)).reshape(-1)
                dec_out, dec_attn = self.model.decoder(
                    decoder_in,
                    memory_bank,
                    memory_lengths=memory_lengths,
                    step=step,
                    uid=uid)
            else:
                uid = batch.uid.unsqueeze(-1).expand_as(
                    decoder_in.view(3, -1, 1)).reshape(-1)
                dec_out, dec_attn = self.model.decoder(
                    decoder_in,
                    memory_bank,
                    memory_lengths=memory_lengths,
                    step=step,
                    uid=uid)
        else:
            dec_out, dec_attn = self.model.decoder(
                decoder_in,
                memory_bank,
                memory_lengths=memory_lengths,
                step=step)

        # Generator forward.
        if not self.copy_attn:
            attn = dec_attn["std"]
            log_probs = self.model.generator(dec_out.squeeze(0))
            if self.user_bias != 'none':
                sfm = torch.nn.LogSoftmax(dim=-1)
                #                 print(log_probs.shape)
                out = log_probs.view(self.beam_size, -1, log_probs.shape[-1])

                if self.user_bias == 'factor_cell':
                    if batch_indices != None:
                        out = out + torch.matmul(
                            self.model.user_bias(batch.uid[batch_indices]),
                            self.model.user_global)
                    else:
                        out = out + torch.matmul(
                            self.model.user_bias(batch.uid),
                            self.model.user_global)
                else:
                    #                     import pdb
                    #                     pdb.set_trace()
                    if batch_indices != None:
                        out = out + self.model.user_bias(
                            batch.uid[batch_indices])
                    else:
                        out = out + self.model.user_bias(batch.uid)
                log_probs = out.view(out.shape[0] * out.shape[1], -1)
                log_probs = sfm(log_probs).cuda()
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            attn = dec_attn["copy"]
            scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
                                          attn.view(-1, attn.size(2)), src_map)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            if batch_offset is None:
                scores = scores.view(batch.batch_size, -1, scores.size(-1))
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))
            scores = collapse_copy_scores(scores,
                                          batch,
                                          self._tgt_vocab,
                                          src_vocabs,
                                          batch_dim=0,
                                          batch_offset=batch_offset)
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            log_probs = scores.squeeze(0).log()
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        return log_probs, attn
Example #11
0
    def _decode_and_generate(
            self,
            decoder_in,
            memory_bank,
            batch,
            src_vocabs,
            memory_lengths,
            src_map=None,
            step=None,
            batch_offset=None,
            batch_emotion=None,
            batch_tgt_concept_emb=None,
            batch_tgt_concept_words=None):
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx
            )

        # Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch
        # print(step)
        # print(decoder_in.shape)
        # print(memory_bank.shape)
        # print(batch.emotion.shape)
        # print(batch_offset)
        # batch_emotion = None
        # if hasattr(batch, "emotion"):
        #     batch_emotion = batch.emotion
        #     src_len, batch_size, hidden = memory_bank.shape
        #     # expand emotion beam_size times
        #     batch_emotion = tile(batch_emotion, batch_size//len(batch_emotion))
        
        if isinstance(self.model.decoder, KGTransformerDecoder):
            dec_out, dec_attn = self.model.decoder(
                decoder_in, memory_bank, memory_lengths=memory_lengths, step=step, 
                emotion=batch_emotion, tgt_concept_emb=batch_tgt_concept_emb,
                tgt_concept_words=batch_tgt_concept_words
            )
        else:
            dec_out, dec_attn = self.model.decoder(
                decoder_in, memory_bank, memory_lengths=memory_lengths, step=step, 
                emotion=batch_emotion
            )

        # Generator forward.
        if not self.copy_attn:
            if "std" in dec_attn:
                attn = dec_attn["std"]
            else:
                attn = None
            if self.model.decoder.__class__.__name__ == "EDSDecoder":
                output = dec_out
                tgt_len, batch_size, hidden_size = output.size()

                # if not hasattr(self.model.decoder, "eds_type"):
                #     self.model.decoder.eds_type = 0
                
                ###########
                ###########
                # method 0: vanilla probs
                if self.model.decoder.eds_type in [0]:
                    # print(output.shape)
                    log_probs = self.model.generator[0](output.squeeze(0))
                    # print(log_probs.shape)
                    # log_probs = self.model.generator(dec_out.squeeze(0))
                
                ###########
                ###########
                # method 1: mask on logits, other emotional words have zero logits
                elif self.model.decoder.eds_type in [1]:
                    logits = F.relu(self.model.generator[0][:2](output)) # [tgt_len, batch, vocab], relu refers to the official code in training
                    # logits = torch.sigmoid(self.model.generator[0][:2](output)) # [tgt_len, batch, vocab], sigmoid refers to the official code in generating response
                    type_controller = torch.sigmoid(self.model.generator[1](output)) # [tgt_len, batch, 1]
                    
                    # emotion = batch.emotion
                    # beam_size = batch_size//len(emotion) # expand emotion beam_size times
                    # emotion = tile(emotion, beam_size)
                    
                    mask = self.model.decoder.generic_mask.unsqueeze(0).unsqueeze(0) * type_controller + \
                        self.get_emotion_mask(batch_emotion).unsqueeze(0) * (1-type_controller)
                    logits = mask * logits # [tgt_len, batch, vocab]
                    log_probs = self.model.generator[0][2](logits.squeeze(0))
                    
                ###########
                ###########
                # method 2: two separate generators for generic and emotional words
                # generic vocab = vocab_size - emotion_vocab_size
                elif self.model.decoder.eds_type in [2]:
                    all_probs = torch.zeros((tgt_len, batch_size, self.model.decoder.vocab_size)).type_as(output)
                    generic_probs = torch.softmax(self.model.generator[0][:2](output), dim=-1) # [tgt_len, batch, generic_vocab]
                    emotion_probs = torch.softmax(self.model.generator[1][:2](output), dim=-1) # [tgt_len, batch, emotion_vocab]
                    type_controller = torch.sigmoid(self.model.generator[2](output)) # [tgt_len, batch, 1]
                    generic_probs = (1-type_controller)*generic_probs
                    emotion_probs = type_controller*emotion_probs
                    for i in range(batch_size):
                        generic_indices = torch.cat([self.model.decoder.generic_vocab_indices, self.model.decoder.other_emotion_indices[batch_emotion[i]]])
                        all_probs[:,i,generic_indices] = generic_probs[:,i,:] # [tgt_len, generic_vocab]
                        all_probs[:,i,self.model.decoder.emotion_vocab_indices[batch_emotion[i]]] = emotion_probs[:,i,:] # [tgt_len, emotion_vocab]
                    log_probs = torch.log(all_probs.squeeze(0))

                ###########
                ###########
                # method 3: two separate generators for generic and emotional words
                # emotion_vocab = vocab_size - generic_vocab
                elif self.model.decoder.eds_type in [3]:
                    all_probs = torch.zeros((tgt_len, batch_size, self.model.decoder.vocab_size)).type_as(output)
                    generic_probs = torch.softmax(self.model.generator[0][:2](output), dim=-1) # [tgt_len, batch, generic_vocab]
                    emotion_probs = torch.softmax(self.model.generator[1][:2](output), dim=-1) # [tgt_len, batch, emotion_vocab]
                    type_controller = torch.sigmoid(self.model.generator[2](output)) # [tgt_len, batch, 1]
                    generic_probs = (1-type_controller)*generic_probs
                    emotion_probs = type_controller*emotion_probs
                    all_probs[:,:,self.model.decoder.generic_vocab_indices] = generic_probs # [tgt_len, batch, generic_vocab]
                    all_probs[:,:,self.model.decoder.all_emotion_indices] = emotion_probs # [tgt_len, batch, emotion_vocab]
                    log_probs = torch.log(all_probs.squeeze(0))
                
            else:
                log_probs = self.model.generator(dec_out.squeeze(0))
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            attn = dec_attn["copy"]
            scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
                                          attn.view(-1, attn.size(2)),
                                          src_map)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            if batch_offset is None:
                scores = scores.view(batch.batch_size, -1, scores.size(-1))
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))
            scores = collapse_copy_scores(
                scores,
                batch,
                self._tgt_vocab,
                src_vocabs,
                batch_dim=0,
                batch_offset=batch_offset
            )
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            log_probs = scores.squeeze(0).log()
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        return log_probs, attn
Example #12
0
    def _decode_and_generate(
            self,
            decoder_in,
            memory_bank,
            # yida translate
            tag_src,
            tag_decoder_in,
            batch,
            src_vocabs,
            memory_lengths,
            src_map=None,
            step=None,
            batch_offset=None):
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx)

        # Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch
        tag_gen = self.model.generators[
            "tag"] if "tag" in self.model.generators else None
        dec_out, dec_attn, rnn_outs = self.model.decoder(
            # yida translate
            decoder_in,
            memory_bank,
            tag_gen,
            tag_src,
            tag_decoder_in,
            memory_lengths=memory_lengths,
            step=step)

        # Generator forward.
        if not self.copy_attn:
            if "std" in dec_attn:
                attn = dec_attn["std"]
            else:
                attn = None
            # yida translate
            tag_log_probs = None
            if "tag" in self.model.generators:
                tag_outputs = rnn_outs if not isinstance(rnn_outs,
                                                         list) else dec_out
                tag_log_probs = self.model.generators["tag"](
                    tag_outputs.squeeze(0))
                tag_argmax = tag_log_probs.max(1)[1]
                # dist = torch.distributions.Multinomial(logits=tag_log_probs, total_count=1)
                # tag_argmax = torch.argmax(dist.sample(), dim=1)
                # tag_index = torch.multinomial(tag_log_probs, num_samples=1)
                if "high" in self.model.generators:
                    vocab_size = sum([
                        v._modules["0"].out_features
                        for k, v in self.model.generators.items() if k != "tag"
                    ])
                    log_probs = torch.full(
                        [dec_out.squeeze(0).shape[0], vocab_size],
                        -float('inf'),
                        dtype=torch.float,
                        device=self._dev)
                    for k, gen in self.model.generators.items():
                        if k == "tag":
                            continue
                        indices = tag_argmax.eq(self.tag_vocab[k])
                        if indices.any():
                            k_output = dec_out.squeeze(0)[indices]
                            k_logits = gen(k_output)
                            mask = indices.float().unsqueeze(-1).mm(
                                self.tag_mask[k])
                            log_probs.masked_scatter_(
                                mask.bool(), k_logits.log_softmax(dim=-1))
                else:
                    logits = self.model.generators["generator"](
                        dec_out.squeeze(0))
                    if self.mask_decode:
                        high_num = self.tag_mask["high"].sum().long().item()
                        high_indices = tag_argmax.eq(self.tag_vocab["high"])
                        low_indices = tag_argmax.eq(self.tag_vocab["low"])
                        # logits[high_indices, high_num:] = -float("inf")
                        logits[low_indices, :high_num] = -float("inf")
                    log_probs = torch.log_softmax(logits, dim=-1)
            else:
                logits = self.model.generators["generator"](dec_out.squeeze(0))
                log_probs = torch.log_softmax(logits, dim=-1)

            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            attn = dec_attn["copy"]
            scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
                                          attn.view(-1, attn.size(2)), src_map)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            if batch_offset is None:
                scores = scores.view(-1, batch.batch_size, scores.size(-1))
                scores = scores.transpose(0, 1).contiguous()
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))
            scores = collapse_copy_scores(scores,
                                          batch,
                                          self._tgt_vocab,
                                          src_vocabs,
                                          batch_dim=0,
                                          batch_offset=batch_offset)
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            log_probs = scores.squeeze(0).log()
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        # yida translate
        return log_probs, attn, tag_log_probs
Example #13
0
    def _decode_and_generate(self,
                             decoder_in,
                             memory_bank,
                             batch,
                             src_vocabs,
                             memory_lengths,
                             src_map=None,
                             step=None,
                             batch_offset=None):
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx)
        if show_debug_detail:
            print("debug in _decode_and_generate, memory_bank",
                  memory_bank.size(), memory_bank.type())
            print("debug in _decode_and_generate, decoder_in",
                  decoder_in.size(), decoder_in.type())
            print('debug in _decode_and_generate, memory_lengths',
                  memory_lengths, memory_lengths.type())

            print('debug max_relative_positions',
                  self.model.decoder.transformer_layers[0].self_attn.
                  max_relative_positions)  # always 0
            print('debug word_vec_size',
                  self.model.decoder.embeddings.word_vec_size)
            print('debug word_padding_idx',
                  self.model.decoder.embeddings.word_padding_idx)
            print('debug step', step, type(step))

        # Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch

        tic = time.perf_counter()
        dec_out, dec_attn = self.model.decoder(decoder_in,
                                               memory_bank,
                                               memory_lengths=memory_lengths,
                                               step=step)
        toc = time.perf_counter()
        onmt_docoder_time = toc - tic

        tic = time.perf_counter()
        turbo_dec_out, turbo_dec_attn = self.turbo_decoder(
            decoder_in, memory_bank, memory_lengths=memory_lengths, step=step)
        toc = time.perf_counter()
        turbo_docoder_time = toc - tic
        global_timer.turbo_timer += turbo_docoder_time
        global_timer.torch_timer += onmt_docoder_time
        # check correctness
        assert (torch.max(torch.abs(dec_out - turbo_dec_out) < 1e-3))
        assert (torch.max(
            torch.abs(dec_attn["std"] - turbo_dec_attn["std"]) < 1e-3))

        # print("onmt time: ", onmt_docoder_time, " turbo time: ", turbo_docoder_time)
        # print("dec_out diff: ", torch.max(
        #             torch.abs(dec_out -
        #                       turbo_dec_out)))
        # print("attn diff: ", torch.max(
        #             torch.abs(dec_attn["std"] -
        #                       turbo_dec_attn["std"])))

        tic = time.perf_counter()
        # Generator forward.
        if not self.copy_attn:
            if "std" in dec_attn:
                attn = dec_attn["std"]
            else:
                attn = None
            log_probs = self.model.generator(dec_out.squeeze(0))
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            attn = dec_attn["copy"]
            scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
                                          attn.view(-1, attn.size(2)), src_map)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            if batch_offset is None:
                scores = scores.view(-1, batch.batch_size, scores.size(-1))
                scores = scores.transpose(0, 1).contiguous()
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))
            scores = collapse_copy_scores(scores,
                                          batch,
                                          self._tgt_vocab,
                                          src_vocabs,
                                          batch_dim=0,
                                          batch_offset=batch_offset)
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            log_probs = scores.squeeze(0).log()
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        toc = time.perf_counter()
        others_time = toc - tic
        if show_profile_detail:
            print(f"others_time {others_time:0.4f}")
        return log_probs, attn
    def _decode_and_generate(
        self,
        decoder_in,
        memory_bank,
        batch,
        src_vocabs,
        memory_lengths,
        src_map=None,
        step=None,
        batch_offset=None,
        # we need src_origin to split the history/current content
        src_origin=None):
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx)

        sep_id = batch.dataset.fields['src'].base_field.vocab.stoi['[SEP]']

        # Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch
        dec_out, dec_attn = self.model.decoder(decoder_in,
                                               memory_bank,
                                               memory_lengths=memory_lengths,
                                               step=step,
                                               sep_id=sep_id,
                                               src_origin=src_origin,
                                               batch_offset=batch_offset)

        # Generator forward.
        if not self.copy_attn:
            if "std" in dec_attn:
                attn = dec_attn["std"]
            else:
                attn = None
            log_probs = self.model.generator(dec_out.squeeze(0))
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            if "copy" in dec_attn:
                attn = dec_attn["copy"]
                scores = self.model.generator(
                    dec_out.view(-1, dec_out.size(2)),
                    attn.view(-1, attn.size(2)), src_map)
            elif "his_copy" in dec_attn:
                his_attn = dec_attn["his_copy"]
                cur_attn = dec_attn["cur_copy"]
                his_mid = dec_attn["his_mid"]
                cur_mid = dec_attn["cur_mid"]

                attn_shape = his_attn.shape
                scores, attn = self.model.generator(
                    dec_out.view(-1, dec_out.size(2)),
                    his_attn.view(-1, his_attn.size(2)),
                    cur_attn.view(-1, cur_attn.size(2)),
                    his_mid.view(-1, his_mid.size(2)),
                    cur_mid.view(-1, cur_mid.size(2)), src_map)
                # convert attn as the same shape as normal copy
                attn = attn.view(attn_shape)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            if batch_offset is None:
                scores = scores.view(-1, batch.batch_size, scores.size(-1))
                scores = scores.transpose(0, 1).contiguous()
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))
            scores = collapse_copy_scores(scores,
                                          batch,
                                          self._tgt_vocab,
                                          src_vocabs,
                                          batch_dim=0,
                                          batch_offset=batch_offset)
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            log_probs = scores.squeeze(0).log()
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        return log_probs, attn
Example #15
0
    def forward_pass(self,
                     decoder_in,
                     step,
                     past=None,
                     input_embeds=None,
                     tags=None,
                     use_copy=True):
        memory_bank = self.memory_bank
        src_vocabs = self.src_vocabs
        memory_lengths = self.memory_lengths
        src_map = self.src_map
        batch = self.batch
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx)

        decoder = self.model.decoder

        dec_out, all_hidden_states, past, dec_attn = decoder(
            decoder_in,
            memory_bank,
            memory_lengths=memory_lengths,
            step=step,
            past=past,
            input_embeds=input_embeds,
            pplm_return=True)

        # Generator forward.
        if not self.copy_attn:
            if "std" in dec_attn:
                attn = dec_attn["std"]
            else:
                attn = None

            if self.simple_fusion:
                lm_dec_out, _ = self.model.lm_decoder(decoder_in,
                                                      memory_bank.new_zeros(
                                                          1, 1, 1),
                                                      step=step)
                probs = self.model.generator(dec_out.squeeze(0),
                                             lm_dec_out.squeeze(0))
            else:
                probs = self.model.generator(dec_out.squeeze(0))
                # print(log_probs)
                # returns [(batch_size x beam_size) , vocab ] when 1 step
                # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            attn = dec_attn["copy"]

            scores, p_copy = self.model.generator(dec_out.view(
                -1, dec_out.size(2)),
                                                  attn.view(-1, attn.size(2)),
                                                  src_map,
                                                  tags=tags)

            scores = scores.view(batch.batch_size, -1, scores.size(-1))
            scores = collapse_copy_scores(scores,
                                          batch,
                                          self._tgt_vocab,
                                          src_vocabs,
                                          batch_dim=0,
                                          batch_offset=None)
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            # log_probs = scores.squeeze(0).log()
            probs = scores.squeeze(0)
            if use_copy is False:
                probs = probs[:, :50257]
                return probs, attn, all_hidden_states, past
            return probs, attn, all_hidden_states, past, p_copy
        return probs, attn, all_hidden_states, past
Example #16
0
    def _decode_and_generate(self,
                             decoder_in,
                             memory_bank,
                             batch,
                             src_vocabs,
                             memory_lengths,
                             src_map=None,
                             step=None,
                             batch_offset=None):

        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx)

        # Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch

        #print(">>>>> hidden state of the decoder:    <<<<<")
        #print(self.model.decoder.state["hidden"][1])
        #before_state1, before_state2 = self.model.decoder.state["hidden"]]0].clone(), self.model.decoder.state["hidden"][1].clone()
        #print(before_state == self.model.decoder.state["hidden"])

        dec_out, dec_attn = self.model.decoder(
            decoder_in,
            memory_bank,
            memory_lengths=memory_lengths,
            step=step,
            counterfactual_attention_method=self.
            counterfactual_attention_method)

        hack_dict = {}

        # Generator forward.
        if not self.copy_attn:
            if "std" in dec_attn:
                attn = dec_attn["std"]
            else:
                attn = None

            log_probs = self.model.generator(dec_out.squeeze(0))

            if self.counterfactual_attention_method == 'uniform_or_zero_out_max':
                hack_dict['log_probs_uniform'] = self.model.generator(
                    dec_attn['uniform'][1].squeeze(0))
                hack_dict['attention_matrix_uniform'] = dec_attn['uniform'][0]
                hack_dict['log_probs_zero_out_max'] = self.model.generator(
                    dec_attn['zero_out_max'][1].squeeze(0))
                hack_dict['attention_matrix_zero_out_max'] = dec_attn[
                    'zero_out_max'][0]
                hack_dict['log_probs_permute'] = self.model.generator(
                    dec_attn['permute'][1].squeeze(0))
                hack_dict['attention_matrix_permute'] = dec_attn['permute'][0]

            elif self.counterfactual_attention_method is not None:
                hack_dict['log_probs_counterfactual'] = self.model.generator(
                    dec_attn[self.counterfactual_attention_method][1].squeeze(
                        0))
                hack_dict['attention_matrix'] = dec_attn[
                    self.counterfactual_attention_method][0]

            #if tvd_permute is True:
            #    dec_outs = dec_attn["std_tvd_permute"]
            #    distances = []
            #    for my_dec_out in dec_outs:
            #        my_dec_out = my_dec_out.squeeze(0)
            #        my_log_prob = self.model.generator(my_dec_out)

            #distance = tvd(torch.exp(log_probs), torch.exp(my_log_prob))
            #        distance = high_distance(torch.exp(log_probs), torch.exp(my_log_prob))[0]
            #        distances.append(distance)

            #    dist_change_median = torch.median(torch.stack(distances), dim=0).values

            #    if attn.size()[0] != 1:
            #        print(">>>> Shit! Target length in attention is more than 1 <<<<")
            #        assert False

            #    max_attention = attn[0].max(dim=1).values

            #    hack_dict['tvd_dist_change_median'] = dist_change_median
            #    hack_dict['tvd_max_attention'] = max_attention

            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            attn = dec_attn["copy"]
            scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
                                          attn.view(-1, attn.size(2)), src_map)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            if batch_offset is None:
                scores = scores.view(batch.batch_size, -1, scores.size(-1))
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))
            scores = collapse_copy_scores(scores,
                                          batch,
                                          self._tgt_vocab,
                                          src_vocabs,
                                          batch_dim=0,
                                          batch_offset=batch_offset)
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            log_probs = scores.squeeze(0).log()
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence

        if self.counterfactual_attention_method is not None:
            return log_probs, attn, hack_dict
        else:
            return log_probs, attn
    def _decode_and_generate(self,
                             decoder_in,
                             memory_bank,
                             batch,
                             src_vocabs,
                             memory_lengths,
                             src_map=None,
                             step=None,
                             batch_offset=None,
                             alive_seq=None,
                             valid=False):
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx)

        # Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch

        dec_out, dec_attn = self.model.decoder(decoder_in,
                                               memory_bank,
                                               memory_lengths=memory_lengths,
                                               step=step)

        if not valid:
            dec_out = dec_out[-1, :, :].unsqueeze(0)
            dec_attn['std'] = dec_attn['std'][-1, :, :].unsqueeze(0)

        # Generator forward.
        if not self.copy_attn:
            if "std" in dec_attn:
                attn = dec_attn["std"]
            else:
                attn = None
            log_probs = self.model.generator(dec_out.squeeze(0))
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            if valid:
                attn = dec_attn["copy"][:decoder_in.size(
                    0), :, :]  #.unsqueeze(0)
            else:
                attn = dec_attn["copy"][-1, :, :].unsqueeze(0)
            scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
                                          attn.view(-1, attn.size(2)), src_map)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            if batch_offset is None:
                scores = scores.view(batch.batch_size, -1, scores.size(-1))
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))
            scores = collapse_copy_scores(scores,
                                          batch,
                                          self._tgt_vocab,
                                          src_vocabs,
                                          batch_dim=0,
                                          batch_offset=batch_offset)
            scores = scores.view(
                decoder_in.size(0) if valid else 1, -1,
                scores.size(-1))  #Original first argument: decoder_in.size(0)
            log_probs = scores.squeeze(0).log()
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence

        return log_probs, attn
Example #18
0
    def _decode_and_generate(
            self,
            decoder_in,
            memory_bank,
            batch,
            src_vocabs,
            memory_lengths,
            src_map=None,
            step=None,
            batch_offset=None):
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx
            )

        # Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch
        dec_out, dec_attn = self.model.decoder(
            decoder_in, memory_bank, memory_lengths=memory_lengths, step=step
        )

        # Generator forward.
        if not self.copy_attn:
            if "std" in dec_attn:
                attn = dec_attn["std"]
            else:
                attn = None
             
            if self.trans_to is not None:
                dec_out_ = dec_out.view(-1, dec_out.shape[-1])#squeeze(0)
                '''
                lang_ids = torch.empty(dec_out_.shape[0], device=dec_out_.device).fill_(self.trans_to).view(-1)
                self.model.generator.set_lang(lang_ids)
                '''
                log_probs = self.model.generator(dec_out_)
            else:
                dec_out_ = dec_out.view(-1, dec_out.shape[-1])
                log_probs = self.model.generator(dec_out_)#.squeeze(0))
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            attn = dec_attn["copy"]
            scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
                                          attn.view(-1, attn.size(2)),
                                          src_map)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            if batch_offset is None:
                scores = scores.view(batch.batch_size, -1, scores.size(-1))
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))
            scores = collapse_copy_scores(
                scores,
                batch,
                self._tgt_vocab,
                src_vocabs,
                batch_dim=0,
                batch_offset=batch_offset
            )
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            log_probs = scores.squeeze(0).log()
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        return log_probs, attn
Example #19
0
    def _decode_and_generate(
            self,
            decoder_in,
            memory_bank,
            batch,
            src_vocabs,
            memory_lengths,
            src_map=None,
            step=None,
            batch_offset=None):
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx
            )

        # Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch
        dec_out, dec_attn = self.model.decoder(
            decoder_in, memory_bank, memory_lengths=memory_lengths, step=step
        )

        # MMM
        if len(self.length_model) > 0:
            # print('Using target length model.')
            if self.length_model == 'oracle':
                # print('Length model: oracle')
                pad = self._tgt_pad_idx
                eos = self._tgt_eos_idx
                #sequence has <s> and </s>
                t_lens = (batch.tgt != pad).sum(dim=0).squeeze(1)
                # # add noise to t_lens for experiments (just for test)
                # noisy_t_lens = torch.tensor([l+randint(-2,2) for l in t_lens])
                # if self.cuda:
                #     noisy_t_lens = noisy_t_lens.to('cuda')
                # self.model.generator[-1].t_lens = noisy_t_lens
                self.device = 'cuda' if self._use_cuda else 'cpu'
                self.model.generator[-1].t_lens = \
                    t_lens if batch_offset is None else t_lens.index_select(0, batch_offset.to(self.device))
                self.model.generator[-1].eos_ind = eos
                self.model.generator[-1].batch_max_len = batch.tgt.size(0)
            elif self.length_model == 'fixed_ratio':
                # print('Length model: fixed_ratio')
                pad = self._tgt_pad_idx
                eos = self._tgt_eos_idx
                #sequence has <s> and </s>
                t_lens = torch.ceil(batch.src[1].float().cuda()*1.0699071772279642)
                # # add noise to t_lens for experiments (just for test)
                # noisy_t_lens = torch.tensor([l+randint(-2,2) for l in t_lens])
                # if self.cuda:
                #     noisy_t_lens = noisy_t_lens.to('cuda')
                # self.model.generator[-1].t_lens = noisy_t_lens
                self.model.generator[-1].t_lens = \
                    t_lens if batch_offset is None else t_lens.index_select(0, batch_offset.to(self.device))
                self.model.generator[-1].eos_ind = eos
                self.model.generator[-1].batch_max_len = batch.tgt.size(0)
            elif self.length_model == 'lstm':
                # print('Length model: lstm')
                pad = self._tgt_pad_idx
                eos = self._tgt_eos_idx
                #sequence has <s> and </s>
                #TODO: the code itself must handle ratio and diff lstm length models
                t_lens = []
                src_vocab = dict(self.fields)["src"].base_field.vocab
                ratios = onmt.utils.length_model.predict_length_ratio(self.l_model, self.device,
                                                                      batch.src[0].squeeze().transpose(0, 1), src_vocab)
                # diffs = torch.round(onmt.utils.length_model.predict_length_ratio(self.l_model, self.device,
                #                                                       batch.src[0].squeeze().transpose(0, 1), src_vocab))
                # target sequence has <s> and </s>, but source sequence doesn't have them
                t_lens = ratios * batch.src[1].type(torch.FloatTensor).to(self.device) + 2
                # t_lens = torch.max((diffs + batch.src[1].type(torch.FloatTensor).to(self.device)), torch.zeros(batch.tgt.size(1)).to(self.device)) + 2

                # # add noise to t_lens for experiments (just for test)
                # noisy_t_lens = torch.tensor([l+randint(-2,2) for l in t_lens])
                # if self.cuda:
                #     noisy_t_lens = noisy_t_lens.to('cuda')
                # self.model.generator[-1].t_lens = noisy_t_lens
                self.model.generator[-1].t_lens =\
                    t_lens if batch_offset is None else t_lens.index_select(0, batch_offset.to(self.device))
                self.model.generator[-1].eos_ind = eos
                self.model.generator[-1].batch_max_len = batch.tgt.size(0)
        # /MMM

        # Generator forward.
        if not self.copy_attn:
            if "std" in dec_attn:
                attn = dec_attn["std"]
            else:
                attn = None
            log_probs = self.model.generator(dec_out.squeeze(0))
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            attn = dec_attn["copy"]
            scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
                                          attn.view(-1, attn.size(2)),
                                          src_map)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            if batch_offset is None:
                scores = scores.view(-1, batch.batch_size, scores.size(-1))
                scores = scores.transpose(0, 1).contiguous()
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))
            scores = collapse_copy_scores(
                scores,
                batch,
                self._tgt_vocab,
                src_vocabs,
                batch_dim=0,
                batch_offset=batch_offset
            )
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            log_probs = scores.squeeze(0).log()
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        return log_probs, attn