def __init__(self):
     super(SingleBNN, self).__init__()
     self.convolutions = nn.Sequential(
         nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
         nn.BatchNorm2d(64), BinaryHardtanh(),
         BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1),
         nn.BatchNorm2d(64), BinaryHardtanh(),
         BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1),
         nn.BatchNorm2d(64), BinaryHardtanh(),
         BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1),
         nn.BatchNorm2d(64), BinaryHardtanh(),
         nn.MaxPool2d(kernel_size=2, stride=2),
         BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1),
         nn.BatchNorm2d(64), BinaryHardtanh(),
         BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1),
         nn.BatchNorm2d(64), BinaryHardtanh(),
         nn.MaxPool2d(kernel_size=2, stride=2),
         BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1),
         nn.BatchNorm2d(64), BinaryHardtanh(),
         BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1),
         nn.BatchNorm2d(64), BinaryHardtanh(),
         nn.MaxPool2d(kernel_size=2, stride=2),
         BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1),
         nn.BatchNorm2d(64), BinaryHardtanh(),
         BinarizeConv2d(64, 10, kernel_size=3, stride=1, padding=1),
         nn.BatchNorm2d(10), nn.AvgPool2d(kernel_size=8))
     self.softmax = nn.Sequential(nn.LogSoftmax(dim=1))
Ejemplo n.º 2
0
    def __init__(self, *argv, **argn):
        super(sConv2d, self).__init__()
        if 'factor' in argn:
            self.init_factor = argn['factor']
            del argn['factor']
        else:
            self.init_factor = 1
        if 'max_scales' in argn:
            self.max_scales = argn['max_scales']
            del argn['max_scales']
        else:
            self.max_scales = 4
        if 'usf' in argn:
            self.usf = argn['usf']
            del argn['usf']
        else:
            self.usf = False
        if 'bnorm' in argn:
            self.bnorm = argn['bnorm']
            del argn['bnorm']
        else:
            self.bnorm = True
        if 'lf' in argn:
            self.lf = argn['lf']
            del argn['lf']
        else:
            self.lf = False
        if 'binarized' in argn:
            self.binarized = argn['binarized']
            del argn['binarized']
        else:
            self.binarized = False
        if 'last' in argn:
            self.last = argn['last']
            del argn['last']
        else:
            self.last = False

        self.alpha = nn.Parameter(torch.ones(self.max_scales))  # *0.001)
        if self.lf:
            self.factor = nn.Parameter(torch.ones(1) * self.init_factor)
        else:
            self.factor = self.init_factor
        self.interp = F.interpolate
        if self.binarized:
            self.conv = nn.ModuleList([BinarizeConv2d(*argv, **argn, bias=True) for _ in range(1)])
            if self.bnorm:
                self.bn = nn.ModuleList([nn.BatchNorm2d(argv[1]) for _ in range(self.max_scales)])
            self.relu = nn.ModuleList([nn.ReLU() for _ in range(self.max_scales)])
        else:
            self.conv = nn.ModuleList([nn.Conv2d(*argv, **argn, bias=True) for _ in range(self.max_scales)])
            if self.bnorm:
                self.bn = nn.ModuleList([nn.BatchNorm2d(argv[1]) for _ in range(self.max_scales)])
            self.relu = nn.ModuleList([nn.ReLU() for _ in range(self.max_scales)])
        self.lastbn = nn.BatchNorm2d(argv[1])
        if not self.last:
            self.hardtanh = BinaryHardtanh()