Exemplo n.º 1
0
    def __init__(self, sess, args):
        self.sess = sess

        ####patients folder name
        self.train_patient_no = [d.split('/')[-1] for d in glob(args.dcm_path + '/*') \
        if ('zip' not in d) & (d.split('/')[-1] not in args.test_patient_no)]
        self.test_patient_no = args.test_patient_no

        #save directory
        self.p_info = '_'.join(self.test_patient_no)
        self.checkpoint_dir = os.path.join(args.result, args.checkpoint_dir,
                                           self.p_info)
        self.log_dir = os.path.join(args.result, args.log_dir, self.p_info)
        print('directory check!!\ncheckpoint : {}\ntensorboard_logs : {}'.format(\
            self.checkpoint_dir, self.log_dir))

        #module
        self.discriminator = md.discriminator
        self.generator = md.generator
        """
        load images
        """
        print('data load... dicom -> numpy')
        self.image_loader = ut.DCMDataLoader(\
              args.dcm_path, args.LDCT_path, args.NDCT_path, \
             image_size = args.whole_size, patch_size = args.patch_size, \
             depth = args.img_channel, image_max = args.trun_max, image_min = args.trun_min,\
             is_unpair = args.is_unpair, augument = args.augument, norm = args.norm)

        self.test_image_loader = ut.DCMDataLoader(\
             args.dcm_path, args.LDCT_path, args.NDCT_path,\
             image_size = args.whole_size, patch_size = args.patch_size, \
              depth = args.img_channel, image_max = args.trun_max, image_min = args.trun_min,\
             is_unpair = args.is_unpair, augument = args.augument, norm = args.norm)

        t1 = time.time()
        if args.phase == 'train':
            self.image_loader(self.train_patient_no)
            self.test_image_loader(self.test_patient_no)
            print('data load complete !!!, {}\nN_train : {}, N_test : {}'.format(\
                time.time() - t1, len(self.image_loader.LDCT_image_name), \
                len(self.test_image_loader.LDCT_image_name)))

        else:
            self.test_image_loader(self.test_patient_no)
            print('data load complete !!!, {}, N_test : {}'.format(\
                time.time() - t1, len(self.test_image_loader.LDCT_image_name)))
        """
        build model
        """
        self.real_X =  tf.placeholder(tf.float32, \
         [None, args.patch_size, args.patch_size, args.img_channel], name = 'LDCT')
        self.real_Y =  tf.placeholder(tf.float32, \
         [None, args.patch_size, args.patch_size, args.img_channel], name = 'NDCT')
        self.sample_GX =  tf.placeholder(tf.float32, \
         [None, args.patch_size, args.patch_size, args.img_channel], name = 'G_LDCT')
        self.sample_FY =  tf.placeholder(tf.float32, \
         [None, args.patch_size, args.patch_size, args.img_channel], name = 'F_NDCT')

        self.whole_X =  tf.placeholder(tf.float32, \
                    [1, args.whole_size, args.whole_size, args.img_channel],\
                    name = 'LDCT')
        self.whole_Y = tf.placeholder(tf.float32, \
                    [1, args.whole_size, args.whole_size, args.img_channel],\
                    name = 'NDCT')

        #### Generator & Discriminator
        #Generator
        self.G_X = self.generator(args, self.real_X, False, \
            name="generatorX2Y")
        self.F_GX = self.generator(args, self.G_X, False, \
            name="generatorY2X")
        self.F_Y = self.generator(args, self.real_Y, True, \
            name="generatorY2X")
        self.G_FY = self.generator(args, self.F_Y, True, \
            name="generatorX2Y")

        self.G_Y = self.generator(args, self.real_Y, True, \
        name="generatorX2Y")
        self.F_X = self.generator(args, self.real_X, True, \
        name="generatorY2X")

        #Discriminator
        self.D_GX = self.discriminator(args, self.G_X, reuse=False, \
            name="discriminatorY")
        self.D_FY = self.discriminator(args, self.F_Y, reuse=False, \
            name="discriminatorX")
        self.D_sample_GX = self.discriminator(args, self.sample_GX, \
                  reuse=True, name="discriminatorY") #for discriminator loss
        self.D_sample_FY = self.discriminator(args, self.sample_FY, \
                  reuse=True, name="discriminatorX") #for discriminator loss
        self.D_Y = self.discriminator(args, self.real_Y, reuse=True, \
            name="discriminatorY")
        self.D_X = self.discriminator(args, self.real_X, reuse=True, \
            name="discriminatorX")

        #### Loss
        #generator loss
        self.G_loss_X2Y = md.least_square(\
            self.D_GX, tf.ones_like(self.D_GX))
        self.G_loss_Y2X = md.least_square(\
            self.D_FY, tf.ones_like(self.D_FY))

        self.G_loss = self.G_loss_X2Y + self.G_loss_Y2X

        if args.cycle_loss:
            self.cycle_loss = md.cycle_loss(\
                self.real_X, self.F_GX, self.real_Y, self.G_FY, args.L1_lambda)
            self.G_loss += self.cycle_loss

        if args.ident_loss:
            self.identity_loss = md.identity_loss(\
                self.real_X, self.G_Y, self.real_Y, self.F_X, args.L1_gamma)
            self.G_loss += self.identity_loss

        if args.resid_loss:
            self.residual_loss = md.residual_loss(\
                      self.real_X, self.G_X, self.F_GX, self.real_Y, \
                      self.F_Y,  self.G_FY, args.L1_delta)
            self.G_loss += self.residual_loss

        #dicriminator loss
        self.D_loss_real_Y = md.least_square(self.D_Y, tf.ones_like(self.D_Y))
        self.D_loss_GX = md.least_square(self.sample_GX,
                                         tf.zeros_like(self.sample_GX))
        self.D_loss_real_X = md.least_square(self.D_X, tf.ones_like(self.D_X))
        self.D_loss_FY = md.least_square(self.D_sample_FY,
                                         tf.zeros_like(self.D_sample_FY))
        self.D_loss_Y = (self.D_loss_real_Y + self.D_loss_GX)
        self.D_loss_X = (self.D_loss_real_X + self.D_loss_FY)
        self.D_loss = (self.D_loss_X + self.D_loss_Y) / 2

        #### variable list
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
        self.g_vars = [var for var in t_vars if 'generator' in var.name]

        #### optimizer
        self.lr = tf.placeholder(tf.float32, None, name='learning_rate')
        self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
            .minimize(self.D_loss, var_list=self.d_vars)
        self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
            .minimize(self.G_loss, var_list=self.g_vars)
        """
        Summary
        """
        #### loss summary
        #generator
        self.G_loss_sum = tf.summary.scalar("1_G_loss", self.G_loss, \
            family = 'Generator_loss')
        self.G_loss_X2Y_sum = tf.summary.scalar("2_G_loss_X2Y", self.G_loss_X2Y, \
            family = 'Generator_loss')
        self.G_loss_Y2X_sum = tf.summary.scalar("3_G_loss_Y2X", self.G_loss_Y2X, \
            family = 'Generator_loss')

        generator_loss_list = [
            self.G_loss_sum, self.G_loss_X2Y_sum, self.G_loss_Y2X_sum
        ]
        if args.cycle_loss:
            self.cycle_loss_sum = tf.summary.scalar("4_cycle_loss", self.cycle_loss, \
            family = 'Generator_loss')
            generator_loss_list += [self.cycle_loss_sum]
        if args.ident_loss:
            self.identity_loss_sum = tf.summary.scalar("5_identity_loss", self.identity_loss, \
            family = 'Generator_loss')
            generator_loss_list += [self.identity_loss_sum]
        if args.resid_loss:
            self.residual_loss_sum = tf.summary.scalar("6_residual_loss", self.residual_loss, \
            family = 'Generator_loss')
            generator_loss_list += [self.residual_loss_sum]

        self.g_sum = tf.summary.merge(generator_loss_list)

        #discriminator
        self.D_loss_sum = tf.summary.scalar("1_D_loss", self.D_loss, \
            family = 'Discriminator_loss')
        self.D_loss_Y_sum = tf.summary.scalar("2_D_loss_Y", self.D_loss_real_Y, \
            family = 'Discriminator_loss')
        self.D_loss_GX_sum = tf.summary.scalar("3_D_loss_GX", self.D_loss_GX, \
            family = 'Discriminator_loss')
        self.d_sum = tf.summary.merge(\
            [self.D_loss_sum, self.D_loss_Y_sum, self.D_loss_GX_sum])

        #### image summary
        self.test_G_X = self.generator(\
            args, self.whole_X, True, name="generatorX2Y")
        self.train_img_summary = tf.concat(\
            [self.real_X, self.real_Y, self.G_X], axis = 2)
        self.summary_image_1 = tf.summary.image('1_train_image', \
            self.train_img_summary)
        self.test_img_summary = tf.concat(\
            [self.whole_X, self.whole_Y, self.test_G_X], axis = 2)
        self.summary_image_2 = tf.summary.image('2_test_image', \
            self.test_img_summary)

        #### psnr summary
        self.summary_psnr_ldct = tf.summary.scalar("1_psnr_LDCT", \
        ut.tf_psnr(self.whole_X, self.whole_Y, 2), family = 'PSNR')  #-1 ~ 1
        self.summary_psnr_result = tf.summary.scalar("2_psnr_output", \
        ut.tf_psnr(self.whole_Y, self.test_G_X, 2), family = 'PSNR')  #-1 ~ 1
        self.summary_psnr = tf.summary.merge([self.summary_psnr_ldct, \
            self.summary_psnr_result])

        #model saver
        self.saver = tf.train.Saver(max_to_keep=None)

        #image pool
        self.pool = md.ImagePool(args.max_size)

        print('--------------------------------------------\n# of parameters : {} '.\
             format(np.sum([np.prod(v.get_shape().as_list()) \
                for v in tf.trainable_variables()])))
Exemplo n.º 2
0
    def __init__(self, sess, args):
        self.sess = sess

        ####patients folder name
        self.train_patient_no = [
            d.split('/')[-1] for d in glob(args.dcm_path + '/*')
            if ('zip' not in d)
            & (d.split('/')[-1] not in args.test_patient_no)
        ]
        self.test_patient_no = args.test_patient_no

        #save directory
        self.p_info = '_'.join(self.test_patient_no)
        self.checkpoint_dir = os.path.join('.', args.checkpoint_dir,
                                           self.p_info)
        self.log_dir = os.path.join('.', 'logs', self.p_info)
        print(
            'directory check!!\ncheckpoint : {}\ntensorboard_logs : {}'.format(
                self.checkpoint_dir, self.log_dir))

        #module
        self.discriminator = md.discriminator
        self.generator = md.generator

        #network options
        OPTIONS = namedtuple(
            'OPTIONS', 'gf_dim glf_dim df_dim \
                              img_channel is_training')
        self.options = OPTIONS._make((args.ngf, args.nglf, args.ndf,
                                      args.img_channel, args.phase == 'train'))
        """
        load images
        """
        print('data load... dicom -> numpy')
        self.image_loader = ut.DCMDataLoader(args.dcm_path, args.LDCT_path, args.NDCT_path, \
             image_size = args.whole_size, patch_size = args.patch_size, depth = args.img_channel,
             image_max = args.img_vmax, image_min = args.img_vmin, batch_size = args.batch_size, \
             is_unpair = args.unpair, model = args.model)

        self.test_image_loader = ut.DCMDataLoader(args.dcm_path, args.LDCT_path, args.NDCT_path,\
             image_size = args.whole_size, patch_size = args.patch_size, depth = args.img_channel,
             image_max = args.img_vmax, image_min = args.img_vmin, batch_size = args.batch_size, \
             is_unpair = args.unpair, model = args.model)

        t1 = time.time()
        if args.phase == 'train':
            self.image_loader(self.train_patient_no)
            self.test_image_loader(self.test_patient_no)
            print(
                'data load complete !!!, {}\nN_train : {}, N_test : {}'.format(
                    time.time() - t1, len(self.image_loader.LDCT_image_name),
                    len(self.test_image_loader.LDCT_image_name)))
            [self.patch_X, self.patch_Y
             ] = self.image_loader.input_pipeline(self.sess, args.patch_size,
                                                  args.end_epoch)
        else:
            self.test_image_loader(self.test_patient_no)
            print('data load complete !!!, {}, N_test : {}'.format(
                time.time() - t1, len(self.test_image_loader.LDCT_image_name)))

            self.patch_X = tf.placeholder(
                tf.float32,
                [None, args.patch_size, args.patch_size, args.img_channel],
                name='LDCT')
            self.patch_Y = tf.placeholder(
                tf.float32,
                [None, args.patch_size, args.patch_size, args.img_channel],
                name='NDCT')
        """
        build model
        """
        #### image placehold(for test)
        self.test_X = tf.placeholder(
            tf.float32,
            [None, args.whole_size, args.whole_size, args.img_channel],
            name='X')
        self.test_Y = tf.placeholder(
            tf.float32,
            [None, args.whole_size, args.whole_size, args.img_channel],
            name='Y')

        #### Generator & Discriminator
        #Generator
        self.G_X = self.generator(self.patch_X,
                                  self.options,
                                  False,
                                  name="generatorX2Y")
        self.F_GX = self.generator(self.G_X,
                                   self.options,
                                   False,
                                   name="generatorY2X")
        self.F_Y = self.generator(self.patch_Y,
                                  self.options,
                                  True,
                                  name="generatorY2X")
        self.G_FY = self.generator(self.F_Y,
                                   self.options,
                                   True,
                                   name="generatorX2Y")

        self.G_Y = self.generator(self.patch_Y,
                                  self.options,
                                  True,
                                  name="generatorX2Y")  #G : x->y
        self.F_X = self.generator(self.patch_X,
                                  self.options,
                                  True,
                                  name="generatorY2X")  #F : y->X

        #Discriminator
        self.D_GX = self.discriminator(self.G_X,
                                       self.options,
                                       reuse=False,
                                       name="discriminatorY")
        self.D_FY = self.discriminator(self.F_Y,
                                       self.options,
                                       reuse=False,
                                       name="discriminatorX")
        self.D_Y = self.discriminator(self.patch_Y,
                                      self.options,
                                      reuse=True,
                                      name="discriminatorY")
        self.D_X = self.discriminator(self.patch_X,
                                      self.options,
                                      reuse=True,
                                      name="discriminatorX")

        #### Loss
        #generator loss
        self.cycle_loss = md.cycle_loss(self.patch_X, self.F_GX, self.patch_Y,
                                        self.G_FY, args.L1_lambda)
        self.identity_loss = md.identity_loss(self.patch_X, self.G_Y,
                                              self.patch_Y, self.F_X,
                                              args.L1_gamma)
        self.G_loss_X2Y = md.least_square(self.D_GX, tf.ones_like(self.D_GX))
        self.G_loss_Y2X = md.least_square(self.D_FY, tf.ones_like(self.D_FY))

        self.G_loss = self.G_loss_X2Y + self.G_loss_Y2X + self.cycle_loss + self.identity_loss

        #dicriminator loss
        self.D_loss_patch_Y = md.least_square(self.D_Y, tf.ones_like(self.D_Y))
        self.D_loss_patch_GX = md.least_square(self.D_GX,
                                               tf.zeros_like(self.D_GX))
        self.D_loss_patch_X = md.least_square(self.D_X, tf.ones_like(self.D_X))
        self.D_loss_patch_FY = md.least_square(self.D_FY,
                                               tf.zeros_like(self.D_FY))

        self.D_loss_Y = (self.D_loss_patch_Y + self.D_loss_patch_GX)
        self.D_loss_X = (self.D_loss_patch_X + self.D_loss_patch_FY)
        self.D_loss = (self.D_loss_X + self.D_loss_Y) / 2

        #### variable list
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
        self.g_vars = [var for var in t_vars if 'generator' in var.name]

        #### optimizer
        self.lr = tf.placeholder(tf.float32, None, name='learning_rate')
        self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
            .minimize(self.D_loss, var_list=self.d_vars)
        self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
            .minimize(self.G_loss, var_list=self.g_vars)
        """
        Summary
        """
        #### loss summary
        #generator
        self.G_loss_sum = tf.summary.scalar("1_G_loss",
                                            self.G_loss,
                                            family='Generator_loss')
        self.cycle_loss_sum = tf.summary.scalar("2_cycle_loss",
                                                self.cycle_loss,
                                                family='Generator_loss')
        self.identity_loss_sum = tf.summary.scalar("3_identity_loss",
                                                   self.identity_loss,
                                                   family='Generator_loss')
        self.G_loss_X2Y_sum = tf.summary.scalar("4_G_loss_X2Y",
                                                self.G_loss_X2Y,
                                                family='Generator_loss')
        self.G_loss_Y2X_sum = tf.summary.scalar("5_G_loss_Y2X",
                                                self.G_loss_Y2X,
                                                family='Generator_loss')
        self.g_sum = tf.summary.merge([
            self.G_loss_sum, self.cycle_loss_sum, self.identity_loss_sum,
            self.G_loss_X2Y_sum, self.G_loss_Y2X_sum
        ])

        #discriminator
        self.D_loss_sum = tf.summary.scalar("1_D_loss",
                                            self.D_loss,
                                            family='Discriminator_loss')
        self.D_loss_Y_sum = tf.summary.scalar("2_D_loss_Y",
                                              self.D_loss_patch_Y,
                                              family='Discriminator_loss')
        self.D_loss_GX_sum = tf.summary.scalar("3_D_loss_GX",
                                               self.D_loss_patch_GX,
                                               family='Discriminator_loss')
        self.d_sum = tf.summary.merge(
            [self.D_loss_sum, self.D_loss_Y_sum, self.D_loss_GX_sum])

        #### image summary
        self.test_G_X = self.generator(self.test_X,
                                       self.options,
                                       True,
                                       name="generatorX2Y")
        self.train_img_summary = tf.concat(
            [self.patch_X, self.patch_Y, self.G_X], axis=2)
        self.summary_image_1 = tf.summary.image('1_train_patch_image',
                                                self.train_img_summary)
        self.test_img_summary = tf.concat(
            [self.test_X, self.test_Y, self.test_G_X], axis=2)
        self.summary_image_2 = tf.summary.image('2_test_whole_image',
                                                self.test_img_summary)

        #### psnr summary
        self.summary_psnr_ldct = tf.summary.scalar("1_psnr_LDCT",
                                                   ut.tf_psnr(
                                                       self.test_X,
                                                       self.test_Y, 2),
                                                   family='PSNR')  #-1 ~ 1
        self.summary_psnr_result = tf.summary.scalar("2_psnr_output",
                                                     ut.tf_psnr(
                                                         self.test_Y,
                                                         self.test_G_X, 2),
                                                     family='PSNR')  #-1 ~ 1
        self.summary_psnr = tf.summary.merge(
            [self.summary_psnr_ldct, self.summary_psnr_result])

        #model saver
        self.saver = tf.train.Saver(max_to_keep=None)

        print('--------------------------------------------\n# of parameters : {} '.\
             format(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))
Exemplo n.º 3
0
    def __init__(self, sess, args):
        self.sess = sess    
        
        ####patients folder name
        self.train_patient_no = [d.split('/')[-1] for d in glob(args.dcm_path + '/*') if ('zip' not in d) & (d.split('/')[-1] not in args.test_patient_no)]     
        self.test_patient_no = args.test_patient_no    


        #save directory
        self.p_info = '_'.join(self.test_patient_no)
        self.checkpoint_dir = os.path.join(args.result, args.checkpoint_dir, self.p_info)
        self.log_dir = os.path.join(args.result, args.log_dir,  self.p_info)
        print('directory check!!\ncheckpoint : {}\ntensorboard_logs : {}'.format(self.checkpoint_dir, self.log_dir))

        #### set modules (generator, discriminator, vgg net)
        self.g_net = modules.generator
        self.d_net = modules.discriminator
        self.vgg = modules.Vgg19(vgg_path = args.pretrained_vgg) 
        
        """
        load images
        """
        print('data load... dicom -> numpy') 
        self.image_loader = ut.DCMDataLoader(\
              args.dcm_path, args.LDCT_path, args.NDCT_path, \
             image_size = args.whole_size, patch_size = args.patch_size, \
             depth = args.img_channel, image_max = args.trun_max, image_min = args.trun_min,\
             is_unpair = args.is_unpair, augument = args.augument, norm = args.norm)
                                     
        self.test_image_loader = ut.DCMDataLoader(\
             args.dcm_path, args.LDCT_path, args.NDCT_path,\
             image_size = args.whole_size, patch_size = args.patch_size, \
              depth = args.img_channel, image_max = args.trun_max, image_min = args.trun_min,\
             is_unpair = args.is_unpair, augument = args.augument, norm = args.norm)
        

        t1 = time.time()
        if args.phase == 'train':
            self.image_loader(self.train_patient_no)
            self.test_image_loader(self.test_patient_no)
            print('data load complete !!!, {}\nN_train : {}, N_test : {}'.format(time.time() - t1, len(self.image_loader.LDCT_image_name), len(self.test_image_loader.LDCT_image_name)))
        else:
            self.test_image_loader(self.test_patient_no)
            print('data load complete !!!, {}, N_test : {}'.format(time.time() - t1, len(self.test_image_loader.LDCT_image_name)))
            

        """
        build model
        """
        self.z_i = tf.placeholder(tf.float32, [None, args.patch_size, args.patch_size, args.img_channel], name = 'whole_LDCT')
        self.x_i = tf.placeholder(tf.float32, [None, args.patch_size, args.patch_size, args.img_channel], name = 'whole_LDCT')
        #### image placehold  (patch image, whole image)
        self.whole_z = tf.placeholder(tf.float32, [1, args.whole_size, args.whole_size, args.img_channel], name = 'whole_LDCT')
        self.whole_x = tf.placeholder(tf.float32, [1, args.whole_size, args.whole_size, args.img_channel], name = 'whole_NDCT')

        #### generate & discriminate
        #generated images
        self.G_zi = self.g_net(self.z_i, reuse = False)
        self.G_whole_zi = self.g_net(self.whole_z)

        #discriminate
        self.D_xi = self.d_net(self.x_i, reuse = False)
        self.D_G_zi= self.d_net(self.G_zi)

        #### loss define
        #gradients penalty
        self.epsilon = tf.random_uniform([], 0.0, 1.0)
        self.x_hat = self.epsilon * self.x_i + (1 - self.epsilon) * self.G_zi
        self.D_x_hat = self.d_net(self.x_hat)
        self.grad_x_hat = tf.gradients(self.D_x_hat, self.x_hat)[0]
        self.grad_x_hat_l2 = tf.sqrt(tf.reduce_sum(tf.square(self.grad_x_hat), axis=1))
        self.gradient_penalty =  tf.square(self.grad_x_hat_l2 - 1.0)

        #perceptual loss
        self.G_zi_3c = tf.concat([self.G_zi]*3, axis=3)
        self.xi_3c = tf.concat([self.x_i]*3, axis=3)
        [w, h, d] = self.G_zi_3c.get_shape().as_list()[1:]
        self.vgg_perc_loss = tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square((self.vgg.extract_feature(self.G_zi_3c) -  self.vgg.extract_feature(self.xi_3c))))) / (w*h*d))

        #discriminator loss(WGAN LOSS)
        d_loss = tf.reduce_mean(self.D_G_zi) - tf.reduce_mean(self.D_xi) 
        grad_penal =  args.lambda_ *tf.reduce_mean(self.gradient_penalty )
        self.D_loss = d_loss +grad_penal
        #generator loss
        self.G_loss = args.lambda_1 * self.vgg_perc_loss - tf.reduce_mean(self.D_G_zi)


        #### variable list
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
        self.g_vars = [var for var in t_vars if 'generator' in var.name]

        """
        summary
        """
        #loss summary
        self.summary_vgg_perc_loss = tf.summary.scalar("1_PerceptualLoss_VGG", self.vgg_perc_loss)
        self.summary_d_loss_all = tf.summary.scalar("2_DiscriminatorLoss_WGAN", self.D_loss)
        self.summary_d_loss_1 = tf.summary.scalar("3_D_loss_disc", d_loss)
        self.summary_d_loss_2 = tf.summary.scalar("4_D_loss_gradient_penalty", grad_penal)
        self.summary_g_loss = tf.summary.scalar("GeneratorLoss", self.G_loss)
        self.summary_all_loss = tf.summary.merge([self.summary_vgg_perc_loss, self.summary_d_loss_all, self.summary_d_loss_1, self.summary_d_loss_2, self.summary_g_loss])
            
        #psnr summary
        self.summary_psnr_ldct = tf.summary.scalar("1_psnr_LDCT", ut.tf_psnr(self.whole_z, self.whole_x, 1), family = 'PSNR')  # 0 ~ 1
        self.summary_psnr_result = tf.summary.scalar("2_psnr_output", ut.tf_psnr(self.whole_x, self.G_whole_zi, 1), family = 'PSNR')  # 0 ~ 1
        self.summary_psnr = tf.summary.merge([self.summary_psnr_ldct, self.summary_psnr_result])
        
 
        #image summary
        self.check_img_summary = tf.concat([tf.expand_dims(self.z_i[0], axis=0), \
                                            tf.expand_dims(self.x_i[0], axis=0), \
                                            tf.expand_dims(self.G_zi[0], axis=0)], axis = 2)        
        self.summary_train_image = tf.summary.image('0_train_image', self.check_img_summary)                                    
        self.whole_img_summary = tf.concat([self.whole_z, self.whole_x, self.G_whole_zi], axis = 2)        
        self.summary_image = tf.summary.image('1_whole_image', self.whole_img_summary)
        
        #### optimizer
        self.d_adam, self.g_adam = None, None
        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            self.d_adam = tf.train.AdamOptimizer(learning_rate= args.alpha, beta1 = args.beta1, beta2 = args.beta2).minimize(self.D_loss, var_list = self.d_vars)
            self.g_adam = tf.train.AdamOptimizer(learning_rate= args.alpha, beta1 = args.beta1, beta2 = args.beta2).minimize(self.G_loss, var_list = self.g_vars)
                
        #model saver
        self.saver = tf.train.Saver(max_to_keep=None)    

        print('--------------------------------------------\n# of parameters : {} '.\
              format(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))
Exemplo n.º 4
0
    def __init__(self, sess, args):
        self.sess = sess

        ####patients folder name
        self.train_patent_no = [
            d.split('/')[-1] for d in glob(args.dcm_path + '/*')
            if ('zip' not in d)
            & (d.split('/')[-1] not in args.test_patient_no)
        ]
        self.test_patent_no = args.test_patient_no

        #save directory
        self.p_info = '_'.join(self.test_patent_no)
        self.checkpoint_dir = os.path.join(args.result, args.checkpoint_dir,
                                           self.p_info)
        self.log_dir = os.path.join(args.result, args.log_dir, self.p_info)
        print(
            'directory check!!\ncheckpoint : {}\ntensorboard_logs : {}'.format(
                self.checkpoint_dir, self.log_dir))

        #### set modules
        self.red_cnn = modules.redcnn
        """
        load images
        """
        print('data load... dicom -> numpy')
        self.image_loader = ut.DCMDataLoader(\
              args.dcm_path, args.LDCT_path, args.NDCT_path, \
             image_size = args.whole_size, patch_size = args.patch_size, \
             depth = args.img_channel, image_max = args.trun_max, image_min = args.trun_min,\
             is_unpair = args.is_unpair, augument = args.augument, norm = args.norm)

        self.test_image_loader = ut.DCMDataLoader(\
             args.dcm_path, args.LDCT_path, args.NDCT_path,\
             image_size = args.whole_size, patch_size = args.patch_size, \
              depth = args.img_channel, image_max = args.trun_max, image_min = args.trun_min,\
             is_unpair = args.is_unpair, augument = args.augument, norm = args.norm)

        t1 = time.time()
        if args.phase == 'train':
            self.image_loader(self.train_patent_no)
            self.test_image_loader(self.test_patent_no)
            print(
                'data load complete !!!, {}\nN_train : {}, N_test : {}'.format(
                    time.time() - t1, len(self.image_loader.LDCT_image_name),
                    len(self.test_image_loader.LDCT_image_name)))
        else:
            self.test_image_loader(self.test_patent_no)
            print('data load complete !!!, {}, N_test : {}'.format(
                time.time() - t1, len(self.test_image_loader.LDCT_image_name)))
        """
        build model
        """
        self.X = tf.placeholder(
            tf.float32,
            [None, args.patch_size, args.patch_size, args.img_channel],
            name='LDCT')
        self.Y = tf.placeholder(
            tf.float32,
            [None, args.patch_size, args.patch_size, args.img_channel],
            name='NDCT')
        self.whole_X = tf.placeholder(
            tf.float32,
            [1, args.whole_size, args.whole_size, args.img_channel],
            name='whole_LDCT')
        self.whole_Y = tf.placeholder(
            tf.float32,
            [1, args.whole_size, args.whole_size, args.img_channel],
            name='whole_NDCT')

        #### denoised images
        self.output_img = self.red_cnn(self.X, reuse=False)
        self.WHOLE_output_img = self.red_cnn(self.whole_X)

        #### loss
        self.loss = tf.reduce_mean(
            tf.squared_difference(self.Y, self.output_img))

        #### trainable variable list
        self.t_vars = tf.trainable_variables()

        #### optimizer
        self.global_step = tf.Variable(0, trainable=False)
        self.learning_rate = tf.train.exponential_decay(args.alpha,
                                                        self.global_step,
                                                        args.num_iter,
                                                        args.decay_rate,
                                                        staircase=True)
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate=self.learning_rate).minimize(
                self.loss, var_list=self.t_vars, global_step=self.global_step)
        """
        summary
        """
        #loss summary
        self.summary_loss = tf.summary.scalar("loss", self.loss)
        #psnr summary
        self.summary_psnr_ldct = tf.summary.scalar("1_psnr_LDCT",
                                                   ut.tf_psnr(
                                                       self.whole_Y,
                                                       self.whole_X, 1),
                                                   family='PSNR')  # 0 ~ 1
        self.summary_psnr_result = tf.summary.scalar(
            "2_psnr_output",
            ut.tf_psnr(self.whole_Y, self.WHOLE_output_img, 1),
            family='PSNR')  # 0 ~ 1
        self.summary_psnr = tf.summary.merge(
            [self.summary_psnr_ldct, self.summary_psnr_result])

        #image summary
        self.check_img_summary = tf.concat([tf.expand_dims(self.X[0], axis=0), \
                                    tf.expand_dims(self.Y[0], axis=0), \
                                    tf.expand_dims(self.output_img[0], axis=0)], axis = 2)
        self.summary_train_image = tf.summary.image('0_train_image',
                                                    self.check_img_summary)
        self.whole_img_summary = tf.concat(
            [self.whole_X, self.whole_Y, self.WHOLE_output_img], axis=2)
        self.summary_image = tf.summary.image('1_whole_image',
                                              self.whole_img_summary)

        #model saver
        self.saver = tf.train.Saver(max_to_keep=None)

        print('--------------------------------------------\n# of parameters : {} '.\
              format(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))
Exemplo n.º 5
0
    def __init__(self, sess, args):
        self.sess = sess

        ####patients folder name
        self.train_patent_no = [
            d.split('/')[-1] for d in glob(args.dcm_path + '/*')
            if ('zip' not in d) & (d not in args.test_patient_no)
        ]
        self.test_patent_no = args.test_patient_no

        #### set modules
        self.red_cnn = modules.redcnn
        """
        load images
        """
        print('data load... dicom -> numpy')
        self.image_loader = ut.DCMDataLoader(args.dcm_path, args.LDCT_path, args.NDCT_path, \
             image_size = args.whole_size, patch_size = args.patch_size, depth = args.img_channel,
             image_max = args.img_vmax, image_min = args.img_vmin, batch_size = args.batch_size, model = args.model)

        self.test_image_loader = ut.DCMDataLoader(args.dcm_path, args.LDCT_path, args.NDCT_path,\
             image_size = args.whole_size, patch_size = args.patch_size, depth = args.img_channel,
             image_max = args.img_vmax, image_min = args.img_vmin, batch_size = args.batch_size, model = args.model)

        t1 = time.time()
        if args.phase == 'train':
            self.image_loader(self.train_patent_no)
            self.test_image_loader(self.test_patent_no)
            print(
                'data load complete !!!, {}\nN_train : {}, N_test : {}'.format(
                    time.time() - t1, len(self.image_loader.LDCT_image_name),
                    len(self.test_image_loader.LDCT_image_name)))
            [self.X, self.Y
             ] = self.image_loader.input_pipeline(self.sess, args.patch_size,
                                                  args.num_iter)
        else:
            self.test_image_loader(self.test_patent_no)
            print('data load complete !!!, {}, N_test : {}'.format(
                time.time() - t1, len(self.test_image_loader.LDCT_image_name)))
            self.X = tf.placeholder(
                tf.float32,
                [None, args.patch_size, args.patch_size, args.img_channel],
                name='LDCT')
            self.Y = tf.placeholder(
                tf.float32,
                [None, args.patch_size, args.patch_size, args.img_channel],
                name='NDCT')
        """
        build model
        """
        self.whole_X = tf.placeholder(
            tf.float32,
            [1, args.whole_size, args.whole_size, args.img_channel],
            name='whole_LDCT')
        self.whole_Y = tf.placeholder(
            tf.float32,
            [1, args.whole_size, args.whole_size, args.img_channel],
            name='whole_NDCT')

        #### denoised images
        self.output_img = self.red_cnn(self.X, reuse=False)
        self.WHOLE_output_img = self.red_cnn(self.whole_X)

        #### loss
        self.loss = tf.reduce_mean(
            tf.squared_difference(self.Y, self.output_img))

        #### trainable variable list
        self.t_vars = tf.trainable_variables()

        #### optimizer
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate=args.alpha).minimize(self.loss, var_list=self.t_vars)
        """
        summary
        """
        #loss summary
        self.summary_loss = tf.summary.scalar("loss", self.loss)
        #psnr summary
        self.summary_psnr_ldct = tf.summary.scalar("1_psnr_LDCT",
                                                   ut.tf_psnr(
                                                       self.whole_Y,
                                                       self.whole_X, 1),
                                                   family='PSNR')  # 0 ~ 1
        self.summary_psnr_result = tf.summary.scalar(
            "2_psnr_output",
            ut.tf_psnr(self.whole_Y, self.WHOLE_output_img, 1),
            family='PSNR')  # 0 ~ 1
        self.summary_psnr = tf.summary.merge(
            [self.summary_psnr_ldct, self.summary_psnr_result])

        #image summary
        self.check_img_summary = tf.concat([tf.expand_dims(self.X[0], axis=0), \
                                    tf.expand_dims(self.Y[0], axis=0), \
                                    tf.expand_dims(self.output_img[0], axis=0)], axis = 2)
        self.summary_train_image = tf.summary.image('0_train_image',
                                                    self.check_img_summary)
        self.whole_img_summary = tf.concat(
            [self.whole_X, self.whole_Y, self.WHOLE_output_img], axis=2)
        self.summary_image = tf.summary.image('1_whole_image',
                                              self.whole_img_summary)

        #ROI summary
        if args.mayo_roi:
            self.ROI_X = tf.placeholder(tf.float32,
                                        [None, 128, 128, args.img_channel],
                                        name='ROI_X')
            self.ROI_Y = tf.placeholder(tf.float32,
                                        [None, 128, 128, args.img_channel],
                                        name='ROI_Y')
            self.ROI_output = self.red_cnn(self.ROI_X)

            self.ROI_real_img_summary = tf.concat(
                [self.ROI_X, self.ROI_Y, self.ROI_output], axis=2)
            self.summary_ROI_image_1 = tf.summary.image(
                '2_ROI_image_1', self.ROI_real_img_summary)
            self.summary_ROI_image_2 = tf.summary.image(
                '3_ROI_image_2', self.ROI_real_img_summary)

        #model saver
        self.saver = tf.train.Saver(max_to_keep=None)

        print('--------------------------------------------\n# of parameters : {} '.\
              format(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))
Exemplo n.º 6
0
    def __init__(self, args):
        # save directory
        if args.taskID:
            self.taskID = args.taskID
        else:
            self.taskID = ut.TaskID_Generator()
        self.checkpoint_dir = os.path.join(args.checkpoint_dir, self.taskID)
        self.log_dir = os.path.join(args.checkpoint_dir, self.taskID + '_tb')
        print(
            'directory check!!\ncheckpoint : {}\ntensorboard_logs : {}'.format(
                self.checkpoint_dir, self.log_dir))

        # network options
        OPTIONS = namedtuple(
            'OPTIONS', 'gf_dim glf_dim df_dim \
                              img_channel is_training')
        self.options = OPTIONS._make((args.ngf, args.nglf, args.ndf,
                                      args.img_channel, args.phase == 'train'))
        """
        load images
        """
        print('data load... dicom -> numpy')

        t1 = time.time()
        if args.phase == 'train':
            self.train_image_loader = ut.DCMDataLoader(
                args.data_path,
                image_size=args.whole_size,
                patch_size=args.patch_size,
                image_max=args.img_vmax,
                image_min=args.img_vmin,
                batch_size=args.batch_size,
                extension=args.extension)
            self.test_image_loader = ut.DCMDataLoader(
                args.data_path,
                image_size=args.whole_size,
                patch_size=args.patch_size,
                image_max=args.img_vmax,
                image_min=args.img_vmin,
                batch_size=args.batch_size,
                extension=args.extension)
            self.train_image_loader(args.train_patient_no_A,
                                    args.train_patient_no_B)
            self.test_image_loader(args.test_patient_no_A,
                                   args.test_patient_no_B)
            self.patch_X_set, self.patch_Y_set = self.train_image_loader.get_train_set(
                args.patch_size)
            self.whole_X_set, self.whole_Y_set = self.test_image_loader.get_test_set(
            )
            print('data load complete !!!, {}\n'.format(time.time() - t1))
            print('N_train : {}, N_test : {}'.format(
                self.train_image_loader.LDCT_images_size,
                self.test_image_loader.LDCT_images_size))
        else:
            self.test_image_loader = ut.DCMDataLoader(
                args.data_path,
                image_size=args.whole_size,
                patch_size=args.patch_size,
                image_max=args.img_vmax,
                image_min=args.img_vmin,
                batch_size=args.batch_size,
                extension=args.extension,
                phase=args.phase)
            self.test_image_loader(args.test_patient_no_A,
                                   args.test_patient_no_B)
            self.whole_X_set, self.whole_Y_set = self.test_image_loader.get_test_set(
            )
            print('data load complete !!!, {}, N_test : {}'.format(
                time.time() - t1, self.test_image_loader.LDCT_images_size))
        """
        build model
        """
        if args.phase == 'train':
            input_shape = (args.patch_size, args.patch_size, args.img_channel)
        else:
            input_shape = (args.whole_size, args.whole_size, args.img_channel)
        # Generator
        self.generator_G = md.generator(input_shape,
                                        self.options,
                                        name="generatorX2Y")
        self.generator_F = md.generator(input_shape,
                                        self.options,
                                        name="generatorY2X")
        # Discriminator
        self.discriminator_X = md.discriminator(input_shape,
                                                self.options,
                                                name="discriminatorX")
        self.discriminator_Y = md.discriminator(input_shape,
                                                self.options,
                                                name="discriminatorY")
        """
        set check point
        """
        self.ckpt = tf.train.Checkpoint(
            step=tf.Variable(0, dtype=tf.int64),
            generator_G=self.generator_G,
            generator_F=self.generator_F,
            discriminator_X=self.discriminator_X,
            discriminator_Y=self.discriminator_Y,
            generator_optimizer=tf.keras.optimizers.Adam(learning_rate=args.lr,
                                                         beta_1=args.beta1,
                                                         beta_2=args.beta2),
            discriminator_optimizer=tf.keras.optimizers.Adam(
                learning_rate=args.lr, beta_1=args.beta1, beta_2=args.beta2))
        """
        Summary writer (TensorBoard)
        """
        self.writer = tf.summary.create_file_writer(self.log_dir)