예제 #1
0
def basic(x, size, stride, stack_index, block_index):
    """ The basic block. """
    prefix = f'layer{stack_index+1}-{block_index}-'
    y = W.conv(x,
               size,
               3,
               stride=stride,
               padding=1,
               bias=False,
               name=prefix + 'conv1')
    y = W.batch_norm(y, activation='relu', name=prefix + 'bn1')
    y = W.conv(y,
               size,
               3,
               stride=1,
               padding=1,
               bias=False,
               name=prefix + 'conv2')
    y = W.batch_norm(y, name=prefix + 'bn2')
    if y.shape[1] != x.shape[1]:
        x = W.conv(x,
                   y.shape[1],
                   1,
                   stride=stride,
                   bias=False,
                   name=prefix + 'downsample-0')
        x = W.batch_norm(x, name=prefix + 'downsample-1')
    return F.relu(y + x)
예제 #2
0
def basic(x, size, stride):
    y = W.conv(x, size, 3, stride=stride, padding=1, bias=False)
    y = W.batch_norm(y, activation='relu')
    y = W.conv(y, size, 3, stride=1, padding=1, bias=False)
    y = W.batch_norm(y)
    if y.shape[1] != x.shape[1]: # channel size mismatch, needs projection
        x = W.conv(x, y.shape[1], 1, stride=stride, bias=False)
        x = W.batch_norm(x)
    y = y+x # residual shortcut connection
    return F.relu(y)
예제 #3
0
def test_batch_norm():
    m = nn.Module()
    x = torch.randn(1, 2, 3)  # BCD
    torch.manual_seed(100)
    y0 = nn.BatchNorm1d(2)(x)
    torch.manual_seed(100)
    y1 = W.batch_norm(x, parent=m)
    m = nn.Module()
    assert torch.equal(y0, y1), 'batch_norm incorrect output on 1d signal.'
    x = torch.randn(1, 2, 3, 4)  # BCD
    torch.manual_seed(100)
    y0 = nn.BatchNorm2d(2)(x)
    torch.manual_seed(100)
    y1 = W.batch_norm(x, parent=m)
    assert torch.equal(y0, y1), 'batch_norm incorrect output on 2d signal.'
예제 #4
0
 def forward(self, x):
     y = W.conv(x, 64, 7, stride=2, padding=3, bias=False, name='conv1')
     y = W.batch_norm(y, activation='relu', name='bn1')
     y = F.max_pool2d(y, 3, stride=2, padding=1)
     for i, spec in enumerate(self.stack_spec):
         y = stack(y, *spec, i, block=self.block)
     y = F.adaptive_avg_pool2d(y, 1)
     y = torch.flatten(y, 1)
     y = W.linear(y, 1000, name='fc')
     return y
예제 #5
0
 def forward(self, x):
     y = W.conv(x, 64, 7, stride=2, padding=3, bias=False)
     y = W.batch_norm(y, activation='relu')
     y = F.max_pool2d(y, 3, stride=2, padding=1)
     for spec in self.stack_spec:
         y = stack(y, *spec, block=self.block)
     y = F.adaptive_avg_pool2d(y, 1)
     y = torch.flatten(y, 1)
     y = W.linear(y, 2000)
     return y
예제 #6
0
def conv_bn_relu(x, size, stride=1, expand=1, kernel=3, groups=1, name=''):
    x = W.conv(
        x,
        size,
        kernel,
        padding=(kernel - 1) // 2,
        stride=stride,
        groups=groups,
        bias=False,
        name=f'{name}-0',
    )
    return W.batch_norm(x, activation='relu6', name=f'{name}-1')
예제 #7
0
def bottleneck(x, size_out, stride, expand, name=''):
    size_in = x.shape[1]
    size_mid = size_in * expand
    y = conv_bn_relu(x, size_mid, kernel=1,
                     name=f'{name}-conv-0') if expand > 1 else x
    y = conv_bn_relu(y,
                     size_mid,
                     stride,
                     kernel=3,
                     groups=size_mid,
                     name=f'{name}-conv-{1 if expand > 1 else 0}')
    y = W.conv(y,
               size_out,
               kernel=1,
               bias=False,
               name=f'{name}-conv-{2 if expand > 1 else 1}')
    y = W.batch_norm(y, name=f'{name}-conv-{3 if expand > 1 else 2}')
    if stride == 1 and size_in == size_out:
        y += x  # residual shortcut
    return y
예제 #8
0
def conv_bn_act(x,
                size,
                kernel=1,
                stride=1,
                groups=1,
                bias=False,
                eps=1e-3,
                momentum=1e-2,
                act=swish,
                name='',
                **kw):
    x = conv_pad_same(x,
                      size,
                      kernel,
                      stride=stride,
                      groups=groups,
                      bias=bias,
                      name=name + '-conv')
    return W.batch_norm(x,
                        eps=eps,
                        momentum=momentum,
                        activation=act,
                        name=name + '-bn')