def __init__(self, in_planes, out_planes, stride, droprate_init=0.0, weight_decay=0., lamba=0.01, local_rep=False, temperature=2./3.): super(BasicBlock, self).__init__() self.bn1 = nn.BatchNorm2d(in_planes) self.conv1 = L0Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False, droprate_init=droprate_init, weight_decay=weight_decay / (1 - 0.3), local_rep=local_rep, lamba=lamba, temperature=temperature) self.bn2 = nn.BatchNorm2d(out_planes) self.conv2 = MAPConv2d(out_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False, weight_decay=weight_decay) self.equalInOut = (in_planes == out_planes) self.convShortcut = (not self.equalInOut) and \ MAPConv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False, weight_decay=weight_decay) or None
def __init__(self, depth, num_classes, widen_factor=1, droprate_init=0.3, N=50000, beta_ema=0.99, weight_decay=5e-4, local_rep=False, lamba=0.01, temperature=2./3.): super(L0WideResNet, self).__init__() nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] assert((depth - 4) % 6 == 0) self.n = (depth - 4) // 6 self.N = N self.beta_ema = beta_ema block = BasicBlock self.weight_decay = N * weight_decay self.lamba = lamba # 1st conv before any network block self.conv1 = MAPConv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False, weight_decay=self.weight_decay) # 1st block self.block1 = NetworkBlock(self.n, nChannels[0], nChannels[1], block, 1, droprate_init, self.weight_decay, self.lamba, local_rep=local_rep, temperature=temperature) # 2nd block self.block2 = NetworkBlock(self.n, nChannels[1], nChannels[2], block, 2, droprate_init, self.weight_decay, self.lamba, local_rep=local_rep, temperature=temperature) # 3rd block self.block3 = NetworkBlock(self.n, nChannels[2], nChannels[3], block, 2, droprate_init, self.weight_decay, self.lamba, local_rep=local_rep, temperature=temperature) # bn, relu and classifier self.bn = nn.BatchNorm2d(nChannels[3]) self.fcout = MAPDense(nChannels[3], num_classes, weight_decay=self.weight_decay) self.layers, self.bn_params = [], [] for m in self.modules(): if isinstance(m, MAPDense) or isinstance(m, MAPConv2d) or isinstance(m, L0Conv2d): self.layers.append(m) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() self.bn_params += [m.weight, m.bias] if beta_ema > 0.: print('Using temporal averaging with beta: {}'.format(beta_ema)) self.avg_param = deepcopy(list(p.data for p in self.parameters())) if torch.cuda.is_available(): self.avg_param = [a.cuda() for a in self.avg_param] self.steps_ema = 0. print('Using weight decay: {}'.format(self.weight_decay))
def __init__(self, depth, num_classes, widen_factor=1, droprate_init=0.3, N=50000, beta_ema=0.99, weight_decay=5e-4, local_rep=False, lamba=0.01, temperature=2. / 3., dropout=0.5, dropout_botk=0.5, dropout_type="weight"): super(TDWideResNet, self).__init__() nChannels = [ 16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor ] assert ((depth - 4) % 6 == 0) self.n = (depth - 4) // 6 self.N = N self.beta_ema = beta_ema block = TDBasicBlock self.weight_decay = 0.001 self.lamba = lamba # 1st conv before any network block self.conv1 = MAPConv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False, weight_decay=self.weight_decay) # 1st block self.block1 = TDNetworkBlock(self.n, nChannels[0], nChannels[1], block, 1, droprate_init, self.weight_decay, self.lamba, local_rep=local_rep, temperature=temperature, dropout=dropout, dropout_botk=dropout_botk, dropout_type=dropout_type) # 2nd block self.block2 = TDNetworkBlock(self.n, nChannels[1], nChannels[2], block, 2, droprate_init, self.weight_decay, self.lamba, local_rep=local_rep, temperature=temperature, dropout=dropout, dropout_botk=dropout_botk, dropout_type=dropout_type) # 3rd block self.block3 = TDNetworkBlock(self.n, nChannels[2], nChannels[3], block, 2, droprate_init, self.weight_decay, self.lamba, local_rep=local_rep, temperature=temperature, dropout=dropout, dropout_botk=dropout_botk, dropout_type=dropout_type) # bn, relu and classifier self.bn = nn.BatchNorm2d(nChannels[3]) self.fcout = MAPDense(nChannels[3], num_classes, weight_decay=self.weight_decay) self.layers, self.bn_params = [], [] for m in self.modules(): if isinstance(m, MAPDense) or isinstance( m, MAPConv2d) or isinstance(m, L0Conv2d): self.layers.append(m) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() self.bn_params += [m.weight, m.bias] print('Using weight decay: {}'.format(self.weight_decay))