Exemple #1
0
        def _compute_attention_sum(q, m, length):
            # q : batch_size x lstm_size
            # m : batch_size x max(length) x embedded_dim
            assert torch.max(length) == m.size()[1]
            max_len = m.size()[1]
            if simple:
                if q.size()[-1] != m.size()[-1]:
                    q = self.attention(q) # batch_size x embedded_dim
                weight_logit = torch.bmm(m, q.unsqueeze(-1)).squeeze(2) # batch_size x n_features
            else:
                linear_m = self.attention[1]
                linear_q = self.attention[0]
                linear_out = self.attention[2]

                packed = pack(m, list(length), batch_first=True)
                proj_m = PackedSequence(linear_m(packed.data), packed.batch_sizes)
                proj_m, _ = pad(proj_m, batch_first=True)  # batch_size x n_features x proj_dim
                proj_q = linear_q(q).unsqueeze(1) # batch_size x 1 x proj_dim
                packed = pack(F.relu(proj_m + proj_q), list(length), batch_first=True)
                weight_logit = PackedSequence(linear_out(packed.data), packed.batch_sizes)
                weight_logit, _ = pad(weight_logit, batch_first=True) # batch_size x n_features x 1
                weight_logit = weight_logit.squeeze(2)

            # max_len = weight_logit.size()[1]
            indices = torch.arange(0, max_len,
                out=torch.LongTensor(max_len).unsqueeze(0)).cuda()
            # TODO here.. cuda..
            mask = indices < length.unsqueeze(1)#.long()
            weight_logit[1-mask] = -np.inf
            weight = F.softmax(weight_logit, dim=1) # nonzero x max_len
            weighted = torch.bmm(weight.unsqueeze(1), m)
            # batch_size x 1 x max_len
            # batch_size x     max_len x embedded_dim
            # = batch_size x 1 x embedded_dim
            return weighted.squeeze(1), weight  #nonzero x embedded_dim
Exemple #2
0
 def forward(self, ob, state_in=None):
     """Forward."""
     if isinstance(ob, PackedSequence):
         x = ob.data
     else:
         x = ob
     x = (x.float() / 128.) - 1.
     x = F.relu(self.conv1(x))
     x = F.relu(self.conv2(x))
     x = F.relu(self.fc(x.view(-1, self.nunits)))
     if isinstance(ob, PackedSequence):
         x = PackedSequence(x,
                            batch_sizes=ob.batch_sizes,
                            sorted_indices=ob.sorted_indices,
                            unsorted_indices=ob.unsorted_indices)
     else:
         x = x.unsqueeze(0)
     if state_in is None:
         x, state_out = self.lstm(x)
     else:
         x, state_out = self.lstm(x, state_in)
     if isinstance(x, PackedSequence):
         x = x.data
     else:
         x = x.squeeze(0)
     return self.dist(x), self.vf(x), state_out
Exemple #3
0
    def forward(self, x, y_adj, input_len):
        seq_len = x.shape[1]
        x = self.input_linear(x)
        # generates node-wise representation h_t
        x = pack_padded_sequence(x, input_len, batch_first=True)
        output_packed, self.h = self.lstm(x, self.h)
        output, _ = pad_packed_sequence(output_packed,
                                        batch_first=True,
                                        total_length=seq_len)

        # generate self-attention masks
        mask_sequence = generate_mask_sequence(size=output.shape[1])
        mask_pad = generate_mask_pad(input_len, output.shape)
        mask = mask_sequence & mask_pad

        # apply self-attention
        output_att, _ = self.attention(output, output, output, mask=mask)

        # skip-connection
        if self.concat:
            output = torch.cat([output, output_att], dim=2)
        else:
            output = output + output_att

        h_adj = self.output_hidden(output)
        h_adj = pack_padded_sequence(h_adj, input_len, batch_first=True).data
        h_adj = h_adj.unsqueeze(0)
        hidden_null = torch.zeros(4 - 1, h_adj.size(1),
                                  h_adj.size(2)).to("cuda:0")
        self.h_adj = torch.cat(
            (h_adj, hidden_null),
            dim=0)  # initializes the edge-wise hidden vectors

        x_adj = torch.cat((torch.zeros(y_adj.size(0), y_adj.size(1),
                                       1).to("cuda:0"), y_adj[:, :, 0:-1]),
                          dim=2)
        x_adj = pack_padded_sequence(x_adj, input_len, batch_first=True)
        x_adj, pad_container = x_adj[0], x_adj[1]
        x_adj = x_adj.unsqueeze(-1)

        x_adj = self.input_adj(x_adj)
        output_adj, self.h_adj = self.adj_lstm(
            x_adj, self.h_adj)  # get edge-wise representation
        output_adj = PackedSequence(output_adj, pad_container, None, None)
        output_adj = pad_packed_sequence(output_adj, batch_first=True)[0]

        output_adj = F.relu_(self.output_adj_1(output_adj))
        output_adj = torch.sigmoid(self.output_adj_2(output_adj))
        output_adj = output_adj.squeeze(-1)  # a_t

        output_coord = F.relu_(self.output_coord_1(output))
        output_coord = torch.tanh(self.output_coord_2(output_coord))  # x_t

        return output_adj, output_coord
Exemple #4
0
    def forward(self, x, y_adj, input_len):
        seq_len = x.shape[1]
        x = self.input_linear(x)
        x = pack_padded_sequence(x, input_len, batch_first=True)
        output_packed, self.h = self.lstm(x, self.h)
        output, _ = pad_packed_sequence(output_packed,
                                        batch_first=True,
                                        total_length=seq_len)
        # generate node-level representation

        h_adj = self.output_hidden(output)
        h_adj = pack_padded_sequence(h_adj, input_len, batch_first=True).data
        h_adj = h_adj.unsqueeze(0)
        hidden_null = torch.zeros(4 - 1, h_adj.size(1),
                                  h_adj.size(2)).to("cuda:0")
        # the node-level representation initializes the first hidden state in the edge-wise RNN
        self.h_adj = torch.cat((h_adj, hidden_null), dim=0)

        # creates the input of the edge-wise RNN with teacher-forcing, prepending a zero-vector
        x_adj = torch.cat((torch.zeros(y_adj.size(0), y_adj.size(1),
                                       1).to("cuda:0"), y_adj[:, :, 0:-1]),
                          dim=2)
        x_adj = pack_padded_sequence(x_adj, input_len, batch_first=True)
        x_adj, pad_container = x_adj[0], x_adj[1]
        x_adj = x_adj.unsqueeze(-1)

        # generate edge-wise representation
        x_adj = self.input_adj(x_adj)
        output_adj, self.h_adj = self.adj_lstm(x_adj, self.h_adj)
        output_adj = PackedSequence(output_adj, pad_container, None, None)
        output_adj = pad_packed_sequence(
            output_adj, batch_first=True)[0]  # edge-wise representation

        output_adj = F.relu_(self.output_adj_1(output_adj))
        output_adj = torch.sigmoid(self.output_adj_2(output_adj))
        output_adj = output_adj.squeeze(-1)  # a_t == output_adj

        output_coord = F.relu_(self.output_coord_1(output))
        output_coord = torch.tanh(
            self.output_coord_2(output_coord))  # x_t == output_coord

        return output_adj, output_coord
 def forward(self, ob, state_in=None):
     """Forward."""
     if isinstance(ob, PackedSequence):
         x = self.net(ob.data.float())
         x = PackedSequence(x,
                            batch_sizes=ob.batch_sizes,
                            sorted_indices=ob.sorted_indices,
                            unsorted_indices=ob.unsorted_indices)
     else:
         x = self.net(ob.float()).unsqueeze(0)
     if state_in is None:
         x, state_out = self.lstm(x)
     else:
         x, state_out = self.lstm(x, state_in['lstm'])
     if isinstance(x, PackedSequence):
         x = x.data
     else:
         x = x.squeeze(0)
     state_out = {
         'lstm': state_out,
         '1': torch.zeros_like(state_out[0])
     }
     return self.dist(x), self.vf(x), state_out