def test_script(self, input_param, input_shape, _): net = ViT(**(input_param)) net.eval() with torch.no_grad(): torch.jit.script(net) test_data = torch.randn(input_shape) test_script_save(net, test_data)
def __init__( self, in_channels: int, out_channels: int, img_size: Union[Sequence[int], int], feature_size: int = 16, hidden_size: int = 768, mlp_dim: int = 3072, num_heads: int = 12, pos_embed: str = "conv", norm_name: Union[Tuple, str] = "instance", conv_block: bool = True, res_block: bool = True, dropout_rate: float = 0.0, spatial_dims: int = 3, ) -> None: """ Args: in_channels: dimension of input channels. out_channels: dimension of output channels. img_size: dimension of input image. feature_size: dimension of network feature size. hidden_size: dimension of hidden layer. mlp_dim: dimension of feedforward layer. num_heads: number of attention heads. pos_embed: position embedding layer type. norm_name: feature normalization type and arguments. conv_block: bool argument to determine if convolutional block is used. res_block: bool argument to determine if residual block is used. dropout_rate: faction of the input units to drop. spatial_dims: number of spatial dims. Examples:: # for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch') # for single channel input 4-channel output with image size of (96,96), feature size of 32 and batch norm >>> net = UNETR(in_channels=1, out_channels=4, img_size=96, feature_size=32, norm_name='batch', spatial_dims=2) # for 4-channel input 3-channel output with image size of (128,128,128), conv position embedding and instance norm >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') """ super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: raise ValueError("hidden_size should be divisible by num_heads.") self.num_layers = 12 img_size = ensure_tuple_rep(img_size, spatial_dims) self.patch_size = ensure_tuple_rep(16, spatial_dims) self.feat_size = tuple(img_d // p_d for img_d, p_d in zip(img_size, self.patch_size)) self.hidden_size = hidden_size self.classification = False self.vit = ViT( in_channels=in_channels, img_size=img_size, patch_size=self.patch_size, hidden_size=hidden_size, mlp_dim=mlp_dim, num_layers=self.num_layers, num_heads=num_heads, pos_embed=pos_embed, classification=self.classification, dropout_rate=dropout_rate, spatial_dims=spatial_dims, ) self.encoder1 = UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=feature_size, kernel_size=3, stride=1, norm_name=norm_name, res_block=res_block, ) self.encoder2 = UnetrPrUpBlock( spatial_dims=spatial_dims, in_channels=hidden_size, out_channels=feature_size * 2, num_layer=2, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name=norm_name, conv_block=conv_block, res_block=res_block, ) self.encoder3 = UnetrPrUpBlock( spatial_dims=spatial_dims, in_channels=hidden_size, out_channels=feature_size * 4, num_layer=1, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name=norm_name, conv_block=conv_block, res_block=res_block, ) self.encoder4 = UnetrPrUpBlock( spatial_dims=spatial_dims, in_channels=hidden_size, out_channels=feature_size * 8, num_layer=0, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name=norm_name, conv_block=conv_block, res_block=res_block, ) self.decoder5 = UnetrUpBlock( spatial_dims=spatial_dims, in_channels=hidden_size, out_channels=feature_size * 8, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.decoder4 = UnetrUpBlock( spatial_dims=spatial_dims, in_channels=feature_size * 8, out_channels=feature_size * 4, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.decoder3 = UnetrUpBlock( spatial_dims=spatial_dims, in_channels=feature_size * 4, out_channels=feature_size * 2, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.decoder2 = UnetrUpBlock( spatial_dims=spatial_dims, in_channels=feature_size * 2, out_channels=feature_size, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)
def test_ill_arg(self): with self.assertRaises(ValueError): ViT( in_channels=1, img_size=(128, 128, 128), patch_size=(16, 16, 16), hidden_size=128, mlp_dim=3072, num_layers=12, num_heads=12, pos_embed="conv", classification=False, dropout_rate=5.0, ) with self.assertRaises(ValueError): ViT( in_channels=1, img_size=(32, 32, 32), patch_size=(64, 64, 64), hidden_size=512, mlp_dim=3072, num_layers=12, num_heads=8, pos_embed="perceptron", classification=False, dropout_rate=0.3, ) with self.assertRaises(ValueError): ViT( in_channels=1, img_size=(96, 96, 96), patch_size=(8, 8, 8), hidden_size=512, mlp_dim=3072, num_layers=12, num_heads=14, pos_embed="conv", classification=False, dropout_rate=0.3, ) with self.assertRaises(ValueError): ViT( in_channels=1, img_size=(97, 97, 97), patch_size=(4, 4, 4), hidden_size=768, mlp_dim=3072, num_layers=12, num_heads=8, pos_embed="perceptron", classification=True, dropout_rate=0.3, ) with self.assertRaises(ValueError): ViT( in_channels=4, img_size=(96, 96, 96), patch_size=(16, 16, 16), hidden_size=768, mlp_dim=3072, num_layers=12, num_heads=12, pos_embed="perc", classification=False, dropout_rate=0.3, )
def test_shape(self, input_param, input_shape, expected_shape): net = ViT(**input_param) with eval_mode(net): result, _ = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape)