def __init__(self): super(SingleBNN, self).__init__() self.convolutions = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), BinaryHardtanh(), BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), BinaryHardtanh(), BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), BinaryHardtanh(), BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), BinaryHardtanh(), nn.MaxPool2d(kernel_size=2, stride=2), BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), BinaryHardtanh(), BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), BinaryHardtanh(), nn.MaxPool2d(kernel_size=2, stride=2), BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), BinaryHardtanh(), BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), BinaryHardtanh(), nn.MaxPool2d(kernel_size=2, stride=2), BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), BinaryHardtanh(), BinarizeConv2d(64, 10, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(10), nn.AvgPool2d(kernel_size=8)) self.softmax = nn.Sequential(nn.LogSoftmax(dim=1))
def __init__(self, *argv, **argn): super(sConv2d, self).__init__() if 'factor' in argn: self.init_factor = argn['factor'] del argn['factor'] else: self.init_factor = 1 if 'max_scales' in argn: self.max_scales = argn['max_scales'] del argn['max_scales'] else: self.max_scales = 4 if 'usf' in argn: self.usf = argn['usf'] del argn['usf'] else: self.usf = False if 'bnorm' in argn: self.bnorm = argn['bnorm'] del argn['bnorm'] else: self.bnorm = True if 'lf' in argn: self.lf = argn['lf'] del argn['lf'] else: self.lf = False if 'binarized' in argn: self.binarized = argn['binarized'] del argn['binarized'] else: self.binarized = False if 'last' in argn: self.last = argn['last'] del argn['last'] else: self.last = False self.alpha = nn.Parameter(torch.ones(self.max_scales)) # *0.001) if self.lf: self.factor = nn.Parameter(torch.ones(1) * self.init_factor) else: self.factor = self.init_factor self.interp = F.interpolate if self.binarized: self.conv = nn.ModuleList([BinarizeConv2d(*argv, **argn, bias=True) for _ in range(1)]) if self.bnorm: self.bn = nn.ModuleList([nn.BatchNorm2d(argv[1]) for _ in range(self.max_scales)]) self.relu = nn.ModuleList([nn.ReLU() for _ in range(self.max_scales)]) else: self.conv = nn.ModuleList([nn.Conv2d(*argv, **argn, bias=True) for _ in range(self.max_scales)]) if self.bnorm: self.bn = nn.ModuleList([nn.BatchNorm2d(argv[1]) for _ in range(self.max_scales)]) self.relu = nn.ModuleList([nn.ReLU() for _ in range(self.max_scales)]) self.lastbn = nn.BatchNorm2d(argv[1]) if not self.last: self.hardtanh = BinaryHardtanh()