示例#1
0
def squeeze_excitation(x, size_se, name='', **kw):
    if size_se == 0:
        return x
    size_in = x.shape[1]
    x = F.adaptive_avg_pool2d(x, 1)
    x = W.conv(x, size_se, 1, activation=swish, name=name + '-conv1')
    return W.conv(x, size_in, 1, activation=swish, name=name + '-conv2')
示例#2
0
def basic(x, size, stride, stack_index, block_index):
    """ The basic block. """
    prefix = f'layer{stack_index+1}-{block_index}-'
    y = W.conv(x,
               size,
               3,
               stride=stride,
               padding=1,
               bias=False,
               name=prefix + 'conv1')
    y = W.batch_norm(y, activation='relu', name=prefix + 'bn1')
    y = W.conv(y,
               size,
               3,
               stride=1,
               padding=1,
               bias=False,
               name=prefix + 'conv2')
    y = W.batch_norm(y, name=prefix + 'bn2')
    if y.shape[1] != x.shape[1]:
        x = W.conv(x,
                   y.shape[1],
                   1,
                   stride=stride,
                   bias=False,
                   name=prefix + 'downsample-0')
        x = W.batch_norm(x, name=prefix + 'downsample-1')
    return F.relu(y + x)
示例#3
0
 def forward(self, x):
     x = W.conv(x, 20, 5, activation='relu')
     x = F.max_pool2d(x, 2)
     x = W.conv(x, 50, 5, activation='relu')
     x = F.max_pool2d(x, 2)
     x = x.view(-1, 800)
     x = W.linear(x, 500, activation='relu')
     x = W.linear(x, 10)
     return F.log_softmax(x, dim=1)
示例#4
0
def basic(x, size, stride):
    y = W.conv(x, size, 3, stride=stride, padding=1, bias=False)
    y = W.batch_norm(y, activation='relu')
    y = W.conv(y, size, 3, stride=1, padding=1, bias=False)
    y = W.batch_norm(y)
    if y.shape[1] != x.shape[1]: # channel size mismatch, needs projection
        x = W.conv(x, y.shape[1], 1, stride=stride, bias=False)
        x = W.batch_norm(x)
    y = y+x # residual shortcut connection
    return F.relu(y)
示例#5
0
def test_conv():
    m = nn.Module()
    x = torch.randn(1, 2, 8)  # BCD
    torch.manual_seed(100)
    y0 = nn.Conv1d(2, 3, 3)(x)
    torch.manual_seed(100)
    y1 = W.conv(x, 3, 3, parent=m)
    assert torch.equal(y0, y1), 'conv incorrect output on 1d signal.'
    m = nn.Module()
    x = torch.randn(1, 2, 3, 4)  # BCD
    torch.manual_seed(100)
    y0 = nn.Conv2d(2, 3, 3)(x)
    torch.manual_seed(100)
    y1 = W.conv(x, 3, 3, parent=m)
    assert torch.equal(y0, y1), 'conv incorrect output on 2d signal.'
示例#6
0
 def forward(self, x):
     y = W.conv(x, 64, 7, stride=2, padding=3, bias=False, name='conv1')
     y = W.batch_norm(y, activation='relu', name='bn1')
     y = F.max_pool2d(y, 3, stride=2, padding=1)
     for i, spec in enumerate(self.stack_spec):
         y = stack(y, *spec, i, block=self.block)
     y = F.adaptive_avg_pool2d(y, 1)
     y = torch.flatten(y, 1)
     y = W.linear(y, 1000, name='fc')
     return y
示例#7
0
 def forward(self, x):
     y = W.conv(x, 64, 7, stride=2, padding=3, bias=False)
     y = W.batch_norm(y, activation='relu')
     y = F.max_pool2d(y, 3, stride=2, padding=1)
     for spec in self.stack_spec:
         y = stack(y, *spec, block=self.block)
     y = F.adaptive_avg_pool2d(y, 1)
     y = torch.flatten(y, 1)
     y = W.linear(y, 2000)
     return y
示例#8
0
def conv_bn_relu(x, size, stride=1, expand=1, kernel=3, groups=1, name=''):
    x = W.conv(
        x,
        size,
        kernel,
        padding=(kernel - 1) // 2,
        stride=stride,
        groups=groups,
        bias=False,
        name=f'{name}-0',
    )
    return W.batch_norm(x, activation='relu6', name=f'{name}-1')
示例#9
0
def conv_pad_same(x, size, kernel=1, stride=1, **kw):
    pad = 0
    if kernel != 1 or stride != 1:
        in_size, s, k = [
            torch.as_tensor(v) for v in (x.shape[2:], stride, kernel)
        ]
        pad = torch.max(((in_size + s - 1) // s - 1) * s + k - in_size,
                        torch.tensor(0))
        left, right = pad // 2, pad - pad // 2
        if torch.all(left == right):
            pad = tuple(left.tolist())
        else:
            left, right = left.tolist(), right.tolist()
            pad = sum(zip(left[::-1], right[::-1]), ())
            x = F.pad(x, pad)
            pad = 0
    return W.conv(x, size, kernel, stride=stride, padding=pad, **kw)
示例#10
0
def bottleneck(x, size_out, stride, expand, name=''):
    size_in = x.shape[1]
    size_mid = size_in * expand
    y = conv_bn_relu(x, size_mid, kernel=1,
                     name=f'{name}-conv-0') if expand > 1 else x
    y = conv_bn_relu(y,
                     size_mid,
                     stride,
                     kernel=3,
                     groups=size_mid,
                     name=f'{name}-conv-{1 if expand > 1 else 0}')
    y = W.conv(y,
               size_out,
               kernel=1,
               bias=False,
               name=f'{name}-conv-{2 if expand > 1 else 1}')
    y = W.batch_norm(y, name=f'{name}-conv-{3 if expand > 1 else 2}')
    if stride == 1 and size_in == size_out:
        y += x  # residual shortcut
    return y