示例#1
0
 def forward(self, delta):
     out = torch.dropout(F.leaky_relu(self.fc1(delta), 0.2), 0.4,
                         self.training)
     mu = torch.dropout(F.leaky_relu(self.fc2(out), 0.2), 0.4,
                        self.training).reshape(-1, 1, self.mu_dim)
     mu = mu.repeat(1, self.num_imgs, 1)
     return mu
示例#2
0
 def forward(self, x, state=None, M=None):
     '''
     Chebyshev graph convolution operation
     :param x: (batch_size,N, dim_in)
     :return: (batch_size,N, dim_out)
     '''
     batch_size, num_of_vertices, in_channels = x.shape
     output = torch.zeros(batch_size, num_of_vertices, self.dim_out).to(
         self.DEVICE)  # (batch_size,N, dim_out)
     L_tilde = scaled_Laplacian(self.adj)
     cheb_polynomials = [
         torch.from_numpy(i).type(torch.FloatTensor)
         for i in cheb_polynomial(L_tilde, self.order_K)
     ]
     if state is not None:
         s = torch.einsum('ij,jkm->ikm', M,
                          state.permute(1, 0, 2)).permute(1, 0, 2)
         x = torch.cat((x, s), dim=-1)
     x0 = x
     if self._in_drop != 0:
         x = torch.dropout(x, 1.0 - self._in_drop, train=True)
     # k-order展开
     for k in range(self.order_K):
         # chebyshev多项式
         output = output + x.permute(0, 2, 1).matmul(cheb_polynomials[k].to(
             self.DEVICE)).permute(0, 2, 1).matmul(self.Theta[k])
     output = torch.matmul(output, self.weights)
     output = output + self.biases
     res = F.relu(output)
     if self._gcn_drop != 0.0:
         res = torch.dropout(res, 1.0 - self._gcn_drop, train=True)
     if self._residual:
         x0 = self.linear(x0)
         res = res + x0
     return res  # (batch_size,N, dim_out)
示例#3
0
    def forward(self, e, r, er_e2, direction="tail"):
        assert direction in ("head", "tail"), "Unknown forward direction"

        emb_hr_e = self.ent_embeddings(e)  # [m, k]
        emb_hr_r = self.rel_embeddings(r)  # [m, k]

        if direction == "tail":
            ere2_sigmoid = self.g(
                torch.dropout(self.f1(emb_hr_e, emb_hr_r),
                              p=self.hidden_dropout,
                              train=True), self.ent_embeddings.weight)
        else:
            ere2_sigmoid = self.g(
                torch.dropout(self.f2(emb_hr_e, emb_hr_r),
                              p=self.hidden_dropout,
                              train=True), self.ent_embeddings.weight)

        ere2_loss_left = -torch.sum(
            (torch.log(torch.clamp(ere2_sigmoid, 1e-10, 1.0)) *
             torch.max(torch.FloatTensor([0]).to(self.device), er_e2)))
        ere2_loss_right = -torch.sum(
            (torch.log(torch.clamp(1 - ere2_sigmoid, 1e-10, 1.0)) * torch.max(
                torch.FloatTensor([0]).to(self.device), torch.neg(er_e2))))

        hrt_loss = ere2_loss_left + ere2_loss_right

        return hrt_loss
示例#4
0
 def forward(self, x, presence=None):
     y = self._mab(x, x, x, presence)
     if self._dropout_rate > 0.:
         x = torch.dropout(x, p=self._dropout_rate)
     y += x
     if presence is not None:
         y *= presence
     y = y if getattr(self, 'ln0', None) is None else self.ln0(y)
     h = self.mlp2(self.relu(self.mlp1(y)))
     if self._dropout_rate > 0.:
         h = torch.dropout(h, p=self._dropout_rate)
     h += y
     h = h if getattr(self, 'ln1', None) is None else self.ln1(h)
     return h
示例#5
0
    def forward(self, inputs: Tensor, targets: Optional[Tensor] = None) -> Tensor:
        """
        dynamic convoluitonal recurrent neural network
        :param inputs: [B, n_hist, N, input_dim]
        :param targets: exists for training, tensor, [B, n_pred, N, output_dim]
        :return: tensor, [B, n_pred, N, input_dim]
        """
        supports = self.adaptive_supports(inputs.device)

        h, c = self.encoding(inputs, supports)

        torch.dropout(h, p=self.dropout, train=self.training)
        torch.dropout(c, p=self.dropout, train=self.training)

        outputs = self.decoding((h, c), supports, targets)
        return outputs
示例#6
0
文件: modules.py 项目: adobe/NLP-Cube
 def forward(self, hidden, encoder_outputs, return_logsoftmax=False):
     # hidden = [batch size, dec hid dim]
     # encoder_outputs = [src sent len, batch size, enc hid dim * 2]
     batch_size = encoder_outputs.shape[0]
     src_len = encoder_outputs.shape[1]
     # repeat encoder hidden state src_len times
     hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
     encoder_outputs = encoder_outputs
     # hidden = [batch size, src sent len, dec hid dim]
     # encoder_outputs = [batch size, src sent len, enc hid dim * 2]
     energy = torch.dropout(
         torch.tanh(
             self.attn(
                 torch.cat((hidden, encoder_outputs),
                           dim=2).transpose(1, 2)).transpose(1, 2)), 0.1,
         self.training)
     energy = energy.transpose(1, 2)
     # energy = [batch size, src sent len, dec hid dim]
     # energy = [batch size, dec hid dim, src sent len]
     # v = [dec hid dim]
     v = self.v.repeat(batch_size, 1).unsqueeze(1)
     # v = [batch size, 1, dec hid dim]
     attention = torch.bmm(v, energy).squeeze(1)
     # attention= [batch size, src len]
     if return_logsoftmax:
         return F.log_softmax(attention, dim=1)
     else:
         return F.softmax(attention, dim=1)
示例#7
0
    def forward(self, x, y, dropout=0.0, intermediate_output=False):
        if isinstance(x, list):
            # msg enabled
            last_x_forward = None
            x = list(reversed(x))

            x_forward = self.blocks[0](x[0])

            for data, block, from_rgb in zip(x[1:], self.blocks[1:],
                                             self.from_rgb_combiners):
                last_x_forward = x_forward

                x_forward = from_rgb(data, x_forward)
                x_forward = torch.dropout(x_forward, p=dropout, train=True)
                x_forward = block(x_forward)

            if intermediate_output:
                return x_forward, last_x_forward.mean()
            else:
                return x_forward
        else:
            last_x_forward = None
            x_forward = x

            for block in self.blocks:
                last_x_forward = x
                x_forward = block(x_forward)

            if intermediate_output:
                return x_forward, last_x_forward.mean()
            else:
                return x_forward
示例#8
0
文件: tripod2.py 项目: adobe/tripod
    def forward(self, x, return_attentions=False, partition_dropout=False):
        input_emb = self.embedding(x)
        # summary-based embeddings
        output, hidden = self.encoder(input_emb)
        att_sum, cond_sum = self.attn_sum(hidden, output)
        att_gst, cond_gst = self.attn_gst(hidden)
        att_mem, cond_mem = self.attn_mem(hidden)
        cond_sum_ext = cond_sum.unsqueeze(1).repeat(1, x.shape[1], 1)
        cond_gst_ext = cond_gst.unsqueeze(1).repeat(1, x.shape[1], 1)
        cond_mem_ext = cond_mem.unsqueeze(1).repeat(1, x.shape[1], 1)

        cond = torch.relu(
            torch.cat([cond_sum_ext, cond_gst_ext, cond_mem_ext],
                      dim=2))  # cond_sum_ext + cond_gst_ext + cond_mem_ext
        if self.training:
            cond = torch.dropout(cond, 0.5, True)

        out_sum, hidden_sum = self.decoder(x,
                                           cond,
                                           partition_dropout=partition_dropout)

        if not return_attentions:
            return out_sum
        else:
            return out_sum, att_sum, att_gst, att_mem, cond_sum, cond_gst, cond_mem
示例#9
0
    def forward(self, batch):
        x_emb = batch['x_input']
        x_spa = batch['x_input_spa']
        x_lang = batch['x_lang']
        x_lang = self._lang_emb(x_lang).unsqueeze(1).repeat(
            1, x_emb[0].shape[1], 1)
        x_word_char = batch['x_word_char']
        x_word_case = batch['x_word_case']
        x_word_lang = batch['x_word_lang']
        x_word_masks = batch['x_word_masks']
        x_word_len = batch['x_word_len']
        x_sent_len = batch['x_sent_len']
        char_emb_packed = self._wg(x_word_char, x_word_case, x_word_lang,
                                   x_word_masks, x_word_len)

        sl = x_sent_len.cpu().numpy()

        x_char_emb = unpack(char_emb_packed,
                            sl,
                            x_emb[0].shape[1],
                            device=self._get_device())

        word_emb_ext = None

        for ii in range(len(x_emb)):
            we = x_emb[ii]
            if word_emb_ext is None:
                word_emb_ext = self._ext_proj[ii](we.float().to(
                    self._get_device()))
            else:
                word_emb_ext = word_emb_ext + self._ext_proj[ii](we)

        word_emb_ext = word_emb_ext / len(x_emb)
        word_emb_ext = torch.tanh(word_emb_ext)
        x_emb = word_emb_ext
        x_spa_emb = self._spa_emb(x_spa)
        x_emb = mask_concat([x_emb, x_char_emb], 0.33, self.training,
                            self._get_device())
        x_emb = torch.cat([x_emb, x_spa_emb], dim=-1)
        x = torch.cat([x_emb, x_lang], dim=-1).permute(0, 2, 1)
        x_lang = x_lang.permute(0, 2, 1)
        half = self._config.cnn_filter // 2
        res = None
        cnt = 0
        for conv in self._convs:
            conv_out = conv(x)
            tmp = torch.tanh(conv_out[:, :half, :]) * torch.sigmoid(
                (conv_out[:, half:, :]))
            if res is None:
                res = tmp
            else:
                res = res + tmp
            x = torch.dropout(tmp, 0.2, self.training)
            cnt += 1
            if cnt != self._config.cnn_layers:
                x = torch.cat([x, x_lang], dim=1)
        x = x + res
        x = torch.cat([x, x_lang], dim=1)
        x = x.permute(0, 2, 1)
        return self._output(x)
    def forward(self, input):

        output = self.fc(input)
        output = self.nl(output)
        if self.dropout:
            output = torch.dropout(output, p=0.5, train=True)
        return output
示例#11
0
 def forward(self, input):
     """return the index of the cluster closest to the input"""
     Y = torch.matmul(input, self.w) + self.b  # + is a smart plus
     #Y = self.linear(input)
     if self.act is not None:
         Y = self.act(Y)
     return torch.dropout(Y, p=self.dr, train=self.training)
示例#12
0
    def forward(self, X):
        out = X
        for layer in self.layers:
            out = layer(out)
            # We are using dropout at the test time
            out = torch.dropout(out, p=self.dropout_rate, train=True)

        return out
示例#13
0
 def __init__(self, hidden_size, dropout=None):
     super(PicoAttention, self).__init__()
     if dropout == None :
         self.dropout = None
     else:
         self.dropout = torch.dropout(hidden_size, dropout, False)
     
     self.softmax = nn.Softmax(dim=-1)
示例#14
0
 def forward(self, inputs, mask=None, keep_prob=1.0, is_train=True):
     x = torch.dropout(inputs, keep_prob, is_train)
     x = self.fc1(x)
     x = torch.tanh(x)
     x = self.fc2(x)
     if mask is not None:
         x = self.softmax_mask(x, mask)
     x = F.softmax(x, dim=1)
     x = x.squeeze(-1)
     return x
示例#15
0
 def forward(self, input, adj, training=True):
     if self.dropout >= 0.:
         support = torch.dropout(input, p=self.dropout, train=training)
     else:
         support = input
     support = torch.mm(support, self.weight)
     output = torch.mm(adj, support)
     if self.bias is not None:
         output = output + self.bias
     return self.act(output)
示例#16
0
    def forward(self, inputs, input_mask):
        #batch_size, max_word_count, embedding_size
        inputs_mean = get_vector_mean(inputs, input_mask)
        inputs_mean = torch.dropout(inputs_mean,
                                    p=self.dropout_,
                                    train=self.training)
        #inputs_mean = self.drop_layer(inputs_mean)

        f_s = torch.tanh(self.f_W(inputs_mean))
        return f_s
示例#17
0
    def forward(self, domain_list):

        x = torch.tensor(self._make_input(domain_list), dtype=torch.long, device=self._get_device())
        hidden = self._char_emb(x)
        hidden = torch.dropout(hidden, 0.5, self.training)
        output, _ = self._rnn(hidden)
        output = output[:, -1, :]

        hidden = self._hidden(output)

        return self._softmax_type(hidden), self._softmax_subtype(hidden)
示例#18
0
 def forward(self, inputs: Tensor, graph: dgl.DGLGraph) -> Tensor:
     """
     forward of spatio-temporal convolutional block
     :param inputs: tensor, [N, T, F_in]
     :param graph: DGLGraph, with `N` nodes
     :return: tensor, [N, T, F_out]
     """
     outputs = self.t_conv1(inputs)
     outputs = self.s_conv(outputs, graph)
     outputs = self.t_conv2(outputs)
     # outputs = self.ln(outputs)
     return torch.dropout(outputs, p=self.dropout, train=self.training)
示例#19
0
 def att_match(self,
               mid,
               pat,
               mid_mask,
               pat_mask,
               keep_prob=1.0,
               is_train=True):
     mid_d = torch.dropout(mid, keep_prob, is_train)
     pat_d = torch.dropout(pat, keep_prob, is_train)
     mid_a = self.attention_model(mid_d,
                                  mask=mid_mask,
                                  keep_prob=keep_prob,
                                  is_train=is_train)
     pat_a = self.attention_model(pat_d,
                                  mask=pat_mask,
                                  keep_prob=keep_prob,
                                  is_train=is_train)
     # import pdb; pdb.set_trace()
     mid_v = torch.sum(mid_a.unsqueeze(-1) * mid, dim=1)
     pat_v = torch.sum(pat_a.unsqueeze(-1) * pat, dim=1)
     pat_v_d = torch.sum(pat_a.unsqueeze(-1) * pat_d, dim=1)
     sur_sim = self.cosine(mid_v, pat_v_d)
     pat_sim = self.cosine(pat_v, pat_v_d)
     return sur_sim, pat_sim
示例#20
0
    def _vanilla_dropout(self, x, is_training):
        bs, inp_seq_len = x.shape

        if self.draw_dropout_per_col:
            kOnes = self._get_cached_constant_ones((bs, 1, 1), x.device)
            vecs = []
            for _ in range(inp_seq_len):
                vecs.append(
                    torch.dropout(kOnes,
                                  p=np.random.randint(0, inp_seq_len) /
                                  inp_seq_len,
                                  train=is_training))
            dropout_vec = torch.cat(vecs, dim=1)
        else:
            kOnes = self._get_cached_constant_ones((bs, inp_seq_len, 1),
                                                   x.device)
            dropout_vec = torch.dropout(kOnes,
                                        p=np.random.randint(0, inp_seq_len) /
                                        inp_seq_len,
                                        train=is_training)
        # During training, non-dropped 1's are scaled by 1/(1-p), so we
        # clamp back to 1.  Shaped [bs, num cols, 1].
        batch_mask = torch.clamp(dropout_vec, 0, 1)
        return batch_mask
 def forward(self, x):
     x = self.cnn1(x)
     x = self.conv1_bn(x)
     x = torch.relu(x)
     x = self.maxpool1(x)
     x = self.cnn2(x)
     x = self.conv2_bn(x)
     x = torch.relu(x)
     x = self.maxpool2(x)
     x = x.view(-1, 16 * 62 * 62)
     x = torch.dropout(x, p=0.5, train=True)
     x = self.fc1(x)
     x = self.bn_fc1(x)
     x = torch.relu(x)
     x = self.fc2(x)
     x = self.bn_fc2(x)
     return x
示例#22
0
文件: SPGA.py 项目: cleverer123/NGACF
    def forward(self, input, adj):
        h = torch.mm(input, self.W) # ()
        N = h.size()[0]

        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = torch.softmax(attention, dim=1)
        attention = torch.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, h)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime
示例#23
0
    def forward(self, string_list):
        x_char, x_case = self._make_input(string_list)
        x_char = torch.tensor(x_char,
                              dtype=torch.long,
                              device=self._get_device())
        x_case = torch.tensor(x_case,
                              dtype=torch.long,
                              device=self._get_device())
        hidden = torch.cat([self._char_emb(x_char),
                            self._case_emb(x_case)],
                           dim=-1)
        hidden = torch.dropout(hidden, 0.5, self.training)
        output, _ = self._rnn(hidden)

        hidden = self._hidden(output)

        return self._softmax_type(hidden)
示例#24
0
    def forward(self, x, lengths, atten):
        batch_size = x.size(0)  # x.shape (B, seq, HIDDEN_SIZE)
        seq_size = x.size(1)

        # shape (B, seq, HIDDEN_SIZE)
        k = self.project_k.forward(x)
        q = self.project_q.forward(x)
        v = self.project_v.forward(x)

        # shape (B, seq, Heads, HIDDEN_SIZE/Heads)
        # shape (B, Heads, seq, HIDDEN_SIZE/Heads)
        k = k.view(batch_size, seq_size, TRANSFORMER_HEADS,
                   int(HIDDEN_SIZE / TRANSFORMER_HEADS)).transpose(1, 2)
        q = q.view(batch_size, seq_size, TRANSFORMER_HEADS,
                   int(HIDDEN_SIZE / TRANSFORMER_HEADS)).transpose(1, 2)
        v = v.view(batch_size, seq_size, TRANSFORMER_HEADS,
                   int(HIDDEN_SIZE / TRANSFORMER_HEADS)).transpose(1, 2)

        atten_raw = q @ k.transpose(-1, -2) / np.sqrt(x.size(-1))

        mask = torch.tril(torch.ones(seq_size,
                                     seq_size)).to(DEVICE)  # (Seq, Seq)
        atten_mask = atten_raw.masked_fill(
            mask == 0, value=float('-inf'))  # (B, Seq, Seq)
        for idx, length in enumerate(lengths):  # (B, Seq, Seq)
            atten_mask[idx, :, length:] = float('-inf')
            atten_mask[idx, length:, :] = float('-inf')

        atten = torch.softmax(atten_mask, dim=-1)
        atten = atten.masked_fill(((atten > 0) == False), value=0.0)
        out = atten @ v

        out = out.transpose(1, 2)
        out = out.contiguous().view(batch_size, seq_size, HIDDEN_SIZE)
        atten = atten.detach().mean(
            dim=1)  # shape (B, Heads, seq, seq) => (B, seq, seq)

        # torch.nn.Module > self.training
        # model.eval() model.train()
        out_1 = x + torch.dropout(out, p=DROPOUT, train=self.training)
        out_1_norm = self.norm_1.forward(out_1)

        out_2 = self.ff.forward(out_1_norm)
        out_3 = out_1_norm + out_2
        y_prim = self.norm_2.forward(out_3)
        return y_prim, lengths, atten
示例#25
0
    def forward(self, x):
        for index in range(self.num_conv):
            # take information from block on convolution
            x = self.layers[index](x)
            x = self.activation(self.activations[index])(x)
            # dropout of size activation, with 0 denoting shortcut
            x = torch.dropout(x, self.dropouts[index], self.training)

        # reshape the tensor
        x = x.view(self.batch_size, -1)

        # Check to see if there is a output layer predefined.
        if len(self.layers) == self.num_conv:
            self.layers.append(nn.Linear(self.num_flat_features(x), 1))

        x = self.layers[-1](x)
        x = self.activation(self.activations[-1])(x)
        return x
示例#26
0
文件: modules.py 项目: adobe/NLP-Cube
    def forward(self, x_char, x_case, x_lang, x_mask, x_word_len):
        x_char = self._tok_emb(x_char)
        x_case = self._case_emb(x_case)
        x_lang = self._lang_emb(x_lang)

        x = torch.cat([x_char, x_case], dim=-1)
        x = x.permute(0, 2, 1)
        x_lang = x_lang.unsqueeze(1).repeat(1, x_case.shape[1],
                                            1).permute(0, 2, 1)
        half = self._num_filters // 2
        count = 0
        res = None
        skip = None
        for conv in self._convolutions_char:
            count += 1
            drop = self.training
            if count >= len(self._convolutions_char):
                drop = False
            if skip is not None:
                x = x + skip

            x = torch.cat([x, x_lang], dim=1)
            conv_out = conv(x)
            tmp = torch.tanh(conv_out[:, :half, :]) * torch.sigmoid(
                (conv_out[:, half:, :]))
            if res is None:
                res = tmp
            else:
                res = res + tmp
            skip = tmp
            x = torch.dropout(tmp, 0.1, drop)
        x = x + res
        x = x.permute(0, 2, 1)
        x = x * x_mask.unsqueeze(2)
        pre = torch.sum(x, dim=1, dtype=torch.float)
        norm = pre / x_word_len.unsqueeze(1)
        # embeds = self._pre_out(norm)
        # norm = embeds.norm(p=2, dim=-1, keepdim=True)
        # embeds_normalized = embeds.div(norm)
        # return embeds_normalized

        return torch.tanh(self._pre_out(norm))
示例#27
0
    def mean_match(self,
                   mid,
                   pat,
                   mid_mask,
                   pat_mask,
                   keep_prob,
                   is_train=True):
        def mean(emb, mask):
            mask = mask.float()
            length = torch.sum(mask, dim=1)
            emb = torch.sum(emb, dim=1) / length.unsqueeze(1)
            return emb

        pat_d = torch.dropout(pat, keep_prob, is_train)
        mid_v = mean(mid, mid_mask)
        pat_v = mean(pat, pat_mask)
        pat_v_d = mean(pat_d, pat_mask)
        sur_sim = self.cosine(mid_v, pat_v_d)
        pat_sim = self.cosine(pat_v, pat_v_d)
        return sur_sim, pat_sim
示例#28
0
def attention(query, key, value, mask=None, dropout=None):
    """ Compute 'Scaled Dot Product Attention'

    Reference:
        (http://nlp.seas.harvard.edu/2018/04/03/attention.html)
        (https://github.com/BangLiu/QANet-PyTorch/blob/master/model/QANet.py)
    """
    # q, k, v: (batch_size, h=8, seq_len, d_k=16)
    # mask: (batch_size, seq_len)

    d_k = query.size(-1)
    scores = torch.matmul(query, key.permute(0,1,3,2)) \
            / math.sqrt(d_k)
    # scores: (batch_size, h=8, seq_len, seq_len)
    if mask is not None:
        mask = mask.unsqueeze(1)
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn
 def forward(self, query, key, value, **kwargs):
     # pytorch sparse tensors still under active development, so expect changes soon
     # for example, sparse batch matrix multiplication is not currently supported
     # TODO add support for masks
     m = query.size(0)
     n = key.size(0)
     if key.size(0) != value.size(0):
         raise RuntimeError("key and value must have same length")
     query = self.query_ff(query).view(m, -1, self.head_dim).transpose(0, 1)
     key = self.key_ff(key).view(n, -1, self.head_dim).transpose(0, 1)
     value = self.value_ff(value).view(n, -1, self.head_dim).transpose(0, 1)
     rows = torch.arange(m, device=query.device).repeat(2 * self.attn_span + 1, 1).transpose(0, 1).flatten()
     cols = torch.cat([torch.arange(i - self.attn_span, i + self.attn_span + 1, device=query.device) for i in range(n)])
     bounds = (cols >= 0) & (cols < n)
     cols[~bounds] = 0
     idxs = torch.stack([rows, cols])
     vals = (query[:, rows, :] * key[:, cols, :] * bounds.view(1, -1, 1)).sum(-1) / math.sqrt(n)
     vals[:, ~bounds] = -float("inf")
     vals = torch.dropout(torch.softmax(vals.view(-1, n, 2 * self.attn_span + 1), dim=-1), self.dropout, self.training).view(-1, idxs.size(1))
     attn_matrix = [torch.sparse.FloatTensor(idxs[:, bounds], val[bounds], (m, n)) for val in vals]
     out = self.out_ff(torch.stack([torch.sparse.mm(attn, val) for attn, val in zip(attn_matrix, value)]).transpose(0, 1).contiguous().view(n, -1, self.embed_dim))
     return out, attn_matrix
示例#30
0
    def ToTwoLevel(self, data):
        bs = data.size()[0]
        y_onehots = []
        data = data.long()
        for i, coli_dom_size in enumerate(self.input_bins):

            y_onehot = torch.zeros(bs, coli_dom_size, device=data.device)
            y_onehot.scatter_(1, data[:, i].view(-1, 1), 1)
            y_onehot = torch.dropout(y_onehot, p=0.3, train=self.training)

            # add on one-hot encoding at coarser second-level
            # e.g., for domain of 35, the 2nd level will have domain size of 4
            second_level_dom_size = 1 + coli_dom_size // 10
            y2_onehot = torch.zeros(bs,
                                    second_level_dom_size,
                                    device=data.device)
            y2_onehot.scatter_(1, data[:, i].view(-1, 1) // 10, 1)

            y_onehots.append(y_onehot)
            y_onehots.append(y2_onehot)

        # [bs, sum(dist size) + sum(2nd_level)]
        return torch.cat(y_onehots, 1)