def __init__(self, d_depth, d_emb, d_inp, d_cap, n_parts, n_classes, n_iters=3): super().__init__() self.depth_emb = nn.Parameter(torch.zeros(d_depth, d_emb)) self.detect_parts = nn.Sequential(nn.Linear(d_emb, d_inp), Swish(), nn.LayerNorm(d_inp)) self.routings = nn.Sequential( Routing(d_cov=1, d_inp=d_inp, d_out=d_cap, n_out=n_parts * 2, n_iters=3), Routing(d_cov=1, d_inp=d_cap, d_out=d_cap, n_inp=n_parts * 2, n_out=n_parts, n_iters=3), Routing(d_cov=1, d_inp=d_cap, d_out=d_cap, n_inp=n_parts, n_out=n_classes, n_iters=3), ) nn.init.kaiming_normal_(self.detect_parts[0].weight) nn.init.zeros_(self.detect_parts[0].bias)
def __init__(self, n_objs, n_parts, d_chns): super().__init__() self.convolve = nn.Sequential(*[ m for (inp_ch, out_ch, stride) in zip([4] + [d_chns] * 5, [d_chns] * 6, [1, 2] * 3) for m in [ nn.BatchNorm2d(inp_ch), nn.Conv2d(inp_ch, out_ch, 3, stride), Swish() ] ]) self.compute_a = nn.Sequential(nn.BatchNorm2d(d_chns), nn.Conv2d(d_chns, n_parts, 1)) self.compute_mu = nn.Sequential(nn.BatchNorm2d(d_chns), nn.Conv2d(d_chns, n_parts * 4 * 4, 1)) self.routings = nn.Sequential( Routing(d_cov=4, d_inp=4, d_out=4, n_out=n_parts), Routing(d_cov=4, d_inp=4, d_out=4, n_inp=n_parts, n_out=n_objs), ) for conv in [m for m in self.convolve if type(m) == nn.Conv2d]: nn.init.kaiming_normal_(conv.weight) nn.init.zeros_(conv.bias)
def __init__(self, n_objs, n_parts, d_chns): super().__init__() self.convolve = nn.Sequential( nn.BatchNorm2d(2 + 2), nn.Conv2d(2 + 2, d_chns, kernel_size=3), Swish(), nn.BatchNorm2d(d_chns), nn.Conv2d(d_chns, d_chns, kernel_size=3, stride=2), Swish(), nn.BatchNorm2d(d_chns), nn.Conv2d(d_chns, d_chns, kernel_size=3), Swish(), nn.BatchNorm2d(d_chns), nn.Conv2d(d_chns, d_chns, kernel_size=3, stride=2), Swish(), nn.BatchNorm2d(d_chns), nn.Conv2d(d_chns, d_chns, kernel_size=3), Swish(), nn.BatchNorm2d(d_chns), nn.Conv2d(d_chns, d_chns, kernel_size=3, stride=2), Swish(), ) self.compute_a = nn.Sequential(nn.BatchNorm2d(d_chns), nn.Conv2d(d_chns, n_parts, 1)) self.compute_mu = nn.Sequential(nn.BatchNorm2d(d_chns), nn.Conv2d(d_chns, n_parts * 4 * 4, 1)) self.routings = nn.Sequential( Routing(d_cov=4, d_out=4, n_out=n_parts, d_inp=4, n_iters=3), Routing(d_cov=4, d_out=4, n_out=n_objs, d_inp=4, n_inp=n_parts, n_iters=3), ) for conv in [m for m in self.convolve if type(m) == nn.Conv2d]: nn.init.kaiming_normal_(conv.weight) nn.init.zeros_(conv.bias)