Esempio n. 1
0
def multi_head_attention(x, y=None, num_head=8, dropout=0.1, mask=None, **kw):
    def split_heads(t):  # (B, C, L) -> (B, N, H, L) where N*H == C
        return t.reshape(batch, num_head, size // num_head, t.shape[-1])

    def merge_heads(t):  # (B, N, H, L) -> (B, C, L)
        return t.reshape(batch, -1, t.shape[-1])  # (B, C, L)

    if y is None:
        y = x  # self attention
    batch, size = x.shape[:2]  # B, C, Lx
    assert size % num_head == 0, 'num_head must be a divisor of size.'
    assert y.shape[:2] == x.shape[:2], 'The first 2 dims of x, y must match.'
    q = W.linear(x, size)  # query
    k = W.linear(y, size)  # key
    v = W.linear(y, size)  # value
    q = split_heads(q)  # (B, N, H, Lx)
    k = split_heads(k)  # (B, N, H, Ly)
    v = split_heads(v)  # (B, N, H, Ly)
    q *= (size // num_head)**(-0.5)
    a = q.transpose(2, 3).contiguous().matmul(
        k)  # attention weights, (B, N, Lx, Ly)
    if mask is not None:
        a += mask
    a = F.softmax(a, dim=-1)
    a = W.dropout(a, dropout)
    x = v.matmul(a.transpose(2, 3).contiguous())  # (B, N, H, Lx)
    x = merge_heads(x)  # (B, C, Lx)
    return W.linear(x, size)
Esempio n. 2
0
 def forward(self, x):
     x = W.conv(x, 20, 5, activation='relu')
     x = F.max_pool2d(x, 2)
     x = W.conv(x, 50, 5, activation='relu')
     x = F.max_pool2d(x, 2)
     x = x.view(-1, 800)
     x = W.linear(x, 500, activation='relu')
     x = W.linear(x, 10)
     return F.log_softmax(x, dim=1)
Esempio n. 3
0
def test_linear():
    m = nn.Module()
    x = torch.randn(1, 2, 3)  # BDC
    torch.manual_seed(100)
    y0 = nn.Linear(3, 4)(x)
    torch.manual_seed(100)
    y1 = W.linear(x, 4, parent=m, in_shape='BDC', out_shape='BDC')
    assert torch.equal(y0, y1), 'linear incorrect output on 1d signal.'
    m = nn.Module()
    x = torch.randn(1, 2, 3, 4)  # BDC
    torch.manual_seed(100)
    y0 = nn.Linear(4, 3)(x)
    torch.manual_seed(100)
    y1 = W.linear(x, 3, parent=m, in_shape='BDC', out_shape='BDC')
    assert torch.equal(y0, y1), 'batch_norm incorrect output on 2d signal.'
Esempio n. 4
0
 def forward(self, x):  # D
     embedding_dim, hidden_dim, vocab_size, tagset_size = self.arg
     y = W.embedding(x, embedding_dim, vocab_size)  # D->DC
     y = W.lstm(y.T[None, ...], hidden_dim)  # DC->BCD
     y = W.linear(y, tagset_size)  # BCD
     y = F.log_softmax(y, dim=1)  # BCD
     return y[0].T  # DC
Esempio n. 5
0
 def forward(self, x):
     y = W.conv(x, 64, 7, stride=2, padding=3, bias=False, name='conv1')
     y = W.batch_norm(y, activation='relu', name='bn1')
     y = F.max_pool2d(y, 3, stride=2, padding=1)
     for i, spec in enumerate(self.stack_spec):
         y = stack(y, *spec, i, block=self.block)
     y = F.adaptive_avg_pool2d(y, 1)
     y = torch.flatten(y, 1)
     y = W.linear(y, 1000, name='fc')
     return y
Esempio n. 6
0
 def forward(self, x):
     y = W.conv(x, 64, 7, stride=2, padding=3, bias=False)
     y = W.batch_norm(y, activation='relu')
     y = F.max_pool2d(y, 3, stride=2, padding=1)
     for spec in self.stack_spec:
         y = stack(y, *spec, block=self.block)
     y = F.adaptive_avg_pool2d(y, 1)
     y = torch.flatten(y, 1)
     y = W.linear(y, 2000)
     return y
Esempio n. 7
0
 def forward(self, x):
     x = conv_bn_act(x, 32, kernel=3, stride=2, name='head')
     for size, expand, kernel, stride, repeat, se_ratio, dc_ratio in spec_b0:
         for i in range(repeat):
             stride = stride if i == 0 else 1
             x = mb_block(x, size, expand, kernel, stride, se_ratio,
                          dc_ratio)
     x = conv_bn_act(x, 1280, name='tail')
     x = F.adaptive_avg_pool2d(x, 1)
     x = W.dropout(x, 0.2)
     x = x.view(x.shape[0], -1)
     x = W.linear(x, 1000)
     return x
Esempio n. 8
0
def classify(x, size, *arg, **kw):
    x = W.dropout(x, rate=0.2, name='classifier-0')
    return W.linear(x, size, name='classifier-1')
Esempio n. 9
0
def feed_forward(x, size_ff=2048, dropout=0.1, **kw):
    y = W.linear(x, size_ff, activation='relu')
    y = W.dropout(y, dropout)
    return W.linear(y, x.shape[1])