def __init__(self): self.inplanes = 64 super(FeatureExtractor, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64, eps=1e-05) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.Sequential(*[nn.MaxPool2d(kernel_size=2, stride=1), Downsample(filt_size=3, stride=2, channels=64)]) self.layer1 = self._make_layer(BasicBlock, 64, 2) self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2) self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2) for m in self.modules(): if isinstance(m, nn.Conv2d): if(m.in_channels!=m.out_channels or m.out_channels!=m.groups or m.bias is not None): # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') else: print('Not initializing') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)
def init_mid_model(self): args = self.args filter_dim = args.filter_dim latent_dim = args.filter_dim im_size = args.im_size self.mid_conv1 = nn.Conv2d(3, filter_dim, kernel_size=3, stride=1, padding=1) self.mid_res_1a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=True, rescale=False, classes=1000) self.mid_res_1b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=False, classes=1000) self.mid_res_2a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=True, rescale=False, classes=1000) self.mid_res_2b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, classes=1000) self.mid_res_3a = CondResBlock(args, filters=2 * filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, classes=1000) self.mid_res_3b = CondResBlock(args, filters=2 * filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, classes=1000) # self.mid_fc1 = nn.Linear(filter_dim*4, 128) self.mid_energy_map = nn.Linear(filter_dim * 4, 1) self.avg_pool = Downsample(channels=3)
def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = [Downsample(filt_size=3, stride=stride, channels=self.inplanes),] if(stride !=1) else [] downsample += [conv1x1(self.inplanes, planes * block.expansion, 1), nn.BatchNorm2d(planes * block.expansion)] downsample = nn.Sequential(*downsample) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, 1, None)) return nn.Sequential(*layers)
def __init__(self, args): super(ImagenetModel, self).__init__() self.act = swish self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.cond = args.cond self.args = args self.init_main_model() if args.multiscale: self.init_mid_model() self.init_small_model() self.relu = torch.nn.ReLU(inplace=True) self.downsample = Downsample(channels=3) self.heir_weight = nn.Parameter(torch.Tensor([1.0, 1.0, 1.0]))
def __init__(self, args): super(ResNetModel, self).__init__() self.act = swish self.args = args self.spec_norm = args.spec_norm self.norm = args.norm self.init_main_model() if args.multiscale: self.init_mid_model() self.init_small_model() self.relu = torch.nn.ReLU(inplace=True) self.downsample = Downsample(channels=3) self.cond = args.cond
def __init__(self, args, downsample=True, rescale=True, filters=64, latent_dim=64, im_size=64, classes=512, norm=True, spec_norm=False): super(CondResBlock, self).__init__() self.filters = filters self.latent_dim = latent_dim self.im_size = im_size self.downsample = downsample if filters <= 128: self.bn1 = nn.InstanceNorm2d(filters, affine=True) else: self.bn1 = nn.GroupNorm(32, filters) if not norm: self.bn1 = None self.args = args if spec_norm: self.conv1 = spectral_norm(nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)) else: self.conv1 = WSConv2d(filters, filters, kernel_size=3, stride=1, padding=1) if filters <= 128: self.bn2 = nn.InstanceNorm2d(filters, affine=True) else: self.bn2 = nn.GroupNorm(32, filters, affine=True) if not norm: self.bn2 = None if spec_norm: self.conv2 = spectral_norm(nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)) else: self.conv2 = WSConv2d(filters, filters, kernel_size=3, stride=1, padding=1) self.dropout = Dropout(0.2) # Upscale to an mask of image self.latent_map = nn.Linear(classes, 2*filters) self.latent_map_2 = nn.Linear(classes, 2*filters) self.relu = torch.nn.ReLU(inplace=True) self.act = swish # Upscale to mask of image if downsample: if rescale: self.conv_downsample = nn.Conv2d(filters, 2 * filters, kernel_size=3, stride=1, padding=1) if args.alias: self.avg_pool = Downsample(channels=2*filters) else: self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1) else: self.conv_downsample = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1) if args.alias: self.avg_pool = Downsample(channels=filters) else: self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)