def _get_components( self, encoder: ResnetEncoder, num_classes: int, bridge_params: Dict, decoder_params: Dict, head_params: Dict, ): bridge = UnetBridge( in_channels=encoder.out_channels, in_strides=encoder.out_strides, out_channels=encoder.out_channels[-1], block_fn=partial(EncoderUpsampleBlock, pool_first=True), **bridge_params, ) decoder = UNetDecoder( in_channels=bridge.out_channels, in_strides=bridge.out_strides, block_fn=partial(DecoderConcatBlock, aggregate_first=True, upsample_scale=2), **decoder_params, ) head = UnetHead( in_channels=decoder.out_channels, in_strides=decoder.out_strides, out_channels=num_classes, num_upsample_blocks=int(np.log2(decoder.out_strides[-1])), **head_params, ) return encoder, bridge, decoder, head
def _get_components( self, encoder: UnetEncoder, num_classes: int, bridge_params: Dict, decoder_params: Dict, head_params: Dict, ): bridge = UnetBridge( in_channels=encoder.out_channels, in_strides=encoder.out_strides, out_channels=encoder.out_channels[-1] * 2, block_fn=EncoderDownsampleBlock, **bridge_params, ) decoder = UNetDecoder( in_channels=bridge.out_channels, in_strides=bridge.out_strides, block_fn=DecoderConcatBlock, **decoder_params, ) head = UnetHead( in_channels=decoder.out_channels, in_strides=decoder.out_strides, out_channels=num_classes, num_upsample_blocks=int(np.log2(decoder.out_strides[-1])), **head_params, ) return encoder, bridge, decoder, head
def _get_components( self, encoder: ResnetEncoder, num_classes: int, bridge_params: Dict, decoder_params: Dict, head_params: Dict, ): bridge = None decoder = PSPDecoder(in_channels=encoder.out_channels, in_strides=encoder.out_strides, **decoder_params) head = UnetHead(in_channels=decoder.out_channels, in_strides=decoder.out_strides, out_channels=num_classes, upsample_scale=decoder.out_strides[-1], interpolation_mode="bilinear", align_corners=True, **head_params) return encoder, bridge, decoder, head
def _get_components( self, encoder: ResnetEncoder, num_classes: int, bridge_params: Dict, decoder_params: Dict, head_params: Dict, ): bridge = None decoder = UNetDecoder(in_channels=encoder.out_channels, in_strides=encoder.out_strides, block_fn=partial(DecoderSumBlock, aggregate_first=False, upsample_scale=None), **decoder_params) head = UnetHead(in_channels=decoder.out_channels, in_strides=decoder.out_strides, out_channels=num_classes, num_upsample_blocks=int( np.log2(decoder.out_strides[-1])), **head_params) return encoder, bridge, decoder, head