def on_build(self): self.conv1 = nn.Conv2D (3, 64, kernel_size=7, strides=2, padding='SAME') self.bn1 = nn.BatchNorm2D(64) self.conv2 = ConvBlock(64, 128) self.conv3 = ConvBlock(128, 128) self.conv4 = ConvBlock(128, 256) self.m = [] self.top_m = [] self.conv_last = [] self.bn_end = [] self.l = [] self.bl = [] self.al = [] for i in range(4): self.m += [ HourGlass(256, 4) ] self.top_m += [ ConvBlock(256, 256) ] self.conv_last += [ nn.Conv2D (256, 256, kernel_size=1, strides=1, padding='VALID') ] self.bn_end += [ nn.BatchNorm2D(256) ] self.l += [ nn.Conv2D (256, 68, kernel_size=1, strides=1, padding='VALID') ] if i < 4-1: self.bl += [ nn.Conv2D (256, 256, kernel_size=1, strides=1, padding='VALID') ] self.al += [ nn.Conv2D (68, 256, kernel_size=1, strides=1, padding='VALID') ]
def on_build(self, in_planes, out_planes): self.in_planes = in_planes self.out_planes = out_planes self.bn1 = nn.BatchNorm2D(in_planes) self.conv1 = nn.Conv2D (in_planes, out_planes//2, kernel_size=3, strides=1, padding='SAME', use_bias=False ) self.bn2 = nn.BatchNorm2D(out_planes//2) self.conv2 = nn.Conv2D (out_planes//2, out_planes//4, kernel_size=3, strides=1, padding='SAME', use_bias=False ) self.bn3 = nn.BatchNorm2D(out_planes//4) self.conv3 = nn.Conv2D (out_planes//4, out_planes//4, kernel_size=3, strides=1, padding='SAME', use_bias=False ) if self.in_planes != self.out_planes: self.down_bn1 = nn.BatchNorm2D(in_planes) self.down_conv1 = nn.Conv2D (in_planes, out_planes, kernel_size=1, strides=1, padding='VALID', use_bias=False ) else: self.down_bn1 = None self.down_conv1 = None