def forward(self, x): # Encoder x, pad = pad_image_tensor(x, 32) x = self.firstconv(x) x = self.firstbn(x) x = self.firstrelu(x) x = self.firstmaxpool(x) e1 = self.encoder1(x) e2 = self.encoder2(e1) e3 = self.encoder3(e2) e4 = self.encoder4(e3) # Decoder with Skip Connections d4 = self.decoder4(e4) + e3 d3 = self.decoder3(d4) + e2 d2 = self.decoder2(d3) + e1 d1 = self.decoder1(d2) d1 = self.finaldropout(d1) # Final Classification f1 = self.finaldeconv1(d1) f2 = self.finalrelu1(f1) f3 = self.finalconv2(f2) f4 = self.finalrelu2(f3) f5 = self.finalconv3(f4) f5 = unpad_image_tensor(f5, pad) return f5
def test_pad_unpad_nonsymmetric(shape, padding): x = torch.randn(shape) x_padded, pad_params = pad_image_tensor(x, pad_size=padding) assert x_padded.size(2) % padding[0] == 0 assert x_padded.size(3) % padding[1] == 0 y = unpad_image_tensor(x_padded, pad_params) assert (x == y).all()
def forward(self, x): x, pad = pad_image_tensor(x, 32) enc_features = self.encoder(x) dec_features = self.decoder(enc_features) features = self.fpn_fuse(dec_features) features = self.dropout(features) features = self.final_decoder(features) logits = self.logits(features) logits = F.interpolate(logits, size=(x.size(2), x.size(3)), mode='bilinear', align_corners=True) logits = unpad_image_tensor(logits, pad) return logits
def forward(self, x): x, pad = pad_image_tensor(x, 32) enc_features = self.encoder(x) # Decode mask mask, dsv = self.decoder(enc_features) if self.full_size_mask: mask = F.interpolate(mask, size=x.size()[2:], mode="bilinear", align_corners=False) mask = unpad_image_tensor(mask, pad) output = {OUTPUT_MASK_KEY: mask, OUTPUT_MASK_32_KEY: dsv} return output