def squeeze_excitation(x, size_se, name='', **kw): if size_se == 0: return x size_in = x.shape[1] x = F.adaptive_avg_pool2d(x, 1) x = W.conv(x, size_se, 1, activation=swish, name=name + '-conv1') return W.conv(x, size_in, 1, activation=swish, name=name + '-conv2')
def basic(x, size, stride, stack_index, block_index): """ The basic block. """ prefix = f'layer{stack_index+1}-{block_index}-' y = W.conv(x, size, 3, stride=stride, padding=1, bias=False, name=prefix + 'conv1') y = W.batch_norm(y, activation='relu', name=prefix + 'bn1') y = W.conv(y, size, 3, stride=1, padding=1, bias=False, name=prefix + 'conv2') y = W.batch_norm(y, name=prefix + 'bn2') if y.shape[1] != x.shape[1]: x = W.conv(x, y.shape[1], 1, stride=stride, bias=False, name=prefix + 'downsample-0') x = W.batch_norm(x, name=prefix + 'downsample-1') return F.relu(y + x)
def forward(self, x): x = W.conv(x, 20, 5, activation='relu') x = F.max_pool2d(x, 2) x = W.conv(x, 50, 5, activation='relu') x = F.max_pool2d(x, 2) x = x.view(-1, 800) x = W.linear(x, 500, activation='relu') x = W.linear(x, 10) return F.log_softmax(x, dim=1)
def basic(x, size, stride): y = W.conv(x, size, 3, stride=stride, padding=1, bias=False) y = W.batch_norm(y, activation='relu') y = W.conv(y, size, 3, stride=1, padding=1, bias=False) y = W.batch_norm(y) if y.shape[1] != x.shape[1]: # channel size mismatch, needs projection x = W.conv(x, y.shape[1], 1, stride=stride, bias=False) x = W.batch_norm(x) y = y+x # residual shortcut connection return F.relu(y)
def test_conv(): m = nn.Module() x = torch.randn(1, 2, 8) # BCD torch.manual_seed(100) y0 = nn.Conv1d(2, 3, 3)(x) torch.manual_seed(100) y1 = W.conv(x, 3, 3, parent=m) assert torch.equal(y0, y1), 'conv incorrect output on 1d signal.' m = nn.Module() x = torch.randn(1, 2, 3, 4) # BCD torch.manual_seed(100) y0 = nn.Conv2d(2, 3, 3)(x) torch.manual_seed(100) y1 = W.conv(x, 3, 3, parent=m) assert torch.equal(y0, y1), 'conv incorrect output on 2d signal.'
def forward(self, x): y = W.conv(x, 64, 7, stride=2, padding=3, bias=False, name='conv1') y = W.batch_norm(y, activation='relu', name='bn1') y = F.max_pool2d(y, 3, stride=2, padding=1) for i, spec in enumerate(self.stack_spec): y = stack(y, *spec, i, block=self.block) y = F.adaptive_avg_pool2d(y, 1) y = torch.flatten(y, 1) y = W.linear(y, 1000, name='fc') return y
def forward(self, x): y = W.conv(x, 64, 7, stride=2, padding=3, bias=False) y = W.batch_norm(y, activation='relu') y = F.max_pool2d(y, 3, stride=2, padding=1) for spec in self.stack_spec: y = stack(y, *spec, block=self.block) y = F.adaptive_avg_pool2d(y, 1) y = torch.flatten(y, 1) y = W.linear(y, 2000) return y
def conv_bn_relu(x, size, stride=1, expand=1, kernel=3, groups=1, name=''): x = W.conv( x, size, kernel, padding=(kernel - 1) // 2, stride=stride, groups=groups, bias=False, name=f'{name}-0', ) return W.batch_norm(x, activation='relu6', name=f'{name}-1')
def conv_pad_same(x, size, kernel=1, stride=1, **kw): pad = 0 if kernel != 1 or stride != 1: in_size, s, k = [ torch.as_tensor(v) for v in (x.shape[2:], stride, kernel) ] pad = torch.max(((in_size + s - 1) // s - 1) * s + k - in_size, torch.tensor(0)) left, right = pad // 2, pad - pad // 2 if torch.all(left == right): pad = tuple(left.tolist()) else: left, right = left.tolist(), right.tolist() pad = sum(zip(left[::-1], right[::-1]), ()) x = F.pad(x, pad) pad = 0 return W.conv(x, size, kernel, stride=stride, padding=pad, **kw)
def bottleneck(x, size_out, stride, expand, name=''): size_in = x.shape[1] size_mid = size_in * expand y = conv_bn_relu(x, size_mid, kernel=1, name=f'{name}-conv-0') if expand > 1 else x y = conv_bn_relu(y, size_mid, stride, kernel=3, groups=size_mid, name=f'{name}-conv-{1 if expand > 1 else 0}') y = W.conv(y, size_out, kernel=1, bias=False, name=f'{name}-conv-{2 if expand > 1 else 1}') y = W.batch_norm(y, name=f'{name}-conv-{3 if expand > 1 else 2}') if stride == 1 and size_in == size_out: y += x # residual shortcut return y