Example #1
0
    def forward(self,
                inputs,
                encoder_state,
                prev_decoder_state=None,
                controls=None):

        rnn_input = self.input_net(inputs)
        if prev_decoder_state is None:
            prev_decoder_state = encoder_state["state"]

        if controls is not None:
            rnn_input = self._concat_controls(rnn_input, controls)

        if isinstance(rnn_input, Variable):
            # dont use length aware rnn here, avoids unpacking and need for
            # in order inputs
            rnn_output, rnn_state = self.rnn(rnn_input.data,
                                             prev_decoder_state)
            rnn_output = rnn_input.new_with_meta(rnn_output)
        else:
            rnn_output, rnn_state = self.rnn(rnn_input.data,
                                             prev_decoder_state)

        attention = self.attention_net(rnn_output, encoder_state["output"])
        if attention["output"] is not None:
            hidden_state = plum.cat([rnn_output, attention["output"]], dim=2)
        else:
            hidden_state = rnn_output

        pre_output = self.pre_output_net(hidden_state)
        output = self.predictor_net(pre_output)

        output["decoder_state"] = rnn_state
        return output
Example #2
0
    def __call__(self, items):

        bd0 = items[0].batch_dim
        sd0 = items[0].length_dim
        pv0 = items[0].pad_value
        for item in items[1:]:
            assert bd0 == item.batch_dim
            assert sd0 == item.length_dim
            assert pv0 == item.pad_value

        if self.pad_batches:
            batch_sizes = [item.batch_size for item in items]
            max_batch_size = max(batch_sizes)
            items = [item.pad_batch_dim(max_batch_size - item.batch_size)
                     for item in items]
        lengths = plum.cat([item.lengths for item in items])
        max_len = lengths.max()
        items = [item.pad_length_dim(max_len - item.lengths.max())
                 for item in items] 

        data = plum.cat([item.data for item in items], dim=bd0)

        return Variable(data, lengths=lengths, batch_dim=bd0, length_dim=sd0,
                        pad_value=pv0)
Example #3
0
    def __call__(self, batch):

        sizes = [item.size(self.pad_dim) for item in batch]
        max_size = max(sizes)
        diffs = [max_size - sz for sz in sizes]

        for i, diff in enumerate(diffs):
            if diff == 0:
                continue

            dims = list(batch[i].size())
            dims[self.pad_dim] = diff

            pad = batch[i].new(*dims).fill_(self.pad_value)
            batch[i] = plum.cat([batch[i], pad],
                                dim=self.pad_dim,
                                ignore_length=True)

        return batch
Example #4
0
    def _collect_search_states(self, active_items):

        batch_size = active_items.size(0)

        last_state = self._states[-1]
        last_step = self.steps - 1
        for batch in range(batch_size):
            beam = 0 
            while len(self._beam_scores[batch]) < self.beam_size:
                IDX = batch * self.beam_size + beam
                self._beam_scores[batch].append(
                    last_state["beam_score"][0, IDX, 0].view(1))
                self._terminal_info[batch].append(
                    (last_step, beam + batch * self.beam_size))
                beam += 1

        # TODO consider removing beam indices from state
        beam_indices = torch.stack([state["beam_indices"] 
                                    for state in self._states])

        self._beam_scores = torch.stack([torch.cat(bs)
                                         for bs in self._beam_scores])
        
        lengths = self._states[0]["target"].new(
            [[step + 1 for step, beam in self._terminal_info[batch]]
             for batch in range(batch_size)])
        
        selector = self._states[0]["target"].new(
            batch_size, self.beam_size, lengths.max())
        mask = selector.new().bool().new(selector.size()).fill_(1)

        for batch in range(batch_size):
            for beam in range(self.beam_size):
                step, real_beam = self._terminal_info[batch][beam]
                mask[batch, beam,:step + 1].fill_(0)
                self._collect_beam(batch, real_beam, step, 
                                   beam_indices,
                                   selector[batch, beam])
        selector = selector.view(batch_size * self.beam_size, -1)

        ## RESORTING HERE ##
        #if self.sort_by_score:
        # TODO make this an option again
        if True:
            self._beam_scores, I = torch.sort(self._beam_scores, dim=1,
                                              descending=True)
            offset1 = (
                torch.arange(batch_size, device=I.device) * self.beam_size
            ).view(batch_size, 1)
            II = I + offset1
            selector = selector[II.view(-1)]
            mask = mask.view(batch_size * self.beam_size,-1)[II]\
                .view(batch_size, self.beam_size, -1)
            lengths = lengths.gather(1, I)
        ## 

        # TODO reimplement staged indexing         
#        for step, sel_step in enumerate(selector.split(1, dim=1)):
#            self._states[step].stage_indexing("batch", sel_step.view(-1))

        self._output = []
        for step, sel_step in enumerate(selector.split(1, dim=1)):
            self._output.append(
                self._states[step]["target"].index_select(1, sel_step.view(-1))
            )
        #print(self._states[0]["output"].size())
        self._output = plum.cat([o.data for o in self._output], 0).t()\
            .view(batch_size, self.beam_size, -1)
        
        for i in range(batch_size):
            for j in range(self.beam_size):
                self._output[i,j,lengths[i,j]:].fill_(
                    int(self.vocab.pad_index))
        
        self._lengths = lengths

        return self        
Example #5
0
 def __call__(self, item):
     return plum.cat(
         item,
         dim=self.dim,
     )
Example #6
0
 def forward(self, inputs):
     return plum.cat(inputs, dim=self.dim)