def __init__(self, input_channels, block_channels, stride, projection_shortcut, use_dropout, builder: ConvBuilder): super(WRNCifarBlock, self).__init__() assert len(block_channels) == 2 if projection_shortcut: self.proj = builder.BNReLUConv2d(in_channels=input_channels, out_channels=block_channels[1], kernel_size=1, stride=stride, padding=0) else: self.proj = builder.ResIdentity(num_channels=block_channels[1]) self.conv1 = builder.BNReLUConv2d(in_channels=input_channels, out_channels=block_channels[0], kernel_size=3, stride=stride, padding=1) if use_dropout: self.dropout = builder.Dropout(keep_prob=0.7) print('use dropout for WRN') else: self.dropout = builder.Identity() self.conv2 = builder.BNReLUConv2d(in_channels=block_channels[0], out_channels=block_channels[1], kernel_size=3, stride=1, padding=1)
def __init__(self, builder:ConvBuilder, in_planes, planes, stride=1): super(BasicBlock, self).__init__() self.conv1 = builder.Conv2dBNReLU(in_channels=in_planes, out_channels=planes, kernel_size=3, stride=stride, padding=1) self.conv2 = builder.Conv2dBN(in_channels=planes, out_channels=self.expansion * planes, kernel_size=3, stride=1, padding=1) if stride != 1 or in_planes != self.expansion * planes: self.shortcut = builder.Conv2dBN(in_channels=in_planes, out_channels=self.expansion * planes, kernel_size=1, stride=stride) else: self.shortcut = builder.ResIdentity(num_channels=in_planes)