Example #1
0
        def _compute_attention_sum(q, m, length):
            # q : batch_size x lstm_size
            # m : batch_size x max(length) x embedded_dim
            assert torch.max(length) == m.size()[1]
            max_len = m.size()[1]
            if simple:
                if q.size()[-1] != m.size()[-1]:
                    q = self.attention(q) # batch_size x embedded_dim
                weight_logit = torch.bmm(m, q.unsqueeze(-1)).squeeze(2) # batch_size x n_features
            else:
                linear_m = self.attention[1]
                linear_q = self.attention[0]
                linear_out = self.attention[2]

                packed = pack(m, list(length), batch_first=True)
                proj_m = PackedSequence(linear_m(packed.data), packed.batch_sizes)
                proj_m, _ = pad(proj_m, batch_first=True)  # batch_size x n_features x proj_dim
                proj_q = linear_q(q).unsqueeze(1) # batch_size x 1 x proj_dim
                packed = pack(F.relu(proj_m + proj_q), list(length), batch_first=True)
                weight_logit = PackedSequence(linear_out(packed.data), packed.batch_sizes)
                weight_logit, _ = pad(weight_logit, batch_first=True) # batch_size x n_features x 1
                weight_logit = weight_logit.squeeze(2)

            # max_len = weight_logit.size()[1]
            indices = torch.arange(0, max_len,
                out=torch.LongTensor(max_len).unsqueeze(0)).cuda()
            # TODO here.. cuda..
            mask = indices < length.unsqueeze(1)#.long()
            weight_logit[1-mask] = -np.inf
            weight = F.softmax(weight_logit, dim=1) # nonzero x max_len
            weighted = torch.bmm(weight.unsqueeze(1), m)
            # batch_size x 1 x max_len
            # batch_size x     max_len x embedded_dim
            # = batch_size x 1 x embedded_dim
            return weighted.squeeze(1), weight  #nonzero x embedded_dim
Example #2
0
 def collate(self, batch):
     ei = [torch.tensor(sent) for sent in batch]
     eo = [torch.tensor([x if x is not None else -100 for x in sent.target]) for sent in batch]
     el = torch.tensor([len(sent) for sent in batch])
     ei = pad(ei, padding_value = self.pad)
     eo = pad(eo, padding_value = -100)
     return Batch(ei, eo, el)
Example #3
0
 def __call__(self, indices):
     return {
         'src':
         pad([torch.tensor([0] + self.data[index]) for index in indices]),
         'trg':
         pad([torch.tensor(self.data[index] + [0]) for index in indices],
             padding_value=-100),
         'len':
         self.lengths[indices],
     }
Example #4
0
 def collate(self, items):
     old = [torch.tensor(i["old"]) for i in items]
     new = [torch.tensor(i["new"]) for i in items]
     len_old = [i["len_old"] for i in items]
     len_new = [i["len_new"] for i in items]
     return {
         "old": pad(old, batch_first=True),
         "new": pad(new, batch_first=True),
         "len_old": torch.tensor(len_old),
         "len_new": torch.tensor(len_new)
     }
Example #5
0
 def collate(self, batch):
     ei = pad([torch.tensor(sent) for sent in batch],
              padding_value=self.pad)
     eo = pad([torch.tensor(sent) for sent in batch],
              padding_value=self.pad)
     el = [len(sent) for sent in batch]
     rand_tensor = torch.rand(ei.shape)
     rand_token = torch.randint(2, len(self.vocab), ei.shape)
     normal_token = ei > 2
     position_to_mask = (rand_tensor < self.mask_th) & normal_token
     position_to_replace = (rand_tensor < self.replace_th) & normal_token
     ei.masked_fill_(position_to_mask, self.msk)
     ei.masked_scatter_(position_to_replace, rand_token)
     eo.masked_fill_(~position_to_mask, self.pad)
     return Batch(ei, eo, el)
Example #6
0
    def forward(self, input, encoder_outputs, hidden=None):
        batch_size = input.size()[0]
        embedded = self.embedding(input)
        embedded = self.embedding_dropout(embedded)

        dec_outs, hidden = self.rnn(
            pack(embedded, [1] * batch_size, batch_first=True), hidden)

        # batch_size * 1 * hidden_size
        dec_outs, lengths = pad(dec_outs, batch_first=True)

        # sum the bidirectional outputs
        if self.bidirectional:
            dec_outs = dec_outs[:, :, :self.hidden_size] + \
                dec_outs[:, :, self.hidden_size:]

        # calculate the attention context vector
        attn_weights = self.attn(dec_outs, encoder_outputs)
        context = attn_weights.bmm(encoder_outputs)

        # Finally concatenate and pass through a linear layer
        concated_input = T.cat((dec_outs, context), 2)
        concated_out = self.concat(concated_input.squeeze(1)).unsqueeze(1)
        concat_output = None
        try:
            concat_output = self.output(F.tanh(concated_out))
        except Exception as e:
            print(concated_out.size(), concated_out)

        return (concat_output, hidden, attn_weights)
Example #7
0
    def forward(self, inputs):

        if isinstance(inputs, PackedSequence):
            # unpack output
            inputs, lengths = pad(inputs, batch_first=self.batch_first)
        if self.batch_first:
            batch_size, max_len = inputs.size()[:2]
        else:
            max_len, batch_size = inputs.size()[:2]
            inputs = inputs.permute(1, 0, 2)

        # att = torch.mul(inputs, self.att_weights.expand_as(inputs))
        # att = att.sum(-1)
        weights = torch.bmm(
            inputs,
            self.att_weights  # (1, hidden_size)
            .permute(1, 0)  # (hidden_size, 1)
            .unsqueeze(0)  # (1, hidden_size, 1)
            # (batch_size, hidden_size, 1)
            .repeat(batch_size, 1, 1))

        attentions = F.softmax(F.relu(weights.squeeze()), dim=-1)

        # apply weights
        weighted = torch.mul(inputs,
                             attentions.unsqueeze(-1).expand_as(inputs))

        # get the final fixed vector representations of the sentences
        representations = weighted.sum(1).squeeze()

        return representations, attentions
Example #8
0
    def forward(self, subword_idxs, subword_masks, token_start, batch_label):
        #self.args.use_crf = True
        bert_outs, _ = self.bert(
            subword_idxs,
            token_type_ids=None,
            attention_mask=subword_masks,
            output_all_encoded_layers=False,
        )
        lens = token_start.sum(dim=1)
        bert_outs = torch.split(bert_outs[token_start], lens.tolist())
        bert_outs = pad_sequence(bert_outs, batch_first=True)
        max_len = bert_outs.size(1)
        mask = torch.arange(max_len).cuda() < lens.unsqueeze(-1)

        # add lstm after bert
        sorted_lens, sorted_idx = torch.sort(lens, dim=0, descending=True)
        reverse_idx = torch.sort(sorted_idx, dim=0)[1]
        bert_outs = bert_outs[sorted_idx]
        bert_outs = pack(bert_outs, sorted_lens, batch_first=True)
        bert_outs, hidden = self.lstm(bert_outs)
        bert_outs, _ = pad(bert_outs, batch_first=True)
        bert_outs = bert_outs[reverse_idx]

        out = self.linear(torch.tanh(bert_outs))
        if self.args.use_crf:
            score, seq = self.crf.viterbi_decode(out, mask)
        else:
            batch_size = out.size(0)
            seq_len = out.size(1)
            out = out.view(-1, out.size(2))
            _, seq = torch.max(out, 1)
            seq = seq.view(batch_size, seq_len)
            seq = mask.long() * seq
        return seq
Example #9
0
    def forward(self, input_idxs, input_masks, syntax_ids=None):
        bert_outs, _ = self.bert(
            input_idxs,
            token_type_ids=None,
            attention_mask=input_masks,
            output_all_encoded_layers=False,
        )
        lens = torch.sum(input_idxs.gt(0), dim=1)
        # bert_outs = torch.split(bert_outs[token_start], lens.tolist())
        bert_outs = pad_sequence(bert_outs, batch_first=True)
        lstm_input = bert_outs
        if self.use_syntax:
            syntax_vec = self.syntax_embed(syntax_ids)
            lstm_input = torch.cat((lstm_input, syntax_vec),-1)

        max_len = lstm_input.size(1)
        lstm_input = lstm_input[:, :max_len, :]
        # mask = torch.arange(max_len).cuda() < lens.unsqueeze(-1)
        # add lstm after bert
        sorted_lens, sorted_idx = torch.sort(lens, dim=0, descending=True)
        reverse_idx = torch.sort(sorted_idx, dim=0)[1]
        lstm_input = lstm_input[sorted_idx]
        lstm_input = pack(lstm_input, sorted_lens, batch_first=True)
        lstm_output, (h, _) = self.lstm(lstm_input) # lstm_output:[batch,sequence_length,embeding]
        output, _ = pad(lstm_output, batch_first=True)
        output = lstm_output.permute(0, 2, 1) # lstm_output:[batch,embeding,sequence_length]
        
        output = nn.MaxPool1d(output, output.size()[2]) # lstm_output:[batch,embeding,1]
        output = output.squeeze(2) # lstm_output:[batch,embeding]
        output = output[reverse_idx]
        output = self.linear(output)
        out = self.linear(torch.tanh(output))
        return out
Example #10
0
    def forward(self, hidden, encoder_outputs):
        lengths = None
        if type(encoder_outputs) is PackedSequence:
            encoder_outputs, lengths = pad(encoder_outputs, batch_first=True)
        else:
            lengths = [len(x) for x in encoder_outputs]

        batch_size = encoder_outputs.size()[0]
        attns = cuda(T.zeros(batch_size, max(lengths)), gpu_id=self.gpu_id)
        lengths = cuda(T.zeros(max(lengths), 1), gpu_id=self.gpu_id)

        if self.method == 'dot':
            attns = T.baddbmm(lengths, encoder_outputs,
                              hidden.transpose(2, 1)).squeeze(2)

        elif self.method == 'general':
            attended = self.attn(encoder_outputs)
            attns = T.baddbmm(lengths, attended,
                              hidden.transpose(2, 1)).squeeze(2)

        elif self.method == 'concat':
            concated = T.cat(
                (hidden.expand_as(encoder_outputs), encoder_outputs), 2)
            energy = self.attn(concated)
            expanded = self.other.unsqueeze(0).expand(batch_size, 1,
                                                      self.hidden_size)
            attns = T.baddbmm(lengths, energy,
                              expanded.transpose(2, 1)).squeeze(2)

        return F.softmax(attns).unsqueeze(1)
Example #11
0
 def forward(self, input_idxs, bert_lens, bert_mask, syntax_embed=None):
     bert_outs = self.bert_embedding(input_idxs, bert_lens, bert_mask)
     lens = torch.sum(bert_lens.gt(0), dim=1)
     # bert_outs = torch.split(bert_outs[token_start], lens.tolist())
     # bert_outs = pad_sequence(bert_outs, batch_first=True)
     lstm_input = bert_outs
     if self.use_syntax:
         syntax_vec = syntax_embed
         lstm_input = torch.cat((lstm_input, syntax_vec), -1)
     # max_len = lstm_input.size(1)
     # lstm_input = lstm_input[:, :max_len, :]
     # mask = torch.arange(max_len).cuda() < lens.unsqueeze(-1)
     # add lstm after bert
     sorted_lens, sorted_idx = torch.sort(lens, dim=0, descending=True)
     reverse_idx = torch.sort(sorted_idx, dim=0)[1]
     lstm_input = lstm_input[sorted_idx]
     lstm_input = pack(lstm_input, sorted_lens, batch_first=True, enforce_sorted=False)
     
     lstm_output, (h, _) = self.lstm(lstm_input) # lstm_output:[batch,sequence_length,embeding]
     
     output, _ = pad(lstm_output, batch_first=True)
     output = output.permute(0, 2, 1) # lstm_output:[batch,embeding,sequence_length]
     if self.args.maxpooling:
         output = F.max_pool1d(output, output.size()[2]) # lstm_output:[batch,embeding,1]
     elif self.args.avepooling:
         output = F.avg_pool1d(output, output.size()[2]) # lstm_output:[batch,embeding,1]
     output = output.squeeze(2) # lstm_output:[batch,embeding]
     output = output[reverse_idx]
     out = self.linear(torch.tanh(output))
     out = F.softmax(out,dim=1)
     return out
Example #12
0
    def forward(self, input, hx=(None, None, None)):
        # handle packed data
        is_packed = type(input) is PackedSequence
        if is_packed:
            input, lengths = pad(input)
            max_length = lengths[0]
        else:
            max_length = input.size(1) if self.batch_first else input.size(0)
            lengths = [input.size(1)] * max_length if self.batch_first else [
                input.size(0)
            ] * max_length

        batch_size = input.size(0) if self.batch_first else input.size(1)

        # make the data batch-first
        if not self.batch_first:
            input = input.transpose(0, 1)

        controller_hidden, mem_hidden, last_read = self._init_hidden(
            hx, batch_size)

        # batched forward pass per element / word / etc
        outputs = None
        chxs = []
        read_vectors = [last_read] * max_length
        outs = [
            T.cat([input[:, x, :], last_read], 1) for x in range(max_length)
        ]

        for layer in range(self.num_layers):
            # this layer's hidden states
            chx = [x[layer] for x in controller_hidden
                   ] if self.mode == 'LSTM' else controller_hidden[layer]
            # pass through controller
            outs, read_vectors, (chx, mem_hidden[layer]) = self._layer_forward(
                outs, layer, (chx, mem_hidden[layer]))
            chxs.append(chx)

            if layer == self.num_layers - 1:
                # final outputs
                outputs = T.stack(outs, 1)
            else:
                # the controller output + read vectors go into next layer
                outs = [T.cat([o, r], 1) for o, r in zip(outs, read_vectors)]

        # final hidden values
        if self.mode == 'LSTM':
            h = T.stack([x[0] for x in chxs], 0)
            c = T.stack([x[1] for x in chxs], 0)
            controller_hidden = (h, c)
        else:
            controller_hidden = T.stack(chxs, 0)

        if not self.batch_first:
            outputs = outputs.transpose(0, 1)
        if is_packed:
            outputs = pack(output, lengths)

        return outputs, (controller_hidden, mem_hidden, read_vectors[-1])
Example #13
0
 def collate(self, xs):
     max_seq_len = max([x['lengths'] for x in xs])
     mask = [[1] * int(x['lengths']) + [0] * int(max_seq_len - x['lengths'])
             for x in xs]
     mask = torch.tensor(mask, dtype=torch.long)
     return {
         'src': pad([x['src'] for x in xs]),
         'trg': torch.stack([x['trg'] for x in xs], dim=-1),
         'mask': mask,
         'lengths': torch.stack([x['lengths'] for x in xs], dim=-1)
     }
Example #14
0
    def forward(self, inputs, lengths):
        lengths = lengths.contiguous().data.view(-1).tolist()
        embs = self.dropout(self.emb(inputs))
        packed_embs = pack(embs, lengths, batch_first=True)
        output, _ = self.rnn(packed_embs)
        output = pad(output, batch_first=True)[0]
        output = self.dropout(output)
        output_flat = self.output(output.contiguous().view(
            output.size(0) * output.size(1), output.size(2)))

        return output_flat
Example #15
0
    def forward(self, source, source_lengths, hidden=None):
        embedded = self.embedding(source)
        packed = pack(embedded, source_lengths, batch_first=True)
        outputs, hidden = self.rnn(packed, hidden)
        outputs, _ = pad(outputs, batch_first=True)

        # sum the bidirectional outputs
        if self.bidirectional:
            outputs = outputs[:, :, :self.hidden_size] + \
                outputs[:, :, self.hidden_size:]

        return outputs, hidden
Example #16
0
    def forward(self, state, length):
        # length should be sorted
        assert len(state.size()) == 3 # batch x n_features x input_dim
                                      # input_dim == n_features + 1
        batch_size = state.size()[0]
        self.weight = np.zeros((int(batch_size), self.n_features))#state.data.new(int(batch_size), self.n_features).fill_(0.)
        nonzero = torch.sum(length > 0).cpu().numpy() # encode only nonzero points
        if nonzero == 0:
            return state.new(int(batch_size), self.lstm_size + self.embedded_dim).fill_(0.)

        length_ = list(length[:nonzero].cpu().numpy())
        packed = pack(state[:nonzero], length_, batch_first=True)

        embedded = self.embedder(packed.data)


        if self.normalize:
            embedded = F.normalize(embedded, dim=1)
        embedded = PackedSequence(embedded, packed.batch_sizes)
        embedded, _ = pad(embedded, batch_first=True) # nonzero x max(length) x embedded_dim

        # define initial state
        qt = embedded.new(embedded.size()[0], self.lstm_size).fill_(0.)
        ct = embedded.new(embedded.size()[0], self.lstm_size).fill_(0.)

        ###########################
        # shuffling (set encoding)
        ###########################

        for i in range(self.n_shuffle):
            attended, weight = self.attending(qt, embedded, length[:nonzero])
            # attended : nonzero x embedded_dim
            qt, ct = self.lstm(attended, (qt, ct))

        # TODO edit here!
        weight = weight.detach().cpu().numpy()
        tmp = state[:, :, 1:]
        val, acq = torch.max(tmp, 2) # batch x n_features
        tmp = (val.long() * acq).cpu().numpy()
        #tmp = tmp.cpu().numpy()
        tmp = tmp[:weight.shape[0], :weight.shape[1]]
        self.weight[np.arange(nonzero).reshape(-1, 1), tmp] = weight

        encoded = torch.cat((attended, qt), dim=1)
        if batch_size > nonzero:
            encoded = torch.cat(
                (encoded,
                 encoded.new(int(batch_size - nonzero),
                     encoded.size()[1]).fill_(0.)),
                dim=0
            )
        return encoded
Example #17
0
 def featurize(self, x):
     wrd_ix = x[0]
     lengths = x[1]
     x = self.word_vectors(wrd_ix)
     x = pack(x, lengths.tolist(), batch_first=True)
     lstm_out, (hidden_state, cell_state) = self.lstm(x)
     if self.att:
         x, self.attentions = self.att_layer(lstm_out)
     else:
         output, _ = pad(lstm_out, batch_first=True)
         # get the last time step for each sequence
         idx = (lengths - 1).view(-1, 1).expand(output.size(0),
                                                output.size(2)).unsqueeze(1)
         x = output.gather(1, Variable(idx)).squeeze(1)
     return x
Example #18
0
    def forward(self, inputs, lengths):
        lengths = lengths.contiguous().data.view(-1).tolist()
        
        word_vecs = self.dropout(self.word_lut(inputs))
        packed_word_vecs = pack(word_vecs, lengths)
        rnn_output, hidden = self.lstm(packed_word_vecs)
        rnn_output = pad(rnn_output)[0]
        rnn_output = self.dropout(rnn_output)

        output_flat = self.linear_output(
                rnn_output.view(
                    rnn_output.size(0) * rnn_output.size(1), 
                    rnn_output.size(2)
                    )
                )
        
        return output_flat 
Example #19
0
    def neg_log_likehood(self, subword_idxs, subword_masks, token_start,
                         batch_label):
        #self.args.use_crf = False
        bert_outs, _ = self.bert(
            subword_idxs,
            token_type_ids=None,
            attention_mask=subword_masks,
            output_all_encoded_layers=False,
        )
        lens = token_start.sum(dim=1)

        #x = bert_outs[token_start]
        bert_outs = torch.split(bert_outs[token_start], lens.tolist())
        bert_outs = pad_sequence(bert_outs, batch_first=True)
        max_len = bert_outs.size(1)
        mask = torch.arange(max_len).cuda() < lens.unsqueeze(-1)

        # add lstm after bert
        sorted_lens, sorted_idx = torch.sort(lens, dim=0, descending=True)
        reverse_idx = torch.sort(sorted_idx, dim=0)[1]
        bert_outs = bert_outs[sorted_idx]
        bert_outs = pack(bert_outs, sorted_lens, batch_first=True)
        bert_outs, hidden = self.lstm(bert_outs)
        bert_outs, _ = pad(bert_outs, batch_first=True)
        bert_outs = bert_outs[reverse_idx]

        out = self.linear(bert_outs)
        #out = self.dropout(out)
        if self.args.use_crf:
            loss = self.crf(out, mask, batch_label)
            score, seq = self.crf.viterbi_decode(out, mask)
        else:
            batch_size = out.size(0)
            seq_len = out.size(1)
            out = out.view(-1, out.size(2))
            score = torch.nn.functional.log_softmax(out, 1)
            loss_function = nn.NLLLoss(ignore_index=0, reduction="sum")
            loss = loss_function(score, batch_label.view(-1))
            _, seq = torch.max(score, 1)
            seq = seq.view(batch_size, seq_len)
        if self.args.average_loss:
            loss = loss / mask.float().sum()
        return loss, seq
    def forward(self, input, source_lengths, hidden=(None, None, None)):
        batch_size = len(source_lengths)
        if not hidden:
            hidden = (None, None, None)
        (encoder_hidden, interface_hidden, mem_hidden) = hidden

        # encode
        encoded, encoder_hidden = self.encoder(input, source_lengths,
                                               encoder_hidden)

        # reset working memory
        mem_hidden = self.memory.reset(batch_size, mem_hidden)

        # nothing read in first time step (b*w*r)
        read_vectors = cuda(T.zeros(batch_size, self.w, self.r),
                            gpu_id=self.gpu_id)
        dnc_encoded = cuda(T.zeros(encoded.size()), gpu_id=self.gpu_id)

        # unroll the rnn for each time step
        for x in range(max(source_lengths)):
            # concat the input and stuff read from memory in last time step
            b = encoded[:, x, :].unsqueeze(2)  # b * w * 1
            input = T.cat((b, read_vectors),
                          2).view(batch_size, -1,
                                  self.input_size)  # b * 1 * ((r+1)*w)

            # pass it through an RNN
            input = pack(input, [1] * batch_size, batch_first=True)

            out, interface_hidden = self.rnn(input, interface_hidden)
            out, _ = pad(out, batch_first=True)
            ξ = out.squeeze(1)

            # forward pass through memory
            read_vectors, mem_hidden = self.memory(ξ, mem_hidden)

            # final output, todo: differs from deepmind's implementation
            # where they concat and then pass through a Linear
            read_vecs = read_vectors.view(-1, self.w * self.r)
            mem_encoded = T.cat([ξ, read_vecs], 1)
            dnc_encoded[:, x, :] = self.mem_out(mem_encoded)

        return dnc_encoded, (encoder_hidden, interface_hidden, mem_hidden)
    def run_rnn(self, embedded_input, batch, rnn):
        """
        Run embeddings through RNN and return the output.

        Args:
            embedded_input (torch.FloatTensor): batch x seq x dim
            batch (Batch): batch object containing .lengths tensor
            param (torch.nn.LSTM): LSTM to run the embeddings through

        Returns:
            torch.FloatTensor: hidden states output of LSTM, batch x seq x dim
        """
        (sorted_input, sorted_lengths, input_unsort_indices, _) = \
            sort_batch_by_length(embedded_input, batch.lengths)
        packed_input = pack(sorted_input, sorted_lengths.data.tolist(),
                            batch_first=True)
        rnn.flatten_parameters()
        packed_sorted_output, _ = rnn(packed_input)
        sorted_output, _ = pad(packed_sorted_output, batch_first=True)
        return sorted_output[input_unsort_indices]
Example #22
0
    def forward(self, source, source_lengths, hidden=None):
        '''
        source: nr_batches * max_len
        source_lengths: nr_batches
        hidden: nr_layers * nr_batches * nr_hidden
        '''
        batch_size = self.source.size()[0]
        embedded = self.embedding(source)
        outputs = T.zeros(batch_size, max(source_lengths), self.nr_hidden)

        for c in range(self.columns):
            e = embedded[:, :, self.λ * c:(self.λ + 1) * c]
            packed = pack(e, source_lengths, batch_first=True)
            h = None if not hidden else hidden[:, :,
                                               self.λ * c:(self.λ + 1) * c]
            o, h = self.gru(packed, h)
            o, _ = pad(packed, batch_first=True)
            hidden[:, :, self.λ * c:(self.λ + 1) * c] = h
            outputs[:, :, self.λ * c:(self.λ + 1) * c] = o

        return outputs, hidden
Example #23
0
    def forward(self, input):
        B, T = input.size()

        lens = input.ne(self.vocabulary.pad_id).sum(1).long()

        x = self.embedding(input)  # B x T x D
        x = F.dropout(x, self.dropout, self.training)
        x = pack(x, lens, batch_first=True)

        x, _ = self.rnn(x)

        x = pad(x, True)
        B_, T_, H = x.size()
        assert B_ == B
        assert T_ == T

        x=F.dropout(x, self.dropout, self.training)
        x=self.ffn(x)
        x=F.dropout(x, self.dropout, self.training)

        x = x.sum(1) / lens.unsqueeze(1).float()  # B x H

        logits = self.logits(x)  # B x C, C: num_class
        return logits
Example #24
0
 def collate(self, xs):
     return {
         'src': pad([x['src'] for x in xs]),
         'trg': torch.stack([x['trg'] for x in xs], dim=-1),
         'lengths': torch.stack([x['lengths'] for x in xs], dim=-1)
     }
Example #25
0
  def forward(self, input, hx=(None, None, None), reset_experience=False, pass_through_memory=True):
    # handle packed data
    is_packed = type(input) is PackedSequence
    if is_packed:
      input, lengths = pad(input)
      max_length = lengths[0]
    else:
      max_length = input.size(1) if self.batch_first else input.size(0)
      lengths = [input.size(1)] * max_length if self.batch_first else [input.size(0)] * max_length

    batch_size = input.size(0) if self.batch_first else input.size(1)

    if not self.batch_first:
      input = input.transpose(0, 1)
    # make the data time-first

    controller_hidden, mem_hidden, last_read = self._init_hidden(hx, batch_size, reset_experience)

    # concat input with last read (or padding) vectors
    inputs = [T.cat([input[:, x, :], last_read], 1) for x in range(max_length)]

    # batched forward pass per element / word / etc
    if self.debug:
      viz = None

    outs = [None] * max_length
    read_vectors = None

    # pass through time
    for time in range(max_length):
      # pass thorugh layers
      for layer in range(self.num_layers):
        # this layer's hidden states
        chx = controller_hidden[layer]
        m = mem_hidden if self.share_memory else mem_hidden[layer]
        # pass through controller
        outs[time], (chx, m, read_vectors) = \
          self._layer_forward(inputs[time], layer, (chx, m), pass_through_memory)

        # debug memory
        if self.debug:
          viz = self._debug(m, viz)

        # store the memory back (per layer or shared)
        if self.share_memory:
          mem_hidden = m
        else:
          mem_hidden[layer] = m
        controller_hidden[layer] = chx

        if read_vectors is not None:
          # the controller output + read vectors go into next layer
          outs[time] = T.cat([outs[time], read_vectors], 1)
        else:
          outs[time] = T.cat([outs[time], last_read], 1)
        inputs[time] = outs[time]

    if self.debug:
      viz = {k: np.array(v) for k, v in viz.items()}
      viz = {k: v.reshape(v.shape[0], v.shape[1] * v.shape[2]) for k, v in viz.items()}

    # pass through final output layer
    inputs = [self.output(i) for i in inputs]
    outputs = T.stack(inputs, 1 if self.batch_first else 0)

    if is_packed:
      outputs = pack(output, lengths)

    if self.debug:
      return outputs, (controller_hidden, mem_hidden, read_vectors), viz
    else:
      return outputs, (controller_hidden, mem_hidden, read_vectors)
Example #26
0
    def forward(self,
                input,
                hx=(None, None, None),
                reset_experience=False,
                pass_through_memory=True):
        # handle packed data
        is_packed = type(input) is PackedSequence
        if is_packed:
            input, lengths = pad(input)
            max_length = lengths[0]
        else:
            max_length = input.size(1) if self.batch_first else input.size(0)
            lengths = [input.size(1)] * max_length if self.batch_first else [
                input.size(0)
            ] * max_length

        batch_size = input.size(0) if self.batch_first else input.size(1)

        if not self.batch_first:
            input = input.transpose(0, 1)
        # make the data time-first
        controller_hidden, mem_hidden, last_read = self._init_hidden(
            hx, batch_size, reset_experience)

        # concat input with last read (or padding) vectors
        inputs = [
            T.cat([input[:, x, :], last_read], 1) for x in range(max_length)
        ]

        # batched forward pass per element / word / etc
        if self.debug:
            viz = None

        outs = [None] * max_length
        read_vectors = None

        # pass through time
        for time in range(max_length):
            # pass thorugh layers
            for layer in range(self.num_layers):
                # this layer's hidden states
                chx = controller_hidden[layer]
                m = mem_hidden if self.share_memory else mem_hidden[layer]
                # pass through controller
                outs[time], (chx, m, read_vectors) = \
                    self._layer_forward(
                        inputs[time], layer, (chx, m), pass_through_memory)

                # debug memory
                if self.debug:
                    viz = self._debug(m, viz)

                # store the memory back (per layer or shared)
                if self.share_memory:
                    mem_hidden = m
                else:
                    mem_hidden[layer] = m
                controller_hidden[layer] = chx

                if read_vectors is not None:
                    # the controller output + read vectors go into next layer
                    outs[time] = T.cat([outs[time], read_vectors], 1)
                else:
                    outs[time] = T.cat([outs[time], last_read], 1)
                inputs[time] = outs[time]

        if self.debug:
            viz = {k: np.array(v) for k, v in viz.items()}
            reshape_keys = [
                "memory", "link_matrix", "precedence", "read_weights",
                "write_weights", "usage_vector"
            ]
            for key in reshape_keys:
                viz[key] = viz[key].reshape(
                    viz[key].shape[0], viz[key].shape[1] * viz[key].shape[2])
            # mean_keys = ["free_gates", "allocation_gate", "write_gate", "read_modes"]
            # for key in mean_keys:
            #     viz[key] = np.mean(viz[key], axis=0)

            # viz = {k: v.reshape(v.shape[0],  v.shape[1] * v.shape[2])
            #        for k, v in viz.items()}

        # pass through final output layer
        inputs_final = [self.output(i) for i in inputs]
        outputs = T.stack(inputs_final, 1 if self.batch_first else 0)

        if self.debug:
            self._update_controller_memory_contributions()
            c_contrib = [
                self.controller_contribution(i[:, :self.output_size])
                for i in inputs
            ]
            m_contrib = [
                self.memory_contribution(i[:, self.output_size:])
                for i in inputs
            ]

            c_contrib = T.stack(c_contrib, 1 if self.batch_first else 0)
            m_contrib = T.stack(m_contrib, 1 if self.batch_first else 0)
            outputs_ = outputs.clone().detach()
            # average over the batch dim
            c_contrib = c_contrib.mean(0)
            m_contrib = m_contrib.mean(0)
            outputs_ = outputs_.mean(0)

            c_contrib = nn.Softmax(1)(c_contrib)
            m_contrib = nn.Softmax(1)(m_contrib)
            outputs_ = nn.Softmax(1)(outputs_)

            m_influence = torch.abs(outputs_ - c_contrib)
            m_influence = torch.mean(m_influence, 1)
            m_influence_scalar = torch.mean(m_influence, 0)

            c_influence = torch.abs(outputs_ - m_contrib)
            c_influence = torch.mean(c_influence, 1)
            c_influence_scalar = torch.mean(c_influence, 0)

            m_inf_vec = m_influence.div(m_influence + c_influence)
            c_inf_vec = c_influence.div(m_influence + c_influence)

            m_inf = m_influence_scalar / (m_influence_scalar +
                                          c_influence_scalar)
            c_inf = c_influence_scalar / (m_influence_scalar +
                                          c_influence_scalar)

            viz["memory_influence"] = m_inf
            viz["controller_influence"] = c_inf
            viz["memory_influence_vec"] = m_inf_vec.detach().cpu().numpy()
            viz["controller_influence_vec"] = c_inf_vec.detach().cpu().numpy()
        if is_packed:
            outputs = pack(output, lengths)

        if self.debug:
            return outputs, (controller_hidden, mem_hidden, read_vectors), viz
        else:
            return outputs, (controller_hidden, mem_hidden, read_vectors)
Example #27
0
    def forward(self, inputs, lengths=None):
        """
        Parameters
        ----------
        inputs: LongTensor
            The input data. Shape is (seq_len, batch_size) if
            batch_first=False, or (batch_size, seq_len) if
            batch_first=True.

        lengths: LongTensor or List[int], optional
            List of integers with the sequence length for each
            element in the batch.

        Returns
        -------
        output_distribution: FloatTensor
            FloatTensor of shape (batch_size, vocab_size),
            which is a distribution over the vocabulary for each batch.

        """
        # Shape: (batch_size, seq_len, embedding_size) if batch_first=True
        # Shape: (batch_size, seq_len, embedding_size) if batch_first=False
        embedded_seq = self.embedding(inputs)
        embedded_seq = self.dropout(embedded_seq)

        if lengths is not None:
            # Pack the sequence
            embedded_seq = pack(embedded_seq,
                                lengths,
                                batch_first=self.batch_first)

        # encoded_seq shape if batch_first=True: (batch_size, seq_len,
        #                                         hidden_size)
        # encoded_seq shape if batch_first=False: (seq_len, batch_size,
        #                                          hidden_size)
        self.rnn.flatten_parameters()
        encoded_seq, _ = self.rnn(embedded_seq)

        if lengths is not None:
            # Pad the packed sequence
            encoded_seq, _ = pad(encoded_seq, batch_first=self.batch_first)

        # Apply dropout.
        encoded_seq = self.dropout(encoded_seq)

        # Get the output after encoding the entire sequence
        if lengths is not None:
            # Shape: (batch_size, hidden_size)
            idx = (torch.tensor(
                lengths, dtype=torch.long, device=encoded_seq.device) -
                   1).view(-1, 1).expand(len(lengths), encoded_seq.size(2))
            time_dimension = 1 if self.batch_first else 0
            idx = idx.unsqueeze(time_dimension)
            last_output = encoded_seq.gather(time_dimension,
                                             idx).squeeze(time_dimension)
        else:
            last_output = encoded_seq[:, -1]

        # Run this reshaped RNN output through the decoder to get
        # output of shape (batch_size, vocab_size)
        output_distribution = self.decoder(last_output)

        # Return decoded, a distribution over the output vocabulary for
        # each sequence in the batch.
        # Shape: (batch_size, vocab_size)
        return output_distribution
Example #28
0
    def forward(self,
                x,
                hx=(None, None),
                reset_experience=False,
                pass_through_memory=True):
        # handle packed data
        is_packed = type(x) is PackedSequence
        if is_packed:
            x, lengths = pad(x)
            max_length = lengths[0]
        else:
            max_length = x.size(1) if self.batch_first else x.size(0)
            lengths = [x.size(1)] * max_length if self.batch_first else [
                x.size(0)
            ] * max_length

        batch_size = x.size(0) if self.batch_first else x.size(1)

        controller_hidden, mem_hidden = self._init_hidden(
            hx, batch_size, reset_experience)

        # batched forward pass per element / word / etc
        if self.debug:
            viz = None

        outs = [None] * max_length
        read_vectors = None

        # pass thorugh layers
        for layer in range(self.num_layers):
            # this layer's hidden states
            chx = controller_hidden[layer]
            m = mem_hidden if self.share_memory else mem_hidden[layer]
            # pass through controller
            x, (chx, m,
                read_vectors) = self._layer_forward(x, layer, (chx, m),
                                                    pass_through_memory)

            # debug memory
            if self.debug:
                viz = self._debug(m, viz)

            # store the memory back (per layer or shared)
            if self.share_memory:
                mem_hidden = m
            else:
                mem_hidden[layer] = m
            controller_hidden[layer] = chx

            if read_vectors is not None:
                # the controller output + read vectors go into next layer
                x = T.cat([x[-1, :, :], read_vectors], 1)
            read_vectors = None

        if self.debug:
            viz = {k: np.array(v) for k, v in viz.items()}
            viz = {
                k: v.reshape(v.shape[0], v.shape[1] * v.shape[2])
                for k, v in viz.items()
            }

        # pass through final output layer
        for layer in range(len(self.output)):
            x = self.output[layer](x)

        if is_packed:
            x = pack(x, lengths)

        if self.debug:
            return x, (controller_hidden, mem_hidden), viz
        else:
            return x, (controller_hidden, mem_hidden)
    def forward(self, input, hx=None):
        timesteps = input.size(1) if self.batch_first else input.size(0)
        directions = 2 if self.bidirectional else 1
        is_packed = isinstance(input, PackedSequence)

        if is_packed:
            input, batch_sizes = pad(input)
            max_batch_size = batch_sizes[0]
        else:
            batch_sizes = None
            max_batch_size = input.size(0) if self.batch_first else input.size(
                1)

        # layer * direction
        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            hx = var(input.data.new(max_batch_size, self.hidden_size).zero_(),
                     requires_grad=False)
            hx = (hx, hx)
            hx = [[hx for x in range(directions)]
                  for d in range(self.num_layers)]

        # make weights indexable with layer -> direction
        ws = self.all_weights
        if directions == 1:
            ws = [[w] for w in ws]
        else:
            ws = [[ws[l * 2], ws[l * 2 + 1]] for l in range(self.num_layers)]

        # make input batch-first, separate into timeslice wise chunks
        input = input if self.batch_first else input.transpose(0, 1)
        os = [[input[:, i, :] for i in range(timesteps)]
              for d in range(directions)]
        if directions > 1:
            os[1].reverse()

        for time in range(timesteps):
            for layer in range(self.num_layers):
                for direction in range(directions):

                    if self.bias:
                        (w_ih, w_hh, b_ih, b_hh) = ws[layer][direction]
                    else:
                        (w_ih, w_hh) = ws[layer][direction]
                        b_ih = None
                        b_hh = None

                    hy, cy = SubLSTMCellF(os[direction][time],
                                          hx[layer][direction], w_ih, w_hh,
                                          b_ih, b_hh)
                    hx[layer][direction] = (hy, cy)
                    os[direction][time] = hy

                if directions > 1:
                    os[0][time] = T.cat(
                        [os[d][time] for d in range(directions)], -1)
                    os[1][time] = os[0][time]

        output = T.stack([T.stack(o, 1) for o in os])
        output = T.cat(output, -1) if self.bidirectional else output[0]
        output = output if self.batch_first else output.transpose(0, 1)

        if is_packed:
            output = pack(output, batch_sizes)
        return output, hx
Example #30
0
    def forward(self, word_inputs, word_seq_length, char_inputs,
                char_seq_length, char_recover):
        """

             word_inputs: (batch_size,seq_len)
             word_seq_length:()
        """
        batch_size = word_inputs.size(0)
        seq_len = word_inputs.size(1)
        word_emb = self.word_embedding(word_inputs)
        # word_rep = self.drop(word_emb)
        word_rep = word_emb
        if self.args.use_char:
            size = char_inputs.size(0)
            char_emb = self.char_embedding(char_inputs)
            char_emb = pack(char_emb,
                            char_seq_length.numpy(),
                            batch_first=True)
            char_lstm_out, char_hidden = self.char_feature(char_emb)
            char_lstm_out = pad(char_lstm_out, batch_first=True)
            char_hidden = char_hidden[0].transpose(1, 0).contiguous().view(
                size, -1)
            char_hidden = char_hidden[char_recover]
            char_hidden = char_hidden.view(batch_size, seq_len, -1)
            if self.args.attention:
                word_rep = F.tanh(
                    self.attn1(word_emb) + self.attn2(char_hidden))
                z = F.sigmoid(self.attn3(word_rep))
                x = 1 - z
                word_rep = F.mul(z, word_emb) + F.mul(x, char_hidden)
            else:
                word_rep = torch.cat((word_emb, char_hidden), 2)

        word_rep = pack(word_rep,
                        word_seq_length.cpu().numpy(),
                        batch_first=True)
        out, hidden = self.word_feature(word_rep)
        out, _ = pad(out, batch_first=True)
        if self.args.lstm_attention:
            # tanh_out = F.tanh(out)
            # att_vec = self.att1(tanh_out)
            # att_sm_vec = F.softmax(att_vec)
            # out = F.mul(out,att_sm_vec)

            out_list, weight_list = [], []
            for idx in range(seq_len):
                # slice_out = out[:,0:idx+1,:]
                if idx + 2 > seq_len:
                    slice_out = out
                else:
                    slice_out = out[:, 0:idx + 2, :]
                # slice_out = out
                slice_out, weights = self.attention(slice_out)
                # slice_out, weights = SelfAttention(self.args.hidden_dim*2).forward(slice_out)
                out_list.append(slice_out.unsqueeze(1))
                weight_list.append(weights)
            out = torch.cat(out_list, dim=1)

            # pass
            # out = F.tanh(self.att1(out))
            # out = self.softmax(out)
            # out = out*out
            # out = self.att2(out)
        # else:
        out = self.hidden2tag(self.drop(out))
        return out