Example #1
0
    def test_consistency(self, input_param, input_shape, _):
        for eps in [1e-4, 1e-5]:
            for momentum in [0.1, 0.01]:
                for affine in [True, False]:
                    norm_param = {
                        "eps": eps,
                        "momentum": momentum,
                        "affine": affine
                    }
                    input_param["norm_name"] = ("instance", norm_param)
                    input_param_fuser = input_param.copy()
                    input_param_fuser["norm_name"] = ("instance_nvfuser",
                                                      norm_param)
                    for memory_format in [
                            torch.contiguous_format, torch.channels_last_3d
                    ]:
                        net = DynUNet(**input_param).to(
                            "cuda:0", memory_format=memory_format)
                        net_fuser = DynUNet(**input_param_fuser).to(
                            "cuda:0", memory_format=memory_format)
                        net_fuser.load_state_dict(net.state_dict())

                        input_tensor = torch.randn(input_shape).to(
                            "cuda:0", memory_format=memory_format)
                        with eval_mode(net):
                            result = net(input_tensor)
                        with eval_mode(net_fuser):
                            result_fuser = net_fuser(input_tensor)

                        # torch.testing.assert_allclose() is deprecated since 1.12 and will be removed in 1.14
                        if pytorch_after(1, 12):
                            torch.testing.assert_close(result, result_fuser)
                        else:
                            torch.testing.assert_allclose(result, result_fuser)
def get_network(properties, task_id, pretrain_path, checkpoint=None):
    n_class = len(properties["labels"])
    in_channels = len(properties["modality"])
    kernels, strides = get_kernels_strides(task_id)

    net = DynUNet(
        spatial_dims=3,
        in_channels=in_channels,
        out_channels=n_class,
        kernel_size=kernels,
        strides=strides,
        upsample_kernel_size=strides[1:],
        norm_name="instance",
        deep_supervision=True,
        deep_supr_num=deep_supr_num[task_id],
    )

    if checkpoint is not None:
        pretrain_path = os.path.join(pretrain_path, checkpoint)
        if os.path.exists(pretrain_path):
            net.load_state_dict(torch.load(pretrain_path))
            print("pretrained checkpoint: {} loaded".format(pretrain_path))
        else:
            print("no pretrained checkpoint")
    return net