예제 #1
0
def predict():
    src_path = str(
        Path(ROOT_DIR + "/data/LIDC/LUNA16/segmentation/Image/3_98/"))
    mask_path = str(
        Path(ROOT_DIR + "/data/LIDC/LUNA16/segmentation/Mask/3_98/"))
    imges = []
    masks = []
    for z in range(16):
        img = cv2.imread(src_path + str(z) + ".bmp", cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(mask_path + str(z) + ".bmp", cv2.IMREAD_GRAYSCALE)
        imges.append(img)
        masks.append(mask)

    test_imges = np.array(imges)
    test_imges = np.reshape(test_imges, (16, 96, 96))

    test_masks = np.array(masks)
    test_masks = np.reshape(test_masks, (16, 96, 96))
    Vnet3d = Vnet3dModule(
        96,
        96,
        16,
        channels=1,
        costname=("dice coefficient", ),
        inference=True,
        model_path=Path(ROOT_DIR +
                        "/model/trained/segmeation/model/Vnet3d.pd-50000"))
    predict = Vnet3d.prediction(test_imges)
    test_images = np.multiply(test_imges, 1.0 / 255.0)
    test_masks = np.multiply(test_masks, 1.0 / 255.0)
    save_images(test_images, [4, 4], "test_src.bmp")
    save_images(test_masks, [4, 4], "test_mask.bmp")
    save_images(predict, [4, 4], "test_predict.bmp")
예제 #2
0
def train():
    '''
    Preprocessing for dataset
    '''
    # Read  data set (Train data from CSV file)
    csvmaskdata = pd.read_csv(
        Path(ROOT_DIR + '/dataprocess/data/Segmentation3dMask.csv'))
    csvimagedata = pd.read_csv(
        Path(ROOT_DIR + '/dataprocess/data/Segmentation3dImage.csv'))
    maskdata = csvmaskdata.iloc[:, :].values
    imagedata = csvimagedata.iloc[:, :].values
    # shuffle imagedata and maskdata together
    perm = np.arange(len(csvimagedata))
    np.random.shuffle(perm)
    imagedata = imagedata[perm]
    maskdata = maskdata[perm]

    Vnet3d = Vnet3dModule(96,
                          96,
                          16,
                          channels=1,
                          costname=("dice coefficient", ))
    Vnet3d.train(imagedata, maskdata, "Vnet3d.pd",
                 str(Path(ROOT_DIR + "/model/trained/segmeation/")), 0.001,
                 0.5, 10, 6)
def predict():
    '''
    Preprocessing for dataset
    '''
    # Read  data set (Train data from CSV file)
    csvdata = pd.read_csv('dataprocess\\data/test.csv')
    maskdata = csvdata.iloc[:, 1].values
    imagedata = csvdata.iloc[:, 0].values

    dice_values = []
    Vnet3d = Vnet3dModule(128,
                          128,
                          32,
                          channels=1,
                          costname=("dice coefficient", ),
                          inference=True,
                          model_path="log\segmeation\model\Vnet3d.pd-20000")
    for index in range(imagedata.shape[0]):
        image_gt = np.load(imagedata[index])
        mask_pd = Vnet3d.prediction(image_gt)
        mask_gt = np.load(maskdata[index])
        dice_value = calcu_dice(mask_pd, mask_gt)
        print("index,dice:", (index, dice_value))
        dice_values.append(dice_value)
    average = sum(dice_values) / len(dice_values)
    print("average dice:", average)
예제 #4
0
def train():
    '''
    Preprocessing for dataset
    '''
    # Read  data set (Train data from CSV file)
    csvmaskdata = pd.read_csv('MaskLiver.csv')
    csvimagedata = pd.read_csv('Image.csv')
    maskdata = csvmaskdata.iloc[:, :].values
    imagedata = csvimagedata.iloc[:, :].values
    # shuffle imagedata and maskdata together
    perm = np.arange(len(csvimagedata))
    np.random.shuffle(perm)
    imagedata = imagedata[perm]
    maskdata = maskdata[perm]
    Vnet3d = Vnet3dModule(128, 128, 64, channels=1, costname=("dice coefficient",)) #256-->128  16 -->1
    Vnet3d.train(imagedata, maskdata, "Vnet3d.pd", "log/diceVnet3d/", 0.001, 0.7, 10, 1) # This is creating weird directories , I think we need to get rid of some slashes. 
예제 #5
0
def train():
    '''
    Preprocessing for dataset
    '''
    # Read  data set (Train data from CSV file)
    csvdata = pd.read_csv('dataprocess\\data/train.csv')
    maskdata = csvdata.iloc[:, 1].values
    imagedata = csvdata.iloc[:, 0].values
    # shuffle imagedata and maskdata together
    perm = np.arange(len(imagedata))
    np.random.shuffle(perm)
    imagedata = imagedata[perm]
    maskdata = maskdata[perm]

    Vnet3d = Vnet3dModule(128, 128, 32, channels=1, costname=("dice coefficient",))
    Vnet3d.train(imagedata, maskdata, "Vnet3d.pd", "log\\segmeation\\", 0.001, 0.5, 10, 3)
예제 #6
0
def inference():
    depth_z = 48
    Vnet3d = Vnet3dModule(512, 512, depth_z, channels=1, costname=("dice coefficient",), inference=True,
                          model_path="log\segmeation\model\Vnet3d.pd-20000")
    test_path = "E:\junqiangchen\data\kits19\kits19process\\"
    image_path = "Image"
    mask_path = "Mask"
    dice_values = []
    for num in range(200, 210, 1):
        index = 0
        batch_xs = []
        batch_ys = []
        test_image_path = test_path + image_path + "/" + str(num)
        test_mask_path = test_path + mask_path + "/" + str(num)
        for _ in os.listdir(test_image_path):
            image = cv2.imread(test_image_path + "/" + str(index) + ".bmp", cv2.IMREAD_GRAYSCALE)
            label = cv2.imread(test_mask_path + "/" + str(index) + ".bmp", cv2.IMREAD_GRAYSCALE)
            batch_xs.append(image)
            batch_ys.append(label)
            index += 1
        xs_array = np.array(batch_xs)
        ys_array = np.array(batch_ys)
        xs_array = np.reshape(xs_array, (index, 512, 512))
        ys_array = np.reshape(ys_array, (index, 512, 512))
        ys_pd_array = np.empty((index, 512, 512), np.uint8)

        last_depth = 0
        for depth in range(0, index // depth_z, 1):
            patch_xs = xs_array[depth * depth_z:(depth + 1) * depth_z, :, :]
            pathc_pd = Vnet3d.prediction(patch_xs)
            ys_pd_array[depth * depth_z:(depth + 1) * depth_z, :, :] = pathc_pd
            last_depth = depth
        if index != depth_z * last_depth:
            patch_xs = xs_array[(index - depth_z):index, :, :]
            pathc_pd = Vnet3d.prediction(patch_xs)
            ys_pd_array[(index - depth_z):index, :, :] = pathc_pd
        ys_pd_sitk = sitk.GetImageFromArray(ys_pd_array)
        ys_pd_array = removesmallConnectedCompont(ys_pd_sitk, 0.4)
        ys_pd_array = np.clip(ys_pd_array, 0, 255).astype('uint8')
        dice_value = calcu_dice(ys_pd_array, ys_array)
        print("num,dice:", (num, dice_value))
        dice_values.append(dice_value)
        for depth in range(0, index, 1):
            cv2.imwrite(test_mask_path + "/" + str(depth) + "predict.bmp", ys_pd_array[depth])
    average = sum(dice_values) / len(dice_values)
    print("average dice:", average)
예제 #7
0
def train():
    '''
    Vnet network segmentation kidney fine segmatation
    Preprocessing for dataset
    '''
    # Read  data set (Train data from CSV file)
    csvdata = pd.read_csv('dataprocess\\data/train.csv')
    maskdata = csvdata.iloc[:, 1].values
    imagedata = csvdata.iloc[:, 0].values
    # shuffle imagedata and maskdata together
    perm = np.arange(len(imagedata))
    np.random.shuffle(perm)
    imagedata = imagedata[perm]
    maskdata = maskdata[perm]

    Vnet3d = Vnet3dModule(128, 128, 64, channels=4, numclass=3, costname=("dice coefficient",))
    Vnet3d.train(imagedata, maskdata, "Vnet3d.pd", "log\\segmeation\\VNet\\", 0.001, 0.5, 20, 1, [8, 8])
예제 #8
0
파일: vnet3d_train.py 프로젝트: KDV5/LiTS
def train():
    '''
    Preprocessing for dataset
    '''
    # Read  data set (Train data from CSV file)
    csvimagedata = pd.read_csv('./dataprocess/train_img.csv')
    imagedata = csvimagedata.iloc[:, :].values
    csv_pos_data = pd.read_csv('./dataprocess/train_img_SE.csv')
    pos_data = csv_pos_data.iloc[:, :].values
    # shuffle imagedata and maskdata together
    perm = np.arange(len(csvimagedata))
    np.random.shuffle(perm)
    imagedata = imagedata[perm]

    Vnet3d = Vnet3dModule(256,
                          256,
                          16,
                          channels=1,
                          costname=("dice coefficient", ))
    Vnet3d.train(imagedata, pos_data, "Vnet3d.pd", "log/diceVnet3d/", 0.001,
                 0.7, 10, 1)
예제 #9
0
def predict():
    # src_path = "G:\Data\LIDC\LUNA16\segmentation\Image\\3_98\\"
    npy_path = '/data/shanyx/larry/LUNA16-Lung-Nodule-Analysis-2016-Challenge/gen_stage4_iter4999.npy'
    # npy_path = '/data/shanyx/larry/LUNA16-Lung-Nodule-Analysis-2016-Challenge/inp_stage4_iter249.npy'
    # mask_path = "G:\Data\LIDC\LUNA16\segmentation\Mask\\3_98\\"
    imges = []
    # masks = []
    # for z in range(16):
    #     img = cv2.imread(src_path + str(z) + ".bmp", cv2.IMREAD_GRAYSCALE)
    #     # mask = cv2.imread(mask_path + str(z) + ".bmp", cv2.IMREAD_GRAYSCALE)
    #     imges.append(img)
    #     # masks.append(mask)
    test = np.load(npy_path) * 255
    test = test[:, :, 5:20, 55:115, 40:100]
    test = torch.from_numpy(test)
    x_tmp = F.interpolate(test, (16, 96, 96),\
             mode='trilinear', align_corners=True)
    test_imges = x_tmp.numpy()
    test_imges = np.reshape(test_imges, (16, 96, 96))

    # test_masks = np.array(masks)
    # test_masks = np.reshape(test_masks, (16, 96, 96))
    Vnet3d = Vnet3dModule(
        96,
        96,
        16,
        channels=1,
        costname=("dice coefficient", ),
        inference=True,
        model_path=
        "/data/shanyx/larry/LUNA16-Lung-Nodule-Analysis-2016-Challenge/segmeation/model/Vnet3d.pd-50000"
    )
    predict = Vnet3d.prediction(test_imges)
    print(predict.shape)
    # print(predict)
    test_images = np.multiply(test_imges, 1.0 / 255.0)
    # test_masks = np.multiply(test_masks, 1.0 / 255.0)
    save_images(test_images, [4, 4], "test_src.bmp")
    # save_images(test_masks, [4, 4], "test_mask.bmp")
    save_images(predict, [4, 4], "test_predict.bmp")
예제 #10
0
파일: vnet3d_predict.py 프로젝트: KDV5/LiTS
def predict():
    height = 512
    width = 512
    dimension = 32
    Vnet3d = Vnet3dModule(height,
                          width,
                          dimension,
                          channels=1,
                          costname=("dice coefficient", ),
                          inference=True,
                          model_path="log\\diceVnet3d\\model\Vnet3d.pd")
    srcimagepath = "D:\Data\LIST\\test\Image\\111"
    predictpath = "D:\Data\LIST\\test\PredictMask"
    index = 0
    imagelist = []
    for _ in os.listdir(srcimagepath):
        image = cv2.imread(srcimagepath + "/" + str(index) + ".bmp",
                           cv2.IMREAD_GRAYSCALE)
        tmpimage = np.reshape(image, (height, width, 1))
        imagelist.append(tmpimage)
        index += 1

    imagearray = np.array(imagelist)
    imagearray = np.reshape(imagearray, (index, height, width, 1))
    imagemask = np.zeros((index, height, width), np.int32)

    for i in range(0, index + dimension, dimension // 2):
        if (i + dimension) <= index:
            imagedata = imagearray[i:i + dimension, :, :, :]
            imagemask[i:i + dimension, :, :] = Vnet3d.prediction(imagedata)
        elif (i < index):
            imagedata = imagearray[index - dimension:index, :, :, :]
            imagemask[index -
                      dimension:index, :, :] = Vnet3d.prediction(imagedata)

    mask = imagemask.copy()
    mask[imagemask > 0] = 255
    result = np.clip(mask, 0, 255).astype('uint8')
    for i in range(0, index):
        cv2.imwrite(predictpath + "/" + str(i) + ".bmp", result[i])
예제 #11
0
def train():
    '''
    Preprocessing for dataset
    '''
    # Read  data set (Train data from CSV file)
    csvmaskdata = pd.read_csv('promise12Vnet3dMask.csv')
    csvimagedata = pd.read_csv('promise12Vnet3dImage.csv')
    maskdata = csvmaskdata.iloc[:, :].values
    imagedata = csvimagedata.iloc[:, :].values
    # shuffle imagedata and maskdata together
    perm = np.arange(len(csvimagedata))
    np.random.shuffle(perm)
    imagedata = imagedata[perm]
    maskdata = maskdata[perm]

    Vnet3d = Vnet3dModule(128,
                          128,
                          64,
                          channels=1,
                          costname="dice coefficient")
    Vnet3d.train(imagedata, maskdata, "model\\Vnet3dModule.pd", "log\\", 0.001,
                 0.7, 100000, 1)
예제 #12
0
def train():
    '''
    Preprocessing for dataset
    '''
    # Read  data set (Train data from CSV file)
    csvmaskdata = pd.read_csv('trainY25625616.csv')
    csvimagedata = pd.read_csv('trainX25625616.csv')
    maskdata = csvmaskdata.iloc[:, :].values
    imagedata = csvimagedata.iloc[:, :].values
    # shuffle imagedata and maskdata together
    perm = np.arange(len(csvimagedata))
    np.random.shuffle(perm)
    imagedata = imagedata[perm]
    maskdata = maskdata[perm]

    Vnet3d = Vnet3dModule(256,
                          256,
                          16,
                          channels=1,
                          costname=("dice coefficient", ))
    Vnet3d.train(imagedata, maskdata, "Vnet3d.pd", "log\\diceVnet3d\\", 0.001,
                 0.7, 10, 1)
예제 #13
0
def predict():
    src_path = "G:\Data\LIDC\LUNA16\segmentation\Image\\3_98\\"
    mask_path = "G:\Data\LIDC\LUNA16\segmentation\Mask\\3_98\\"
    imges = []
    masks = []
    for z in range(16):
        img = cv2.imread(src_path + str(z) + ".bmp", cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(mask_path + str(z) + ".bmp", cv2.IMREAD_GRAYSCALE)
        imges.append(img)
        masks.append(mask)

    test_imges = np.array(imges)
    test_imges = np.reshape(test_imges, (16, 96, 96))

    test_masks = np.array(masks)
    test_masks = np.reshape(test_masks, (16, 96, 96))
    Vnet3d = Vnet3dModule(96, 96, 16, channels=1, costname=("dice coefficient",), inference=True,
                          model_path="log\segmeation\model\Vnet3d.pd-50000")
    predict = Vnet3d.prediction(test_imges)
    test_images = np.multiply(test_imges, 1.0 / 255.0)
    test_masks = np.multiply(test_masks, 1.0 / 255.0)
    save_images(test_images, [4, 4], "test_src.bmp")
    save_images(test_masks, [4, 4], "test_mask.bmp")
    save_images(predict, [4, 4], "test_predict.bmp")
예제 #14
0
def predict0():
    Vnet3d = Vnet3dModule(256,
                          256,
                          64,
                          inference=True,
                          model_path="model\\Vnet3dModule.pd")
    for filenumber in range(30):
        batch_xs = np.zeros(shape=(64, 256, 256))
        for index in range(64):
            imgs = cv2.imread(
                "D:\Data\PROMISE2012\Vnet3d_data\\test\image\\" +
                str(filenumber) + "\\" + str(index) + ".bmp", 0)
            batch_xs[index, :, :] = imgs[128:384, 128:384]

        predictvalue = Vnet3d.prediction(batch_xs)

        for index in range(64):
            result = np.zeros(shape=(512, 512), dtype=np.uint8)
            result[128:384, 128:384] = predictvalue[index]
            kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
            result = cv2.morphologyEx(result, cv2.MORPH_CLOSE, kernel)
            cv2.imwrite(
                "D:\Data\PROMISE2012\Vnet3d_data\\test\image\\" +
                str(filenumber) + "\\" + str(index) + "mask.bmp", result)
def inference():
    """
    Vnet network segmentation kidney fine segmatation
    :return:
    """
    depth_z = 32
    Vnet3d = Vnet3dModule(512,
                          512,
                          depth_z,
                          channels=1,
                          costname=("dice coefficient", ),
                          inference=True,
                          model_path="log\segmeation\VNet\model\Vnet3d.pd")
    kits_path = "D:\Data\kits19\kits19\\test"
    image_name = "imaging.nii.gz"
    result_path = "D:\Data\kits19\kits19test"
    """
    load itk image,change z Spacing value to 1,and save image ,liver mask ,tumor mask
    :return:None
    """
    # step2 get all train image
    path_list = file_name_path(kits_path)
    read = open("kidneyrang.txt", 'r')
    # step3 get signal train image and mask
    for subsetindex in range(len(path_list)):
        line = read.readline()
        line = line.split(',')
        casename = line[0]
        start = int(line[1])
        end = int(line[2][0:-1])
        kits_subset_path = kits_path + "/" + str(path_list[subsetindex]) + "/"
        file_image = kits_subset_path + image_name
        # 1 load itk image and truncate value with upper and lower and get rang kideny region
        src = load_itkfilewithtrucation(file_image, 300, -200)
        originSpacing = src.GetSpacing()
        src_array = sitk.GetArrayFromImage(src)
        sub_src_array = src_array[:, :, start:end]
        sub_src = sitk.GetImageFromArray(sub_src_array)
        sub_src.SetSpacing(originSpacing)
        print(sub_src.GetSize())
        thickspacing, widthspacing = originSpacing[0], originSpacing[1]
        # 2 change z spacing >1.0 to 1.0
        if thickspacing > 1.0:
            _, sub_src = resize_image_itk(
                sub_src,
                newSpacing=(1.0, widthspacing, widthspacing),
                originSpcaing=(thickspacing, widthspacing, widthspacing),
                resamplemethod=sitk.sitkLinear)
        xs_array = sitk.GetArrayFromImage(sub_src)
        xs_array = np.swapaxes(xs_array, 0, 2)
        index = np.shape(xs_array)[0]
        ys_pd_array = np.zeros(np.shape(xs_array), np.uint8)

        last_depth = 0
        for depth in range(0, index // depth_z, 1):
            patch_xs = xs_array[depth * depth_z:(depth + 1) * depth_z, :, :]
            pathc_pd = Vnet3d.prediction(patch_xs)
            ys_pd_array[depth * depth_z:(depth + 1) * depth_z, :, :] = pathc_pd
            last_depth = depth
        if index != depth_z * last_depth:
            patch_xs = xs_array[(index - depth_z):index, :, :]
            pathc_pd = Vnet3d.prediction(patch_xs)
            ys_pd_array[(index - depth_z):index, :, :] = pathc_pd

        ys_pd_sitk = sitk.GetImageFromArray(ys_pd_array)
        ys_pd_array = removesmallConnectedCompont(ys_pd_sitk, 0.2)
        ys_pd_array = np.clip(ys_pd_array, 0, 255).astype('uint8')
        sub_src_path = result_path + "/src/" + casename
        sub_pred_path = result_path + "/kidney_modify/" + casename
        if not os.path.exists(sub_src_path):
            os.makedirs(sub_src_path)
        if not os.path.exists(sub_pred_path):
            os.makedirs(sub_pred_path)
        for i in range(np.shape(xs_array)[0]):
            cv2.imwrite(sub_src_path + "/" + str(i) + ".bmp", xs_array[i])
            cv2.imwrite(sub_pred_path + "/" + str(i) + ".bmp", ys_pd_array[i])
예제 #16
0
def main_worker(gpu, ngpus_per_node, args):
    if len(args.gpu) == 1:
        args.gpu = 0
    else:
        args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend='nccl',
                                init_method='tcp://127.0.0.1:' + args.port,
                                world_size=args.world_size,
                                rank=args.rank)

    ################
    # Define model #
    ################
    scale_factor = 3 / 2
    tmp_scale = args.img_size_max / args.img_size_min
    args.num_scale = int(np.round(np.log(tmp_scale) / np.log(scale_factor)))
    args.size_list = [
        int(args.img_size_min * scale_factor**i)
        for i in range(args.num_scale + 1)
    ]

    if args.modeltype == '3D':
        discriminator = Discriminator_3D(args.img_size_min, args.num_scale,
                                         scale_factor)
        generator = Generator_3D(args.img_size_min, args.num_scale,
                                 scale_factor)

    elif args.modeltype == '2D':
        discriminator = Discriminator_3D(args.img_size_min, args.num_scale,
                                         scale_factor)
        generator = Generator_2D(args.img_size_min, args.num_scale,
                                 scale_factor)

    networks = [discriminator, generator]

    if args.distributed:
        if args.gpu is not None:
            print('Distributed to', args.gpu)
            torch.cuda.set_device(args.gpu)
            networks = [x.cuda(args.gpu) for x in networks]
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            networks = [
                torch.nn.parallel.DistributedDataParallel(
                    x, device_ids=[args.gpu], output_device=args.gpu)
                for x in networks
            ]
        else:
            networks = [x.cuda() for x in networks]
            networks = [
                torch.nn.parallel.DistributedDataParallel(x) for x in networks
            ]

    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        networks = [x.cuda(args.gpu) for x in networks]
    else:
        networks = [torch.nn.DataParallel(x).cuda() for x in networks]

    discriminator, generator, = networks

    ######################
    # Loss and Optimizer #
    ######################
    if args.distributed:
        d_opt = torch.optim.Adam(
            discriminator.module.sub_discriminators[0].parameters(), 5e-4,
            (0.5, 0.999))
        g_opt = torch.optim.Adam(
            generator.module.sub_generators[0].parameters(), 5e-4,
            (0.5, 0.999))
    else:
        d_opt = torch.optim.Adam(
            discriminator.sub_discriminators[0].parameters(), 5e-4,
            (0.5, 0.999))
        g_opt = torch.optim.Adam(generator.sub_generators[0].parameters(),
                                 5e-4, (0.5, 0.999))

    ##############
    # Load model #
    ##############
    args.stage = 0
    if args.load_model is not None:
        check_load = open(os.path.join(args.log_dir, "checkpoint.txt"), 'r')
        to_restore = check_load.readlines()[-1].strip()
        load_file = os.path.join(args.log_dir, to_restore)
        if os.path.isfile(load_file):
            print("=> loading checkpoint '{}'".format(load_file))
            checkpoint = torch.load(load_file, map_location='cpu')
            for _ in range(int(checkpoint['stage'])):
                generator.progress()
                discriminator.progress()
            networks = [discriminator, generator]
            if args.distributed:
                if args.gpu is not None:
                    print('Distributed to', args.gpu)
                    torch.cuda.set_device(args.gpu)
                    networks = [x.cuda(args.gpu) for x in networks]
                    args.batch_size = int(args.batch_size / ngpus_per_node)
                    args.workers = int(args.workers / ngpus_per_node)
                    networks = [
                        torch.nn.parallel.DistributedDataParallel(
                            x, device_ids=[args.gpu], output_device=args.gpu)
                        for x in networks
                    ]
                else:
                    networks = [x.cuda() for x in networks]
                    networks = [
                        torch.nn.parallel.DistributedDataParallel(x)
                        for x in networks
                    ]

            elif args.gpu is not None:
                torch.cuda.set_device(args.gpu)
                networks = [x.cuda(args.gpu) for x in networks]
            else:
                networks = [torch.nn.DataParallel(x).cuda() for x in networks]

            discriminator, generator, = networks

            args.stage = checkpoint['stage']
            args.img_to_use = checkpoint['img_to_use']
            discriminator.load_state_dict(checkpoint['D_state_dict'])
            generator.load_state_dict(checkpoint['G_state_dict'])
            d_opt.load_state_dict(checkpoint['d_optimizer'])
            g_opt.load_state_dict(checkpoint['g_optimizer'])
            print("=> loaded checkpoint '{}' (stage {})".format(
                load_file, checkpoint['stage']))
        else:
            print("=> no checkpoint found at '{}'".format(args.log_dir))

    cudnn.benchmark = True

    ###########
    # Dataset #
    ###########
    train_dataset, _ = get_dataset(args.dataset, args)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               pin_memory=True)

    ######################
    # Validate and Train #
    ######################
    z_fix_list = [
        torch.randn(args.batch_size, 1, int(args.size_list[0] / 6),
                    args.size_list[0], args.size_list[0])
    ]
    zero_list = [
        torch.zeros(args.batch_size, 1, int(args.size_list[zeros_idx] / 6),
                    args.size_list[zeros_idx], args.size_list[zeros_idx])
        for zeros_idx in range(1, args.num_scale + 1)
    ]
    z_fix_list = z_fix_list + zero_list

    if args.validation:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    elif args.test:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    if not args.multiprocessing_distributed or (
            args.multiprocessing_distributed
            and args.rank % ngpus_per_node == 0):
        check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+")
        record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+")
        record_txt.write('DATASET\t:\t{}\n'.format(args.dataset))
        record_txt.write('GANTYPE\t:\t{}\n'.format(args.gantype))
        record_txt.write('IMGTOUSE\t:\t{}\n'.format(args.img_to_use))
        record_txt.close()

    ######################
    # Segmentation Model #
    ######################
    Vnet3d = Vnet3dModule(
        96,
        96,
        16,
        channels=1,
        costname=("dice coefficient", ),
        inference=True,
        model_path=
        "/data/shanyx/larry/LUNA16-Lung-Nodule-Analysis-2016-Challenge/segmeation/model/Vnet3d.pd-50000"
    )

    for stage in range(args.stage, args.num_scale + 1):
        if args.distributed:
            train_sampler.set_epoch(stage)

        trainSinGAN(Vnet3d, train_loader, networks, {
            "d_opt": d_opt,
            "g_opt": g_opt
        }, stage, args, {"z_rec": z_fix_list})
        # validateSinGAN(train_loader, networks, stage, args, {"z_rec": z_fix_list})

        if args.distributed:
            discriminator.module.progress()
            generator.module.progress()
        else:
            discriminator.progress()
            generator.progress()

        networks = [discriminator, generator]

        if args.distributed:
            if args.gpu is not None:
                print('Distributed', args.gpu)
                torch.cuda.set_device(args.gpu)
                networks = [x.cuda(args.gpu) for x in networks]
                args.batch_size = int(args.batch_size / ngpus_per_node)
                args.workers = int(args.workers / ngpus_per_node)
                networks = [
                    torch.nn.parallel.DistributedDataParallel(
                        x, device_ids=[args.gpu], output_device=args.gpu)
                    for x in networks
                ]
            else:
                networks = [x.cuda() for x in networks]
                networks = [
                    torch.nn.parallel.DistributedDataParallel(x)
                    for x in networks
                ]

        elif args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            networks = [x.cuda(args.gpu) for x in networks]
        else:
            networks = [torch.nn.DataParallel(x).cuda() for x in networks]

        discriminator, generator, = networks

        # Update the networks at finest scale
        if args.distributed:
            for net_idx in range(generator.module.current_scale):
                for param in generator.module.sub_generators[
                        net_idx].parameters():
                    param.requires_grad = False
                for param in discriminator.module.sub_discriminators[
                        net_idx].parameters():
                    param.requires_grad = False

            d_opt = torch.optim.Adam(
                discriminator.module.sub_discriminators[
                    discriminator.current_scale].parameters(), 5e-4,
                (0.5, 0.999))
            g_opt = torch.optim.Adam(
                generator.module.sub_generators[
                    generator.current_scale].parameters(), 5e-4, (0.5, 0.999))
        else:
            for net_idx in range(generator.current_scale):
                for param in generator.sub_generators[net_idx].parameters():
                    param.requires_grad = False
                for param in discriminator.sub_discriminators[
                        net_idx].parameters():
                    param.requires_grad = False

            d_opt = torch.optim.Adam(
                discriminator.sub_discriminators[
                    discriminator.current_scale].parameters(), 5e-4,
                (0.5, 0.999))
            g_opt = torch.optim.Adam(
                generator.sub_generators[generator.current_scale].parameters(),
                5e-4, (0.5, 0.999))

        ##############
        # Save model #
        ##############
        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            if stage == 0:
                check_list = open(os.path.join(args.log_dir, "checkpoint.txt"),
                                  "a+")
            save_checkpoint(
                {
                    'stage': stage + 1,
                    'D_state_dict': discriminator.state_dict(),
                    'G_state_dict': generator.state_dict(),
                    'd_optimizer': d_opt.state_dict(),
                    'g_optimizer': g_opt.state_dict(),
                    'img_to_use': args.img_to_use
                }, check_list, args.log_dir, stage + 1)
            if stage == args.num_scale:
                check_list.close()
def inference():
    """
        Vnet network segmentation kidney corse segmatation,get range of kidney
        Course segmentation,resize image to fixed size,segmentation the mask and get mask range
        :return:
        """
    depth_z = 64
    height = 256
    Vnet3d = Vnet3dModule(
        height,
        height,
        depth_z,
        channels=1,
        costname=("dice coefficient", ),
        inference=True,
        model_path="log\segmeation\CoarseVNet\model\Vnet3d.pd")
    fixed_size = [depth_z, height, height]
    kits_path = "D:\Data\kits19\kits19\\test"
    image_name = "imaging.nii.gz"
    """
    load itk image,change z Spacing value to 1,and save image ,liver mask ,tumor mask
    :return:None
    """
    # step2 get all train image
    path_list = file_name_path(kits_path)
    file_name = "kidneyrang.txt"
    out = open(file_name, 'w')
    # step3 get signal train image and mask
    for subsetindex in range(len(path_list)):
        kits_subset_path = kits_path + "/" + str(path_list[subsetindex]) + "/"
        file_image = kits_subset_path + image_name
        # 1 load itk image and truncate value with upper and lower
        src = load_itkfilewithtrucation(file_image, 300, -200)
        originSize = src.GetSize()
        originSpacing = src.GetSpacing()
        thickspacing, widthspacing = originSpacing[0], originSpacing[1]
        # 2 change image size to fixed size(512,512,64)
        _, src = resize_image_itkwithsize(
            src,
            newSize=fixed_size,
            originSize=originSize,
            originSpcaing=[thickspacing, widthspacing, widthspacing],
            resamplemethod=sitk.sitkLinear)
        # 3 get resample array(image and segmask)
        srcimg = sitk.GetArrayFromImage(src)
        srcimg = np.swapaxes(srcimg, 0, 2)
        ys_pd_array = Vnet3d.prediction(srcimg)
        ys_pd_array = np.clip(ys_pd_array, 0, 255).astype('uint8')

        ys_pd_array = np.swapaxes(ys_pd_array, 0, 2)
        ys_pd_itk = sitk.GetImageFromArray(ys_pd_array)
        ys_pd_itk.SetSpacing(src.GetSpacing())
        ys_pd_itk.SetOrigin(src.GetOrigin())
        ys_pd_itk.SetDirection(src.GetDirection())

        _, ys_pd_itk = resize_image_itkwithsize(
            ys_pd_itk,
            newSize=originSize,
            originSize=fixed_size,
            originSpcaing=[
                src.GetSpacing()[0],
                src.GetSpacing()[1],
                src.GetSpacing()[2]
            ],
            resamplemethod=sitk.sitkNearestNeighbor)

        pd_array = sitk.GetArrayFromImage(ys_pd_itk)
        print(np.shape(pd_array))

        # 4 get range of corse kidney
        expandslice = 5
        startpostion, endpostion = getRangImageDepth(pd_array)
        if startpostion == endpostion:
            print("corse error")
        imagez = np.shape(pd_array)[2]
        startpostion = startpostion - expandslice
        endpostion = endpostion + expandslice
        if startpostion < 0:
            startpostion = 0
        if endpostion > imagez:
            endpostion = imagez
        print("casenaem:", path_list[subsetindex])
        print("startposition:", startpostion)
        print("endpostion:", endpostion)
        out.writelines(path_list[subsetindex] + "," + str(startpostion) + "," +
                       str(endpostion) + "\n")