def __init__(self, max_disp, in_planes=64, batch_norm=True): super(AcfAggregator, self).__init__() self.max_disp = max_disp self.in_planes = in_planes self.batch_norm = batch_norm self.dres0 = nn.Sequential( conv3d_bn_relu(batch_norm, self.in_planes, 32, 3, 1, 1), conv3d_bn_relu(batch_norm, 32, 32, 3, 1, 1), ) self.dres1 = nn.Sequential(conv3d_bn_relu(batch_norm, 32, 32, 3, 1, 1), conv3d_bn(batch_norm, 32, 32, 3, 1, 1)) self.dres2 = Hourglass(in_planes=32, batch_norm=batch_norm) self.dres3 = Hourglass(in_planes=32, batch_norm=batch_norm) self.dres4 = Hourglass(in_planes=32, batch_norm=batch_norm) self.classif1 = nn.Sequential( conv3d_bn_relu(batch_norm, 32, 32, 3, 1, 1), nn.Conv3d(32, 1, kernel_size=3, stride=1, padding=1, bias=False), ) self.classif2 = nn.Sequential( conv3d_bn_relu(batch_norm, 32, 32, 3, 1, 1), nn.Conv3d(32, 1, kernel_size=3, stride=1, padding=1, bias=False), ) self.classif3 = nn.Sequential( conv3d_bn_relu(batch_norm, 32, 32, 3, 1, 1), nn.Conv3d(32, 1, kernel_size=3, stride=1, padding=1, bias=False)) self.deconv1 = nn.ConvTranspose3d(1, 1, 8, 4, 2, bias=False) self.deconv2 = nn.ConvTranspose3d(1, 1, 8, 4, 2, bias=False) self.deconv3 = nn.ConvTranspose3d(1, 1, 8, 4, 2, bias=False)
def __init__(self, in_planes, batchNorm=True): super(Hourglass, self).__init__() self.batchNorm = batchNorm self.conv1 = conv3d_bn_relu(self.batchNorm, in_planes, in_planes * 2, kernel_size=3, stride=2, padding=1, bias=False) self.conv2 = conv3d_bn(self.batchNorm, in_planes * 2, in_planes * 2, kernel_size=3, stride=1, padding=1, bias=False) self.conv3 = conv3d_bn_relu(self.batchNorm, in_planes * 2, in_planes * 2, kernel_size=3, stride=2, padding=1, bias=False) self.conv4 = conv3d_bn_relu(self.batchNorm, in_planes * 2, in_planes * 2, kernel_size=3, stride=1, padding=1, bias=False) self.conv5 = deconv3d_bn(self.batchNorm, in_planes * 2, in_planes * 2, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False) self.conv6 = deconv3d_bn(self.batchNorm, in_planes * 2, in_planes, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False)