class Segmentator(nn.Module): def __init__(self, num_classes, encoder, img_size=(512, 512), shallow_decoder=False): super().__init__() self.low_feat = IntermediateLayerGetter(encoder, {"layer1": "layer1"}) #.cuda() self.encoder = IntermediateLayerGetter(encoder, {"layer4": "out"}) #.cuda() # n_classes, encoder_dim, img_size, low_level_dim, rates self.decoder = Decoder(num_classes, 512, img_size, low_level_dim=64, rates=[1, 6, 12, 18]) self.num_classes = num_classes def forward(self, x): self.low_feat.eval() self.encoder.eval() with torch.no_grad(): # This is possible since gradients are not being updated low_level_feat = self.low_feat(x)['layer1'] enc_feat = self.encoder(x)['out'] segmentation = self.decoder(enc_feat, low_level_feat) if self.num_classes == 1: segmentation = torch.sigmoid(segmentation) return segmentation
def build_model(args): if not hasattr(torchvision.models, args.model): raise ValueError('Invalid model "%s"' % args.model) if not 'resnet' in args.model: raise ValueError('Feature extraction only supports ResNets') cnn = getattr(torchvision.models, args.model)(pretrained=True, norm_layer=FrozenBatchNorm2d) layers = OrderedDict([ ('conv1', cnn.conv1), ('bn1', cnn.bn1), ('relu', cnn.relu), ('maxpool', cnn.maxpool), ]) for i in range(4): name = 'layer%d' % (i + 1) layers[name] = getattr(cnn, name) model = torch.nn.Sequential(layers) model.cuda() model.eval() # print(model) return_layers = {'layer1': 0, 'layer2': 1, 'layer3': 2, 'layer4': 3} model = IntermediateLayerGetter(model, return_layers=return_layers) model.cuda() model.eval() return model
class AttSegmentator(nn.Module): def __init__(self, num_classes, encoder, att_type='additive', img_size=(512, 512)): super().__init__() self.low_feat = IntermediateLayerGetter(encoder, { "layer1": "layer1" }).cuda() self.encoder = IntermediateLayerGetter(encoder, { "layer4": "out" }).cuda() # For resnet18 encoder_dim = 512 low_level_dim = 64 self.num_classes = num_classes self.class_encoder = nn.Linear(num_classes, 512) self.attention_enc = Attention(encoder_dim, att_type) self.decoder = Decoder(2, encoder_dim, img_size, low_level_dim=low_level_dim, rates=[1, 6, 12, 18]) def forward(self, x, v_class, out_att=False): self.low_feat.eval() self.encoder.eval() with torch.no_grad(): low_level_feat = self.low_feat(x)['layer1'] enc_feat = self.encoder(x)['out'] query = self.class_encoder(v_class) shape = enc_feat.shape enc_feat = enc_feat.permute(0, 2, 3, 1).contiguous().view( shape[0], -1, shape[1]) x_enc, attention = self.attention_enc(enc_feat, query) x_enc = x_enc.view(shape) segmentation = self.decoder(x_enc, low_level_feat) if out_att: return segmentation, attention return segmentation
class AttSegmentator(nn.Module): def __init__(self, num_classes, encoder, att_type='additive', img_size=(512, 512)): super().__init__() self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.low_feat = IntermediateLayerGetter(encoder, { "layer1": "layer1" }).to(self.device) self.encoder = IntermediateLayerGetter(encoder, { "layer4": "out" }).to(self.device) # For resnet18 encoder_dim = 512 low_level_dim = 64 self.num_classes = num_classes self.class_encoder = nn.Linear(num_classes, 512) self.attention_enc = Attention(encoder_dim, att_type) self.decoder = Decoder(2, encoder_dim, img_size, low_level_dim=low_level_dim, rates=[1, 6, 12, 18]) def forward(self, x, v_class, out_att=False): #raise NotImplementedError("TODO: Implement the attention-based segmentation network") # Write the forward pass of the model. # Base the model on the segmentation model and add the attention layer. # Be aware of the dimentions. # x_enc, attention = self.attention_enc(x_enc, class_vec) # if out_att: # return segmentation, attention # return segmentation self.low_feat.eval() self.encoder.eval() with torch.no_grad(): # This is possible since gradients are not being updated low_level_feat = self.low_feat(x)['layer1'] enc_feat = self.encoder(x)[ 'out'] # encoder output: value tensor [1, 512, 16, 16] x_enc = enc_feat.permute( 0, 2, 3, 1).contiguous() # encoder output: value tensor [1, 16, 16, 512] x_enc = x_enc.view( x_enc.shape[0], -1, x_enc.shape[-1] ) # x_enc.view(batch_size, -1, n_features) [1, 16*16, 512] class_vec = self.class_encoder( v_class) # Hidden states: this is query tensor [1, 512] x_enc, attention = self.attention_enc(x_enc, class_vec) # x_enc_att: torch.Size([1, 16*16, 512]) , attention: torch.Size([1, 16*16y]) #x_enc = x_enc + x_enc_att x_enc = x_enc.permute(0, 2, 1).contiguous().view(enc_feat.shape) # x_enc = torch.Size([1, 512, 16,16]) #class_vec = class_vec.unsqueeze(2).unsqueeze(-1) segmentation = self.decoder(x_enc, low_level_feat) # if self.num_classes==1: # segmentation = torch.sigmoid(segmentation) if out_att: return segmentation, attention return segmentation