class Segmentator(nn.Module):
    def __init__(self,
                 num_classes,
                 encoder,
                 img_size=(512, 512),
                 shallow_decoder=False):
        super().__init__()
        self.low_feat = IntermediateLayerGetter(encoder,
                                                {"layer1": "layer1"})  #.cuda()
        self.encoder = IntermediateLayerGetter(encoder,
                                               {"layer4": "out"})  #.cuda()

        # n_classes, encoder_dim, img_size, low_level_dim, rates
        self.decoder = Decoder(num_classes,
                               512,
                               img_size,
                               low_level_dim=64,
                               rates=[1, 6, 12, 18])
        self.num_classes = num_classes

    def forward(self, x):
        self.low_feat.eval()
        self.encoder.eval()
        with torch.no_grad():
            # This is possible since gradients are not being updated
            low_level_feat = self.low_feat(x)['layer1']
            enc_feat = self.encoder(x)['out']

        segmentation = self.decoder(enc_feat, low_level_feat)

        if self.num_classes == 1:
            segmentation = torch.sigmoid(segmentation)
        return segmentation
Example #2
0
def build_model(args):
    if not hasattr(torchvision.models, args.model):
        raise ValueError('Invalid model "%s"' % args.model)
    if not 'resnet' in args.model:
        raise ValueError('Feature extraction only supports ResNets')
    cnn = getattr(torchvision.models, args.model)(pretrained=True,
                                                  norm_layer=FrozenBatchNorm2d)
    layers = OrderedDict([
        ('conv1', cnn.conv1),
        ('bn1', cnn.bn1),
        ('relu', cnn.relu),
        ('maxpool', cnn.maxpool),
    ])
    for i in range(4):
        name = 'layer%d' % (i + 1)
        layers[name] = getattr(cnn, name)
    model = torch.nn.Sequential(layers)
    model.cuda()
    model.eval()

    # print(model)

    return_layers = {'layer1': 0, 'layer2': 1, 'layer3': 2, 'layer4': 3}

    model = IntermediateLayerGetter(model, return_layers=return_layers)

    model.cuda()
    model.eval()
    return model
Example #3
0
class AttSegmentator(nn.Module):
    def __init__(self,
                 num_classes,
                 encoder,
                 att_type='additive',
                 img_size=(512, 512)):
        super().__init__()
        self.low_feat = IntermediateLayerGetter(encoder, {
            "layer1": "layer1"
        }).cuda()
        self.encoder = IntermediateLayerGetter(encoder, {
            "layer4": "out"
        }).cuda()
        # For resnet18
        encoder_dim = 512
        low_level_dim = 64
        self.num_classes = num_classes

        self.class_encoder = nn.Linear(num_classes, 512)

        self.attention_enc = Attention(encoder_dim, att_type)

        self.decoder = Decoder(2,
                               encoder_dim,
                               img_size,
                               low_level_dim=low_level_dim,
                               rates=[1, 6, 12, 18])

    def forward(self, x, v_class, out_att=False):
        self.low_feat.eval()
        self.encoder.eval()
        with torch.no_grad():
            low_level_feat = self.low_feat(x)['layer1']
            enc_feat = self.encoder(x)['out']

        query = self.class_encoder(v_class)
        shape = enc_feat.shape

        enc_feat = enc_feat.permute(0, 2, 3, 1).contiguous().view(
            shape[0], -1, shape[1])

        x_enc, attention = self.attention_enc(enc_feat, query)
        x_enc = x_enc.view(shape)

        segmentation = self.decoder(x_enc, low_level_feat)

        if out_att:
            return segmentation, attention
        return segmentation
Example #4
0
class AttSegmentator(nn.Module):
    def __init__(self,
                 num_classes,
                 encoder,
                 att_type='additive',
                 img_size=(512, 512)):
        super().__init__()
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.low_feat = IntermediateLayerGetter(encoder, {
            "layer1": "layer1"
        }).to(self.device)
        self.encoder = IntermediateLayerGetter(encoder, {
            "layer4": "out"
        }).to(self.device)
        # For resnet18
        encoder_dim = 512
        low_level_dim = 64
        self.num_classes = num_classes

        self.class_encoder = nn.Linear(num_classes, 512)

        self.attention_enc = Attention(encoder_dim, att_type)

        self.decoder = Decoder(2,
                               encoder_dim,
                               img_size,
                               low_level_dim=low_level_dim,
                               rates=[1, 6, 12, 18])

    def forward(self, x, v_class, out_att=False):

        #raise NotImplementedError("TODO: Implement the attention-based segmentation network")
        # Write the forward pass of the model.
        # Base the model on the segmentation model and add the attention layer.
        # Be aware of the dimentions.
        # x_enc, attention = self.attention_enc(x_enc, class_vec)

        # if out_att:
        #     return segmentation, attention
        # return segmentation

        self.low_feat.eval()
        self.encoder.eval()

        with torch.no_grad():
            # This is possible since gradients are not being updated
            low_level_feat = self.low_feat(x)['layer1']
            enc_feat = self.encoder(x)[
                'out']  # encoder output: value tensor [1, 512, 16, 16]

        x_enc = enc_feat.permute(
            0, 2, 3,
            1).contiguous()  # encoder output: value tensor [1, 16, 16, 512]
        x_enc = x_enc.view(
            x_enc.shape[0], -1, x_enc.shape[-1]
        )  # x_enc.view(batch_size, -1, n_features) [1, 16*16, 512]
        class_vec = self.class_encoder(
            v_class)  # Hidden states: this is query tensor  [1, 512]
        x_enc, attention = self.attention_enc(x_enc, class_vec)
        # x_enc_att: torch.Size([1, 16*16, 512]) , attention: torch.Size([1, 16*16y])
        #x_enc = x_enc + x_enc_att
        x_enc = x_enc.permute(0, 2, 1).contiguous().view(enc_feat.shape)
        # x_enc = torch.Size([1, 512, 16,16])
        #class_vec = class_vec.unsqueeze(2).unsqueeze(-1)

        segmentation = self.decoder(x_enc, low_level_feat)

        # if self.num_classes==1:
        #     segmentation = torch.sigmoid(segmentation)
        if out_att:
            return segmentation, attention
        return segmentation