def basic(x, size, stride, stack_index, block_index): """ The basic block. """ prefix = f'layer{stack_index+1}-{block_index}-' y = W.conv(x, size, 3, stride=stride, padding=1, bias=False, name=prefix + 'conv1') y = W.batch_norm(y, activation='relu', name=prefix + 'bn1') y = W.conv(y, size, 3, stride=1, padding=1, bias=False, name=prefix + 'conv2') y = W.batch_norm(y, name=prefix + 'bn2') if y.shape[1] != x.shape[1]: x = W.conv(x, y.shape[1], 1, stride=stride, bias=False, name=prefix + 'downsample-0') x = W.batch_norm(x, name=prefix + 'downsample-1') return F.relu(y + x)
def basic(x, size, stride): y = W.conv(x, size, 3, stride=stride, padding=1, bias=False) y = W.batch_norm(y, activation='relu') y = W.conv(y, size, 3, stride=1, padding=1, bias=False) y = W.batch_norm(y) if y.shape[1] != x.shape[1]: # channel size mismatch, needs projection x = W.conv(x, y.shape[1], 1, stride=stride, bias=False) x = W.batch_norm(x) y = y+x # residual shortcut connection return F.relu(y)
def test_batch_norm(): m = nn.Module() x = torch.randn(1, 2, 3) # BCD torch.manual_seed(100) y0 = nn.BatchNorm1d(2)(x) torch.manual_seed(100) y1 = W.batch_norm(x, parent=m) m = nn.Module() assert torch.equal(y0, y1), 'batch_norm incorrect output on 1d signal.' x = torch.randn(1, 2, 3, 4) # BCD torch.manual_seed(100) y0 = nn.BatchNorm2d(2)(x) torch.manual_seed(100) y1 = W.batch_norm(x, parent=m) assert torch.equal(y0, y1), 'batch_norm incorrect output on 2d signal.'
def forward(self, x): y = W.conv(x, 64, 7, stride=2, padding=3, bias=False, name='conv1') y = W.batch_norm(y, activation='relu', name='bn1') y = F.max_pool2d(y, 3, stride=2, padding=1) for i, spec in enumerate(self.stack_spec): y = stack(y, *spec, i, block=self.block) y = F.adaptive_avg_pool2d(y, 1) y = torch.flatten(y, 1) y = W.linear(y, 1000, name='fc') return y
def forward(self, x): y = W.conv(x, 64, 7, stride=2, padding=3, bias=False) y = W.batch_norm(y, activation='relu') y = F.max_pool2d(y, 3, stride=2, padding=1) for spec in self.stack_spec: y = stack(y, *spec, block=self.block) y = F.adaptive_avg_pool2d(y, 1) y = torch.flatten(y, 1) y = W.linear(y, 2000) return y
def conv_bn_relu(x, size, stride=1, expand=1, kernel=3, groups=1, name=''): x = W.conv( x, size, kernel, padding=(kernel - 1) // 2, stride=stride, groups=groups, bias=False, name=f'{name}-0', ) return W.batch_norm(x, activation='relu6', name=f'{name}-1')
def bottleneck(x, size_out, stride, expand, name=''): size_in = x.shape[1] size_mid = size_in * expand y = conv_bn_relu(x, size_mid, kernel=1, name=f'{name}-conv-0') if expand > 1 else x y = conv_bn_relu(y, size_mid, stride, kernel=3, groups=size_mid, name=f'{name}-conv-{1 if expand > 1 else 0}') y = W.conv(y, size_out, kernel=1, bias=False, name=f'{name}-conv-{2 if expand > 1 else 1}') y = W.batch_norm(y, name=f'{name}-conv-{3 if expand > 1 else 2}') if stride == 1 and size_in == size_out: y += x # residual shortcut return y
def conv_bn_act(x, size, kernel=1, stride=1, groups=1, bias=False, eps=1e-3, momentum=1e-2, act=swish, name='', **kw): x = conv_pad_same(x, size, kernel, stride=stride, groups=groups, bias=bias, name=name + '-conv') return W.batch_norm(x, eps=eps, momentum=momentum, activation=act, name=name + '-bn')