def test_consistency(self, input_param, input_shape, _): for eps in [1e-4, 1e-5]: for momentum in [0.1, 0.01]: for affine in [True, False]: norm_param = { "eps": eps, "momentum": momentum, "affine": affine } input_param["norm_name"] = ("instance", norm_param) input_param_fuser = input_param.copy() input_param_fuser["norm_name"] = ("instance_nvfuser", norm_param) for memory_format in [ torch.contiguous_format, torch.channels_last_3d ]: net = DynUNet(**input_param).to( "cuda:0", memory_format=memory_format) net_fuser = DynUNet(**input_param_fuser).to( "cuda:0", memory_format=memory_format) net_fuser.load_state_dict(net.state_dict()) input_tensor = torch.randn(input_shape).to( "cuda:0", memory_format=memory_format) with eval_mode(net): result = net(input_tensor) with eval_mode(net_fuser): result_fuser = net_fuser(input_tensor) # torch.testing.assert_allclose() is deprecated since 1.12 and will be removed in 1.14 if pytorch_after(1, 12): torch.testing.assert_close(result, result_fuser) else: torch.testing.assert_allclose(result, result_fuser)
def get_network(properties, task_id, pretrain_path, checkpoint=None): n_class = len(properties["labels"]) in_channels = len(properties["modality"]) kernels, strides = get_kernels_strides(task_id) net = DynUNet( spatial_dims=3, in_channels=in_channels, out_channels=n_class, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], norm_name="instance", deep_supervision=True, deep_supr_num=deep_supr_num[task_id], ) if checkpoint is not None: pretrain_path = os.path.join(pretrain_path, checkpoint) if os.path.exists(pretrain_path): net.load_state_dict(torch.load(pretrain_path)) print("pretrained checkpoint: {} loaded".format(pretrain_path)) else: print("no pretrained checkpoint") return net