Esempio n. 1
0
    def __init__(self, sess, args):
        self.sess = sess

        ####patients folder name
        self.train_patient_no = [
            d.split('/')[-1] for d in glob(args.dcm_path + '/*')
            if ('zip' not in d)
            & (d.split('/')[-1] not in args.test_patient_no)
        ]
        self.test_patient_no = args.test_patient_no

        #save directory
        self.p_info = '_'.join(self.test_patient_no)
        self.checkpoint_dir = os.path.join('.', args.checkpoint_dir,
                                           self.p_info)
        self.log_dir = os.path.join('.', 'logs', self.p_info)
        print(
            'directory check!!\ncheckpoint : {}\ntensorboard_logs : {}'.format(
                self.checkpoint_dir, self.log_dir))

        #module
        self.discriminator = md.discriminator
        self.generator = md.generator

        #network options
        OPTIONS = namedtuple(
            'OPTIONS', 'gf_dim glf_dim df_dim \
                              img_channel is_training')
        self.options = OPTIONS._make((args.ngf, args.nglf, args.ndf,
                                      args.img_channel, args.phase == 'train'))
        """
        load images
        """
        print('data load... dicom -> numpy')
        self.image_loader = ut.DCMDataLoader(args.dcm_path, args.LDCT_path, args.NDCT_path, \
             image_size = args.whole_size, patch_size = args.patch_size, depth = args.img_channel,
             image_max = args.img_vmax, image_min = args.img_vmin, batch_size = args.batch_size, \
             is_unpair = args.unpair, model = args.model)

        self.test_image_loader = ut.DCMDataLoader(args.dcm_path, args.LDCT_path, args.NDCT_path,\
             image_size = args.whole_size, patch_size = args.patch_size, depth = args.img_channel,
             image_max = args.img_vmax, image_min = args.img_vmin, batch_size = args.batch_size, \
             is_unpair = args.unpair, model = args.model)

        t1 = time.time()
        if args.phase == 'train':
            self.image_loader(self.train_patient_no)
            self.test_image_loader(self.test_patient_no)
            print(
                'data load complete !!!, {}\nN_train : {}, N_test : {}'.format(
                    time.time() - t1, len(self.image_loader.LDCT_image_name),
                    len(self.test_image_loader.LDCT_image_name)))
            [self.patch_X, self.patch_Y
             ] = self.image_loader.input_pipeline(self.sess, args.patch_size,
                                                  args.end_epoch)
        else:
            self.test_image_loader(self.test_patient_no)
            print('data load complete !!!, {}, N_test : {}'.format(
                time.time() - t1, len(self.test_image_loader.LDCT_image_name)))

            self.patch_X = tf.placeholder(
                tf.float32,
                [None, args.patch_size, args.patch_size, args.img_channel],
                name='LDCT')
            self.patch_Y = tf.placeholder(
                tf.float32,
                [None, args.patch_size, args.patch_size, args.img_channel],
                name='NDCT')
        """
        build model
        """
        #### image placehold(for test)
        self.test_X = tf.placeholder(
            tf.float32,
            [None, args.whole_size, args.whole_size, args.img_channel],
            name='X')
        self.test_Y = tf.placeholder(
            tf.float32,
            [None, args.whole_size, args.whole_size, args.img_channel],
            name='Y')

        #### Generator & Discriminator
        #Generator
        self.G_X = self.generator(self.patch_X,
                                  self.options,
                                  False,
                                  name="generatorX2Y")
        self.F_GX = self.generator(self.G_X,
                                   self.options,
                                   False,
                                   name="generatorY2X")
        self.F_Y = self.generator(self.patch_Y,
                                  self.options,
                                  True,
                                  name="generatorY2X")
        self.G_FY = self.generator(self.F_Y,
                                   self.options,
                                   True,
                                   name="generatorX2Y")

        self.G_Y = self.generator(self.patch_Y,
                                  self.options,
                                  True,
                                  name="generatorX2Y")  #G : x->y
        self.F_X = self.generator(self.patch_X,
                                  self.options,
                                  True,
                                  name="generatorY2X")  #F : y->X

        #Discriminator
        self.D_GX = self.discriminator(self.G_X,
                                       self.options,
                                       reuse=False,
                                       name="discriminatorY")
        self.D_FY = self.discriminator(self.F_Y,
                                       self.options,
                                       reuse=False,
                                       name="discriminatorX")
        self.D_Y = self.discriminator(self.patch_Y,
                                      self.options,
                                      reuse=True,
                                      name="discriminatorY")
        self.D_X = self.discriminator(self.patch_X,
                                      self.options,
                                      reuse=True,
                                      name="discriminatorX")

        #### Loss
        #generator loss
        self.cycle_loss = md.cycle_loss(self.patch_X, self.F_GX, self.patch_Y,
                                        self.G_FY, args.L1_lambda)
        self.identity_loss = md.identity_loss(self.patch_X, self.G_Y,
                                              self.patch_Y, self.F_X,
                                              args.L1_gamma)
        self.G_loss_X2Y = md.least_square(self.D_GX, tf.ones_like(self.D_GX))
        self.G_loss_Y2X = md.least_square(self.D_FY, tf.ones_like(self.D_FY))

        self.G_loss = self.G_loss_X2Y + self.G_loss_Y2X + self.cycle_loss + self.identity_loss

        #dicriminator loss
        self.D_loss_patch_Y = md.least_square(self.D_Y, tf.ones_like(self.D_Y))
        self.D_loss_patch_GX = md.least_square(self.D_GX,
                                               tf.zeros_like(self.D_GX))
        self.D_loss_patch_X = md.least_square(self.D_X, tf.ones_like(self.D_X))
        self.D_loss_patch_FY = md.least_square(self.D_FY,
                                               tf.zeros_like(self.D_FY))

        self.D_loss_Y = (self.D_loss_patch_Y + self.D_loss_patch_GX)
        self.D_loss_X = (self.D_loss_patch_X + self.D_loss_patch_FY)
        self.D_loss = (self.D_loss_X + self.D_loss_Y) / 2

        #### variable list
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
        self.g_vars = [var for var in t_vars if 'generator' in var.name]

        #### optimizer
        self.lr = tf.placeholder(tf.float32, None, name='learning_rate')
        self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
            .minimize(self.D_loss, var_list=self.d_vars)
        self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
            .minimize(self.G_loss, var_list=self.g_vars)
        """
        Summary
        """
        #### loss summary
        #generator
        self.G_loss_sum = tf.summary.scalar("1_G_loss",
                                            self.G_loss,
                                            family='Generator_loss')
        self.cycle_loss_sum = tf.summary.scalar("2_cycle_loss",
                                                self.cycle_loss,
                                                family='Generator_loss')
        self.identity_loss_sum = tf.summary.scalar("3_identity_loss",
                                                   self.identity_loss,
                                                   family='Generator_loss')
        self.G_loss_X2Y_sum = tf.summary.scalar("4_G_loss_X2Y",
                                                self.G_loss_X2Y,
                                                family='Generator_loss')
        self.G_loss_Y2X_sum = tf.summary.scalar("5_G_loss_Y2X",
                                                self.G_loss_Y2X,
                                                family='Generator_loss')
        self.g_sum = tf.summary.merge([
            self.G_loss_sum, self.cycle_loss_sum, self.identity_loss_sum,
            self.G_loss_X2Y_sum, self.G_loss_Y2X_sum
        ])

        #discriminator
        self.D_loss_sum = tf.summary.scalar("1_D_loss",
                                            self.D_loss,
                                            family='Discriminator_loss')
        self.D_loss_Y_sum = tf.summary.scalar("2_D_loss_Y",
                                              self.D_loss_patch_Y,
                                              family='Discriminator_loss')
        self.D_loss_GX_sum = tf.summary.scalar("3_D_loss_GX",
                                               self.D_loss_patch_GX,
                                               family='Discriminator_loss')
        self.d_sum = tf.summary.merge(
            [self.D_loss_sum, self.D_loss_Y_sum, self.D_loss_GX_sum])

        #### image summary
        self.test_G_X = self.generator(self.test_X,
                                       self.options,
                                       True,
                                       name="generatorX2Y")
        self.train_img_summary = tf.concat(
            [self.patch_X, self.patch_Y, self.G_X], axis=2)
        self.summary_image_1 = tf.summary.image('1_train_patch_image',
                                                self.train_img_summary)
        self.test_img_summary = tf.concat(
            [self.test_X, self.test_Y, self.test_G_X], axis=2)
        self.summary_image_2 = tf.summary.image('2_test_whole_image',
                                                self.test_img_summary)

        #### psnr summary
        self.summary_psnr_ldct = tf.summary.scalar("1_psnr_LDCT",
                                                   ut.tf_psnr(
                                                       self.test_X,
                                                       self.test_Y, 2),
                                                   family='PSNR')  #-1 ~ 1
        self.summary_psnr_result = tf.summary.scalar("2_psnr_output",
                                                     ut.tf_psnr(
                                                         self.test_Y,
                                                         self.test_G_X, 2),
                                                     family='PSNR')  #-1 ~ 1
        self.summary_psnr = tf.summary.merge(
            [self.summary_psnr_ldct, self.summary_psnr_result])

        #model saver
        self.saver = tf.train.Saver(max_to_keep=None)

        print('--------------------------------------------\n# of parameters : {} '.\
             format(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))
Esempio n. 2
0
    def __init__(self, sess, args):
        self.sess = sess

        ####patients folder name
        self.train_patient_no = [d.split('/')[-1] for d in glob(args.dcm_path + '/*') \
        if ('zip' not in d) & (d.split('/')[-1] not in args.test_patient_no)]
        self.test_patient_no = args.test_patient_no

        #save directory
        self.p_info = '_'.join(self.test_patient_no)
        self.checkpoint_dir = os.path.join(args.result, args.checkpoint_dir,
                                           self.p_info)
        self.log_dir = os.path.join(args.result, args.log_dir, self.p_info)
        print('directory check!!\ncheckpoint : {}\ntensorboard_logs : {}'.format(\
            self.checkpoint_dir, self.log_dir))

        #module
        self.discriminator = md.discriminator
        self.generator = md.generator
        """
        load images
        """
        print('data load... dicom -> numpy')
        self.image_loader = ut.DCMDataLoader(\
              args.dcm_path, args.LDCT_path, args.NDCT_path, \
             image_size = args.whole_size, patch_size = args.patch_size, \
             depth = args.img_channel, image_max = args.trun_max, image_min = args.trun_min,\
             is_unpair = args.is_unpair, augument = args.augument, norm = args.norm)

        self.test_image_loader = ut.DCMDataLoader(\
             args.dcm_path, args.LDCT_path, args.NDCT_path,\
             image_size = args.whole_size, patch_size = args.patch_size, \
              depth = args.img_channel, image_max = args.trun_max, image_min = args.trun_min,\
             is_unpair = args.is_unpair, augument = args.augument, norm = args.norm)

        t1 = time.time()
        if args.phase == 'train':
            self.image_loader(self.train_patient_no)
            self.test_image_loader(self.test_patient_no)
            print('data load complete !!!, {}\nN_train : {}, N_test : {}'.format(\
                time.time() - t1, len(self.image_loader.LDCT_image_name), \
                len(self.test_image_loader.LDCT_image_name)))

        else:
            self.test_image_loader(self.test_patient_no)
            print('data load complete !!!, {}, N_test : {}'.format(\
                time.time() - t1, len(self.test_image_loader.LDCT_image_name)))
        """
        build model
        """
        self.real_X =  tf.placeholder(tf.float32, \
         [None, args.patch_size, args.patch_size, args.img_channel], name = 'LDCT')
        self.real_Y =  tf.placeholder(tf.float32, \
         [None, args.patch_size, args.patch_size, args.img_channel], name = 'NDCT')
        self.sample_GX =  tf.placeholder(tf.float32, \
         [None, args.patch_size, args.patch_size, args.img_channel], name = 'G_LDCT')
        self.sample_FY =  tf.placeholder(tf.float32, \
         [None, args.patch_size, args.patch_size, args.img_channel], name = 'F_NDCT')

        self.whole_X =  tf.placeholder(tf.float32, \
                    [1, args.whole_size, args.whole_size, args.img_channel],\
                    name = 'LDCT')
        self.whole_Y = tf.placeholder(tf.float32, \
                    [1, args.whole_size, args.whole_size, args.img_channel],\
                    name = 'NDCT')

        #### Generator & Discriminator
        #Generator
        self.G_X = self.generator(args, self.real_X, False, \
            name="generatorX2Y")
        self.F_GX = self.generator(args, self.G_X, False, \
            name="generatorY2X")
        self.F_Y = self.generator(args, self.real_Y, True, \
            name="generatorY2X")
        self.G_FY = self.generator(args, self.F_Y, True, \
            name="generatorX2Y")

        self.G_Y = self.generator(args, self.real_Y, True, \
        name="generatorX2Y")
        self.F_X = self.generator(args, self.real_X, True, \
        name="generatorY2X")

        #Discriminator
        self.D_GX = self.discriminator(args, self.G_X, reuse=False, \
            name="discriminatorY")
        self.D_FY = self.discriminator(args, self.F_Y, reuse=False, \
            name="discriminatorX")
        self.D_sample_GX = self.discriminator(args, self.sample_GX, \
                  reuse=True, name="discriminatorY") #for discriminator loss
        self.D_sample_FY = self.discriminator(args, self.sample_FY, \
                  reuse=True, name="discriminatorX") #for discriminator loss
        self.D_Y = self.discriminator(args, self.real_Y, reuse=True, \
            name="discriminatorY")
        self.D_X = self.discriminator(args, self.real_X, reuse=True, \
            name="discriminatorX")

        #### Loss
        #generator loss
        self.G_loss_X2Y = md.least_square(\
            self.D_GX, tf.ones_like(self.D_GX))
        self.G_loss_Y2X = md.least_square(\
            self.D_FY, tf.ones_like(self.D_FY))

        self.G_loss = self.G_loss_X2Y + self.G_loss_Y2X

        if args.cycle_loss:
            self.cycle_loss = md.cycle_loss(\
                self.real_X, self.F_GX, self.real_Y, self.G_FY, args.L1_lambda)
            self.G_loss += self.cycle_loss

        if args.ident_loss:
            self.identity_loss = md.identity_loss(\
                self.real_X, self.G_Y, self.real_Y, self.F_X, args.L1_gamma)
            self.G_loss += self.identity_loss

        if args.resid_loss:
            self.residual_loss = md.residual_loss(\
                      self.real_X, self.G_X, self.F_GX, self.real_Y, \
                      self.F_Y,  self.G_FY, args.L1_delta)
            self.G_loss += self.residual_loss

        #dicriminator loss
        self.D_loss_real_Y = md.least_square(self.D_Y, tf.ones_like(self.D_Y))
        self.D_loss_GX = md.least_square(self.sample_GX,
                                         tf.zeros_like(self.sample_GX))
        self.D_loss_real_X = md.least_square(self.D_X, tf.ones_like(self.D_X))
        self.D_loss_FY = md.least_square(self.D_sample_FY,
                                         tf.zeros_like(self.D_sample_FY))
        self.D_loss_Y = (self.D_loss_real_Y + self.D_loss_GX)
        self.D_loss_X = (self.D_loss_real_X + self.D_loss_FY)
        self.D_loss = (self.D_loss_X + self.D_loss_Y) / 2

        #### variable list
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
        self.g_vars = [var for var in t_vars if 'generator' in var.name]

        #### optimizer
        self.lr = tf.placeholder(tf.float32, None, name='learning_rate')
        self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
            .minimize(self.D_loss, var_list=self.d_vars)
        self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
            .minimize(self.G_loss, var_list=self.g_vars)
        """
        Summary
        """
        #### loss summary
        #generator
        self.G_loss_sum = tf.summary.scalar("1_G_loss", self.G_loss, \
            family = 'Generator_loss')
        self.G_loss_X2Y_sum = tf.summary.scalar("2_G_loss_X2Y", self.G_loss_X2Y, \
            family = 'Generator_loss')
        self.G_loss_Y2X_sum = tf.summary.scalar("3_G_loss_Y2X", self.G_loss_Y2X, \
            family = 'Generator_loss')

        generator_loss_list = [
            self.G_loss_sum, self.G_loss_X2Y_sum, self.G_loss_Y2X_sum
        ]
        if args.cycle_loss:
            self.cycle_loss_sum = tf.summary.scalar("4_cycle_loss", self.cycle_loss, \
            family = 'Generator_loss')
            generator_loss_list += [self.cycle_loss_sum]
        if args.ident_loss:
            self.identity_loss_sum = tf.summary.scalar("5_identity_loss", self.identity_loss, \
            family = 'Generator_loss')
            generator_loss_list += [self.identity_loss_sum]
        if args.resid_loss:
            self.residual_loss_sum = tf.summary.scalar("6_residual_loss", self.residual_loss, \
            family = 'Generator_loss')
            generator_loss_list += [self.residual_loss_sum]

        self.g_sum = tf.summary.merge(generator_loss_list)

        #discriminator
        self.D_loss_sum = tf.summary.scalar("1_D_loss", self.D_loss, \
            family = 'Discriminator_loss')
        self.D_loss_Y_sum = tf.summary.scalar("2_D_loss_Y", self.D_loss_real_Y, \
            family = 'Discriminator_loss')
        self.D_loss_GX_sum = tf.summary.scalar("3_D_loss_GX", self.D_loss_GX, \
            family = 'Discriminator_loss')
        self.d_sum = tf.summary.merge(\
            [self.D_loss_sum, self.D_loss_Y_sum, self.D_loss_GX_sum])

        #### image summary
        self.test_G_X = self.generator(\
            args, self.whole_X, True, name="generatorX2Y")
        self.train_img_summary = tf.concat(\
            [self.real_X, self.real_Y, self.G_X], axis = 2)
        self.summary_image_1 = tf.summary.image('1_train_image', \
            self.train_img_summary)
        self.test_img_summary = tf.concat(\
            [self.whole_X, self.whole_Y, self.test_G_X], axis = 2)
        self.summary_image_2 = tf.summary.image('2_test_image', \
            self.test_img_summary)

        #### psnr summary
        self.summary_psnr_ldct = tf.summary.scalar("1_psnr_LDCT", \
        ut.tf_psnr(self.whole_X, self.whole_Y, 2), family = 'PSNR')  #-1 ~ 1
        self.summary_psnr_result = tf.summary.scalar("2_psnr_output", \
        ut.tf_psnr(self.whole_Y, self.test_G_X, 2), family = 'PSNR')  #-1 ~ 1
        self.summary_psnr = tf.summary.merge([self.summary_psnr_ldct, \
            self.summary_psnr_result])

        #model saver
        self.saver = tf.train.Saver(max_to_keep=None)

        #image pool
        self.pool = md.ImagePool(args.max_size)

        print('--------------------------------------------\n# of parameters : {} '.\
             format(np.sum([np.prod(v.get_shape().as_list()) \
                for v in tf.trainable_variables()])))
Esempio n. 3
0
    def __init__(self, sess, args):
        self.sess = sess

        ####patients folder name
        self.train_patent_no = [
            d.split('/')[-1] for d in glob(args.dcm_path + '/*')
            if ('zip' not in d)
            & (d.split('/')[-1] not in args.test_patient_no)
        ]
        self.test_patent_no = args.test_patient_no

        #save directory
        self.p_info = '_'.join(self.test_patent_no)
        self.checkpoint_dir = os.path.join(args.result, args.checkpoint_dir,
                                           self.p_info)
        self.log_dir = os.path.join(args.result, args.log_dir, self.p_info)
        print(
            'directory check!!\ncheckpoint : {}\ntensorboard_logs : {}'.format(
                self.checkpoint_dir, self.log_dir))

        #### set modules
        self.red_cnn = modules.redcnn
        """
        load images
        """
        print('data load... dicom -> numpy')
        self.image_loader = ut.DCMDataLoader(\
              args.dcm_path, args.LDCT_path, args.NDCT_path, \
             image_size = args.whole_size, patch_size = args.patch_size, \
             depth = args.img_channel, image_max = args.trun_max, image_min = args.trun_min,\
             is_unpair = args.is_unpair, augument = args.augument, norm = args.norm)

        self.test_image_loader = ut.DCMDataLoader(\
             args.dcm_path, args.LDCT_path, args.NDCT_path,\
             image_size = args.whole_size, patch_size = args.patch_size, \
              depth = args.img_channel, image_max = args.trun_max, image_min = args.trun_min,\
             is_unpair = args.is_unpair, augument = args.augument, norm = args.norm)

        t1 = time.time()
        if args.phase == 'train':
            self.image_loader(self.train_patent_no)
            self.test_image_loader(self.test_patent_no)
            print(
                'data load complete !!!, {}\nN_train : {}, N_test : {}'.format(
                    time.time() - t1, len(self.image_loader.LDCT_image_name),
                    len(self.test_image_loader.LDCT_image_name)))
        else:
            self.test_image_loader(self.test_patent_no)
            print('data load complete !!!, {}, N_test : {}'.format(
                time.time() - t1, len(self.test_image_loader.LDCT_image_name)))
        """
        build model
        """
        self.X = tf.placeholder(
            tf.float32,
            [None, args.patch_size, args.patch_size, args.img_channel],
            name='LDCT')
        self.Y = tf.placeholder(
            tf.float32,
            [None, args.patch_size, args.patch_size, args.img_channel],
            name='NDCT')
        self.whole_X = tf.placeholder(
            tf.float32,
            [1, args.whole_size, args.whole_size, args.img_channel],
            name='whole_LDCT')
        self.whole_Y = tf.placeholder(
            tf.float32,
            [1, args.whole_size, args.whole_size, args.img_channel],
            name='whole_NDCT')

        #### denoised images
        self.output_img = self.red_cnn(self.X, reuse=False)
        self.WHOLE_output_img = self.red_cnn(self.whole_X)

        #### loss
        self.loss = tf.reduce_mean(
            tf.squared_difference(self.Y, self.output_img))

        #### trainable variable list
        self.t_vars = tf.trainable_variables()

        #### optimizer
        self.global_step = tf.Variable(0, trainable=False)
        self.learning_rate = tf.train.exponential_decay(args.alpha,
                                                        self.global_step,
                                                        args.num_iter,
                                                        args.decay_rate,
                                                        staircase=True)
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate=self.learning_rate).minimize(
                self.loss, var_list=self.t_vars, global_step=self.global_step)
        """
        summary
        """
        #loss summary
        self.summary_loss = tf.summary.scalar("loss", self.loss)
        #psnr summary
        self.summary_psnr_ldct = tf.summary.scalar("1_psnr_LDCT",
                                                   ut.tf_psnr(
                                                       self.whole_Y,
                                                       self.whole_X, 1),
                                                   family='PSNR')  # 0 ~ 1
        self.summary_psnr_result = tf.summary.scalar(
            "2_psnr_output",
            ut.tf_psnr(self.whole_Y, self.WHOLE_output_img, 1),
            family='PSNR')  # 0 ~ 1
        self.summary_psnr = tf.summary.merge(
            [self.summary_psnr_ldct, self.summary_psnr_result])

        #image summary
        self.check_img_summary = tf.concat([tf.expand_dims(self.X[0], axis=0), \
                                    tf.expand_dims(self.Y[0], axis=0), \
                                    tf.expand_dims(self.output_img[0], axis=0)], axis = 2)
        self.summary_train_image = tf.summary.image('0_train_image',
                                                    self.check_img_summary)
        self.whole_img_summary = tf.concat(
            [self.whole_X, self.whole_Y, self.WHOLE_output_img], axis=2)
        self.summary_image = tf.summary.image('1_whole_image',
                                              self.whole_img_summary)

        #model saver
        self.saver = tf.train.Saver(max_to_keep=None)

        print('--------------------------------------------\n# of parameters : {} '.\
              format(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))
Esempio n. 4
0
    def __init__(self, sess, args):
        self.sess = sess    
        
        ####patients folder name
        self.train_patient_no = [d.split('/')[-1] for d in glob(args.dcm_path + '/*') if ('zip' not in d) & (d.split('/')[-1] not in args.test_patient_no)]     
        self.test_patient_no = args.test_patient_no    


        #save directory
        self.p_info = '_'.join(self.test_patient_no)
        self.checkpoint_dir = os.path.join(args.result, args.checkpoint_dir, self.p_info)
        self.log_dir = os.path.join(args.result, args.log_dir,  self.p_info)
        print('directory check!!\ncheckpoint : {}\ntensorboard_logs : {}'.format(self.checkpoint_dir, self.log_dir))

        #### set modules (generator, discriminator, vgg net)
        self.g_net = modules.generator
        self.d_net = modules.discriminator
        self.vgg = modules.Vgg19(vgg_path = args.pretrained_vgg) 
        
        """
        load images
        """
        print('data load... dicom -> numpy') 
        self.image_loader = ut.DCMDataLoader(\
              args.dcm_path, args.LDCT_path, args.NDCT_path, \
             image_size = args.whole_size, patch_size = args.patch_size, \
             depth = args.img_channel, image_max = args.trun_max, image_min = args.trun_min,\
             is_unpair = args.is_unpair, augument = args.augument, norm = args.norm)
                                     
        self.test_image_loader = ut.DCMDataLoader(\
             args.dcm_path, args.LDCT_path, args.NDCT_path,\
             image_size = args.whole_size, patch_size = args.patch_size, \
              depth = args.img_channel, image_max = args.trun_max, image_min = args.trun_min,\
             is_unpair = args.is_unpair, augument = args.augument, norm = args.norm)
        

        t1 = time.time()
        if args.phase == 'train':
            self.image_loader(self.train_patient_no)
            self.test_image_loader(self.test_patient_no)
            print('data load complete !!!, {}\nN_train : {}, N_test : {}'.format(time.time() - t1, len(self.image_loader.LDCT_image_name), len(self.test_image_loader.LDCT_image_name)))
        else:
            self.test_image_loader(self.test_patient_no)
            print('data load complete !!!, {}, N_test : {}'.format(time.time() - t1, len(self.test_image_loader.LDCT_image_name)))
            

        """
        build model
        """
        self.z_i = tf.placeholder(tf.float32, [None, args.patch_size, args.patch_size, args.img_channel], name = 'whole_LDCT')
        self.x_i = tf.placeholder(tf.float32, [None, args.patch_size, args.patch_size, args.img_channel], name = 'whole_LDCT')
        #### image placehold  (patch image, whole image)
        self.whole_z = tf.placeholder(tf.float32, [1, args.whole_size, args.whole_size, args.img_channel], name = 'whole_LDCT')
        self.whole_x = tf.placeholder(tf.float32, [1, args.whole_size, args.whole_size, args.img_channel], name = 'whole_NDCT')

        #### generate & discriminate
        #generated images
        self.G_zi = self.g_net(self.z_i, reuse = False)
        self.G_whole_zi = self.g_net(self.whole_z)

        #discriminate
        self.D_xi = self.d_net(self.x_i, reuse = False)
        self.D_G_zi= self.d_net(self.G_zi)

        #### loss define
        #gradients penalty
        self.epsilon = tf.random_uniform([], 0.0, 1.0)
        self.x_hat = self.epsilon * self.x_i + (1 - self.epsilon) * self.G_zi
        self.D_x_hat = self.d_net(self.x_hat)
        self.grad_x_hat = tf.gradients(self.D_x_hat, self.x_hat)[0]
        self.grad_x_hat_l2 = tf.sqrt(tf.reduce_sum(tf.square(self.grad_x_hat), axis=1))
        self.gradient_penalty =  tf.square(self.grad_x_hat_l2 - 1.0)

        #perceptual loss
        self.G_zi_3c = tf.concat([self.G_zi]*3, axis=3)
        self.xi_3c = tf.concat([self.x_i]*3, axis=3)
        [w, h, d] = self.G_zi_3c.get_shape().as_list()[1:]
        self.vgg_perc_loss = tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square((self.vgg.extract_feature(self.G_zi_3c) -  self.vgg.extract_feature(self.xi_3c))))) / (w*h*d))

        #discriminator loss(WGAN LOSS)
        d_loss = tf.reduce_mean(self.D_G_zi) - tf.reduce_mean(self.D_xi) 
        grad_penal =  args.lambda_ *tf.reduce_mean(self.gradient_penalty )
        self.D_loss = d_loss +grad_penal
        #generator loss
        self.G_loss = args.lambda_1 * self.vgg_perc_loss - tf.reduce_mean(self.D_G_zi)


        #### variable list
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
        self.g_vars = [var for var in t_vars if 'generator' in var.name]

        """
        summary
        """
        #loss summary
        self.summary_vgg_perc_loss = tf.summary.scalar("1_PerceptualLoss_VGG", self.vgg_perc_loss)
        self.summary_d_loss_all = tf.summary.scalar("2_DiscriminatorLoss_WGAN", self.D_loss)
        self.summary_d_loss_1 = tf.summary.scalar("3_D_loss_disc", d_loss)
        self.summary_d_loss_2 = tf.summary.scalar("4_D_loss_gradient_penalty", grad_penal)
        self.summary_g_loss = tf.summary.scalar("GeneratorLoss", self.G_loss)
        self.summary_all_loss = tf.summary.merge([self.summary_vgg_perc_loss, self.summary_d_loss_all, self.summary_d_loss_1, self.summary_d_loss_2, self.summary_g_loss])
            
        #psnr summary
        self.summary_psnr_ldct = tf.summary.scalar("1_psnr_LDCT", ut.tf_psnr(self.whole_z, self.whole_x, 1), family = 'PSNR')  # 0 ~ 1
        self.summary_psnr_result = tf.summary.scalar("2_psnr_output", ut.tf_psnr(self.whole_x, self.G_whole_zi, 1), family = 'PSNR')  # 0 ~ 1
        self.summary_psnr = tf.summary.merge([self.summary_psnr_ldct, self.summary_psnr_result])
        
 
        #image summary
        self.check_img_summary = tf.concat([tf.expand_dims(self.z_i[0], axis=0), \
                                            tf.expand_dims(self.x_i[0], axis=0), \
                                            tf.expand_dims(self.G_zi[0], axis=0)], axis = 2)        
        self.summary_train_image = tf.summary.image('0_train_image', self.check_img_summary)                                    
        self.whole_img_summary = tf.concat([self.whole_z, self.whole_x, self.G_whole_zi], axis = 2)        
        self.summary_image = tf.summary.image('1_whole_image', self.whole_img_summary)
        
        #### optimizer
        self.d_adam, self.g_adam = None, None
        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            self.d_adam = tf.train.AdamOptimizer(learning_rate= args.alpha, beta1 = args.beta1, beta2 = args.beta2).minimize(self.D_loss, var_list = self.d_vars)
            self.g_adam = tf.train.AdamOptimizer(learning_rate= args.alpha, beta1 = args.beta1, beta2 = args.beta2).minimize(self.G_loss, var_list = self.g_vars)
                
        #model saver
        self.saver = tf.train.Saver(max_to_keep=None)    

        print('--------------------------------------------\n# of parameters : {} '.\
              format(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))
Esempio n. 5
0
    def __init__(self, sess, args):
        self.sess = sess

        ####patients folder name
        self.train_patent_no = [
            d.split('/')[-1] for d in glob(args.dcm_path + '/*')
            if ('zip' not in d) & (d not in args.test_patient_no)
        ]
        self.test_patent_no = args.test_patient_no

        #### set modules
        self.red_cnn = modules.redcnn
        """
        load images
        """
        print('data load... dicom -> numpy')
        self.image_loader = ut.DCMDataLoader(args.dcm_path, args.LDCT_path, args.NDCT_path, \
             image_size = args.whole_size, patch_size = args.patch_size, depth = args.img_channel,
             image_max = args.img_vmax, image_min = args.img_vmin, batch_size = args.batch_size, model = args.model)

        self.test_image_loader = ut.DCMDataLoader(args.dcm_path, args.LDCT_path, args.NDCT_path,\
             image_size = args.whole_size, patch_size = args.patch_size, depth = args.img_channel,
             image_max = args.img_vmax, image_min = args.img_vmin, batch_size = args.batch_size, model = args.model)

        t1 = time.time()
        if args.phase == 'train':
            self.image_loader(self.train_patent_no)
            self.test_image_loader(self.test_patent_no)
            print(
                'data load complete !!!, {}\nN_train : {}, N_test : {}'.format(
                    time.time() - t1, len(self.image_loader.LDCT_image_name),
                    len(self.test_image_loader.LDCT_image_name)))
            [self.X, self.Y
             ] = self.image_loader.input_pipeline(self.sess, args.patch_size,
                                                  args.num_iter)
        else:
            self.test_image_loader(self.test_patent_no)
            print('data load complete !!!, {}, N_test : {}'.format(
                time.time() - t1, len(self.test_image_loader.LDCT_image_name)))
            self.X = tf.placeholder(
                tf.float32,
                [None, args.patch_size, args.patch_size, args.img_channel],
                name='LDCT')
            self.Y = tf.placeholder(
                tf.float32,
                [None, args.patch_size, args.patch_size, args.img_channel],
                name='NDCT')
        """
        build model
        """
        self.whole_X = tf.placeholder(
            tf.float32,
            [1, args.whole_size, args.whole_size, args.img_channel],
            name='whole_LDCT')
        self.whole_Y = tf.placeholder(
            tf.float32,
            [1, args.whole_size, args.whole_size, args.img_channel],
            name='whole_NDCT')

        #### denoised images
        self.output_img = self.red_cnn(self.X, reuse=False)
        self.WHOLE_output_img = self.red_cnn(self.whole_X)

        #### loss
        self.loss = tf.reduce_mean(
            tf.squared_difference(self.Y, self.output_img))

        #### trainable variable list
        self.t_vars = tf.trainable_variables()

        #### optimizer
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate=args.alpha).minimize(self.loss, var_list=self.t_vars)
        """
        summary
        """
        #loss summary
        self.summary_loss = tf.summary.scalar("loss", self.loss)
        #psnr summary
        self.summary_psnr_ldct = tf.summary.scalar("1_psnr_LDCT",
                                                   ut.tf_psnr(
                                                       self.whole_Y,
                                                       self.whole_X, 1),
                                                   family='PSNR')  # 0 ~ 1
        self.summary_psnr_result = tf.summary.scalar(
            "2_psnr_output",
            ut.tf_psnr(self.whole_Y, self.WHOLE_output_img, 1),
            family='PSNR')  # 0 ~ 1
        self.summary_psnr = tf.summary.merge(
            [self.summary_psnr_ldct, self.summary_psnr_result])

        #image summary
        self.check_img_summary = tf.concat([tf.expand_dims(self.X[0], axis=0), \
                                    tf.expand_dims(self.Y[0], axis=0), \
                                    tf.expand_dims(self.output_img[0], axis=0)], axis = 2)
        self.summary_train_image = tf.summary.image('0_train_image',
                                                    self.check_img_summary)
        self.whole_img_summary = tf.concat(
            [self.whole_X, self.whole_Y, self.WHOLE_output_img], axis=2)
        self.summary_image = tf.summary.image('1_whole_image',
                                              self.whole_img_summary)

        #ROI summary
        if args.mayo_roi:
            self.ROI_X = tf.placeholder(tf.float32,
                                        [None, 128, 128, args.img_channel],
                                        name='ROI_X')
            self.ROI_Y = tf.placeholder(tf.float32,
                                        [None, 128, 128, args.img_channel],
                                        name='ROI_Y')
            self.ROI_output = self.red_cnn(self.ROI_X)

            self.ROI_real_img_summary = tf.concat(
                [self.ROI_X, self.ROI_Y, self.ROI_output], axis=2)
            self.summary_ROI_image_1 = tf.summary.image(
                '2_ROI_image_1', self.ROI_real_img_summary)
            self.summary_ROI_image_2 = tf.summary.image(
                '3_ROI_image_2', self.ROI_real_img_summary)

        #model saver
        self.saver = tf.train.Saver(max_to_keep=None)

        print('--------------------------------------------\n# of parameters : {} '.\
              format(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))
Esempio n. 6
0
        def check_test_sample(step):
            # take arbitrary sample in test_set
            buffer_size = 50
            sample_whole_X = tf.constant(
                list(
                    self.whole_X_set.shuffle(buffer_size).take(
                        1).as_numpy_iterator())[0])
            sample_whole_Y = tf.constant(
                list(
                    self.whole_Y_set.shuffle(buffer_size).take(
                        1).as_numpy_iterator())[0])
            G_X = self.generator_G(sample_whole_X, training=False)
            F_Y = self.generator_F(sample_whole_Y, training=False)

            with self.writer.as_default():
                with tf.name_scope("PSNR"):
                    tf.summary.scalar(name="1_psnr",
                                      step=step,
                                      data=ut.tf_psnr(sample_whole_X,
                                                      sample_whole_Y,
                                                      2))  # -1 ~ 1
                    tf.summary.scalar(name="2_psnr_AtoB",
                                      step=step,
                                      data=ut.tf_psnr(sample_whole_Y, G_X, 2))
                    tf.summary.scalar(name="2_psnr_BtoA",
                                      step=step,
                                      data=ut.tf_psnr(sample_whole_X, F_Y, 2))

            # re-scale for Tensorboard
            sample_whole_X = ut.rescale_arr(
                data=sample_whole_X,
                i_min=tf.math.reduce_min(sample_whole_X),
                i_max=tf.math.reduce_max(sample_whole_X),
                o_min=0,
                o_max=255,
                out_dtype=tf.uint8)
            sample_whole_Y = ut.rescale_arr(
                data=sample_whole_Y,
                i_min=tf.math.reduce_min(sample_whole_Y),
                i_max=tf.math.reduce_max(sample_whole_Y),
                o_min=0,
                o_max=255,
                out_dtype=tf.uint8)
            G_X = ut.rescale_arr(data=G_X,
                                 i_min=tf.math.reduce_min(G_X),
                                 i_max=tf.math.reduce_max(G_X),
                                 o_min=0,
                                 o_max=255,
                                 out_dtype=tf.uint8)
            F_Y = ut.rescale_arr(data=F_Y,
                                 i_min=tf.math.reduce_min(F_Y),
                                 i_max=tf.math.reduce_max(F_Y),
                                 o_min=0,
                                 o_max=255,
                                 out_dtype=tf.uint8)

            with self.writer.as_default():
                with tf.name_scope("check_test_sample"):
                    tf.summary.image(name="sample_whole_X",
                                     step=step,
                                     data=sample_whole_X,
                                     max_outputs=1)
                    tf.summary.image(name="sample_whole_Y",
                                     step=step,
                                     data=sample_whole_Y,
                                     max_outputs=1)
                    tf.summary.image(name="G(sample_whole_X)",
                                     step=step,
                                     data=G_X,
                                     max_outputs=1)
                    tf.summary.image(name="F(sample_whole_Y)",
                                     step=step,
                                     data=F_Y,
                                     max_outputs=1)
    def __init__(self, sess, args):
        self.sess = sess

        #### set modules (generator, discriminator, vgg net)
        self.g_net = nt.generator
        self.d_net = nt.discriminator
        self.vgg = nt.Vgg19(size=args.patch_size, vgg_path=args.pretrained_vgg)
        """
        build model
        """
        assert args.phase in ['train', 'test'], 'phase : train or test'
        if args.phase == 'test':
            self.sample_image_loader = ut.DataLoader(args, sample_ck=True)
            self.whole_z, self.whole_x = self.sample_image_loader.loader()
            self.G_whole_zi = self.g_net(self.whole_z, reuse=False)

        elif args.phase == 'train':
            self.image_loader = ut.DataLoader(args)
            self.sample_image_loader = ut.DataLoader(args, sample_ck=True)

            self.z_i, self.x_i = self.image_loader.loader()
            self.whole_z, self.whole_x = self.sample_image_loader.loader()

            #### generate & discriminate & feature extractor
            #generated images
            self.G_zi = self.g_net(self.z_i, reuse=False)
            self.G_whole_zi = self.g_net(self.whole_z)  #for sample check

            #discriminate
            self.D_G_zi = self.d_net(self.G_zi, reuse=False)
            self.D_xi = self.d_net(self.x_i)

            #make 3-channel img for pretrained_vgg model input
            self.G_zi_3c = tf.concat([self.G_zi] * 3, axis=-1)
            self.xi_3c = tf.concat([self.x_i] * 3, axis=-1)
            self.E_g_zi = self.vgg.extract_feature(self.G_zi_3c)
            self.E_xi = self.vgg.extract_feature(self.xi_3c)
            [w, h, d] = self.G_zi_3c.get_shape().as_list()[1:]

            #### loss define
            #discriminator loss
            self.wgan_d_loss = -tf.reduce_mean(self.D_xi) + tf.reduce_mean(
                self.D_G_zi)
            self.epsilon = tf.random_uniform([], 0.0, 1.0)
            self.x_hat = self.epsilon * self.x_i + (1 -
                                                    self.epsilon) * self.G_zi
            self.D_x_hat = self.d_net(self.x_hat)
            self.grad_x_hat = tf.gradients(self.D_x_hat, self.x_hat)[0]
            self.grad_x_hat_l2 = tf.sqrt(
                tf.reduce_sum(tf.square(self.grad_x_hat)))
            self.grad_penal = args.lambda_ * tf.reduce_mean(tf.square(self.grad_x_hat_l2 - \
                                   tf.ones_like(self.grad_x_hat_l2)))

            self.D_loss = self.wgan_d_loss + self.grad_penal

            #generator loss
            self.frobenius_norm2 = tf.reduce_sum(
                tf.square(self.E_g_zi - self.E_xi))
            self.vgg_perc_loss = tf.reduce_mean(self.frobenius_norm2 /
                                                (w * h * d))
            self.G_loss = args.lambda_1 * self.vgg_perc_loss - tf.reduce_mean(
                self.D_G_zi)

            #### variable list
            t_vars = tf.trainable_variables()
            self.d_vars = [
                var for var in t_vars if 'discriminator' in var.name
            ]
            self.g_vars = [var for var in t_vars if 'generator' in var.name]
            """
            summary
            """
            #loss summary
            self.summary_vgg_perc_loss = tf.summary.scalar("1_PerceptualLoss_VGG", \
                                                           self.vgg_perc_loss)
            self.summary_d_loss_all = tf.summary.scalar(
                "2_DiscriminatorLoss", self.D_loss)
            self.summary_d_loss_1 = tf.summary.scalar("3_D_loss_wgan",
                                                      self.wgan_d_loss)
            self.summary_d_loss_2 = tf.summary.scalar("4_D_loss_gradient_penalty", \
                                                      self.grad_penal)
            self.summary_g_loss = tf.summary.scalar("GeneratorLoss",
                                                    self.G_loss)
            self.summary_all_loss = tf.summary.merge([self.summary_vgg_perc_loss,\
                                                      self.summary_d_loss_all,\
                                                      self.summary_d_loss_1, \
                                                      self.summary_d_loss_2, \
                                                      self.summary_g_loss])

            #psnr summary
            self.summary_psnr_ldct = tf.summary.scalar("1_psnr_LDCT", \
                ut.tf_psnr(self.whole_z, self.whole_x, \
                           self.sample_image_loader.psnr_range), family = 'PSNR')
            self.summary_psnr_result = tf.summary.scalar("2_psnr_output", \
                ut.tf_psnr(self.whole_x, self.G_whole_zi, \
                           self.sample_image_loader.psnr_range), family = 'PSNR')
            self.summary_psnr = tf.summary.merge(
                [self.summary_psnr_ldct, self.summary_psnr_result])

            #image summary
            self.check_img_summary = tf.concat([tf.expand_dims(self.z_i[0], axis=0), \
                                                tf.expand_dims(self.x_i[0], axis=0), \
                                                tf.expand_dims(self.G_zi[0], axis=0)], \
                                               axis = 2)
            self.summary_train_image = tf.summary.image('0_train_image', \
                                                        self.check_img_summary)
            self.whole_img_summary = tf.concat([self.whole_z, \
                                                self.whole_x, self.G_whole_zi], axis = 2)
            self.summary_image = tf.summary.image('1_whole_image',
                                                  self.whole_img_summary)

            #### optimizer
            #self.d_adam, self.g_adam = None, None
            with tf.control_dependencies(
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                self.d_adam = tf.train.AdamOptimizer(learning_rate= args.lr, \
                                                     beta1 = args.beta1, \
                                                     beta2 = args.beta2).\
                                    minimize(self.D_loss, var_list = self.d_vars)
                self.g_adam = tf.train.AdamOptimizer(learning_rate= args.lr, \
                                                     beta1 = args.beta1, \
                                                     beta2 = args.beta2).\
                                    minimize(self.G_loss, var_list = self.g_vars)

            print('--------------------------------------------\n# of parameters : {} '.\
                  format(np.sum([np.prod(v.get_shape().as_list()) \
                                     for v in tf.trainable_variables()])))

        #model saver
        self.saver = tf.train.Saver(max_to_keep=None)
Esempio n. 8
0
    def __init__(self, sess, args):
        self.sess = sess    

        #### set modules (generator, discriminator)
        self.g_net = nt.generator
        self.d_net = nt.discriminator
        
        """
        build model
        """                       
        assert args.phase in ['train', 'test'], 'phase : train or test'
        if args.phase=='test':
            self.sample_image_loader = ut.DataLoader(args, sample_ck=True)
            self.whole_xi, self.whole_yi = self.sample_image_loader.loader()
            self.G_whole_xi = self.g_net(args, self.whole_xi, reuse=False)

        elif args.phase=='train':
            self.image_loader = ut.DataLoader(args)
            self.sample_image_loader = ut.DataLoader(args, sample_ck=True)

            self.x_i, self.y_i = self.image_loader.loader()
            self.whole_xi, self.whole_yi = self.sample_image_loader.loader()

            #### generate & discriminate & feature extractor
            #generated images
            self.G_xi = self.g_net(args, self.x_i, reuse = False)
            self.G_whole_xi = self.g_net(args, self.whole_xi) #for sample check

            #discriminate
            self.D_xGxi= self.d_net(args, self.x_i, self.G_xi, reuse = False)
            self.D_xyi = self.d_net(args, self.x_i, self.y_i)

 
            #### loss define : L1 + cGAN
            #discriminator loss
            self.EPS = 1e-12 # https://github.com/affinelayer/pix2pix-tensorflow/blob/master/pix2pix.py
            self.D_loss = tf.reduce_mean(-(tf.log(self.D_xyi + self.EPS) + tf.log(1 - self.D_xGxi + self.EPS)))

            #generator loss
            self.gen_loss_GAN = tf.reduce_mean(-tf.log(self.D_xGxi + self.EPS))
            self.gen_loss_L1 = tf.reduce_mean(tf.abs(self.y_i - self.G_xi))
            self.G_loss = args.gan_weight * self.gen_loss_GAN + args.l1_weight * self.gen_loss_L1

            #### variable list
            t_vars = tf.trainable_variables()
            self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
            self.g_vars = [var for var in t_vars if 'generator' in var.name]

            """
            summary
            """
            #loss summary
            self.summary_d_loss = tf.summary.scalar("1_DiscriminatorLoss", self.D_loss)
            self.summary_g_loss = tf.summary.scalar("2_GeneratorLoss", self.G_loss)
            self.summary_d_loss_1 = tf.summary.scalar("3_G_loss_adv", self.gen_loss_GAN)
            self.summary_d_loss_2 = tf.summary.scalar("4_G_loss_L1", self.gen_loss_L1)
            
            self.summary_all_loss = tf.summary.merge([self.summary_d_loss, self.summary_g_loss, self.summary_d_loss_1, self.summary_d_loss_2, ])
                
            #psnr summary
            self.summary_psnr_ldct = tf.summary.scalar("1_psnr_LDCT", \
                ut.tf_psnr(self.whole_xi, self.whole_yi, self.sample_image_loader.psnr_range), family = 'PSNR')
            self.summary_psnr_result = tf.summary.scalar("2_psnr_output", \
                ut.tf_psnr(self.whole_yi, self.G_whole_xi, self.sample_image_loader.psnr_range), family = 'PSNR')
            self.summary_psnr = tf.summary.merge([self.summary_psnr_ldct, self.summary_psnr_result])

            #image summary
            self.check_img_summary = tf.concat([tf.expand_dims(self.x_i[0], axis=0), \
                                                tf.expand_dims(self.y_i[0], axis=0), \
                                                tf.expand_dims(self.G_xi[0], axis=0)], axis = 2)        
            self.summary_train_image = tf.summary.image('0_train_image', self.check_img_summary)                                    
            self.whole_img_summary = tf.concat([self.whole_xi, self.whole_yi, self.G_whole_xi], axis = 2)        
            self.summary_image = tf.summary.image('1_whole_image', self.whole_img_summary)
            
            #### optimizer
            #self.d_adam, self.g_adam = None, None
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                self.d_adam = tf.train.AdamOptimizer(\
                         learning_rate= args.lr, beta1 = args.beta1, \
                         beta2 = args.beta2).minimize(self.D_loss, var_list = self.d_vars)
                self.g_adam = tf.train.AdamOptimizer(\
                         learning_rate= args.lr, beta1 = args.beta1, \
                         beta2 = args.beta2).minimize(self.G_loss, var_list = self.g_vars)

            print('--------------------------------------------\n# of parameters : {} '.\
                  format(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))

                    
        #model saver
        self.saver = tf.train.Saver(max_to_keep=None)    
Esempio n. 9
0
    def __init__(self, sess, args):
        self.sess = sess    
        
        #### set network
        self.red_cnn = nt.redcnn
        
        """
        build model
        """
        assert args.phase in ['train', 'test'], 'phase : train or test'
        if args.phase=='test':
            self.sample_image_loader = ut.DataLoader(args, sample_ck=True)
            self.whole_X, self.whole_Y = self.sample_image_loader.loader()
            self.WHOLE_output_img  = self.red_cnn(self.whole_X, reuse=False)
        
        elif args.phase=='train':
            self.image_loader = ut.DataLoader(args)
            self.X, self.Y = self.image_loader.loader()
            self.output_img = self.red_cnn(self.X, reuse = False)
            
            #### loss
            self.loss = tf.reduce_mean((self.output_img - self.Y)**2)

            #### trainable variable list
            self.t_vars = tf.trainable_variables()

            #### optimizer
            self.lr = tf.constant(args.start_lr, dtype=tf.float32)
            self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr).\
                                minimize(self.loss, var_list = self.t_vars)

            """
            summary
            """
            self.sample_image_loader = ut.DataLoader(args, sample_ck=True)
            self.whole_X, self.whole_Y = self.sample_image_loader.loader()
            self.WHOLE_output_img  = self.red_cnn(self.whole_X)
            
            #loss summary
            self.summary_loss = tf.summary.scalar("loss", self.loss)
            #psnr summary
            self.summary_psnr_ldct = tf.summary.scalar("1_psnr_LDCT", \
                ut.tf_psnr(self.whole_Y, self.whole_X, \
                           self.sample_image_loader.psnr_range), family = 'PSNR')
            self.summary_psnr_result = tf.summary.scalar("2_psnr_output", \
                ut.tf_psnr(self.whole_Y, self.WHOLE_output_img, \
                           self.sample_image_loader.psnr_range), family = 'PSNR')
            self.summary_psnr = tf.summary.merge([self.summary_psnr_ldct, self.summary_psnr_result])
            
            #image summary
            self.check_img_summary = tf.concat([tf.expand_dims(self.X[0], axis=0), \
                                        tf.expand_dims(self.Y[0], axis=0), \
                                        tf.expand_dims(self.output_img[0], axis=0)], axis = 2) 
            
            self.summary_train_image = tf.summary.image('0_train_image', \
                                                        self.check_img_summary)                                    
            self.whole_img_summary = \
                tf.concat([self.whole_X, self.whole_Y, self.WHOLE_output_img], axis = 2)
            self.summary_image = tf.summary.image('1_whole_image', self.whole_img_summary)
        
            print('--------------------------------------------\n# of parameters : {} '.\
                  format(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))    

        #model saver
        self.saver = tf.train.Saver(max_to_keep=None)